Source code for sparse_operation_kit.embeddings.distributed_embedding

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from sparse_operation_kit.core import EmbeddingVariable
from sparse_operation_kit.core import SparseEmbeddingLayerHandle
from sparse_operation_kit.embeddings import embedding_ops
import tensorflow as tf


[docs]class DistributedEmbedding(tf.keras.layers.Layer): """ Abbreviated as ``sok.DistributedEmbedding(*args, **kwargs)``. This is a wrapper class for distributed sparse embedding layer. It can be used to create a sparse embedding layer which will distribute keys based on `gpu_id = key % gpu_num` to each GPU. Parameters ---------- combiner: string it is used to specify how to combine embedding vectors intra slots. Can be `Mean` or `Sum`. max_vocabulary_size_per_gpu: integer the first dimension of embedding variable whose shape is [max_vocabulary_size_per_gpu, embedding_vec_size]. embedding_vec_size: integer the second dimension of embedding variable whose shape is [max_vocabulary_size_per_gpu, embedding_vec_size]. slot_num: integer the number of feature-fileds which will be processed at the same time in each iteration, where all feature-fileds produce embedding vectors of the same dimension. max_nnz: integer the number of maximum valid keys in each slot (feature-filed). max_feature_num: integer = slot\_num*max\_nnz the maximum valid keys in each sample. It can be used to save GPU memory when this statistic is known. By default, it is equal to :math:`max\_feature\_num=slot\_num*max\_nnz`. use_hashtable: boolean = True whether using `Hashtable` in ``EmbeddingVariable``, if `True`, Hashtable will be created for dynamic insertion. Otherwise, the input keys will be used as the index for embedding vector looking-up, so that input keys must be in the range ``[0, max_vocabulary_size_per_gpu * gpu_num)``. key_dtype: tf.dtypes = tf.int64 the data type of input keys. By default, it is `tf.int64`. embedding_initializer: string or an instance of `tf.keras.initializers.Initializer` the initializer used to generate initial value for embedding variable. By default, it will use `random_uniform` where ``minval=-0.05, maxval=0.05``. Examples -------- .. code-block:: python initializer = tf.keras.initializers.RandomUniform() # or "random_uniform" emb_layer = sok.DistributedEmbedding(combiner, max_vocabulary_size_per_gpu, embedding_vec_size, slot_num, max_nnz, embedding_initializer=initializer) @tf.function def _train_step(inputs, labels): emb_vectors = emb_layer(inputs) ... for i, (inputs, labels) in enumerate(dataset): _train_step(inputs) """ def __init__( self, combiner, max_vocabulary_size_per_gpu, embedding_vec_size, slot_num, max_nnz, max_feature_num=1, use_hashtable=True, key_dtype=None, embedding_initializer=None, **kwargs ): super(DistributedEmbedding, self).__init__(**kwargs) self.combiner = combiner self.max_vocabulary_size_per_gpu = max_vocabulary_size_per_gpu self.embedding_vec_size = embedding_vec_size self.slot_num = slot_num self.max_nnz = max_nnz self.max_feature_num = max_feature_num if self._dtype_policy.variable_dtype is None: # in TF1 and policy is not set # therefore variable dtype and compute dtype should be fp32 from tensorflow.python.keras.mixed_precision import experimental as mixed_precision self._dtype_policy = mixed_precision.Policy("float32") self.var = EmbeddingVariable.CreateInstances( shape=[self.max_vocabulary_size_per_gpu, self.embedding_vec_size], trainable=True, use_hashtable=use_hashtable, dtype=self._dtype_policy.variable_dtype, key_dtype=key_dtype, initializer=embedding_initializer, ) self.emb_layer = SparseEmbeddingLayerHandle( self.var, input_dispatcher="all_gather_dispatcher", input_dispatcher_subsequent_ops=["csr_conversion_distributed"], embedding_executor="distributed", output_dispatcher="reduce_scatter_dispatcher", slot_num=self.slot_num, max_nnz=self.max_nnz, max_feature_num=self.max_feature_num, combiner=self.combiner, compute_dtype=self._dtype_policy.compute_dtype, ) @property def embedding_variable(self): return self.var def get_config(self): config = super(DistributedEmbedding, self).get_config() config.update({}) return config def build(self, input_shape): pass # @tf.function
[docs] def call(self, inputs, training=True): """ The forward logic of this wrapper class. Parameters ---------- inputs: tf.sparse.SparseTensor keys are stored in SparseTensor.values. SparseTensor.dense_shape is 2-dim and denotes [batchsize * slot_num, max_nnz]. Therefore, the rank of SparseTensor.indices must be 2 which denotes [row-indices, column-indices] in the corresponding dense tensor. training: boolean whether training or not. Returns ------- emb_vector: tf.float the embedding vectors for the input keys. Its shape is *[batchsize, slot_num, embedding_vec_size]* """ emb_vector = embedding_ops.embedding_lookup_sparse( embedding_variable=self.var, sp_ids=inputs, slot_num=self.slot_num, training=training ) return emb_vector