# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

Scaling Criteo: Triton Inference with TensorFlow

Overview

The last step is to deploy the ETL workflow and saved model to production. In the production setting, we want to transform the input data as during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the deep learning model for a prediction. Therefore, we deploy the NVTabular workflow with the TensorFlow model as an ensemble model to Triton Inference. The ensemble model garantuees that the same transformation are applied to the raw inputs.

../../_images/triton-tf.png

Learning objectives

In this notebook, we learn how to deploy our models to production

  • Use NVTabular to generate config and model files for Triton Inference Server

  • Deploy an ensemble of NVTabular workflow and TensorFlow model

  • Send example request to Triton Inference Server

Inference with Triton and TensorFlow

First, we need to generate the Triton Inference Server configurations and save the models in the correct format. In the previous notebooks 02-ETL-with-NVTabular and 03-Training-with-TF we saved the NVTabular workflow and TensorFlow model to disk. We will load them.

Saving Ensemble Model for Triton Inference Server

import os

import tensorflow as tf
import nvtabular as nvt
BASE_DIR = os.environ.get("BASE_DIR", "/raid/data/criteo")
input_path = os.environ.get("INPUT_DATA_DIR", os.path.join(BASE_DIR, "test_dask/output"))
workflow = nvt.Workflow.load(os.path.join(input_path, "workflow"))
/usr/local/lib/python3.8/dist-packages/nvtabular/workflow/workflow.py:373: UserWarning: Loading workflow generated with nvtabular version 0.10.0+123.g44d3c3e8.dirty - but we are running nvtabular 1.2.2+4.gebf56ca0f. This might cause issues
  warnings.warn(
model = tf.keras.models.load_model(os.path.join(input_path, "model.savedmodel"))
2022-07-14 23:15:34.019787: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-07-14 23:15:36.054064: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46898 MB memory:  -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:15:00.0, compute capability: 7.5
2022-07-14 23:15:36.054715: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 46890 MB memory:  -> device: 1, name: Quadro RTX 8000, pci bus id: 0000:2d:00.0, compute capability: 7.5

TensorFlow expect the Integer as int32 datatype. Therefore, we need to define the NVTabular output datatypes to int32 for categorical features.

for key in workflow.output_dtypes.keys():
    if key.startswith("C"):
        workflow.output_dtypes[key] = "int32"

NVTabular provides an easy function to deploy the ensemble model for Triton Inference Server.

from nvtabular.inference.triton import export_tensorflow_ensemble
export_tensorflow_ensemble(model, workflow, "criteo", "/tmp/model/models/", ["label"])
WARNING:absl:Function `_wrapped_model` contains input name(s) C1, C10, C11, C12, C13, C14, C15, C16, C17, C18, C19, C2, C20, C21, C22, C23, C24, C25, C26, C3, C4, C5, C6, C7, C8, C9, I1, I10, I11, I12, I13, I2, I3, I4, I5, I6, I7, I8, I9 with unsupported characters which will be renamed to c1, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c2, c20, c21, c22, c23, c24, c25, c26, c3, c4, c5, c6, c7, c8, c9, i1, i10, i11, i12, i13, i2, i3, i4, i5, i6, i7, i8, i9 in the SavedModel.
INFO:tensorflow:Assets written to: /tmp/model/models/criteo_tf/1/model.savedmodel/assets
INFO:tensorflow:Assets written to: /tmp/model/models/criteo_tf/1/model.savedmodel/assets
WARNING:absl:<keras.saving.saved_model.load.DenseFeatures object at 0x7f0638513520> has the same name 'DenseFeatures' as a built-in Keras object. Consider renaming <class 'keras.saving.saved_model.load.DenseFeatures'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C1, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C10, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C11, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C12, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C13, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C14, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C15, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C16, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C17, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C18, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C19, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C2, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C20, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C21, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C22, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C23, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C24, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C25, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C26, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C3, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C4, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C5, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C6, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C7, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C8, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects int32 for column C9, but workflow  is producing type int64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I1, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I10, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I11, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I12, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I13, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I2, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I3, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I4, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I5, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I6, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I7, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I8, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/nvtabular/inference/triton/ensemble.py:85: UserWarning: TF model expects float32 for column I9, but workflow  is producing type float64. Overriding dtype in NVTabular workflow.
  warnings.warn(

We can take a look on the generated files.

!tree /tmp/model
/tmp/model
└── models
    ├── criteo
    │   ├── 1
    │   └── config.pbtxt
    ├── criteo_nvt
    │   ├── 1
    │   │   ├── model.py
    │   │   └── workflow
    │   │       ├── categories
    │   │       │   ├── unique.C1.parquet
    │   │       │   ├── unique.C10.parquet
    │   │       │   ├── unique.C11.parquet
    │   │       │   ├── unique.C12.parquet
    │   │       │   ├── unique.C13.parquet
    │   │       │   ├── unique.C14.parquet
    │   │       │   ├── unique.C15.parquet
    │   │       │   ├── unique.C16.parquet
    │   │       │   ├── unique.C17.parquet
    │   │       │   ├── unique.C18.parquet
    │   │       │   ├── unique.C19.parquet
    │   │       │   ├── unique.C2.parquet
    │   │       │   ├── unique.C20.parquet
    │   │       │   ├── unique.C21.parquet
    │   │       │   ├── unique.C22.parquet
    │   │       │   ├── unique.C23.parquet
    │   │       │   ├── unique.C24.parquet
    │   │       │   ├── unique.C25.parquet
    │   │       │   ├── unique.C26.parquet
    │   │       │   ├── unique.C3.parquet
    │   │       │   ├── unique.C4.parquet
    │   │       │   ├── unique.C5.parquet
    │   │       │   ├── unique.C6.parquet
    │   │       │   ├── unique.C7.parquet
    │   │       │   ├── unique.C8.parquet
    │   │       │   └── unique.C9.parquet
    │   │       ├── metadata.json
    │   │       └── workflow.pkl
    │   └── config.pbtxt
    └── criteo_tf
        ├── 1
        │   └── model.savedmodel
        │       ├── assets
        │       ├── keras_metadata.pb
        │       ├── saved_model.pb
        │       └── variables
        │           ├── variables.data-00000-of-00001
        │           └── variables.index
        └── config.pbtxt

12 directories, 36 files

Loading Ensemble Model with Triton Inference Server

We have only saved the models for Triton Inference Server. We started Triton Inference Server in explicit mode, meaning that we need to send a request that Triton will load the ensemble model.

First, we restart this notebook to free the GPU memory.

# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

We define the BASE_DIR again.

import os

BASE_DIR = os.environ.get("BASE_DIR", "/raid/data/criteo")

We connect to the Triton Inference Server.

import tritonclient.grpc as grpc_client

try:
    triton_client = grpc_client.InferenceServerClient(url="localhost:8001", verbose=True)
    print("client created.")
except Exception as e:
    print("channel creation failed: " + str(e))
client created.

We deactivate warnings.

import warnings

warnings.filterwarnings("ignore")

We check if the server is alive.

triton_client.is_server_live()
is_server_live, metadata ()

live: true
True

We check the available models in the repositories:

  • criteo: Ensemble

  • criteo_nvt: NVTabular

  • criteo_tf: TensorFlow model

triton_client.get_model_repository_index()
get_model_repository_index, metadata ()

models {
  name: "criteo"
}
models {
  name: "criteo_nvt"
}
models {
  name: "criteo_tf"
}
models {
  name: "criteo"
}
models {
  name: "criteo_nvt"
}
models {
  name: "criteo_tf"
}

We load the ensembled model.

%%time

triton_client.load_model(model_name="criteo")
load_model, metadata ()
override files omitted:
model_name: "criteo"

Loaded model 'criteo'
CPU times: user 13.5 ms, sys: 8.86 ms, total: 22.4 ms
Wall time: 41.9 s

Example Request to Triton Inference Server

Now, the models are loaded and we can create a sample request. We read an example raw batch for inference.

# Get dataframe library - cudf or pandas
from merlin.core.dispatch import get_lib

df_lib = get_lib()

# read in the workflow (to get input/output schema to call triton with)
batch_path = os.path.join(BASE_DIR, "converted/criteo")
# raise(ValueError(f"{batch_path}"))
batch = df_lib.read_parquet(os.path.join(batch_path, "*.parquet"), num_rows=3)
batch = batch[[x for x in batch.columns if x != "label"]]
print(batch)
     I1   I2    I3    I4    I5  I6  I7  I8  I9  I10  ...        C17  \
0     5  110  <NA>    16  <NA>   1   0  14   7    1  ... -771205462   
1    32    3     5  <NA>     1   0   0  61   5    0  ... -771205462   
2  <NA>  233     1   146     1   0   0  99   7    0  ... -771205462   

          C18         C19         C20         C21        C22        C23  \
0 -1206449222 -1793932789 -1014091992   351689309  632402057 -675152885   
1 -1578429167 -1793932789   -20981661 -1556988767 -924717482  391309800   
2  1653545869 -1793932789 -1014091992   351689309  632402057 -675152885   

          C24         C25         C26  
0  2091868316   809724924  -317696227  
1  1966410890 -1726799382 -1218975401  
2   883538181   -10139646  -317696227  

[3 rows x 39 columns]

We prepare the batch for inference by using correct column names and data types. We use the same datatypes as defined in our dataframe.

batch.dtypes
I1     int32
I2     int32
I3     int32
I4     int32
I5     int32
I6     int32
I7     int32
I8     int32
I9     int32
I10    int32
I11    int32
I12    int32
I13    int32
C1     int32
C2     int32
C3     int32
C4     int32
C5     int32
C6     int32
C7     int32
C8     int32
C9     int32
C10    int32
C11    int32
C12    int32
C13    int32
C14    int32
C15    int32
C16    int32
C17    int32
C18    int32
C19    int32
C20    int32
C21    int32
C22    int32
C23    int32
C24    int32
C25    int32
C26    int32
dtype: object
import tritonclient.grpc as httpclient
from tritonclient.utils import np_to_triton_dtype
import numpy as np

inputs = []

col_names = list(batch.columns)
col_dtypes = [np.int32] * len(col_names)

for i, col in enumerate(batch.columns):
    d = batch[col].fillna(0).values_host.astype(col_dtypes[i])
    d = d.reshape(len(d), 1)
    inputs.append(httpclient.InferInput(col_names[i], d.shape, np_to_triton_dtype(col_dtypes[i])))
    inputs[i].set_data_from_numpy(d)

We send the request to the triton server and collect the last output.

# placeholder variables for the output
outputs = [httpclient.InferRequestedOutput("output")]

# build a client to connect to our server.
# This InferenceServerClient object is what we'll be using to talk to Triton.
# make the request with tritonclient.http.InferInput object
response = triton_client.infer("criteo", inputs, request_id="1", outputs=outputs)

print("predicted softmax result:\n", response.as_numpy("output"))
infer, metadata ()
model_name: "criteo"
id: "1"
inputs {
  name: "I1"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I2"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I3"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I4"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I5"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I6"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I7"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I8"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I9"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I10"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I11"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I12"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "I13"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C1"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C2"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C3"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C4"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C5"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C6"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C7"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C8"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C9"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C10"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C11"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C12"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C13"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C14"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C15"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C16"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C17"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C18"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C19"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C20"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C21"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C22"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C23"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C24"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C25"
  datatype: "INT32"
  shape: 3
  shape: 1
}
inputs {
  name: "C26"
  datatype: "INT32"
  shape: 3
  shape: 1
}
outputs {
  name: "output"
}
raw_input_contents: "\005\000\000\000 \000\000\000\000\000\000\000"
raw_input_contents: "n\000\000\000\003\000\000\000\351\000\000\000"
raw_input_contents: "\000\000\000\000\005\000\000\000\001\000\000\000"
raw_input_contents: "\020\000\000\000\000\000\000\000\222\000\000\000"
raw_input_contents: "\000\000\000\000\001\000\000\000\001\000\000\000"
raw_input_contents: "\001\000\000\000\000\000\000\000\000\000\000\000"
raw_input_contents: "\000\000\000\000\000\000\000\000\000\000\000\000"
raw_input_contents: "\016\000\000\000=\000\000\000c\000\000\000"
raw_input_contents: "\007\000\000\000\005\000\000\000\007\000\000\000"
raw_input_contents: "\001\000\000\000\000\000\000\000\000\000\000\000"
raw_input_contents: "\000\000\000\000\001\000\000\000\001\000\000\000"
raw_input_contents: "2\001\000\000U\014\000\000\035\014\000\000"
raw_input_contents: "\000\000\000\000\005\000\000\000\001\000\000\000"
raw_input_contents: "y\rwb\215\375\363\345y\rwb"
raw_input_contents: "X]\037\342\246\377\252\240\003B\230\255"
raw_input_contents: "/D\352\257\325\025\252o\r\306\276b"
raw_input_contents: "\317\177\\\224!4\212\332\356Il8"
raw_input_contents: "H\'\2608#\237\326<M\006U\347"
raw_input_contents: "\313m\315o\313m\315o\313m\315o"
raw_input_contents: "!\252\2005\201\355\026\253b\353\365\265"
raw_input_contents: "\003\211\200()lBC\213\314\362\321"
raw_input_contents: "\246\337\336FT\341\365\035\037\202N."
raw_input_contents: "\301}\002.\251\300\351}\301}\002."
raw_input_contents: "1B|\014d\334Rf1B|\014"
raw_input_contents: "\037\035\230\225\'N\353\231\204aq\022"
raw_input_contents: "\267\377\305\000\267\377\305\000\267\377\305\000"
raw_input_contents: "7\345N\2767\345N\2767\345N\276"
raw_input_contents: "\314t\013\212\231\376\273\363\013\r\017\367"
raw_input_contents: "\372>\334L\372>\334L\372>\334L"
raw_input_contents: "\252V\010\322\252V\010\322\252V\010\322"
raw_input_contents: "\272\013\027\270\021\025\353\241\215\033\217b"
raw_input_contents: "\013\302\022\225\013\302\022\225\013\302\022\225"
raw_input_contents: "(/\216\303c\330\277\376(/\216\303"
raw_input_contents: "]Z\366\024\241<2\243]Z\366\024"
raw_input_contents: "\211\260\261%V\356\341\310\211\260\261%"
raw_input_contents: "\013\374\301\327\350\351R\027\013\374\301\327"
raw_input_contents: "\234`\257|\212\0145u\005\271\2514"
raw_input_contents: "\374kC0\352!\023\231\002He\377"
raw_input_contents: "\035W\020\355W\351W\267\035W\020\355"

model_name: "criteo"
model_version: "1"
id: "1"
parameters {
  key: "sequence_end"
  value {
    bool_param: false
  }
}
parameters {
  key: "sequence_id"
  value {
    int64_param: 0
  }
}
parameters {
  key: "sequence_start"
  value {
    bool_param: false
  }
}
outputs {
  name: "output"
  datatype: "FP32"
  shape: 3
  shape: 1
}
raw_output_contents: "Dd\217<$r\233<\241\231u<"

predicted softmax result:
 [[0.01750387]
 [0.01897532]
 [0.01499024]]

Let’s unload the model. We need to unload each model.

triton_client.unload_model(model_name="criteo")
triton_client.unload_model(model_name="criteo_nvt")
triton_client.unload_model(model_name="criteo_tf")
unload_model, metadata ()
model_name: "criteo"
parameters {
  key: "unload_dependents"
  value {
    bool_param: false
  }
}

Unloaded model 'criteo'
unload_model, metadata ()
model_name: "criteo_nvt"
parameters {
  key: "unload_dependents"
  value {
    bool_param: false
  }
}

Unloaded model 'criteo_nvt'
unload_model, metadata ()
model_name: "criteo_tf"
parameters {
  key: "unload_dependents"
  value {
    bool_param: false
  }
}

Unloaded model 'criteo_tf'