"""
 Copyright (c) 2023, NVIDIA CORPORATION.
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
         http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""
import tensorflow as tf
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.framework import dtypes
from tensorflow.nn import embedding_lookup
from hierarchical_parameter_server.core import lookup_ops
[docs]class SparseLookupLayer(tf.keras.layers.Layer):
    """
    Abbreviated as ``hps.SparseLookupLayer(*args, **kwargs)``.
    This is a wrapper class for HPS sparse lookup layer, which basically performs
    the same function as ``tf.nn.embedding_lookup_sparse``. Note that ``ps_config_file``
    and ``global_batch_size`` should be specified in the constructor if you want
    to use implicit HPS initialization.
    Parameters
    ----------
    model_name: str
            The name of the model that has embedding tables.
    table_id: int
            The index of the embedding table for the model specified by
            model_name.
    emb_vec_size: int
            The embedding vector size for the embedding table specified
            by model_name and table_id.
    emb_vec_dtype:
            The data type of embedding vectors which must be ``tf.float32``.
    ps_config_file: str
            The JSON configuration file for HPS initialization.
    global_batch_size: int
            The global batch size for HPS that is deployed on multiple GPUs.
    Examples
    --------
    .. code-block:: python
        import hierarchical_parameter_server as hps
        sparse_lookup_layer = hps.SparseLookupLayer(model_name = args.model_name,
                                                   table_id = args.table_id,
                                                   emb_vec_size = args.embed_vec_size,
                                                   emb_vec_dtype = tf.float32,
                                                   ps_config_file = args.ps_config_file,
                                                   global_batch_size = args.global_batch_size)
        @tf.function
        def _infer_step(inputs):
            embedding_vector = sparse_lookup_layer(sp_ids=inputs,
                                                  sp_weights = None,
                                                  combiner="mean")
            ...
        for i, (inputs, labels) in enumerate(dataset):
            _infer_step(inputs)
    """
    def __init__(
        self,
        model_name,
        table_id,
        emb_vec_size,
        emb_vec_dtype,
        ps_config_file="",
        global_batch_size=1,
        **kwargs,
    ):
        super(SparseLookupLayer, self).__init__(**kwargs)
        self.model_name = model_name
        self.table_id = table_id
        self.emb_vec_size = emb_vec_size
        self.emb_vec_dtype = emb_vec_dtype
        self.ps_config_file = ps_config_file
        self.global_batch_size = global_batch_size
[docs]    def call(self, sp_ids, sp_weights, name=None, combiner=None, max_norm=None):
        """
        Looks up embeddings for the given ids and weights from a list of tensors.
        This op assumes that there is at least one ID for each row in the dense tensor
        represented by ``sp_ids`` (i.e. there are no rows with empty features), and that
        all the indices of ``sp_ids`` are in canonical row-major order. The ``sp_ids``
        and ``sp_weights`` (if not None) are ``SparseTensor`` with rank of 2.
        Embeddings are always aggregated along the last dimension.
        If an ID value cannot be found in the HPS, the default embeddings are retrieved,
        which can be specified in the HPS configuration JSON file.
        Parameters
        ----------
        sp_ids:
            N x M ``SparseTensor`` of ``int32`` or ``int64`` IDs where N is typically batch size
            and M is arbitrary.
        sp_weights:
            Either a ``SparseTensor`` of float or double weights, or ``None`` to
            indicate all weights should be taken to be `1`. If specified, ``sp_weights``
            must have exactly the same shape and indices as ``sp_ids``.
        combiner:
            A string that specifies the reduction op:
            ``"sum"``
              Computes the weighted sum of the embedding results for each row.
            ``"mean"``
              Computes the weighted sum divided by the total weight.
            ``"sqrtn"``
              Computes the weighted sum divided by the square root of the sum of the
              squares of the weights.
            The default value is ``"mean"``.
        max_norm:
            if not ``None``, each embedding is clipped if its l2-norm is larger
            than this value, before combining.
        Returns
        -------
        emb_vector: ``tf.Tensor`` of float32
            A dense tensor representing the combined embeddings for the
            sparse IDs. For each row in the dense tensor represented by ``sp_ids``, the op
            looks up the embeddings for all IDs in that row, multiplies them by the
            corresponding weight, and combines these embeddings as specified.
            In other words, if
            .. code-block:: python
                shape(sp_ids) = shape(sp_weights) = [d0, d1]
            then
            .. code-block:: python
                shape(output) = [d0, self.emb_vec_dtype]
            For instance, if self.emb_vec_dtype is 16, and sp_ids / sp_weights are
            .. code-block:: python
                [0, 0]: id 1, weight 2.0
                [0, 1]: id 3, weight 0.5
                [1, 0]: id 0, weight 1.0
                [2, 3]: id 1, weight 3.0
            with ``combiner = "mean"``, then the output is a 3x16 matrix where
            .. code-block:: python
                output[0, :] = (vector_for_id_1 * 2.0 + vector_for_id_3 * 0.5) / (2.0 + 0.5)
                output[1, :] = (vector_for_id_0 * 1.0) / 1.0
                output[2, :] = (vector_for_id_1 * 3.0) / 3.0
        Raises
        ------
            TypeError: If ``sp_ids`` is not a ``SparseTensor``, or if ``sp_weights`` is
                neither ``None`` nor ``SparseTensor``.
            ValueError: If ``combiner`` is not one of {``"mean"``, ``"sqrtn"``, ``"sum"``}.
        """
        # Extract unique dense ids to be looked up
        if combiner is None:
            combiner = "mean"
        if combiner not in ("mean", "sqrtn", "sum"):
            raise ValueError(f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}")
        if not isinstance(sp_ids, sparse_tensor.SparseTensor):
            raise TypeError(f"sp_ids must be SparseTensor, got {type(sp_ids)}")
        ignore_weights = sp_weights is None
        if not ignore_weights:
            if not isinstance(sp_weights, sparse_tensor.SparseTensor):
                raise TypeError(
                    f"sp_weights must be either None or SparseTensor," f"got {type(sp_weights)}"
                )
            sp_ids.values.get_shape().assert_is_compatible_with(sp_weights.values.get_shape())
            sp_ids.indices.get_shape().assert_is_compatible_with(sp_weights.indices.get_shape())
            sp_ids.dense_shape.get_shape().assert_is_compatible_with(
                sp_weights.dense_shape.get_shape()
            )
            # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
            # sp_weights have equal indices and shapes.
        segment_ids = sp_ids.indices[:, 0]
        ids = sp_ids.values
        ids, idx = array_ops.unique(ids)
        # Query HPS for embeddings
        embeddings = lookup_ops.lookup(
            ids=ids,
            model_name=self.model_name,
            table_id=self.table_id,
            emb_vec_size=self.emb_vec_size,
            emb_vec_dtype=self.emb_vec_dtype,
            ps_config_file=self.ps_config_file,
            global_batch_size=self.global_batch_size,
            max_norm=max_norm,
        )
        # Handle weights and combiner
        if not ignore_weights:
            if segment_ids.dtype != dtypes.int32:
                segment_ids = math_ops.cast(segment_ids, dtypes.int32)
            weights = sp_weights.values
            embeddings = array_ops.gather(embeddings, idx)
            original_dtype = embeddings.dtype
            if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
                # Cast low-precision embeddings to float32 during the computation to
                # avoid numerical issues.
                embeddings = math_ops.cast(embeddings, dtypes.float32)
            if weights.dtype != embeddings.dtype:
                weights = math_ops.cast(weights, embeddings.dtype)
            # Reshape weights to allow broadcast
            ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0)
            ones = array_ops.ones(ones_shape, dtype=dtypes.int32)
            bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0)
            orig_weights_shape = weights.get_shape()
            weights = array_ops.reshape(weights, bcast_weights_shape)
            # Set the weight shape, since after reshaping to bcast_weights_shape,
            # the shape becomes None.
            if embeddings.get_shape().ndims is not None:
                weights.set_shape(
                    orig_weights_shape.concatenate(
                        [1 for _ in range(embeddings.get_shape().ndims - 1)]
                    )
                )
            embeddings *= weights
            if combiner == "sum":
                embeddings = math_ops.segment_sum(embeddings, segment_ids)
            elif combiner == "mean":
                embeddings = math_ops.segment_sum(embeddings, segment_ids)
                weight_sum = math_ops.segment_sum(weights, segment_ids)
                embeddings = math_ops.div_no_nan(embeddings, weight_sum)
            elif combiner == "sqrtn":
                embeddings = math_ops.segment_sum(embeddings, segment_ids)
                weights_squared = math_ops.pow(weights, 2)
                weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
                weight_sum_sqrt = math_ops.sqrt(weight_sum)
                embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt)
            else:
                assert False, "Unrecognized combiner"
            if embeddings.dtype != original_dtype:
                embeddings = math_ops.cast(embeddings, original_dtype)
        else:
            if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
                segment_ids = math_ops.cast(segment_ids, dtypes.int32)
            assert idx is not None
            if combiner == "sum":
                embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids)
            elif combiner == "mean":
                embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids)
            elif combiner == "sqrtn":
                embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids)
            else:
                assert False, "Unrecognized combiner"
        output_shape = [sp_ids.get_shape()[0], self.emb_vec_size]
        embeddings.set_shape(output_shape)
        return embeddings