TF Distributed Embedding

Wrapper classes to build model-parallel embedding layer totally with TensorFlow’s API. It utilizes tf.distribute.Strategy to do the communication among different GPUs.

class sparse_operation_kit.embeddings.tf_distributed_embedding.TFDistributedEmbedding(*args, **kwargs)[source]

This Embedding layer will distribute embedding parameters to multiple GPUs. It leverages tf.distribute.Strategy to do the communication, so that tf.distribute.Strategy must be used.

  • vocabulary_size (integer) – the first dimension of variable whose shape is [vocabulary_size, embedding_vec_size].

  • embedding_vec_size (integer) – the second dimension of variable whose shape is [vocabulary_size, embedding_vec_size].

  • initializer (string, numpy.array = 'GlorotNormal') – When it’s string, it specifies the initializer used to generate initial values. When it’s numpy.array, its shape must be [vocabulary_size, embedding_vec_size], and will be used as the initial value.

  • comm_options (tf.distribute.experimental.CommunicationOptions = None) – see TF’s docs


strategy = ...

with strategy.scope():
    embedding_layer = TFDistributedEmbedding(vocabulary_size, embedding_vec_size,

def _train_step(inputs, labels):
    emb_vectors = embedding_layer(inputs)

for i, (inputs, labels) in enumerate(dataset):, args=(inputs, labels))


Currently, the variables created by this class can not be correctly saved to files.


The forward logic of this wrapper class.


inputs (inputs: tf.Tensor) – keys are stored in tf.Tensor with dtype tf.int32 or tf.int64


replica_output – embedding vectors on each replica, with dtype tf.float32

Return type