SparseOperationKit Optimizer Utils

sparse_operation_kit.optimizers.utils.split_embedding_variable_from_others(variables)[source]

This function is used to split embedding variables from other variables.

Abbreviated as sok.split_embedding_variable_from_others(variables).

Embedding variables are automatically created along with embedding layers. Since the aggregation for embedding variables is different from other variables, we need to split embedding variable and other variables so that optimizer can process those variables in different way.

Parameters

variables (list, tuple) – a list or tuple of trainable tf.Variable.

Returns

  • embedding_variables (tuple) – all embedding variables in the input variable-list.

  • other_variables (tuple) – all normal variables in the input variable-list.

Example

class Model(tf.keras.models.Model):
    def __init__(self, *args, **kwargs):
        super(Model, self).__init__(*args, **kwargs)

        self.embedding_layer = sok.DistributedEmbedding(...)
        self.dense_layer = tf.keras.layers.Dense(units=1, ...)

    def call(self, inputs, training=True):
        vectors = self.embedding_layer(inputs, training)
        out = self.dense_layer(vectors)
        return out

with strategy.scope():
    model = Model()

loss_fn = ...

@tf.function
def _train_step(inputs, labels):
    with tf.GradientTape() as tape:
        out = model(inputs)
        loss = loss_fn(out, labels)
    emb_vars, other_vars = sok.split_embedding_variable_from_others(model.trainable_variables)
    ...

for step, (inputs, labels) in enumerate(dataset):
    strategy.run(_train_step, args=(inputs, labels))
    ...