# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
Serve Recommendations from the HugeCTR Model#
This notebook is created using the latest stable merlin-hugectr container.
Overview#
In this notebook, we will show how we do inference with our trained deep learning recommender model using Triton Inference Server. In this example, we deploy the NVTabular workflow and HugeCTR model with Triton Inference Server. We deploy them as an ensemble. For each request, Triton Inference Server will feed the input data through the NVTabular workflow and its output through the HugeCR model.
Getting Started#
We need to write configuration files with the stored model weights and model configuration.
Let us first move all of our model files to a directory that we will be able to access from the scripts that we will generate.
import os
import json
# path to preprocessed data
INPUT_DATA_DIR = os.environ.get(
"INPUT_DATA_DIR", os.path.expanduser("/workspace/nvt-examples/movielens/data/")
)
# path to saved model
MODEL_DIR = os.path.join(INPUT_DATA_DIR, "model/movielens_hugectr")
file_to_write = """
name: "movielens_hugectr"
backend: "hugectr"
max_batch_size: 64
input [
{
name: "DES"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "CATCOLUMN"
data_type: TYPE_INT64
dims: [ -1 ]
},
{
name: "ROWINDEX"
data_type: TYPE_INT32
dims: [ -1 ]
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind : KIND_GPU
gpus:[0]
}
]
parameters [
{
key: "config"
value: { string_value: "$MODEL_DIR/1/movielens.json" }
},
{
key: "gpucache"
value: { string_value: "true" }
},
{
key: "hit_rate_threshold"
value: { string_value: "0.8" }
},
{
key: "gpucacheper"
value: { string_value: "0.5" }
},
{
key: "label_dim"
value: { string_value: "1" }
},
{
key: "slots"
value: { string_value: "3" }
},
{
key: "cat_feature_num"
value: { string_value: "4" }
},
{
key: "des_feature_num"
value: { string_value: "0" }
},
{
key: "max_nnz"
value: { string_value: "2" }
},
{
key: "embedding_vector_size"
value: { string_value: "16" }
},
{
key: "embeddingkey_long_type"
value: { string_value: "true" }
}
]
"""
with open(os.path.join(MODEL_DIR, "config.pbtxt"), "w", encoding="utf-8") as f:
f.write(file_to_write.replace("$MODEL_DIR", MODEL_DIR))
config = json.dumps(
{
"supportlonglong": True,
"models": [
{
"model": "movielens_hugectr",
"sparse_files": [f"{MODEL_DIR}/0_sparse_1900.model"],
"dense_file": f"{MODEL_DIR}/_dense_1900.model",
"network_file": f"{MODEL_DIR}/1/movielens.json",
"num_of_worker_buffer_in_pool": "1",
"num_of_refresher_buffer_in_pool": "1",
"cache_refresh_percentage_per_iteration": "0.2",
"deployed_device_list": ["0"],
"max_batch_size": "64",
"default_value_for_each_table": ["0.0","0.0"],
"hit_rate_threshold": "0.9",
"gpucacheper": "0.5",
"maxnum_catfeature_query_per_table_per_sample": ["162542", "56632","12"],
"embedding_vecsize_per_table": ["16","16","16"],
"gpucache":"true"
}
]
})
config = json.loads(config)
with open(os.path.join(MODEL_DIR, "ps.json"), "w", encoding="utf-8") as f:
json.dump(config, f)
Let’s import required libraries.
import tritonclient.grpc as httpclient
import cudf
import numpy as np
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Load Models on Triton Server#
In the running docker container, you can start triton server with the command below:
tritonserver --model-repository=<path_to_models> --backend-config=hugectr,ps=<path_to_models>/ps.json --model-control-mode=explicit
Since we add --model-control-mode=explicit
flag, the model wont be loaded at this step, we will load the model below.
Note: The model-repository path is /root/nvt-examples/movielens/data/model
. The models haven’t been loaded, yet. We can request triton server to load the saved ensemble. We initialize a triton client. The path for the json file is /root/nvt-examples/movielens/data/model/movielens_hugectr/ps.json
.
# disable warnings
import warnings
warnings.filterwarnings("ignore")
import tritonhttpclient
try:
triton_client = tritonhttpclient.InferenceServerClient(url="localhost:8000", verbose=True)
print("client created.")
except Exception as e:
print("channel creation failed: " + str(e))
client created.
/usr/local/lib/python3.8/dist-packages/tritonhttpclient/__init__.py:31: DeprecationWarning: The package `tritonhttpclient` is deprecated and will be removed in a future version. Please use instead `tritonclient.http`
warnings.warn(
triton_client.is_server_live()
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
True
triton_client.get_model_repository_index()
POST /v2/repository/index, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '30'}>
bytearray(b'[{"name":"movielens_hugectr"}]')
[{'name': 'movielens_hugectr'}]
Let’s load our model to Triton Server.
%%time
triton_client.load_model(model_name="movielens_hugectr")
POST /v2/repository/models/movielens_hugectr/load, headers None
{}
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '0'}>
Loaded model 'movielens_hugectr'
CPU times: user 3.99 ms, sys: 0 ns, total: 3.99 ms
Wall time: 1.04 s
Let’s send a request to Inference Server and print out the response. Since in our example above we do not have continuous columns, below our only inputs are categorical columns.
file_to_write = f"""
from tritonclient.utils import *
import tritonclient.http as httpclient
import numpy as np
import pandas as pd
import sys
model_name = 'movielens_hugectr'
CATEGORICAL_COLUMNS = ["userId", "movieId", "genres"]
CONTINUOUS_COLUMNS = []
LABEL_COLUMNS = ['label']
emb_size_array = [162542, 29434, 20]
shift = np.insert(np.cumsum(emb_size_array), 0, 0)[:-1]
df = pd.read_parquet('{INPUT_DATA_DIR}/valid/part_0.parquet')
test_df = df.head(10)
rp_lst = [0]
cur = 0
for i in range(1, 31):
if i % 3 == 0:
cur += 2
rp_lst.append(cur)
else:
cur += 1
rp_lst.append(cur)
with httpclient.InferenceServerClient("localhost:8000") as client:
test_df.iloc[:, :2] = test_df.iloc[:, :2] + shift[:2]
test_df.iloc[:, 2] = test_df.iloc[:, 2].apply(lambda x: [e + shift[2] for e in x])
embedding_columns = np.array([list(np.hstack(np.hstack(test_df[CATEGORICAL_COLUMNS].values)))], dtype='int64')
dense_features = np.array([[]], dtype='float32')
row_ptrs = np.array([rp_lst], dtype='int32')
inputs = [httpclient.InferInput("DES", dense_features.shape, np_to_triton_dtype(dense_features.dtype)),
httpclient.InferInput("CATCOLUMN", embedding_columns.shape, np_to_triton_dtype(embedding_columns.dtype)),
httpclient.InferInput("ROWINDEX", row_ptrs.shape, np_to_triton_dtype(row_ptrs.dtype))]
inputs[0].set_data_from_numpy(dense_features)
inputs[1].set_data_from_numpy(embedding_columns)
inputs[2].set_data_from_numpy(row_ptrs)
outputs = [httpclient.InferRequestedOutput("OUTPUT0")]
response = client.infer(model_name, inputs, request_id=str(1), outputs=outputs)
result = response.get_response()
print(result)
print("Prediction Result:")
print(response.as_numpy("OUTPUT0"))
"""
with open("wdl2predict.py", "w", encoding="utf-8") as f:
f.write(file_to_write)
!python3 ./wdl2predict.py
/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py:1851: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
self._setitem_single_column(loc, val, pi)
/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py:1773: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
self._setitem_single_column(ilocs[0], value, pi)
{'id': '1', 'model_name': 'movielens_hugectr', 'model_version': '1', 'parameters': {'NumSample': 10, 'DeviceID': 0}, 'outputs': [{'name': 'OUTPUT0', 'datatype': 'FP32', 'shape': [10], 'parameters': {'binary_data_size': 40}}]}
Prediction Result:
[0.5346206 0.49736455 0.2987379 0.6282493 0.7548654 0.59079504
0.55132014 0.90419775 0.47409508 0.5124942 ]