[1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
End-to-end session-based recommendation with Transformers4Rec
In recent years, several deep learning-based algorithms have been proposed for recommendation systems while its adoption in industry deployments have been steeply growing. In particular, NLP inspired approaches have been successfully adapted for sequential and session-based recommendation problems, which are important for many domains like e-commerce, news and streaming media. Session-Based Recommender Systems (SBRS) have been proposed to model the sequence of interactions within the current user session, where a session is a short sequence of user interactions typically bounded by user inactivity. They have recently gained popularity due to their ability to capture short-term or contextual user preferences towards items.
The field of NLP has evolved significantly within the last decade, particularly due to the increased usage of deep learning. As a result, state of the art NLP approaches have inspired RecSys practitioners and researchers to adapt those architectures, especially for sequential and session-based recommendation problems. Here, we leverage one of the state-of-the-art Transformer-based architecture, XLNet with Masked Language Modeling (MLM) training technique (see our tutorial for details) for training a session-based model.
In this end-to-end-session-based recommnender model example, we use Transformers4Rec
library, which leverages the popular HuggingFace’s Transformers NLP library and make it possible to experiment with cutting-edge implementation of such architectures for sequential and session-based recommendation problems. For detailed explanations of the building blocks of Transformers4Rec meta-architecture visit
getting-started-session-based and tutorial example notebooks.
1. Setup
1.1. Import Libraries and Define Data Input and Output Paths
[2]:
import os
import glob
import numpy as np
import gc
import cudf
import cupy
import nvtabular as nvt
[3]:
DATA_FOLDER = "/workspace/data/"
FILENAME_PATTERN = 'yoochoose-clicks.dat'
DATA_PATH = os.path.join(DATA_FOLDER, FILENAME_PATTERN)
OUTPUT_FOLDER = "./yoochoose_transformed"
OVERWRITE = False
1.2. Download the data
In this notebook we are using the YOOCHOOSE
dataset which contains a collection of sessions from a retailer. Each session encapsulates the click events that the user performed in that session.
The dataset is available on Kaggle. You need to download it and copy to the DATA_FOLDER
path. Note that we are only using the yoochoose-clicks.dat
file.
1.3. Load and clean raw data
[4]:
interactions_df = cudf.read_csv(DATA_PATH, sep=',',
names=['session_id','timestamp', 'item_id', 'category'],
dtype=['int', 'datetime64[s]', 'int', 'int'])
Remove repeated interactions within the same session
[5]:
print("Count with in-session repeated interactions: {}".format(len(interactions_df)))
# Sorts the dataframe by session and timestamp, to remove consecutive repetitions
interactions_df.timestamp = interactions_df.timestamp.astype(int)
interactions_df = interactions_df.sort_values(['session_id', 'timestamp'])
past_ids = interactions_df['item_id'].shift(1).fillna()
session_past_ids = interactions_df['session_id'].shift(1).fillna()
# Keeping only no consecutive repeated in session interactions
interactions_df = interactions_df[~((interactions_df['session_id'] == session_past_ids) & (interactions_df['item_id'] == past_ids))]
print("Count after removed in-session repeated interactions: {}".format(len(interactions_df)))
Count with in-session repeated interactions: 33003944
Count after removed in-session repeated interactions: 28971543
Creates new feature with the timestamp when the item was first seen
[6]:
items_first_ts_df = interactions_df.groupby('item_id').agg({'timestamp': 'min'}).reset_index().rename(columns={'timestamp': 'itemid_ts_first'})
interactions_merged_df = interactions_df.merge(items_first_ts_df, on=['item_id'], how='left')
interactions_merged_df.head()
[6]:
session_id | timestamp | item_id | category | itemid_ts_first | |
---|---|---|---|---|---|
0 | 549 | 1396774534 | 214714927 | 0 | 1396334996 |
1 | 549 | 1396774556 | 214517450 | 0 | 1396329825 |
2 | 549 | 1396774617 | 214714929 | 0 | 1396341783 |
3 | 549 | 1396774647 | 214518555 | 0 | 1396327272 |
4 | 549 | 1396774664 | 214639297 | 0 | 1396353119 |
[7]:
# free gpu memory
del interactions_df, session_past_ids, items_first_ts_df
gc.collect()
[7]:
0
2. Define a preprocessing workflow with NVTabular
NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems. It provides a high level abstraction to simplify code and accelerates computation on the GPU using the RAPIDS cuDF library.
NVTabular supports different feature engineering transformations required by deep learning (DL) models such as Categorical encoding and numerical feature normalization. It also supports feature engineering and generating sequential features.
More information about the supported features can be found here.
2.1 Feature engineering: Create and Transform items features
In this cell, we are defining three transformations ops:
Encoding categorical variables using
Categorify()
op. We setstart_index
to 1, so that encoded null values start from1
instead of0
because we reserve0
for padding the sequence features.
Deriving temporal features from timestamp and computing their cyclical representation using a custom lambda function.
Computing the item recency in days using a custom Op. Note that item recency is defined as the difference between the first occurrence of the item in dataset and the actual date of item interaction.
For more ETL workflow examples, visit NVTabular example notebooks.
[8]:
# Encodes categorical features as contiguous integers
cat_feats = nvt.ColumnSelector(['session_id', 'category', 'item_id']) >> nvt.ops.Categorify(start_index=1)
# create time features
session_ts = nvt.ColumnSelector(['timestamp'])
session_time = (
session_ts >>
nvt.ops.LambdaOp(lambda col: cudf.to_datetime(col, unit='s')) >>
nvt.ops.Rename(name = 'event_time_dt')
)
sessiontime_weekday = (
session_time >>
nvt.ops.LambdaOp(lambda col: col.dt.weekday) >>
nvt.ops.Rename(name ='et_dayofweek')
)
# Derive cyclical features: Defines a custom lambda function
def get_cycled_feature_value_sin(col, max_value):
value_scaled = (col + 0.000001) / max_value
value_sin = np.sin(2*np.pi*value_scaled)
return value_sin
weekday_sin = sessiontime_weekday >> (lambda col: get_cycled_feature_value_sin(col+1, 7)) >> nvt.ops.Rename(name = 'et_dayofweek_sin')
# Compute Item recency: Define a custom Op
class ItemRecency(nvt.ops.Operator):
def transform(self, columns, gdf):
for column in columns.names:
col = gdf[column]
item_first_timestamp = gdf['itemid_ts_first']
delta_days = (col - item_first_timestamp) / (60*60*24)
gdf[column + "_age_days"] = delta_days * (delta_days >=0)
return gdf
def output_column_names(self, columns):
return nvt.ColumnSelector([column + "_age_days" for column in columns.names])
def dependencies(self):
return ["itemid_ts_first"]
recency_features = session_ts >> ItemRecency()
# Apply standardization to this continuous feature
recency_features_norm = recency_features >> nvt.ops.LogOp() >> nvt.ops.Normalize() >> nvt.ops.Rename(name='product_recency_days_log_norm')
time_features = (
session_time +
sessiontime_weekday +
weekday_sin +
recency_features_norm
)
features = nvt.ColumnSelector(['timestamp', 'session_id']) + cat_feats + time_features
2.2 Defines the preprocessing of sequential features
Once the item features are generated, the objective of this cell is grouping interactions at the session level, sorting the interactions by time. We additionally truncate all sessions to first 20 interactions and filter out sessions with less than 2 interactions.
[9]:
# Define Groupby Operator
groupby_features = features >> nvt.ops.Groupby(
groupby_cols=["session_id"],
sort_cols=["timestamp"],
aggs={
'item_id': ["list", "count"],
'category': ["list"],
'timestamp': ["first"],
'event_time_dt': ["first"],
'et_dayofweek_sin': ["list"],
'product_recency_days_log_norm': ["list"]
},
name_sep="-")
# Truncate sequence features to first interacted 20 items
SESSIONS_MAX_LENGTH = 20
groupby_features_list = groupby_features['item_id-list', 'category-list', 'et_dayofweek_sin-list', 'product_recency_days_log_norm-list']
groupby_features_truncated = groupby_features_list >> nvt.ops.ListSlice(0, SESSIONS_MAX_LENGTH) >> nvt.ops.Rename(postfix = '_seq')
# Calculate session day index based on 'event_time_dt-first' column
day_index = ((groupby_features['event_time_dt-first']) >>
nvt.ops.LambdaOp(lambda col: (col - col.min()).dt.days +1) >>
nvt.ops.Rename(f = lambda col: "day_index")
)
# Select features for training
selected_features = groupby_features['session_id', 'item_id-count'] + groupby_features_truncated + day_index
# Filter out sessions with less than 2 interactions
MINIMUM_SESSION_LENGTH = 2
filtered_sessions = selected_features >> nvt.ops.Filter(f=lambda df: df["item_id-count"] >= MINIMUM_SESSION_LENGTH)
Avoid Numba low occupancy warnings
[ ]:
from numba import config
config.CUDA_LOW_OCCUPANCY_WARNINGS = 0
2.3 Execute NVTabular workflow
Once we have defined the general workflow (filtered_sessions
), we provide our cudf dataset to nvt.Dataset class which is optimized to split data into chunks that can fit in device memory and to handle the calculation of complex global statistics. Then, we execute the pipeline that fits and transforms data to get the desired output features.
[ ]:
dataset = nvt.Dataset(interactions_merged_df)
workflow = nvt.Workflow(filtered_sessions)
# Learns features statistics necessary of the preprocessing workflow
workflow.fit(dataset)
# Apply the preprocessing workflow in the dataset and converts the resulting Dask cudf dataframe to a cudf dataframe
sessions_gdf = workflow.transform(dataset).compute()
Let’s print the head of our preprocessed dataset. You can notice that now each example (row) is a session and the sequential features with respect to user interactions were converted to lists with matching length.
[11]:
sessions_gdf.head()
[11]:
session_id | item_id-count | item_id-list_seq | category-list_seq | et_dayofweek_sin-list_seq | product_recency_days_log_norm-list_seq | day_index | |
---|---|---|---|---|---|---|---|
0 | 2 | 200 | [2223, 2125, 1800, 123, 3030, 1861, 1076, 1285... | [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... | [1.1285199e-06, 1.1285199e-06, 1.1285199e-06, ... | [-1.1126341, -0.9665389, -0.1350116, -0.127809... | 27 |
1 | 3 | 200 | [26562, 35137, 19260, 46449, 29027, 39096, 272... | [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... | [0.43388295, 0.43388295, 0.43388295, 0.4338829... | [0.40848607, 0.39331725, 0.5418466, -3.0278225... | 58 |
2 | 4 | 200 | [23212, 30448, 16468, 2052, 22490, 31097, 6243... | [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... | [0.9749277, 0.9749277, 0.9749277, 0.9749277, 0... | [0.6801631, 0.7174695, 0.7185285, 0.7204116, 0... | 71 |
3 | 5 | 200 | [230, 451, 732, 1268, 2014, 567, 497, 439, 338... | [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, ... | [0.43388295, 0.43388295, 0.43388295, 0.4338829... | [1.3680888, -0.6530481, -0.69314253, -0.590593... | 149 |
4 | 6 | 200 | [23, 70, 160, 70, 90, 742, 851, 359, 734, 878,... | [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... | [0.43388295, 0.43388295, 0.43388295, 0.4338829... | [1.3714824, 1.3715883, 1.3715737, 1.3715955, 1... | 149 |
Saves the preprocessing workflow
[12]:
workflow.save('workflow_etl')
2.4 Export pre-processed data by day
In this example we are going to split the preprocessed parquet files by days, to allow for temporal training and evaluation. There will be a folder for each day and three parquet files within each day: train.parquet
, validation.parquet
and test.parquet
P.s. It is worthwhile a note that the dataset have a single categorical feature (category), but it is inconsistent over time in the dataset. All interactions before day 84 (2014-06-23) have the same value for that feature, whereas many other categories are introduced afterwards. Thus for the demo we save only the last five days.
[13]:
sessions_gdf = sessions_gdf[sessions_gdf.day_index>=178]
[14]:
from transformers4rec.data.preprocessing import save_time_based_splits
save_time_based_splits(data=nvt.Dataset(sessions_gdf),
output_dir= "./preproc_sessions_by_day",
partition_col='day_index',
timestamp_col='session_id',
)
Creating time-based splits: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.64it/s]
[15]:
from transformers4rec.torch.utils.examples_utils import list_files
list_files('./preproc_sessions_by_day')
preproc_sessions_by_day/
180/
test.parquet
valid.parquet
train.parquet
181/
test.parquet
valid.parquet
train.parquet
179/
test.parquet
valid.parquet
train.parquet
182/
test.parquet
valid.parquet
train.parquet
178/
test.parquet
valid.parquet
train.parquet
[16]:
# free gpu memory
del sessions_gdf
gc.collect()
[16]:
0
3. Model definition using Transformers4Rec
3.1 Get the schema
The library uses a schema format to configure the input features and automatically creates the necessary layers. This protobuf text file contains the description of each input feature by defining: the name, the type, the number of elements of a list column, the cardinality of a categorical feature and the min and max values of each feature. In addition, the annotation field contains the tags such as specifying continuous
and categorical
features, the target
column or the
item_id
feature, among others.
[17]:
from merlin_standard_lib import Schema
SCHEMA_PATH = "schema_demo.pb"
schema = Schema().from_proto_text(SCHEMA_PATH)
!cat $SCHEMA_PATH
feature {
name: "item_id-list_seq"
value_count {
min: 2
max: 185
}
type: INT
int_domain {
name: "item_id/list"
min: 1
max: 52742
is_categorical: true
}
annotation {
tag: "item_id"
tag: "list"
tag: "categorical"
tag: "item"
}
}
feature {
name: "session_id"
type: INT
int_domain {
name: "session_id"
min: 1
max: 9249733
is_categorical: false
}
annotation {
tag: "groupby_col"
}
}
feature {
name: "category-list_seq"
value_count {
min: 2
max: 185
}
type: INT
int_domain {
name: "category-list_seq"
min: 1
max: 337
is_categorical: true
}
annotation {
tag: "list"
tag: "categorical"
tag: "item"
}
}
feature {
name: "product_recency_days_log_norm-list_seq"
value_count {
min: 2
max: 185
}
type: FLOAT
float_domain {
name: "product_recency_days_log_norm-list_seq"
min: -2.9177291
max: 1.5231701
}
annotation {
tag: "continuous"
tag: "list"
}
}
feature {
name: "et_dayofweek_sin-list_seq"
value_count {
min: 2
max: 185
}
type: FLOAT
float_domain {
name: "et_dayofweek_sin-list_seq"
min: 0.7421683
max: 0.9995285
}
annotation {
tag: "time"
tag: "list"
}
}
We can select the subset of features we want to use for training the model by their tags or their names.
[18]:
schema = schema.select_by_name(
['item_id-list_seq', 'category-list_seq', 'product_recency_days_log_norm-list_seq', 'et_dayofweek_sin-list_seq']
)
3.2 Define the end-to-end Session-based Transformer-based recommendation model
For session-based recommendation model definition, the end-to-end model definition requires four steps:
Instantiate TabularSequenceFeatures input-module from schema to prepare the embedding tables of categorical variables and project continuous features, if specified. In addition, the module provides different aggregation methods (e.g. ‘concat’, ‘elementwise-sum’) to merge input features and generate the sequence of interactions embeddings. The module also supports language modeling tasks to prepare masked labels for training and evaluation (e.g: ‘mlm’ for masked language modeling)
Next, we need to define one or multiple prediction tasks. For this demo, we are going to use NextItemPredictionTask with
Masked Language modeling
: during training randomly selected items are masked and predicted using the unmasked sequence items. For inference it is meant to always predict the next item to be interacted with.Then we construct a
transformer_config
based on the architectures provided by Hugging Face Transformers framework.Finally we link the transformer-body to the inputs and the prediction tasks to get the final pytorch
Model
class.
For more details about the features supported by each sub-module, please check out the library documentation page.
[19]:
from transformers4rec import torch as tr
max_sequence_length, d_model = 20, 320
# Define input module to process tabular input-features and to prepare masked inputs
input_module = tr.TabularSequenceFeatures.from_schema(
schema,
max_sequence_length=max_sequence_length,
continuous_projection=64,
aggregation="concat",
d_output=d_model,
masking="mlm",
)
# Define Next item prediction-task
prediction_task = tr.NextItemPredictionTask(hf_format=True, weight_tying=True)
# Define the config of the XLNet Transformer architecture
transformer_config = tr.XLNetConfig.build(
d_model=d_model, n_head=8, n_layer=2, total_seq_length=max_sequence_length
)
#Get the end-to-end model
model = transformer_config.to_torch_model(input_module, prediction_task)
Projecting inputs of NextItemPredictionTask to'64' As weight tying requires the input dimension '320' to be equal to the item-id embedding dimension '64'
[20]:
model
[20]:
Model(
(heads): ModuleList(
(0): Head(
(body): SequentialBlock(
(0): TabularSequenceFeatures(
(_aggregation): ConcatFeatures()
(to_merge): ModuleDict(
(continuous_module): SequentialBlock(
(0): ContinuousFeatures(
(filter_features): FilterFeatures()
(_aggregation): ConcatFeatures()
)
(1): SequentialBlock(
(0): DenseBlock(
(0): Linear(in_features=1, out_features=64, bias=True)
(1): ReLU(inplace=True)
)
)
(2): AsTabular()
)
(categorical_module): SequenceEmbeddingFeatures(
(filter_features): FilterFeatures()
(embedding_tables): ModuleDict(
(item_id-list_seq): Embedding(52743, 64, padding_idx=0)
(category-list_seq): Embedding(338, 64, padding_idx=0)
)
)
)
(projection_module): SequentialBlock(
(0): DenseBlock(
(0): Linear(in_features=192, out_features=320, bias=True)
(1): ReLU(inplace=True)
)
)
(_masking): MaskedLanguageModeling()
)
(1): TansformerBlock(
(transformer): XLNetModel(
(word_embedding): Embedding(1, 320)
(layer): ModuleList(
(0): XLNetLayer(
(rel_attn): XLNetRelativeAttention(
(layer_norm): LayerNorm((320,), eps=0.03, elementwise_affine=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(ff): XLNetFeedForward(
(layer_norm): LayerNorm((320,), eps=0.03, elementwise_affine=True)
(layer_1): Linear(in_features=320, out_features=1280, bias=True)
(layer_2): Linear(in_features=1280, out_features=320, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(dropout): Dropout(p=0.3, inplace=False)
)
(1): XLNetLayer(
(rel_attn): XLNetRelativeAttention(
(layer_norm): LayerNorm((320,), eps=0.03, elementwise_affine=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(ff): XLNetFeedForward(
(layer_norm): LayerNorm((320,), eps=0.03, elementwise_affine=True)
(layer_1): Linear(in_features=320, out_features=1280, bias=True)
(layer_2): Linear(in_features=1280, out_features=320, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(dropout): Dropout(p=0.3, inplace=False)
)
)
(dropout): Dropout(p=0.3, inplace=False)
)
(masking): MaskedLanguageModeling()
)
)
(prediction_task_dict): ModuleDict(
(next-item): NextItemPredictionTask(
(sequence_summary): SequenceSummary(
(summary): Identity()
(activation): Identity()
(first_dropout): Identity()
(last_dropout): Identity()
)
(metrics): ModuleList(
(0): NDCGAt()
(1): AvgPrecisionAt()
(2): RecallAt()
)
(loss): NLLLoss()
(embeddings): SequenceEmbeddingFeatures(
(filter_features): FilterFeatures()
(embedding_tables): ModuleDict(
(item_id-list_seq): Embedding(52743, 64, padding_idx=0)
(category-list_seq): Embedding(338, 64, padding_idx=0)
)
)
(item_embedding_table): Embedding(52743, 64, padding_idx=0)
(masking): MaskedLanguageModeling()
(task_block): SequentialBlock(
(0): DenseBlock(
(0): Linear(in_features=320, out_features=64, bias=True)
(1): ReLU(inplace=True)
)
)
(pre): Block(
(module): NextItemPredictionTask(
(item_embedding_table): Embedding(52743, 64, padding_idx=0)
(log_softmax): LogSoftmax(dim=-1)
)
)
)
)
)
)
)
3.3. Daily Fine-Tuning: Training over a time window¶
Now that the model is defined, we are going to launch training. For that, Transfromers4rec extends HF Transformers Trainer class to adapt the evaluation loop for session-based recommendation task and the calculation of ranking metrics. The original train()
method is not modified meaning that we leverage the efficient training implementation from that library, which manages for example half-precision (FP16) training.
Sets Training arguments
An additional argument data_loader_engine
is defined to automatically load the features needed for training using the schema. The default value is nvtabular
for optimized GPU-based data-loading. Optionally a PyarrowDataLoader
(pyarrow
) can also be used as a basic option, but it is slower and works only for small datasets, as the full data is loaded to CPU memory.
[21]:
training_args = tr.trainer.T4RecTrainingArguments(
output_dir="./tmp",
max_sequence_length=20,
data_loader_engine='nvtabular',
num_train_epochs=10,
dataloader_drop_last=False,
per_device_train_batch_size = 384,
per_device_eval_batch_size = 512,
learning_rate=0.0005,
fp16=True,
report_to = [],
logging_steps=200
)
Instantiate the trainer
[22]:
recsys_trainer = tr.Trainer(
model=model,
args=training_args,
schema=schema,
compute_metrics=True)
Using amp fp16 backend
Launches daily Training and Evaluation
In this demo, we will use the fit_and_evaluate
method that allows us to conduct a time-based finetuning by iteratively training and evaluating using a sliding time window: At each iteration, we use training data of a specific time index \(t\) to train the model then we evaluate on the validation data of next index \(t + 1\). Particularly, the start time is set to 178 and end time to 180.
[23]:
from transformers4rec.torch.utils.examples_utils import fit_and_evaluate
aot_results = fit_and_evaluate(recsys_trainer, start_time_index=178, end_time_index=178, input_dir='./preproc_sessions_by_day')
***** Running training *****
Num examples = 28800
Num Epochs = 10
Instantaneous batch size per device = 384
Total train batch size (w. parallel, distributed & accumulation) = 1536
Gradient Accumulation steps = 1
Total optimization steps = 750
***** Launch training for day 178: *****
Step | Training Loss |
---|---|
200 | 7.807600 |
400 | 6.737500 |
600 | 6.462300 |
Saving model checkpoint to ./tmp/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Training completed. Do not forget to share your model on huggingface.co/models =)
***** Evaluation results for day 179:*****
eval/next-item/avg_precision_@10 = 0.06432348489761353
eval/next-item/avg_precision_@20 = 0.06879562139511108
eval/next-item/ndcg_@10 = 0.09142451733350754
eval/next-item/ndcg_@20 = 0.10751541703939438
eval/next-item/recall_@10 = 0.17764931917190552
eval/next-item/recall_@20 = 0.24123314023017883
Visualize the average over time metrics
[24]:
mean_results = {k: np.mean(v) for k,v in aot_results.items()}
for key in sorted(mean_results.keys()):
print(" %s = %s" % (key, str(mean_results[key])))
AOT_eval/next-item/avg_precision@10 = 0.06432348489761353
AOT_eval/next-item/avg_precision@20 = 0.06879562139511108
AOT_eval/next-item/ndcg@10 = 0.09142451733350754
AOT_eval/next-item/ndcg@20 = 0.10751541703939438
AOT_eval/next-item/recall@10 = 0.17764931917190552
AOT_eval/next-item/recall@20 = 0.24123314023017883
Saves the model
[25]:
recsys_trainer._save_model_and_checkpoint(save_model_class=True)
Saving model checkpoint to ./tmp/checkpoint-750
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Exports the preprocessing workflow and model in the format required by Triton server:**
NVTabular’s export_pytorch_ensemble()
function enables us to create model files and config files to be served to Triton Inference Server.
[26]:
from nvtabular.inference.triton import export_pytorch_ensemble
export_pytorch_ensemble(
model,
workflow,
sparse_max=recsys_trainer.get_train_dataloader().dataset.sparse_max,
name= "t4r_pytorch",
model_path= "/workspace/TF4Rec/models/",
label_columns =[],
)
4. Serving Ensemble Model to the Triton Inference Server
NVIDIA Triton Inference Server (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.
The last step of machine learning (ML)/deep learning (DL) pipeline is to deploy the ETL workflow and saved model to production. In the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the DL model for a prediction. Therefore, we deploy the NVTabular workflow with the PyTorch model as an ensemble model to Triton Inference. The ensemble model guarantees that the same transformation is applied to the raw inputs.
In this section, you will learn how to - to deploy saved NVTabular and PyTorch models to Triton Inference Server - send requests for predictions and get responses.
4.1. Pull and Start Inference Container
At this point, before connecing to the Triton Server, we launch the inference docker container and then load the ensemble t4r_pytorch
to the inference server. This is done with the scripts below:
Launch the docker container
docker run -it --gpus device=0 -p 8000:8000 -p 8001:8001 -p 8002:8002 -v <path_to_saved_models>:/workspace/models/ nvcr.io/nvidia/merlin/merlin-inference:21.09
This script will mount your local model-repository folder that includes your saved models from the previous cell to /workspace/models
directory in the merlin-inference docker container.
Start triton server After you started the merlin-inference container, you can start triton server with the command below. You need to provide correct path of the models folder.
tritonserver --model-repository=<path_to_models> --model-control-mode=explicit
Note: The model-repository path for our example is /workspace/models
. The models haven’t been loaded, yet. Below, we will request the Triton server to load the saved ensemble model below.
Connect to the Triton Inference Server and check if the server is alive
[17]:
import tritonhttpclient
try:
triton_client = tritonhttpclient.InferenceServerClient(url="localhost:8000", verbose=True)
print("client created.")
except Exception as e:
print("channel creation failed: " + str(e))
triton_client.is_server_live()
client created.
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
[17]:
True
Load raw data for inference
We select the last 50 interactions and filter out sessions with less than 2 interactions.
[19]:
interactions_merged_df=interactions_merged_df.sort_values('timestamp')
batch = interactions_merged_df[-50:]
sessions_to_use = batch.session_id.value_counts()
filtered_batch = batch[batch.session_id.isin(sessions_to_use[sessions_to_use.values>1].index.values)]
Send the request to triton server
[20]:
triton_client.get_model_repository_index()
POST /v2/repository/index, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '77'}>
bytearray(b'[{"name":"t4r_pytorch"},{"name":"t4r_pytorch_nvt"},{"name":"t4r_pytorch_pt"}]')
[20]:
[{'name': 't4r_pytorch'},
{'name': 't4r_pytorch_nvt'},
{'name': 't4r_pytorch_pt'}]
Load the ensemble model to triton
If all models are loaded successfully, you should be seeing successfully loaded
status next to each model name on your terminal.
[21]:
triton_client.load_model(model_name="t4r_pytorch")
POST /v2/repository/models/t4r_pytorch/load, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '0'}>
Loaded model 't4r_pytorch'
[22]:
import nvtabular.inference.triton as nvt_triton
import tritonclient.grpc as grpcclient
inputs = nvt_triton.convert_df_to_triton_input(filtered_batch.columns, filtered_batch, grpcclient.InferInput)
output_names = ["output"]
outputs = []
for col in output_names:
outputs.append(grpcclient.InferRequestedOutput(col))
MODEL_NAME_NVT = "t4r_pytorch"
with grpcclient.InferenceServerClient("localhost:8001") as client:
response = client.infer(MODEL_NAME_NVT, inputs)
print(col, ':\n', response.as_numpy(col))
output :
[[-13.53608 -14.02154 -8.15927 ... -13.912806 -14.316951
-13.758053 ]
[-15.546847 -16.178349 -7.881463 ... -16.058243 -17.085182
-15.761725 ]
[-12.496786 -13.111265 -8.736879 ... -13.031836 -13.436075
-12.741135 ]
...
[-14.425283 -14.728777 -7.756508 ... -14.73007 -15.161329
-14.437494 ]
[-15.366516 -15.427296 -7.3262033 ... -15.448423 -15.94982
-15.064197 ]
[-11.908236 -12.42782 -8.78612 ... -12.316145 -12.594669
-12.181059 ]]
Visualise top-k predictions
[23]:
from transformers4rec.torch.utils.examples_utils import visualize_response
visualize_response(filtered_batch, response, top_k=5, session_col='session_id')
- Top-5 predictions for session `11257991`: 2365 || 260 || 196 || 33 || 1169
- Top-5 predictions for session `11270119`: 898 || 2214 || 1169 || 958 || 2814
- Top-5 predictions for session `11311424`: 898 || 2214 || 1987 || 958 || 1169
- Top-5 predictions for session `11336059`: 260 || 196 || 1169 || 2365 || 33
- Top-5 predictions for session `11394056`: 863 || 2365 || 196 || 33 || 1169
- Top-5 predictions for session `11399751`: 1987 || 2214 || 958 || 2814 || 1169
- Top-5 predictions for session `11401481`: 898 || 157 || 2214 || 620 || 2814
- Top-5 predictions for session `11421333`: 1126 || 196 || 127 || 1169 || 1987
- Top-5 predictions for session `11425751`: 196 || 33 || 958 || 1169 || 1987
- Top-5 predictions for session `11445777`: 33 || 196 || 1987 || 1126 || 1169
- Top-5 predictions for session `11457123`: 184 || 1169 || 313 || 33 || 863
- Top-5 predictions for session `11467406`: 898 || 958 || 2214 || 1169 || 1987
- Top-5 predictions for session `11493827`: 863 || 2365 || 196 || 33 || 1169
- Top-5 predictions for session `11528554`: 1061 || 33 || 863 || 1020 || 1169
- Top-5 predictions for session `11561822`: 127 || 1126 || 196 || 1169 || 1987
As you noticed, we first got prediction results (logits) from the trained model head, and then by using a handy util function visualize_response
we extracted top-k encoded item-ids from logits. Basically, we generated recommended items for a given session.
This is the end of the tutorial. You successfully
performed feature engineering with NVTabular
trained transformer architecture based session-based recommendation models with Transformers4Rec
deployed a trained model to Triton Inference Server, sent request and got responses from the server.
Unload models
[ ]:
triton_client.unload_model(model_name="t4r_pytorch")
triton_client.unload_model(model_name="t4r_pytorch_nvt")
triton_client.unload_model(model_name="t4r_pytorch_pt")
References
Merlin Transformers4rec: https://github.com/NVIDIA-Merlin/Transformers4Rec
Merlin NVTabular: https://github.com/NVIDIA-Merlin/NVTabular/tree/main/nvtabular
Triton inference server: https://github.com/triton-inference-server