#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from source code: https://github.com/karlhigley/ranking-metrics-torch
from abc import abstractmethod
from typing import List
import numpy as np
import tensorflow as tf
from merlin_standard_lib import Registry
from .utils import tf_utils
ranking_metrics_registry = Registry("tf.ranking_metrics")
METRIC_PARAMETERS_DOCSTRING = """
    scores : tf.Tensor
        scores of predicted item-ids.
    labels : tf.Tensor
        true item-ids labels.
"""
[docs]@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class RankingMetric(tf.keras.metrics.Metric):
    """
    Metric wrapper for computing ranking metrics@K for session-based task.
    Parameters
    ----------
    top_ks : list, default [2, 5])
        list of cutoffs
    labels_onehot : bool
        Enable transform the encoded labels to one-hot representation
    """
    def __init__(
        self,
        name=None,
        dtype=None,
        top_ks: List[int] = [2, 5],
        labels_onehot: bool = False,
        **kwargs,
    ):
        super(RankingMetric, self).__init__(name=name, **kwargs)
        self.top_ks = top_ks
        self.labels_onehot = labels_onehot
        # Store the mean vector of the batch metrics (for each cut-off at topk) in ListWrapper
        self.metric_mean: List[tf.Tensor] = []
        self.accumulator = tf.Variable(
            tf.zeros(shape=[1, len(self.top_ks)]),
            trainable=False,
            shape=tf.TensorShape([None, tf.compat.v1.Dimension(len(self.top_ks))]),
        )
[docs]    def get_config(self):
        config = {"top_ks": self.top_ks, "labels_onehot": self.labels_onehot}
        base_config = super(RankingMetric, self).get_config()
        return dict(list(base_config.items()) + list(config.items())) 
    def _build(self, shape):
        bs = shape[0]
        variable_shape = [bs, tf.compat.v1.Dimension(len(self.top_ks))]
        self.accumulator.assign(tf.zeros(variable_shape))
[docs]    def update_state(self, y_true: tf.Tensor, y_pred: tf.Tensor, **kwargs):
        # Computing the metrics at different cut-offs
        # init batch accumulator
        self._build(shape=tf.shape(y_pred))
        if self.labels_onehot:
            y_true = tf_utils.tranform_label_to_onehot(y_true, tf.shape(y_pred)[-1])
        self._metric(
            scores=tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]]),
            labels=y_true,
        )
        self.metric_mean.append(self.accumulator) 
[docs]    def result(self):
        # Computing the mean of the batch metrics (for each cut-off at topk)
        return tf.reduce_mean(tf.concat(self.metric_mean, axis=0), axis=0) 
[docs]    def reset_state(self):
        self.metric_mean = [] 
    @abstractmethod
    def _metric(self, scores: tf.Tensor, labels: tf.Tensor, **kwargs):
        """
        Update `self.accumulator` with the ranking metric of
        prediction scores and one-hot labels for different cut-offs `ks`.
        This method should be overridden by subclasses.
        Parameters
        ----------
        {METRIC_PARAMETERS_DOCSTRING}
        """
        raise NotImplementedError
[docs]    def metric_fn(self, scores: tf.Tensor, labels: tf.Tensor, **kwargs) -> tf.Tensor:
        """
        Compute ranking metric over predictions and one-hot targets for different cut-offs.
        Parameters
        ----------
        {METRIC_PARAMETERS_DOCSTRING}
        """
        self._build(shape=tf.shape(scores))
        self._metric(scores=tf.reshape(scores, [-1, tf.shape(scores)[-1]]), labels=labels, **kwargs)
        return self.accumulator  
[docs]@ranking_metrics_registry.register_with_multiple_names("precision_at", "precision")
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class PrecisionAt(RankingMetric):
    def __init__(self, top_ks=None, labels_onehot=False, **kwargs):
        super(PrecisionAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot, **kwargs)
    def _metric(self, scores: tf.Tensor, labels: tf.Tensor, **kwargs) -> tf.Tensor:
        """
        Compute precision@K for each provided cutoff in ks
        Parameters
        ----------
        {METRIC_PARAMETERS_DOCSTRING}
        """
        ks = tf.convert_to_tensor(self.top_ks)
        ks, scores, labels = check_inputs(ks, scores, labels)
        _, _, topk_labels = tf_utils.extract_topk(ks, scores, labels)
        bs = tf.shape(scores)[0]
        for index in range(int(tf.shape(ks)[0])):
            k = ks[index]
            rows_ids = tf.range(bs, dtype=tf.int64)
            indices = tf.concat(
                [
                    tf.expand_dims(rows_ids, 1),
                    tf.cast(index, tf.int64) * tf.ones([bs, 1], dtype=tf.int64),
                ],
                axis=1,
            )
            self.accumulator.scatter_nd_update(
                indices=indices, updates=tf.reduce_sum(topk_labels[:, : int(k)], axis=1) / float(k)
            ) 
[docs]@ranking_metrics_registry.register_with_multiple_names("recall_at", "recall")
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class RecallAt(RankingMetric):
    def __init__(self, top_ks=None, labels_onehot=False, **kwargs):
        super(RecallAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot, **kwargs)
    def _metric(self, scores: tf.Tensor, labels: tf.Tensor, **kwargs) -> tf.Tensor:
        """
        Compute recall@K for each provided cutoff in ks
        Parameters
        ----------
        {METRIC_PARAMETERS_DOCSTRING}
        """
        ks = tf.convert_to_tensor(self.top_ks)
        ks, scores, labels = check_inputs(ks, scores, labels)
        _, _, topk_labels = tf_utils.extract_topk(ks, scores, labels)
        # Compute recalls at K
        num_relevant = tf.reduce_sum(labels, axis=-1)
        rel_indices = tf.where(num_relevant != 0)
        rel_count = tf.gather_nd(num_relevant, rel_indices)
        if tf.shape(rel_indices)[0] > 0:
            for index in range(int(tf.shape(ks)[0])):
                k = ks[index]
                rel_labels = tf.cast(
                    tf.gather_nd(topk_labels, rel_indices)[:, : int(k)], tf.float32
                )
                batch_recall_k = tf.cast(
                    tf.reshape(
                        tf.math.divide(tf.reduce_sum(rel_labels, axis=-1), rel_count),
                        (len(rel_indices), 1),
                    ),
                    tf.float32,
                )
                # Ensuring type is double, because it can be float if --fp16
                update_indices = tf.concat(
                    [
                        rel_indices,
                        tf.expand_dims(
                            tf.cast(index, tf.int64) * tf.ones(tf.shape(rel_indices)[0], tf.int64),
                            -1,
                        ),
                    ],
                    axis=1,
                )
                self.accumulator.scatter_nd_update(
                    indices=update_indices, updates=tf.reshape(batch_recall_k, (-1,))
                ) 
[docs]@ranking_metrics_registry.register_with_multiple_names("avg_precision_at", "avg_precision", "map")
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class AvgPrecisionAt(RankingMetric):
    def __init__(self, top_ks=None, labels_onehot=False, **kwargs):
        super(AvgPrecisionAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot, **kwargs)
        max_k = tf.reduce_max(self.top_ks)
        self.precision_at = PrecisionAt(top_ks=1 + np.array((range(max_k)))).metric_fn
    def _metric(self, scores: tf.Tensor, labels: tf.Tensor, **kwargs) -> tf.Tensor:
        """
        Compute average precision @K for provided cutoff in ks
        Parameters
        ----------
        {METRIC_PARAMETERS_DOCSTRING}
        """
        ks = tf.convert_to_tensor(self.top_ks)
        ks, scores, labels = check_inputs(ks, scores, labels)
        topk_scores, _, topk_labels = tf_utils.extract_topk(ks, scores, labels)
        num_relevant = tf.reduce_sum(labels, axis=-1)
        bs = tf.shape(scores)[0]
        precisions = self.precision_at(topk_scores, topk_labels)
        rel_precisions = precisions * topk_labels
        for index in range(int(tf.shape(ks)[0])):
            k = ks[index]
            tf_total_prec = tf.reduce_sum(rel_precisions[:, :k], axis=1)
            clip_value = tf.clip_by_value(
                num_relevant, clip_value_min=1, clip_value_max=tf.cast(k, tf.float32)
            )
            rows_ids = tf.range(bs, dtype=tf.int64)
            indices = tf.concat(
                [
                    tf.expand_dims(rows_ids, 1),
                    tf.cast(index, tf.int64) * tf.ones([bs, 1], dtype=tf.int64),
                ],
                axis=1,
            )
            self.accumulator.scatter_nd_update(indices=indices, updates=tf_total_prec / clip_value) 
[docs]@ranking_metrics_registry.register_with_multiple_names("dcg_at", "dcg")
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class DCGAt(RankingMetric):
    def __init__(self, top_ks=None, labels_onehot=False, **kwargs):
        super(DCGAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot, **kwargs)
    def _metric(
        self, scores: tf.Tensor, labels: tf.Tensor, log_base: int = 2, **kwargs
    ) -> tf.Tensor:
        """
        Compute discounted cumulative gain @K for each provided cutoff in ks
        (ignoring ties)
        Parameters
        ----------
        {METRIC_PARAMETERS_DOCSTRING}
        """
        ks = tf.convert_to_tensor(self.top_ks)
        ks, scores, labels = check_inputs(ks, scores, labels)
        _, _, topk_labels = tf_utils.extract_topk(ks, scores, labels)
        # Compute discounts
        max_k = tf.reduce_max(ks)
        discount_positions = tf.cast(tf.range(max_k), tf.float32)
        discount_log_base = tf.math.log(tf.convert_to_tensor([log_base], dtype=tf.float32))
        discounts = 1 / (tf.math.log(discount_positions + 2) / discount_log_base)
        bs = tf.shape(scores)[0]
        # Compute DCGs at K
        for index in range(len(self.top_ks)):
            k = ks[index]
            m = topk_labels[:, :k] * tf.repeat(
                tf.expand_dims(discounts[:k], 0), tf.shape(topk_labels)[0], axis=0
            )
            rows_ids = tf.range(bs, dtype=tf.int64)
            indices = tf.concat(
                [
                    tf.expand_dims(rows_ids, 1),
                    tf.cast(index, tf.int64) * tf.ones([bs, 1], dtype=tf.int64),
                ],
                axis=1,
            )
            self.accumulator.scatter_nd_update(
                indices=indices, updates=tf.cast(tf.reduce_sum(m, axis=1), tf.float32)
            ) 
            # Ensuring type is double, because it can be float if --fp16
[docs]@ranking_metrics_registry.register_with_multiple_names("ndcg_at", "ndcg")
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class NDCGAt(RankingMetric):
    def __init__(self, top_ks=None, labels_onehot=False, **kwargs):
        super(NDCGAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot, **kwargs)
        self.dcg_at = DCGAt(top_ks).metric_fn
    def _metric(
        self, scores: tf.Tensor, labels: tf.Tensor, log_base: int = 2, **kwargs
    ) -> tf.Tensor:
        """
        Compute normalized discounted cumulative gain @K for each provided cutoffs in ks
        (ignoring ties)
        Parameters
        ----------
        {METRIC_PARAMETERS_DOCSTRING}
        """
        ks = tf.convert_to_tensor(self.top_ks)
        ks, scores, labels = check_inputs(ks, scores, labels)
        topk_scores, _, topk_labels = tf_utils.extract_topk(ks, scores, labels)
        # Compute discounted cumulative gains
        gains = self.dcg_at(labels=topk_labels, scores=topk_scores, log_base=log_base)
        self.accumulator.assign(gains)
        normalizing_gains = self.dcg_at(labels=topk_labels, scores=topk_labels, log_base=log_base)
        # Prevent divisions by zero
        relevant_pos = tf.where(normalizing_gains != 0)
        tf.where(normalizing_gains == 0, 0.0, gains)
        updates = tf.gather_nd(self.accumulator, relevant_pos) / tf.gather_nd(
            normalizing_gains, relevant_pos
        )
        self.accumulator.scatter_nd_update(relevant_pos, updates) 
[docs]def process_metrics(metrics, prefix=""):
    metrics_proc = {}
    for metric in metrics:
        results = metric.result()
        if getattr(metric, "top_ks", None):
            for i, ks in enumerate(metric.top_ks):
                metrics_proc.update(
                    {f"{prefix}{metric.name.split('_')[0]}@{ks}": tf.gather(results, i)}
                )
        else:
            metrics_proc[metric.name] = results
    return metrics_proc