# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_transformers4rec_end-to-end-session-based-02-end-to-end-session-based-with-yoochoose-pyt/nvidia_logo.png

End-to-end session-based recommendations with PyTorch

In recent years, several deep learning-based algorithms have been proposed for recommendation systems while its adoption in industry deployments have been steeply growing. In particular, NLP inspired approaches have been successfully adapted for sequential and session-based recommendation problems, which are important for many domains like e-commerce, news and streaming media. Session-Based Recommender Systems (SBRS) have been proposed to model the sequence of interactions within the current user session, where a session is a short sequence of user interactions typically bounded by user inactivity. They have recently gained popularity due to their ability to capture short-term or contextual user preferences towards items.

The field of NLP has evolved significantly within the last decade, particularly due to the increased usage of deep learning. As a result, state of the art NLP approaches have inspired RecSys practitioners and researchers to adapt those architectures, especially for sequential and session-based recommendation problems. Here, we leverage one of the state-of-the-art Transformer-based architecture, XLNet with Masked Language Modeling (MLM) training technique (see our tutorial for details) for training a session-based model.

In this end-to-end-session-based recommnender model example, we use Transformers4Rec library, which leverages the popular HuggingFace’s Transformers NLP library and make it possible to experiment with cutting-edge implementation of such architectures for sequential and session-based recommendation problems. For detailed explanations of the building blocks of Transformers4Rec meta-architecture visit getting-started-session-based and tutorial example notebooks.

1. Model definition using Transformers4Rec

In the previous notebook, we have created sequential features and saved our processed data frames as parquet files. Now we use these processed parquet files to train a session-based recommendation model with the XLNet architecture.

1.1 Get the schema

The library uses a schema format to configure the input features and automatically creates the necessary layers. This protobuf text file contains the description of each input feature by defining: the name, the type, the number of elements of a list column, the cardinality of a categorical feature and the min and max values of each feature. In addition, the annotation field contains the tags such as specifying the continuous and categorical features, the target column or the item_id feature, among others.

We create the schema object by reading the processed train parquet file generated by NVTabular pipeline in the previous, 01-ETL-with-NVTabular, notebook.

import os
INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "/workspace/data")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/preproc_sessions_by_day")
from merlin.schema import Schema
from merlin.io import Dataset

train = Dataset(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
schema = train.schema
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(

We can select the subset of features we want to use for training the model by their tags or their names.

schema = schema.select_by_name(
   ['item_id-list', 'category-list', 'product_recency_days_log_norm-list', 'et_dayofweek_sin-list']
)

We can print out the schema.

schema
name tags dtype is_list is_ragged properties.num_buckets properties.freq_threshold properties.max_size properties.start_index properties.cat_path properties.embedding_sizes.cardinality properties.embedding_sizes.dimension properties.domain.min properties.domain.max properties.domain.name properties.value_count.min properties.value_count.max
0 item_id-list (Tags.LIST, Tags.ID, Tags.CATEGORICAL, Tags.IT... DType(name='int64', element_type=<ElementType.... True False NaN 0.0 0.0 1.0 .//categories/unique.item_id.parquet 52741.0 512.0 0.0 52740.0 item_id 20 20
1 category-list (Tags.LIST, Tags.CATEGORICAL) DType(name='int64', element_type=<ElementType.... True False NaN 0.0 0.0 1.0 .//categories/unique.category.parquet 336.0 42.0 0.0 335.0 category 20 20
2 product_recency_days_log_norm-list (Tags.LIST, Tags.CONTINUOUS) DType(name='float32', element_type=<ElementTyp... True False NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 20 20
3 et_dayofweek_sin-list (Tags.LIST, Tags.CONTINUOUS) DType(name='float32', element_type=<ElementTyp... True False NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 20 20

1.2 Define the end-to-end Session-based Transformer-based recommendation model

For defining a session-based recommendation model, the end-to-end model definition requires four steps:

  1. Instantiate TabularSequenceFeatures input-module from schema to prepare the embedding tables of categorical variables and project continuous features, if specified. In addition, the module provides different aggregation methods (e.g. ‘concat’, ‘elementwise-sum’) to merge input features and generate the sequence of interactions embeddings. The module also supports language modeling tasks to prepare masked labels for training and evaluation (e.g: ‘mlm’ for masked language modeling)

  2. Next, we need to define one or multiple prediction tasks. For this demo, we are going to use NextItemPredictionTask with Masked Language modeling: during training, randomly selected items are masked and predicted using the unmasked sequence items. For inference, it is meant to always predict the next item to be interacted with.

  3. Then we construct a transformer_config based on the architectures provided by Hugging Face Transformers framework.

  4. Finally we link the transformer-body to the inputs and the prediction tasks to get the final pytorch Model class.

For more details about the features supported by each sub-module, please check out the library documentation page.

from transformers4rec import torch as tr

max_sequence_length, d_model = 20, 320
# Define input module to process tabular input-features and to prepare masked inputs
input_module = tr.TabularSequenceFeatures.from_schema(
    schema,
    max_sequence_length=max_sequence_length,
    continuous_projection=64,
    aggregation="concat",
    d_output=d_model,
    masking="mlm",
)

# Define Next item prediction-task 
prediction_task = tr.NextItemPredictionTask(weight_tying=True)

# Define the config of the XLNet Transformer architecture
transformer_config = tr.XLNetConfig.build(
    d_model=d_model, n_head=8, n_layer=2, total_seq_length=max_sequence_length
)

# Get the end-to-end model 
model = transformer_config.to_torch_model(input_module, prediction_task)
Projecting inputs of NextItemPredictionTask to'64' As weight tying requires the input dimension '320' to be equal to the item-id embedding dimension '64'

You can print out the model structure by uncommenting the line below.

#model

1.3. Daily Fine-Tuning: Training over a time window¶

Now that the model is defined, we are going to launch training. For that, Transfromers4rec extends HF Transformers Trainer class to adapt the evaluation loop for session-based recommendation task and the calculation of ranking metrics. The original train() method is not modified meaning that we leverage the efficient training implementation from that library, which manages, for example, half-precision (FP16) training.

Set the training arguments

An additional argument data_loader_engine is defined to automatically load the features needed for training using the schema. The default value is merlin for optimized GPU-based data-loading. Optionally a PyarrowDataLoader (pyarrow) can also be used as a basic option, but it is slower and works only for small datasets, as the full data is loaded to CPU memory.

training_args = tr.trainer.T4RecTrainingArguments(
            output_dir="./tmp",
            max_sequence_length=20,
            data_loader_engine='merlin',
            num_train_epochs=10, 
            dataloader_drop_last=False,
            per_device_train_batch_size = 384,
            per_device_eval_batch_size = 512,
            learning_rate=0.0005,
            fp16=True,
            report_to = [],
            logging_steps=200
        )

Instantiate the trainer

recsys_trainer = tr.Trainer(
    model=model,
    args=training_args,
    schema=schema,
    compute_metrics=True)
Using amp fp16 backend

Launch daily training and evaluation

In this demo, we will use the fit_and_evaluate method that allows us to conduct a time-based finetuning by iteratively training and evaluating using a sliding time window: At each iteration, we use the training data of a specific time index \(t\) to train the model; then we evaluate on the validation data of the next index \(t + 1\). Particularly, we set start time to 178 and end time to 180.

from transformers4rec.torch.utils.examples_utils import fit_and_evaluate
OT_results = fit_and_evaluate(recsys_trainer, start_time_index=178, end_time_index=180, input_dir=OUTPUT_DIR)
***** Launch training for day 178: *****
***** Running training *****
  Num examples = 28800
  Num Epochs = 10
  Instantaneous batch size per device = 384
  Total train batch size (w. parallel, distributed & accumulation) = 768
  Gradient Accumulation steps = 1
  Total optimization steps = 750
/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
[750/750 01:41, Epoch 10/10]
Step Training Loss
200 7.712600
400 6.600600
600 6.329100

Saving model checkpoint to ./tmp/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '


Training completed. Do not forget to share your model on huggingface.co/models =)
[6/6 02:14]
***** Running training *****
  Num examples = 20736
  Num Epochs = 10
  Instantaneous batch size per device = 384
  Total train batch size (w. parallel, distributed & accumulation) = 768
  Gradient Accumulation steps = 1
  Total optimization steps = 540
***** Evaluation results for day 179:*****

 eval_/next-item/avg_precision@10 = 0.08841836452484131
 eval_/next-item/avg_precision@20 = 0.09261580556631088
 eval_/next-item/ndcg@10 = 0.11992594599723816
 eval_/next-item/ndcg@20 = 0.1354396641254425
 eval_/next-item/recall@10 = 0.2204238921403885
 eval_/next-item/recall@20 = 0.28131020069122314

***** Launch training for day 179: *****
[540/540 01:12, Epoch 10/10]
Step Training Loss
200 6.902200
400 6.516700

Saving model checkpoint to ./tmp/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '


Training completed. Do not forget to share your model on huggingface.co/models =)


***** Running training *****
  Num examples = 16896
  Num Epochs = 10
  Instantaneous batch size per device = 384
  Total train batch size (w. parallel, distributed & accumulation) = 768
  Gradient Accumulation steps = 1
  Total optimization steps = 440
***** Evaluation results for day 180:*****

 eval_/next-item/avg_precision@10 = 0.05934491753578186
 eval_/next-item/avg_precision@20 = 0.06326105445623398
 eval_/next-item/ndcg@10 = 0.08262269198894501
 eval_/next-item/ndcg@20 = 0.09686349332332611
 eval_/next-item/recall@10 = 0.15617716312408447
 eval_/next-item/recall@20 = 0.21212121844291687

***** Launch training for day 180: *****
[440/440 00:59, Epoch 10/10]
Step Training Loss
200 6.998900
400 6.566000

Training completed. Do not forget to share your model on huggingface.co/models =)
***** Evaluation results for day 181:*****

 eval_/next-item/avg_precision@10 = 0.13744957745075226
 eval_/next-item/avg_precision@20 = 0.14369292557239532
 eval_/next-item/ndcg@10 = 0.17851395905017853
 eval_/next-item/ndcg@20 = 0.20183274149894714
 eval_/next-item/recall@10 = 0.31168830394744873
 eval_/next-item/recall@20 = 0.40630799531936646

Visualize the average of metrics over time

OT_results is a list of scores (accuracy metrics) for evaluation based on given start and end time_index. Since in this example we do evaluation on days 179, 180 and 181, we get three metrics in the list one for each day.

OT_results
{'indexed_by_time_eval_/next-item/avg_precision@10': [0.08841836452484131,
  0.05934491753578186,
  0.13744957745075226],
 'indexed_by_time_eval_/next-item/avg_precision@20': [0.09261580556631088,
  0.06326105445623398,
  0.14369292557239532],
 'indexed_by_time_eval_/next-item/ndcg@10': [0.11992594599723816,
  0.08262269198894501,
  0.17851395905017853],
 'indexed_by_time_eval_/next-item/ndcg@20': [0.1354396641254425,
  0.09686349332332611,
  0.20183274149894714],
 'indexed_by_time_eval_/next-item/recall@10': [0.2204238921403885,
  0.15617716312408447,
  0.31168830394744873],
 'indexed_by_time_eval_/next-item/recall@20': [0.28131020069122314,
  0.21212121844291687,
  0.40630799531936646]}
import numpy as np
# take the average of metric values over time
avg_results = {k: np.mean(v) for k,v in OT_results.items()}
for key in sorted(avg_results.keys()): 
    print(" %s = %s" % (key, str(avg_results[key]))) 
 indexed_by_time_eval_/next-item/avg_precision@10 = 0.09507095317045848
 indexed_by_time_eval_/next-item/avg_precision@20 = 0.0998565951983134
 indexed_by_time_eval_/next-item/ndcg@10 = 0.12702086567878723
 indexed_by_time_eval_/next-item/ndcg@20 = 0.14471196631590524
 indexed_by_time_eval_/next-item/recall@10 = 0.2294297864039739
 indexed_by_time_eval_/next-item/recall@20 = 0.2999131381511688

Save the model

recsys_trainer._save_model_and_checkpoint(save_model_class=True)
Saving model checkpoint to ./tmp/checkpoint-440
Trainer.model is not a `PreTrainedModel`, only saving its state dict.

Export the preprocessing workflow and model in the format required by Triton server:

NVTabular’s export_pytorch_ensemble() function enables us to create model files and config files to be served to Triton Inference Server.

x_cat_names, x_cont_names = ['item_id-list', 'category-list'], ['product_recency_days_log_norm-list', 'et_dayofweek_sin-list']

sparse_features_max = {
    fname: 20
    for fname in x_cat_names + x_cont_names
}

sparse_features_max
{'item_id-list': 20,
 'category-list': 20,
 'product_recency_days_log_norm-list': 20,
 'et_dayofweek_sin-list': 20}
from nvtabular.inference.triton import export_pytorch_ensemble
from nvtabular.workflow import Workflow
workflow = Workflow.load(os.path.join(INPUT_DATA_DIR, "workflow_etl"))
model_path = os.path.join(INPUT_DATA_DIR, "models")

export_pytorch_ensemble(
    model,
    workflow,
    sparse_max=sparse_features_max,
    name="t4r_pytorch",
    model_path= model_path,
    label_columns =[],
)
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:171: UserWarning: PyTorch model expects int64 for column session_id, but workflow  is producing type DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))). Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:171: UserWarning: PyTorch model expects int32 for column item_id-count, but workflow  is producing type DType(name='int32', element_type=<ElementType.Int: 'int'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))). Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:171: UserWarning: PyTorch model expects int64 for column item_id-list, but workflow  is producing type DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))). Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:171: UserWarning: PyTorch model expects float32 for column et_dayofweek_sin-list, but workflow  is producing type DType(name='float32', element_type=<ElementType.Float: 'float'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))). Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:171: UserWarning: PyTorch model expects float32 for column product_recency_days_log_norm-list, but workflow  is producing type DType(name='float32', element_type=<ElementType.Float: 'float'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))). Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:171: UserWarning: PyTorch model expects int64 for column category-list, but workflow  is producing type DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))). Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:171: UserWarning: PyTorch model expects int64 for column day_index, but workflow  is producing type DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=None)). Overriding dtype in NVTabular workflow.
  warnings.warn(

2. Serving Ensemble Model to the Triton Inference Server

NVIDIA Triton Inference Server (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.

The last step of a machine learning (ML)/deep learning (DL) pipeline is to deploy the ETL workflow and saved model to production. In the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the DL model for a prediction. Therefore, we deploy the NVTabular workflow with the PyTorch model as an ensemble model to Triton Inference. The ensemble model guarantees that the same transformation is applied to the raw inputs.

In this section, you will learn how to

  • to deploy saved NVTabular and PyTorch models to Triton Inference Server

  • send requests for predictions and get responses.

2.1. Pull and Start Inference Container

At this point, we start the Triton Inference Server (TIS).

Start triton server
You can start triton server with the command below. You need to provide correct path of the models folder.

tritonserver --model-repository=<path_to_models> --model-control-mode=explicit

Note: The model-repository path for our example is /workspace/data/models/. The models have not been loaded yet. Below, we will request the Triton server to load the saved ensemble model.

2.2. Connect to the Triton Inference Server and check if the server is alive

import tritonhttpclient
try:
    triton_client = tritonhttpclient.InferenceServerClient(url="localhost:8000", verbose=True)
    print("client created.")
except Exception as e:
    print("channel creation failed: " + str(e))
triton_client.is_server_live()
client created.
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
True

2.3. Load raw data for inference

We select the last 50 interactions and filter out sessions with less than 2 interactions.

import pandas as pd
interactions_merged_df = pd.read_parquet(os.path.join(INPUT_DATA_DIR, "interactions_merged_df.parquet"))
interactions_merged_df = interactions_merged_df.sort_values('timestamp')
batch = interactions_merged_df[-50:]
sessions_to_use = batch.session_id.value_counts()
filtered_batch = batch[batch.session_id.isin(sessions_to_use[sessions_to_use.values>1].index.values)]

2.4. Load the ensemble model to triton

The models should be loaded successfully before we send a request to TIS. If all models are loaded successfully, you should be seeing successfully loaded status next to each model name on your terminal.

triton_client.load_model(model_name="t4r_pytorch")
POST /v2/repository/models/t4r_pytorch/load, headers None
{}
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '0'}>
Loaded model 't4r_pytorch'

2.5. Send the request to triton server

triton_client.get_model_repository_index()
POST /v2/repository/index, headers None

<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '167'}>
bytearray(b'[{"name":"t4r_pytorch","version":"1","state":"READY"},{"name":"t4r_pytorch_nvt","version":"1","state":"READY"},{"name":"t4r_pytorch_pt","version":"1","state":"READY"}]')
[{'name': 't4r_pytorch', 'version': '1', 'state': 'READY'},
 {'name': 't4r_pytorch_nvt', 'version': '1', 'state': 'READY'},
 {'name': 't4r_pytorch_pt', 'version': '1', 'state': 'READY'}]

If all models are loaded successfully, you should be seeing READY status next to each model.

import nvtabular.inference.triton as nvt_triton
import tritonclient.grpc as grpcclient

inputs = nvt_triton.convert_df_to_triton_input(filtered_batch.columns, filtered_batch, grpcclient.InferInput)

output_names = ["output"]

outputs = []
for col in output_names:
    outputs.append(grpcclient.InferRequestedOutput(col))
    
MODEL_NAME_NVT = "t4r_pytorch"

with grpcclient.InferenceServerClient("localhost:8001") as client:
    response = client.infer(MODEL_NAME_NVT, inputs)
    print(col, ':\n', response.as_numpy(col))
output :
 [[-12.359397  -12.863064   -8.659327  ... -12.52291   -13.386017
  -12.252247 ]
 [-16.44956   -16.089582   -8.681267  ... -17.113033  -18.37918
  -16.107119 ]
 [-13.414572  -13.681402   -8.590441  ... -14.162938  -15.2169285
  -13.981831 ]
 ...
 [-17.573406  -16.371202   -9.535273  ... -17.32026   -18.294775
  -16.001776 ]
 [-13.201723  -13.092476   -8.6168995 ... -13.6306    -14.307053
  -12.9177685]
 [-17.229147  -16.887306   -9.369176  ... -17.403896  -18.368675
  -16.291914 ]]
  • Visualise top-k predictions

from transformers4rec.torch.utils.examples_utils import visualize_response
visualize_response(filtered_batch, response, top_k=5, session_col='session_id')
- Top-5 predictions for session `11457123`: 1761 || 186 || 2651 || 2383 || 1987

- Top-5 predictions for session `11467406`: 4136 || 224 || 2774 || 2693 || 2759

- Top-5 predictions for session `11528554`: 135 || 183 || 1697 || 1359 || 1340

- Top-5 predictions for session `11336059`: 2556 || 2651 || 186 || 6989 || 7284

- Top-5 predictions for session `11445777`: 2789 || 5591 || 2891 || 2759 || 4541

- Top-5 predictions for session `11493827`: 6510 || 4136 || 4204 || 4155 || 4541

- Top-5 predictions for session `11425751`: 2788 || 2556 || 224 || 1987 || 2050

- Top-5 predictions for session `11399751`: 3841 || 2214 || 2556 || 224 || 2651

- Top-5 predictions for session `11311424`: 6461 || 4713 || 4136 || 9285 || 4155

- Top-5 predictions for session `11257991`: 5932 || 620 || 1334 || 633 || 224

- Top-5 predictions for session `11561822`: 2956 || 6488 || 8084 || 4713 || 4136

- Top-5 predictions for session `11421333`: 224 || 9285 || 6488 || 11389 || 8084

- Top-5 predictions for session `11270119`: 5932 || 6488 || 4136 || 4541 || 4204

- Top-5 predictions for session `11401481`: 4204 || 11389 || 6488 || 4136 || 224

- Top-5 predictions for session `11394056`: 2759 || 5591 || 4541 || 4204 || 4136

As you noticed, we first got prediction results (logits) from the trained model head, and then by using a handy util function visualize_response we extracted top-k encoded item-ids from logits. Basically, we generated recommended items for a given session.

This is the end of the tutorial. You successfully

  • performed feature engineering with NVTabular

  • trained transformer architecture based session-based recommendation models with Transformers4Rec

  • deployed a trained model to Triton Inference Server, sent request and got responses from the server.

Unload models

triton_client.unload_model(model_name="t4r_pytorch")
triton_client.unload_model(model_name="t4r_pytorch_nvt")
triton_client.unload_model(model_name="t4r_pytorch_pt")
POST /v2/repository/models/t4r_pytorch/unload, headers None
{"parameters":{"unload_dependents":false}}
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '0'}>
Loaded model 't4r_pytorch'
POST /v2/repository/models/t4r_pytorch_nvt/unload, headers None
{"parameters":{"unload_dependents":false}}
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '0'}>
Loaded model 't4r_pytorch_nvt'
POST /v2/repository/models/t4r_pytorch_pt/unload, headers None
{"parameters":{"unload_dependents":false}}
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '0'}>
Loaded model 't4r_pytorch_pt'