Source code for transformers4rec.torch.utils.examples_utils

import gc
import glob
import os

import numpy as np
import torch


[docs]def list_files(startpath): """ Util function to print the nested structure of a directory """ for root, dirs, files in os.walk(startpath): level = root.replace(startpath, "").count(os.sep) indent = " " * 4 * (level) print("{}{}/".format(indent, os.path.basename(root))) subindent = " " * 4 * (level + 1) for f in files: print("{}{}".format(subindent, f))
[docs]def visualize_response(batch, response, top_k, session_col="session_id"): """ Util function to extract top-k encoded item-ids from logits Parameters ---------- batch : cudf.DataFrame the batch of raw data sent to triton server. response: tritonclient.grpc.InferResult the response returned by grpc client. top_k: int the `top_k` top items to retrieve from predictions. """ sessions = batch[session_col].drop_duplicates().values predictions = response.as_numpy("output") top_preds = np.argpartition(predictions, -top_k, axis=1)[:, -top_k:] for session, next_items in zip(sessions, top_preds): print( "- Top-%s predictions for session `%s`: %s\n" % (top_k, session, " || ".join([str(e) for e in next_items])) )
[docs]def fit_and_evaluate(trainer, start_time_index, end_time_index, input_dir): """ Util function for time-window based fine-tuning using the T4rec Trainer class. Iteratively train using data of a given index and evaluate on the validation data of the following index. Parameters ---------- start_time_index: int The start index for training, it should match the partitions of the data directory end_time_index: int The end index for training, it should match the partitions of the data directory input_dir: str The input directory where the parquet files were saved based on partition column Returns ------- aot_metrics: dict The average over time of ranking metrics. """ aot_metrics = {} for time_index in range(start_time_index, end_time_index + 1): # 1. Set data time_index_train = time_index time_index_eval = time_index + 1 train_paths = glob.glob(os.path.join(input_dir, f"{time_index_train}/train.parquet")) eval_paths = glob.glob(os.path.join(input_dir, f"{time_index_eval}/valid.parquet")) # 2. Train on train data of time_index print("\n***** Launch training for day %s: *****" % time_index) trainer.train_dataset_or_path = train_paths trainer.reset_lr_scheduler() trainer.train() # 3. Evaluate on valid data of time_index+1 trainer.eval_dataset_or_path = eval_paths eval_metrics = trainer.evaluate(metric_key_prefix="eval") print("\n***** Evaluation results for day %s:*****\n" % time_index_eval) for key in sorted(eval_metrics.keys()): if "at_" in key: print(" %s = %s" % (key.replace("at_", "@"), str(eval_metrics[key]))) if "AOT_" + key.replace("at_", "@") in aot_metrics: aot_metrics["AOT_" + key.replace("_at_", "@")] += [eval_metrics[key]] else: aot_metrics["AOT_" + key.replace("_at_", "@")] = [eval_metrics[key]] # free GPU for next day training wipe_memory() return aot_metrics
[docs]def wipe_memory(): gc.collect() torch.cuda.empty_cache()