# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================
Getting Started MovieLens: Serving a HugeCTR Model
In this notebook, we will show how we do inference with our trained deep learning recommender model using Triton Inference Server. In this example, we deploy the NVTabular workflow and HugeCTR model with Triton Inference Server. We deploy them as an ensemble. For each request, Triton Inference Server will feed the input data through the NVTabular workflow and its output through the HugeCR model.
Getting Started
We need to write configuration files with the stored model weights and model configuration.
%%writefile /model/movielens_hugectr/config.pbtxt
name: "movielens_hugectr"
backend: "hugectr"
max_batch_size: 64
input [
{
name: "DES"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "CATCOLUMN"
data_type: TYPE_INT64
dims: [ -1 ]
},
{
name: "ROWINDEX"
data_type: TYPE_INT32
dims: [ -1 ]
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind : KIND_GPU
gpus:[0]
}
]
parameters [
{
key: "config"
value: { string_value: "/model/movielens_hugectr/1/movielens.json" }
},
{
key: "gpucache"
value: { string_value: "true" }
},
{
key: "hit_rate_threshold"
value: { string_value: "0.8" }
},
{
key: "gpucacheper"
value: { string_value: "0.5" }
},
{
key: "label_dim"
value: { string_value: "1" }
},
{
key: "slots"
value: { string_value: "3" }
},
{
key: "cat_feature_num"
value: { string_value: "4" }
},
{
key: "des_feature_num"
value: { string_value: "0" }
},
{
key: "max_nnz"
value: { string_value: "2" }
},
{
key: "embedding_vector_size"
value: { string_value: "16" }
},
{
key: "embeddingkey_long_type"
value: { string_value: "true" }
}
]
Overwriting /model/movielens_hugectr/config.pbtxt
%%writefile /model/ps.json
{
"supportlonglong":true,
"models":[
{
"model":"movielens_hugectr",
"sparse_files":["/model/movielens_hugectr/0_sparse_1900.model"],
"dense_file":"/model/movielens_hugectr/_dense_1900.model",
"network_file":"/model/movielens_hugectr/1/movielens.json",
"num_of_worker_buffer_in_pool": "1",
"num_of_refresher_buffer_in_pool": "1",
"cache_refresh_percentage_per_iteration": "0.2",
"deployed_device_list":["0"],
"max_batch_size":"64",
"default_value_for_each_table":["0.0","0.0"],
"hit_rate_threshold":"0.9",
"gpucacheper":"0.5",
"gpucache":"true"
}
]
}
Overwriting /model/ps.json
Let’s import required libraries.
import tritonclient.grpc as httpclient
import cudf
import numpy as np
Load Models on Triton Inference Server
At this stage, you should launch the Triton Inference Server docker container with the following script:
docker run -it --gpus=all -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}:/model nvcr.io/nvidia/merlin/merlin-hugectr:latest
For production use, refer to the Merlin containers from the NVIDIA GPU Cloud (NGC) catalog and specify a tag rather than
latest
.
After you start the container, start Triton Inference Server with the following command:
tritonserver --model-repository=<path_to_models> --backend-config=hugectr,ps=<path_to_models>/ps.json --model-control-mode=explicit
Note: The model-repository path is /model/
. The models haven’t been loaded, yet. We can request triton server to load the saved ensemble. We initialize a triton client. The path for the json file is /model/movielens_hugectr/1/movielens.json
.
# disable warnings
import warnings
warnings.filterwarnings("ignore")
import tritonhttpclient
try:
triton_client = tritonhttpclient.InferenceServerClient(url="localhost:8000", verbose=True)
print("client created.")
except Exception as e:
print("channel creation failed: " + str(e))
client created.
/usr/local/lib/python3.8/dist-packages/tritonhttpclient/__init__.py:31: DeprecationWarning: The package `tritonhttpclient` is deprecated and will be removed in a future version. Please use instead `tritonclient.http`
warnings.warn(
triton_client.is_server_live()
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
True
triton_client.get_model_repository_index()
POST /v2/repository/index, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '46'}>
bytearray(b'[{"name":"data"},{"name":"movielens_hugectr"}]')
[{'name': 'data'}, {'name': 'movielens_hugectr'}]
Let’s load our model to Triton Server.
%%time
triton_client.load_model(model_name="movielens_hugectr")
POST /v2/repository/models/movielens_hugectr/load, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '0'}>
Loaded model 'movielens_hugectr'
CPU times: user 2.6 ms, sys: 2.57 ms, total: 5.17 ms
Wall time: 3.62 s
Let’s send a request to Inference Server and print out the response. Since in our example above we do not have continuous columns, below our only inputs are categorical columns.
import pandas as pd
df = pd.read_parquet("/model/data/valid/part_0.parquet")
df.head()
userId | movieId | genres | rating | |
---|---|---|---|---|
0 | 32187 | 520 | [2, 6] | 1.0 |
1 | 67974 | 8 | [1, 14] | 0.0 |
2 | 41311 | 1026 | [1, 7] | 0.0 |
3 | 5951 | 336 | [2, 4] | 1.0 |
4 | 16913 | 335 | [3, 8, 11, 4] | 1.0 |
%%writefile ./wdl2predict.py
from tritonclient.utils import *
import tritonclient.http as httpclient
import numpy as np
import pandas as pd
import sys
model_name = 'movielens_hugectr'
CATEGORICAL_COLUMNS = ["userId", "movieId", "genres"]
CONTINUOUS_COLUMNS = []
LABEL_COLUMNS = ['label']
emb_size_array = [162542, 29434, 20]
shift = np.insert(np.cumsum(emb_size_array), 0, 0)[:-1]
df = pd.read_parquet("/model/data/valid/part_0.parquet")
test_df = df.head(10)
rp_lst = [0]
cur = 0
for i in range(1, 31):
if i % 3 == 0:
cur += 2
rp_lst.append(cur)
else:
cur += 1
rp_lst.append(cur)
with httpclient.InferenceServerClient("localhost:8000") as client:
test_df.iloc[:, :2] = test_df.iloc[:, :2] + shift[:2]
test_df.iloc[:, 2] = test_df.iloc[:, 2].apply(lambda x: [e + shift[2] for e in x])
embedding_columns = np.array([list(np.hstack(np.hstack(test_df[CATEGORICAL_COLUMNS].values)))], dtype='int64')
dense_features = np.array([[]], dtype='float32')
row_ptrs = np.array([rp_lst], dtype='int32')
inputs = [httpclient.InferInput("DES", dense_features.shape, np_to_triton_dtype(dense_features.dtype)),
httpclient.InferInput("CATCOLUMN", embedding_columns.shape, np_to_triton_dtype(embedding_columns.dtype)),
httpclient.InferInput("ROWINDEX", row_ptrs.shape, np_to_triton_dtype(row_ptrs.dtype))]
inputs[0].set_data_from_numpy(dense_features)
inputs[1].set_data_from_numpy(embedding_columns)
inputs[2].set_data_from_numpy(row_ptrs)
outputs = [httpclient.InferRequestedOutput("OUTPUT0")]
response = client.infer(model_name, inputs, request_id=str(1), outputs=outputs)
result = response.get_response()
print(result)
print("Prediction Result:")
print(response.as_numpy("OUTPUT0"))
Overwriting ./wdl2predict.py
!python3 ./wdl2predict.py
/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py:1851: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
self._setitem_single_column(loc, val, pi)
/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py:1773: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
self._setitem_single_column(ilocs[0], value, pi)
Traceback (most recent call last):
File "./wdl2predict.py", line 50, in <module>
response = client.infer(model_name,
File "/usr/local/lib/python3.8/dist-packages/tritonclient/http/__init__.py", line 1256, in infer
_raise_if_error(response)
File "/usr/local/lib/python3.8/dist-packages/tritonclient/http/__init__.py", line 64, in _raise_if_error
raise error
tritonclient.utils.InferenceServerException: The CATCOLUMN input sample size in request is not match with configuration. The input sample size to be an integer multiple of the configuration.