SparseOperationKit Experiment API
Initialize
- sparse_operation_kit.experiment.init(comm_tool='horovod')[source]
- Abbreviated as - sok.experiment.init.- This function is used to do the initialization of SparseOperationKit (SOK). - SOK will leverage all available GPUs for current CPU process. Please set CUDA_VISIBLE_DEVICES or tf.config.set_visible_devices to specify which GPU(s) are used in this process before launching tensorflow runtime and calling this function. - Currently, these experiment API only support - horovodas the communication tool, so- horovod.initmust be called before initializing SOK.- Example code for doing initialization: - import tensorflow as tf import horovod.tensorflow as hvd import sparse_operation_kit.experiment as sok hvd.init() gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU") sok.init() - Parameters
- comm_tool (string) – a string to specify which communication tool to use. Default value is “horovod”. 
- Return type
- None 
 
Lookup
- sparse_operation_kit.experiment.distributed_variable.Variable(*args, **kwargs)[source]
- Abbreviated as - sok.experiment.Variable.- This is a helper function to generate model-parallel variable. There are two use cases: - Distributed Variable: - import numpy as np import tensorflow as tf import horovod.tensorflow as hvd from sparse_operation_kit import experiment as sok hvd.init() gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU") # nopep8 sok.init() # If there are 2 GPUs in total, the shape on GPU0 will be [2, 3] and the shape # on GPU1 will be [2, 3] v = sok.Variable(np.arange(4 * 3).reshape(4, 3), dtype=tf.float32) # GPU0 output: [[0, 1, 2] # [6, 7, 8]] # GPU1 output: [[3, 4, 5] # 9, 10, 11] print(v) - Localized Variable: - import numpy as np import tensorflow as tf import horovod.tensorflow as hvd from sparse_operation_kit import experiment as sok hvd.init() gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU") # nopep8 sok.init() # If there are 2 GPUs in total, the shape on GPU0 will be [5, 3] and the shape # on GPU1 will be [0, 3] v = sok.Variable( np.arange(5 * 3).reshape(5, 3), dtype=tf.float32, mode="localized:0" ) print(v.shape) - As shown in the two examples above, when you need to store different parts of a variable on different GPUs (that is, allocating a model-parallel variable), this function can help you allocate the required memory on each GPU. - Parameters
- args – compatible with tf.Variable. 
- kwargs – compatible with tf.Variable. 
- mode (string) – a string to specify which model-parallel mode to use. Default value is “distributed”, which stands for the Distributed Variable that mentioned above. Another option is “localized:#”, which stands for Localized Variable, where # indicates which GPU you want to put this variable on. See the explanation above for specific examples. 
 
- Returns
- variable – a tf.Variable that represents a part of the model-parallel variable. 
- Return type
- tf.Variable 
 
- sparse_operation_kit.experiment.lookup.lookup_sparse(params, sp_ids, hotness, combiners)[source]
- Abbreviated as - sok.experiment.lookup_sparse.- Peform fused sparse lookup on the given embedding - params. This function is similar to the- tf.nn.embedding_lookup_sparse, but with two differences:- It can do distributed lookup. 
- It can accept multiple params and multiple sp_ids to do fused lookup at once, which brings performance benifits. 
 - Parameters
- params (list, tuple) – a list or tuple of trainable sok.Variable. 
- sp_ids (list, tuple) – a list or tuple of tf.SparseTensor or tf.RaggedTensor. 
- hotness (list, tuple) – a list or tuple of int to specify the max hotness of each lookup. 
- combiners (list, tuple) – a list or tuple of string to specify the combiner of each lookup. 
 
- Returns
- emb_vec – a list of tf.Tensor(the results of lookup). 
- Return type
- list 
 - Example - import numpy as np import tensorflow as tf import horovod.tensorflow as hvd from sparse_operation_kit import experiment as sok hvd.init() gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU") sok.init() v1 = sok.Variable(np.arange(17 * 3).reshape(17, 3), dtype=tf.float32) v2 = sok.Variable(np.arange(7 * 5).reshape(7, 5), dtype=tf.float32) indices1 = tf.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]], values=[1, 1, 3, 4, 5], dense_shape=[2, 3]) ) indices2 = tf.SparseTensor( indices=[[0, 0], [1, 0], [1, 1]], values=[1, 2, 3], dense_shape=[2, 2] ) embeddings = sok.lookup_sparse( [v1, v2], [indices1, indices2], hotness=[3, 2], combiners=["sum", "sum"] ) print(embeddings[0]) print(embeddings[1]) 
DynamicVariable
- class sparse_operation_kit.experiment.dynamic_variable.DynamicVariable(*args, **kwargs)[source]
- Abbreviated as - sok.experiment.DynamicVariable.- A variable that allocates memory dynamically. - Parameters
- dimension (int) – The last dimension of this variable(that is, the embedding vector size of embedding table). 
- initializer (string) – a string to specify how to initialize this variable. Currently, only support “random” or string of a float value(meaning const initializer). Default value is “random”. 
- key_type (dtype) – specify the data type of indices. Unlike the static variable of tensorflow, this variable is dyanmically allocated and contains a hash table inside it. So the data type of indices must be specified to construct the hash table. Default value is tf.int64. 
- dtype (dtype) – specify the data type of values. Default value is tf.float32. 
 
 - Example - import numpy as np import tensorflow as tf import horovod.tensorflow as hvd from sparse_operation_kit import experiment as sok v = sok.DynamicVariable(dimension=3, initializer="13") print("v.shape:", v.shape) print("v.size:", v.size) indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64) embedding = tf.nn.embedding_lookup(v, indices) print("embedding:", embedding) print("v.shape:", v.shape) print("v.size:", v.size) 
- sparse_operation_kit.experiment.dynamic_variable.export(var)[source]
- Abbreviated as - sok.experiment.export.- Export the indices and value tensor from the given variable. - Parameters
- var (sok.DynamicVariable) – The variable to extract indices and values. 
- Returns
- indices (tf.Tensor) – The indices of the given variable. 
- values (tf.Tensor) – the values of the given variable. 
 
 
- sparse_operation_kit.experiment.dynamic_variable.assign(var, indices, values)[source]
- Abbreviated as - sok.experiment.assign.- Assign the indices and value tensor to the target variable. - Parameters
- var (sok.DynamicVariable) – The target variable of assign. 
- indices (tf.Tensor) – indices to be assigned to the variable. 
- values (tf.Tensor) – values to be assigned to the variable 
 
- Returns
- variable 
- Return type
- sok.DynamicVariable 
 
- sparse_operation_kit.experiment.optimizer.OptimizerWrapper(optimizer)[source]
- Abbreviated as - sok.experiment.OptimizerWrapper.- This is a wrapper for tensorflow optimizer so that it can update sok.DynamicVariable. - Parameters
- optimizer (tensorflow optimizer) – The original tensorflow optimizer. 
 - Example - import numpy as np import tensorflow as tf import horovod.tensorflow as hvd from sparse_operation_kit import experiment as sok v = sok.DynamicVariable(dimension=3, initializer="13") indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64) with tf.GradientTape() as tape: embedding = tf.nn.embedding_lookup(v, indices) print("embedding:", embedding) loss = tf.reduce_sum(embedding) grads = tape.gradient(loss, [v]) optimizer = tf.keras.optimizers.SGD(learning_rate=1.0) optimizer = sok.OptimizerWrapper(optimizer) optimizer.apply_gradients(zip(grads, [v])) embedding = tf.nn.embedding_lookup(v, indices) print("embedding:", embedding)