Source code for distributed_embeddings.python.ops.embedding_lookup_ops

# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Embedding ops."""

import tensorflow as tf

from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import resource_loader
from tensorflow.python.ops import resource_variable_ops

ops = tf.load_op_library(resource_loader.get_path_to_datafile('_embedding_lookup_ops.so'))


def read_var_no_copy(res_var):
  resource_variable_ops.variable_accessed(res_var)
  return ops.read_variable_no_copy(res_var.handle, dtype=res_var.dtype)


@tf.RegisterGradient("ReadVariableNoCopy")
def _read_grad(_, grad):
  """Gradient for read op(no copy)."""
  return grad


[docs]def embedding_lookup(param, ids, combiner=None): """Looks up embeddings for the given `ids` from a embedding tensor. Args: param (Tensor): A single tensor representing the complete embedding tensor. ids (Tensor): A 2D `int32` or `int64` `Tensor` containing the ids to be looked up in `param`. Also support `RaggedTensor` and `SparseTensor`. combiner (string or None): Reduction method, ['sum', 'mean'] or None. Default None. Returns: Tensor: A `Tensor` with the same type as the tensors in `param`. .. note:: When combiner is None, returned tensor has shape: ``shape(ids) + shape(param)[1]`` Otherwise, embedding from same row is reduced and returned tensor has shape: ``shape(ids)[0] + shape(param)[1]`` Note when ids is RaggedTensor, its values and row_splits are col_index and row_index of CSR format hotness matrix, thus can be directly constructed. Raises: TypeError: If `param` is empty. ValueError: If `ids` is not 2D tensor. """ if not tf.is_tensor(param): raise TypeError("param must be Tensor") if ids.get_shape().ndims != 2: raise ValueError("Only support 2D input") if combiner is None: return tf.nn.embedding_lookup(param, ids) if isinstance(ids, ragged_tensor.RaggedTensor): # assuming no empty sample. tf.shape may fail on earlier tf version with ragged input try: dim_0 = tf.shape(ids, out_type=tf.int32)[0] if ids.shape[0] is None else ids.shape[0] except: # pylint: disable=bare-except dim_0 = tf.shape(ids.row_splits, out_type=tf.int32)[0] - 1 if ids.shape[0] is None else ids.shape[0] num_input = tf.shape( ids.values, out_type=tf.int32)[0] if ids.values.shape[0] is None else ids.values.shape[0] if dim_0 == num_input: return tf.nn.embedding_lookup(param, ids.values) return ops.embedding_lookup_variable_hotness(read_var_no_copy(param), ids.values, ids.row_splits, combiner) if isinstance(ids, tf.SparseTensor): # sparse is ordered but may not be right-ragged. so we generate offset here # avoid d2h copy in eager mode by using sparsetensor's shape directly dim_0 = tf.shape(ids, out_type=tf.int32)[0] if ids.shape[0] is None else ids.shape[0] num_input = tf.shape( ids.values, out_type=tf.int32)[0] if ids.values.shape[0] is None else ids.values.shape[0] if dim_0 == num_input: return tf.nn.embedding_lookup(param, ids.values) # use custom op to avoid bad XLA bahavior and d2h copy caused by searchsorted row_splits = ops.row_to_split(ids.indices, dim_0) # we really want ids.values and row_splits to be same dtype to simplify things # since max(row_splits) here is likely ~total hotness, int32 should be ok # TODO(Deyu): fuse this cast into above row_to_split function and make always int32 return ops.embedding_lookup_variable_hotness(read_var_no_copy(param), ids.values, tf.cast(row_splits, dtype=ids.values.dtype), combiner) dim1 = tf.shape(ids, out_type=tf.int32)[1] if ids.shape[1] is None else ids.shape[1] if dim1 == 1: return tf.nn.embedding_lookup(param, tf.squeeze(ids, [1])) if combiner == 'sum': return tf.reduce_sum(tf.nn.embedding_lookup(param, ids), axis=1) return tf.reduce_mean(tf.nn.embedding_lookup(param, ids), axis=1)
@tf.RegisterGradient("EmbeddingLookupVariableHotness") def _embedding_lookup_variable_hotness_grad(op, grad): """The gradients for `embedding_lookup_variable_hotness`. Args: op (object): The `embedding_lookup_variable_hotness` `Operation` that we are differentiating, which we can use to find the inputs and outputs of the original op. grad (Tensor): Gradient with respect to the output of `embedding_lookup_variable_hotness`. Returns: IndexedSlices: A `IndexedSlices` contain sparse gradients with respect to the embedding parameter of `embedding_lookup_variable_hotness`. """ param_shape = tf.shape(op.inputs[0]) flat_ids = tf.reshape(op.inputs[1], [-1]) offsets = op.inputs[2] unique_ids, unique_grad = ops.embedding_lookup_variable_hotness_grad( flat_ids, offsets, grad, op.inputs[0], combiner=op.get_attr('combiner')) return (tf.IndexedSlices(unique_grad, unique_ids, param_shape), None, None) def integer_lookup(table, count, keys, capacity): resource_variable_ops.variable_accessed(table) resource_variable_ops.variable_accessed(count) return ops.integer_lookup(table.handle, count.handle, keys, capacity, count.dtype)