SparseOperationKit Dynamic Variable

class sparse_operation_kit.dynamic_variable.DynamicVariable(*args, **kwargs)[source]

Abbreviated as sok.DynamicVariable.

A variable that allocates memory dynamically.

Parameters:
  • dimension (int) – The last dimension of this variable(that is, the embedding vector size of embedding table).

  • initializer (string) –

    a string to specify how to initialize this variable.

    Currently, only support “random” or string of a float value(meaning const initializer). Default value is “random”.

  • var_type (string) –

    a string to specify to use DET(“hbm”) or HKV(“hybrid”) as the backend.

    default value is “hbm”.

    If use DET as the backend, DET will retain two key values as the empty value and reclaim value for the hash table.

    If the input key is the same as these two values, the program will crash.

    If the key type is signed, the empty value = std::numeric_limits::max(), and the reclaim value = std::numeric_limits::min().

    If the key type is unsigned, the empty value = std::numeric_limits::max(), and the reclaim value = std::numeric_limits::max() - 1.

    If use HKV as the backend, only support tf.int64 as key_type.

    If use HKV as the backend, please set init_capacity and max_capacity value equal to 2 powers.

  • key_type (dtype) – specify the data type of indices. Unlike the static variable of tensorflow, this variable is dynamically allocated and contains a hash table inside it. So the data type of indices must be specified to construct the hash table. Default value is tf.int64.

  • dtype (dtype) – specify the data type of values. Default value is tf.float32.

Example

import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
import sparse_operation_kit as sok

v = sok.DynamicVariable(dimension=3, initializer="13")
print("v.shape:", v.shape)
print("v.size:", v.size)

indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)

embedding = tf.nn.embedding_lookup(v, indices)
print("embedding:", embedding)
print("v.shape:", v.shape)
print("v.size:", v.size)
sparse_operation_kit.dynamic_variable.export(var)[source]

Abbreviated as sok.export.

Export the indices and value tensor from the given variable.

Parameters:

var (sok.DynamicVariable) – The variable to extract indices and values.

Returns:

  • indices (tf.Tensor) – The indices of the given variable.

  • values (tf.Tensor) – the values of the given variable.

sparse_operation_kit.dynamic_variable.assign(var, indices, values)[source]

Abbreviated as sok.assign.

Assign the indices and value tensor to the target variable.

Parameters:
  • var (sok.DynamicVariable) – The target variable of assign.

  • indices (tf.Tensor) – indices to be assigned to the variable.

  • values (tf.Tensor) – values to be assigned to the variable

Returns:

variable

Return type:

sok.DynamicVariable

sparse_operation_kit.optimizer.OptimizerWrapper(optimizer)[source]

Abbreviated as sok.OptimizerWrapper.

This is a wrapper for tensorflow optimizer so that it can update sok.DynamicVariable.

Parameters:

optimizer (tensorflow optimizer) – The original tensorflow optimizer.

Example

import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
import sparse_operation_kit as sok

v = sok.DynamicVariable(dimension=3, initializer="13")

indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)

with tf.GradientTape() as tape:
    embedding = tf.nn.embedding_lookup(v, indices)
    print("embedding:", embedding)
    loss = tf.reduce_sum(embedding)

grads = tape.gradient(loss, [v])

optimizer = tf.keras.optimizers.SGD(learning_rate=1.0)
optimizer = sok.OptimizerWrapper(optimizer)
optimizer.apply_gradients(zip(grads, [v]))

embedding = tf.nn.embedding_lookup(v, indices)
print("embedding:", embedding)