# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
End-to-end session-based recommendation with Transformers4Rec
In recent years, several deep learning-based algorithms have been proposed for recommendation systems while its adoption in industry deployments have been steeply growing. In particular, NLP inspired approaches have been successfully adapted for sequential and session-based recommendation problems, which are important for many domains like e-commerce, news and streaming media. Session-Based Recommender Systems (SBRS) have been proposed to model the sequence of interactions within the current user session, where a session is a short sequence of user interactions typically bounded by user inactivity. They have recently gained popularity due to their ability to capture short-term or contextual user preferences towards items.
The field of NLP has evolved significantly within the last decade, particularly due to the increased usage of deep learning. As a result, state of the art NLP approaches have inspired RecSys practitioners and researchers to adapt those architectures, especially for sequential and session-based recommendation problems. Here, we leverage one of the state-of-the-art Transformer-based architecture, XLNet with Masked Language Modeling (MLM) training technique (see our tutorial for details) for training a session-based model.
In this end-to-end-session-based recommnender model example, we use Transformers4Rec
library, which leverages the popular HuggingFace’s Transformers NLP library and make it possible to experiment with cutting-edge implementation of such architectures for sequential and session-based recommendation problems. For detailed explanations of the building blocks of Transformers4Rec meta-architecture visit getting-started-session-based and tutorial example notebooks.
1. Model definition using Transformers4Rec
In the previous notebook, we have created sequential features and saved our processed data frames as parquet files, and now we use these processed parquet files to train a session-based recommendation model with XLNet architecture.
1.1 Get the schema
The library uses a schema format to configure the input features and automatically creates the necessary layers. This protobuf text file contains the description of each input feature by defining: the name, the type, the number of elements of a list column, the cardinality of a categorical feature and the min and max values of each feature. In addition, the annotation field contains the tags such as specifying continuous
and categorical
features, the target
column or the item_id
feature, among others.
from merlin_standard_lib import Schema
SCHEMA_PATH = "schema_demo.pb"
schema = Schema().from_proto_text(SCHEMA_PATH)
!cat $SCHEMA_PATH
feature {
name: "session_id"
type: INT
int_domain {
name: "session_id"
min: 1
max: 9249733
is_categorical: false
}
annotation {
tag: "groupby_col"
}
}
feature {
name: "item_id-list_seq"
value_count {
min: 2
max: 185
}
type: INT
int_domain {
name: "item_id/list"
min: 1
max: 52742
is_categorical: true
}
annotation {
tag: "item_id"
tag: "list"
tag: "categorical"
tag: "item"
}
}
feature {
name: "category-list_seq"
value_count {
min: 2
max: 185
}
type: INT
int_domain {
name: "category-list_seq"
min: 1
max: 337
is_categorical: true
}
annotation {
tag: "list"
tag: "categorical"
tag: "item"
}
}
feature {
name: "product_recency_days_log_norm-list_seq"
value_count {
min: 2
max: 185
}
type: FLOAT
float_domain {
name: "product_recency_days_log_norm-list_seq"
min: -2.9177291
max: 1.5231701
}
annotation {
tag: "continuous"
tag: "list"
}
}
feature {
name: "et_dayofweek_sin-list_seq"
value_count {
min: 2
max: 185
}
type: FLOAT
float_domain {
name: "et_dayofweek_sin-list_seq"
min: 0.7421683
max: 0.9995285
}
annotation {
tag: "continuous"
tag: "time"
tag: "list"
}
}
We can select the subset of features we want to use for training the model by their tags or their names.
schema = schema.select_by_name(
['item_id-list_seq', 'category-list_seq', 'product_recency_days_log_norm-list_seq', 'et_dayofweek_sin-list_seq']
)
3.2 Define the end-to-end Session-based Transformer-based recommendation model
For session-based recommendation model definition, the end-to-end model definition requires four steps:
Instantiate TabularSequenceFeatures input-module from schema to prepare the embedding tables of categorical variables and project continuous features, if specified. In addition, the module provides different aggregation methods (e.g. ‘concat’, ‘elementwise-sum’) to merge input features and generate the sequence of interactions embeddings. The module also supports language modeling tasks to prepare masked labels for training and evaluation (e.g: ‘mlm’ for masked language modeling)
Next, we need to define one or multiple prediction tasks. For this demo, we are going to use NextItemPredictionTask with
Masked Language modeling
: during training randomly selected items are masked and predicted using the unmasked sequence items. For inference it is meant to always predict the next item to be interacted with.Then we construct a
transformer_config
based on the architectures provided by Hugging Face Transformers framework.Finally we link the transformer-body to the inputs and the prediction tasks to get the final pytorch
Model
class.
For more details about the features supported by each sub-module, please check out the library documentation page.
from transformers4rec import torch as tr
max_sequence_length, d_model = 20, 320
# Define input module to process tabular input-features and to prepare masked inputs
input_module = tr.TabularSequenceFeatures.from_schema(
schema,
max_sequence_length=max_sequence_length,
continuous_projection=64,
aggregation="concat",
d_output=d_model,
masking="mlm",
)
# Define Next item prediction-task
prediction_task = tr.NextItemPredictionTask(hf_format=True, weight_tying=True)
# Define the config of the XLNet Transformer architecture
transformer_config = tr.XLNetConfig.build(
d_model=d_model, n_head=8, n_layer=2, total_seq_length=max_sequence_length
)
#Get the end-to-end model
model = transformer_config.to_torch_model(input_module, prediction_task)
Projecting inputs of NextItemPredictionTask to'64' As weight tying requires the input dimension '320' to be equal to the item-id embedding dimension '64'
model
Model(
(heads): ModuleList(
(0): Head(
(body): SequentialBlock(
(0): TabularSequenceFeatures(
(_aggregation): ConcatFeatures()
(to_merge): ModuleDict(
(continuous_module): SequentialBlock(
(0): ContinuousFeatures(
(filter_features): FilterFeatures()
(_aggregation): ConcatFeatures()
)
(1): SequentialBlock(
(0): DenseBlock(
(0): Linear(in_features=2, out_features=64, bias=True)
(1): ReLU(inplace=True)
)
)
(2): AsTabular()
)
(categorical_module): SequenceEmbeddingFeatures(
(filter_features): FilterFeatures()
(embedding_tables): ModuleDict(
(item_id-list_seq): Embedding(52743, 64, padding_idx=0)
(category-list_seq): Embedding(338, 64, padding_idx=0)
)
)
)
(projection_module): SequentialBlock(
(0): DenseBlock(
(0): Linear(in_features=192, out_features=320, bias=True)
(1): ReLU(inplace=True)
)
)
(_masking): MaskedLanguageModeling()
)
(1): TansformerBlock(
(transformer): XLNetModel(
(word_embedding): Embedding(1, 320)
(layer): ModuleList(
(0): XLNetLayer(
(rel_attn): XLNetRelativeAttention(
(layer_norm): LayerNorm((320,), eps=0.03, elementwise_affine=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(ff): XLNetFeedForward(
(layer_norm): LayerNorm((320,), eps=0.03, elementwise_affine=True)
(layer_1): Linear(in_features=320, out_features=1280, bias=True)
(layer_2): Linear(in_features=1280, out_features=320, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(dropout): Dropout(p=0.3, inplace=False)
)
(1): XLNetLayer(
(rel_attn): XLNetRelativeAttention(
(layer_norm): LayerNorm((320,), eps=0.03, elementwise_affine=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(ff): XLNetFeedForward(
(layer_norm): LayerNorm((320,), eps=0.03, elementwise_affine=True)
(layer_1): Linear(in_features=320, out_features=1280, bias=True)
(layer_2): Linear(in_features=1280, out_features=320, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(dropout): Dropout(p=0.3, inplace=False)
)
)
(dropout): Dropout(p=0.3, inplace=False)
)
(masking): MaskedLanguageModeling()
)
)
(prediction_task_dict): ModuleDict(
(next-item): NextItemPredictionTask(
(sequence_summary): SequenceSummary(
(summary): Identity()
(activation): Identity()
(first_dropout): Identity()
(last_dropout): Identity()
)
(metrics): ModuleList(
(0): NDCGAt()
(1): AvgPrecisionAt()
(2): RecallAt()
)
(loss): NLLLoss()
(embeddings): SequenceEmbeddingFeatures(
(filter_features): FilterFeatures()
(embedding_tables): ModuleDict(
(item_id-list_seq): Embedding(52743, 64, padding_idx=0)
(category-list_seq): Embedding(338, 64, padding_idx=0)
)
)
(item_embedding_table): Embedding(52743, 64, padding_idx=0)
(masking): MaskedLanguageModeling()
(task_block): SequentialBlock(
(0): DenseBlock(
(0): Linear(in_features=320, out_features=64, bias=True)
(1): ReLU(inplace=True)
)
)
(pre): Block(
(module): NextItemPredictionTask(
(item_embedding_table): Embedding(52743, 64, padding_idx=0)
(log_softmax): LogSoftmax(dim=-1)
)
)
)
)
)
)
)
3.3. Daily Fine-Tuning: Training over a time window¶
Now that the model is defined, we are going to launch training. For that, Transfromers4rec extends HF Transformers Trainer class to adapt the evaluation loop for session-based recommendation task and the calculation of ranking metrics. The original train()
method is not modified meaning that we leverage the efficient training implementation from that library, which manages for example half-precision (FP16) training.
Sets Training arguments
An additional argument data_loader_engine
is defined to automatically load the features needed for training using the schema. The default value is nvtabular
for optimized GPU-based data-loading. Optionally a PyarrowDataLoader
(pyarrow
) can also be used as a basic option, but it is slower and works only for small datasets, as the full data is loaded to CPU memory.
training_args = tr.trainer.T4RecTrainingArguments(
output_dir="./tmp",
max_sequence_length=20,
data_loader_engine='nvtabular',
num_train_epochs=10,
dataloader_drop_last=False,
per_device_train_batch_size = 384,
per_device_eval_batch_size = 512,
learning_rate=0.0005,
fp16=True,
report_to = [],
logging_steps=200
)
Instantiate the trainer
recsys_trainer = tr.Trainer(
model=model,
args=training_args,
schema=schema,
compute_metrics=True)
Using amp fp16 backend
Launches daily Training and Evaluation
In this demo, we will use the fit_and_evaluate
method that allows us to conduct a time-based finetuning by iteratively training and evaluating using a sliding time window: At each iteration, we use training data of a specific time index \(t\) to train the model then we evaluate on the validation data of next index \(t + 1\). Particularly, we set start time to 178 and end time to 180.
from transformers4rec.torch.utils.examples_utils import fit_and_evaluate
OT_results = fit_and_evaluate(recsys_trainer, start_time_index=178, end_time_index=180, input_dir='./preproc_sessions_by_day')
***** Running training *****
Num examples = 28800
Num Epochs = 10
Instantaneous batch size per device = 384
Total train batch size (w. parallel, distributed & accumulation) = 384
Gradient Accumulation steps = 1
Total optimization steps = 750
***** Launch training for day 178: *****
Step | Training Loss |
---|---|
200 | 7.721300 |
400 | 6.645300 |
600 | 6.346600 |
Saving model checkpoint to ./tmp/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Training completed. Do not forget to share your model on huggingface.co/models =)
***** Evaluation results for day 179:*****
eval/next-item/avg_precision@10 = 0.08169878274202347
eval/next-item/avg_precision@20 = 0.08574782311916351
eval/next-item/ndcg@10 = 0.11046259850263596
eval/next-item/ndcg@20 = 0.12532013654708862
eval/next-item/recall@10 = 0.20192678272724152
eval/next-item/recall@20 = 0.2608863115310669
***** Launch training for day 179: *****
***** Running training *****
Num examples = 20736
Num Epochs = 10
Instantaneous batch size per device = 384
Total train batch size (w. parallel, distributed & accumulation) = 384
Gradient Accumulation steps = 1
Total optimization steps = 540
Step | Training Loss |
---|---|
200 | 6.845400 |
400 | 6.425500 |
Saving model checkpoint to ./tmp/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Training completed. Do not forget to share your model on huggingface.co/models =)
***** Running training *****
Num examples = 16896
Num Epochs = 10
Instantaneous batch size per device = 384
Total train batch size (w. parallel, distributed & accumulation) = 384
Gradient Accumulation steps = 1
Total optimization steps = 440
***** Evaluation results for day 180:*****
eval/next-item/avg_precision@10 = 0.06733544170856476
eval/next-item/avg_precision@20 = 0.07138364017009735
eval/next-item/ndcg@10 = 0.09150198847055435
eval/next-item/ndcg@20 = 0.10659514367580414
eval/next-item/recall@10 = 0.16969697177410126
eval/next-item/recall@20 = 0.2293706238269806
***** Launch training for day 180: *****
Step | Training Loss |
---|---|
200 | 6.880100 |
400 | 6.469300 |
Training completed. Do not forget to share your model on huggingface.co/models =)
***** Evaluation results for day 181:*****
eval/next-item/avg_precision@10 = 0.12588053941726685
eval/next-item/avg_precision@20 = 0.13366079330444336
eval/next-item/ndcg@10 = 0.17167863249778748
eval/next-item/ndcg@20 = 0.200278639793396
eval/next-item/recall@10 = 0.3200370967388153
eval/next-item/recall@20 = 0.4313543736934662
Visualize the average over time metrics
OT_results
is a list of scores (accuracy metrics) for evaluation based on given start and end time_index. Since in this example we do evaluation on days 179, 180 and 181, we get three metrics in the list one for each day.
OT_results
{'indexed_by_time_eval/next-item/avg_precision@10': [0.08169878274202347,
0.06733544170856476,
0.12588053941726685],
'indexed_by_time_eval/next-item/avg_precision@20': [0.08574782311916351,
0.07138364017009735,
0.13366079330444336],
'indexed_by_time_eval/next-item/ndcg@10': [0.11046259850263596,
0.09150198847055435,
0.17167863249778748],
'indexed_by_time_eval/next-item/ndcg@20': [0.12532013654708862,
0.10659514367580414,
0.200278639793396],
'indexed_by_time_eval/next-item/recall@10': [0.20192678272724152,
0.16969697177410126,
0.3200370967388153],
'indexed_by_time_eval/next-item/recall@20': [0.2608863115310669,
0.2293706238269806,
0.4313543736934662]}
import numpy as np
# take the average of metric values over time
avg_results = {k: np.mean(v) for k,v in OT_results.items()}
for key in sorted(avg_results.keys()):
print(" %s = %s" % (key, str(avg_results[key])))
indexed_by_time_eval/next-item/avg_precision@10 = 0.09163825462261836
indexed_by_time_eval/next-item/avg_precision@20 = 0.09693075219790141
indexed_by_time_eval/next-item/ndcg@10 = 0.12454773982365926
indexed_by_time_eval/next-item/ndcg@20 = 0.1440646400054296
indexed_by_time_eval/next-item/recall@10 = 0.2305536170800527
indexed_by_time_eval/next-item/recall@20 = 0.3072037696838379
Saves the model
recsys_trainer._save_model_and_checkpoint(save_model_class=True)
Exports the preprocessing workflow and model in the format required by Triton server:**
NVTabular’s export_pytorch_ensemble()
function enables us to create model files and config files to be served to Triton Inference Server.
from nvtabular.inference.triton import export_pytorch_ensemble
from nvtabular.workflow import Workflow
workflow = Workflow.load("workflow_etl")
export_pytorch_ensemble(
model,
workflow,
sparse_max=recsys_trainer.get_train_dataloader().dataset.sparse_max,
name= "t4r_pytorch",
model_path= "/workspace/TF4Rec/models/",
label_columns =[],
)
4. Serving Ensemble Model to the Triton Inference Server
NVIDIA Triton Inference Server (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.
The last step of machine learning (ML)/deep learning (DL) pipeline is to deploy the ETL workflow and saved model to production. In the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the DL model for a prediction. Therefore, we deploy the NVTabular workflow with the PyTorch model as an ensemble model to Triton Inference. The ensemble model guarantees that the same transformation is applied to the raw inputs.
In this section, you will learn how to
to deploy saved NVTabular and PyTorch models to Triton Inference Server
send requests for predictions and get responses.
4.1. Pull and Start Inference Container
At this point, before connecing to the Triton Server, we launch the inference docker container and then load the ensemble t4r_pytorch
to the inference server. This is done with the scripts below:
Launch the docker container
docker run -it --gpus device=0 -p 8000:8000 -p 8001:8001 -p 8002:8002 -v <path_to_saved_models>:/workspace/models/ nvcr.io/nvidia/merlin/merlin-inference:21.09
This script will mount your local model-repository folder that includes your saved models from the previous cell to /workspace/models
directory in the merlin-inference docker container.
Start triton server
After you started the merlin-inference container, you can start triton server with the command below. You need to provide correct path of the models folder.
tritonserver --model-repository=<path_to_models> --model-control-mode=explicit
Note: The model-repository path for our example is /workspace/models
. The models haven’t been loaded, yet. Below, we will request the Triton server to load the saved ensemble model below.
Connect to the Triton Inference Server and check if the server is alive
import tritonhttpclient
try:
triton_client = tritonhttpclient.InferenceServerClient(url="localhost:8000", verbose=True)
print("client created.")
except Exception as e:
print("channel creation failed: " + str(e))
triton_client.is_server_live()
client created.
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
True
Load raw data for inference
We select the last 50 interactions and filter out sessions with less than 2 interactions.
import pandas as pd
interactions_merged_df = pd.read_parquet("/raid/data/yoochoose/interactions_merged_df.parquet")
interactions_merged_df = interactions_merged_df.sort_values('timestamp')
batch = interactions_merged_df[-50:]
sessions_to_use = batch.session_id.value_counts()
filtered_batch = batch[batch.session_id.isin(sessions_to_use[sessions_to_use.values>1].index.values)]
Send the request to triton server
triton_client.get_model_repository_index()
POST /v2/repository/index, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '77'}>
bytearray(b'[{"name":"t4r_pytorch"},{"name":"t4r_pytorch_nvt"},{"name":"t4r_pytorch_pt"}]')
[{'name': 't4r_pytorch'},
{'name': 't4r_pytorch_nvt'},
{'name': 't4r_pytorch_pt'}]
Load the ensemble model to triton
If all models are loaded successfully, you should be seeing successfully loaded
status next to each model name on your terminal.
triton_client.load_model(model_name="t4r_pytorch")
POST /v2/repository/models/t4r_pytorch/load, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '0'}>
Loaded model 't4r_pytorch'
import nvtabular.inference.triton as nvt_triton
import tritonclient.grpc as grpcclient
inputs = nvt_triton.convert_df_to_triton_input(filtered_batch.columns, filtered_batch, grpcclient.InferInput)
output_names = ["output"]
outputs = []
for col in output_names:
outputs.append(grpcclient.InferRequestedOutput(col))
MODEL_NAME_NVT = "t4r_pytorch"
with grpcclient.InferenceServerClient("localhost:8001") as client:
response = client.infer(MODEL_NAME_NVT, inputs)
print(col, ':\n', response.as_numpy(col))
output :
[[-13.53608 -14.02154 -8.15927 ... -13.912806 -14.316951
-13.758053 ]
[-15.546847 -16.178349 -7.881463 ... -16.058243 -17.085182
-15.761725 ]
[-12.496786 -13.111265 -8.736879 ... -13.031836 -13.436075
-12.741135 ]
...
[-14.425283 -14.728777 -7.756508 ... -14.73007 -15.161329
-14.437494 ]
[-15.366516 -15.427296 -7.3262033 ... -15.448423 -15.94982
-15.064197 ]
[-11.908236 -12.42782 -8.78612 ... -12.316145 -12.594669
-12.181059 ]]
Visualise top-k predictions
from transformers4rec.torch.utils.examples_utils import visualize_response
visualize_response(filtered_batch, response, top_k=5, session_col='session_id')
- Top-5 predictions for session `11257991`: 2365 || 260 || 196 || 33 || 1169
- Top-5 predictions for session `11270119`: 898 || 2214 || 1169 || 958 || 2814
- Top-5 predictions for session `11311424`: 898 || 2214 || 1987 || 958 || 1169
- Top-5 predictions for session `11336059`: 260 || 196 || 1169 || 2365 || 33
- Top-5 predictions for session `11394056`: 863 || 2365 || 196 || 33 || 1169
- Top-5 predictions for session `11399751`: 1987 || 2214 || 958 || 2814 || 1169
- Top-5 predictions for session `11401481`: 898 || 157 || 2214 || 620 || 2814
- Top-5 predictions for session `11421333`: 1126 || 196 || 127 || 1169 || 1987
- Top-5 predictions for session `11425751`: 196 || 33 || 958 || 1169 || 1987
- Top-5 predictions for session `11445777`: 33 || 196 || 1987 || 1126 || 1169
- Top-5 predictions for session `11457123`: 184 || 1169 || 313 || 33 || 863
- Top-5 predictions for session `11467406`: 898 || 958 || 2214 || 1169 || 1987
- Top-5 predictions for session `11493827`: 863 || 2365 || 196 || 33 || 1169
- Top-5 predictions for session `11528554`: 1061 || 33 || 863 || 1020 || 1169
- Top-5 predictions for session `11561822`: 127 || 1126 || 196 || 1169 || 1987
As you noticed, we first got prediction results (logits) from the trained model head, and then by using a handy util function visualize_response
we extracted top-k encoded item-ids from logits. Basically, we generated recommended items for a given session.
This is the end of the tutorial. You successfully
performed feature engineering with NVTabular
trained transformer architecture based session-based recommendation models with Transformers4Rec
deployed a trained model to Triton Inference Server, sent request and got responses from the server.
Unload models
triton_client.unload_model(model_name="t4r_pytorch")
triton_client.unload_model(model_name="t4r_pytorch_nvt")
triton_client.unload_model(model_name="t4r_pytorch_pt")
References
Merlin Transformers4rec: https://github.com/NVIDIA-Merlin/Transformers4Rec
Merlin NVTabular: https://github.com/NVIDIA-Merlin/NVTabular/tree/main/nvtabular
Triton inference server: https://github.com/triton-inference-server