#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Optional, Union
import tensorflow as tf
from merlin_standard_lib import Schema, Tag
from merlin_standard_lib.utils.doc_utils import docstring_parameter
from ..block.base import Block, SequentialBlock
from ..block.mlp import MLPBlock
from ..masking import MaskSequence, masking_registry
from ..tabular.base import (
    TABULAR_MODULE_PARAMS_DOCSTRING,
    AsTabular,
    TabularAggregationType,
    TabularBlock,
    TabularTransformationType,
)
from ..utils import tf_utils
from . import embedding
from .tabular import TABULAR_FEATURES_PARAMS_DOCSTRING, TabularFeatures
[docs]@docstring_parameter(
    tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING,
    embedding_features_parameters=embedding.EMBEDDING_FEATURES_PARAMS_DOCSTRING,
)
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class SequenceEmbeddingFeatures(embedding.EmbeddingFeatures):
    """Input block for embedding-lookups for categorical features. This module produces 3-D tensors,
    this is useful for sequential models like transformers.
    Parameters
    ----------
    {embedding_features_parameters}
    padding_idx: int
        The symbol to use for padding.
    {tabular_module_parameters}
    """
    def __init__(
        self,
        feature_config: Dict[str, embedding.FeatureConfig],
        item_id: Optional[str] = None,
        mask_zero: bool = True,
        padding_idx: int = 0,
        pre: Optional[TabularTransformationType] = None,
        post: Optional[TabularTransformationType] = None,
        aggregation: Optional[TabularAggregationType] = None,
        schema: Optional[Schema] = None,
        name: Optional[str] = None,
        embedding_tables={},
        **kwargs
    ):
        super().__init__(
            feature_config,
            item_id=item_id,
            pre=pre,
            post=post,
            aggregation=aggregation,
            schema=schema,
            name=name,
            embedding_tables=embedding_tables,
            **kwargs,
        )
        self.padding_idx = padding_idx
        self.mask_zero = mask_zero
[docs]    def lookup_feature(self, name, val, **kwargs):
        return super(SequenceEmbeddingFeatures, self).lookup_feature(
            name, val, output_sequence=True
        ) 
[docs]    def compute_call_output_shape(self, input_shapes):
        batch_size = self.calculate_batch_size_from_input_shapes(input_shapes)
        sequence_length = input_shapes[list(self.feature_config.keys())[0]][1]
        output_shapes = {}
        for name, val in input_shapes.items():
            output_shapes[name] = tf.TensorShape(
                [batch_size, sequence_length, self.feature_config[name].table.dim]
            )
        return output_shapes 
[docs]    def compute_mask(self, inputs, mask=None):
        if not self.mask_zero:
            return None
        outputs = {}
        for key, val in inputs.items():
            outputs[key] = tf.not_equal(val, self.padding_idx)
        return outputs 
[docs]    def get_config(self):
        config = super().get_config()
        config["mask_zero"] = self.mask_zero
        config["padding_idx"] = self.padding_idx
        return config  
[docs]@docstring_parameter(
    tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING,
    tabular_features_parameters=TABULAR_FEATURES_PARAMS_DOCSTRING,
)
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class TabularSequenceFeatures(TabularFeatures):
    """Input module that combines different types of features to a sequence: continuous,
    categorical & text.
    Parameters
    ----------
    {tabular_features_parameters}
    projection_module: BlockOrModule, optional
        Module that's used to project the output of this module, typically done by an MLPBlock.
    masking: MaskSequence, optional
         Masking to apply to the inputs.
    {tabular_module_parameters}
    """
    EMBEDDING_MODULE_CLASS = SequenceEmbeddingFeatures
    def __init__(
        self,
        continuous_layer: Optional[TabularBlock] = None,
        categorical_layer: Optional[TabularBlock] = None,
        text_embedding_layer: Optional[TabularBlock] = None,
        projection_block: Optional[Block] = None,
        masking: Optional[MaskSequence] = None,
        pre: Optional[TabularTransformationType] = None,
        post: Optional[TabularTransformationType] = None,
        aggregation: Optional[TabularAggregationType] = None,
        name=None,
        **kwargs
    ):
        super().__init__(
            continuous_layer=continuous_layer,
            categorical_layer=categorical_layer,
            text_embedding_layer=text_embedding_layer,
            pre=pre,
            post=post,
            aggregation=aggregation,
            name=name,
            **kwargs,
        )
        self.projection_block = projection_block
        self.set_masking(masking)
[docs]    @classmethod
    def from_schema(  # type: ignore
        cls,
        schema: Schema,
        continuous_tags=(Tag.CONTINUOUS,),
        categorical_tags=(Tag.CATEGORICAL,),
        aggregation=None,
        max_sequence_length=None,
        continuous_projection=None,
        projection=None,
        d_output=None,
        masking=None,
        **kwargs
    ) -> "TabularSequenceFeatures":
        """Instantiates ``TabularFeatures`` from a ``DatasetSchema``
        Parameters
        ----------
        schema : DatasetSchema
            Dataset schema
        continuous_tags : Optional[Union[DefaultTags, list, str]], optional
            Tags to filter the continuous features, by default Tag.CONTINUOUS
        categorical_tags : Optional[Union[DefaultTags, list, str]], optional
            Tags to filter the categorical features, by default Tag.CATEGORICAL
        aggregation : Optional[str], optional
            Feature aggregation option, by default None
        automatic_build : bool, optional
            Automatically infers input size from features, by default True
        max_sequence_length : Optional[int], optional
            Maximum sequence length for list features by default None
        continuous_projection : Optional[Union[List[int], int]], optional
            If set, concatenate all numerical features and project them by a number of MLP layers
            The argument accepts a list with the dimensions of the MLP layers, by default None
        projection: Optional[torch.nn.Module, BuildableBlock], optional
            If set, project the aggregated embeddings vectors into hidden dimension vector space,
            by default None
        d_output: Optional[int], optional
            If set, init a MLPBlock as projection module to project embeddings vectors,
            by default None
        masking: Optional[Union[str, MaskSequence]], optional
            If set, Apply masking to the input embeddings and compute masked labels, It requires
            a categorical_module including an item_id column, by default None
        Returns
        -------
        TabularFeatures:
            Returns ``TabularFeatures`` from a dataset schema"""
        output = super().from_schema(
            schema=schema,
            continuous_tags=continuous_tags,
            categorical_tags=categorical_tags,
            aggregation=aggregation,
            max_sequence_length=max_sequence_length,
            continuous_projection=continuous_projection,
            **kwargs,
        )
        if d_output and projection:
            raise ValueError("You cannot specify both d_output and projection at the same time")
        if (projection or masking or d_output) and not aggregation:
            # TODO: print warning here for clarity
            output.set_aggregation("concat")
        # hidden_size = output.output_size()
        if d_output and not projection:
            projection = MLPBlock([d_output])
        if projection:
            output.projection_block = projection
        if isinstance(masking, str):
            masking = masking_registry.parse(masking)(**kwargs)
        if masking and not getattr(output, "item_id", None):
            raise ValueError("For masking a categorical_module is required including an item_id.")
        output.set_masking(masking)
        return output 
[docs]    def project_continuous_features(self, dimensions):
        if isinstance(dimensions, int):
            dimensions = [dimensions]
        continuous = self.continuous_layer
        continuous.set_aggregation("concat")
        continuous = SequentialBlock(
            [continuous, MLPBlock(dimensions), AsTabular("continuous_projection")]
        )
        self.to_merge["continuous_layer"] = continuous
        return self 
[docs]    def call(self, inputs, training=True):
        outputs = super(TabularSequenceFeatures, self).call(inputs)
        if self.masking or self.projection_block:
            outputs = self.aggregation(outputs)
        if self.projection_block:
            outputs = self.projection_block(outputs)
        if self.masking:
            outputs = self.masking(
                outputs, item_ids=self.to_merge["categorical_layer"].item_seq, training=training
            )
        return outputs 
[docs]    def compute_call_output_shape(self, input_shape):
        output_shapes = {}
        for layer in self.merge_values:
            output_shapes.update(layer.compute_output_shape(input_shape))
        return output_shapes 
[docs]    def compute_output_shape(self, input_shapes):
        output_shapes = super().compute_output_shape(input_shapes)
        if self.projection_block:
            output_shapes = self.projection_block.compute_output_shape(output_shapes)
        return output_shapes 
    @property
    def masking(self):
        return self._masking
[docs]    def set_masking(self, value):
        self._masking = value 
    @property
    def item_id(self) -> Optional[str]:
        if "categorical_layer" in self.to_merge_dict:
            return getattr(self.to_merge_dict["categorical_layer"], "item_id", None)
        return None
    @property
    def item_embedding_table(self):
        if "categorical_layer" in self.to_merge_dict:
            return getattr(self.to_merge_dict["categorical_layer"], "item_embedding_table", None)
        return None
[docs]    def get_config(self):
        config = super().get_config()
        config = tf_utils.maybe_serialize_keras_objects(
            self, config, ["projection_block", "masking"]
        )
        return config 
[docs]    @classmethod
    def from_config(cls, config, **kwargs):
        config = tf_utils.maybe_deserialize_keras_objects(config, ["projection_block", "masking"])
        return super().from_config(config)  
TabularFeaturesType = Union[TabularSequenceFeatures, TabularFeatures]