# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================
http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png

Getting Started MovieLens: Serving a HugeCTR Model

In this notebook, we will show how we do inference with our trained deep learning recommender model using Triton Inference Server. In this example, we deploy the NVTabular workflow and HugeCTR model with Triton Inference Server. We deploy them as an ensemble. For each request, Triton Inference Server will feed the input data through the NVTabular workflow and its output through the HugeCR model.

Getting Started

We need to write configuration files with the stored model weights and model configuration.

%%writefile /model/movielens_hugectr/config.pbtxt
name: "movielens_hugectr"
backend: "hugectr"
max_batch_size: 64
input [
   {
    name: "DES"
    data_type: TYPE_FP32
    dims: [ -1 ]
  },
  {
    name: "CATCOLUMN"
    data_type: TYPE_INT64
    dims: [ -1 ]
  },
  {
    name: "ROWINDEX"
    data_type: TYPE_INT32
    dims: [ -1 ]
  }
]
output [
  {
    name: "OUTPUT0"
    data_type: TYPE_FP32
    dims: [ -1 ]
  }
]
instance_group [
  {
    count: 1
    kind : KIND_GPU
    gpus:[0]
  }
]

parameters [
  {
  key: "config"
  value: { string_value: "/model/movielens_hugectr/1/movielens.json" }
  },
  {
  key: "gpucache"
  value: { string_value: "true" }
  },
  {
  key: "hit_rate_threshold"
  value: { string_value: "0.8" }
  },
  {
  key: "gpucacheper"
  value: { string_value: "0.5" }
  },
  {
  key: "label_dim"
  value: { string_value: "1" }
  },
  {
  key: "slots"
  value: { string_value: "3" }
  },
  {
  key: "cat_feature_num"
  value: { string_value: "4" }
  },
 {
  key: "des_feature_num"
  value: { string_value: "0" }
  },
  {
  key: "max_nnz"
  value: { string_value: "2" }
  },
  {
  key: "embedding_vector_size"
  value: { string_value: "16" }
  },
  {
  key: "embeddingkey_long_type"
  value: { string_value: "true" }
  }
]
Overwriting /model/movielens_hugectr/config.pbtxt
%%writefile /model/ps.json
{
    "supportlonglong":true,
    "models":[
        {
            "model":"movielens_hugectr",
            "sparse_files":["/model/movielens_hugectr/0_sparse_1900.model"],
            "dense_file":"/model/movielens_hugectr/_dense_1900.model",
            "network_file":"/model/movielens_hugectr/1/movielens.json",
            "num_of_worker_buffer_in_pool": "1",
            "num_of_refresher_buffer_in_pool": "1",
            "cache_refresh_percentage_per_iteration": "0.2",
            "deployed_device_list":["0"],
            "max_batch_size":"64",
            "default_value_for_each_table":["0.0","0.0"],
            "hit_rate_threshold":"0.9",
            "gpucacheper":"0.5",
            "gpucache":"true"
        }
    ]  
}
Overwriting /model/ps.json

Let’s import required libraries.

import tritonclient.grpc as httpclient
import cudf
import numpy as np

Load Models on Triton Inference Server

At this stage, you should launch the Triton Inference Server docker container with the following script:

docker run -it --gpus=all -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}:/model nvcr.io/nvidia/merlin/merlin-hugectr:latest

For production use, refer to the Merlin containers from the NVIDIA GPU Cloud (NGC) catalog and specify a tag rather than latest.

After you start the container, start Triton Inference Server with the following command:

tritonserver --model-repository=<path_to_models> --backend-config=hugectr,ps=<path_to_models>/ps.json --model-control-mode=explicit

Note: The model-repository path is /model/. The models haven’t been loaded, yet. We can request triton server to load the saved ensemble. We initialize a triton client. The path for the json file is /model/movielens_hugectr/1/movielens.json.

# disable warnings
import warnings

warnings.filterwarnings("ignore")
import tritonhttpclient

try:
    triton_client = tritonhttpclient.InferenceServerClient(url="localhost:8000", verbose=True)
    print("client created.")
except Exception as e:
    print("channel creation failed: " + str(e))
client created.
/usr/local/lib/python3.8/dist-packages/tritonhttpclient/__init__.py:31: DeprecationWarning: The package `tritonhttpclient` is deprecated and will be removed in a future version. Please use instead `tritonclient.http`
  warnings.warn(
triton_client.is_server_live()
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
True
triton_client.get_model_repository_index()
POST /v2/repository/index, headers None

<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '46'}>
bytearray(b'[{"name":"data"},{"name":"movielens_hugectr"}]')
[{'name': 'data'}, {'name': 'movielens_hugectr'}]

Let’s load our model to Triton Server.

%%time

triton_client.load_model(model_name="movielens_hugectr")
POST /v2/repository/models/movielens_hugectr/load, headers None

<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '0'}>
Loaded model 'movielens_hugectr'
CPU times: user 2.6 ms, sys: 2.57 ms, total: 5.17 ms
Wall time: 3.62 s

Let’s send a request to Inference Server and print out the response. Since in our example above we do not have continuous columns, below our only inputs are categorical columns.

import pandas as pd
df = pd.read_parquet("/model/data/valid/part_0.parquet")
df.head()
userId movieId genres rating
0 32187 520 [2, 6] 1.0
1 67974 8 [1, 14] 0.0
2 41311 1026 [1, 7] 0.0
3 5951 336 [2, 4] 1.0
4 16913 335 [3, 8, 11, 4] 1.0
%%writefile ./wdl2predict.py
from tritonclient.utils import *
import tritonclient.http as httpclient
import numpy as np
import pandas as pd
import sys

model_name = 'movielens_hugectr'
CATEGORICAL_COLUMNS = ["userId", "movieId", "genres"]
CONTINUOUS_COLUMNS = []
LABEL_COLUMNS = ['label']
emb_size_array = [162542, 29434, 20]
shift = np.insert(np.cumsum(emb_size_array), 0, 0)[:-1]
df = pd.read_parquet("/model/data/valid/part_0.parquet")
test_df = df.head(10)

rp_lst = [0]
cur = 0
for i in range(1, 31):
    if i % 3 == 0:
        cur += 2
        rp_lst.append(cur)
    else:
        cur += 1
        rp_lst.append(cur)

with httpclient.InferenceServerClient("localhost:8000") as client:
    test_df.iloc[:, :2] = test_df.iloc[:, :2] + shift[:2]
    test_df.iloc[:, 2] = test_df.iloc[:, 2].apply(lambda x: [e + shift[2] for e in x])
    embedding_columns = np.array([list(np.hstack(np.hstack(test_df[CATEGORICAL_COLUMNS].values)))], dtype='int64')
    dense_features = np.array([[]], dtype='float32')
    row_ptrs = np.array([rp_lst], dtype='int32')

    inputs = [httpclient.InferInput("DES", dense_features.shape, np_to_triton_dtype(dense_features.dtype)),
              httpclient.InferInput("CATCOLUMN", embedding_columns.shape, np_to_triton_dtype(embedding_columns.dtype)),
              httpclient.InferInput("ROWINDEX", row_ptrs.shape, np_to_triton_dtype(row_ptrs.dtype))]

    inputs[0].set_data_from_numpy(dense_features)
    inputs[1].set_data_from_numpy(embedding_columns)
    inputs[2].set_data_from_numpy(row_ptrs)
    outputs = [httpclient.InferRequestedOutput("OUTPUT0")]

    response = client.infer(model_name, inputs, request_id=str(1), outputs=outputs)

    result = response.get_response()
    print(result)
    print("Prediction Result:")
    print(response.as_numpy("OUTPUT0"))
Overwriting ./wdl2predict.py
!python3 ./wdl2predict.py
/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py:1851: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_single_column(loc, val, pi)
/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py:1773: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_single_column(ilocs[0], value, pi)
Traceback (most recent call last):
  File "./wdl2predict.py", line 50, in <module>
    response = client.infer(model_name,
  File "/usr/local/lib/python3.8/dist-packages/tritonclient/http/__init__.py", line 1256, in infer
    _raise_if_error(response)
  File "/usr/local/lib/python3.8/dist-packages/tritonclient/http/__init__.py", line 64, in _raise_if_error
    raise error
tritonclient.utils.InferenceServerException: The CATCOLUMN input sample size in request is not match with configuration. The input sample size to be an integer multiple of the configuration.