SparseOperationKit Optimizer Utils
- sparse_operation_kit.optimizers.utils.split_embedding_variable_from_others(variables)[source]
This function is used to split embedding variables from other variables.
Abbreviated as
sok.split_embedding_variable_from_others(variables)
.Embedding variables are automatically created along with embedding layers. Since the aggregation for embedding variables is different from other variables, we need to split embedding variable and other variables so that optimizer can process those variables in different way.
- Parameters
variables (list, tuple) – a list or tuple of trainable tf.Variable.
- Returns
embedding_variables (tuple) – all embedding variables in the input variable-list.
other_variables (tuple) – all normal variables in the input variable-list.
Example
class Model(tf.keras.models.Model): def __init__(self, *args, **kwargs): super(Model, self).__init__(*args, **kwargs) self.embedding_layer = sok.DistributedEmbedding(...) self.dense_layer = tf.keras.layers.Dense(units=1, ...) def call(self, inputs, training=True): vectors = self.embedding_layer(inputs, training) out = self.dense_layer(vectors) return out with strategy.scope(): model = Model() loss_fn = ... @tf.function def _train_step(inputs, labels): with tf.GradientTape() as tape: out = model(inputs) loss = loss_fn(out, labels) emb_vars, other_vars = sok.split_embedding_variable_from_others(model.trainable_variables) ... for step, (inputs, labels) in enumerate(dataset): strategy.run(_train_step, args=(inputs, labels)) ...