Source code for transformers4rec.tf.model.prediction_task

import logging
from typing import Dict, Optional, Sequence, Type, Union

import tensorflow as tf
from tensorflow.keras.layers import Layer

from ..block.mlp import MLPBlock
from ..ranking_metric import AvgPrecisionAt, NDCGAt, RecallAt
from ..utils.tf_utils import maybe_deserialize_keras_objects, maybe_serialize_keras_objects
from .base import PredictionTask


[docs]def name_fn(name, inp): return "/".join([name, inp]) if name else None
MetricOrMetricClass = Union[tf.keras.metrics.Metric, Type[tf.keras.metrics.Metric]] LOG = logging.getLogger("transformers4rec")
[docs]@tf.keras.utils.register_keras_serializable(package="transformers4rec") class BinaryClassificationTask(PredictionTask): DEFAULT_LOSS = tf.keras.losses.BinaryCrossentropy() DEFAULT_METRICS = ( tf.keras.metrics.Precision, tf.keras.metrics.Recall, tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.AUC, ) def __init__( self, target_name: Optional[str] = None, task_name: Optional[str] = None, task_block: Optional[Layer] = None, loss=DEFAULT_LOSS, metrics: Sequence[MetricOrMetricClass] = DEFAULT_METRICS, summary_type="first", **kwargs, ): super().__init__( loss=loss, metrics=list(metrics), target_name=target_name, task_name=task_name, summary_type=summary_type, task_block=task_block, **kwargs, ) self.pre = tf.keras.layers.Dense(1, activation="sigmoid", name=self.child_name("logit"))
[docs]@tf.keras.utils.register_keras_serializable(package="transformers4rec") class RegressionTask(PredictionTask): DEFAULT_LOSS = tf.keras.losses.MeanSquaredError() DEFAULT_METRICS = (tf.keras.metrics.RootMeanSquaredError,) def __init__( self, target_name: Optional[str] = None, task_name: Optional[str] = None, task_block: Optional[Layer] = None, loss=DEFAULT_LOSS, metrics=DEFAULT_METRICS, summary_type="first", **kwargs, ): super().__init__( loss=loss, metrics=metrics, target_name=target_name, task_name=task_name, summary_type=summary_type, task_block=task_block, **kwargs, ) self.pre = tf.keras.layers.Dense(1, name=self.child_name("logit"))
[docs]@tf.keras.utils.register_keras_serializable(package="transformers4rec") class NextItemPredictionTask(PredictionTask): """Next-item prediction task. Parameters ---------- loss: Loss function. SparseCategoricalCrossentropy() metrics: List of RankingMetrics to be evaluated. prediction_metrics: List of Keras metrics used to summarize the predictions. label_metrics: List of Keras metrics used to summarize the labels. loss_metrics: List of Keras metrics used to summarize the loss. name: Optional task name. target_dim: int Dimension of the target. weight_tying: bool The item id embedding table weights are shared with the prediction network layer. item_embedding_table: tf.Variable Variable of embedding table for the item. softmax_temperature: float Softmax temperature, used to reduce model overconfidence, so that softmax(logits / T). Value 1.0 reduces to regular softmax. """ DEFAULT_LOSS = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, ) DEFAULT_METRICS = ( # default metrics suppose labels are int encoded NDCGAt(top_ks=[10, 20], labels_onehot=True), AvgPrecisionAt(top_ks=[10, 20], labels_onehot=True), RecallAt(top_ks=[10, 20], labels_onehot=True), ) def __init__( self, loss=DEFAULT_LOSS, metrics=DEFAULT_METRICS, target_name: Optional[str] = None, task_name: Optional[str] = None, task_block: Optional[Layer] = None, weight_tying: bool = True, target_dim: int = None, softmax_temperature: float = 1, padding_idx: int = 0, **kwargs, ): super().__init__( loss=loss, metrics=metrics, target_name=target_name, task_name=task_name, task_block=task_block, **kwargs, ) self.weight_tying = weight_tying self.target_dim = target_dim self.softmax_temperature = softmax_temperature self.padding_idx = padding_idx
[docs] def build(self, input_shape, body, inputs=None): # Retrieve the embedding module to get the name of itemid col and its related table if not len(input_shape) == 3 or isinstance(input_shape, dict): raise ValueError( "NextItemPredictionTask needs a 3-dim vector as input, found:" f"{input_shape}" ) if not inputs: inputs = body.inputs if not getattr(inputs, "item_id", None): raise ValueError( "For Item Prediction task a categorical_module " "including an item_id column is required." ) self.embeddings = inputs.categorical_layer if not self.target_dim: self.target_dim = self.embeddings.item_embedding_table.shape[0] if self.weight_tying: self.item_embedding_table = self.embeddings.item_embedding_table item_dim = self.item_embedding_table.shape[1] if input_shape[-1] != item_dim and not self.task_block: LOG.warning( f"Projecting inputs of NextItemPredictionTask to'{item_dim}' " f"As weight tying requires the input dimension '{input_shape[-1]}' " f"to be equal to the item-id embedding dimension '{item_dim}'" ) # project input tensors to same dimension as item-id embeddings self.task_block = MLPBlock([item_dim]) # Retrieve the masking if used in the model block self.masking = inputs.masking if self.masking: self.padding_idx = self.masking.padding_idx self.pre = _NextItemPredictionTask( target_dim=self.target_dim, weight_tying=self.weight_tying, item_embedding_table=self.item_embedding_table, softmax_temperature=self.softmax_temperature, ) return super().build(input_shape)
[docs] def call(self, inputs, **kwargs): if isinstance(inputs, (tuple, list)): inputs = inputs[0] x = inputs if self.task_block: x = self.task_block(x) # retrieve labels from masking if self.masking: labels = self.masking.masked_targets else: labels = self.embeddings.item_seq # remove vectors of padded items trg_flat = tf.reshape(labels, (-1,)) non_pad_mask = trg_flat != self.padding_idx x = self.remove_pad_3d(x, non_pad_mask) # compute predictions probs x = self.pre(x) return x
[docs] def remove_pad_3d(self, inp_tensor, non_pad_mask): # inp_tensor: (n_batch x seqlen x emb_dim) inp_tensor = tf.reshape(inp_tensor, (-1, inp_tensor.shape[-1])) inp_tensor_fl = tf.boolean_mask( inp_tensor, tf.broadcast_to(tf.expand_dims(non_pad_mask, 1), tf.shape(inp_tensor)) ) out_tensor = tf.reshape(inp_tensor_fl, (-1, inp_tensor.shape[1])) return out_tensor
[docs] def compute_loss( # type: ignore self, inputs, targets=None, compute_metrics: bool = True, call_task: bool = True, sample_weight: Optional[tf.Tensor] = None, **kwargs, ) -> tf.Tensor: if isinstance(targets, dict) and self.target_name: targets = targets[self.target_name] predictions = inputs if call_task: predictions = self(inputs) # retrieve labels from masking if self.masking: targets = self.masking.masked_targets else: targets = self.embeddings.item_seq # flatten labels and remove padding index targets = tf.reshape(targets, (-1,)) non_pad_mask = targets != self.padding_idx targets = tf.boolean_mask(targets, non_pad_mask) loss = self.loss(y_true=targets, y_pred=predictions, sample_weight=sample_weight) if compute_metrics: update_ops = self.calculate_metrics(predictions, targets, forward=False, loss=loss) update_ops = [x for x in update_ops if x is not None] with tf.control_dependencies(update_ops): return tf.identity(loss) return loss
[docs] def calculate_metrics( self, predictions, targets=None, sample_weight=None, forward=True, loss=None ): if isinstance(targets, dict) and self.target_name: targets = targets[self.target_name] if forward: predictions = self(predictions) # retrieve labels from masking if self.masking: targets = self.masking.masked_targets # flatten labels and remove padding index targets = tf.reshape(targets, -1) non_pad_mask = targets != self.padding_idx targets = tf.boolean_mask(targets, non_pad_mask) update_ops = [] for metric in self.eval_metrics: update_ops.append( metric.update_state(y_true=targets, y_pred=predictions, sample_weight=sample_weight) ) for metric in self.prediction_metrics: update_ops.append(metric.update_state(predictions, sample_weight=sample_weight)) for metric in self.label_metrics: update_ops.append(metric.update_state(targets, sample_weight=sample_weight)) for metric in self.loss_metrics: if not loss: loss = self.loss(y_true=targets, y_pred=predictions, sample_weight=sample_weight) update_ops.append(metric.update_state(loss, sample_weight=sample_weight)) return update_ops
[docs] def metric_results(self, mode: str = None) -> Dict[str, tf.Tensor]: metrics = {metric.name: metric.result() for metric in self.eval_metrics} topks = {metric.name: metric.top_ks for metric in self.eval_metrics} # explode metrics for each cut-off in top_ks results = {} for name, metric in metrics.items(): for measure, k in zip(metric, topks[name]): results[f"{name}_{k}"] = measure return results
@tf.keras.utils.register_keras_serializable(package="transformers4rec") class _NextItemPredictionTask(tf.keras.layers.Layer): """Predict the interacted item-id probabilities. - During inference, the task consists of predicting the next item. - During training, the class supports the following Language modeling tasks: Causal LM and Masked LM. p.s: we are planning to support Permutation LM and Replacement Token Detection in future release. Parameters: ----------- target_dim: int Dimension of the target. weight_tying: bool The item id embedding table weights are shared with the prediction network layer. item_embedding_table: tf.Variable Variable of embedding table for the item. softmax_temperature: float Softmax temperature, used to reduce model overconfidence, so that softmax(logits / T). Value 1.0 reduces to regular softmax. """ def __init__( self, target_dim: int, weight_tying: bool = True, item_embedding_table: Optional[tf.Variable] = None, softmax_temperature: float = 0, **kwargs, ): super().__init__(**kwargs) self.target_dim = target_dim self.weight_tying = weight_tying self.item_embedding_table = item_embedding_table self.softmax_temperature = softmax_temperature if self.weight_tying: if item_embedding_table is None: raise ValueError( "For Item Prediction task with weight tying " "the embedding table of item_id is required ." ) self.output_layer_bias = self.add_weight( name="output_layer_bias", shape=(self.target_dim,), initializer=tf.keras.initializers.Zeros(), ) else: self.output_layer = tf.keras.layers.Dense( units=self.target_dim, kernel_initializer="random_normal", bias_initializer="zeros", name="logits", ) @classmethod def from_config(cls, config): config = maybe_deserialize_keras_objects(config, ["output_layer"]) return super().from_config(config) def get_config(self): config = super().get_config() config = maybe_serialize_keras_objects(self, config, ["output_layer"]) config["target_dim"] = self.target_dim config["weight_tying"] = self.weight_tying config["softmax_temperature"] = self.softmax_temperature return config def call(self, inputs: tf.Tensor, **kwargs): if self.weight_tying: logits = tf.matmul(inputs, tf.transpose(self.item_embedding_table)) logits = tf.nn.bias_add(logits, self.output_layer_bias) else: logits = self.output_layer(inputs) if self.softmax_temperature: # Softmax temperature to reduce model overconfidence # and better calibrate probs and accuracy logits = logits / self.softmax_temperature predictions = tf.nn.log_softmax(logits, axis=-1) return predictions def _get_name(self) -> str: return "_NextItemPredictionTask"