#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from source code: https://github.com/karlhigley/ranking-metrics-torch
from typing import List, Optional, Sequence, Union
import tensorflow as tf
from keras.utils import losses_utils, metrics_utils
from tensorflow.keras import backend
from tensorflow.keras.metrics import Mean, Metric
from tensorflow.keras.metrics import get as get_metric
from merlin.models.tf.metrics import metrics_registry
from merlin.models.tf.utils.tf_utils import extract_topk
METRIC_PARAMETERS_DOCSTRING = """
y_true : tf.Tensor
A tensor with shape (batch_size, n_items) corresponding to
the multi-hot representation of true labels.
y_pred : tf.Tensor
A tensor with shape (batch_size, n_items) corresponding to
the prediction scores.
label_relevant_counts: tf.Tensor
A 1D tensor (batch size) which contains the total number of relevant items
for the example. This is necessary when setting `pre_sorting=True` on the
ranking metrics classes (e.g. RecallAt(5, pre_sorted=True)), as extract_topk
is used to extract only the top-k predictions and corresponding labels,
potentially losing other relevant items not among top-k. In such cases,
the label_relevant_counts will contain the total relevant counts per example.
k : int
The cut-off for ranking metrics
"""
def recall_at(
y_true: tf.Tensor,
y_pred: tf.Tensor,
label_relevant_counts: Optional[tf.Tensor] = None,
k: int = 5,
) -> tf.Tensor:
"""
Computes Recall@K metric
----------
{METRIC_PARAMETERS_DOCSTRING}
"""
rel_count = tf.clip_by_value(label_relevant_counts, clip_value_min=1, clip_value_max=float(k))
rel_labels = tf.reduce_sum(y_true[:, :k], axis=-1)
results = tf.cast(
tf.math.divide_no_nan(rel_labels, rel_count),
backend.floatx(),
)
return results
def precision_at(
y_true: tf.Tensor,
y_pred: tf.Tensor,
label_relevant_counts: Optional[tf.Tensor] = None,
k: int = 5,
) -> tf.Tensor:
"""
Computes Precision@K metric
Parameters
----------
{METRIC_PARAMETERS_DOCSTRING}
"""
results = tf.cast(tf.reduce_mean(y_true[:, : int(k)], axis=-1), backend.floatx())
return results
def average_precision_at(
y_true: tf.Tensor,
y_pred: tf.Tensor,
label_relevant_counts: Optional[tf.Tensor] = None,
k: int = 5,
) -> tf.Tensor:
"""
Computes Mean Average Precision (MAP) @K
Parameters
----------
{METRIC_PARAMETERS_DOCSTRING}
"""
# Computing the precision from 1 to k range
precisions = tf.stack([precision_at(y_true, y_pred, k=_k) for _k in range(1, k + 1)], axis=-1)
# Keeping only the precision at the position of relevant items
rel_precisions = precisions * y_true[:, :k]
total_prec = tf.reduce_sum(rel_precisions, axis=-1)
total_relevant_topk = tf.clip_by_value(
label_relevant_counts, clip_value_min=1, clip_value_max=float(k)
)
results = tf.cast(tf.math.divide_no_nan(total_prec, total_relevant_topk), backend.floatx())
return results
def dcg_at(
y_true: tf.Tensor,
y_pred: tf.Tensor,
label_relevant_counts: Optional[tf.Tensor] = None,
k: int = 5,
log_base: int = 2,
) -> tf.Tensor:
"""
Compute discounted cumulative gain @K (ignoring ties)
Parameters
----------
{METRIC_PARAMETERS_DOCSTRING}
"""
# Compute discounts
discount_positions = tf.cast(tf.range(k), backend.floatx())
discount_log_base = tf.math.log(tf.convert_to_tensor([log_base], dtype=backend.floatx()))
discounts = 1 / (tf.math.log(discount_positions + 2) / discount_log_base)
m = y_true[:, :k] * tf.repeat(tf.expand_dims(discounts[:k], 0), tf.shape(y_true)[0], axis=0)
results = tf.cast(tf.reduce_sum(m, axis=-1), backend.floatx())
return results
def ndcg_at(
y_true: tf.Tensor,
y_pred: tf.Tensor,
label_relevant_counts: Optional[tf.Tensor] = None,
k: int = 5,
log_base: int = 2,
) -> tf.Tensor:
"""
Compute normalized discounted cumulative gain @K (ignoring ties)
Parameters
----------
{METRIC_PARAMETERS_DOCSTRING}
log_base : int
Base of the log discount where relevant items are rankied. Defaults to 2
"""
gains = dcg_at(y_true, y_pred, k=k, log_base=log_base)
perfect_labels_sorting = tf.cast(
tf.cast(tf.expand_dims(tf.range(k), 0), label_relevant_counts.dtype) # type: ignore
< tf.expand_dims(label_relevant_counts, -1),
backend.floatx(),
)
ideal_gains = dcg_at(perfect_labels_sorting, perfect_labels_sorting, k=k, log_base=log_base)
results = tf.cast(tf.math.divide_no_nan(gains, ideal_gains), backend.floatx())
return results
def mrr_at(
y_true: tf.Tensor,
y_pred: tf.Tensor,
label_relevant_counts: Optional[tf.Tensor] = None,
k: int = 5,
) -> tf.Tensor:
"""
Compute MRR
----------
{METRIC_PARAMETERS_DOCSTRING}
"""
first_rel_position = tf.cast(tf.argmax(y_true, axis=-1) + 1, backend.floatx())
relevant_mask = tf.reduce_max(y_true[:, : int(k)], axis=-1)
rel_position = first_rel_position * relevant_mask
results = tf.cast(tf.math.divide_no_nan(1.0, rel_position), backend.floatx())
return results
class TopkMetricWithLabelRelevantCountsMixin:
@property
def label_relevant_counts(self) -> tf.Tensor:
return self._label_relevant_counts
@label_relevant_counts.setter
def label_relevant_counts(self, new_value: tf.Tensor):
self._label_relevant_counts = new_value
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class TopkMetric(Mean, TopkMetricWithLabelRelevantCountsMixin):
def __init__(self, fn, k=5, pre_sorted=True, name=None, dtype=None, **kwargs):
if name is not None:
name = f"{name}_{k}"
super().__init__(name=name, dtype=dtype)
self._fn = fn
self.k = k
self._pre_sorted = pre_sorted
self._fn_kwargs = kwargs
self.label_relevant_counts = None
@property
def pre_sorted(self):
return self._pre_sorted
@pre_sorted.setter
def pre_sorted(self, new_value):
self._pre_sorted = new_value
def update_state(
self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
):
y_true, y_pred = self.check_cast_inputs(y_true, y_pred)
(
[y_true, y_pred],
sample_weight,
) = metrics_utils.ragged_assert_compatible_and_get_flat_values(
[y_true, y_pred], sample_weight
)
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)
tf.debugging.assert_greater_equal(
tf.shape(y_true)[1],
self.k,
f"The TopkMetric {self.name} cutoff ({self.k}) cannot be smaller than "
f"the number of predictions per example",
)
y_pred, y_true, label_relevant_counts = self._maybe_sort_top_k(
y_pred, y_true, self.label_relevant_counts
)
ag_fn = tf.__internal__.autograph.tf_convert(
self._fn, tf.__internal__.autograph.control_status_ctx()
)
matches = ag_fn(
y_true,
y_pred,
label_relevant_counts=label_relevant_counts,
k=self.k,
**self._fn_kwargs,
)
return super().update_state(matches, sample_weight=sample_weight)
def _maybe_sort_top_k(self, y_pred, y_true, label_relevant_counts: tf.Tensor = None):
if not self.pre_sorted:
y_pred, y_true, label_relevant_counts = extract_topk(self.k, y_pred, y_true)
else:
if label_relevant_counts is None:
raise Exception(
"If y_true was pre-sorted (and truncated to top-k) you must "
"provide label_relevant_counts argument."
)
label_relevant_counts = tf.cast(label_relevant_counts, self._dtype)
return y_pred, y_true, label_relevant_counts
def check_cast_inputs(self, labels, predictions):
tf.assert_equal(
tf.rank(predictions), 2, f"predictions must be 2-D tensor (got {predictions.shape})"
)
tf.assert_equal(tf.rank(labels), 2, f"labels must be 2-D tensor (got {labels.shape})")
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
return tf.cast(labels, self._dtype), tf.cast(predictions, self._dtype)
def get_config(self):
config = {}
if type(self) is TopkMetric:
# Only include function argument when the object is of a subclass.
config["fn"] = self._fn
config["k"] = self.k
config["pre_sorted"] = self.pre_sorted
for k, v in self._fn_kwargs.items():
config[k] = backend.eval(v) if tf.is_tensor(v) or isinstance(v, tf.Variable) else v
base_config = super(TopkMetric, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
return {}
@classmethod
def from_config(cls, config):
fn = config.pop("fn", None)
k = config.pop("k", None)
pre_sorted = config.pop("pre_sorted", None)
if cls is TopkMetric:
return cls(get_metric(fn), k=k, pre_sorted=pre_sorted, **config)
return super(TopkMetric, cls).from_config(config)
[docs]@metrics_registry.register_with_multiple_names("recall_at", "recall")
class RecallAt(TopkMetric):
[docs] def __init__(self, k=10, pre_sorted=False, name="recall_at"):
super().__init__(recall_at, k=k, pre_sorted=pre_sorted, name=name)
@metrics_registry.register_with_multiple_names("precision_at", "precision")
class PrecisionAt(TopkMetric):
def __init__(self, k=10, pre_sorted=False, name="precision_at"):
super().__init__(precision_at, k=k, pre_sorted=pre_sorted, name=name)
[docs]@metrics_registry.register_with_multiple_names("map_at", "map")
class AvgPrecisionAt(TopkMetric):
[docs] def __init__(self, k=10, pre_sorted=False, name="map_at"):
super().__init__(average_precision_at, k=k, pre_sorted=pre_sorted, name=name)
@metrics_registry.register_with_multiple_names("mrr_at", "mrr")
class MRRAt(TopkMetric):
def __init__(self, k=10, pre_sorted=False, name="mrr_at"):
super().__init__(mrr_at, k=k, pre_sorted=pre_sorted, name=name)
[docs]@metrics_registry.register_with_multiple_names("ndcg_at", "ndcg")
class NDCGAt(TopkMetric):
[docs] def __init__(self, k=10, pre_sorted=False, name="ndcg_at"):
super().__init__(ndcg_at, k=k, pre_sorted=pre_sorted, name=name)
class TopKMetricsAggregator(Metric, TopkMetricWithLabelRelevantCountsMixin):
"""Aggregator for top-k metrics (TopkMetric) that is optimized
to sort top-k predictions only once for all metrics.
*topk_metrics : TopkMetric
Multiple arguments with TopkMetric instances
"""
def __init__(self, *topk_metrics: TopkMetric):
super(TopKMetricsAggregator, self).__init__()
assert len(topk_metrics) > 0, "At least one topk_metrics should be provided"
assert all(
isinstance(m, TopkMetric) for m in topk_metrics
), "All provided metrics should inherit from TopkMetric"
self.topk_metrics = topk_metrics
# Setting the `pre_sorted` of topk metrics so that
# prediction scores are not sorted again for each metric
for metric in self.topk_metrics:
metric.pre_sorted = True
self.k = max([m.k for m in self.topk_metrics])
self.label_relevant_counts = None
def update_state(
self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: Optional[tf.Tensor] = None
):
# Extractubg sorted top-k prediction scores and labels only ONCE
# so that sorting does not need to happen for each individual metric
# (as the top-k metrics have been set with pre_sorted=True in this constructor
y_pred, y_true, label_relevant_counts_from_targets = extract_topk(self.k, y_pred, y_true)
# If label_relevant_counts is not set by a block (e.g. TopKIndexBlock) that
# has already extracted the top-k predictions, it is assumed that
# y_true contains all items
label_relevant_counts = self.label_relevant_counts
if label_relevant_counts is None:
label_relevant_counts = label_relevant_counts_from_targets
for metric in self.topk_metrics:
# Sets the label_relevant_counts using a property,
# as metric.update_state() should have standard signature
# for better compatibility with Keras
metric.label_relevant_counts = label_relevant_counts
metric.update_state(y_true, y_pred, sample_weight)
def result(self):
outputs = {}
for metric in self.topk_metrics:
outputs[metric.name] = metric.result()
return outputs
@classmethod
def default_metrics(cls, top_ks: Sequence[int], **kwargs) -> Sequence[TopkMetric]:
"""Returns an TopKMetricsAggregator instance with the default top-k metrics
at the cut-offs defined in top_ks
Parameters
----------
top_ks : Sequence[int]
List with the cut-offs for top-k metrics (e.g. [5,10,50])
Returns
-------
Sequence[TopkMetric]
A TopKMetricsAggregator instance with the default top-k metrics at the predefined
cut-offs
"""
metrics: List[TopkMetric] = []
for k in top_ks:
metrics.extend([RecallAt(k), MRRAt(k), NDCGAt(k), AvgPrecisionAt(k), PrecisionAt(k)])
# Using Top-k metrics aggregator provides better performance than having top-k
# metrics computed separately, as prediction scores are sorted only once for all metrics
aggregator = cls(*metrics)
return [aggregator]
def filter_topk_metrics(
metrics: Sequence[Metric],
) -> List[Union[TopkMetric, TopKMetricsAggregator]]:
"""Returns only top-k metrics from the list of metrics
Parameters
----------
metrics : List[Metric]
List of metrics
Returns
-------
List[Union[TopkMetric, TopKMetricsAggregator]]
List with the top-k metrics in the list of input metrics
"""
topk_metrics = list(
[
metric
for metric in metrics
if isinstance(metric, TopkMetric) or isinstance(metric, TopKMetricsAggregator)
]
)
return topk_metrics