Source code for transformers4rec.tf.model.base

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
from collections import defaultdict
from types import SimpleNamespace
from typing import Dict, List, Optional, Text, Type, Union

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.python.framework import ops
from tensorflow.python.keras.utils import generic_utils
from transformers import TFSequenceSummary

from merlin_standard_lib import Schema, Tag

from ..features.sequence import TabularFeaturesType
from ..ranking_metric import process_metrics
from ..utils.tf_utils import (
    LossMixin,
    MetricsMixin,
    maybe_deserialize_keras_objects,
    maybe_serialize_keras_objects,
)


def name_fn(name, inp):
    return "/".join([name, inp]) if name else None


MetricOrMetricClass = Union[tf.keras.metrics.Metric, Type[tf.keras.metrics.Metric]]


[docs]class PredictionTask(Layer, LossMixin, MetricsMixin): def __init__( self, loss: tf.keras.losses.Loss, target_name: Optional[str] = None, task_name: Optional[str] = None, metrics: Optional[List[MetricOrMetricClass]] = None, pre: Optional[Layer] = None, task_block: Optional[Layer] = None, prediction_metrics: Optional[List[tf.keras.metrics.Metric]] = None, label_metrics: Optional[List[tf.keras.metrics.Metric]] = None, loss_metrics: Optional[List[tf.keras.metrics.Metric]] = None, name: Optional[Text] = None, summary_type="last", **kwargs, ) -> None: """Initializes the task. Parameters ---------- loss: Loss function. Defaults to BinaryCrossentropy. metrics: List of Keras metrics to be evaluated. prediction_metrics: List of Keras metrics used to summarize the predictions. label_metrics: List of Keras metrics used to summarize the labels. loss_metrics: List of Keras metrics used to summarize the loss. name: Optional task name. """ super().__init__(name=name, **kwargs) self.target_name = target_name self.sequence_summary = TFSequenceSummary( SimpleNamespace(summary_type=summary_type) # type: ignore ) # noqa self.pre = pre self.task_block = task_block self.loss = loss self._task_name = task_name create_metrics = self._create_metrics self.eval_metrics = create_metrics(metrics) if metrics else [] self.prediction_metrics = create_metrics(prediction_metrics) if prediction_metrics else [] self.label_metrics = create_metrics(label_metrics) if label_metrics else [] self.loss_metrics = create_metrics(loss_metrics) if loss_metrics else []
[docs] def call(self, inputs, training=False, **kwargs): x = inputs if len(x.shape) == 3: x = self.sequence_summary(x) if self.task_block: x = self.task_block(x) if self.pre: x = self.pre(x) return x
def _create_metrics(self, metrics: List[MetricOrMetricClass]) -> List[tf.keras.metrics.Metric]: outputs = [] for metric in metrics: if not isinstance(metric, tf.keras.metrics.Metric): metric = metric(name=self.child_name(generic_utils.to_snake_case(metric.__name__))) outputs.append(metric) return outputs @property def task_name(self): if self._task_name: return self._task_name base_name = generic_utils.to_snake_case(self.__class__.__name__) return name_fn(self.target_name, base_name) if self.target_name else base_name
[docs] def child_name(self, name): return name_fn(self.task_name, name)
[docs] def compute_loss( # type: ignore self, inputs, targets, training: bool = False, call_task: bool = True, compute_metrics=True, sample_weight: Optional[tf.Tensor] = None, **kwargs, ) -> tf.Tensor: if isinstance(targets, dict) and self.target_name: targets = targets[self.target_name] predictions = inputs if call_task: predictions = self(inputs, training=training, **kwargs) loss = self.loss(y_true=targets, y_pred=predictions, sample_weight=sample_weight) if compute_metrics: update_ops = self.calculate_metrics(predictions, targets, forward=False, loss=loss) update_ops = [x for x in update_ops if x is not None] with tf.control_dependencies(update_ops): return tf.identity(loss) return loss
[docs] def repr_add(self): return [("loss", self.loss)]
[docs] def calculate_metrics(self, predictions, targets, sample_weight=None, forward=True, loss=None): if isinstance(targets, dict) and self.target_name: targets = targets[self.target_name] if forward: predictions = self(predictions) update_ops = [] for metric in self.eval_metrics: update_ops.append( metric.update_state(y_true=targets, y_pred=predictions, sample_weight=sample_weight) ) for metric in self.prediction_metrics: update_ops.append(metric.update_state(predictions, sample_weight=sample_weight)) for metric in self.label_metrics: update_ops.append(metric.update_state(targets, sample_weight=sample_weight)) for metric in self.loss_metrics: if not loss: loss = self.loss(y_true=targets, y_pred=predictions, sample_weight=sample_weight) update_ops.append(metric.update_state(loss, sample_weight=sample_weight)) return update_ops
[docs] def metric_results(self, mode: str = None): return {metric.name: metric.result() for metric in self.metrics}
[docs] def reset_metrics(self): for metric in self.metrics: metric.reset_state()
[docs] def to_head(self, body, inputs=None, **kwargs) -> "Head": return Head(body, self, inputs=inputs, **kwargs)
[docs] def to_model(self, body, inputs=None, **kwargs) -> "Model": return Model(Head(body, self, inputs=inputs, **kwargs), **kwargs)
[docs] @classmethod def from_config(cls, config): config = maybe_deserialize_keras_objects( config, { "pre": tf.keras.layers.deserialize, "loss": tf.keras.losses.deserialize, "metrics": tf.keras.metrics.deserialize, "prediction_metrics": tf.keras.metrics.deserialize, "label_metrics": tf.keras.metrics.deserialize, "loss_metrics": tf.keras.metrics.deserialize, }, ) return super().from_config(config)
[docs] def get_config(self): config = super().get_config() config = maybe_serialize_keras_objects( self, config, ["metrics", "prediction_metrics", "label_metrics", "loss_metrics", "loss", "pre"], ) config["summary_type"] = self.sequence_summary.summary_type if self.target_name: config["target_name"] = self.target_name if self._task_name: config["task_name"] = self._task_name return config
class BaseModel(tf.keras.Model, LossMixin, abc.ABC): def train_step(self, inputs): """Custom train step using the `compute_loss` method.""" with tf.GradientTape() as tape: inputs, targets = inputs loss = self.compute_loss(inputs, targets, training=True) # Handle regularization losses as well. regularization_loss = sum(self.losses) total_loss = loss + regularization_loss gradients = tape.gradient(total_loss, self.trainable_variables) self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) metrics = process_metrics(self.metrics, prefix="train_") metrics["loss"] = loss metrics["regularization_loss"] = regularization_loss metrics["total_loss"] = total_loss return metrics def test_step(self, inputs): """Custom test step using the `compute_loss` method.""" loss = self.compute_loss(*inputs, training=False) # Handle regularization losses as well. regularization_loss = sum(self.losses) total_loss = loss + regularization_loss metrics = process_metrics(self.metrics, prefix="eval_") metrics["loss"] = loss metrics["regularization_loss"] = regularization_loss metrics["total_loss"] = total_loss return metrics
[docs]@tf.keras.utils.register_keras_serializable(package="transformers4rec") class Model(BaseModel): def __init__( self, *head: Head, head_weights: Optional[List[float]] = None, name=None, **kwargs ): if head_weights: if not isinstance(head_weights, (list, tuple)): raise ValueError("`head_weights` must be a list or tuple") if not len(head_weights) == len(head): raise ValueError( "`head_weights` needs to have the same length " "as the number of heads" ) super().__init__(name=name, **kwargs) self.heads = head self.head_weights = tuple(head_weights or [1.0] * len(head))
[docs] def call(self, inputs, **kwargs): # TODO: Optimize this outputs = {} for head in self.heads: outputs.update(head(inputs, call_body=True, always_output_dict=True)) if len(outputs) == 1: return outputs[list(outputs.keys())[0]] return outputs
[docs] def compute_loss( # type: ignore self, inputs, targets, training: bool = False, compute_metrics=True, **kwargs ) -> tf.Tensor: losses = tuple( [ head.compute_loss( inputs, targets, call_body=kwargs.pop("call_body", True), compute_metrics=compute_metrics, **kwargs, ) for head in self.heads ] ) with ops.name_scope("merge_losses", values=losses + self.head_weights): weighted_losses = [] for loss, head_weight in zip(losses, self.head_weights): weighted_losses.append(tf.math.multiply(loss, head_weight)) return tf.add_n(weighted_losses)
[docs] def metric_results(self, mode=None): outputs = [] for head in self.heads: outputs.append(head.metric_results(mode=mode)) if len(outputs) == 1: outputs = outputs[0] return outputs
[docs] @classmethod def from_config(cls, config, custom_objects=None): heads = [tf.keras.utils.deserialize_keras_object(h) for h in config.pop("heads")] return cls(*heads, **config)
[docs] def get_config(self): return { "head_weights": self.head_weights, "heads": [tf.keras.utils.serialize_keras_object(h) for h in self.heads], }
def _output_metrics(metrics): if len(metrics) == 1: return metrics[list(metrics.keys())[0]] return metrics