# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
Serving a Session-based Recommendation model with Torch Backend
This notebook is created using the latest stable merlin-pytorch container.
At this point, when you reach out to this notebook, we expect that you have already executed the 01-ETL-with-NVTabular.ipynb
and 02-session-based-XLNet-with-PyT.ipynb
notebooks, and saved the NVT workflow and the trained session-based model.
In this notebook, you are going to learn how you can serve a trained Transformer-based PyTorch model on NVIDIA Triton Inference Server (TIS) with Torch backend using Merlin systems library. One common way to do inference with a trained model is to use TorchScript, an intermediate representation of a PyTorch model that can be run in Python as well as in a high performance environment like C++. TorchScript is actually the recommended model format for scaled inference and deployment. TIS PyTorch (LibTorch) backend is designed to run TorchScript models using the PyTorch C++ API.
Triton Inference Server (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.
Import required libraries
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import cudf
import glob
import numpy as np
import pandas as pd
import torch
from transformers4rec import torch as tr
from merlin.io import Dataset
from merlin.core.dispatch import make_df
from merlin.systems.dag import Ensemble
from merlin.systems.dag.ops.pytorch import PredictPyTorch
from merlin.systems.dag.ops.workflow import TransformWorkflow
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
We define the paths
INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "/workspace/data")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/sessions_by_day")
model_path= os.environ.get("model_path", f"{INPUT_DATA_DIR}/saved_model")
Set the schema object
We create the schema object by reading the processed train parquet file.
from merlin.schema import Schema
from merlin.io import Dataset
train = Dataset(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
schema = train.schema
We need to load the saved model to be able to serve it on TIS.
import cloudpickle
loaded_model = cloudpickle.load(
open(os.path.join(model_path, "t4rec_model_class.pkl"), "rb")
)
Switch the model to eval mode. We call model.eval()
before tracing to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this might yield inconsistent inference results.
model = loaded_model.cuda()
model.eval()
Model(
(heads): ModuleList(
(0): Head(
(body): SequentialBlock(
(0): TabularSequenceFeatures(
(to_merge): ModuleDict(
(continuous_module): SequentialBlock(
(0): ContinuousFeatures(
(filter_features): FilterFeatures()
(_aggregation): ConcatFeatures()
)
(1): SequentialBlock(
(0): DenseBlock(
(0): Linear(in_features=2, out_features=64, bias=True)
(1): ReLU(inplace=True)
)
)
(2): AsTabular()
)
(categorical_module): SequenceEmbeddingFeatures(
(filter_features): FilterFeatures()
(embedding_tables): ModuleDict(
(item_id-list): Embedding(495, 64, padding_idx=0)
(category-list): Embedding(172, 64, padding_idx=0)
)
)
)
(_aggregation): ConcatFeatures()
(projection_module): SequentialBlock(
(0): DenseBlock(
(0): Linear(in_features=192, out_features=100, bias=True)
(1): ReLU(inplace=True)
)
)
(_masking): MaskedLanguageModeling()
)
(1): SequentialBlock(
(0): DenseBlock(
(0): Linear(in_features=100, out_features=64, bias=True)
(1): ReLU(inplace=True)
)
)
(2): TansformerBlock(
(transformer): XLNetModel(
(word_embedding): Embedding(1, 64)
(layer): ModuleList(
(0): XLNetLayer(
(rel_attn): XLNetRelativeAttention(
(layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(ff): XLNetFeedForward(
(layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)
(layer_1): Linear(in_features=64, out_features=256, bias=True)
(layer_2): Linear(in_features=256, out_features=64, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(dropout): Dropout(p=0.3, inplace=False)
)
(1): XLNetLayer(
(rel_attn): XLNetRelativeAttention(
(layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(ff): XLNetFeedForward(
(layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)
(layer_1): Linear(in_features=64, out_features=256, bias=True)
(layer_2): Linear(in_features=256, out_features=64, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(dropout): Dropout(p=0.3, inplace=False)
)
)
(dropout): Dropout(p=0.3, inplace=False)
)
(masking): MaskedLanguageModeling()
)
)
(prediction_task_dict): ModuleDict(
(next-item): NextItemPredictionTask(
(sequence_summary): SequenceSummary(
(summary): Identity()
(activation): Identity()
(first_dropout): Identity()
(last_dropout): Identity()
)
(metrics): ModuleList(
(0): NDCGAt()
(1): RecallAt()
)
(loss): CrossEntropyLoss()
(embeddings): SequenceEmbeddingFeatures(
(filter_features): FilterFeatures()
(embedding_tables): ModuleDict(
(item_id-list): Embedding(495, 64, padding_idx=0)
(category-list): Embedding(172, 64, padding_idx=0)
)
)
(item_embedding_table): Embedding(495, 64, padding_idx=0)
(masking): MaskedLanguageModeling()
(pre): Block(
(module): NextItemPredictionTask(
(item_embedding_table): Embedding(495, 64, padding_idx=0)
)
)
)
)
)
)
)
Trace the model
We serve the model with the PyTorch backend that is used to execute TorchScript models. All models created in PyTorch using the python API must be traced/scripted to produce a TorchScript model. For tracing the model, we use torch.jit.trace api that takes the model as a Python function or torch.nn.Module, and an example input that will be passed to the function while tracing.
train_paths = os.path.join(OUTPUT_DIR, f"{1}/train.parquet")
dataset = Dataset(train_paths)
Create a dict of tensors to feed it as example inputs in the torch.jit.trace()
.
import pandas as pd
from merlin.table import TensorTable, TorchColumn
from merlin.table.conversions import convert_col
df = cudf.read_parquet(train_paths, columns=model.input_schema.column_names)
table = TensorTable.from_df(df.loc[:100])
for column in table.columns:
table[column] = convert_col(table[column], TorchColumn)
model_input_dict = table.to_dict()
model_input_dict['item_id-list__values']
tensor([306, 5, 40, 17, 43, 20, 69, 8, 57, 137, 35, 37, 85, 65,
5, 28, 9, 153, 74, 53, 15, 173, 59, 32, 11, 21, 23, 23,
9, 15, 12, 69, 37, 16, 6, 22, 39, 20, 22, 95, 40, 7,
25, 32, 17, 8, 26, 32, 33, 18, 12, 10, 41, 14, 28, 56,
30, 21, 16, 42, 13, 83, 65, 46, 105, 38, 11, 3, 3, 14,
9, 36, 116, 15, 15, 23, 8, 16, 68, 151, 60, 18, 48, 19,
16, 4, 37, 246, 169, 21, 16, 116, 27, 4, 19, 76, 6, 31,
153, 38, 35, 11, 38, 3, 73, 38, 74, 6, 7, 12, 18, 10,
54, 11, 29, 5, 24, 11, 20, 3, 17, 42, 26, 24, 30, 26,
62, 89, 12, 38, 18, 3, 10, 18, 15, 131, 19, 6, 51, 60,
10, 3, 14, 22, 21, 39, 44, 221, 88, 14, 16, 80, 5, 16,
21, 81, 27, 8, 20, 49, 32, 83, 49, 19, 3, 17, 8, 10,
29, 62, 94, 38, 15, 11, 12, 16, 10, 31, 7, 53, 3, 42,
38, 25, 5, 62, 20, 73, 48, 6, 12, 19, 15, 38, 30, 9,
82, 31, 49, 64, 22, 38, 10, 56, 11, 13, 3, 14, 39, 18,
47, 65, 18, 15, 9, 74, 3, 50, 37, 22, 66, 47, 23, 17,
8, 21, 35, 7, 12, 16, 21, 26, 31, 13, 20, 9, 193, 49,
9, 62, 51, 45, 90, 14, 47, 9, 73, 16, 3, 62, 24, 82,
7, 14, 37, 29, 26, 42, 6, 90, 3, 10, 33, 7, 7, 10,
31, 10, 12, 21, 55, 25, 21, 3, 20, 24, 25, 4, 3, 52,
5, 5, 10, 12, 37, 162, 31, 5, 119, 5, 24, 65, 4, 10,
46, 86, 5, 58, 15, 48, 66, 14, 23, 12, 13, 6, 48, 8,
22, 95, 5, 42, 86, 108, 26, 7, 80, 54, 63, 12, 147, 177,
17, 18, 24, 15, 40, 5, 40, 7, 6, 63, 4, 18, 123, 33,
36, 25, 40, 18, 16, 10, 18, 26, 21, 59, 44, 12, 28, 30,
134, 7, 21, 8, 7, 32, 41, 60, 52, 25, 36, 6, 45, 39,
16, 20, 95, 8, 56, 53, 48, 17, 14, 3, 46, 35, 17, 12,
30, 8, 5, 54, 75, 96, 4, 43, 8, 61, 4, 8, 34, 30,
34, 49, 29, 92, 6, 28, 26, 22, 46, 20, 11, 14, 13, 75,
22, 21, 17, 166, 4, 87, 5, 11, 37, 26, 23],
device='cuda:0')
traced_model = torch.jit.trace(model, model_input_dict, strict=True)
Generate model input and output schemas to feed in the PredictPyTorch
operator below.
input_schema = model.input_schema
output_schema = model.output_schema
input_schema
name | tags | dtype | is_list | is_ragged | properties.value_count.min | properties.value_count.max | properties.num_buckets | properties.freq_threshold | properties.max_size | properties.cat_path | properties.embedding_sizes.cardinality | properties.embedding_sizes.dimension | properties.domain.min | properties.domain.max | properties.domain.name | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | weekday_sin-list | (Tags.LIST, Tags.CONTINUOUS) | DType(name='float32', element_type=<ElementTyp... | True | True | 2 | 16 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
1 | age_days-list | (Tags.LIST, Tags.CONTINUOUS) | DType(name='float32', element_type=<ElementTyp... | True | True | 2 | 16 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2 | item_id-list | (Tags.CATEGORICAL, Tags.LIST, Tags.ITEM, Tags.ID) | DType(name='int64', element_type=<ElementType.... | True | True | 2 | 16 | NaN | 0.0 | 0.0 | .//categories/unique.item_id.parquet | 495.0 | 52.0 | 0.0 | 494.0 | item_id |
3 | category-list | (Tags.CATEGORICAL, Tags.LIST) | DType(name='int64', element_type=<ElementType.... | True | True | 2 | 16 | NaN | 0.0 | 0.0 | .//categories/unique.category.parquet | 172.0 | 29.0 | 0.0 | 171.0 | category |
Let’s create a folder that we can store the exported models and the config files.
import shutil
ens_model_path = os.environ.get("ens_model_path", f"{INPUT_DATA_DIR}/models")
# Make sure we have a clean stats space for Dask
if os.path.isdir(ens_model_path):
shutil.rmtree(ens_model_path)
os.mkdir(ens_model_path)
We want to serve NVT model and our trained session-based model together as an ensemble to the Triton Inference Server. That way we can send raw requests to Triton and return back item scores per session. For that we need to load our save workflow first.
from nvtabular.workflow import Workflow
workflow = Workflow.load(os.path.join(INPUT_DATA_DIR, "workflow_etl"))
print(workflow.input_schema.column_names)
['item_id', 'category', 'day', 'age_days', 'weekday_sin', 'session_id']
For transforming the raw input features during inference, we use TransformWorkflow operator that ensures the workflow is correctly saved and packaged with the required config so the server will know how to load it. We use PredictPyTorch operator that takes a pytorch model and packages it correctly for tritonserver to run on the PyTorch backend.
torch_op = workflow.input_schema.column_names >> TransformWorkflow(workflow) >> PredictPyTorch(
traced_model, input_schema, output_schema
)
ensemble = Ensemble(torch_op, workflow.input_schema)
/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'session_id', which is not being used by any downstream operator in the ensemble graph.
warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'day-first', which is not being used by any downstream operator in the ensemble graph.
warnings.warn(
The last step is to create the ensemble artifacts that Triton Inference Server can consume. To make these artifacts, we import the Ensemble class. The class is responsible for interpreting the graph and exporting the correct files for the server.
When we create an Ensemble
object we supply the graph and a schema representing the starting input of the graph. The inputs to the ensemble graph are the inputs to the first operator of out graph. After we created the Ensemble we export the graph, supplying an export path for the ensemble.export
function. This returns an ensemble config which represents the entire inference pipeline and a list of node-specific configs.
ens_config, node_configs = ensemble.export(ens_model_path)
ensemble.input_schema
name | tags | dtype | is_list | is_ragged | |
---|---|---|---|---|---|
0 | item_id | () | DType(name='int32', element_type=<ElementType.... | False | False |
1 | category | () | DType(name='int32', element_type=<ElementType.... | False | False |
2 | day | () | DType(name='int64', element_type=<ElementType.... | False | False |
3 | age_days | () | DType(name='float32', element_type=<ElementTyp... | False | False |
4 | weekday_sin | () | DType(name='float32', element_type=<ElementTyp... | False | False |
5 | session_id | () | DType(name='int64', element_type=<ElementType.... | False | False |
Starting Triton Server
It is time to deploy all the models as an ensemble model to Triton Inference Serve TIS. After we export the ensemble, we are ready to start the TIS. You can start triton server by using the following command on your terminal:
tritonserver --model-repository=<ensemble_export_path>
For the --model-repository
argument, specify the same path as the export_path that you specified previously in the ensemble.export
method. This command will launch the server and load all the models to the server. Once all the models are loaded successfully, you should see READY status printed out in the terminal for each loaded model.
import tritonclient.http as client
# Create a triton client
try:
triton_client = client.InferenceServerClient(url="localhost:8000", verbose=True)
print("client created.")
except Exception as e:
print("channel creation failed: " + str(e))
client created.
After we create the client and verified it is connected to the server instance, we can communicate with the server and ensure all the models are loaded correctly.
# ensure triton is in a good state
triton_client.is_server_live()
triton_client.get_model_repository_index()
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
POST /v2/repository/index, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '188'}>
bytearray(b'[{"name":"0_transformworkflowtriton","version":"1","state":"READY"},{"name":"1_predictpytorchtriton","version":"1","state":"READY"},{"name":"executor_model","version":"1","state":"READY"}]')
[{'name': '0_transformworkflowtriton', 'version': '1', 'state': 'READY'},
{'name': '1_predictpytorchtriton', 'version': '1', 'state': 'READY'},
{'name': 'executor_model', 'version': '1', 'state': 'READY'}]
Send request to Triton and get the response
The last step of a machine learning (ML)/deep learning (DL) pipeline is to deploy the model to production, and get responses for a given query or a set of queries. In this section, we generate a dataframe that we can serve as a request to TIS. We do serve the raw dataframe and in the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the deployed DL model for a prediction.
Let’s generate a dataframe with raw input values. We can send this dataframe to Triton as a request.
NUM_ROWS =1000
long_tailed_item_distribution = np.clip(np.random.lognormal(3., 1., int(NUM_ROWS)).astype(np.int32), 1, 50000)
# generate random item interaction features
df = pd.DataFrame(np.random.randint(70000, 90000, int(NUM_ROWS)), columns=['session_id'])
df['item_id'] = long_tailed_item_distribution
# generate category mapping for each item-id
df['category'] = pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32)
df['age_days'] = np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)
df['weekday_sin']= np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)
# generate day mapping for each session
map_day = dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique()))))
df['day'] = df.session_id.map(map_day)
print(df.head(2))
session_id item_id category age_days weekday_sin day
0 79856 3 2 0.327276 0.080060 2
1 74117 6 4 0.012172 0.147716 1
Once our models are successfully loaded to the TIS, we can now easily send a request to TIS and get a response for our query with send_triton_request utility function.
from merlin.systems.triton.utils import send_triton_request
response = send_triton_request(workflow.input_schema, df, output_schema.column_names, endpoint="localhost:8001")
response
{'next-item': array([[-3.9399953, -2.632081 , -4.2211075, ..., -3.6699016, -3.673493 ,
-3.1244578],
[-3.940445 , -2.6335964, -4.2203593, ..., -3.671566 , -3.6745713,
-3.1240335],
[-3.9393594, -2.6300201, -4.222065 , ..., -3.6674871, -3.672068 ,
-3.1251097],
...,
[-3.9396427, -2.6304667, -4.2218847, ..., -3.6677885, -3.6724825,
-3.1250875],
[-3.939829 , -2.6316376, -4.221267 , ..., -3.6693997, -3.6732295,
-3.1245873],
[-3.9399223, -2.631995 , -4.2210817, ..., -3.669589 , -3.6734715,
-3.1244512]], dtype=float32)}
response['next-item'].shape
(28, 495)
We return a response for each request in the df. Each row in the response['next-item']
array corresponds to the logit values per item in the catalog, and one logit score corresponding to the null, OOV and padded items. The first score of each array in each row corresponds to the score for the padded item, OOV or null item. Note that we dont have OOV or null items in our syntheticall generated datasets.
This is the end of this suit of examples. You successfully performed feature engineering with NVTabular trained transformer architecture based session-based recommendation models with Transformers4Rec deployed the saved workflow and the trained model to Triton Inference Server with Torch backend, sent request and got responses from the server. If you would like to learn how to serve a TF4Rec model with Python backend please visit this example.