http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-hps-tensorflow-triton-deployment/nvidia_logo.png

Deploy SavedModel using HPS with Triton TensorFlow Backend

Overview

This notebook demonstrates how to deploy the SavedModel that leverages HPS with Triton TensorFlow backend. It also shows how to apply TF-TRT optimization to SavedModel whose embedding lookup is based on HPS. It is recommended to run hierarchical_parameter_server_demo.ipynb before diving into this notebook.

For more details about HPS APIs, please refer to HPS APIs. For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

Installation

Get HPS from NGC

The HPS Python module is preinstalled in the 22.12 and later Merlin TensorFlow Container: nvcr.io/nvidia/merlin/merlin-tensorflow:22.12.

You can check the existence of the required libraries by running the following Python code after launching this container.

$ python3 -c "import hierarchical_parameter_server as hps"

The Triton TensorFlow backend is also available in this container.

Configurations

First of all we specify the required configurations, e.g., the arguments needed for generating the dataset, the paths to save the model and the model parameters. We will use a deep neural network (DNN) model which has one embedding table and several dense layers in this notebook. Please note that there are two inputs here, one is the key tensor (one-hot) while the other is the dense feature tensor.

import hierarchical_parameter_server as hps
import os
import numpy as np
import tensorflow as tf
import struct

args = dict()

args["gpu_num"] = 1                               # the number of available GPUs
args["iter_num"] = 10                             # the number of training iteration
args["slot_num"] = 5                              # the number of feature fields in this embedding layer
args["embed_vec_size"] = 16                       # the dimension of embedding vectors
args["global_batch_size"] = 1024                  # the globally batchsize for all GPUs
args["max_vocabulary_size"] = 50000
args["vocabulary_range_per_slot"] = [[0,10000],[10000,20000],[20000,30000],[30000,40000],[40000,50000]]
args["dense_dim"] = 10

args["dense_model_path"] = "hps_tf_triton_dense.model"
args["ps_config_file"] = "hps_tf_triton.json"
args["embedding_table_path"] = "hps_tf_triton_sparse_0.model"
args["saved_path"] = "hps_tf_triton_tf_saved_model"
args["np_key_type"] = np.int64
args["np_vector_type"] = np.float32
args["tf_key_type"] = tf.int64
args["tf_vector_type"] = tf.float32


os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(args["gpu_num"])))
[INFO] hierarchical_parameter_server is imported
/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!
  warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "
def generate_random_samples(num_samples, vocabulary_range_per_slot, dense_dim, key_dtype = args["np_key_type"]):
    keys = list()
    for vocab_range in vocabulary_range_per_slot:
        keys_per_slot = np.random.randint(low=vocab_range[0], high=vocab_range[1], size=(num_samples, 1), dtype=key_dtype)
        keys.append(keys_per_slot)
    keys = np.concatenate(np.array(keys), axis = 1)
    dense_features = np.random.random((num_samples, dense_dim)).astype(np.float32)
    labels = np.random.randint(low=0, high=2, size=(num_samples, 1))
    return keys, dense_features, labels

def tf_dataset(keys, dense_features, labels, batchsize):
    dataset = tf.data.Dataset.from_tensor_slices((keys, dense_features, labels))
    dataset = dataset.batch(batchsize, drop_remainder=True)
    return dataset

Train with native TF layers

We define the model graph for training with native TF layers, i.e., tf.nn.embedding_lookup and tf.keras.layers.Dense. Besides, the embedding weights are stored in tf.Variable. We can then train the model and extract the trained weights of the embedding table. As for the dense layers, they are saved as a separate model graph, which can be loaded directly during inference.

class TrainModel(tf.keras.models.Model):
    def __init__(self,
                 init_tensors,
                 slot_num,
                 embed_vec_size,
                 dense_dim,
                 **kwargs):
        super(TrainModel, self).__init__(**kwargs)
        
        self.slot_num = slot_num
        self.embed_vec_size = embed_vec_size
        self.dense_dim = dense_dim
        self.init_tensors = init_tensors
        self.params = tf.Variable(initial_value=tf.concat(self.init_tensors, axis=0))
        self.concat = tf.keras.layers.Concatenate(axis=1, name="concatenate")
        self.fc_1 = tf.keras.layers.Dense(units=256, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros",
                                                 name='fc_1')
        self.fc_2 = tf.keras.layers.Dense(units=1, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros",
                                                 name='fc_2')

    def call(self, inputs):
        keys, dense_features = inputs[0], inputs[1]
        embedding_vector = tf.nn.embedding_lookup(params=self.params, ids=keys)
        embedding_vector = tf.reshape(embedding_vector, shape=[-1, self.slot_num * self.embed_vec_size])
        concated_features = self.concat([embedding_vector, dense_features])
        logit = self.fc_2(self.fc_1(concated_features))
        return logit

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.slot_num, ), dtype=args["tf_key_type"]),
                  tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32)]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def train(args):
    init_tensors = np.ones(shape=[args["max_vocabulary_size"], args["embed_vec_size"]], dtype=args["np_vector_type"])
    
    model = TrainModel(init_tensors, args["slot_num"], args["embed_vec_size"], args["dense_dim"])
    model.summary()
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
    
    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit = model(inputs)
            loss = loss_fn(labels, logit)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return logit, loss

    keys, dense_features, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot"], args["dense_dim"])
    dataset = tf_dataset(keys, dense_features, labels, args["global_batch_size"])
    for i, (keys, dense_features, labels) in enumerate(dataset):
        inputs = [keys, dense_features]
        _, loss = _train_step(inputs, labels)
        print("-"*20, "Step {}, loss: {}".format(i, loss),  "-"*20)

    return model
trained_model = train(args)
weights_list = trained_model.get_weights()
embedding_weights = weights_list[-1]
dense_model = tf.keras.Model(trained_model.get_layer("concatenate").input,
                             trained_model.get_layer("fc_2").output)
dense_model.summary()
dense_model.save(args["dense_model_path"])
2022-11-23 01:36:13.919938: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-23 01:36:14.444040: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30991 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.compat.v1.nn.embedding_lookup), but are not present in its tracked objects:   <tf.Variable 'Variable:0' shape=(50000, 16) dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer.
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 5)]          0           []                               
                                                                                                  
 tf.compat.v1.nn.embedding_look  (None, 5, 16)       0           ['input_1[0][0]']                
 up (TFOpLambda)                                                                                  
                                                                                                  
 tf.reshape (TFOpLambda)        (None, 80)           0           ['tf.compat.v1.nn.embedding_looku
                                                                 p[0][0]']                        
                                                                                                  
 input_2 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 concatenate (Concatenate)      (None, 90)           0           ['tf.reshape[0][0]',             
                                                                  'input_2[0][0]']                
                                                                                                  
 fc_1 (Dense)                   (None, 256)          23296       ['concatenate[0][0]']            
                                                                                                  
 fc_2 (Dense)                   (None, 1)            257         ['fc_1[0][0]']                   
                                                                                                  
==================================================================================================
Total params: 23,553
Trainable params: 23,553
Non-trainable params: 0
__________________________________________________________________________________________________
-------------------- Step 0, loss: 10934.333984375 --------------------
-------------------- Step 1, loss: 9218.0703125 --------------------
-------------------- Step 2, loss: 7060.255859375 --------------------
-------------------- Step 3, loss: 5094.876953125 --------------------
-------------------- Step 4, loss: 3605.475830078125 --------------------
-------------------- Step 5, loss: 2593.270751953125 --------------------
-------------------- Step 6, loss: 1741.0677490234375 --------------------
-------------------- Step 7, loss: 1045.5091552734375 --------------------
-------------------- Step 8, loss: 541.4227905273438 --------------------
-------------------- Step 9, loss: 242.8596649169922 --------------------
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_3 (InputLayer)           [(None, 80)]         0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 concatenate (Concatenate)      (None, 90)           0           ['input_3[0][0]',                
                                                                  'input_2[0][0]']                
                                                                                                  
 fc_1 (Dense)                   (None, 256)          23296       ['concatenate[1][0]']            
                                                                                                  
 fc_2 (Dense)                   (None, 1)            257         ['fc_1[1][0]']                   
                                                                                                  
==================================================================================================
Total params: 23,553
Trainable params: 23,553
Non-trainable params: 0
__________________________________________________________________________________________________
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: hps_tf_triton_dense.model/assets
INFO:tensorflow:Assets written to: hps_tf_triton_dense.model/assets

Create the inference graph with HPS LookupLayer

In order to use HPS in the inference stage, we need to create a inference model graph which is almost the same as the train graph except that tf.nn.embedding_lookup is replaced by hps.LookupLayer. The trained dense model graph can be loaded directly, while the embedding weights should be converted to the formats required by HPS.

We can then save the inference model graph, which will be ready to be loaded for inference deployment. Please note that the inference SavedModel that leverages HPS will be deployed with the Triton TensorFlow backend, thus implicit initialization of HPS should be enabled by specifying ps_config_file and global_batch_size in the constructor of hps.LookupLayer. For more information, please refer to HPS Initialize.

To this end, we need to create a JSON configuration file and specify the details of the embedding tables for the models to be deployed. We only show how to deploy a model that has one embedding table here, and it can support multiple models with multiple embedding tables actually.

%%writefile hps_tf_triton.json
{
    "supportlonglong": true,
    "models": [{
        "model": "hps_tf_triton",
        "sparse_files": ["/hugectr/hps_tf/notebooks/model_repo/hps_tf_triton/hps_tf_triton_sparse_0.model"],
        "num_of_worker_buffer_in_pool": 3,
        "embedding_table_names":["sparse_embedding1"],
        "embedding_vecsize_per_table": [16],
        "maxnum_catfeature_query_per_table_per_sample": [5],
        "default_value_for_each_table": [1.0],
        "deployed_device_list": [0],
        "max_batch_size": 1024,
        "cache_refresh_percentage_per_iteration": 0.2,
        "hit_rate_threshold": 1.0,
        "gpucacheper": 1.0,
        "gpucache": true
        }
    ]
}
Writing hps_tf_triton.json
triton_model_repo = "/hugectr/hps_tf/notebooks/model_repo/hps_tf_triton/"

class InferenceModel(tf.keras.models.Model):
    def __init__(self,
                 slot_num,
                 embed_vec_size,
                 dense_dim,
                 dense_model_path,
                 **kwargs):
        super(InferenceModel, self).__init__(**kwargs)
        
        self.slot_num = slot_num
        self.embed_vec_size = embed_vec_size
        self.dense_dim = dense_dim
        self.lookup_layer = hps.LookupLayer(model_name = "hps_tf_triton", 
                                            table_id = 0,
                                            emb_vec_size = self.embed_vec_size,
                                            emb_vec_dtype = args["tf_vector_type"],
                                            ps_config_file = triton_model_repo + args["ps_config_file"],
                                            global_batch_size = args["global_batch_size"],
                                            name = "lookup")
        self.dense_model = tf.keras.models.load_model(dense_model_path)

    def call(self, inputs):
        keys, dense_features = inputs[0], inputs[1]
        embedding_vector = self.lookup_layer(keys)
        embedding_vector = tf.reshape(embedding_vector, shape=[-1, self.slot_num * self.embed_vec_size])
        logit = self.dense_model([embedding_vector, dense_features])
        return logit

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.slot_num, ), dtype=args["tf_key_type"]),
                  tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32)]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def create_and_save_inference_graph(args): 
    model = InferenceModel(args["slot_num"], args["embed_vec_size"], args["dense_dim"], args["dense_model_path"])
    model.summary()
    _ = model([tf.keras.Input(shape=(args["slot_num"], ), dtype=args["tf_key_type"]),
               tf.keras.Input(shape=(args["dense_dim"], ), dtype=tf.float32)])
    model.save(args["saved_path"])
def convert_to_sparse_model(embeddings_weights, embedding_table_path, embedding_vec_size):
    os.system("mkdir -p {}".format(embedding_table_path))
    with open("{}/key".format(embedding_table_path), 'wb') as key_file, \
        open("{}/emb_vector".format(embedding_table_path), 'wb') as vec_file:
      for key in range(embeddings_weights.shape[0]):
        vec = embeddings_weights[key]
        key_struct = struct.pack('q', key)
        vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
        key_file.write(key_struct)
        vec_file.write(vec_struct)
convert_to_sparse_model(embedding_weights, args["embedding_table_path"], args["embed_vec_size"])
create_and_save_inference_graph(args)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_4 (InputLayer)           [(None, 5)]          0           []                               
                                                                                                  
 lookup (LookupLayer)           (None, 5, 16)        0           ['input_4[0][0]']                
                                                                                                  
 tf.reshape_1 (TFOpLambda)      (None, 80)           0           ['lookup[0][0]']                 
                                                                                                  
 input_5 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 model_1 (Functional)           (None, 1)            23553       ['tf.reshape_1[0][0]',           
                                                                  'input_5[0][0]']                
                                                                                                  
==================================================================================================
Total params: 23,553
Trainable params: 23,553
Non-trainable params: 0
__________________________________________________________________________________________________
INFO:tensorflow:Assets written to: hps_tf_triton_tf_saved_model/assets
INFO:tensorflow:Assets written to: hps_tf_triton_tf_saved_model/assets

Deploy SavedModel using HPS with Triton TensorFlow Backend

In order to deploy the inference SavedModel with the Triton TensorFlow backend, we need to create the model repository and define the config.pbtxt first. Please note that some required portions (i.e., the input and output tensors) of the model configuration are generated automatically by Triton (see Auto-Generated Model Configuration), so you do NOT need to specify them explicitly in config.pbtxt.

!mkdir -p model_repo/hps_tf_triton/1
!mv hps_tf_triton_tf_saved_model model_repo/hps_tf_triton/1/model.savedmodel
!mv hps_tf_triton_sparse_0.model model_repo/hps_tf_triton
!mv hps_tf_triton.json model_repo/hps_tf_triton
%%writefile model_repo/hps_tf_triton/config.pbtxt
name: "hps_tf_triton"
platform: "tensorflow_savedmodel"
max_batch_size:1024
input [
  {
    name: "input_6"
    data_type: TYPE_INT64
    dims: [5]
  },
  {
    name: "input_7"
    data_type: TYPE_FP32
    dims: [10]
  }
]
output [
  {
    name: "output_1"
    data_type: TYPE_FP32
    dims: [1]
  }
]
version_policy: {
        specific:{versions: 1}
},
instance_group [
  {
    count: 1
    kind : KIND_GPU
    gpus: [0]
  }
]
Writing model_repo/hps_tf_triton/config.pbtxt
!tree model_repo/hps_tf_triton
model_repo/hps_tf_triton
├── 1
│   └── model.savedmodel
│       ├── assets
│       ├── keras_metadata.pb
│       ├── saved_model.pb
│       └── variables
│           ├── variables.data-00000-of-00001
│           └── variables.index
├── config.pbtxt
├── hps_tf_triton.json
└── hps_tf_triton_sparse_0.model
    ├── emb_vector
    └── key

5 directories, 8 files

We can then launch the Triton inference server using the TensorFlow backend. Please note that LD_PRELOAD is utilized to load the custom TensorFlow operations (i.e., HPS related operations) into Triton. For more information, please refer to TensorFlow Custom Operations in Triton.

Note: Since Background processes not supported by Jupyter, please launch the Triton Server according to the following command independently in the background.

LD_PRELOAD=/usr/local/lib/python3.8/dist-packages/merlin_hps-1.0.0-py3.8-linux-x86_64.egg/hierarchical_parameter_server/lib/libhierarchical_parameter_server.so tritonserver –model-repository=/hugectr/hps_tf/notebooks/model_repo –backend-config=tensorflow,version=2 –load-model=hps_tf_triton –model-control-mode=explicit

We can then send the requests to the Triton inference server using the HTTP client. Please note that HPS will be initialized implicitly when the first request is processed at the server side, and the latency can be higher than that of later requests.

!curl localhost:8000/v2/models/hps_tf_triton/config
{"name":"hps_tf_triton","platform":"tensorflow_savedmodel","backend":"tensorflow","version_policy":{"specific":{"versions":[1]}},"max_batch_size":1024,"input":[{"name":"input_6","data_type":"TYPE_INT64","format":"FORMAT_NONE","dims":[5],"is_shape_tensor":false,"allow_ragged_batch":false,"optional":false},{"name":"input_7","data_type":"TYPE_FP32","format":"FORMAT_NONE","dims":[10],"is_shape_tensor":false,"allow_ragged_batch":false,"optional":false}],"output":[{"name":"output_1","data_type":"TYPE_FP32","dims":[1],"label_filename":"","is_shape_tensor":false}],"batch_input":[],"batch_output":[],"optimization":{"priority":"PRIORITY_DEFAULT","input_pinned_memory":{"enable":true},"output_pinned_memory":{"enable":true},"gather_kernel_buffer_threshold":0,"eager_batching":false},"dynamic_batching":{"preferred_batch_size":[1024],"max_queue_delay_microseconds":0,"preserve_ordering":false,"priority_levels":0,"default_priority_level":0,"priority_queue_policy":{}},"instance_group":[{"name":"hps_tf_triton_0","kind":"KIND_GPU","count":1,"gpus":[0],"secondary_devices":[],"profile":[],"passive":false,"host_policy":""}],"default_model_filename":"model.savedmodel","cc_model_filenames":{},"metric_tags":{},"parameters":{},"model_warmup":[]}
import os
num_gpu = 1
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(num_gpu)))

import numpy as np
import tritonclient.http as httpclient
from tritonclient.utils import *


def send_inference_requests(num_requests, num_samples):
    triton_client = httpclient.InferenceServerClient(url="localhost:8000", verbose=True)
    triton_client.is_server_live()
    triton_client.get_model_repository_index()

    for i in range(num_requests):
        print("--------------------------Request {}--------------------------".format(i))
        key_tensor, dense_tensor, _ = generate_random_samples(num_samples, args["vocabulary_range_per_slot"], args["dense_dim"])

        inputs = [
            httpclient.InferInput("input_6", 
                                  key_tensor.shape,
                                  np_to_triton_dtype(np.int64)),
            httpclient.InferInput("input_7", 
                                  dense_tensor.shape,
                                  np_to_triton_dtype(np.float32)),
        ]

        inputs[0].set_data_from_numpy(key_tensor)
        inputs[1].set_data_from_numpy(dense_tensor)
        outputs = [
            httpclient.InferRequestedOutput("output_1")
        ]

        # print("Input key tensor is \n{}".format(key_tensor))
        # print("Input dense tensor is \n{}".format(dense_tensor))
        model_name = "hps_tf_triton"
        with httpclient.InferenceServerClient("localhost:8000") as client:
            response = client.infer(model_name,
                                    inputs,
                                    outputs=outputs)
            result = response.get_response()

            print("Response details:\n{}".format(result))
send_inference_requests(num_requests = 5, num_samples = 128)
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
POST /v2/repository/index, headers None

<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '56'}>
bytearray(b'[{"name":"hps_tf_triton","version":"1","state":"READY"}]')
--------------------------Request 0--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}
--------------------------Request 1--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}
--------------------------Request 2--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}
--------------------------Request 3--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}
--------------------------Request 4--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}

Deploy TF-TRT SavedModel using HPS with Triton TensorFlow Backend

We can leverage TF-TRT to optimize the above inference TF SavedModel. The hps.LookupLayer will fall back to the TF ops while the TensorRT engine will be built to execute the dense network. The optimized TF-TRT SavedModel can still be deployed with Triton TensorFlow backend.

The TF-TRT SavedModel is be placed in the folder "model_repo/hps_tf_triton/2/" and the config.pbtxt file is updated correspondingly to load the version 2 of the inference model, i.e., the TF-TRT optimized one.

# Build TF-TRT SavedModel
from tensorflow.python.compiler.tensorrt import trt_convert as trt

ORIGINAL_MODEL_PATH = "model_repo/hps_tf_triton/1/model.savedmodel"
NEW_MODEL_PATH = "model_repo/hps_tf_triton/2/model.savedmodel"

# Instantiate the TF-TRT converter
converter = trt.TrtGraphConverterV2(
   input_saved_model_dir=ORIGINAL_MODEL_PATH,
   precision_mode=trt.TrtPrecisionMode.FP32
)

# Convert the model into TRT compatible segments
trt_func = converter.convert()
converter.summary()

keys, dense_features, _ = generate_random_samples(args["global_batch_size"], args["vocabulary_range_per_slot"], args["dense_dim"])
keys  = tf.convert_to_tensor(keys)
dense_features = tf.convert_to_tensor(dense_features)
def input_fn():
   yield [keys, dense_features]

converter.build(input_fn=input_fn)
converter.save(output_saved_model_dir=NEW_MODEL_PATH)
INFO:tensorflow:Linked TensorRT version: (8, 4, 2)
INFO:tensorflow:Linked TensorRT version: (8, 4, 2)
INFO:tensorflow:Loaded TensorRT version: (8, 4, 2)
INFO:tensorflow:Loaded TensorRT version: (8, 4, 2)
INFO:tensorflow:Clearing prior device assignments in loaded saved model
2022-11-23 01:37:22.924379: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
2022-11-23 01:37:22.924537: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2022-11-23 01:37:22.928272: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30991 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
INFO:tensorflow:Clearing prior device assignments in loaded saved model
INFO:tensorflow:Automatic mixed precision has been deactivated.
INFO:tensorflow:Automatic mixed precision has been deactivated.
2022-11-23 01:37:23.028482: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
2022-11-23 01:37:23.028568: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2022-11-23 01:37:23.031909: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30991 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
2022-11-23 01:37:23.048593: W tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc:198] Calibration with FP32 or FP16 is not implemented. Falling back to use_calibration = False.Note that the default value of use_calibration is True.
2022-11-23 01:37:23.049761: W tensorflow/compiler/tf2tensorrt/segment/segment.cc:952] 

################################################################################
TensorRT unsupported/non-converted OP Report:
	- NoOp -> 2x
	- Placeholder -> 2x
	- Identity -> 1x
	- Init -> 1x
	- Lookup -> 1x
	- Reshape -> 1x
--------------------------------------------------------------------------------
	- Total nonconverted OPs: 8
	- Total nonconverted OP Types: 6
For more information see https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html#supported-ops.
################################################################################

2022-11-23 01:37:23.049860: W tensorflow/compiler/tf2tensorrt/segment/segment.cc:1280] The environment variable TF_TRT_MAX_ALLOWED_ENGINES=20 has no effect since there are only 1 TRT Engines with  at least minimum_segment_size=3 nodes.
2022-11-23 01:37:23.049893: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:799] Number of TensorRT candidate segments: 1
2022-11-23 01:37:23.050667: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:916] Replaced segment 0 consisting of 9 nodes by TRTEngineOp_000_000.
TRTEngineOP Name                 Device        # Nodes # Inputs      # Outputs     Input DTypes       Output Dtypes      Input Shapes       Output Shapes     
================================================================================================================================================================
TRTEngineOp_000_000              device:GPU:0  10      2             1             ['float32', 'f ... ['float32']        [[-1, 80], [-1 ... [[-1, 1]]         

	- BiasAdd: 2x
	- ConcatV2: 1x
	- Const: 5x
	- MatMul: 2x

================================================================================================================================================================
[*] Total number of TensorRT engines: 1
[*] % of OPs Converted: 50.00% [10/20]

=====================================================HPS Parse====================================================
[HCTR][01:37:23.329][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][01:37:23.329][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][01:37:23.329][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][01:37:23.329][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][01:37:23.329][INFO][RK0][main]: refresh_interval is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][01:37:23.329][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][01:37:23.329][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][01:37:23.329][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][01:37:23.329][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][01:37:23.329][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][01:37:23.745][INFO][RK0][main]: Table: hps_et.hps_tf_triton.sparse_embedding1; cached 50000 / 50000 embeddings in volatile database (HashMapBackend); load: 50000 / 18446744073709551615 (0.00%).
[HCTR][01:37:23.745][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][01:37:23.745][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][01:37:23.753][INFO][RK0][main]: Model name: hps_tf_triton
[HCTR][01:37:23.753][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][01:37:23.753][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][01:37:23.753][INFO][RK0][main]: Use I64 input key: True
[HCTR][01:37:23.753][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][01:37:23.753][INFO][RK0][main]: The size of thread pool: 80
[HCTR][01:37:23.753][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][01:37:23.753][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][01:37:23.753][INFO][RK0][main]: The refresh percentage : 0.200000
[HCTR][01:37:23.778][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][01:37:23.814][INFO][RK0][main]: EC initialization for model: "hps_tf_triton", num_tables: 1
[HCTR][01:37:23.814][INFO][RK0][main]: EC initialization on device: 0
[HCTR][01:37:23.815][INFO][RK0][main]: Creating lookup session for hps_tf_triton on device: 0
2022-11-23 01:37:23.818078: I tensorflow/compiler/tf2tensorrt/common/utils.cc:104] Linked TensorRT version: 8.4.2
2022-11-23 01:37:23.818150: I tensorflow/compiler/tf2tensorrt/common/utils.cc:106] Loaded TensorRT version: 8.4.2
2022-11-23 01:37:28.749149: I tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:1275] [TF-TRT] Sparse compute capability is enabled.
2022-11-23 01:37:28.814132: E tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc:86] DefaultLogger 1: [wrapper.cpp::CublasWrapper::85] Error Code 1: Cublas (Could not initialize cublas. Please check CUDA installation.)
2022-11-23 01:37:28.817575: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:1061] TF-TRT Warning: Engine creation for TRTEngineOp_000_000 failed. The native segment will be used instead. Reason: INTERNAL: Failed to build TensorRT engine
2022-11-23 01:37:28.817694: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:894] TF-TRT Warning: Engine retrieval for input shapes: [[1024,80], [1024,10]] failed. Running native segment for TRTEngineOp_000_000
2022-11-23 01:37:28.823806: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:894] TF-TRT Warning: Engine retrieval for input shapes: [[1024,80], [1024,10]] failed. Running native segment for TRTEngineOp_000_000
INFO:tensorflow:Assets written to: model_repo/hps_tf_triton/2/model.savedmodel/assets
INFO:tensorflow:Assets written to: model_repo/hps_tf_triton/2/model.savedmodel/assets
%%writefile model_repo/hps_tf_triton/config.pbtxt
name: "hps_tf_triton"
platform: "tensorflow_savedmodel"
max_batch_size:1024
input [
  {
    name: "input_6"
    data_type: TYPE_INT64
    dims: [5]
  },
  {
    name: "input_7"
    data_type: TYPE_FP32
    dims: [10]
  }
]
output [
  {
    name: "output_1"
    data_type: TYPE_FP32
    dims: [1]
  }
]
version_policy: {
        specific:{versions: 2}
},
instance_group [
  {
    count: 1
    kind : KIND_GPU
    gpus: [0]
  }
]
Overwriting model_repo/hps_tf_triton/config.pbtxt
!tree model_repo/hps_tf_triton
model_repo/hps_tf_triton
├── 1
│   └── model.savedmodel
│       ├── assets
│       ├── keras_metadata.pb
│       ├── saved_model.pb
│       └── variables
│           ├── variables.data-00000-of-00001
│           └── variables.index
├── 2
│   └── model.savedmodel
│       ├── assets
│       │   └── trt-serialized-engine.TRTEngineOp_000_000
│       ├── saved_model.pb
│       └── variables
│           ├── variables.data-00000-of-00001
│           └── variables.index
├── config.pbtxt
├── hps_tf_triton.json
└── hps_tf_triton_sparse_0.model
    ├── emb_vector
    └── key

9 directories, 12 files
# Release the occupied GPU memory by TensorFlow and Keras
from numba import cuda
cuda.select_device(0)
cuda.close()

We can then launch the Triton inference server using the TensorFlow backend using the same command in the background. Please remember to kill the previous tritonserver process completely before launching it again. Otherwise, there could be out of memory errors.

When the triton server is succesfully launched, we can then send the requests to it using the HTTP client again.

!curl localhost:8000/v2/models/hps_tf_triton/config
{"name":"hps_tf_triton","platform":"tensorflow_savedmodel","backend":"tensorflow","version_policy":{"specific":{"versions":[2]}},"max_batch_size":1024,"input":[{"name":"input_6","data_type":"TYPE_INT64","format":"FORMAT_NONE","dims":[5],"is_shape_tensor":false,"allow_ragged_batch":false,"optional":false},{"name":"input_7","data_type":"TYPE_FP32","format":"FORMAT_NONE","dims":[10],"is_shape_tensor":false,"allow_ragged_batch":false,"optional":false}],"output":[{"name":"output_1","data_type":"TYPE_FP32","dims":[1],"label_filename":"","is_shape_tensor":false}],"batch_input":[],"batch_output":[],"optimization":{"priority":"PRIORITY_DEFAULT","input_pinned_memory":{"enable":true},"output_pinned_memory":{"enable":true},"gather_kernel_buffer_threshold":0,"eager_batching":false},"dynamic_batching":{"preferred_batch_size":[1024],"max_queue_delay_microseconds":0,"preserve_ordering":false,"priority_levels":0,"default_priority_level":0,"priority_queue_policy":{}},"instance_group":[{"name":"hps_tf_triton_0","kind":"KIND_GPU","count":1,"gpus":[0],"secondary_devices":[],"profile":[],"passive":false,"host_policy":""}],"default_model_filename":"model.savedmodel","cc_model_filenames":{},"metric_tags":{},"parameters":{},"model_warmup":[]}
send_inference_requests(num_requests = 5, num_samples = 128)
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
POST /v2/repository/index, headers None

<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '56'}>
bytearray(b'[{"name":"hps_tf_triton","version":"2","state":"READY"}]')
--------------------------Request 0--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '2', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}
--------------------------Request 1--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '2', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}
--------------------------Request 2--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '2', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}
--------------------------Request 3--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '2', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}
--------------------------Request 4--------------------------
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '2', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [128, 1], 'parameters': {'binary_data_size': 512}}]}