# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
End-to-end session-based recommendations with PyTorch
In recent years, several deep learning-based algorithms have been proposed for recommendation systems while its adoption in industry deployments have been steeply growing. In particular, NLP inspired approaches have been successfully adapted for sequential and session-based recommendation problems, which are important for many domains like e-commerce, news and streaming media. Session-Based Recommender Systems (SBRS) have been proposed to model the sequence of interactions within the current user session, where a session is a short sequence of user interactions typically bounded by user inactivity. They have recently gained popularity due to their ability to capture short-term or contextual user preferences towards items.
The field of NLP has evolved significantly within the last decade, particularly due to the increased usage of deep learning. As a result, state of the art NLP approaches have inspired RecSys practitioners and researchers to adapt those architectures, especially for sequential and session-based recommendation problems. Here, we leverage one of the state-of-the-art Transformer-based architecture, XLNet with Masked Language Modeling (MLM) training technique (see our tutorial for details) for training a session-based model.
In this end-to-end-session-based recommnender model example, we use Transformers4Rec
library, which leverages the popular HuggingFace’s Transformers NLP library and make it possible to experiment with cutting-edge implementation of such architectures for sequential and session-based recommendation problems. For detailed explanations of the building blocks of Transformers4Rec meta-architecture visit getting-started-session-based and tutorial example notebooks.
1. Model definition using Transformers4Rec
In the previous notebook, we have created sequential features and saved our processed data frames as parquet files. Now we use these processed parquet files to train a session-based recommendation model with the XLNet architecture.
1.1 Get the schema
The library uses a schema format to configure the input features and automatically creates the necessary layers. This protobuf text file contains the description of each input feature by defining: the name, the type, the number of elements of a list column, the cardinality of a categorical feature and the min and max values of each feature. In addition, the annotation field contains the tags such as specifying the continuous
and categorical
features, the target
column or the item_id
feature, among others.
We create the schema object by reading the processed train parquet file generated by NVTabular pipeline in the previous, 01-ETL-with-NVTabular, notebook.
import os
INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "/workspace/data")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/preproc_sessions_by_day")
from merlin.schema import Schema
from merlin.io import Dataset
train = Dataset(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
schema = train.schema
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
warnings.warn(
We can select the subset of features we want to use for training the model by their tags or their names.
schema = schema.select_by_name(
['item_id-list', 'category-list', 'product_recency_days_log_norm-list', 'et_dayofweek_sin-list']
)
We can print out the schema.
schema
name | tags | dtype | is_list | is_ragged | properties.num_buckets | properties.freq_threshold | properties.max_size | properties.start_index | properties.cat_path | properties.embedding_sizes.cardinality | properties.embedding_sizes.dimension | properties.domain.min | properties.domain.max | properties.domain.name | properties.value_count.min | properties.value_count.max | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | item_id-list | (Tags.ITEM_ID, Tags.CATEGORICAL, Tags.ID, Tags... | DType(name='int64', element_type=<ElementType.... | True | True | NaN | 0.0 | 0.0 | 1.0 | .//categories/unique.item_id.parquet | 52741.0 | 512.0 | 0.0 | 52740.0 | item_id | 0 | 20 |
1 | category-list | (Tags.CATEGORICAL, Tags.LIST) | DType(name='int64', element_type=<ElementType.... | True | True | NaN | 0.0 | 0.0 | 1.0 | .//categories/unique.category.parquet | 336.0 | 42.0 | 0.0 | 335.0 | category | 0 | 20 |
2 | product_recency_days_log_norm-list | (Tags.LIST, Tags.CONTINUOUS) | DType(name='float32', element_type=<ElementTyp... | True | True | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 0 | 20 |
3 | et_dayofweek_sin-list | (Tags.LIST, Tags.CONTINUOUS) | DType(name='float32', element_type=<ElementTyp... | True | True | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 0 | 20 |
1.2 Define the end-to-end Session-based Transformer-based recommendation model
For defining a session-based recommendation model, the end-to-end model definition requires four steps:
Instantiate TabularSequenceFeatures input-module from schema to prepare the embedding tables of categorical variables and project continuous features, if specified. In addition, the module provides different aggregation methods (e.g. ‘concat’, ‘elementwise-sum’) to merge input features and generate the sequence of interactions embeddings. The module also supports language modeling tasks to prepare masked labels for training and evaluation (e.g: ‘mlm’ for masked language modeling)
Next, we need to define one or multiple prediction tasks. For this demo, we are going to use NextItemPredictionTask with
Masked Language modeling
: during training, randomly selected items are masked and predicted using the unmasked sequence items. For inference, it is meant to always predict the next item to be interacted with.Then we construct a
transformer_config
based on the architectures provided by Hugging Face Transformers framework.Finally we link the transformer-body to the inputs and the prediction tasks to get the final pytorch
Model
class.
For more details about the features supported by each sub-module, please check out the library documentation page.
from transformers4rec import torch as tr
max_sequence_length, d_model = 20, 320
# Define input module to process tabular input-features and to prepare masked inputs
input_module = tr.TabularSequenceFeatures.from_schema(
schema,
max_sequence_length=max_sequence_length,
continuous_projection=64,
aggregation="concat",
d_output=d_model,
masking="mlm",
)
# Define Next item prediction-task
prediction_task = tr.NextItemPredictionTask(weight_tying=True)
# Define the config of the XLNet Transformer architecture
transformer_config = tr.XLNetConfig.build(
d_model=d_model, n_head=8, n_layer=2, total_seq_length=max_sequence_length
)
# Get the end-to-end model
model = transformer_config.to_torch_model(input_module, prediction_task)
Projecting inputs of NextItemPredictionTask to'64' As weight tying requires the input dimension '320' to be equal to the item-id embedding dimension '64'
You can print out the model structure by uncommenting the line below.
#model
1.3. Daily Fine-Tuning: Training over a time window¶
Now that the model is defined, we are going to launch training. For that, Transfromers4rec extends HF Transformers Trainer class to adapt the evaluation loop for session-based recommendation task and the calculation of ranking metrics. The original train()
method is not modified meaning that we leverage the efficient training implementation from that library, which manages, for example, half-precision (FP16) training.
Set the training arguments
An additional argument data_loader_engine
is defined to automatically load the features needed for training using the schema. The default value is merlin
for optimized GPU-based data-loading. Optionally a PyarrowDataLoader
(pyarrow
) can also be used as a basic option, but it is slower and works only for small datasets, as the full data is loaded to CPU memory.
training_args = tr.trainer.T4RecTrainingArguments(
output_dir="./tmp",
max_sequence_length=20,
data_loader_engine='merlin',
num_train_epochs=10,
dataloader_drop_last=False,
per_device_train_batch_size = 384,
per_device_eval_batch_size = 512,
learning_rate=0.0005,
fp16=True,
report_to = [],
logging_steps=200
)
Instantiate the trainer
recsys_trainer = tr.Trainer(
model=model,
args=training_args,
schema=schema,
compute_metrics=True)
Using amp fp16 backend
Launch daily training and evaluation
In this demo, we will use the fit_and_evaluate
method that allows us to conduct a time-based finetuning by iteratively training and evaluating using a sliding time window: At each iteration, we use the training data of a specific time index \(t\) to train the model; then we evaluate on the validation data of the next index \(t + 1\). Particularly, we set start time to 178 and end time to 180.
from transformers4rec.torch.utils.examples_utils import fit_and_evaluate
OT_results = fit_and_evaluate(recsys_trainer, start_time_index=178, end_time_index=180, input_dir=OUTPUT_DIR)
***** Running training *****
Num examples = 28800
Num Epochs = 10
Instantaneous batch size per device = 384
Total train batch size (w. parallel, distributed & accumulation) = 384
Gradient Accumulation steps = 1
Total optimization steps = 750
***** Launch training for day 178: *****
Step | Training Loss |
---|---|
200 | 7.681100 |
400 | 6.656800 |
600 | 6.372000 |
Saving model checkpoint to ./tmp/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Training completed. Do not forget to share your model on huggingface.co/models =)
***** Running training *****
Num examples = 20736
Num Epochs = 10
Instantaneous batch size per device = 384
Total train batch size (w. parallel, distributed & accumulation) = 384
Gradient Accumulation steps = 1
Total optimization steps = 540
***** Evaluation results for day 179:*****
eval_/next-item/avg_precision@10 = 0.08119537681341171
eval_/next-item/avg_precision@20 = 0.0857219472527504
eval_/next-item/ndcg@10 = 0.11199340969324112
eval_/next-item/ndcg@20 = 0.12857995927333832
eval_/next-item/recall@10 = 0.20809248089790344
eval_/next-item/recall@20 = 0.27398842573165894
***** Launch training for day 179: *****
Step | Training Loss |
---|---|
200 | 6.805900 |
400 | 6.250300 |
Saving model checkpoint to ./tmp/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Training completed. Do not forget to share your model on huggingface.co/models =)
***** Running training *****
Num examples = 16896
Num Epochs = 10
Instantaneous batch size per device = 384
Total train batch size (w. parallel, distributed & accumulation) = 384
Gradient Accumulation steps = 1
Total optimization steps = 440
***** Evaluation results for day 180:*****
eval_/next-item/avg_precision@10 = 0.06664632260799408
eval_/next-item/avg_precision@20 = 0.07125943899154663
eval_/next-item/ndcg@10 = 0.09318044036626816
eval_/next-item/ndcg@20 = 0.110211081802845
eval_/next-item/recall@10 = 0.17855477333068848
eval_/next-item/recall@20 = 0.24568764865398407
***** Launch training for day 180: *****
Step | Training Loss |
---|---|
200 | 6.608300 |
400 | 6.030500 |
Training completed. Do not forget to share your model on huggingface.co/models =)
***** Evaluation results for day 181:*****
eval_/next-item/avg_precision@10 = 0.13680869340896606
eval_/next-item/avg_precision@20 = 0.14374792575836182
eval_/next-item/ndcg@10 = 0.18158714473247528
eval_/next-item/ndcg@20 = 0.2070869356393814
eval_/next-item/recall@10 = 0.3181818127632141
eval_/next-item/recall@20 = 0.4202226400375366
Visualize the average of metrics over time
OT_results
is a list of scores (accuracy metrics) for evaluation based on given start and end time_index. Since in this example we do evaluation on days 179, 180 and 181, we get three metrics in the list one for each day.
OT_results
{'indexed_by_time_eval_/next-item/avg_precision@10': [0.08119537681341171,
0.06664632260799408,
0.13680869340896606],
'indexed_by_time_eval_/next-item/avg_precision@20': [0.0857219472527504,
0.07125943899154663,
0.14374792575836182],
'indexed_by_time_eval_/next-item/ndcg@10': [0.11199340969324112,
0.09318044036626816,
0.18158714473247528],
'indexed_by_time_eval_/next-item/ndcg@20': [0.12857995927333832,
0.110211081802845,
0.2070869356393814],
'indexed_by_time_eval_/next-item/recall@10': [0.20809248089790344,
0.17855477333068848,
0.3181818127632141],
'indexed_by_time_eval_/next-item/recall@20': [0.27398842573165894,
0.24568764865398407,
0.4202226400375366]}
import numpy as np
# take the average of metric values over time
avg_results = {k: np.mean(v) for k,v in OT_results.items()}
for key in sorted(avg_results.keys()):
print(" %s = %s" % (key, str(avg_results[key])))
indexed_by_time_eval_/next-item/avg_precision@10 = 0.09488346427679062
indexed_by_time_eval_/next-item/avg_precision@20 = 0.10024310400088628
indexed_by_time_eval_/next-item/ndcg@10 = 0.12892033159732819
indexed_by_time_eval_/next-item/ndcg@20 = 0.14862599223852158
indexed_by_time_eval_/next-item/recall@10 = 0.23494302233060202
indexed_by_time_eval_/next-item/recall@20 = 0.3132995714743932
Trace the model
We serve the model with the PyTorch backend that is used to execute TorchScript models. All models created in PyTorch using the python API must be traced/scripted to produce a TorchScript model. For tracing the model, we use torch.jit.trace api that takes the model as a Python function or torch.nn.Module, and an example input that will be passed to the function while tracing.
import os
import torch
import cudf
from merlin.io import Dataset
from nvtabular import Workflow
from merlin.systems.dag import Ensemble
from merlin.systems.dag.ops.pytorch import PredictPyTorch
from merlin.systems.dag.ops.workflow import TransformWorkflow
from merlin.table import TensorTable, TorchColumn
from merlin.table.conversions import convert_col
Create a dict of tensors to feed it as example inputs in the torch.jit.trace()
.
df = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "./preproc_sessions_by_day/178/train.parquet"), columns=model.input_schema.column_names)
table = TensorTable.from_df(df.iloc[:100])
for column in table.columns:
table[column] = convert_col(table[column], TorchColumn)
model_input_dict = table.to_dict()
Let’s now add a top_k parameter to model so that we can return the top k item ids with the highest scores when we serve the model on Triton Inference Server.
model.top_k = 20
Let us now trace the model.
model.eval()
traced_model = torch.jit.trace(model, model_input_dict, strict=True)
Let’s check out the item_id-list
column in the model_input_dict
dictionary.
The column is represented as values and offsets dictating which values belong to which example.
model_input_dict['item_id-list__values']
tensor([ 604, 878, 742, 90, 4777, 1583, 3446, 8083, 3446, 4018,
742, 4777, 184, 12288, 2065, 430, 6, 30, 6, 157,
1987, 2590, 10855, 8217, 4210, 8711, 4242, 81, 112, 4242,
5732, 6810, 33, 6, 73, 664, 2312, 7124, 9113, 445,
1157, 774, 685, 430, 1945, 475, 597, 289, 166, 29,
342, 289, 33, 423, 166, 480, 2772, 288, 962, 4001,
2050, 3274, 499, 1219, 395, 1636, 11839, 10714, 11107, 289,
166, 650, 1085, 302, 88, 650, 214, 304, 177, 317,
423, 6, 3818, 931, 2186, 1085, 206, 687, 3831, 687,
202, 20, 43, 20, 2034, 10457, 20, 21, 1126, 2815,
4210, 34264, 830, 774, 620, 2050, 1987, 1079, 4713, 1336,
661, 289, 430, 863, 4829, 5786, 19156, 17270, 23365, 4209,
3651, 1037, 4770, 224, 277, 1020, 650, 166, 1354, 206,
1889, 2473, 1697, 997, 480, 774, 3841, 4316, 3841, 1230,
3841, 1697, 3809, 475, 981, 804, 313, 613, 1219, 1334,
1941, 2888, 2626, 1334, 2689, 804, 475, 981, 313, 804,
1219, 206, 651, 429, 605, 101, 413, 1965, 627, 814,
627, 814, 4713, 3675, 2789, 3769, 1283, 9540, 1251, 313,
685, 2497, 395, 845, 3462, 2713, 5077, 388, 340, 297,
388, 9641, 46, 61, 822, 61, 602, 2270, 719, 3274,
2556, 2, 61, 24, 96, 61, 423, 61, 475, 3460,
2693, 3460, 3044, 2556, 3988, 992, 1603, 122, 2704, 2787,
3135, 550, 516, 44, 1551, 2702, 206, 1762, 31, 474,
481, 198, 474, 2704, 2393, 1025, 20033, 72, 1334, 224,
3460, 4774, 2050, 6485, 15953, 422, 1488, 2346, 4470, 2548,
571, 1770, 1324, 453, 837, 123, 638, 4759, 3552, 6825,
2740, 5347, 5390, 1169, 4100, 1230, 804, 3588, 2449, 185,
16, 643, 274, 686, 18092, 10457, 609, 2969, 3480, 2969,
37, 609, 2969, 609, 22, 8312, 257, 37, 22, 5361,
7186, 7380, 6052, 7256, 13404, 557, 160, 1664, 4375, 3484,
685, 651, 445, 429, 445, 774, 651, 4284, 1738, 3855,
225, 210, 7245, 6731, 771, 1987, 157, 804, 442, 804,
2091, 1169, 2091, 1169, 3484, 4375, 3484, 445, 429, 430,
423, 1697, 1393, 1798, 2753, 206, 1153, 21588, 2189, 3704,
4463, 5816, 7557, 507, 1797, 814, 627, 2016, 855, 1889,
224, 597, 37, 597, 16533, 10255, 2, 2651, 4028, 2556,
2788, 6379, 1830, 1070, 30, 312, 445, 1085, 1569, 2222,
2664, 1950, 2098, 1672, 224, 4336, 651, 997, 1157, 830,
800, 597, 1085, 12430, 415, 12430, 651, 800, 1756, 1378,
1413, 633, 2034, 7932, 6034, 6360, 4662, 576, 4662, 576,
4662], device='cuda:0')
And here are the offsets
model_input_dict['item_id-list__offsets']
tensor([ 0, 12, 14, 16, 19, 21, 23, 26, 32, 35, 45, 47, 49, 51,
56, 59, 61, 66, 69, 72, 76, 78, 80, 83, 87, 91, 94, 96,
98, 100, 102, 104, 107, 111, 113, 122, 124, 128, 130, 133, 136, 141,
143, 161, 164, 167, 172, 176, 180, 182, 184, 191, 193, 196, 199, 201,
209, 214, 234, 237, 243, 245, 265, 267, 274, 276, 281, 289, 295, 298,
300, 305, 307, 312, 314, 317, 320, 324, 327, 331, 335, 337, 339, 343,
345, 348, 351, 354, 356, 360, 362, 364, 367, 374, 376, 381, 383, 386,
388, 396, 401], device='cuda:0', dtype=torch.int32)
Let’s create a folder that we can store the exported models and the config files.
import shutil
ens_model_path = os.environ.get("ens_model_path", f"{INPUT_DATA_DIR}/models")
# Make sure we have a clean stats space for Dask
if os.path.isdir(ens_model_path):
shutil.rmtree(ens_model_path)
os.mkdir(ens_model_path)
workflow = Workflow.load(os.path.join(INPUT_DATA_DIR, "workflow_etl"))
torch_op = workflow.input_schema.column_names >> TransformWorkflow(workflow) >> PredictPyTorch(
traced_model, model.input_schema, model.output_schema
)
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
warnings.warn(
The last step is to create the ensemble artifacts that Triton Inference Server can consume. To make these artifacts, we import the Ensemble class. The class is responsible for interpreting the graph and exporting the correct files for the server.
When we create an Ensemble
object we supply the graph and a schema representing the starting input of the graph. The inputs to the ensemble graph are the inputs to the first operator of out graph. After we created the Ensemble we export the graph, supplying an export path for the ensemble.export
function. This returns an ensemble config which represents the entire inference pipeline and a list of node-specific configs.
ensemble = Ensemble(torch_op, workflow.input_schema)
ens_config, node_configs = ensemble.export(ens_model_path)
/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'session_id', which is not being used by any downstream operator in the ensemble graph.
warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'item_id-count', which is not being used by any downstream operator in the ensemble graph.
warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'day_index', which is not being used by any downstream operator in the ensemble graph.
warnings.warn(
2. Serving Ensemble Model to the Triton Inference Server
NVIDIA Triton Inference Server (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.
The last step of a machine learning (ML)/deep learning (DL) pipeline is to deploy the ETL workflow and saved model to production. In the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the DL model for a prediction. Therefore, we deploy the NVTabular workflow with the PyTorch model as an ensemble model to Triton Inference. The ensemble model guarantees that the same transformation is applied to the raw inputs.
In this section, you will learn how to
to deploy saved NVTabular and PyTorch models to Triton Inference Server
send requests for predictions and get responses.
2.1 Starting Triton Server
It is time to deploy all the models as an ensemble model to Triton Inference Serve TIS. After we export the ensemble, we are ready to start the TIS. You can start triton server by using the following command on your terminal:
tritonserver --model-repository=<ensemble_export_path>
For the --model-repository
argument, specify the same path as the export_path that you specified previously in the ensemble.export
method. This command will launch the server and load all the models to the server. Once all the models are loaded successfully, you should see READY status printed out in the terminal for each loaded model.
2.2. Connect to the Triton Inference Server and check if the server is alive
import tritonhttpclient
try:
triton_client = tritonhttpclient.InferenceServerClient(url="localhost:8000", verbose=True)
print("client created.")
except Exception as e:
print("channel creation failed: " + str(e))
triton_client.is_server_live()
client created.
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
/usr/local/lib/python3.8/dist-packages/tritonhttpclient/__init__.py:31: DeprecationWarning: The package `tritonhttpclient` is deprecated and will be removed in a future version. Please use instead `tritonclient.http`
warnings.warn(
True
2.3. Load raw data for inference
We select the last 50 interactions and filter out sessions with less than 2 interactions.
import pandas as pd
interactions_merged_df = pd.read_parquet(os.path.join(INPUT_DATA_DIR, "interactions_merged_df.parquet"))
interactions_merged_df = interactions_merged_df.sort_values('timestamp')
batch = interactions_merged_df[-50:]
sessions_to_use = batch.session_id.value_counts()
filtered_batch = batch[batch.session_id.isin(sessions_to_use[sessions_to_use.values>1].index.values)]
2.5. Send the request to triton server
triton_client.get_model_repository_index()
POST /v2/repository/index, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '188'}>
bytearray(b'[{"name":"0_transformworkflowtriton","version":"1","state":"READY"},{"name":"1_predictpytorchtriton","version":"1","state":"READY"},{"name":"executor_model","version":"1","state":"READY"}]')
[{'name': '0_transformworkflowtriton', 'version': '1', 'state': 'READY'},
{'name': '1_predictpytorchtriton', 'version': '1', 'state': 'READY'},
{'name': 'executor_model', 'version': '1', 'state': 'READY'}]
If all models are loaded successfully, you should be seeing READY
status next to each model.
from merlin.systems.triton.utils import send_triton_request
response = send_triton_request(workflow.input_schema, filtered_batch, model.output_schema.column_names)
print(response)
{'item_id_scores': array([[ 7.7149506, 7.3862066, 7.3664827, 7.2885814, 6.8977885,
6.853054 , 6.6473446, 6.5615396, 6.5239835, 6.504697 ,
6.4961066, 6.482874 , 6.478437 , 6.4518657, 6.403399 ,
6.3845167, 6.372211 , 6.3541226, 6.2738447, 6.2641015],
[ 7.8986144, 7.575227 , 6.9020343, 6.8903823, 6.839163 ,
6.833202 , 6.761392 , 6.7598944, 6.6728435, 6.6523266,
6.640926 , 6.6118965, 6.470592 , 6.419156 , 6.4112453,
6.4044685, 6.386283 , 6.3791957, 6.362708 , 6.32642 ],
[ 3.2538116, 3.1455884, 3.1179721, 3.0379417, 2.9422204,
2.9090347, 2.8573742, 2.8552375, 2.8385096, 2.8210163,
2.8062954, 2.7926047, 2.7883859, 2.7677355, 2.7579772,
2.7548594, 2.7234986, 2.7234783, 2.7092912, 2.7081766],
[ 7.1264834, 6.9747195, 6.686076 , 6.55559 , 6.554246 ,
6.4562297, 6.409179 , 6.303057 , 6.1918006, 6.18303 ,
6.1193542, 6.107653 , 6.0896993, 6.0872216, 6.0593266,
6.0522294, 5.960239 , 5.953763 , 5.9512544, 5.93457 ],
[ 8.472156 , 8.323568 , 8.279997 , 8.109349 , 8.0277195,
7.8910418, 7.743165 , 7.7279477, 7.692122 , 7.6912303,
7.6711187, 7.5660324, 7.5545163, 7.524415 , 7.4610972,
7.424488 , 7.4123936, 7.411586 , 7.3956637, 7.334862 ],
[10.485718 , 10.375805 , 10.131399 , 9.950816 , 9.917899 ,
9.840376 , 9.77128 , 9.676795 , 9.599488 , 9.544931 ,
9.512772 , 9.488843 , 9.455492 , 9.438552 , 9.430434 ,
9.37854 , 9.362637 , 9.329395 , 9.303702 , 9.279214 ],
[ 5.836771 , 5.828289 , 5.7385616, 5.6870475, 5.5709176,
5.5614986, 5.5448956, 5.3459535, 5.344148 , 5.310749 ,
5.2611475, 5.2548814, 5.1640043, 5.1203575, 5.0081887,
4.987002 , 4.976223 , 4.90389 , 4.8933997, 4.86463 ],
[ 8.174524 , 7.422059 , 7.3884916, 7.272735 , 7.1977425,
7.0858173, 6.645329 , 6.5097084, 6.470001 , 6.453685 ,
6.2901287, 6.1282573, 6.1203957, 6.0265145, 6.0052614,
5.9081306, 5.880647 , 5.78306 , 5.7507606, 5.7427897],
[ 6.7380514, 6.735104 , 6.6285214, 6.5611916, 6.527247 ,
6.4949493, 6.4821672, 6.474738 , 6.4419727, 6.4345016,
6.3947067, 6.352811 , 6.314824 , 6.293556 , 6.288764 ,
6.2727513, 6.2707367, 6.2422667, 6.2342215, 6.196748 ],
[ 5.6076183, 5.472957 , 5.412912 , 5.3116655, 5.177665 ,
5.112378 , 5.085245 , 4.9710717, 4.940496 , 4.93859 ,
4.8099985, 4.760807 , 4.7229834, 4.7145324, 4.707185 ,
4.667753 , 4.6421194, 4.630641 , 4.6156597, 4.5743294],
[ 8.147796 , 7.6630974, 7.5877905, 7.4908724, 7.4732018,
7.404136 , 7.352083 , 7.3242574, 7.3009863, 7.242015 ,
7.1836567, 7.1416435, 7.138858 , 7.134639 , 7.111288 ,
7.0651016, 7.0538983, 7.0510435, 7.044665 , 7.035279 ],
[ 3.2990756, 3.1885252, 3.1802373, 3.0592136, 3.045468 ,
3.0039759, 2.9578297, 2.9415162, 2.93026 , 2.9297981,
2.8965662, 2.8944998, 2.877764 , 2.8689 , 2.8668582,
2.8634927, 2.78628 , 2.767284 , 2.7598662, 2.7465074],
[ 8.852634 , 8.734692 , 8.606397 , 8.578839 , 8.453225 ,
8.416724 , 8.407235 , 8.343583 , 8.280216 , 8.269384 ,
8.20807 , 8.177954 , 8.15201 , 8.136789 , 8.103194 ,
7.873078 , 7.852006 , 7.849255 , 7.8436575, 7.7874575],
[ 8.121338 , 7.9475265, 7.9030285, 7.8839293, 7.850343 ,
7.8212285, 7.7303505, 7.655741 , 7.5204306, 7.4822707,
7.4198914, 7.3865814, 7.3747654, 7.3459187, 7.3127384,
7.2247725, 7.1183944, 7.111621 , 7.0977235, 7.089139 ],
[ 8.760972 , 8.427143 , 8.166194 , 8.004948 , 7.831641 ,
7.8019366, 7.784935 , 7.6579127, 7.5422807, 7.531015 ,
7.48702 , 7.448603 , 7.3950186, 7.355933 , 7.3510966,
7.2111845, 7.183003 , 7.166962 , 7.1576095, 7.1405106]],
dtype=float32), 'item_ids': array([[ 127, 1245, 1271, 4074, 161, 532, 3928, 19290, 11874,
1446, 346, 9285, 3452, 3334, 10987, 479, 7555, 3206,
14633, 13677],
[ 2693, 3460, 10401, 2393, 9415, 9285, 14213, 2404, 13750,
10987, 8084, 183, 2889, 5110, 16662, 10962, 7865, 14401,
4158, 3211],
[ 4278, 3225, 5889, 953, 7577, 1420, 4185, 4591, 573,
10287, 3514, 7923, 5623, 3269, 873, 2433, 15077, 5110,
6357, 3593],
[10401, 573, 3211, 2404, 3452, 2763, 4014, 786, 3296,
6030, 4380, 7481, 2067, 10962, 183, 809, 4651, 2806,
2575, 8084],
[10987, 9415, 9285, 479, 7998, 7555, 2475, 2404, 8084,
2693, 2661, 1102, 786, 14343, 3460, 161, 13677, 2889,
2393, 1600],
[ 8084, 13677, 14633, 12754, 14213, 3452, 13821, 2984, 11816,
2475, 1600, 4074, 13749, 9285, 17882, 10401, 10673, 1964,
11874, 9801],
[ 224, 1453, 1889, 620, 2556, 520, 741, 1219, 633,
2050, 6344, 2651, 2852, 1039, 3841, 375, 2473, 3225,
4204, 2980],
[ 620, 633, 1889, 1625, 1453, 1219, 4962, 6344, 224,
2034, 1413, 2980, 741, 1908, 597, 1085, 4770, 2305,
1647, 1334],
[ 573, 14505, 3365, 9020, 5932, 7470, 7055, 6344, 10962,
11937, 6488, 13821, 19263, 14213, 633, 8553, 10123, 10401,
1453, 15439],
[15148, 13277, 1381, 24893, 11689, 37268, 7454, 1547, 2305,
30765, 46450, 18630, 8747, 1085, 1889, 25487, 7506, 620,
14220, 8291],
[19263, 13821, 15289, 11874, 8058, 17882, 7159, 12754, 6473,
14213, 16233, 573, 3452, 12059, 1600, 8084, 10401, 13749,
15439, 9801],
[ 8813, 1889, 464, 1855, 520, 127, 185, 4204, 637,
31, 784, 2789, 4713, 6461, 11874, 840, 177, 5224,
5818, 6473],
[ 7865, 5932, 7055, 14213, 9020, 4541, 8084, 6488, 8553,
14505, 6473, 10962, 4204, 4136, 16233, 13821, 10401, 9459,
9801, 15439],
[10401, 6488, 5932, 7865, 10962, 14213, 8553, 8084, 9020,
573, 4204, 13821, 3175, 15439, 4136, 10123, 11874, 3452,
17882, 9236],
[ 2693, 3460, 14213, 7865, 5932, 6488, 4204, 8084, 10962,
8553, 1219, 9020, 4541, 13750, 9236, 1453, 4136, 2980,
10401, 2651]])}
The response contains scores (logits) for each of the returned items and the returned it ids.
response.keys()
dict_keys(['item_id_scores', 'item_ids'])
Just as we requested by setting the top_k
parameter, only 20 predictions are returned.
response['item_ids'].shape
(15, 20)
This is the end of the tutorial. You successfully
performed feature engineering with NVTabular
trained transformer architecture based session-based recommendation models with Transformers4Rec
deployed a trained model to Triton Inference Server, sent request and got responses from the server.
References
Merlin Transformers4rec: https://github.com/NVIDIA-Merlin/Transformers4Rec
Merlin NVTabular: https://github.com/NVIDIA-Merlin/NVTabular/tree/stable/nvtabular
Merlin Dataloader: https://github.com/NVIDIA-Merlin/dataloader
Triton inference server: https://github.com/triton-inference-server