Source code for merlin.models.tf.metrics.ranking

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Adapted from source code: https://github.com/karlhigley/ranking-metrics-torch
from typing import List, Optional, Sequence

import tensorflow as tf
from keras.utils import losses_utils, metrics_utils
from keras.utils.tf_utils import is_tensor_or_variable
from tensorflow.keras import backend
from tensorflow.keras.metrics import Mean
from tensorflow.keras.metrics import get as get_metric

from merlin.models.tf.metrics import metrics_registry
from merlin.models.tf.utils.tf_utils import extract_topk

METRIC_PARAMETERS_DOCSTRING = """
    y_true : tf.Tensor
        A tensor with shape (batch_size, n_items) corresponding to
        the multi-hot representation of true labels.
    y_pred : tf.Tensor
        A tensor with shape (batch_size, n_items) corresponding to
        the prediction scores.
    label_relevant_counts: tf.Tensor
        A 1D tensor (batch size) which contains the total number of relevant items
        for the example. This is necessary when setting `pre_sorting=True` on the
        ranking metrics classes (e.g. RecallAt(5, pre_sorted=True)), as extract_topk
        is used to extract only the top-k predictions and corresponding labels,
        potentially losing other relevant items not among top-k. In such cases,
        the label_relevant_counts will contain the total relevant counts per example.
    k : int
        The cut-off for ranking metrics
"""


def recall_at(
    y_true: tf.Tensor,
    y_pred: tf.Tensor,
    label_relevant_counts: Optional[tf.Tensor] = None,
    k: int = 5,
) -> tf.Tensor:
    """
    Computes Recall@K metric
    ----------
    {METRIC_PARAMETERS_DOCSTRING}
    """
    rel_count = tf.clip_by_value(label_relevant_counts, clip_value_min=1, clip_value_max=float(k))

    rel_labels = tf.reduce_sum(y_true[:, :k], axis=-1)
    results = tf.cast(
        tf.math.divide_no_nan(rel_labels, rel_count),
        backend.floatx(),
    )
    return results


def precision_at(
    y_true: tf.Tensor,
    y_pred: tf.Tensor,
    label_relevant_counts: Optional[tf.Tensor] = None,
    k: int = 5,
) -> tf.Tensor:
    """
    Computes Precision@K metric
    Parameters
    ----------
    {METRIC_PARAMETERS_DOCSTRING}
    """
    results = tf.cast(tf.reduce_mean(y_true[:, : int(k)], axis=-1), backend.floatx())
    return results


def average_precision_at(
    y_true: tf.Tensor,
    y_pred: tf.Tensor,
    label_relevant_counts: Optional[tf.Tensor] = None,
    k: int = 5,
) -> tf.Tensor:
    """
    Computes Mean Average Precision (MAP) @K
    Parameters
    ----------
    {METRIC_PARAMETERS_DOCSTRING}
    """
    # Computing the precision from 1 to k range
    precisions = tf.stack([precision_at(y_true, y_pred, k=_k) for _k in range(1, k + 1)], axis=-1)
    # Keeping only the precision at the position of relevant items
    rel_precisions = precisions * y_true[:, :k]

    total_prec = tf.reduce_sum(rel_precisions, axis=-1)
    total_relevant_topk = tf.clip_by_value(
        label_relevant_counts, clip_value_min=1, clip_value_max=float(k)
    )

    results = tf.cast(tf.math.divide_no_nan(total_prec, total_relevant_topk), backend.floatx())

    return results


def dcg_at(
    y_true: tf.Tensor,
    y_pred: tf.Tensor,
    label_relevant_counts: Optional[tf.Tensor] = None,
    k: int = 5,
    log_base: int = 2,
) -> tf.Tensor:

    """
    Compute discounted cumulative gain @K (ignoring ties)
    Parameters
    ----------
    {METRIC_PARAMETERS_DOCSTRING}
    """

    # Compute discounts
    discount_positions = tf.cast(tf.range(k), backend.floatx())
    discount_log_base = tf.math.log(tf.convert_to_tensor([log_base], dtype=backend.floatx()))

    discounts = 1 / (tf.math.log(discount_positions + 2) / discount_log_base)
    m = y_true[:, :k] * tf.repeat(tf.expand_dims(discounts[:k], 0), tf.shape(y_true)[0], axis=0)

    results = tf.cast(tf.reduce_sum(m, axis=-1), backend.floatx())
    return results


def ndcg_at(
    y_true: tf.Tensor,
    y_pred: tf.Tensor,
    label_relevant_counts: Optional[tf.Tensor] = None,
    k: int = 5,
    log_base: int = 2,
) -> tf.Tensor:
    """
    Compute normalized discounted cumulative gain @K (ignoring ties)
    Parameters
    ----------
    {METRIC_PARAMETERS_DOCSTRING}
    log_base : int
        Base of the log discount where relevant items are rankied. Defaults to 2
    """
    gains = dcg_at(y_true, y_pred, k=k, log_base=log_base)
    perfect_labels_sorting = tf.cast(
        tf.cast(tf.expand_dims(tf.range(k), 0), label_relevant_counts.dtype)  # type: ignore
        < tf.expand_dims(label_relevant_counts, -1),
        backend.floatx(),
    )
    ideal_gains = dcg_at(perfect_labels_sorting, perfect_labels_sorting, k=k, log_base=log_base)

    results = tf.cast(tf.math.divide_no_nan(gains, ideal_gains), backend.floatx())
    return results


def mrr_at(
    y_true: tf.Tensor,
    y_pred: tf.Tensor,
    label_relevant_counts: Optional[tf.Tensor] = None,
    k: int = 5,
) -> tf.Tensor:
    """
    Compute MRR
    ----------
    {METRIC_PARAMETERS_DOCSTRING}
    """

    first_rel_position = tf.cast(tf.argmax(y_true, axis=-1) + 1, backend.floatx())
    relevant_mask = tf.reduce_max(y_true[:, : int(k)], axis=-1)

    rel_position = first_rel_position * relevant_mask
    results = tf.cast(tf.math.divide_no_nan(1.0, rel_position), backend.floatx())

    return results


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class RankingMetric(Mean):
    def __init__(self, fn, k=5, pre_sorted=True, name=None, dtype=None, **kwargs):
        if name is not None:
            name = f"{name}_{k}"
        super().__init__(name=name, dtype=dtype)
        self._fn = fn
        self.k = k
        self.pre_sorted = pre_sorted
        self._fn_kwargs = kwargs

    def update_state(
        self,
        y_true: tf.Tensor,
        y_pred: tf.Tensor,
        label_relevant_counts: Optional[tf.Tensor] = None,
        sample_weight: Optional[tf.Tensor] = None,
    ):
        y_true, y_pred = self.check_cast_inputs(y_true, y_pred)
        (
            [
                y_true,
                y_pred,
            ],
            sample_weight,
        ) = metrics_utils.ragged_assert_compatible_and_get_flat_values(
            [y_true, y_pred], sample_weight
        )
        y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)

        tf.debugging.assert_greater_equal(
            tf.shape(y_true)[1],
            self.k,
            f"The ranking metric ({self.name}) cutoff ({self.k}) cannot be smaller than "
            f"the number of predictions per example",
        )

        y_pred, y_true, label_relevant_counts = self._maybe_sort_top_k(
            y_pred, y_true, label_relevant_counts
        )

        ag_fn = tf.__internal__.autograph.tf_convert(
            self._fn, tf.__internal__.autograph.control_status_ctx()
        )

        matches = ag_fn(
            y_true,
            y_pred,
            label_relevant_counts=label_relevant_counts,
            k=self.k,
            **self._fn_kwargs,
        )
        return super().update_state(matches, sample_weight=sample_weight)

    def _maybe_sort_top_k(self, y_pred, y_true, label_relevant_counts: tf.Tensor = None):
        if not self.pre_sorted:
            y_pred, y_true, label_relevant_counts = extract_topk(self.k, y_pred, y_true)
        else:
            if label_relevant_counts is None:
                raise Exception(
                    "If y_true was pre-sorted (and truncated to top-k) you must "
                    "provide label_relevant_counts argument."
                )
            label_relevant_counts = tf.cast(label_relevant_counts, self._dtype)

        return y_pred, y_true, label_relevant_counts

    def check_cast_inputs(self, labels, predictions):
        tf.assert_equal(
            tf.rank(predictions), 2, f"predictions must be 2-D tensor (got {predictions.shape})"
        )
        tf.assert_equal(tf.rank(labels), 2, f"labels must be 2-D tensor (got {labels.shape})")

        predictions.get_shape().assert_is_compatible_with(labels.get_shape())

        return tf.cast(labels, self._dtype), tf.cast(predictions, self._dtype)

    def get_config(self):
        config = {}

        if type(self) is RankingMetric:
            # Only include function argument when the object is of a subclass.
            config["fn"] = self._fn
            config["k"] = self.k
            config["pre_sorted"] = self.pre_sorted

            for k, v in self._fn_kwargs.items():
                config[k] = backend.eval(v) if is_tensor_or_variable(v) else v
            base_config = super(RankingMetric, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

        return {}

    @classmethod
    def from_config(cls, config):
        fn = config.pop("fn", None)
        k = config.pop("k", None)
        pre_sorted = config.pop("pre_sorted", None)
        if cls is RankingMetric:
            return cls(get_metric(fn), k=k, pre_sorted=pre_sorted, **config)
        return super(RankingMetric, cls).from_config(config)


[docs]@metrics_registry.register_with_multiple_names("recall_at", "recall") class RecallAt(RankingMetric):
[docs] def __init__(self, k=10, pre_sorted=False, name="recall_at"): super().__init__(recall_at, k=k, pre_sorted=pre_sorted, name=name)
@metrics_registry.register_with_multiple_names("precision_at", "precision") class PrecisionAt(RankingMetric): def __init__(self, k=10, pre_sorted=False, name="precision_at"): super().__init__(precision_at, k=k, pre_sorted=pre_sorted, name=name)
[docs]@metrics_registry.register_with_multiple_names("map_at", "map") class AvgPrecisionAt(RankingMetric):
[docs] def __init__(self, k=10, pre_sorted=False, name="map_at"): super().__init__(average_precision_at, k=k, pre_sorted=pre_sorted, name=name)
@metrics_registry.register_with_multiple_names("mrr_at", "mrr") class MRRAt(RankingMetric): def __init__(self, k=10, pre_sorted=False, name="mrr_at"): super().__init__(mrr_at, k=k, pre_sorted=pre_sorted, name=name)
[docs]@metrics_registry.register_with_multiple_names("ndcg_at", "ndcg") class NDCGAt(RankingMetric):
[docs] def __init__(self, k=10, pre_sorted=False, name="ndcg_at"): super().__init__(ndcg_at, k=k, pre_sorted=pre_sorted, name=name)
[docs]def ranking_metrics(top_ks: Sequence[int], **kwargs) -> Sequence[RankingMetric]: metrics: List[RankingMetric] = [] for k in top_ks: metrics.extend([RecallAt(k), MRRAt(k), NDCGAt(k), AvgPrecisionAt(k), PrecisionAt(k)]) return metrics