Source code for transformers4rec.torch.model.base

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import inspect
from collections import defaultdict
from types import SimpleNamespace
from typing import Callable, Dict, Iterable, List, Optional, Type, Union, cast

import numpy as np
import torch
import torchmetrics as tm
from tqdm import tqdm
from transformers.modeling_utils import SequenceSummary

from merlin_standard_lib import Schema, Tag
from merlin_standard_lib.registry import camelcase_to_snakecase

from ..block.base import BlockBase, BlockOrModule, BlockType
from ..features.base import InputBlock
from ..features.sequence import TabularFeaturesType
from ..typing import TabularData, TensorOrTabularData
from ..utils.torch_utils import LossMixin, MetricsMixin


def name_fn(name, inp):
    return "/".join([name, inp]) if name else None


[docs]class PredictionTask(torch.nn.Module, LossMixin, MetricsMixin): """Individual prediction-task of a model. Parameters ---------- loss: torch.nn.Module The loss to use during training of this task. metrics: torch.nn.Module The metrics to calculate during training & evaluation. target_name: str, optional Name of the target, this is needed when there are multiple targets. task_name: str, optional Name of the prediction task, if not provided a name will be automatically constructed based on the target-name & class-name. forward_to_prediction_fn: Callable[[torch.Tensor], torch.Tensor] Function to apply before the prediction task_block: BlockType Module to transform input tensor before computing predictions. pre: BlockType Module to compute the predictions probabilities. summary_type: str This is used to summarize a sequence into a single tensor. Accepted values are: - `"last"` -- Take the last token hidden state (like XLNet) - `"first"` -- Take the first token hidden state (like Bert) - `"mean"` -- Take the mean of all tokens hidden states - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) - `"attn"` -- Not implemented now, use multi-head attention """ def __init__( self, loss: torch.nn.Module, metrics: Iterable[tm.Metric] = None, target_name: Optional[str] = None, task_name: Optional[str] = None, forward_to_prediction_fn: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, task_block: Optional[BlockType] = None, pre: Optional[BlockType] = None, summary_type: str = "last", ): super().__init__() self.sequence_summary = SequenceSummary( SimpleNamespace(summary_type=summary_type) # type: ignore ) # noqa self.target_name = target_name self.forward_to_prediction_fn = forward_to_prediction_fn self.set_metrics(metrics) self.loss = loss self.pre = pre self.task_block = task_block self._task_name = task_name
[docs] def build( self, body: BlockType, input_size, inputs: Optional[InputBlock] = None, device=None, task_block: Optional[BlockType] = None, pre=None, ): """ The method will be called when block is converted to a model, i.e when linked to prediction head. Parameters ---------- block: the model block to link with head device: set the device for the metrics and layers of the task """ if task_block: # TODO: What to do when `self.task_block is not None`? self.task_block = task_block if pre: # TODO: What to do when `self.pre is not None`? self.pre = pre # Build task block pre_input_size = input_size if self.task_block: if isinstance(self.task_block, torch.nn.Module): self.task_block = copy.deepcopy(self.task_block) else: self.task_block = self.task_block.build(input_size) pre_input_size = self.task_block.output_size() # type: ignore if self.pre: if isinstance(self.pre, torch.nn.Module): self.pre = copy.deepcopy(self.pre) else: self.pre = self.pre.build(pre_input_size) if device: self.to(device) for metric in self.metrics: metric.to(device) self.built = True
[docs] def forward(self, inputs, **kwargs): x = inputs if len(x.size()) == 3: x = self.sequence_summary(x) if self.task_block: x = self.task_block(x) if self.pre: x = self.pre(x) return x
@property def task_name(self): if self._task_name: return self._task_name base_name = camelcase_to_snakecase(self.__class__.__name__) return name_fn(self.target_name, base_name) if self.target_name else base_name
[docs] def child_name(self, name): return name_fn(self.task_name, name)
[docs] def set_metrics(self, metrics): self.metrics = torch.nn.ModuleList(metrics)
[docs] def compute_loss( self, inputs: Union[torch.Tensor, TabularData], targets: Union[torch.Tensor, TabularData], compute_metrics: bool = True, training: bool = False, **kwargs, ) -> torch.Tensor: if isinstance(targets, dict) and self.target_name: targets = targets[self.target_name] predictions = self(inputs, training=training) loss = self.loss(predictions, targets) if compute_metrics: self.calculate_metrics(predictions, targets, mode="train", forward=False) return loss return loss
[docs] def calculate_metrics( # type: ignore self, predictions: Union[torch.Tensor, TabularData], targets: Union[torch.Tensor, TabularData], mode: str = "val", forward: bool = True, **kwargs, ) -> Dict[str, torch.Tensor]: if isinstance(targets, dict) and self.target_name: targets = targets[self.target_name] outputs = {} if forward: predictions = self(predictions) predictions = self.forward_to_prediction_fn(cast(torch.Tensor, predictions)) from .prediction_task import BinaryClassificationTask for metric in self.metrics: if isinstance(metric, tuple(type(x) for x in BinaryClassificationTask.DEFAULT_METRICS)): targets = cast(torch.Tensor, targets).int() outputs[self.metric_name(metric)] = metric(predictions, targets) return outputs
[docs] def compute_metrics(self, **kwargs): return {self.metric_name(metric): metric.compute() for metric in self.metrics}
[docs] def metric_name(self, metric: tm.Metric) -> str: return self.child_name(camelcase_to_snakecase(metric.__class__.__name__))
[docs] def reset_metrics(self): for metric in self.metrics: metric.reset()
[docs] def to_head(self, body, inputs=None, **kwargs) -> "Head": return Head(body, self, inputs=inputs, **kwargs)
[docs] def to_model(self, body, inputs=None, **kwargs) -> "Model": return Model(Head(body, self, inputs=inputs, **kwargs), **kwargs)
[docs]class Model(torch.nn.Module, LossMixin, MetricsMixin): """Model class that can aggregate one of multiple heads. Parameters ---------- head: Head One or more heads of the model. head_weights: List[float], optional Weight-value to use for each head. head_reduction: str, optional How to reduce the losses into a single tensor when multiple heads are used. optimizer: Type[torch.optim.Optimizer] Optimizer-class to use during fitting name: str, optional Name of the model. """ def __init__( self, *head: Head, head_weights: Optional[List[float]] = None, head_reduction: str = "mean", optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, name=None, ): """ #TODO """ if head_weights: if not isinstance(head_weights, list): raise ValueError("`head_weights` must be a list") if not len(head_weights) == len(head): raise ValueError( "`head_weights` needs to have the same length " "as the number of heads" ) super().__init__() self.name = name self.heads = torch.nn.ModuleList(head) self.head_weights = head_weights or [1.0] * len(head) self.head_reduction = head_reduction self.optimizer = optimizer
[docs] def forward(self, inputs: TensorOrTabularData, training=True, **kwargs): # TODO: Optimize this outputs = {} for head in self.heads: outputs.update(head(inputs, call_body=True, training=training, always_output_dict=True)) if len(outputs) == 1: return outputs[list(outputs.keys())[0]] return outputs
[docs] def compute_loss(self, inputs, targets, compute_metrics=True, **kwargs) -> torch.Tensor: losses = [] for i, head in enumerate(self.heads): loss = head.compute_loss( inputs, targets, call_body=True, compute_metrics=compute_metrics, **kwargs ) losses.append(loss * self.head_weights[i]) loss_tensor = torch.stack(losses) return getattr(loss_tensor, self.head_reduction)()
[docs] def calculate_metrics( # type: ignore self, inputs, targets, mode="val", call_body=True, forward=True, **kwargs ) -> Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]]: outputs = {} for head in self.heads: outputs.update( head.calculate_metrics( inputs, targets, mode=mode, call_body=call_body, forward=forward, **kwargs ) ) return outputs
[docs] def compute_metrics(self, mode=None) -> Dict[str, Union[float, torch.Tensor]]: metrics = {} for head in self.heads: metrics.update(head.compute_metrics(mode=mode)) return metrics
[docs] def reset_metrics(self): for head in self.heads: head.reset_metrics()
[docs] def to_lightning(self): import pytorch_lightning as pl parent_self = self class BlockWithHeadLightning(pl.LightningModule): def __init__(self): super(BlockWithHeadLightning, self).__init__() self.parent = parent_self def forward(self, inputs, *args, **kwargs): return self.parent(inputs, *args, **kwargs) def training_step(self, batch, batch_idx): loss = self.parent.compute_loss(*batch) self.log("train_loss", loss) return loss def configure_optimizers(self): optimizer = self.parent.optimizer(self.parent.parameters(), lr=1e-3) return optimizer return BlockWithHeadLightning()
[docs] def fit( self, dataloader, optimizer=torch.optim.Adam, eval_dataloader=None, num_epochs=1, amp=False, train=True, verbose=True, ): if isinstance(dataloader, torch.utils.data.DataLoader): dataset = dataloader.dataset else: dataset = dataloader if inspect.isclass(optimizer): optimizer = optimizer(self.parameters()) self.train(mode=train) epoch_losses = [] with torch.set_grad_enabled(mode=train): for epoch in range(num_epochs): losses = [] batch_iterator = enumerate(iter(dataset)) if verbose: batch_iterator = tqdm(batch_iterator) for batch_idx, (x, y) in batch_iterator: if amp: with torch.cuda.amp.autocast(): loss = self.compute_loss(x, y) else: loss = self.compute_loss(x, y) losses.append(float(loss)) if train: optimizer.zero_grad() loss.backward() optimizer.step() if verbose: print(self.compute_metrics(mode="train")) if eval_dataloader: print(self.evaluate(eval_dataloader, verbose=False)) epoch_losses.append(np.mean(losses)) return np.array(epoch_losses)
[docs] def evaluate(self, dataloader, verbose=True, mode="eval"): if isinstance(dataloader, torch.utils.data.DataLoader): dataset = dataloader.dataset else: dataset = dataloader batch_iterator = enumerate(iter(dataset)) if verbose: batch_iterator = tqdm(batch_iterator) self.reset_metrics() for batch_idx, (x, y) in batch_iterator: self.calculate_metrics(x, y, mode=mode) return self.compute_metrics(mode=mode)
def _get_name(self): if self.name: return self.name return super(Model, self)._get_name()
def _output_metrics(metrics): # If there is only a single head with metrics, returns just those metrics if len(metrics) == 1 and isinstance(metrics[list(metrics.keys())[0]], dict): return metrics[list(metrics.keys())[0]] return metrics