SparseOperationKit Experiment API

Initialize

sparse_operation_kit.experiment.init(comm_tool='horovod')[source]

Abbreviated as sok.experiment.init.

This function is used to do the initialization of SparseOperationKit (SOK).

SOK will leverage all available GPUs for current CPU process. Please set CUDA_VISIBLE_DEVICES or tf.config.set_visible_devices to specify which GPU(s) are used in this process before launching tensorflow runtime and calling this function.

Currently, these experiment API only support horovod as the communication tool, so horovod.init must be called before initializing SOK.

Example code for doing initialization:

import tensorflow as tf
import horovod.tensorflow as hvd
import sparse_operation_kit.experiment as sok

hvd.init()
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")

sok.init()
Parameters

comm_tool (string) – a string to specify which communication tool to use. Default value is “horovod”.

Return type

None

Lookup

sparse_operation_kit.experiment.distributed_variable.Variable(*args, **kwargs)[source]

Abbreviated as sok.experiment.Variable.

This is a helper function to generate model-parallel variable. There are two use cases:

Distributed Variable:

import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
from sparse_operation_kit import experiment as sok

hvd.init()
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")  # nopep8

sok.init()

# If there are 2 GPUs in total, the shape on GPU0 will be [2, 3] and the shape
# on GPU1 will be [2, 3]
v = sok.Variable(np.arange(4 * 3).reshape(4, 3), dtype=tf.float32)

# GPU0 output: [[0, 1, 2]
#               [6, 7, 8]]
# GPU1 output: [[3, 4,  5]
#                9, 10, 11]
print(v)

Localized Variable:

import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
from sparse_operation_kit import experiment as sok

hvd.init()
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")  # nopep8

sok.init()

# If there are 2 GPUs in total, the shape on GPU0 will be [5, 3] and the shape
# on GPU1 will be [0, 3]
v = sok.Variable(
    np.arange(5 * 3).reshape(5, 3), dtype=tf.float32, mode="localized:0"
)
print(v.shape)

As shown in the two examples above, when you need to store different parts of a variable on different GPUs (that is, allocating a model-parallel variable), this function can help you allocate the required memory on each GPU.

Parameters
  • args – compatible with tf.Variable.

  • kwargs – compatible with tf.Variable.

  • mode (string) – a string to specify which model-parallel mode to use. Default value is “distributed”, which stands for the Distributed Variable that mentioned above. Another option is “localized:#”, which stands for Localized Variable, where # indicates which GPU you want to put this variable on. See the explanation above for specific examples.

Returns

variable – a tf.Variable that represents a part of the model-parallel variable.

Return type

tf.Variable

sparse_operation_kit.experiment.lookup.lookup_sparse(params, sp_ids, hotness, combiners)[source]

Abbreviated as sok.experiment.lookup_sparse.

Peform fused sparse lookup on the given embedding params. This function is similar to the tf.nn.embedding_lookup_sparse, but with two differences:

  • It can do distributed lookup.

  • It can accept multiple params and multiple sp_ids to do fused lookup at once, which brings performance benifits.

Parameters
  • params (list, tuple) – a list or tuple of trainable sok.Variable.

  • sp_ids (list, tuple) – a list or tuple of tf.SparseTensor or tf.RaggedTensor.

  • hotness (list, tuple) – a list or tuple of int to specify the max hotness of each lookup.

  • combiners (list, tuple) – a list or tuple of string to specify the combiner of each lookup.

Returns

emb_vec – a list of tf.Tensor(the results of lookup).

Return type

list

Example

import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
from sparse_operation_kit import experiment as sok

hvd.init()
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")

sok.init()

v1 = sok.Variable(np.arange(17 * 3).reshape(17, 3), dtype=tf.float32)
v2 = sok.Variable(np.arange(7 * 5).reshape(7, 5), dtype=tf.float32)

indices1 = tf.SparseTensor(
    indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
    values=[1, 1, 3, 4, 5],
    dense_shape=[2, 3])
)
indices2 = tf.SparseTensor(
    indices=[[0, 0], [1, 0], [1, 1]],
    values=[1, 2, 3],
    dense_shape=[2, 2]
)

embeddings = sok.lookup_sparse(
    [v1, v2], [indices1, indices2], hotness=[3, 2], combiners=["sum", "sum"]
)
print(embeddings[0])
print(embeddings[1])

DynamicVariable

class sparse_operation_kit.experiment.dynamic_variable.DynamicVariable(*args, **kwargs)[source]

Abbreviated as sok.experiment.DynamicVariable.

A variable that allocates memory dynamically.

Parameters
  • dimension (int) – The last dimension of this variable(that is, the embedding vector size of embedding table).

  • initializer (string) – a string to specify how to initialize this variable. Currently, only support “random” or string of a float value(meaning const initializer). Default value is “random”.

  • key_type (dtype) – specify the data type of indices. Unlike the static variable of tensorflow, this variable is dyanmically allocated and contains a hash table inside it. So the data type of indices must be specified to construct the hash table. Default value is tf.int64.

  • dtype (dtype) – specify the data type of values. Default value is tf.float32.

Example

import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
from sparse_operation_kit import experiment as sok

v = sok.DynamicVariable(dimension=3, initializer="13")
print("v.shape:", v.shape)
print("v.size:", v.size)

indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)

embedding = tf.nn.embedding_lookup(v, indices)
print("embedding:", embedding)
print("v.shape:", v.shape)
print("v.size:", v.size)
sparse_operation_kit.experiment.dynamic_variable.export(var)[source]

Abbreviated as sok.experiment.export.

Export the indices and value tensor from the given variable.

Parameters

var (sok.DynamicVariable) – The variable to extract indices and values.

Returns

  • indices (tf.Tensor) – The indices of the given variable.

  • values (tf.Tensor) – the values of the given variable.

sparse_operation_kit.experiment.dynamic_variable.assign(var, indices, values)[source]

Abbreviated as sok.experiment.assign.

Assign the indices and value tensor to the target variable.

Parameters
  • var (sok.DynamicVariable) – The target variable of assign.

  • indices (tf.Tensor) – indices to be assigned to the variable.

  • values (tf.Tensor) – values to be assigned to the variable

Returns

variable

Return type

sok.DynamicVariable

sparse_operation_kit.experiment.optimizer.OptimizerWrapper(optimizer)[source]

Abbreviated as sok.experiment.OptimizerWrapper.

This is a wrapper for tensorflow optimizer so that it can update sok.DynamicVariable.

Parameters

optimizer (tensorflow optimizer) – The original tensorflow optimizer.

Example

import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
from sparse_operation_kit import experiment as sok

v = sok.DynamicVariable(dimension=3, initializer="13")

indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)

with tf.GradientTape() as tape:
    embedding = tf.nn.embedding_lookup(v, indices)
    print("embedding:", embedding)
    loss = tf.reduce_sum(embedding)

grads = tape.gradient(loss, [v])

optimizer = tf.keras.optimizers.SGD(learning_rate=1.0)
optimizer = sok.OptimizerWrapper(optimizer)
optimizer.apply_gradients(zip(grads, [v]))

embedding = tf.nn.embedding_lookup(v, indices)
print("embedding:", embedding)