Source code for transformers4rec.torch.ranking_metric

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Adapted from source code: https://github.com/karlhigley/ranking-metrics-torch
from abc import abstractmethod

import torch
import torchmetrics as tm
from merlin.models.utils.registry import Registry
from torchmetrics.utilities.data import dim_zero_cat

from .utils import torch_utils

ranking_metrics_registry = Registry.class_registry("torch.ranking_metrics")


[docs]class RankingMetric(tm.Metric): """ Metric wrapper for computing ranking metrics@K for session-based task. Parameters ---------- top_ks : list, default [2, 5]) list of cutoffs labels_onehot : bool Enable transform the labels to one-hot representation """ def __init__(self, top_ks=None, labels_onehot=False): super(RankingMetric, self).__init__() self.top_ks = top_ks or [2, 5] if not isinstance(self.top_ks, (list, tuple)): self.top_ks = [self.top_ks] self.labels_onehot = labels_onehot # Store the mean of the batch metrics (for each cut-off at topk) self.add_state("metric_mean", default=[], dist_reduce_fx="cat")
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor, **kwargs): # type: ignore # Computing the metrics at different cut-offs if self.labels_onehot: target = torch_utils.tranform_label_to_onehot(target, preds.size(-1)) metric = self._metric( self.top_ks, preds.view(-1, preds.size(-1)), target.view(-1, target.size(-1)) ) self.metric_mean.append(metric) # type: ignore
[docs] def compute(self): # Computing the mean of the batch metrics (for each cut-off at topk) return dim_zero_cat(self.metric_mean).mean(0)
@abstractmethod def _metric(self, ks: torch.Tensor, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute a ranking metric over a predictions and one-hot targets. This method should be overridden by subclasses. """
[docs]@ranking_metrics_registry.register_with_multiple_names("precision_at", "precision") class PrecisionAt(RankingMetric): def __init__(self, top_ks=None, labels_onehot=False): super(PrecisionAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot) def _metric(self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """Compute precision@K for each of the provided cutoffs Parameters ---------- ks : torch.Tensor or list list of cutoffs scores : torch.Tensor predicted item scores labels : torch.Tensor true item labels Returns ------- torch.Tensor: list of precisions at cutoffs """ ks, scores, labels = torch_utils.check_inputs(ks, scores, labels) _, _, topk_labels = torch_utils.extract_topk(ks, scores, labels) precisions = torch_utils.create_output_placeholder(scores, ks) for index, k in enumerate(ks): precisions[:, index] = torch.sum(topk_labels[:, : int(k)], dim=1) / float(k) return precisions
[docs]@ranking_metrics_registry.register_with_multiple_names("recall_at", "recall") class RecallAt(RankingMetric): def __init__(self, top_ks=None, labels_onehot=False): super(RecallAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot) def _metric(self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """Compute recall@K for each of the provided cutoffs Parameters ---------- ks : torch.Tensor or list list of cutoffs scores : torch.Tensor predicted item scores labels : torch.Tensor true item labels Returns ------- torch.Tensor: list of recalls at cutoffs """ ks, scores, labels = torch_utils.check_inputs(ks, scores, labels) _, _, topk_labels = torch_utils.extract_topk(ks, scores, labels) recalls = torch_utils.create_output_placeholder(scores, ks) # Compute recalls at K num_relevant = torch.sum(labels, dim=-1) rel_indices = (num_relevant != 0).nonzero().squeeze(dim=1) rel_count = num_relevant[rel_indices] if rel_indices.shape[0] > 0: for index, k in enumerate(ks): rel_labels = topk_labels[rel_indices, : int(k)] recalls[rel_indices, index] = torch.div( torch.sum(rel_labels, dim=-1), rel_count ).to( dtype=torch.float32 ) # Ensuring type is double, because it can be float if --fp16 return recalls
[docs]@ranking_metrics_registry.register_with_multiple_names("avg_precision_at", "avg_precision", "map") class AvgPrecisionAt(RankingMetric): def __init__(self, top_ks=None, labels_onehot=False): super(AvgPrecisionAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot) self.precision_at = PrecisionAt(top_ks)._metric def _metric(self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """Compute average precision at K for provided cutoffs Parameters ---------- ks : torch.Tensor or list list of cutoffs scores : torch.Tensor 2-dim tensor of predicted item scores labels : torch.Tensor 2-dim tensor of true item labels Returns ------- torch.Tensor: list of average precisions at cutoffs """ ks, scores, labels = torch_utils.check_inputs(ks, scores, labels) topk_scores, _, topk_labels = torch_utils.extract_topk(ks, scores, labels) avg_precisions = torch_utils.create_output_placeholder(scores, ks) # Compute average precisions at K num_relevant = torch.sum(labels, dim=1) max_k = max(ks) precisions = self.precision_at(list(range(1, max_k + 1)), topk_scores, topk_labels) rel_precisions = precisions * topk_labels for index, k in enumerate(ks): total_prec = rel_precisions[:, : int(k)].sum(dim=1) avg_precisions[:, index] = total_prec / num_relevant.clamp(min=1, max=k).to( dtype=torch.float32, device=scores.device ) # Ensuring type is double, because it can be float if --fp16 return avg_precisions
[docs]@ranking_metrics_registry.register_with_multiple_names("dcg_at", "dcg") class DCGAt(RankingMetric): def __init__(self, top_ks=None, labels_onehot=False): super(DCGAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot) def _metric( self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, log_base: int = 2 ) -> torch.Tensor: """Compute discounted cumulative gain at K for provided cutoffs (ignoring ties) Parameters ---------- ks : torch.Tensor or list list of cutoffs scores : torch.Tensor predicted item scores labels : torch.Tensor true item labels Returns ------- torch.Tensor : list of discounted cumulative gains at cutoffs """ ks, scores, labels = torch_utils.check_inputs(ks, scores, labels) topk_scores, topk_indices, topk_labels = torch_utils.extract_topk(ks, scores, labels) dcgs = torch_utils.create_output_placeholder(scores, ks) # Compute discounts discount_positions = torch.arange(max(ks)).to(device=scores.device, dtype=torch.float32) discount_log_base = torch.log( torch.Tensor([log_base]).to(device=scores.device, dtype=torch.float32) ).item() discounts = 1 / (torch.log(discount_positions + 2) / discount_log_base) # Compute DCGs at K for index, k in enumerate(ks): dcgs[:, index] = torch.sum( (topk_labels[:, :k] * discounts[:k].repeat(topk_labels.shape[0], 1)), dim=1 ).to( dtype=torch.float32, device=scores.device ) # Ensuring type is double, because it can be float if --fp16 return dcgs
[docs]@ranking_metrics_registry.register_with_multiple_names("ndcg_at", "ndcg") class NDCGAt(RankingMetric): def __init__(self, top_ks=None, labels_onehot=False): super(NDCGAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot) self.dcg_at = DCGAt(top_ks)._metric def _metric( self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, log_base: int = 2 ) -> torch.Tensor: """Compute normalized discounted cumulative gain at K for provided cutoffs (ignoring ties) Parameters ---------- ks : torch.Tensor or list list of cutoffs scores : torch.Tensor predicted item scores labels : torch.Tensor true item labels Returns ------- torch.Tensor : list of discounted cumulative gains at cutoffs """ ks, scores, labels = torch_utils.check_inputs(ks, scores, labels) topk_scores, topk_indices, topk_labels = torch_utils.extract_topk(ks, scores, labels) # ndcgs = _create_output_placeholder(scores, ks) #TODO track if this line is needed # Compute discounted cumulative gains gains = self.dcg_at(ks, topk_scores, topk_labels) normalizing_gains = self.dcg_at(ks, topk_labels, topk_labels) # Prevent divisions by zero relevant_pos = (normalizing_gains != 0).nonzero(as_tuple=True) irrelevant_pos = (normalizing_gains == 0).nonzero(as_tuple=True) gains[irrelevant_pos] = 0 gains[relevant_pos] /= normalizing_gains[relevant_pos] return gains
[docs]@ranking_metrics_registry.register_with_multiple_names("mrr_at", "mrr") class MeanReciprocalRankAt(RankingMetric): def __init__(self, top_ks=None, labels_onehot=False): super(MeanReciprocalRankAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot) def _metric( self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, log_base: int = 2 ) -> torch.Tensor: """Compute mean recipricol rank at K for provided cutoffs (ignoring ties) Parameters ---------- ks : torch.Tensor or list list of cutoffs scores : torch.Tensor predicted item scores labels : torch.Tensor true item labels Returns ------- torch.Tensor : list of mean recipricol rank at cutoffs """ ks, scores, labels = torch_utils.check_inputs(ks, scores, labels) topk_scores, topk_indices, topk_labels = torch_utils.extract_topk(ks, scores, labels) results = torch.zeros(scores.shape[0], len(ks)).to( device=scores.device, dtype=torch.float32 ) for index, k in enumerate(ks): values, _ = (topk_labels[:, :k] / (torch.arange(k) + 1).to(device=scores.device)).max( dim=1 ) results[:, index] = values return results