#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Union
import tensorflow as tf
from tensorflow.keras.layers import Dense, Layer
from merlin.models.tf.core.base import Block, MetricOrMetrics
from merlin.models.tf.prediction_tasks.base import PredictionTask
from merlin.models.tf.utils.tf_utils import (
    maybe_deserialize_keras_objects,
    maybe_serialize_keras_objects,
)
from merlin.models.utils.schema_utils import categorical_cardinalities
from merlin.schema import Schema, Tags
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class BinaryClassificationTask(PredictionTask):
    """
    Prediction task for binary classification.
    Parameters
    ----------
    target: Union[str, Schema], optional
        The name of the target. If a Schema is provided, the target is inferred from the schema.
    task_name: str, optional
        The name of the task.
    task_block: Block, optional
        The block to use for the task.
    """
    # Default loss to use
    DEFAULT_LOSS = "binary_crossentropy"
    # Default metrics to use
    DEFAULT_METRICS = (
        tf.keras.metrics.Precision,
        tf.keras.metrics.Recall,
        tf.keras.metrics.BinaryAccuracy,
        tf.keras.metrics.AUC,
    )
[docs]    def __init__(
        self,
        target: Optional[Union[str, Schema]] = None,
        task_name: Optional[str] = None,
        task_block: Optional[Layer] = None,
        **kwargs,
    ):
        if isinstance(target, Schema):
            target_name = target.select_by_tag(Tags.BINARY_CLASSIFICATION)
            if not target_name.column_names:
                raise ValueError(
                    "Binary classification task requires a column with a ",
                    "`Tags.BINARY_CLASSIFICATION` tag.",
                )
            elif len(target_name.column_names) > 1:
                raise ValueError(
                    "Binary classification task requires a single column with a ",
                    "`Tags.BINARY_CLASSIFICATION` tag. ",
                    "Found {} columns. ".format(len(target_name.column_names)),
                    "Please specify the column name with the `target` argument.",
                )
            target_name = target_name.column_names[0]
        else:
            target_name = target if target else kwargs.pop("target_name", None)
        output_layer = kwargs.pop("output_layer", None)
        super().__init__(
            target_name=target_name,
            task_name=task_name,
            task_block=task_block,
            **kwargs,
        )
        self.output_layer = output_layer or tf.keras.layers.Dense(
            1, activation="linear", name="output_layer"
        )
        # To ensure that the output is always fp32, avoiding numerical
        # instabilities with mixed_float16 (fp16) policy
        self.output_activation = tf.keras.layers.Activation(
            "sigmoid", dtype="float32", name="prediction"
        ) 
[docs]    def call(self, inputs, training=False, **kwargs):
        return self.output_activation(self.output_layer(inputs)) 
[docs]    def compute_output_shape(self, input_shape):
        return self.output_layer.compute_output_shape(input_shape) 
[docs]    def get_config(self):
        config = super().get_config()
        config = maybe_serialize_keras_objects(
            self,
            config,
            {"output_layer": tf.keras.layers.serialize},
        )
        return config 
[docs]    @classmethod
    def from_config(cls, config):
        config = maybe_deserialize_keras_objects(
            config, ["output_layer"], tf.keras.layers.deserialize
        )
        return super().from_config(config)  
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class CategFeaturePrediction(Block):
    """Block that predicts a categorical feature. num_classes is inferred from the"""
    def __init__(
        self,
        schema: Schema,
        feature_name: Optional[str] = None,
        bias_initializer="zeros",
        kernel_initializer="random_normal",
        activation=None,
        **kwargs,
    ):
        super(CategFeaturePrediction, self).__init__(**kwargs)
        self.bias_initializer = bias_initializer
        self.kernel_initializer = kernel_initializer
        self.feature_name = feature_name or schema.select_by_tag(Tags.ITEM_ID).column_names[0]
        self.num_classes = categorical_cardinalities(schema)[self.feature_name]
        self.activation = activation
        # To ensure that the output is always fp32, avoiding numerical
        # instabilities with mixed_float16 policy
        self.output_activation = tf.keras.layers.Activation(
            activation, dtype="float32", name="predictions"
        )
    def build(self, input_shape):
        self.output_layer = Dense(
            units=self.num_classes,
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
            name=f"{self.feature_name}-prediction",
            activation="linear",
        )
        return super().build(input_shape)
    def call(self, inputs, training=False, **kwargs) -> tf.Tensor:
        return self.output_activation(self.output_layer(inputs))
    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.num_classes,)
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class MultiClassClassificationTask(PredictionTask):
    """
    Prediction task for multi-class classification.
    Parameters
    ----------
    target_name : Optional[str], optional
        Label name, by default None
    task_name: str, optional
        The name of the task.
    task_block: Block, optional
        The block to use for the task.
    """
    DEFAULT_LOSS = "categorical_crossentropy"
    DEFAULT_METRICS: MetricOrMetrics = (tf.keras.metrics.Accuracy,)
[docs]    def __init__(
        self,
        target_name: Optional[str] = None,
        task_name: Optional[str] = None,
        task_block: Optional[Layer] = None,
        pre: Optional[Block] = None,
        **kwargs,
    ):
        super().__init__(
            target_name=target_name,
            task_name=task_name,
            task_block=task_block,
            pre=pre,
            **kwargs,
        ) 
[docs]    @classmethod
    def from_schema(
        cls,
        schema: Schema,
        feature_name: str = Tags.ITEM_ID,
        bias_initializer="zeros",
        kernel_initializer="random_normal",
        extra_pre: Optional[Block] = None,
        **kwargs,
    ) -> "MultiClassClassificationTask":
        """Create from Schema."""
        pre = CategFeaturePrediction(
            schema,
            feature_name,
            bias_initializer=bias_initializer,
            kernel_initializer=kernel_initializer,
        )
        if extra_pre:
            pre = pre.connect(extra_pre)
        return cls(
            pre=pre,
            **kwargs,
        ) 
[docs]    def call(self, inputs, training=False, **kwargs):
        return inputs