Source code for transformers4rec.torch.trainer

# type: ignore

# Copyright (c) 2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import inspect
import random
import re
from import Sized
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.cuda.amp import autocast
from torch.optim import Optimizer
from import DataLoader
from import Dataset
from transformers import Trainer as BaseTrainer
from transformers.optimization import TYPE_TO_SCHEDULER_FUNCTION
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import find_batch_size
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalLoopOutput, SchedulerType
from transformers.utils import logging

from merlin_standard_lib import Schema

from ..config.trainer import T4RecTrainingArguments
from .model.base import Model
from .model.prediction_task import NextItemPredictionTask
from .utils.data_utils import T4RecDataLoader
from .utils.torch_utils import nested_concat, nested_detach, nested_numpify, nested_truncate

logger = logging.get_logger(__name__)

[docs]class Trainer(BaseTrainer): """ An :class:`~transformers.Trainer` specialized for sequential recommendation including (session-based and sequtial recommendation) Parameters ---------- model: Model The Model defined using Transformers4Rec api. args: T4RecTrainingArguments The training arguments needed to setup training and evaluation experiments. schema: Optional[Dataset.schema], optional The schema object including features to use and their properties. by default None train_dataset_or_path: Optional[Union[str, Dataset]], optional Path of parquet files or DataSet to use for training. by default None eval_dataset_or_path: Optional[str, Dataset], optional Path of parquet files or DataSet to use for evaluation. by default None train_dataloader: Optional[DataLoader], optional The data generator to use for training. by default None eval_dataloader: Optional[DataLoader], optional The data generator to use for evaluation. by default None compute_metrics: Optional[bool], optional Whether to compute metrics defined by Model class or not. by default None incremental_logging: bool Whether to enable incremental logging or not. If True, it ensures that global steps are incremented over many `trainer.train()` calls, so that train and eval metrics steps do not overlap and can be seen properly in reports like W&B and Tensorboard """ def __init__( self, model: Model, args: T4RecTrainingArguments, schema: Schema = None, train_dataset_or_path=None, eval_dataset_or_path=None, test_dataset_or_path=None, train_dataloader: Optional[DataLoader] = None, eval_dataloader: Optional[DataLoader] = None, test_dataloader: Optional[DataLoader] = None, callbacks: Optional[List[TrainerCallback]] = [], compute_metrics=None, incremental_logging: bool = False, **kwargs, ): mock_dataset = DatasetMock() self.incremental_logging = incremental_logging if self.incremental_logging: self.past_global_steps = 0 incremental_logging_callback = IncrementalLoggingCallback(self) callbacks.append(incremental_logging_callback) super(Trainer, self).__init__( model=model, args=args, train_dataset=mock_dataset, eval_dataset=mock_dataset, callbacks=callbacks, **kwargs, ) self.compute_metrics = compute_metrics self.train_dataset_or_path = train_dataset_or_path self.eval_dataset_or_path = eval_dataset_or_path self.test_dataset_or_path = test_dataset_or_path self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader self.test_dataloader = test_dataloader self.schema = schema self.incremental_logging = incremental_logging # Set global_rank and global_size if DDP is used if self.args.local_rank != -1: self.device = self.local_rank = self.args.local_rank self.global_size = self.args.world_size else: self.device = self.local_rank = None self.global_size = None
[docs] def get_train_dataloader(self): """ Set the train dataloader to use by Trainer. It supports user defined data-loader set as an attribute in the constructor. When the attribute is None, The data-loader is defined using train_dataset and the `data_loader_engine` specified in Training Arguments. """ if self.train_dataloader is not None: return self.train_dataloader assert self.schema is not None, "schema is required to generate Train Dataloader" return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema( self.schema, self.train_dataset_or_path, self.args.per_device_train_batch_size, max_sequence_length=self.args.max_sequence_length, drop_last=self.args.dataloader_drop_last, shuffle=True, shuffle_buffer_size=self.args.shuffle_buffer_size, global_rank=self.local_rank, global_size=self.global_size, device=self.device, )
[docs] def get_eval_dataloader(self, eval_dataset=None): """ Set the eval dataloader to use by Trainer. It supports user defined data-loader set as an attribute in the constructor. When the attribute is None, The data-loader is defined using eval_dataset and the `data_loader_engine` specified in Training Arguments. """ if self.eval_dataloader is not None: return self.eval_dataloader if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset assert self.schema is not None, "schema is required to generate Eval Dataloader" return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema( self.schema, self.eval_dataset_or_path, self.args.per_device_eval_batch_size, max_sequence_length=self.args.max_sequence_length, drop_last=self.args.dataloader_drop_last, shuffle=False, shuffle_buffer_size=self.args.shuffle_buffer_size, global_rank=self.local_rank, global_size=self.global_size, device=self.device, )
[docs] def get_test_dataloader(self, test_dataset=None): """ Set the test dataloader to use by Trainer. It supports user defined data-loader set as an attribute in the constructor. When the attribute is None, The data-loader is defined using test_dataset and the `data_loader_engine` specified in Training Arguments. """ if self.test_dataloader is not None: return self.test_dataloader if test_dataset is None and self.test_dataset_or_path is None: raise ValueError("Trainer: test requires an test_dataset.") test_dataset = test_dataset if test_dataset is not None else self.test_dataset_or_path assert self.schema is not None, "schema is required to generate Test Dataloader" return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema( self.schema, test_dataset, self.args.per_device_eval_batch_size, max_sequence_length=self.args.max_sequence_length, drop_last=self.args.dataloader_drop_last, shuffle=False, shuffle_buffer_size=self.args.shuffle_buffer_size, global_rank=self.local_rank, global_size=self.global_size, device=self.device, )
[docs] def num_examples(self, dataloader: DataLoader): """ Overriding :obj:`Trainer.num_examples()` method because the data loaders for this project do not return the dataset size, but the number of steps. So we estimate the dataset size here by multiplying the number of steps * batch size """ """ if dataloader == self.get_train_dataloader(): batch_size = self.args.per_device_train_batch_size else: batch_size = self.args.per_device_eval_batch_size """ return len(dataloader) * dataloader._batch_size
[docs] def reset_lr_scheduler(self) -> None: """ Resets the LR scheduler of the previous :obj:`Trainer.train()` call, so that a new LR scheduler one is created by the next :obj:`Trainer.train()` call. This is important for LR schedules like `get_linear_schedule_with_warmup()` which decays LR to 0 in the end of the train """ self.lr_scheduler = None
[docs] def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): # flexibility in scheduler with num_cycles as hyperparams if self.lr_scheduler is None: self.lr_scheduler = self.get_scheduler( self.args.lr_scheduler_type, optimizer=self.optimizer if optimizer is None else optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, num_cycles=self.args.learning_rate_num_cosine_cycles_by_epoch * self.args.num_train_epochs, )
# Override the method get_scheduler to accept num_cycle params ? # The advantage is to use the unified HF API with many scheduler # we can also send a PR to HF ?
[docs] @staticmethod def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: Optional[int] = 0.5, ): """ Unified API to get any scheduler from its name. Parameters ---------- name: (:obj:`str` or `:obj:`SchedulerType`) The name of the scheduler to use. optimizer: (:obj:`torch.optim.Optimizer`) The optimizer that will be used during training. num_warmup_steps: (:obj:`int`, `optional`) The number of warm-up steps to perform. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_training_steps: (:obj:`int`, `optional`) The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_cycles: (:obj:`int`, `optional`) The number of waves in the cosine schedule / hard restarts to use for cosine scheduler """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return schedule_func(optimizer) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") if "num_cycles" in inspect.signature(schedule_func).parameters: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, ) return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps )
[docs] def compute_loss(self, model, inputs, return_outputs=False): """ Overriding :obj:`Trainer.compute_loss()` To allow for passing the targets to the model's forward method How the loss is computed by Trainer. By default, all Transformers4Rec models return a dictionary of three elements {'loss', 'predictions', and 'labels} """ inputs, targets = inputs outputs = model(inputs, targets=targets, training=True) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] if "loss" not in outputs: raise ValueError( "The model did not return a loss from the inputs, only the following keys: " f"{','.join(outputs.keys())}. " "For reference, the inputs it received are {','.join(inputs.keys())}." ) loss = outputs["loss"] return (loss, outputs) if return_outputs else loss
[docs] def prediction_step( self, model: torch.nn.Module, inputs: Dict[str, torch.Tensor], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, training: bool = False, testing: bool = True, ) -> Tuple[ Optional[float], Optional[torch.Tensor], Optional[torch.Tensor], Optional[Dict[str, Any]], ]: """ Overriding :obj:`Trainer.prediction_step()` to provide more flexibility to unpack results from the model, like returning labels that are not exactly one input feature model """ inputs = self._prepare_inputs(inputs) inputs, targets = inputs with torch.no_grad(): if self._use_cuda_amp: with autocast(): outputs = model(inputs, targets=targets, training=training, testing=testing) else: outputs = model(inputs, targets=targets, training=training, testing=testing) if testing: loss = outputs["loss"].mean().detach() labels = nested_detach(outputs["labels"]) predictions = nested_detach(outputs["predictions"]) else: loss, labels = None, None predictions = nested_detach(outputs) if prediction_loss_only: return (loss, None, None, None) # TODO: define metadata dict in the model for logging # other_outputs = { # k: v.detach() if isinstance(v, torch.Tensor) else v # for k, v in outputs.items() # if k not in ignore_keys + ["loss", "predictions", "labels"] # } other_outputs = None return (loss, predictions, labels, other_outputs)
@property def _use_cuda_amp(self): """ Check for CUDA AMP that is compatible with versions of the transformers package before and after version 4.20 (which renamed the property `use_amp` to `use_cuda_amp`) """ try: return self.use_cuda_amp except AttributeError: return self.use_amp
[docs] def evaluation_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: Optional[str] = "eval", ) -> EvalLoopOutput: """ Overriding :obj:`Trainer.prediction_loop()` (shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`) to provide more flexibility to work with streaming metrics (computed at each eval batch) and to log with the outputs of the model (e.g. prediction scores, prediction metadata, attention weights) Parameters ---------- dataloader: DataLoader DataLoader object to use to iterate over evaluation data description: str Parameter to describe the evaluation experiment. e.g: `Prediction`, `test` prediction_loss_only: Optional[bool] Whether or not to return the loss only. by default None ignore_keys: Optional[List[str]] Columns not accepted by the ``model.forward()`` method are automatically removed. by default None metric_key_prefix: Optional[str] Prefix to use when logging evaluation metrics. by default `eval` """ prediction_loss_only = ( prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only ) if description == "Prediction": testing = False else: testing = True # set the model model = self.model # reset metrics for the dataset (Train, Valid or Test) if self.compute_metrics: model.reset_metrics() if not isinstance(dataloader.dataset, raise ValueError("dataset must implement __len__") batch_size = dataloader._batch_size"***** Running %s *****", description)" Batch size = %d", batch_size) preds_host: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None if metric_key_prefix == "train" and self.args.eval_steps_on_train_set: num_examples = self.args.eval_steps_on_train_set * batch_size else: num_examples = self.num_examples(dataloader)" Num sessions (examples) = %d", num_examples) model.eval() self.callback_handler.eval_dataloader = dataloader # Initialize containers # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) losses_host = None preds_host = None labels_host = None # losses/preds/labels on CPU (final containers) all_losses = None all_preds = None all_labels = None # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 # Iterate over dataloader for step, inputs in enumerate(dataloader): # Update the observed num examples observed_batch_size = find_batch_size(inputs) if observed_batch_size is not None: observed_num_examples += observed_batch_size # Limits the number of evaluation steps on train set (which is usually larger) if ( metric_key_prefix == "train" and self.args.eval_steps_on_train_set > 0 and step + 1 > self.args.eval_steps_on_train_set ): break loss, preds, labels, outputs = self.prediction_step( model, inputs, prediction_loss_only, ignore_keys=ignore_keys, testing=testing, ) # Updates metrics # TODO: compute metrics each N eval_steps to speedup evaluation metrics_results_detailed = None if self.compute_metrics is not None and testing: if step % self.args.compute_metrics_each_n_steps == 0: metrics_results_detailed = model.calculate_metrics(preds, labels) # Update containers on host if loss is not None: losses = self._nested_gather(loss.repeat(batch_size)) losses_host = ( losses if losses_host is None else, losses), dim=0) ) if labels is not None: labels = self._pad_across_processes(labels) labels = self._nested_gather(labels) labels_host = ( labels if labels_host is None else nested_concat(labels_host, labels, padding_index=0) ) if preds is not None and self.args.predict_top_k > 0: if self.model.top_k: raise ValueError( "you cannot set top_k argument in the model class and the, " "predict_top_k in the trainer at the same time. Please ensure setting " "only predict_top_k" ) # get outputs of next-item scores if isinstance(preds, dict): assert any( isinstance(x, NextItemPredictionTask) for x in model.prediction_tasks ), "Top-k prediction is specific to NextItemPredictionTask, " "Please ensure `self.args.predict_top_k == 0` " pred_next_item = preds["next-item"] else: assert isinstance( model.prediction_tasks[0], NextItemPredictionTask ), "Top-k prediction is specific to NextItemPredictionTask, " "Please ensure `self.args.predict_top_k == 0` " pred_next_item = preds preds_sorted_item_scores, preds_sorted_item_ids = torch.topk( pred_next_item, k=self.args.predict_top_k, dim=-1 ) self._maybe_log_predictions( labels, preds_sorted_item_ids, preds_sorted_item_scores, # outputs["pred_metadata"], metrics_results_detailed, metric_key_prefix, ) # The output predictions will be a tuple with the ranked top-n item ids, # and item recommendation scores if isinstance(preds, dict): preds["next-item"] = ( preds_sorted_item_ids, preds_sorted_item_scores, ) else: preds = ( preds_sorted_item_ids, preds_sorted_item_scores, ) preds_host = ( preds if preds_host is None else nested_concat( preds_host, preds, ) ) self.control = self.callback_handler.on_prediction_step( self.args, self.state, self.control ) # Gather all tensors and put them back on the CPU # if we have done enough accumulation steps. if ( self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0 ): if losses_host is not None: losses = nested_numpify(losses_host) all_losses = ( losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) ) if labels_host is not None: labels = nested_numpify(labels_host) all_labels = ( labels if all_labels is None else nested_concat(all_labels, labels, padding_index=0) ) if preds_host is not None: preds = nested_numpify(preds_host) all_preds = ( preds if all_preds is None else nested_concat( all_preds, preds, ) ) # Set back to None to begin a new accumulation losses_host, preds_host, labels_host = None, None, None if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU if losses_host is not None: losses = nested_numpify(losses_host) all_losses = ( losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) ) if labels_host is not None: labels = nested_numpify(labels_host) all_labels = ( labels if all_labels is None else nested_concat(all_labels, labels, padding_index=0) ) if preds_host is not None: preds_host = nested_numpify(preds_host) all_preds = ( preds_host if all_preds is None else nested_concat( all_preds, preds_host, ) ) # Get Number of samples : # the data loaders for this project do not return the dataset size, num_samples = observed_num_examples # Number of losses has been rounded to a multiple of batch_size # and in a distributed training, the number of # samplers has been rounded to a multiple of batch_size, so we truncate. if all_losses is not None: all_losses = all_losses[:num_samples] if all_preds is not None: all_preds = nested_truncate(all_preds, num_samples) if all_labels is not None: all_labels = nested_truncate(all_labels, num_samples) # Get metrics : metrics = {} # Computing the metrics results as the average of all steps if self.compute_metrics and testing: streaming_metrics_results = model.compute_metrics(mode=metric_key_prefix) streaming_metrics_results_flattened = process_metrics( streaming_metrics_results, prefix=metric_key_prefix + "_/" ) metrics = {**metrics, **streaming_metrics_results_flattened} if testing: metrics[f"{metric_key_prefix}_/loss"] = all_losses.mean().item() return EvalLoopOutput( predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_examples, )
def _save_model_and_checkpoint(self, save_model_class=False): """ Save the serialized model + trainer and random states. Parameters ---------- save_model_class: Optional[bool] Whether to save the Model class or not. by default False """ import os"Saving model...") output_dir = os.path.join( self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" ) # save model parameters self._save_checkpoint(self.model, trial=None, metrics=None) # save the serialized model if save_model_class: # TODO : fix serialization of DatasetSchema object
[docs] def load_model_trainer_states_from_checkpoint(self, checkpoint_path, model=None): """ This method loads the checkpoints states of the model, trainer and random states. If model is None the serialized model class is loaded from checkpoint. It does not loads the optimizer and LR scheduler states (for that call trainer.train() with resume_from_checkpoint argument for a complete load) Parameters ---------- checkpoint_path: str Path to the checkpoint directory. model: Optional[Model] Model class used by Trainer. by default None """ import os if model is None: try: import cloudpickle except ImportError: raise ImportError("cloudpickle is required to load model class")"Loading model class") model = cloudpickle.load( open(os.path.join(checkpoint_path, "t4rec_model_class.pkl"), "rb") ) self.model = model"Loading weights of previously trained model") # Restoring model weights self.model.load_state_dict( # torch.load(os.path.join(training_args.output_dir, "pytorch_model.bin")) torch.load(os.path.join(checkpoint_path, "pytorch_model.bin")) ) # Restoring random state rng_file = os.path.join(checkpoint_path, "rng_state.pth") checkpoint_rng_state = torch.load(rng_file) random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) # Restoring AMP scaler if self._use_cuda_amp: self.scaler.load_state_dict(torch.load(os.path.join(checkpoint_path, "")))
@property def log_predictions_callback(self) -> Callable: return self.__log_predictions_callback @log_predictions_callback.setter def log_predictions_callback(self, var: Callable): self.__log_predictions_callback = var def _maybe_log_predictions( self, labels: torch.Tensor, pred_item_ids: torch.Tensor, pred_item_scores: torch.Tensor, metrics: Dict[str, np.ndarray], metric_key_prefix: str, ): """ If --log_predictions is enabled, calls a callback function to log predicted item ids, scores, metadata and metrics. Parameters ---------- labels: torch.Tensor True labels. pred_item_ids: torch.Tensor The predicted items ids. if top_k is set: we return to top-k items for each next-item prediction. pred_item_scores: torch.Tensor The prediction scores, if top_k is set: we return to top-k predictions for each next-item prediction. metrics: Dict[str, np.ndarray] Dictionary of metrics computed by Model. metric_key_prefix: str Prefix to use when logging evaluation metrics. by default `eval` """ # TODO Add pred_metadata: Dict[str, torch.Tensor], if self.args.log_predictions and self.log_predictions_callback is not None: # Converting torch Tensors to NumPy and callback predictions logging function # preds_metadata = {k: v.cpu().numpy() for k, v in pred_metadata.items()} self.log_predictions_callback( labels=labels.cpu().numpy(), pred_item_ids=pred_item_ids.cpu().numpy(), pred_item_scores=pred_item_scores.cpu() .numpy() .astype(np.float32), # Because it is float16 when --fp16 # preds_metadata=preds_metadata, metrics=metrics, dataset_type=metric_key_prefix, ) def _increment_past_global_steps(self, current_global_step: int): self.past_global_steps += current_global_step def _get_general_global_step(self) -> int: general_global_step = self.past_global_steps if general_global_step += self.state.global_step return general_global_step
[docs] def log(self, logs: Dict[str, float]) -> None: # Ensuring that eval metrics are prefixed as "eval_" so that the HF integration loggers # do not prefix metrics names with 'train/' (as 'train/' is always added when not eval) logs = {re.sub("^eval/", "eval_", k).replace("train/", ""): v for k, v in logs.items()} if not self.incremental_logging: super().log(logs) else: # If Incremental logging is enabled, ensures that global steps are always # incremented after train() calls # so that metrics are logger with no overlap on W&B and Tensorboard if self.state.epoch is not None: logs["epoch"] = round(self.state.epoch, 2) # As state.global_step is also used for the learning rate schedules, # we create a copy only for logging state_copy = deepcopy(self.state) state_copy.global_step = self._get_general_global_step() output = {**logs, **{"step": state_copy.global_step}} self.state.log_history.append(output) self.control = self.callback_handler.on_log(self.args, state_copy, self.control, logs)
[docs]def process_metrics(metrics, prefix="", to_cpu=True): metrics_proc = {} for root_key, root_value in metrics.items(): if isinstance(root_value, dict): flattened_metrics = process_metrics(root_value, prefix=prefix, to_cpu=to_cpu) metrics_proc = {**metrics_proc, **flattened_metrics} else: value = root_value.cpu().numpy().item() if to_cpu else root_value metrics_proc[f"{prefix}{root_key}"] = value return metrics_proc
[docs]class IncrementalLoggingCallback(TrainerCallback): """ An :class:`~transformers.TrainerCallback` that changes the state of the Trainer on specific hooks for the purpose of the incremental logging Parameters ---------- trainer: Trainer """ def __init__(self, trainer: Trainer): self.trainer = trainer
[docs] def on_train_begin(self, args, state, control, model=None, **kwargs): pass
[docs] def on_train_end(self, args, state, control, model=None, **kwargs): # Increments the global steps for logging with the global steps of the last train() self.trainer._increment_past_global_steps(state.global_step)
[docs] def on_epoch_end(self, args, state, control, model=None, **kwargs): # Evaluates on eval set # self.trainer.evaluate() pass
[docs]class DatasetMock(Dataset, Sized): """ Mock to inform HF Trainer that the dataset is sized, and can be obtained via the generated/provided data loader """ def __init__(self, nsteps=1): self.nsteps = nsteps def __len__(self): return self.nsteps