# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
Exporting Ranking Models#
This notebook is created using the latest stable merlin-tensorflow container.
In this example notebook we demonstrate how to export (save) NVTabular workflow
and a ranking model
for model deployment with Merlin Systems library.
Learning Objectives:
Export NVTabular workflow for model deployment
Export TensorFlow DLRM model for model deployment
Load saved NVTabular Workflow
Load saved trained Merlin Models model
Create Ensemble Graph
Export Ensemble Graph
Deploy model on Triton Inference Server
We will follow the steps below:
Prepare the data with NVTabular and export NVTabular workflow
Train a DLRM model with Merlin Models and export the trained model
Launch Triton server and deploy trained models on Triton
Send request to Triton and receive back the response
Importing Libraries#
Let’s start with importing the libraries that we’ll use in this notebook.
import os
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
import nvtabular as nvt
from nvtabular.ops import *
import numpy as np
from merlin.models.utils.example_utils import workflow_fit_transform
from merlin.schema.tags import Tags
import merlin.models.tf as mm
from merlin.io.dataset import Dataset
import tensorflow as tf
2023-06-28 21:03:00.600621: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/torch.py:43: UserWarning: PyTorch dtype mappings did not load successfully due to an error: No module named 'torch'
warn(f"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}")
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
[INFO]: sparse_operation_kit is imported
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.2.0-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so
[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.2.0-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so
[SOK INFO] Initialize finished, communication tool: horovod
2023-06-28 21:03:07.070258: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2023-06-28 21:03:07.070303: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:226] Using CUDA malloc Async allocator for GPU: 0
2023-06-28 21:03:07.070448: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1638] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 16249 MB memory: -> device: 0, name: Quadro GV100, pci bus id: 0000:2d:00.0, compute capability: 7.0
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Feature Engineering with NVTabular#
We use the synthetic train and test datasets generated by mimicking the real Ali-CCP: Alibaba Click and Conversion Prediction dataset to build our recommender system ranking models.
If you would like to use real Ali-CCP dataset instead, you can download the training and test datasets on tianchi.aliyun.com. You can then use get_aliccp() function to curate the raw csv files and save them as parquet files.
from merlin.datasets.synthetic import generate_data
DATA_FOLDER = os.environ.get("DATA_FOLDER", "/workspace/data/")
NUM_ROWS = os.environ.get("NUM_ROWS", 1000000)
SYNTHETIC_DATA = eval(os.environ.get("SYNTHETIC_DATA", "True"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 512))
if SYNTHETIC_DATA:
train, valid = generate_data("aliccp-raw", int(NUM_ROWS), set_sizes=(0.8, 0.2))
# save the datasets as parquet files
train.to_ddf().to_parquet(os.path.join(DATA_FOLDER, "train"))
valid.to_ddf().to_parquet(os.path.join(DATA_FOLDER, "valid"))
Let’s define our input and output paths.
train_path = os.path.join(DATA_FOLDER, "train", "*.parquet")
valid_path = os.path.join(DATA_FOLDER, "valid", "*.parquet")
output_path = os.path.join(DATA_FOLDER, "processed")
After we execute fit()
and transform()
functions on the raw dataset applying the operators defined in the NVTabular workflow pipeline below, the processed parquet files are saved to output_path
.
%%time
category_temp_directory = os.path.join(DATA_FOLDER, "categories")
user_id = ["user_id"] >> Categorify(out_path=category_temp_directory) >> TagAsUserID()
item_id = ["item_id"] >> Categorify(out_path=category_temp_directory) >> TagAsItemID()
targets = ["click"] >> AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, "target"])
item_features = ["item_category", "item_shop", "item_brand"] >> Categorify(out_path=category_temp_directory) >> TagAsItemFeatures()
user_features = (
[
"user_shops",
"user_profile",
"user_group",
"user_gender",
"user_age",
"user_consumption_2",
"user_is_occupied",
"user_geography",
"user_intentions",
"user_brands",
"user_categories",
]
>> Categorify(out_path=category_temp_directory)
>> TagAsUserFeatures()
)
outputs = user_id + item_id + item_features + user_features + targets
workflow = nvt.Workflow(outputs)
train_dataset = nvt.Dataset(train_path)
valid_dataset = nvt.Dataset(valid_path)
workflow.fit(train_dataset)
workflow.transform(train_dataset).to_parquet(output_path=output_path + "/train/")
workflow.transform(valid_dataset).to_parquet(output_path=output_path + "/valid/")
CPU times: user 2.61 s, sys: 1.09 s, total: 3.7 s
Wall time: 3.68 s
We save NVTabular workflow
model in the current working directory.
workflow.save(os.path.join(DATA_FOLDER, "workflow"))
Let’s check out our saved workflow model folder.
!pip install seedir
Requirement already satisfied: seedir in /usr/local/lib/python3.8/dist-packages (0.4.2)
Requirement already satisfied: natsort in /usr/local/lib/python3.8/dist-packages (from seedir) (8.4.0)
import seedir as sd
sd.seedir(
DATA_FOLDER,
style="lines",
itemlimit=10,
depthlimit=3,
exclude_folders=".ipynb_checkpoints",
sort=True,
)
data/
├─categories/
│ └─categories/
│ ├─meta.item_brand.parquet
│ ├─meta.item_category.parquet
│ ├─meta.item_id.parquet
│ ├─meta.item_shop.parquet
│ ├─meta.user_age.parquet
│ ├─meta.user_brands.parquet
│ ├─meta.user_categories.parquet
│ ├─meta.user_consumption_2.parquet
│ ├─meta.user_gender.parquet
│ └─meta.user_geography.parquet
├─dlrm/
│ ├─.merlin/
│ │ ├─input_schema.json
│ │ └─output_schema.json
│ ├─assets/
│ ├─fingerprint.pb
│ ├─keras_metadata.pb
│ ├─saved_model.pb
│ └─variables/
│ ├─variables.data-00000-of-00001
│ └─variables.index
├─processed/
│ ├─train/
│ │ ├─.merlin/
│ │ ├─_file_list.txt
│ │ ├─_metadata
│ │ ├─_metadata.json
│ │ ├─part_0.parquet
│ │ └─schema.pbtxt
│ └─valid/
│ ├─.merlin/
│ ├─_file_list.txt
│ ├─_metadata
│ ├─_metadata.json
│ ├─part_0.parquet
│ └─schema.pbtxt
├─train/
│ └─part.0.parquet
├─valid/
│ └─part.0.parquet
└─workflow/
├─categories/
│ ├─unique.item_brand.parquet
│ ├─unique.item_category.parquet
│ ├─unique.item_id.parquet
│ ├─unique.item_shop.parquet
│ ├─unique.user_age.parquet
│ ├─unique.user_brands.parquet
│ ├─unique.user_categories.parquet
│ ├─unique.user_consumption_2.parquet
│ ├─unique.user_gender.parquet
│ └─unique.user_geography.parquet
├─metadata.json
└─workflow.pkl
Build and Train a DLRM model#
In this example, we build, train, and export a Deep Learning Recommendation Model (DLRM) architecture. To learn more about how to train different deep learning models, how easily transition from one model to another and the seamless integration between data preparation and model training visit 03-Exploring-different-models.ipynb notebook.
NVTabular workflow above exports a schema file, schema.pbtxt, of our processed dataset. To learn more about the schema object, schema file and tags
, you can explore 02-Merlin-Models-and-NVTabular-integration.ipynb.
# define train and valid dataset objects
train = Dataset(os.path.join(output_path, "train", "*.parquet"))
valid = Dataset(os.path.join(output_path, "valid", "*.parquet"))
# define schema object
schema = train.schema
target_column = schema.select_by_tag(Tags.TARGET).column_names[0]
target_column
'click'
model = mm.DLRMModel(
schema,
embedding_dim=64,
bottom_block=mm.MLPBlock([128, 64]),
top_block=mm.MLPBlock([128, 64, 32]),
prediction_tasks=mm.BinaryOutput(target_column),
)
%%time
model.compile("adam", run_eagerly=False, metrics=[tf.keras.metrics.AUC()])
model.fit(train, validation_data=valid, batch_size=BATCH_SIZE)
2023-06-28 21:03:36.828993: I tensorflow/core/common_runtime/executor.cc:1209] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
[[{{node Placeholder/_0}}]]
1563/1563 [==============================] - ETA: 0s - loss: 0.6932 - auc: 0.4998 - regularization_loss: 0.0000e+00 - loss_batch: 0.6932
2023-06-28 21:04:40.190967: I tensorflow/core/common_runtime/executor.cc:1209] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
[[{{node Placeholder/_0}}]]
1563/1563 [==============================] - 69s 38ms/step - loss: 0.6932 - auc: 0.4998 - regularization_loss: 0.0000e+00 - loss_batch: 0.6932 - val_loss: 0.6931 - val_auc: 0.5000 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.6932
CPU times: user 1min 51s, sys: 14.1 s, total: 2min 5s
Wall time: 1min 11s
<keras.callbacks.History at 0x7f74b2a4b1c0>
Save model#
model.save(os.path.join(DATA_FOLDER, "dlrm"))
WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, prepare_list_features_layer_call_fn, prepare_list_features_layer_call_and_return_conditional_losses, dense_9_layer_call_fn while saving (showing 5 of 96). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /workspace/data/dlrm/assets
INFO:tensorflow:Assets written to: /workspace/data/dlrm/assets
We have NVTabular wokflow and DLRM model exported, now it is time to move on to the next step: model deployment with Merlin Systems.
Deploying the model with Merlin Systems#
The last step of machine learning (ML)/deep learning (DL) pipeline is to deploy the ETL workflow and saved model into production. In the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the DL model for a prediction. Therefore, we deploy the NVTabular workflow with the Tensorflow model as an ensemble model to Triton Inference using Merlin Systems library very easily. The ensemble model guarantees that the same transformation is applied to the raw inputs.
In the next steps, we will learn how to deploy NVTabular workflow and the trained DLRM model into Triton Inference Server with Merlin Systems library. NVIDIA Triton Inference Server (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.
First, we load the nvtabular.Workflow
that we created in with this example.
from nvtabular.workflow import Workflow
workflow = Workflow.load(os.path.join(DATA_FOLDER, "workflow"))
After we load the workflow, we remove the label columns from it’s inputs. This removes all columns with the TARGET tag from the workflow. We do this because we need to set the workflow to only require the features needed to predict, not train, when creating an inference pipeline.
from merlin.schema.tags import Tags
label_columns = workflow.output_schema.select_by_tag(Tags.TARGET).column_names
workflow.remove_inputs(label_columns)
<nvtabular.workflow.workflow.Workflow at 0x7f74b290f550>
After loading the workflow, we load the model. This model was trained with the output of the workflow from the Exporting Ranking Models example from Merlin Models.
First, we need to import the Merlin Models library. Loading a TensorFlow model, which is based on custom subclasses, requires to the subclass definition. Otherwise, TensorFlow cannot load correctly load the model.
tf_model_path = os.path.join(DATA_FOLDER, "dlrm")
model = tf.keras.models.load_model(tf_model_path)
Create the Ensemble Graph#
After we have both the model and the workflow loaded, we can create the ensemble graph. You create the graph. The goal is to illustrate the path of data through your full system. In this example we only serve a workflow with a model, but you can add other components that help you meet your business logic requirements.
Because this example has two components—a model and a workflow—we require two operators. These operators, also known as inference operators, are meant to abstract away all the “hard parts” of loading a specific component, such as a workflow or model, into Triton Inference Server.
The following code block shows how to use two inference operators:
TransformWorkflow:
This operator ensures that the workflow is correctly saved and packaged with the required config so the server will know how to load it.PredictTensorflow:
This operator will do something similar with the model, loaded before.
Let’s give it a try.
from merlin.systems.dag.ops.workflow import TransformWorkflow
from merlin.systems.dag.ops.tensorflow import PredictTensorflow
serving_operators = workflow.input_schema.column_names >> TransformWorkflow(workflow) >> PredictTensorflow(model)
WARNING:absl:Found untraced functions such as model_context_2_layer_call_fn, model_context_2_layer_call_and_return_conditional_losses, prepare_list_features_2_layer_call_fn, prepare_list_features_2_layer_call_and_return_conditional_losses, dense_9_layer_call_fn while saving (showing 5 of 96). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpomjyo5xq/assets
INFO:tensorflow:Assets written to: /tmp/tmpomjyo5xq/assets
Export Graph as Ensemble#
The last step is to create the ensemble artifacts that Triton Inference Server can consume. To make these artifacts, we import the Ensemble class. The class is responsible for interpreting the graph and exporting the correct files for the server.
After you run the following cell, you’ll see that we create a ColumnSchema for the expected inputs to the workflow. The workflow is a Schema.
When you are creating an Ensemble object you supply the graph and a schema representing the starting input of the graph. the inputs to the ensemble graph are the inputs to the first operator of your graph.
After you have created the Ensemble you export the graph, supplying an export path for the Ensemble.export function.
This returns an ensemble config which represents the entire inference pipeline and a list of node-specific configs.
Let’s take a look below.
workflow.output_schema
name | tags | dtype | is_list | is_ragged | properties.num_buckets | properties.freq_threshold | properties.max_size | properties.cat_path | properties.domain.min | properties.domain.max | properties.domain.name | properties.embedding_sizes.cardinality | properties.embedding_sizes.dimension | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | user_id | (Tags.CATEGORICAL, Tags.ID, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 772 | user_id | 773 | 66 |
1 | item_id | (Tags.CATEGORICAL, Tags.ITEM, Tags.ID) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.i... | 0 | 789 | item_id | 790 | 67 |
2 | item_category | (Tags.CATEGORICAL, Tags.ITEM) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.i... | 0 | 789 | item_category | 790 | 67 |
3 | item_shop | (Tags.CATEGORICAL, Tags.ITEM) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.i... | 0 | 789 | item_shop | 790 | 67 |
4 | item_brand | (Tags.CATEGORICAL, Tags.ITEM) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.i... | 0 | 789 | item_brand | 790 | 67 |
5 | user_shops | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 772 | user_shops | 773 | 66 |
6 | user_profile | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 73 | user_profile | 74 | 18 |
7 | user_group | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 13 | user_group | 14 | 16 |
8 | user_gender | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 4 | user_gender | 5 | 16 |
9 | user_age | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 8 | user_age | 9 | 16 |
10 | user_consumption_2 | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 5 | user_consumption_2 | 6 | 16 |
11 | user_is_occupied | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 4 | user_is_occupied | 5 | 16 |
12 | user_geography | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 6 | user_geography | 7 | 16 |
13 | user_intentions | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 772 | user_intentions | 773 | 66 |
14 | user_brands | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 772 | user_brands | 773 | 66 |
15 | user_categories | (Tags.CATEGORICAL, Tags.USER) | DType(name='int64', element_type=<ElementType.... | False | False | None | 0 | 0 | /workspace/data/categories/categories/unique.u... | 0 | 772 | user_categories | 773 | 66 |
from merlin.systems.dag.ensemble import Ensemble
ensemble = Ensemble(serving_operators, workflow.input_schema)
export_path = os.path.join(DATA_FOLDER, "ensemble")
ens_conf, node_confs = ensemble.export(export_path)
WARNING:absl:Found untraced functions such as model_context_2_layer_call_fn, model_context_2_layer_call_and_return_conditional_losses, prepare_list_features_2_layer_call_fn, prepare_list_features_2_layer_call_and_return_conditional_losses, dense_9_layer_call_fn while saving (showing 5 of 96). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets
INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Display the path to the directory with the ensemble.
print(export_path)
/workspace/data/ensemble
Verification of Ensemble Artifacts#
After we export the ensemble, we can check the export path for the graph’s artifacts. The directory structure represents an ordering number followed by an operator identifier such as 0_transformworkflow
, 1_predicttensorflow
, and so on.
Inside each of those directories, the export method writes a config.pbtxt file and a directory with a number. The number indicates the version and begins at 1. The artifacts for each operator are found inside the version folder. These artifacts vary depending on the operator in use.
Install the seedir python package so we can view some of the directory contents.
sd.seedir(export_path, style='lines', itemlimit=10, depthlimit=3, exclude_folders='.ipynb_checkpoints', sort=True)
ensemble/
├─0_transformworkflowtriton/
│ ├─1/
│ │ ├─model.py
│ │ └─workflow/
│ └─config.pbtxt
├─1_predicttensorflowtriton/
│ ├─1/
│ │ └─model.savedmodel/
│ └─config.pbtxt
└─executor_model/
├─1/
│ ├─ensemble/
│ └─model.py
└─config.pbtxt
Starting Triton Server#
After we export the ensemble, we are ready to start the Triton Inference Server. The server is installed in all the Merlin inference containers. If you are not using one of our containers, then ensure it is installed in your environment. For more information, see the Triton Inference Server documentation.
You can start the server by running the following command:
tritonserver --model-repository=/workspace/data/ensemble
For the --model-repository argument, specify the same value as the export_path that you specified previously in the ensemble.export method.
After you run the tritonserver command, wait until your terminal shows messages like the following example:
I0414 18:29:50.741833 4067 grpc_server.cc:4421] Started GRPCInferenceService at 0.0.0.0:8001
I0414 18:29:50.742197 4067 http_server.cc:3113] Started HTTPService at 0.0.0.0:8000
I0414 18:29:50.783470 4067 http_server.cc:178] Started Metrics Service at 0.0.0.0:8002
Retrieving Recommendations from Triton Inference Server#
Now that our server is running, we can send requests to it. This request is composed of values that correspond to the request schema that was created when we exported the ensemble graph.
In the code below we create a request to send to triton and send it. We will then analyze the response, to show the full experience.
First we need to ensure that we have a client connected to the server that we started. To do this, we use the Triton HTTP client library.
import tritonclient.http as client
# Create a triton client
try:
triton_client = client.InferenceServerClient(url="localhost:8000", verbose=True)
print("client created.")
except Exception as e:
print("channel creation failed: " + str(e))
client created.
After we create the client and verified it is connected to the server instance, we can communicate with the server and ensure all the models are loaded correctly.
# ensure triton is in a good state
triton_client.is_server_live()
triton_client.get_model_repository_index()
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
POST /v2/repository/index, headers None
<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '191'}>
bytearray(b'[{"name":"0_transformworkflowtriton","version":"1","state":"READY"},{"name":"1_predicttensorflowtriton","version":"1","state":"READY"},{"name":"executor_model","version":"1","state":"READY"}]')
[{'name': '0_transformworkflowtriton', 'version': '1', 'state': 'READY'},
{'name': '1_predicttensorflowtriton', 'version': '1', 'state': 'READY'},
{'name': 'executor_model', 'version': '1', 'state': 'READY'}]
After verifying the models are correctly loaded by the server, we use some original, raw validation data and send it as an inference request to the server.
The df_lib object is cudf if a GPU is available and pandas otherwise.
from merlin.core.dispatch import get_lib
df_lib = get_lib()
# read in data for request
batch = df_lib.read_parquet(
os.path.join(DATA_FOLDER,"valid", "part.0.parquet"), columns=workflow.input_schema.column_names
).head(3)
batch
user_id | item_id | item_category | item_shop | item_brand | user_shops | user_profile | user_group | user_gender | user_age | user_consumption_2 | user_is_occupied | user_geography | user_intentions | user_brands | user_categories | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
__null_dask_index__ | ||||||||||||||||
800000 | 25 | 26 | 85 | 5936 | 2045 | 1670 | 2 | 1 | 1 | 1 | 1 | 1 | 1 | 484 | 830 | 88 |
800001 | 28 | 13 | 41 | 2850 | 982 | 1879 | 2 | 1 | 1 | 1 | 1 | 1 | 1 | 544 | 934 | 98 |
800002 | 9 | 2 | 4 | 238 | 82 | 557 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 162 | 277 | 30 |
After we isolate our batch, we convert the dataframe representation into inputs for Triton. We also declare the outputs that we expect to receive from the model.
from merlin.systems.triton import convert_df_to_triton_input
import tritonclient.grpc as grpcclient
# create inputs and outputs
inputs = convert_df_to_triton_input(workflow.input_schema, batch, grpcclient.InferInput)
output_cols = ensemble.graph.output_schema.column_names
print(output_cols)
outputs = [
grpcclient.InferRequestedOutput(col)
for col in output_cols
]
['click/binary_output']
Now that our inputs and outputs are created, we can use the triton_client that we created earlier to send the inference request.
# send request to tritonserver
with grpcclient.InferenceServerClient("localhost:8001") as client:
response = client.infer("executor_model", inputs, request_id="1", outputs=outputs)
When the server completes the inference request, it returns a response, i.e. likelihood per request. This response is parsed to get the desired predictions.
predictions = response.as_numpy('click/binary_output')
print(predictions)
[[0.5002032]
[0.5001995]
[0.5001995]]
Summary#
This sample notebook started with data preprocessing and model training. We learned how to create an ensemble graph, verify the ensemble artifacts in the file system, and then put the ensemble into production with Triton Inference Server. Finally, we sent a simple inference request to the server and printed the response.