http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-hps-tensorflow-triton-deployment/nvidia_logo.png

HPS TensorRT Plugin Demo for TensorFlow Trained Model

Overview

This notebook demonstrates how to build and deploy the HPS-integrated TensorRT engine for the model trained with TensorFlow.

For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

Installation

Use NGC

The HPS TensorRT plugin is preinstalled in the 23.05 and later Merlin TensorFlow Container: nvcr.io/nvidia/merlin/merlin-tensorflow:23.05.

You can check the existence of the required libraries by running the following Python code after launching this container.

import ctypes
plugin_lib_name = "/usr/local/hps_trt/lib/libhps_plugin.so"
plugin_handle = ctypes.CDLL(plugin_lib_name, mode=ctypes.RTLD_GLOBAL)

Configurations

First of all we specify the required configurations, e.g., the arguments needed for generating the dataset, the model parameters and the paths to save the model. We will use DLRM model which has one embedding table, bottom MLP layers, interaction layer and top MLP layers. Please note that the input to the embedding layer will be a dense key tensor of int32.

import os
import numpy as np
import tensorflow as tf
import struct

args = dict()

args["gpu_num"] = 1                               # the number of available GPUs
args["iter_num"] = 50                             # the number of training iteration
args["slot_num"] = 26                             # the number of feature fields in this embedding layer
args["embed_vec_size"] = 128                      # the dimension of embedding vectors
args["dense_dim"] = 13                            # the dimension of dense features
args["global_batch_size"] = 1024                  # the globally batchsize for all GPUs
args["max_vocabulary_size"] = 260000
args["vocabulary_range_per_slot"] = [[i*10000, (i+1)*10000] for i in range(26)]
args["combiner"] = "mean"

args["ps_config_file"] = "dlrm_tf.json"
args["embedding_table_path"] = "dlrm_tf_sparse.model"
args["saved_path"] = "dlrm_tf_saved_model"
args["np_key_type"] = np.int32
args["np_vector_type"] = np.float32
args["tf_key_type"] = tf.int32
args["tf_vector_type"] = tf.float32

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(args["gpu_num"])))
2023-01-03 07:47:02.443077: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
def generate_random_samples(num_samples, vocabulary_range_per_slot, dense_dim, key_dtype = args["np_key_type"]):
    keys = list()
    for vocab_range in vocabulary_range_per_slot:
        keys_per_slot = np.random.randint(low=vocab_range[0], high=vocab_range[1], size=(num_samples, 1), dtype=key_dtype)
        keys.append(keys_per_slot)
    keys = np.concatenate(np.array(keys), axis = 1)
    numerical_features = np.random.random((num_samples, dense_dim)).astype(np.float32)
    labels = np.random.randint(low=0, high=2, size=(num_samples, 1))
    return keys, numerical_features, labels

def tf_dataset(keys, numerical_features, labels, batchsize):
    dataset = tf.data.Dataset.from_tensor_slices((keys, numerical_features, labels))
    dataset = dataset.batch(batchsize, drop_remainder=True)
    return dataset

Train with native TF layers

We define the model graph for training with native TF layers, i.e., tf.nn.embedding_lookup, tf.keras.layers.Dense and so on. We can then train the model and extract the trained weights of the embedding table.

class MLP(tf.keras.layers.Layer):
    def __init__(self,
                arch,
                activation='relu',
                out_activation=None,
                **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.layers = []
        index = 0
        for units in arch[:-1]:
            self.layers.append(tf.keras.layers.Dense(units, activation=activation, name="{}_{}".format(kwargs['name'], index)))
            index+=1
        self.layers.append(tf.keras.layers.Dense(arch[-1], activation=out_activation, name="{}_{}".format(kwargs['name'], index)))

            
    def call(self, inputs, training=True):
        x = self.layers[0](inputs)
        for layer in self.layers[1:]:
            x = layer(x)
        return x

class SecondOrderFeatureInteraction(tf.keras.layers.Layer):
    def __init__(self):
        super(SecondOrderFeatureInteraction, self).__init__()

    def call(self, inputs, num_feas):
        dot_products = tf.reshape(tf.matmul(inputs, inputs, transpose_b=True), (-1, num_feas * num_feas))
        indices = tf.constant([i * num_feas + j for j in range(1, num_feas) for i in range(j)])
        flat_interactions = tf.gather(dot_products, indices, axis=1)
        return flat_interactions

class DLRM(tf.keras.models.Model):
    def __init__(self,
                 init_tensors,
                 embed_vec_size,
                 slot_num,
                 dense_dim,
                 arch_bot,
                 arch_top,
                 **kwargs):
        super(DLRM, self).__init__(**kwargs)
        

        self.init_tensors = init_tensors
        self.params = tf.Variable(initial_value=tf.concat(self.init_tensors, axis=0))
        
        self.embed_vec_size = embed_vec_size
        self.slot_num = slot_num
        self.dense_dim = dense_dim
    
        self.bot_nn = MLP(arch_bot, name = "bottom", out_activation='relu')
        self.top_nn = MLP(arch_top, name = "top", out_activation='sigmoid')
        self.interaction_op = SecondOrderFeatureInteraction()

        self.interaction_out_dim = self.slot_num * (self.slot_num+1) // 2
        self.reshape_layer1 = tf.keras.layers.Reshape((1, arch_bot[-1]), name = "reshape1")
        self.concat1 = tf.keras.layers.Concatenate(axis=1, name = "concat1")
        self.concat2 = tf.keras.layers.Concatenate(axis=1, name = "concat2")
            
    def call(self, inputs, training=True):
        categorical_features = inputs["keys"]
        numerical_features = inputs["numerical_features"]
        
        embedding_vector = tf.nn.embedding_lookup(params=self.params, ids=categorical_features)
        dense_x = self.bot_nn(numerical_features)
        concat_features = self.concat1([embedding_vector, self.reshape_layer1(dense_x)])
        
        Z = self.interaction_op(concat_features, self.slot_num+1)
        z = self.concat2([dense_x, Z])
        logit = self.top_nn(z)
        return logit

    def summary(self):
        inputs = {"keys": tf.keras.Input(shape=(self.slot_num, ), dtype=args["tf_key_type"], name="keys"), 
                  "numerical_features": tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32, name="numrical_features")}
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
 def train(args):
    init_tensors = np.ones(shape=[args["max_vocabulary_size"], args["embed_vec_size"]], dtype=args["np_vector_type"])
    
    model = DLRM(init_tensors, args["embed_vec_size"], args["slot_num"], args["dense_dim"],
                arch_bot = [512, 256, args["embed_vec_size"]],
                arch_top = [1024, 1024, 512, 256, 1],
                name = "dlrm")
    model.summary()
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
    loss_fn = tf.keras.losses.BinaryCrossentropy()
    
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit = model(inputs)
            loss = loss_fn(labels, logit)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return loss, logit

    keys, numerical_features, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot"], args["dense_dim"], args["np_key_type"])
    dataset = tf_dataset(keys, numerical_features, labels, args["global_batch_size"])
    for i, (keys, numerical_features, labels) in enumerate(dataset):
        inputs = {"keys": keys, "numerical_features": numerical_features}
        loss, logit = _train_step(inputs, labels)
        print("-"*20, "Step {}, loss: {}".format(i, loss),  "-"*20)

    return model
trained_model = train(args)
weights_list = trained_model.get_weights()
embedding_weights = weights_list[-1]
trained_model.save(args["saved_path"])
WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.compat.v1.nn.embedding_lookup), but are not present in its tracked objects:   <tf.Variable 'Variable:0' shape=(260000, 128) dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer.
2023-01-03 07:47:04.662936: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-03 07:47:05.163322: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1637] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31004 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 numrical_features (InputLayer)  [(None, 13)]        0           []                               
                                                                                                  
 bottom (MLP)                   (None, 128)          171392      ['numrical_features[0][0]']      
                                                                                                  
 keys (InputLayer)              [(None, 26)]         0           []                               
                                                                                                  
 tf.compat.v1.nn.embedding_look  (None, 26, 128)     0           ['keys[0][0]']                   
 up (TFOpLambda)                                                                                  
                                                                                                  
 reshape1 (Reshape)             (None, 1, 128)       0           ['bottom[0][0]']                 
                                                                                                  
 concat1 (Concatenate)          (None, 27, 128)      0           ['tf.compat.v1.nn.embedding_looku
                                                                 p[0][0]',                        
                                                                  'reshape1[0][0]']               
                                                                                                  
 second_order_feature_interacti  (None, 351)         0           ['concat1[0][0]']                
 on (SecondOrderFeatureInteract                                                                   
 ion)                                                                                             
                                                                                                  
 concat2 (Concatenate)          (None, 479)          0           ['bottom[0][0]',                 
                                                                  'second_order_feature_interactio
                                                                 n[0][0]']                        
                                                                                                  
 top (MLP)                      (None, 1)            2197505     ['concat2[0][0]']                
                                                                                                  
==================================================================================================
Total params: 2,368,897
Trainable params: 2,368,897
Non-trainable params: 0
__________________________________________________________________________________________________
-------------------- Step 0, loss: 37.65617752075195 --------------------
-------------------- Step 1, loss: 355382208.0 --------------------
-------------------- Step 2, loss: 1477252.125 --------------------
-------------------- Step 3, loss: 1909784.0 --------------------
-------------------- Step 4, loss: 507233.625 --------------------
-------------------- Step 5, loss: 3365.16259765625 --------------------
-------------------- Step 6, loss: 268666.8125 --------------------
-------------------- Step 7, loss: 1651924.25 --------------------
-------------------- Step 8, loss: 101612.2421875 --------------------
-------------------- Step 9, loss: 202236.03125 --------------------
-------------------- Step 10, loss: 66273.078125 --------------------
-------------------- Step 11, loss: 48022.0703125 --------------------
-------------------- Step 12, loss: 73294.34375 --------------------
-------------------- Step 13, loss: 27559.2265625 --------------------
-------------------- Step 14, loss: 91068.890625 --------------------
-------------------- Step 15, loss: 2597.23193359375 --------------------
-------------------- Step 16, loss: 144474.921875 --------------------
-------------------- Step 17, loss: 78464.734375 --------------------
-------------------- Step 18, loss: 3672.82666015625 --------------------
-------------------- Step 19, loss: 618.08203125 --------------------
-------------------- Step 20, loss: 9272.6083984375 --------------------
-------------------- Step 21, loss: 12456.2373046875 --------------------
-------------------- Step 22, loss: 278.4212646484375 --------------------
-------------------- Step 23, loss: 1402.9920654296875 --------------------
-------------------- Step 24, loss: 1270.646484375 --------------------
-------------------- Step 25, loss: 24.113067626953125 --------------------
-------------------- Step 26, loss: 1.259315848350525 --------------------
-------------------- Step 27, loss: 0.8482578992843628 --------------------
-------------------- Step 28, loss: 0.6932626366615295 --------------------
-------------------- Step 29, loss: 0.8282428979873657 --------------------
-------------------- Step 30, loss: 0.8464676141738892 --------------------
-------------------- Step 31, loss: 0.7537561655044556 --------------------
-------------------- Step 32, loss: 0.7004175782203674 --------------------
-------------------- Step 33, loss: 820.9553833007812 --------------------
-------------------- Step 34, loss: 0.9226191639900208 --------------------
-------------------- Step 35, loss: 1.0884783267974854 --------------------
-------------------- Step 36, loss: 0.965215265750885 --------------------
-------------------- Step 37, loss: 0.7583897709846497 --------------------
-------------------- Step 38, loss: 0.737635612487793 --------------------
-------------------- Step 39, loss: 0.8850047588348389 --------------------
-------------------- Step 40, loss: 0.8901631236076355 --------------------
-------------------- Step 41, loss: 0.7397541999816895 --------------------
-------------------- Step 42, loss: 0.7126980423927307 --------------------
-------------------- Step 43, loss: 0.8687698245048523 --------------------
-------------------- Step 44, loss: 0.8077254295349121 --------------------
-------------------- Step 45, loss: 0.6924254298210144 --------------------
-------------------- Step 46, loss: 0.8170604109764099 --------------------
-------------------- Step 47, loss: 0.7409390211105347 --------------------
-------------------- Step 48, loss: 0.6984725594520569 --------------------
-------------------- Step 49, loss: 0.7613609433174133 --------------------
WARNING:absl:Found untraced functions such as bottom_0_layer_call_fn, bottom_0_layer_call_and_return_conditional_losses, bottom_1_layer_call_fn, bottom_1_layer_call_and_return_conditional_losses, bottom_2_layer_call_fn while saving (showing 5 of 16). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: dlrm_tf_saved_model/assets
INFO:tensorflow:Assets written to: dlrm_tf_saved_model/assets
# Release the occupied GPU memory by TensorFlow and Keras
from numba import cuda
cuda.select_device(0)
cuda.close()

Build the HPS-integrated TensorRT engine

In order to use HPS in the inference stage, we need to convert the embedding weights to the formats required by HPS first and create JSON configuration file for HPS.

Then we convert the TF saved model to ONNX, and employ the ONNX GraphSurgoen tool to replace the native TF embedding lookup layer with the placeholder of HPS TensorRT plugin layer.

After that, we can build the TensorRT engine, which is comprised of the HPS TensorRT plugin layer and the dense network.

Step1: Prepare sparse model and JSON configuration file for HPS

Please note that the storage format in the dlrm_tf_sparse.model/key file is int64, while the HPS TensorRT plugin currently only support int32 when loading the keys into memory. There is no overflow since the key value range is 0~260000.

def convert_to_sparse_model(embeddings_weights, embedding_table_path, embedding_vec_size):
    os.system("mkdir -p {}".format(embedding_table_path))
    with open("{}/key".format(embedding_table_path), 'wb') as key_file, \
        open("{}/emb_vector".format(embedding_table_path), 'wb') as vec_file:
      for key in range(embeddings_weights.shape[0]):
        vec = embeddings_weights[key]
        key_struct = struct.pack('q', key)
        vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
        key_file.write(key_struct)
        vec_file.write(vec_struct)
convert_to_sparse_model(embedding_weights, args["embedding_table_path"], args["embed_vec_size"])
%%writefile dlrm_tf.json
{
    "supportlonglong": false,
    "models": [{
        "model": "dlrm",
        "sparse_files": ["dlrm_tf_sparse.model"],
        "num_of_worker_buffer_in_pool": 3,
        "embedding_table_names":["sparse_embedding0"],
        "embedding_vecsize_per_table": [128],
        "maxnum_catfeature_query_per_table_per_sample": [26],
        "default_value_for_each_table": [1.0],
        "deployed_device_list": [0],
        "max_batch_size": 1024,
        "cache_refresh_percentage_per_iteration": 0.2,
        "hit_rate_threshold": 1.0,
        "gpucacheper": 1.0,
        "gpucache": true
        }
    ]
}
Overwriting dlrm_tf.json

Step2: Convert to ONNX and do ONNX graph surgery

# covert TF SavedModel to ONNX
!python -m tf2onnx.convert --saved-model dlrm_tf_saved_model --output dlrm_tf.onnx
2023-01-03 07:47:30.112237: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
/usr/lib/python3.8/runpy.py:127: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
  warn(RuntimeWarning(msg))
2023-01-03 07:47:32,420 - WARNING - ***IMPORTANT*** Installed protobuf is not cpp accelerated. Conversion will be extremely slow. See https://github.com/onnx/tensorflow-onnx/issues/1557
2023-01-03 07:47:33,012 - WARNING - '--tag' not specified for saved_model. Using --tag serve
2023-01-03 07:47:35,656 - INFO - Signatures found in model: [serving_default].
2023-01-03 07:47:35,656 - WARNING - '--signature_def' not specified, using first signature: serving_default
2023-01-03 07:47:35,656 - INFO - Output names: ['output_1']
2023-01-03 07:47:41,580 - INFO - Using tensorflow=2.10.0, onnx=1.13.0, tf2onnx=1.13.0/2c1db5
2023-01-03 07:47:41,580 - INFO - Using opset <onnx, 13>
2023-01-03 07:47:42,881 - INFO - Computed 0 values for constant folding
2023-01-03 07:47:43,913 - INFO - Optimizing ONNX model
2023-01-03 07:47:44,115 - INFO - After optimization: Cast -3 (3->0), Concat -1 (3->2), Const -15 (35->20), Identity -2 (2->0), Shape -1 (1->0), Slice -1 (1->0), Squeeze -1 (1->0), Unsqueeze -3 (3->0)
2023-01-03 07:47:45,507 - INFO - 
2023-01-03 07:47:45,507 - INFO - Successfully converted TensorFlow model dlrm_tf_saved_model to ONNX
2023-01-03 07:47:45,508 - INFO - Model inputs: ['keys', 'numerical_features']
2023-01-03 07:47:45,508 - INFO - Model outputs: ['output_1']
2023-01-03 07:47:45,508 - INFO - ONNX model is saved at dlrm_tf.onnx
# ONNX graph surgery to insert HPS the TensorRT plugin placeholder
import onnx_graphsurgeon as gs
from onnx import  shape_inference
import numpy as np
import onnx

graph = gs.import_onnx(onnx.load("dlrm_tf.onnx"))
saved = []

for node in graph.nodes:
    if node.name == "StatefulPartitionedCall/dlrm/embedding_lookup":
        categorical_features = gs.Variable(name="categorical_features", dtype=np.int32, shape=("unknown", 26))
        hps_node = gs.Node(op="HPS_TRT", attrs={"ps_config_file": "dlrm_tf.json\0", "model_name": "dlrm\0", "table_id": 0, "emb_vec_size": 128}, 
                           inputs=[categorical_features], outputs=[node.outputs[0]])
        graph.nodes.append(hps_node)
        saved.append(categorical_features)
        node.outputs.clear()
for i in graph.inputs:
    if i.name == "numerical_features":
        saved.append(i)
graph.inputs = saved

graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "dlrm_tf_with_hps.onnx")

Step3: Build the TensorRT engine

# build the TensorRT engine based on dlrm_tf_with_hps.onnx
import tensorrt as trt
import ctypes

plugin_lib_name = "/usr/local/hps_trt/lib/libhps_plugin.so"
handle = ctypes.CDLL(plugin_lib_name, mode=ctypes.RTLD_GLOBAL)

TRT_LOGGER = trt.Logger(trt.Logger.INFO)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

def build_engine_from_onnx(onnx_model_path):
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as builder_config:        
        model = open(onnx_model_path, 'rb')
        parser.parse(model.read())

        profile = builder.create_optimization_profile()        
        profile.set_shape("categorical_features", (1, 26), (1024, 26), (1024, 26))    
        profile.set_shape("numerical_features", (1, 13), (1024, 13), (1024, 13))
        builder_config.add_optimization_profile(profile)
        engine = builder.build_serialized_network(network, builder_config)
        return engine

serialized_engine = build_engine_from_onnx("dlrm_tf_with_hps.onnx")
with open("dlrm_tf_with_hps.trt", "wb") as fout:
    fout.write(serialized_engine)
print("Succesfully build the TensorRT engine")
[01/03/2023-07:47:52] [TRT] [I] [MemUsageChange] Init CUDA: CPU +268, GPU +0, now: CPU 1636, GPU 497 (MiB)
[01/03/2023-07:47:54] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +170, GPU +46, now: CPU 1860, GPU 543 (MiB)
[01/03/2023-07:47:54] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
[01/03/2023-07:47:54] [TRT] [W] onnx2trt_utils.cpp:377: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/03/2023-07:47:54] [TRT] [I] No importer registered for op: HPS_TRT. Attempting to import as plugin.
[01/03/2023-07:47:54] [TRT] [I] Searching for plugin: HPS_TRT, plugin_version: 1, plugin_namespace: 
=====================================================HPS Parse====================================================
[HCTR][07:47:54.544][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][07:47:54.545][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][07:47:54.545][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][07:47:54.545][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][07:47:54.545][INFO][RK0][main]: refresh_interval is not specified using default: 0
[HCTR][07:47:54.545][INFO][RK0][main]: use_static_table is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][07:47:54.545][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][07:47:54.545][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][07:47:54.545][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][07:47:54.545][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][07:47:54.545][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:47:54.545][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:47:54.968][INFO][RK0][main]: Table: hps_et.dlrm.sparse_embedding0; cached 260000 / 260000 embeddings in volatile database (HashMapBackend); load: 260000 / 18446744073709551615 (0.00%).
[HCTR][07:47:54.978][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][07:47:54.978][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][07:47:54.985][INFO][RK0][main]: Model name: dlrm
[HCTR][07:47:54.985][INFO][RK0][main]: Max batch size: 1024
[HCTR][07:47:54.985][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][07:47:54.985][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][07:47:54.985][INFO][RK0][main]: Use static table: False
[HCTR][07:47:54.985][INFO][RK0][main]: Use I64 input key: False
[HCTR][07:47:54.985][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][07:47:54.985][INFO][RK0][main]: The size of thread pool: 80
[HCTR][07:47:54.985][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][07:47:54.985][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][07:47:54.985][INFO][RK0][main]: The refresh percentage : 0.200000
[HCTR][07:47:55.028][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:47:55.028][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:47:55.165][INFO][RK0][main]: EC initialization for model: "dlrm", num_tables: 1
[HCTR][07:47:55.165][INFO][RK0][main]: EC initialization on device: 0
[HCTR][07:47:55.197][INFO][RK0][main]: Creating lookup session for dlrm on device: 0
[01/03/2023-07:47:55] [TRT] [I] Successfully created plugin: HPS_TRT
[01/03/2023-07:47:55] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +88, GPU +46, now: CPU 6129, GPU 835 (MiB)
[01/03/2023-07:47:55] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +115, GPU +52, now: CPU 6244, GPU 887 (MiB)
[01/03/2023-07:47:55] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[01/03/2023-07:49:04] [TRT] [I] Total Activation Memory: 34103362048
[01/03/2023-07:49:04] [TRT] [I] Detected 2 inputs and 1 output network tensors.
[01/03/2023-07:49:05] [TRT] [I] Total Host Persistent Memory: 416
[01/03/2023-07:49:05] [TRT] [I] Total Device Persistent Memory: 0
[01/03/2023-07:49:05] [TRT] [I] Total Scratch Memory: 45142016
[01/03/2023-07:49:05] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 75 MiB
[01/03/2023-07:49:05] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 3 steps to complete.
[01/03/2023-07:49:05] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.013464ms to assign 3 blocks to 3 nodes requiring 58774016 bytes.
[01/03/2023-07:49:05] [TRT] [I] Total Activation Memory: 58774016
[01/03/2023-07:49:05] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 6533, GPU 1035 (MiB)
[01/03/2023-07:49:05] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 6533, GPU 1043 (MiB)
[01/03/2023-07:49:05] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +32, now: CPU 0, GPU 32 (MiB)
Succesfully build the TensorRT engine

Deploy HPS-integrated TensorRT engine on Triton

In order to deploy the TensorRT engine with the Triton TensorRT backend, we need to create the model repository and define the config.pbtxt first.

!mkdir -p model_repo/dlrm_tf_with_hps/1
!mv dlrm_tf_with_hps.trt model_repo/dlrm_tf_with_hps/1
%%writefile model_repo/dlrm_tf_with_hps/config.pbtxt

platform: "tensorrt_plan"
default_model_filename: "dlrm_tf_with_hps.trt"
backend: "tensorrt"
max_batch_size: 0
input [
  {
    name: "categorical_features"
    data_type: TYPE_INT32
    dims: [-1,26]
  },
  {
    name: "numerical_features"
    data_type: TYPE_FP32
    dims: [-1,13]
  }
]
output [
  {
      name: "output_1"
      data_type: TYPE_FP32
      dims: [-1,1]
  }
]
instance_group [
  {
    count: 1
    kind: KIND_GPU
    gpus:[0]

  }
]
Writing model_repo/dlrm_tf_with_hps/config.pbtxt
!tree model_repo/dlrm_tf_with_hps
model_repo/dlrm_tf_with_hps
├── 1
│   └── dlrm_tf_with_hps.trt
└── config.pbtxt

1 directory, 2 files

We can then launch the Triton inference server using the TensorRT backend. Please note that LD_PRELOAD is utilized to load the custom TensorRT plugin (i.e., HPS TensorRT plugin) into Triton.

Note: Since Background processes not supported by Jupyter, please launch the Triton Server according to the following command independently in the background.

LD_PRELOAD=/usr/local/hps_trt/lib/libhps_plugin.so tritonserver –model-repository=/hugectr/hps_trt/notebooks/model_repo/ –load-model=dlrm_tf_with_hps –model-control-mode=explicit

If you successfully started tritonserver, you should see a log similar to following:

+----------+--------------------------------+--------------------------------+
| Backend  | Path                           | Config                         |
+----------+--------------------------------+--------------------------------+
| tensorrt | /opt/tritonserver/backends/ten | {"cmdline":{"auto-complete-con |
|          | sorrt/libtriton_tensorrt.so    | fig":"true","min-compute-capab |
|          |                                | ility":"6.000000","backend-dir |
|          |                                | ectory":"/opt/tritonserver/bac |
|          |                                | kends","default-max-batch-size |
|          |                                | ":"4"}}                        |
|          |                                |                                |
+----------+--------------------------------+--------------------------------+

+------------------+---------+--------+
| Model            | Version | Status |
+------------------+---------+--------+
| dlrm_tf_with_hps | 1       | READY  |
+------------------+---------+--------+

We can then send the requests to the Triton inference server using the HTTP client.

import os
import shutil
import numpy as np
import tritonclient.http as httpclient
from tritonclient.utils import *

BATCH_SIZE = 1024

categorical_feature = np.random.randint(0,260000,size=(BATCH_SIZE,26)).astype(np.int32)
numerical_feature = np.random.random((BATCH_SIZE, 13)).astype(np.float32)

inputs = [
    httpclient.InferInput("categorical_features", 
                          categorical_feature.shape,
                          np_to_triton_dtype(np.int32)),
    httpclient.InferInput("numerical_features", 
                          numerical_feature.shape,
                          np_to_triton_dtype(np.float32)),                          
]
inputs[0].set_data_from_numpy(categorical_feature)
inputs[1].set_data_from_numpy(numerical_feature)


outputs = [
    httpclient.InferRequestedOutput("output_1")
]

model_name = "dlrm_tf_with_hps"

with httpclient.InferenceServerClient("localhost:8000") as client:
    response = client.infer(model_name,
                            inputs,
                            outputs=outputs)
    result = response.get_response()
    
    print("Prediction result is \n{}".format(response.as_numpy("output_1")))
    print("Response details:\n{}".format(result))
Prediction result is 
[[0.6404852]
 [0.6404852]
 [0.6404852]
 ...
 [0.6404852]
 [0.6404852]
 [0.6404852]]
Response details:
{'model_name': 'dlrm_tf_with_hps', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [1024, 1], 'parameters': {'binary_data_size': 4096}}]}