Source code for sparse_operation_kit.embeddings.tf_distributed_embedding

# Copyright (c) 2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.python.ops import collective_ops
from tensorflow.python.framework import ops

[docs]class TFDistributedEmbedding(tf.keras.layers.Layer): """ This Embedding layer will distribute embedding parameters to multiple GPUs. It leverages tf.distribute.Strategy to do the communication, so that tf.distribute.Strategy must be used. Parameters ---------- vocabulary_size: integer the first dimension of variable whose shape is [vocabulary_size, embedding_vec_size]. embedding_vec_size: integer the second dimension of variable whose shape is [vocabulary_size, embedding_vec_size]. initializer: string, numpy.array = 'GlorotNormal' When it's string, it specifies the initializer used to generate initial values. When it's numpy.array, its shape must be [vocabulary_size, embedding_vec_size], and will be used as the initial value. comm_options: tf.distribute.experimental.CommunicationOptions = None see TF's docs Examples -------- .. code-block:: python strategy = ... with strategy.scope(): embedding_layer = TFDistributedEmbedding(vocabulary_size, embedding_vec_size, initializer) ... @tf.function def _train_step(inputs, labels): emb_vectors = embedding_layer(inputs) for i, (inputs, labels) in enumerate(dataset):, args=(inputs, labels)) Notes ----- Currently, the variables created by this class can not be correctly saved to files. """ def __init__( self, vocabulary_size, embedding_vec_size, initializer="GlorotNormal", comm_options=None, **kwargs, ): super(TFDistributedEmbedding, self).__init__(**kwargs) self._vocabulary_size = vocabulary_size self._embedding_vec_size = embedding_vec_size self._uid = ops.uid() if isinstance(initializer, str): self._initial_value = tf.keras.initializers.get(initializer)( shape=(self._vocabulary_size, self._embedding_vec_size) ) else: if initializer.shape != (self._vocabulary_size, self._embedding_vec_size): raise ValueError( "The shape of initializer must be [vocabulary_size, embedding_vec_size.]" ) self._initial_value = initializer self._comm_options = comm_options self._embedding_weights = tf.Variable( initial_value=self._initial_value, dtype=tf.float32, name="EmbeddingWeights" ) if not tf.distribute.has_strategy(): raise RuntimeError("This layer must be created under tf.distribute.Strategy.Scope().") # strategy = tf.distribute.get_strategy() # @property def embedding_weights(self): return self._embedding_weights @tf.function def broadcast_variables(self): replica_ctx = tf.distribute.get_replica_context() g_replica_id = replica_ctx.replica_id_in_sync_group if replica_ctx.num_replicas_in_sync == 1: return variable = tf.identity(self._embedding_weights) if 0 == g_replica_id: values = collective_ops.broadcast_send( variable, variable.shape, variable.dtype, group_size=replica_ctx.num_replicas_in_sync, group_key=2, instance_key=2 + self._uid, timeout=5, ) else: values = collective_ops.broadcast_recv( variable.shape, variable.dtype, group_size=replica_ctx.num_replicas_in_sync, group_key=2, instance_key=2 + self._uid, timeout=5, ) self._embedding_weights.assign(values) def _condition(self, gathered_inputs, replica_ctx): global_replica_id = replica_ctx.replica_id_in_sync_group global_replica_id = tf.cast(global_replica_id, gathered_inputs.dtype) num_devices = replica_ctx.num_replicas_in_sync condition = gathered_inputs % num_devices == global_replica_id return condition
[docs] def call(self, inputs): """ The forward logic of this wrapper class. Parameters ---------- inputs: inputs: tf.Tensor keys are stored in tf.Tensor with dtype tf.int32 or tf.int64 Returns ------- replica_output: tf.Tensor embedding vectors on each replica, with dtype tf.float32 """ if tf.distribute.in_cross_replica_context(): raise RuntimeError( "The forward propagation of TFDistributedEmbedding " "cannot be called in cross_replica_context." ) replica_ctx = tf.distribute.get_replica_context() global_replica_id = replica_ctx.replica_id_in_sync_group inputs_shape = tf.shape(inputs) replica_size = tf.size(inputs) replica_inputs = tf.reshape(inputs, [replica_size]) replica_inputs = tf.identity(replica_inputs) # all-gather for each replica along batch dim gathered_inputs = replica_ctx.all_gather( value=replica_inputs, axis=0, options=self._comm_options ) # select inputs for each replica condition = self._condition(gathered_inputs, replica_ctx) replica_indices = tf.where(condition) replica_selected_inputs = tf.gather_nd(gathered_inputs, replica_indices) # embedding lookup replica_vectors = tf.nn.embedding_lookup( params=self._embedding_weights, ids=replica_selected_inputs ) # all-gather embedding vectors for each replica gathered_vectors = replica_ctx.all_gather( value=replica_vectors, axis=0, options=self._comm_options ) gathered_indices = replica_ctx.all_gather( value=replica_indices, axis=0, options=self._comm_options ) gathered_indices = tf.squeeze(gathered_indices) # reorder embedding vectors sorted_gathered_indices = tf.argsort(gathered_indices) gathered_inputs_size = replica_ctx.all_gather( value=tf.expand_dims(replica_size, axis=0), axis=0, options=self._comm_options ) if tf.rank(gathered_inputs_size) == 0: gathered_inputs_size = tf.expand_dims(gathered_inputs_size, axis=0) begin = ( tf.math.reduce_sum(gathered_inputs_size[:global_replica_id]) if global_replica_id > 0 else 0 ) begin = tf.expand_dims(begin, axis=0) size = tf.slice(gathered_inputs_size, begin=[global_replica_id], size=[1]) replica_output_indices = tf.slice(sorted_gathered_indices, begin=begin, size=size) # select replica's vectors replica_output = tf.gather(gathered_vectors, replica_output_indices) output_shape = inputs.get_shape().concatenate(self._embedding_vec_size) replica_output = tf.reshape(replica_output, output_shape) return replica_output
if __name__ == "__main__": import os os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" input_data = tf.constant([i for i in range(40)]) input_data = tf.reshape(input_data, [8, 5]) initial_value = tf.constant([i for i in range(40 * 4)], dtype=tf.float32) initial_value = tf.reshape(initial_value, shape=(40, 4)) global_batch_size = 4 strategy = tf.distribute.MirroredStrategy() def _dataset_fn(input_context): replica_bs = input_context.get_per_replica_batch_size(global_batch_size) dataset = dataset = dataset.batch(replica_bs) dataset = dataset.repeat(2) dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) return dataset dataset = strategy.distribute_datasets_from_function(_dataset_fn) with strategy.scope(): embedding_layer = TFDistributedEmbedding( vocabulary_size=40, embedding_vec_size=4, initializer=initial_value ) @tf.function def _step(inputs): with tf.GradientTape() as tape: outputs = embedding_layer(inputs) grads = tape.gradient(outputs, embedding_layer._embedding_weights) return outputs, grads for step, inputs in enumerate(dataset): outputs =, args=(inputs,)) print(f"Iteration: {step}") for out in outputs: print(out)