SparseOperationKit Optimizer Scope
- class sparse_operation_kit.core.context_scope.OptimizerScope(trainable_variables)[source]
The context manager used along with TensorFlow optimizers. It is only needed when TensorFlow native optimizers is used.
Abbreviated as
sok.OptimizerScope(variables)
.- Parameters
trainable_variables (list, tuple) – a list or tuple of trainable tf.Variable.
- Returns
context_manager – used to switch handles for embedding variables.
- Return type
context_manager
Example
with strategy.scope(): model = ... emb_opt = tf.keras.optimizers.Adam(...) other_opt = tf.keras.optimizers.Adam(...) @tf.function def _train_step(inputs, labels): with tf.GradientTape() as tape: logits = model(inputs) loss = loss_fn(logits, labels) emb_vars, other_vars = sok.split_embedding_variable_from_others(model.trainable_variables) emb_grads, other_grads = tape.gradient(loss, [emb_vars, other_vars]) with sok.OptimizerScope(emb_vars): emb_opt.apply_gradients(zip(emb_grads, emb_vars), experimental_aggregate_gradients=False) dense_opt.apply_gradients(zip(other_grads, other_vars))
Notes
This context manager may not be used in next release.