# type: ignore
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import inspect
import random
import re
from collections.abc import Sized
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.cuda.amp import autocast
from torch.optim import Optimizer
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from transformers import Trainer as BaseTrainer
from transformers.optimization import TYPE_TO_SCHEDULER_FUNCTION
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import (
find_batch_size,
nested_concat,
nested_numpify,
nested_truncate,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalLoopOutput, SchedulerType
from transformers.utils import logging
from merlin_standard_lib import Schema
from ..config.trainer import T4RecTrainingArguments
from .model.base import Model
from .utils.data_utils import T4RecDataLoader
logger = logging.get_logger(__name__)
[docs]class Trainer(BaseTrainer):
"""
An :class:`~transformers.Trainer` specialized for sequential recommendation
including (session-based and sequtial recommendation)
Parameters
----------
model: Model
The Model defined using Transformers4Rec api.
args: T4RecTrainingArguments
The training arguments needed to setup training and evaluation
experiments.
schema: Optional[Dataset.schema], optional
The schema object including features to use and their properties.
by default None
train_dataset_or_path: Optional[Union[str, Dataset]], optional
Path of parquet files or DataSet to use for training.
by default None
eval_dataset_or_path: Optional[str, Dataset], optional
Path of parquet files or DataSet to use for evaluation.
by default None
train_dataloader: Optional[DataLoader], optional
The data generator to use for training.
by default None
eval_dataloader: Optional[DataLoader], optional
The data generator to use for evaluation.
by default None
compute_metrics: Optional[bool], optional
Whether to compute metrics defined by Model class or not.
by default None
incremental_logging: bool
Whether to enable incremental logging or not. If True, it ensures that
global steps are incremented over many `trainer.train()` calls, so that
train and eval metrics steps do not overlap and can be seen properly
in reports like W&B and Tensorboard
"""
def __init__(
self,
model: Model,
args: T4RecTrainingArguments,
schema: Schema = None,
train_dataset_or_path=None,
eval_dataset_or_path=None,
train_dataloader: Optional[DataLoader] = None,
eval_dataloader: Optional[DataLoader] = None,
callbacks: Optional[List[TrainerCallback]] = [],
compute_metrics=None,
incremental_logging: bool = False,
**kwargs,
):
mock_dataset = DatasetMock()
hf_model = HFWrapper(model)
self.incremental_logging = incremental_logging
if self.incremental_logging:
self.past_global_steps = 0
incremental_logging_callback = IncrementalLoggingCallback(self)
callbacks.append(incremental_logging_callback)
super(Trainer, self).__init__(
model=hf_model,
args=args,
train_dataset=mock_dataset,
eval_dataset=mock_dataset,
callbacks=callbacks,
**kwargs,
)
self.compute_metrics = compute_metrics
self.train_dataset_or_path = train_dataset_or_path
self.eval_dataset_or_path = eval_dataset_or_path
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.schema = schema
self.incremental_logging = incremental_logging
[docs] def get_train_dataloader(self):
"""
Set the train dataloader to use by Trainer.
It supports user defined data-loader set as an attribute in the constructor.
When the attribute is None, The data-loader is defined using train_dataset
and the `data_loader_engine` specified in Training Arguments.
"""
if self.train_dataloader is not None:
return self.train_dataloader
assert self.schema is not None, "schema is required to generate Train Dataloader"
return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema(
self.schema,
self.train_dataset_or_path,
self.args.per_device_train_batch_size,
max_sequence_length=self.args.max_sequence_length,
drop_last=self.args.dataloader_drop_last,
shuffle=True,
shuffle_buffer_size=self.args.shuffle_buffer_size,
)
[docs] def get_eval_dataloader(self, eval_dataset=None):
"""
Set the eval dataloader to use by Trainer.
It supports user defined data-loader set as an attribute in the constructor.
When the attribute is None, The data-loader is defined using eval_dataset
and the `data_loader_engine` specified in Training Arguments.
"""
if self.eval_dataloader is not None:
return self.eval_dataloader
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
assert self.schema is not None, "schema is required to generate Eval Dataloader"
return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema(
self.schema,
self.eval_dataset_or_path,
self.args.per_device_eval_batch_size,
max_sequence_length=self.args.max_sequence_length,
drop_last=self.args.dataloader_drop_last,
shuffle=False,
shuffle_buffer_size=self.args.shuffle_buffer_size,
)
[docs] def num_examples(self, dataloader: DataLoader):
"""
Overriding :obj:`Trainer.num_examples()` method because
the data loaders for this project do not return the dataset size,
but the number of steps. So we estimate the dataset size here
by multiplying the number of steps * batch size
"""
"""
if dataloader == self.get_train_dataloader():
batch_size = self.args.per_device_train_batch_size
else:
batch_size = self.args.per_device_eval_batch_size
"""
return len(dataloader) * dataloader._batch_size
[docs] def reset_lr_scheduler(self) -> None:
"""
Resets the LR scheduler of the previous :obj:`Trainer.train()` call,
so that a new LR scheduler one is created by the next :obj:`Trainer.train()` call.
This is important for LR schedules like `get_linear_schedule_with_warmup()`
which decays LR to 0 in the end of the train
"""
self.lr_scheduler = None
[docs] def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
# flexibility in scheduler with num_cycles as hyperparams
if self.lr_scheduler is None:
self.lr_scheduler = self.get_scheduler(
self.args.lr_scheduler_type,
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=num_training_steps,
num_cycles=self.args.learning_rate_num_cosine_cycles_by_epoch
* self.args.num_train_epochs,
)
# Override the method get_scheduler to accept num_cycle params ?
# The advantage is to use the unified HF API with many scheduler
# we can also send a PR to HF ?
[docs] @staticmethod
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: Optional[int] = 0.5,
):
"""
Unified API to get any scheduler from its name.
Parameters
----------
name: (:obj:`str` or `:obj:`SchedulerType`)
The name of the scheduler to use.
optimizer: (:obj:`torch.optim.Optimizer`)
The optimizer that will be used during training.
num_warmup_steps: (:obj:`int`, `optional`)
The number of warmup steps to do. This is not required by all schedulers
(hence the argument being optional),
the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps: (:obj:`int`, `optional`)
The number of training steps to do. This is not required by all schedulers
(hence the argument being optional),
the function will raise an error if it's unset and the scheduler type requires it.
num_cycles: (:obj:`int`, `optional`)
The number of waves in the cosine schedule /
hard restarts to use for cosine scheduler
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if "num_cycles" in inspect.signature(schedule_func).parameters:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
)
[docs] def prediction_step(
self,
model: torch.nn.Module,
inputs: Dict[str, torch.Tensor],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[
Optional[float],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[Dict[str, Any]],
]:
"""
Overriding :obj:`Trainer.prediction_step()`
to provide more flexibility to unpack results from the model,
like returning labels that are not exactly one input feature
model
"""
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
if self.use_amp:
with autocast():
outputs = model(inputs, training=False)
else:
outputs = model(inputs, training=False)
loss = outputs["loss"].mean().detach()
if prediction_loss_only:
return (loss, None, None, None)
predictions = outputs["predictions"].detach()
labels = outputs["labels"].detach()
# TODO: define metadata dict in the model for logging
# other_outputs = {
# k: v.detach() if isinstance(v, torch.Tensor) else v
# for k, v in outputs.items()
# if k not in ignore_keys + ["loss", "predictions", "labels"]
# }
other_outputs = None
return (loss, predictions, labels, other_outputs)
[docs] def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: Optional[str] = "eval",
) -> EvalLoopOutput:
"""
Overriding :obj:`Trainer.prediction_loop()`
(shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`)
to provide more flexibility to work with streaming metrics
(computed at each eval batch) and
to log with the outputs of the model
(e.g. prediction scores, prediction metadata, attention weights)
Parameters
----------
dataloader: DataLoader
DataLoader object to use to iterate over evaluation data
description: str
Parameter to describe the evaluation experiment.
e.g: `Prediction`, `test`
prediction_loss_only: Optional[bool]
Whether or not to return the loss only.
by default None
ignore_keys: Optional[List[str]]
Columns not accepted by the ``model.forward()`` method
are automatically removed.
by default None
metric_key_prefix: Optional[str]
Prefix to use when logging evaluation metrics.
by default `eval`
"""
prediction_loss_only = (
prediction_loss_only
if prediction_loss_only is not None
else self.args.prediction_loss_only
)
# set the model
model = self.model.module
# reset metrics for the dataset (Train, Valid or Test)
if self.compute_metrics:
model.reset_metrics()
if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__")
batch_size = dataloader._batch_size
logger.info("***** Running %s *****", description)
logger.info(" Batch size = %d", batch_size)
preds_item_ids_scores_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
if metric_key_prefix == "train" and self.args.eval_steps_on_train_set:
num_examples = self.args.eval_steps_on_train_set * batch_size
else:
num_examples = self.num_examples(dataloader)
logger.info(" Num sessions (examples) = %d", num_examples)
model.eval()
self.callback_handler.eval_dataloader = dataloader
# Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_item_ids_scores_host = None
labels_host = None
# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds_item_ids_scores = None
all_labels = None
# Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0
# Iterate over dataloader
for step, inputs in enumerate(dataloader):
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# Limits the number of evaluation steps on train set (which is usually larger)
if (
metric_key_prefix == "train"
and self.args.eval_steps_on_train_set > 0
and step + 1 > self.args.eval_steps_on_train_set
):
break
loss, preds, labels, outputs = self.prediction_step(
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
)
# Updates metrics
# TODO: compute metrics each N eval_steps to speedup evaluation
metrics_results_detailed = None
if self.compute_metrics:
if step % self.args.compute_metrics_each_n_steps == 0:
metrics_results_detailed = model.calculate_metrics(
preds, labels, mode=metric_key_prefix, forward=False, call_body=False
)
# Update containers on host
if loss is not None:
losses = self._nested_gather(loss.repeat(batch_size))
losses_host = (
losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
)
if labels is not None:
labels = self._pad_across_processes(labels)
labels = self._nested_gather(labels)
labels_host = (
labels
if labels_host is None
else nested_concat(labels_host, labels, padding_index=0)
)
if preds is not None and self.args.predict_top_k > 0:
preds_sorted_item_scores, preds_sorted_item_ids = torch.topk(
preds, k=self.args.predict_top_k, dim=-1
)
self._maybe_log_predictions(
labels,
preds_sorted_item_ids,
preds_sorted_item_scores,
# outputs["pred_metadata"],
metrics_results_detailed,
metric_key_prefix,
)
# The output predictions will be a tuple with the ranked top-n item ids,
# and item recommendation scores
preds_item_ids_scores = (
preds_sorted_item_ids,
preds_sorted_item_scores,
)
preds_item_ids_scores_host = (
preds_item_ids_scores
if preds_item_ids_scores_host is None
else nested_concat(
preds_item_ids_scores_host,
preds_item_ids_scores,
)
)
self.control = self.callback_handler.on_prediction_step(
self.args, self.state, self.control
)
# Gather all tensors and put them back on the CPU
# if we have done enough accumulation steps.
if (
self.args.eval_accumulation_steps is not None
and (step + 1) % self.args.eval_accumulation_steps == 0
):
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = (
losses
if all_losses is None
else np.concatenate((all_losses, losses), axis=0)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels
if all_labels is None
else nested_concat(all_labels, labels, padding_index=0)
)
if preds_item_ids_scores_host is not None:
preds_item_ids_scores = nested_numpify(preds_item_ids_scores_host)
all_preds_item_ids_scores = (
preds_item_ids_scores
if all_preds_item_ids_scores is None
else nested_concat(
all_preds_item_ids_scores,
preds_item_ids_scores,
)
)
# Set back to None to begin a new accumulation
losses_host, preds_item_ids_scores_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = (
losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=0)
)
if preds_item_ids_scores_host is not None:
preds_item_ids_scores = nested_numpify(preds_item_ids_scores_host)
all_preds_item_ids_scores = (
preds_item_ids_scores
if all_preds_item_ids_scores is None
else nested_concat(
all_preds_item_ids_scores,
preds_item_ids_scores,
)
)
# Get Number of samples :
# the data loaders for this project do not return the dataset size,
num_samples = observed_num_examples
# Number of losses has been rounded to a multiple of batch_size
# and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate.
if all_losses is not None:
all_losses = all_losses[:num_samples]
if all_preds_item_ids_scores is not None:
all_preds_item_ids_scores = nested_truncate(all_preds_item_ids_scores, num_samples)
if all_labels is not None:
all_labels = nested_truncate(all_labels, num_samples)
# Get metrics :
metrics = {}
# Computing the metrics results as the average of all steps
if self.compute_metrics:
streaming_metrics_results = model.compute_metrics(mode=metric_key_prefix)
streaming_metrics_results_flattened = process_metrics(
streaming_metrics_results, prefix=metric_key_prefix + "/"
)
metrics = {**metrics, **streaming_metrics_results_flattened}
metrics[f"{metric_key_prefix}/loss"] = all_losses.mean().item()
return EvalLoopOutput(
predictions=all_preds_item_ids_scores,
label_ids=all_labels,
metrics=metrics,
num_samples=num_examples,
)
def _save_model_and_checkpoint(self, save_model_class=False):
"""
Save the serialized model + trainer and random states.
Parameters
----------
save_model_class: Optional[bool]
Whether to save the Model class or not.
by default False
"""
import os
try:
import cloudpickle
except ImportError:
cloudpickle = None
logger.info("Saving model...")
output_dir = os.path.join(
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
)
# save model parameters
self._save_checkpoint(self.model, trial=None, metrics=None)
# save the serialized model
if save_model_class:
# TODO : fix serialization of DatasetSchema object
if cloudpickle is None:
raise ValueError("cloudpickle is required to save model class")
with open(os.path.join(output_dir, "model_class.pkl"), "wb") as out:
cloudpickle.dump(self.model.module, out)
[docs] def load_model_trainer_states_from_checkpoint(self, checkpoint_path, model=None):
"""
This method loads the checkpoints states of the model, trainer and random states.
If model is None the serialized model class is loaded from checkpoint.
It does not loads the optimizer and LR scheduler states (for that call trainer.train()
with resume_from_checkpoint argument for a complete load)
Parameters
----------
checkpoint_path: str
Path to the checkpoint directory.
model: Optional[Model]
Model class used by Trainer. by default None
"""
import os
if model is None:
try:
import cloudpickle
except ImportError:
raise ImportError("cloudpickle is required to load model class")
logger.info("Loading model class")
model = cloudpickle.load(open(os.path.join(checkpoint_path, "model_class.pkl"), "rb"))
self.model = HFWrapper(model)
logger.info("Loading weights of previously trained model")
# Restoring model weights
self.model.load_state_dict(
# torch.load(os.path.join(training_args.output_dir, "pytorch_model.bin"))
torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
)
# Restoring random state
rng_file = os.path.join(checkpoint_path, "rng_state.pth")
checkpoint_rng_state = torch.load(rng_file)
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
# Restoring AMP scaler
if self.use_amp:
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint_path, "scaler.pt")))
@property
def log_predictions_callback(self) -> Callable:
return self.__log_predictions_callback
@log_predictions_callback.setter
def log_predictions_callback(self, var: Callable):
self.__log_predictions_callback = var
def _maybe_log_predictions(
self,
labels: torch.Tensor,
pred_item_ids: torch.Tensor,
pred_item_scores: torch.Tensor,
metrics: Dict[str, np.ndarray],
metric_key_prefix: str,
):
"""
If --log_predictions is enabled, calls a callback function to
log predicted item ids, scores, metadata and metrics.
Parameters
----------
labels: torch.Tensor
True labels.
pred_item_ids: torch.Tensor
The predicted items ids. if top_k is set:
we return to top-k items for each
next-item prediction.
pred_item_scores: torch.Tensor
The prediction scores, if top_k is set:
we return to top-k predictions for each
next-item prediction.
metrics: Dict[str, np.ndarray]
Dictionary of metrics computed by Model.
metric_key_prefix: str
Prefix to use when logging evaluation metrics.
by default `eval`
"""
# TODO Add pred_metadata: Dict[str, torch.Tensor],
if self.args.log_predictions and self.log_predictions_callback is not None:
# Converting torch Tensors to NumPy and callback predictions logging function
# preds_metadata = {k: v.cpu().numpy() for k, v in pred_metadata.items()}
self.log_predictions_callback(
labels=labels.cpu().numpy(),
pred_item_ids=pred_item_ids.cpu().numpy(),
pred_item_scores=pred_item_scores.cpu()
.numpy()
.astype(np.float32), # Because it is float16 when --fp16
# preds_metadata=preds_metadata,
metrics=metrics,
dataset_type=metric_key_prefix,
)
def _increment_past_global_steps(self, current_global_step: int):
self.past_global_steps += current_global_step
def _get_general_global_step(self) -> int:
general_global_step = self.past_global_steps
if self.model.training:
general_global_step += self.state.global_step
return general_global_step
[docs] def log(self, logs: Dict[str, float]) -> None:
# Ensuring that eval metrics are prefixed as "eval_" so that the HF integration loggers
# do not prefix metrics names with 'train/' (as 'train/' is always added when not eval)
logs = {re.sub("^eval/", "eval_", k).replace("train/", ""): v for k, v in logs.items()}
if not self.incremental_logging:
super().log(logs)
else:
# If Incremental logging is enabled, ensures that global steps are always
# incremented after train() calls
# so that metrics are logger with no overlap on W&B and Tensorboard
if self.state.epoch is not None:
logs["epoch"] = round(self.state.epoch, 2)
# As state.global_step is also used for the learning rate schedules,
# we create a copy only for logging
state_copy = deepcopy(self.state)
state_copy.global_step = self._get_general_global_step()
output = {**logs, **{"step": state_copy.global_step}}
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, state_copy, self.control, logs)
[docs]def process_metrics(metrics, prefix="", to_cpu=True):
metrics_proc = {}
for root_key, root_value in metrics.items():
if isinstance(root_value, dict):
flattened_metrics = process_metrics(root_value, prefix=prefix, to_cpu=to_cpu)
metrics_proc = {**metrics_proc, **flattened_metrics}
else:
value = root_value.cpu().numpy().item() if to_cpu else root_value
metrics_proc[f"{prefix}{root_key}"] = value
return metrics_proc
[docs]class IncrementalLoggingCallback(TrainerCallback):
"""
An :class:`~transformers.TrainerCallback` that changes the state of the Trainer
on specific hooks for the purpose of the incremental logging
Parameters
----------
trainer: Trainer
"""
def __init__(self, trainer: Trainer):
self.trainer = trainer
[docs] def on_train_begin(self, args, state, control, model=None, **kwargs):
pass
[docs] def on_train_end(self, args, state, control, model=None, **kwargs):
# Increments the global steps for logging with the global steps of the last train()
self.trainer._increment_past_global_steps(state.global_step)
[docs] def on_epoch_end(self, args, state, control, model=None, **kwargs):
# Evaluates on eval set
# self.trainer.evaluate()
pass
[docs]class DatasetMock(Dataset, Sized):
"""
Mock to inform HF Trainer that the dataset is sized,
and can be obtained via the generated/provided data loader
"""
def __init__(self, nsteps=1):
self.nsteps = nsteps
def __len__(self):
return self.nsteps
[docs]class HFWrapper(torch.nn.Module):
"""
Prepare the signature of the forward method
as required by HF Trainer
"""
def __init__(self, model):
super().__init__()
self.module = model
[docs] def forward(self, *args, **kwargs):
inputs = kwargs
return self.module(inputs, *args)