# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_systems_serving-ranking-models-with-merlin-systems/nvidia_logo.png

Serving Ranking Models With Merlin Systems#

This notebook is created using the latest stable merlin-tensorflow container. This Jupyter notebook example demonstrates how to deploy a ranking model to Triton Inference Server (TIS) and generate prediction results for a given query. As a prerequisite, the ranking model must be trained and saved with Merlin Models. Please read the README for the instructions.

Overview#

NVIDIA Merlin is an open source framework that accelerates and scales end-to-end recommender system pipelines. The Merlin framework is broken up into several sub components, these include: Merlin-Core, Merlin-Models, NVTabular and Merlin-Systems. Merlin Systems will be the focus of this example.

The purpose of the Merlin Systems library is to make it easy for Merlin users to quickly deploy their recommender systems from development to Triton Inference Server. We extended the same user-friendly API users are accustomed to in NVTabular and leveraged it to accommodate deploying recommender system components to TIS.

There are some points we need ensure before we continue with this Notebook. Please ensure you have a working NVTabular workflow and model stored in an accessible location. Merlin Systems take the data preprocessing workflow defined in NVTabular and load that into Triton Inference Server as a model. Subsequently it does the same for the trained model. Lets take a closer look at how Merlin Systems makes deploying to TIS simple and effortless, in the rest of this notebook.

Learning objectives#

In this notebook, we learn how to deploy a NVTabular Workflow and a trained Tensorflow model from Merlin Models to Triton.

  • Load NVTabular Workflow

  • Load Pre-trained Merlin Models model

  • Create Ensemble Graph

  • Export Ensemble Graph

  • Run Tritonserver

  • Send Request to Tritonserver

Dataset#

We use the synthetic train and test datasets generated by mimicking the real Ali-CCP: Alibaba Click and Conversion Prediction dataset to build our recommender system ranking models. To see how the data is transformed with NVTabular and how a DLRM model is trained with Merlin Models check out the 04-Exporting-ranking-models.ipynb example notebook which is a prerequisite for this notebook.

It is important to note that the steps take in this notebook are generalized and can be applied to any set of workflow and models.

Tools#

  • NVTabular

  • Merlin Models

  • Merlin Systems

  • Triton Inference Server

Install Required Libraries#

Install TensorFlow so we can read the saved model from disk.

#!pip install tensorflow-gpu

Load an NVTabular Workflow#

First, we load the nvtabular.Workflow that we created in with this example.

import os
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
from nvtabular.workflow import Workflow

input_path = os.environ.get("INPUT_FOLDER", "/workspace/data/")

workflow_stored_path = os.path.join(input_path, "workflow")

workflow = Workflow.load(workflow_stored_path)

After we load the workflow, we remove the label columns from it’s inputs. This removes all columns with the TARGET tag from the workflow. We do this because we need to set the workflow to only require the features needed to predict, not train, when creating an inference pipeline.

from merlin.schema.tags import Tags

label_columns = workflow.output_schema.select_by_tag(Tags.TARGET).column_names
workflow.remove_inputs(label_columns)
<nvtabular.workflow.workflow.Workflow at 0x7ff5b42c9850>

Load the Tensorflow Model#

After loading the workflow, we load the model. This model was trained with the output of the workflow from the Exporting Ranking Models example from Merlin Models.

First, we need to import the Merlin Models library. Loading a TensorFlow model, which is based on custom subclasses, requires to the subclass definition. Otherwise, TensorFlow cannot load correctly load the model.

import merlin.models.tf as mm
2022-07-12 17:18:23.722737: I tensorflow/core/platform/cpu_feature_guard.cc:152] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-07-12 17:18:25.872447: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:214] Using CUDA malloc Async allocator for GPU: 0
2022-07-12 17:18:25.872791: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 16255 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB-LS, pci bus id: 0000:89:00.0, compute capability: 7.0
import tensorflow as tf
tf_model_path = os.path.join(input_path, "dlrm")

model = tf.keras.models.load_model(tf_model_path)
2022-07-12 17:18:27.313679: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

Create the Ensemble Graph#

After we have both the model and the workflow loaded, we can create the ensemble graph. You create the graph. The goal is to illustrate the path of data through your full system. In this example we only serve a workflow with a model, but you can add other components that help you meet your business logic requirements.

Because this example has two components—a model and a workflow—we require two operators. These operators, also known as inference operators, are meant to abstract away all the “hard parts” of loading a specific component, such as a workflow or model, into Triton Inference Server.

The following code block shows how to use two inference operators:

TransformWorkflow
This operator ensures that the workflow is correctly saved and packaged with the required config so the server will know how to load it.
PredictTensorflow
This operator will do something similar with the model, loaded before.

Let’s give it a try.

from merlin.systems.dag.ops.workflow import TransformWorkflow
from merlin.systems.dag.ops.tensorflow import PredictTensorflow

serving_operators = workflow.input_schema.column_names >> TransformWorkflow(workflow) >> PredictTensorflow(model)

Export Graph as Ensemble#

The last step is to create the ensemble artifacts that Triton Inference Server can consume. To make these artifacts, we import the Ensemble class. The class is responsible for interpreting the graph and exporting the correct files for the server.

After you run the following cell, you’ll see that we create a ColumnSchema for the expected inputs to the workflow. The workflow is a Schema.

When you are creating an Ensemble object you supply the graph and a schema representing the starting input of the graph. the inputs to the ensemble graph are the inputs to the first operator of your graph.

After you have created the Ensemble you export the graph, supplying an export path for the Ensemble.export function.

This returns an ensemble config which represents the entire inference pipeline and a list of node-specific configs.

Let’s take a look below.

workflow.output_schema
name tags dtype is_list is_ragged properties.num_buckets properties.freq_threshold properties.max_size properties.start_index properties.cat_path properties.domain.min properties.domain.max properties.domain.name properties.embedding_sizes.cardinality properties.embedding_sizes.dimension
0 user_id (Tags.USER, Tags.USER_ID, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_id.parquet 0 755 user_id 755 65
1 item_id (Tags.ITEM_ID, Tags.ITEM, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.item_id.parquet 0 772 item_id 772 66
2 item_category (Tags.ITEM, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.item_category.parquet 0 772 item_category 772 66
3 item_shop (Tags.ITEM, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.item_shop.parquet 0 772 item_shop 772 66
4 item_brand (Tags.ITEM, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.item_brand.parquet 0 772 item_brand 772 66
5 user_shops (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_shops.parquet 0 755 user_shops 755 65
6 user_profile (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_profile.parquet 0 67 user_profile 67 17
7 user_group (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_group.parquet 0 13 user_group 13 16
8 user_gender (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_gender.parquet 0 3 user_gender 3 16
9 user_age (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_age.parquet 0 8 user_age 8 16
10 user_consumption_2 (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_consumption_2.parquet 0 4 user_consumption_2 4 16
11 user_is_occupied (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_is_occupied.parquet 0 3 user_is_occupied 3 16
12 user_geography (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_geography.parquet 0 5 user_geography 5 16
13 user_intentions (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_intentions.parquet 0 755 user_intentions 755 65
14 user_brands (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_brands.parquet 0 755 user_brands 755 65
15 user_categories (Tags.USER, Tags.CATEGORICAL) int64 False False None 0 0 0 .//categories/unique.user_categories.parquet 0 755 user_categories 755 65
from merlin.systems.dag.ensemble import Ensemble
import numpy as np

ensemble = Ensemble(serving_operators, workflow.input_schema)

export_path = os.path.join(input_path, "ensemble")

ens_conf, node_confs = ensemble.export(export_path)

Display the path to the directory with the ensemble.

export_path
'/models/examples/ensemble'

Verification of Ensemble Artifacts#

After we export the ensemble, we can check the export path for the graph’s artifacts. The directory structure represents an ordering number followed by an operator identifier such as 1_transformworkflow, 2_predicttensorflow, and so on.

Inside each of those directories, the export method writes a config.pbtxt file and a directory with a number. The number indicates the version and begins at 1. The artifacts for each operator are found inside the version folder. These artifacts vary depending on the operator in use.

Install the seedir python package so we can view some of the directory contents.

# install seedir
!pip install seedir
Requirement already satisfied: seedir in /usr/local/lib/python3.8/dist-packages (0.3.1)
Requirement already satisfied: natsort in /usr/local/lib/python3.8/dist-packages (from seedir) (8.1.0)
Requirement already satisfied: emoji in /usr/local/lib/python3.8/dist-packages (from seedir) (1.7.0)
import seedir as sd

sd.seedir(export_path, style='lines', itemlimit=10, depthlimit=3, exclude_folders='.ipynb_checkpoints', sort=True)
ensemble/
├─0_transformworkflow/
│ ├─1/
│ │ ├─__pycache__/
│ │ ├─model.py
│ │ └─workflow/
│ └─config.pbtxt
├─1_predicttensorflow/
│ ├─1/
│ │ └─model.savedmodel/
│ └─config.pbtxt
└─ensemble_model/
  ├─1/
  └─config.pbtxt

Starting Triton Inference Server#

After we export the ensemble, we are ready to start the Triton Inference Server. The server is installed in all the Merlin inference containers. If you are not using one of our containers, then ensure it is installed in your environment. For more information, see the Triton Inference Server documentation.

You can start the server by running the following command:

tritonserver --model-repository=/workspace/data/ensemble --backend-config=tensorflow,version=2

For the --model-repository argument, specify the same value as the export_path that you specified previously in the ensemble.export method.

After you run the tritonserver command, wait until your terminal shows messages like the following example:

I0414 18:29:50.741833 4067 grpc_server.cc:4421] Started GRPCInferenceService at 0.0.0.0:8001
I0414 18:29:50.742197 4067 http_server.cc:3113] Started HTTPService at 0.0.0.0:8000
I0414 18:29:50.783470 4067 http_server.cc:178] Started Metrics Service at 0.0.0.0:8002

Retrieving Recommendations from Triton Inference Server#

Now that our server is running, we can send requests to it. This request is composed of values that correspond to the request schema that was created when we exported the ensemble graph.

In the code below we create a request to send to triton and send it. We will then analyze the response, to show the full experience.

First we need to ensure that we have a client connected to the server that we started. To do this, we use the Triton HTTP client library.

import tritonclient.http as client

# Create a triton client
try:
    triton_client = client.InferenceServerClient(url="localhost:8000", verbose=True)
    print("client created.")
except Exception as e:
    print("channel creation failed: " + str(e))
client created.

After we create the client and verified it is connected to the server instance, we can communicate with the server and ensure all the models are loaded correctly.

# ensure triton is in a good state
triton_client.is_server_live()
triton_client.get_model_repository_index()
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
POST /v2/repository/index, headers None

<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '179'}>
bytearray(b'[{"name":"0_transformworkflow","version":"1","state":"READY"},{"name":"1_predicttensorflow","version":"1","state":"READY"},{"name":"ensemble_model","version":"1","state":"READY"}]')
[{'name': '0_transformworkflow', 'version': '1', 'state': 'READY'},
 {'name': '1_predicttensorflow', 'version': '1', 'state': 'READY'},
 {'name': 'ensemble_model', 'version': '1', 'state': 'READY'}]

After verifying the models are correctly loaded by the server, we use some original validation data and send it as an inference request to the server. Here the valid folder was generated after 04-Exporting-ranking-models.ipynb notebook is run.

The df_lib object is cudf if a GPU is available and pandas otherwise.

from merlin.core.dispatch import get_lib

df_lib = get_lib()

original_data_path = os.environ.get("INPUT_FOLDER", "/workspace/data/")

# read in data for request
batch = df_lib.read_parquet(
    os.path.join(original_data_path,"valid", "part.0.parquet"), num_rows=3, columns=workflow.input_schema.column_names
)
batch
user_id item_id item_category item_shop item_brand user_shops user_profile user_group user_gender user_age user_consumption_2 user_is_occupied user_geography user_intentions user_brands user_categories
__null_dask_index__
700000 55 5 19 1305 450 2981 3 1 1 1 1 1 1 863 1482 156
700001 53 8 33 2283 787 2871 3 1 1 1 1 1 1 831 1427 150
700002 10 6 24 1631 562 497 1 1 1 1 1 1 1 144 247 26

After we isolate our batch, we convert the dataframe representation into inputs for Triton. We also declare the outputs that we expect to receive from the model.

from merlin.systems.triton import convert_df_to_triton_input
import tritonclient.grpc as grpcclient
# create inputs and outputs

inputs = convert_df_to_triton_input(workflow.input_schema, batch, grpcclient.InferInput)

output_cols = ensemble.graph.output_schema.column_names

outputs = [
    grpcclient.InferRequestedOutput(col)
    for col in output_cols
]

Now that our inputs and outputs are created, we can use the triton_client that we created earlier to send the inference request.

# send request to tritonserver
with grpcclient.InferenceServerClient("localhost:8001") as client:
    response = client.infer("executor_model", inputs, request_id="1", outputs=outputs)

When the server completes the inference request, it returns a response. This response is parsed to get the desired predictions.

# access individual response columns to get values back.
for col in ensemble.graph.output_schema.column_names:
    print(col, response.as_numpy(col), response.as_numpy(col).shape)
click/binary_classification_task [[0.49801633]
 [0.49624982]
 [0.49931964]] (3, 1)

Summary#

This sample notebook started with an exported DLRM model and workflow. We saw how to create an ensemble graph, verify the ensemble artifacts in the file system, and then put the ensemble into production with Triton Inference Server. Finally, we sent a simple inference request to the server and printed the response.