SparseOperationKit Optimizer Scope

class sparse_operation_kit.core.context_scope.OptimizerScope(trainable_variables)[source]

The context manager used along with TensorFlow optimizers. It is only needed when TensorFlow native optimizers is used.

Abbreviated as sok.OptimizerScope(variables).

Parameters

trainable_variables (list, tuple) – a list or tuple of trainable tf.Variable.

Returns

context_manager – used to switch handles for embedding variables.

Return type

context_manager

Example

with strategy.scope():
    model = ...

    emb_opt = tf.keras.optimizers.Adam(...)
    other_opt = tf.keras.optimizers.Adam(...)

@tf.function
def _train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs)
        loss = loss_fn(logits, labels)

    emb_vars, other_vars = sok.split_embedding_variable_from_others(model.trainable_variables)
    emb_grads, other_grads = tape.gradient(loss, [emb_vars, other_vars])

    with sok.OptimizerScope(emb_vars):
        emb_opt.apply_gradients(zip(emb_grads, emb_vars),
                                experimental_aggregate_gradients=False)

    dense_opt.apply_gradients(zip(other_grads, other_vars))

Notes

This context manager may not be used in next release.