SparseOperationKit Optimizer Utils


This function is used to split embedding variables from other variables.

Abbreviated as sok.split_embedding_variable_from_others(variables).

Embedding variables are automatically created along with embedding layers. Since the aggregation for embedding variables is different from other variables, we need to split embedding variable and other variables so that optimizer can process those variables in different way.


variables (list, tuple) – a list or tuple of trainable tf.Variable.


  • embedding_variables (tuple) – all embedding variables in the input variable-list.

  • other_variables (tuple) – all normal variables in the input variable-list.


class Model(tf.keras.models.Model):
    def __init__(self, *args, **kwargs):
        super(Model, self).__init__(*args, **kwargs)

        self.embedding_layer = sok.DistributedEmbedding(...)
        self.dense_layer = tf.keras.layers.Dense(units=1, ...)

    def call(self, inputs, training=True):
        vectors = self.embedding_layer(inputs, training)
        out = self.dense_layer(vectors)
        return out

with strategy.scope():
    model = Model()

loss_fn = ...

def _train_step(inputs, labels):
    with tf.GradientTape() as tape:
        out = model(inputs)
        loss = loss_fn(out, labels)
    emb_vars, other_vars = sok.split_embedding_variable_from_others(model.trainable_variables)

for step, (inputs, labels) in enumerate(dataset):, args=(inputs, labels))