# type: ignore
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import inspect
import random
import re
from collections.abc import Sized
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.cuda.amp import autocast
from torch.optim import Optimizer
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from transformers import Trainer as BaseTrainer
from transformers.optimization import TYPE_TO_SCHEDULER_FUNCTION
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import find_batch_size
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalLoopOutput, SchedulerType
from transformers.utils import logging
from merlin_standard_lib import Schema
from ..config.trainer import T4RecTrainingArguments
from .model.base import Model
from .model.prediction_task import NextItemPredictionTask
from .utils.data_utils import T4RecDataLoader
from .utils.torch_utils import nested_concat, nested_detach, nested_numpify, nested_truncate
logger = logging.get_logger(__name__)
[docs]class Trainer(BaseTrainer):
    """
    An :class:`~transformers.Trainer` specialized for sequential recommendation
    including (session-based and sequtial recommendation)
    Parameters
    ----------
    model: Model
        The Model defined using Transformers4Rec api.
    args: T4RecTrainingArguments
        The training arguments needed to setup training and evaluation
        experiments.
    schema: Optional[Dataset.schema], optional
        The schema object including features to use and their properties.
        by default None
    train_dataset_or_path: Optional[Union[str, Dataset]], optional
        Path of parquet files or DataSet to use for training.
        by default None
    eval_dataset_or_path: Optional[str, Dataset], optional
        Path of parquet files or DataSet to use for evaluation.
        by default None
    train_dataloader: Optional[DataLoader], optional
        The data generator to use for training.
        by default None
    eval_dataloader: Optional[DataLoader], optional
        The data generator to use for evaluation.
        by default None
    compute_metrics: Optional[bool], optional
        Whether to compute metrics defined by Model class or not.
        by default None
    incremental_logging: bool
        Whether to enable incremental logging or not. If True, it ensures that
        global steps are incremented over many `trainer.train()` calls, so that
        train and eval metrics steps do not overlap and can be seen properly
        in reports like W&B and Tensorboard
    """
    def __init__(
        self,
        model: Model,
        args: T4RecTrainingArguments,
        schema: Schema = None,
        train_dataset_or_path=None,
        eval_dataset_or_path=None,
        test_dataset_or_path=None,
        train_dataloader: Optional[DataLoader] = None,
        eval_dataloader: Optional[DataLoader] = None,
        test_dataloader: Optional[DataLoader] = None,
        callbacks: Optional[List[TrainerCallback]] = [],
        compute_metrics=None,
        incremental_logging: bool = False,
        **kwargs,
    ):
        mock_dataset = DatasetMock()
        self.incremental_logging = incremental_logging
        if self.incremental_logging:
            self.past_global_steps = 0
            incremental_logging_callback = IncrementalLoggingCallback(self)
            callbacks.append(incremental_logging_callback)
        super(Trainer, self).__init__(
            model=model,
            args=args,
            train_dataset=mock_dataset,
            eval_dataset=mock_dataset,
            callbacks=callbacks,
            **kwargs,
        )
        self.compute_metrics = compute_metrics
        self.train_dataset_or_path = train_dataset_or_path
        self.eval_dataset_or_path = eval_dataset_or_path
        self.test_dataset_or_path = test_dataset_or_path
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        self.test_dataloader = test_dataloader
        self.schema = schema
        self.incremental_logging = incremental_logging
        # Set global_rank and global_size if DDP is used
        if self.args.local_rank != -1:
            self.device = self.local_rank = self.args.local_rank
            self.global_size = self.args.world_size
        else:
            self.device = self.local_rank = None
            self.global_size = None
[docs]    def get_train_dataloader(self):
        """
        Set the train dataloader to use by Trainer.
        It supports user defined data-loader set as an attribute in the constructor.
        When the attribute is None, The data-loader is defined using train_dataset
        and the `data_loader_engine` specified in Training Arguments.
        """
        if self.train_dataloader is not None:
            return self.train_dataloader
        assert self.schema is not None, "schema is required to generate Train Dataloader"
        return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema(
            self.schema,
            self.train_dataset_or_path,
            self.args.per_device_train_batch_size,
            max_sequence_length=self.args.max_sequence_length,
            drop_last=self.args.dataloader_drop_last,
            shuffle=True,
            shuffle_buffer_size=self.args.shuffle_buffer_size,
            global_rank=self.local_rank,
            global_size=self.global_size,
            device=self.device,
        ) 
[docs]    def get_eval_dataloader(self, eval_dataset=None):
        """
        Set the eval dataloader to use by Trainer.
        It supports user defined data-loader set as an attribute in the constructor.
        When the attribute is None, The data-loader is defined using eval_dataset
        and the `data_loader_engine` specified in Training Arguments.
        """
        if self.eval_dataloader is not None:
            return self.eval_dataloader
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        assert self.schema is not None, "schema is required to generate Eval Dataloader"
        return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema(
            self.schema,
            self.eval_dataset_or_path,
            self.args.per_device_eval_batch_size,
            max_sequence_length=self.args.max_sequence_length,
            drop_last=self.args.dataloader_drop_last,
            shuffle=False,
            shuffle_buffer_size=self.args.shuffle_buffer_size,
            global_rank=self.local_rank,
            global_size=self.global_size,
            device=self.device,
        ) 
[docs]    def get_test_dataloader(self, test_dataset=None):
        """
        Set the test dataloader to use by Trainer.
        It supports user defined data-loader set as an attribute in the constructor.
        When the attribute is None, The data-loader is defined using test_dataset
        and the `data_loader_engine` specified in Training Arguments.
        """
        if self.test_dataloader is not None:
            return self.test_dataloader
        if test_dataset is None and self.test_dataset_or_path is None:
            raise ValueError("Trainer: test requires an test_dataset.")
        test_dataset = test_dataset if test_dataset is not None else self.test_dataset_or_path
        assert self.schema is not None, "schema is required to generate Test Dataloader"
        return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema(
            self.schema,
            test_dataset,
            self.args.per_device_eval_batch_size,
            max_sequence_length=self.args.max_sequence_length,
            drop_last=self.args.dataloader_drop_last,
            shuffle=False,
            shuffle_buffer_size=self.args.shuffle_buffer_size,
            global_rank=self.local_rank,
            global_size=self.global_size,
            device=self.device,
        ) 
[docs]    def num_examples(self, dataloader: DataLoader):
        """
        Overriding :obj:`Trainer.num_examples()` method because
        the data loaders for this project do not return the dataset size,
        but the number of steps. So we estimate the dataset size here
        by multiplying the number of steps * batch size
        """
        """
        if dataloader == self.get_train_dataloader():
            batch_size = self.args.per_device_train_batch_size
        else:
            batch_size = self.args.per_device_eval_batch_size
        """
        return len(dataloader) * dataloader._batch_size 
[docs]    def reset_lr_scheduler(self) -> None:
        """
        Resets the LR scheduler of the previous :obj:`Trainer.train()` call,
        so that a new LR scheduler one is created by the next :obj:`Trainer.train()` call.
        This is important for LR schedules like `get_linear_schedule_with_warmup()`
        which decays LR to 0 in the end of the train
        """
        self.lr_scheduler = None 
[docs]    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
        # flexibility in scheduler with num_cycles as hyperparams
        if self.lr_scheduler is None:
            self.lr_scheduler = self.get_scheduler(
                self.args.lr_scheduler_type,
                optimizer=self.optimizer if optimizer is None else optimizer,
                num_warmup_steps=self.args.warmup_steps,
                num_training_steps=num_training_steps,
                num_cycles=self.args.learning_rate_num_cosine_cycles_by_epoch
                * self.args.num_train_epochs,
            ) 
    # Override the method get_scheduler to accept num_cycle params ?
    # The advantage is to use the unified HF API with many scheduler
    # we can also send a PR to HF ?
[docs]    @staticmethod
    def get_scheduler(
        name: Union[str, SchedulerType],
        optimizer: Optimizer,
        num_warmup_steps: Optional[int] = None,
        num_training_steps: Optional[int] = None,
        num_cycles: Optional[int] = 0.5,
    ):
        """
        Unified API to get any scheduler from its name.
        Parameters
        ----------
        name: (:obj:`str` or `:obj:`SchedulerType`)
            The name of the scheduler to use.
        optimizer: (:obj:`torch.optim.Optimizer`)
            The optimizer that will be used during training.
        num_warmup_steps: (:obj:`int`, `optional`)
            The number of warm-up steps to perform. This is not required by all schedulers
            (hence the argument being optional),
            the function will raise an error if it's unset and the scheduler type requires it.
        num_training_steps: (:obj:`int`, `optional`)
            The number of training steps to do. This is not required by all schedulers
            (hence the argument being optional),
            the function will raise an error if it's unset and the scheduler type requires it.
        num_cycles: (:obj:`int`, `optional`)
            The number of waves in the cosine schedule /
            hard restarts to use for cosine scheduler
        """
        name = SchedulerType(name)
        schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
        if name == SchedulerType.CONSTANT:
            return schedule_func(optimizer)
        # All other schedulers require `num_warmup_steps`
        if num_warmup_steps is None:
            raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
        if name == SchedulerType.CONSTANT_WITH_WARMUP:
            return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
        # All other schedulers require `num_training_steps`
        if num_training_steps is None:
            raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
        if "num_cycles" in inspect.signature(schedule_func).parameters:
            return schedule_func(
                optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=num_training_steps,
                num_cycles=num_cycles,
            )
        return schedule_func(
            optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
        ) 
[docs]    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Overriding :obj:`Trainer.compute_loss()`
        To allow for passing the targets to the model's forward method
        How the loss is computed by Trainer. By default, all Transformers4Rec models return
        a dictionary of three elements {'loss', 'predictions', and 'labels}
        """
        inputs, targets = inputs
        outputs = model(inputs, targets=targets, training=True)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        if "loss" not in outputs:
            raise ValueError(
                "The model did not return a loss from the inputs, only the following keys: "
                f"{','.join(outputs.keys())}. "
                "For reference, the inputs it received are {','.join(inputs.keys())}."
            )
        loss = outputs["loss"]
        return (loss, outputs) if return_outputs else loss 
[docs]    def prediction_step(
        self,
        model: torch.nn.Module,
        inputs: Dict[str, torch.Tensor],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
        training: bool = False,
        testing: bool = True,
    ) -> Tuple[
        Optional[float],
        Optional[torch.Tensor],
        Optional[torch.Tensor],
        Optional[Dict[str, Any]],
    ]:
        """
        Overriding :obj:`Trainer.prediction_step()`
        to provide more flexibility to unpack results from the model,
        like returning labels that are not exactly one input feature
        model
        """
        inputs = self._prepare_inputs(inputs)
        inputs, targets = inputs
        with torch.no_grad():
            if self._use_cuda_amp:
                with autocast():
                    outputs = model(inputs, targets=targets, training=training, testing=testing)
            else:
                outputs = model(inputs, targets=targets, training=training, testing=testing)
        if testing:
            loss = outputs["loss"].mean().detach()
            labels = nested_detach(outputs["labels"])
            predictions = nested_detach(outputs["predictions"])
        else:
            loss, labels = None, None
            predictions = nested_detach(outputs)
        if prediction_loss_only:
            return (loss, None, None, None)
        # TODO: define metadata dict in the model for logging
        # other_outputs = {
        #    k: v.detach() if isinstance(v, torch.Tensor) else v
        #    for k, v in outputs.items()
        #    if k not in ignore_keys + ["loss", "predictions", "labels"]
        # }
        other_outputs = None
        return (loss, predictions, labels, other_outputs) 
    @property
    def _use_cuda_amp(self):
        """
        Check for CUDA AMP that is compatible with versions of the
        transformers package before and after version 4.20 (which
        renamed the property `use_amp` to `use_cuda_amp`)
        """
        try:
            return self.use_cuda_amp
        except AttributeError:
            return self.use_amp
[docs]    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: Optional[str] = "eval",
    ) -> EvalLoopOutput:
        """
        Overriding :obj:`Trainer.prediction_loop()`
        (shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`)
        to provide more flexibility to work with streaming metrics
        (computed at each eval batch) and
        to log with the outputs of the model
        (e.g. prediction scores, prediction metadata, attention weights)
        Parameters
        ----------
        dataloader: DataLoader
            DataLoader object to use to iterate over evaluation data
        description: str
            Parameter to describe the evaluation experiment.
            e.g: `Prediction`, `test`
        prediction_loss_only: Optional[bool]
            Whether or not to return the loss only.
            by default None
        ignore_keys: Optional[List[str]]
            Columns not accepted by the ``model.forward()`` method
            are automatically removed.
            by default None
        metric_key_prefix: Optional[str]
            Prefix to use when logging evaluation metrics.
            by default `eval`
        """
        prediction_loss_only = (
            prediction_loss_only
            if prediction_loss_only is not None
            else self.args.prediction_loss_only
        )
        if description == "Prediction":
            testing = False
        else:
            testing = True
        # set the model
        model = self.model
        # reset metrics for the dataset (Train, Valid or Test)
        if self.compute_metrics:
            model.reset_metrics()
        if not isinstance(dataloader.dataset, collections.abc.Sized):
            raise ValueError("dataset must implement __len__")
        batch_size = dataloader._batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Batch size = %d", batch_size)
        preds_host: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
        if metric_key_prefix == "train" and self.args.eval_steps_on_train_set:
            num_examples = self.args.eval_steps_on_train_set * batch_size
        else:
            num_examples = self.num_examples(dataloader)
        logger.info("  Num sessions (examples) = %d", num_examples)
        model.eval()
        self.callback_handler.eval_dataloader = dataloader
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
        # Will be useful when we have an iterable dataset so don't know its length.
        observed_num_examples = 0
        # Iterate over dataloader
        for step, inputs in enumerate(dataloader):
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
            # Limits the number of evaluation steps on train set (which is usually larger)
            if (
                metric_key_prefix == "train"
                and self.args.eval_steps_on_train_set > 0
                and step + 1 > self.args.eval_steps_on_train_set
            ):
                break
            loss, preds, labels, outputs = self.prediction_step(
                model,
                inputs,
                prediction_loss_only,
                ignore_keys=ignore_keys,
                testing=testing,
            )
            # Updates metrics
            # TODO: compute metrics each N eval_steps to speedup evaluation
            metrics_results_detailed = None
            if self.compute_metrics is not None and testing:
                if step % self.args.compute_metrics_each_n_steps == 0:
                    metrics_results_detailed = model.calculate_metrics(preds, labels)
            # Update containers on host
            if loss is not None:
                losses = self._nested_gather(loss.repeat(batch_size))
                losses_host = (
                    losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
                )
            if labels is not None:
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
                labels_host = (
                    labels
                    if labels_host is None
                    else nested_concat(labels_host, labels, padding_index=0)
                )
            if (
                preds is not None
                and any(isinstance(x, NextItemPredictionTask) for x in model.prediction_tasks)
                and (self.args.predict_top_k or self.model.top_k)
            ):
                # get outputs of next-item scores
                if isinstance(preds, dict):
                    pred_next_item = preds["next-item"]
                else:
                    pred_next_item = preds
                preds_sorted_item_scores = None
                preds_sorted_item_ids = None
                if self.model.top_k is not None and isinstance(pred_next_item, (list, tuple)):
                    preds_sorted_item_scores, preds_sorted_item_ids = pred_next_item
                    if self.args.predict_top_k:
                        if self.args.predict_top_k > self.model.top_k:
                            raise ValueError(
                                "The args.predict_top_k should not be larger than model.top_k. "
                                "The model.top_k is available to support inference (e.g. when "
                                "serving with Triton Inference Server) to return only the top-k "
                                "predicted items ids and their scores."
                                "When doing offline predictions with `trainer.predict(), "
                                "if you set model.top_k, the model will also limit the number of "
                                "predictions output from trainer.predict(). "
                                "In that case, you want either to reduce args.predict_top_k or "
                                "increase model.top_k, so that args.predict_top_k is "
                                "not larger than model.top_k."
                            )
                        preds_sorted_item_scores = preds_sorted_item_scores[
                            :, : self.args.predict_top_k
                        ]
                        preds_sorted_item_ids = preds_sorted_item_ids[:, : self.args.predict_top_k]
                elif self.args.predict_top_k:
                    preds_sorted_item_scores, preds_sorted_item_ids = torch.topk(
                        pred_next_item, k=self.args.predict_top_k, dim=-1
                    )
                if preds_sorted_item_scores is not None:
                    self._maybe_log_predictions(
                        labels,
                        preds_sorted_item_ids,
                        preds_sorted_item_scores,
                        # outputs["pred_metadata"],
                        metrics_results_detailed,
                        metric_key_prefix,
                    )
                    # The output predictions will be a tuple with the ranked top-n item ids,
                    # and item recommendation scores
                    if isinstance(preds, dict):
                        preds["next-item"] = (
                            preds_sorted_item_ids,
                            preds_sorted_item_scores,
                        )
                    else:
                        preds = (
                            preds_sorted_item_ids,
                            preds_sorted_item_scores,
                        )
            preds_host = (
                preds
                if preds_host is None
                else nested_concat(
                    preds_host,
                    preds,
                )
            )
            self.control = self.callback_handler.on_prediction_step(
                self.args, self.state, self.control
            )
            # Gather all tensors and put them back on the CPU
            # if we have done enough accumulation steps.
            if (
                self.args.eval_accumulation_steps is not None
                and (step + 1) % self.args.eval_accumulation_steps == 0
            ):
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = (
                        losses
                        if all_losses is None
                        else np.concatenate((all_losses, losses), axis=0)
                    )
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels
                        if all_labels is None
                        else nested_concat(all_labels, labels, padding_index=0)
                    )
                if preds_host is not None:
                    preds = nested_numpify(preds_host)
                    all_preds = (
                        preds
                        if all_preds is None
                        else nested_concat(
                            all_preds,
                            preds,
                        )
                    )
                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host = None, None, None
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
        # Gather all remaining tensors and put them back on the CPU
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = (
                losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
            )
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = (
                labels if all_labels is None else nested_concat(all_labels, labels, padding_index=0)
            )
        if preds_host is not None:
            preds_host = nested_numpify(preds_host)
            all_preds = (
                preds_host
                if all_preds is None
                else nested_concat(
                    all_preds,
                    preds_host,
                )
            )
        # Get Number of samples :
        # the data loaders for this project do not return the dataset size,
        num_samples = observed_num_examples
        # Number of losses has been rounded to a multiple of batch_size
        # and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
        # Get metrics :
        metrics = {}
        # Computing the metrics results as the average of all steps
        if self.compute_metrics and testing:
            streaming_metrics_results = model.compute_metrics(mode=metric_key_prefix)
            streaming_metrics_results_flattened = process_metrics(
                streaming_metrics_results, prefix=metric_key_prefix + "_/"
            )
            metrics = {**metrics, **streaming_metrics_results_flattened}
        if testing:
            metrics[f"{metric_key_prefix}_/loss"] = all_losses.mean().item()
        return EvalLoopOutput(
            predictions=all_preds,
            label_ids=all_labels,
            metrics=metrics,
            num_samples=num_examples,
        ) 
    def _save_model_and_checkpoint(self, save_model_class=False):
        """
        Save the serialized model + trainer and random states.
        Parameters
        ----------
        save_model_class: Optional[bool]
            Whether to save the Model class or not.
            by default False
        """
        import os
        logger.info("Saving model...")
        output_dir = os.path.join(
            self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
        )
        # save model parameters
        self._save_checkpoint(self.model, trial=None, metrics=None)
        # save the serialized model
        if save_model_class:
            # TODO : fix serialization of DatasetSchema object
            self.model.save(output_dir)
[docs]    def load_model_trainer_states_from_checkpoint(self, checkpoint_path, model=None):
        """
        This method loads the checkpoints states of the model, trainer and random states.
        If model is None the serialized model class is loaded from checkpoint.
        It does not loads the optimizer and LR scheduler states (for that call trainer.train()
        with resume_from_checkpoint argument for a complete load)
        Parameters
        ----------
        checkpoint_path: str
            Path to the checkpoint directory.
        model: Optional[Model]
            Model class used by Trainer. by default None
        """
        import os
        if model is None:
            try:
                import cloudpickle
            except ImportError:
                raise ImportError("cloudpickle is required to load model class")
            logger.info("Loading model class")
            model = cloudpickle.load(
                open(os.path.join(checkpoint_path, "t4rec_model_class.pkl"), "rb")
            )
        self.model = model
        logger.info("Loading weights of previously trained model")
        # Restoring model weights
        self.model.load_state_dict(
            # torch.load(os.path.join(training_args.output_dir, "pytorch_model.bin"))
            torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
        )
        # Restoring random state
        rng_file = os.path.join(checkpoint_path, "rng_state.pth")
        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
        # Restoring AMP scaler
        if self._use_cuda_amp:
            self.scaler.load_state_dict(torch.load(os.path.join(checkpoint_path, "scaler.pt"))) 
    @property
    def log_predictions_callback(self) -> Callable:
        return self.__log_predictions_callback
    @log_predictions_callback.setter
    def log_predictions_callback(self, var: Callable):
        self.__log_predictions_callback = var
    def _maybe_log_predictions(
        self,
        labels: torch.Tensor,
        pred_item_ids: torch.Tensor,
        pred_item_scores: torch.Tensor,
        metrics: Dict[str, np.ndarray],
        metric_key_prefix: str,
    ):
        """
        If --log_predictions is enabled, calls a callback function to
        log predicted item ids, scores, metadata and metrics.
        Parameters
        ----------
        labels: torch.Tensor
            True labels.
        pred_item_ids: torch.Tensor
            The predicted items ids. if top_k is set:
            we return to top-k items for each
            next-item prediction.
        pred_item_scores: torch.Tensor
            The prediction scores, if top_k is set:
            we return to top-k predictions for each
            next-item prediction.
        metrics: Dict[str, np.ndarray]
            Dictionary of metrics computed by Model.
        metric_key_prefix: str
            Prefix to use when logging evaluation metrics.
            by default `eval`
        """
        # TODO Add pred_metadata: Dict[str, torch.Tensor],
        if self.args.log_predictions and self.log_predictions_callback is not None:
            # Converting torch Tensors to NumPy and callback predictions logging function
            # preds_metadata = {k: v.cpu().numpy() for k, v in pred_metadata.items()}
            self.log_predictions_callback(
                labels=labels.cpu().numpy(),
                pred_item_ids=pred_item_ids.cpu().numpy(),
                pred_item_scores=pred_item_scores.cpu()
                .numpy()
                .astype(np.float32),  # Because it is float16 when --fp16
                # preds_metadata=preds_metadata,
                metrics=metrics,
                dataset_type=metric_key_prefix,
            )
    def _increment_past_global_steps(self, current_global_step: int):
        self.past_global_steps += current_global_step
    def _get_general_global_step(self) -> int:
        general_global_step = self.past_global_steps
        if self.model.training:
            general_global_step += self.state.global_step
        return general_global_step
[docs]    def log(self, logs: Dict[str, float]) -> None:
        # Ensuring that eval metrics are prefixed as "eval_" so that the HF integration loggers
        # do not prefix metrics names with 'train/' (as 'train/' is always added when not eval)
        logs = {re.sub("^eval/", "eval_", k).replace("train/", ""): v for k, v in logs.items()}
        if not self.incremental_logging:
            super().log(logs)
        else:
            # If Incremental logging is enabled, ensures that global steps are always
            # incremented after train() calls
            # so that metrics are logger with no overlap on W&B and Tensorboard
            if self.state.epoch is not None:
                logs["epoch"] = round(self.state.epoch, 2)
            # As state.global_step is also used for the learning rate schedules,
            # we create a copy only for logging
            state_copy = deepcopy(self.state)
            state_copy.global_step = self._get_general_global_step()
            output = {**logs, **{"step": state_copy.global_step}}
            self.state.log_history.append(output)
            self.control = self.callback_handler.on_log(self.args, state_copy, self.control, logs)  
[docs]def process_metrics(metrics, prefix="", to_cpu=True):
    metrics_proc = {}
    for root_key, root_value in metrics.items():
        if isinstance(root_value, dict):
            flattened_metrics = process_metrics(root_value, prefix=prefix, to_cpu=to_cpu)
            metrics_proc = {**metrics_proc, **flattened_metrics}
        else:
            value = root_value.cpu().numpy().item() if to_cpu else root_value
            metrics_proc[f"{prefix}{root_key}"] = value
    return metrics_proc 
[docs]class IncrementalLoggingCallback(TrainerCallback):
    """
    An :class:`~transformers.TrainerCallback` that changes the state of the Trainer
    on specific hooks for the purpose of the incremental logging
    Parameters
    ----------
    trainer: Trainer
    """
    def __init__(self, trainer: Trainer):
        self.trainer = trainer
[docs]    def on_train_begin(self, args, state, control, model=None, **kwargs):
        pass 
[docs]    def on_train_end(self, args, state, control, model=None, **kwargs):
        # Increments the global steps for logging with the global steps of the last train()
        self.trainer._increment_past_global_steps(state.global_step) 
[docs]    def on_epoch_end(self, args, state, control, model=None, **kwargs):
        # Evaluates on eval set
        # self.trainer.evaluate()
        pass  
[docs]class DatasetMock(Dataset, Sized):
    """
    Mock to inform HF Trainer that the dataset is sized,
    and can be obtained via the generated/provided data loader
    """
    def __init__(self, nsteps=1):
        self.nsteps = nsteps
    def __len__(self):
        return self.nsteps