http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-hps-tensorflow-triton-deployment/nvidia_logo.png

HPS TensorRT Plugin Demo for HugeCTR Trained Model

Overview

This notebook demonstrates how to build and deploy the HPS-integrated TensorRT engine for the model trained with HugeCTR.

For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

Installation

Use NGC

The HPS TensorRT plugin is preinstalled in the 23.01 and later Merlin HugeCTR Container: nvcr.io/nvidia/merlin/merlin-hugectr:23.01.

You can check the existence of the required libraries by running the following Python code after launching this container.

import ctypes
plugin_lib_name = "/usr/local/hps_trt/lib/libhps_plugin.so"
plugin_handle = ctypes.CDLL(plugin_lib_name, mode=ctypes.RTLD_GLOBAL)

Data Generation

HugeCTR provides a tool to generate synthetic datasets. The Data Generator is capable of generating datasets of different file formats and different distributions. We will generate one-hot Parquet datasets with power-law distribution for this notebook:

import hugectr
from hugectr.tools import DataGeneratorParams, DataGenerator

data_generator_params = DataGeneratorParams(
  format = hugectr.DataReaderType_t.Parquet,
  label_dim = 1,
  dense_dim = 13,
  num_slot = 26,
  i64_input_key = True,
  nnz_array = [1 for _ in range(26)],
  source = "./data_parquet/file_list.txt",
  eval_source = "./data_parquet/file_list_test.txt",
  slot_size_array = [10000 for _ in range(26)],
  check_type = hugectr.Check_t.Non,
  dist_type = hugectr.Distribution_t.PowerLaw,
  power_law_type = hugectr.PowerLaw_t.Short,
  num_files = 16,
  eval_num_files = 4,
  num_samples_per_file = 40960)
data_generator = DataGenerator(data_generator_params)
data_generator.generate()
[HCTR][05:12:08.561][INFO][RK0][main]: Generate Parquet dataset
[HCTR][05:12:08.561][INFO][RK0][main]: train data folder: ./data_parquet, eval data folder: ./data_parquet, slot_size_array: 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, nnz array: 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, #files for train: 16, #files for eval: 4, #samples per file: 40960, Use power law distribution: 1, alpha of power law: 1.3
[HCTR][05:12:08.564][INFO][RK0][main]: ./data_parquet exist
[HCTR][05:12:08.568][INFO][RK0][main]: ./data_parquet/train/gen_0.parquet
[HCTR][05:12:10.204][INFO][RK0][main]: ./data_parquet/train/gen_1.parquet
[HCTR][05:12:10.455][INFO][RK0][main]: ./data_parquet/train/gen_2.parquet
[HCTR][05:12:10.709][INFO][RK0][main]: ./data_parquet/train/gen_3.parquet
[HCTR][05:12:10.957][INFO][RK0][main]: ./data_parquet/train/gen_4.parquet
[HCTR][05:12:11.196][INFO][RK0][main]: ./data_parquet/train/gen_5.parquet
[HCTR][05:12:11.437][INFO][RK0][main]: ./data_parquet/train/gen_6.parquet
[HCTR][05:12:11.681][INFO][RK0][main]: ./data_parquet/train/gen_7.parquet
[HCTR][05:12:11.920][INFO][RK0][main]: ./data_parquet/train/gen_8.parquet
[HCTR][05:12:12.171][INFO][RK0][main]: ./data_parquet/train/gen_9.parquet
[HCTR][05:12:12.411][INFO][RK0][main]: ./data_parquet/train/gen_10.parquet
[HCTR][05:12:12.650][INFO][RK0][main]: ./data_parquet/train/gen_11.parquet
[HCTR][05:12:12.885][INFO][RK0][main]: ./data_parquet/train/gen_12.parquet
[HCTR][05:12:13.120][INFO][RK0][main]: ./data_parquet/train/gen_13.parquet
[HCTR][05:12:13.341][INFO][RK0][main]: ./data_parquet/train/gen_14.parquet
[HCTR][05:12:13.577][INFO][RK0][main]: ./data_parquet/train/gen_15.parquet
[HCTR][05:12:13.818][INFO][RK0][main]: ./data_parquet/file_list.txt done!
[HCTR][05:12:13.827][INFO][RK0][main]: ./data_parquet/val/gen_0.parquet
[HCTR][05:12:14.066][INFO][RK0][main]: ./data_parquet/val/gen_1.parquet
[HCTR][05:12:14.299][INFO][RK0][main]: ./data_parquet/val/gen_2.parquet
[HCTR][05:12:14.537][INFO][RK0][main]: ./data_parquet/val/gen_3.parquet
[HCTR][05:12:14.751][INFO][RK0][main]: ./data_parquet/file_list_test.txt done!

Train with HugeCTR

We can train a DLRM model with HugeCTR Python APIs. The trained sparse and dense model files will be saved separately. The model graph will be dumped into a JSON file.

%%writefile train.py
import hugectr
from mpi4py import MPI


solver = hugectr.CreateSolver(
    model_name="dlrm",
    max_eval_batches=160,
    batchsize_eval=1024,
    batchsize=1024,
    lr=0.001,
    vvgpu=[[0]],
    repeat_dataset=True,
    use_mixed_precision=True,
    use_cuda_graph=True,
    scaler=1024,
    i64_input_key=True,
)
reader = hugectr.DataReaderParams(
    data_reader_type=hugectr.DataReaderType_t.Parquet,
    source=["./data_parquet/file_list.txt"],
    eval_source="./data_parquet/file_list_test.txt",
    slot_size_array=[10000 for _ in range(26)],
    check_type=hugectr.Check_t.Non,
)
optimizer = hugectr.CreateOptimizer(
    optimizer_type=hugectr.Optimizer_t.Adam,
    update_type=hugectr.Update_t.Global,
    beta1=0.9,
    beta2=0.999,
    epsilon=0.0001,
)

model = hugectr.Model(solver, reader, optimizer)
model.add(
    hugectr.Input(
        label_dim=1,
        label_name="label",
        dense_dim=13,
        dense_name="numerical_features",
        data_reader_sparse_param_array=[hugectr.DataReaderSparseParam("keys", 1, True, 26)],
    )
)
model.add(
    hugectr.SparseEmbedding(
        embedding_type=hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
        workspace_size_per_gpu_in_mb=5000,
        embedding_vec_size=128,
        combiner="mean",
        sparse_embedding_name="sparse_embedding1",
        bottom_name="keys",
        optimizer=optimizer,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.MLP,
        bottom_names=["numerical_features"],
        top_names=["mlp1"],
        num_outputs=[512, 256, 128],
        act_type=hugectr.Activation_t.Relu,
        use_bias=True,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.Interaction,
        bottom_names=["mlp1", "sparse_embedding1"],
        top_names=["interaction1"],
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.MLP,
        bottom_names=["interaction1"],
        top_names=["mlp2"],
        num_outputs=[1024, 1024, 512, 256, 1],
        use_bias=True,
        activations=[
            hugectr.Activation_t.Relu,
            hugectr.Activation_t.Relu,
            hugectr.Activation_t.Relu,
            hugectr.Activation_t.Relu,
            hugectr.Activation_t.Non,
        ],
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss,
        bottom_names=["mlp2", "label"],
        top_names=["loss"],
    )
)
model.graph_to_json("dlrm_hugectr_graph.json")
model.compile()
model.summary()
model.fit(max_iter=1200, display=200, eval_interval=1000, snapshot=1000, snapshot_prefix="dlrm_hugectr")
Writing train.py
!python3 train.py
--------------------------------------------------------------------------
An error occurred while trying to map in the address of a function.
  Function Name: cuIpcOpenMemHandle_v2
  Error string:  /usr/lib/x86_64-linux-gnu/libcuda.so.1: undefined symbol: cuIpcOpenMemHandle_v2
CUDA-aware support is disabled.
--------------------------------------------------------------------------
HugeCTR Version: 4.1
====================================================Model Init=====================================================
[HCTR][05:12:24.539][INFO][RK0][main]: Initialize model: dlrm
[HCTR][05:12:24.539][INFO][RK0][main]: Global seed is 2950905596
[HCTR][05:12:24.542][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][05:12:26.698][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][05:12:26.698][INFO][RK0][main]: Start all2all warmup
[HCTR][05:12:26.698][INFO][RK0][main]: End all2all warmup
[HCTR][05:12:26.699][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][05:12:26.700][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][05:12:26.705][INFO][RK0][main]: num of DataReader workers for train: 1
[HCTR][05:12:26.705][INFO][RK0][main]: num of DataReader workers for eval: 1
[HCTR][05:12:26.782][INFO][RK0][main]: Vocabulary size: 260000
[HCTR][05:12:26.782][INFO][RK0][main]: max_vocabulary_size_per_gpu_=3413333
[HCTR][05:12:26.791][INFO][RK0][main]: Graph analysis to resolve tensor dependency
[HCTR][05:12:26.795][INFO][RK0][main]: Save the model graph to dlrm_hugectr_graph.json successfully
===================================================Model Compile===================================================
[HCTR][05:12:27.772][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][05:12:27.781][INFO][RK0][main]: gpu0 init embedding done
[HCTR][05:12:27.783][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][05:12:27.785][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][05:12:27.785][INFO][RK0][main]: Model structure on each GPU
Label                                   Dense                         Sparse                        
label                                   numerical_features             keys                          
(1024,1)                                (1024,13)                               
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      keys                          sparse_embedding1             (1024,26,128)                 
------------------------------------------------------------------------------------------------------------------
MLP                                     numerical_features            mlp1                          (1024,128)                    
------------------------------------------------------------------------------------------------------------------
Interaction                             mlp1                          interaction1                  (1024,480)                    
                                        sparse_embedding1                                                                         
------------------------------------------------------------------------------------------------------------------
MLP                                     interaction1                  mlp2                          (1024,1)                      
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  mlp2                          loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[HCTR][05:12:27.785][INFO][RK0][main]: Use non-epoch mode with number of iterations: 1200
[HCTR][05:12:27.785][INFO][RK0][main]: Training batchsize: 1024, evaluation batchsize: 1024
[HCTR][05:12:27.785][INFO][RK0][main]: Evaluation interval: 1000, snapshot interval: 1000
[HCTR][05:12:27.785][INFO][RK0][main]: Dense network trainable: True
[HCTR][05:12:27.785][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][05:12:27.785][INFO][RK0][main]: Use mixed precision: True, scaler: 1024.000000, use cuda graph: True
[HCTR][05:12:27.785][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][05:12:27.785][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][05:12:27.785][INFO][RK0][main]: Training source file: ./data_parquet/file_list.txt
[HCTR][05:12:27.785][INFO][RK0][main]: Evaluation source file: ./data_parquet/file_list_test.txt
[HCTR][05:12:31.522][INFO][RK0][main]: Iter: 200 Time(200 iters): 3.72017s Loss: 0.693168 lr:0.001
[HCTR][05:12:35.188][INFO][RK0][main]: Iter: 400 Time(200 iters): 3.64947s Loss: 0.694016 lr:0.001
[HCTR][05:12:38.814][INFO][RK0][main]: Iter: 600 Time(200 iters): 3.60927s Loss: 0.69323 lr:0.001
[HCTR][05:12:42.432][INFO][RK0][main]: Iter: 800 Time(200 iters): 3.60078s Loss: 0.693079 lr:0.001
[HCTR][05:12:46.050][INFO][RK0][main]: Iter: 1000 Time(200 iters): 3.60162s Loss: 0.693134 lr:0.001
[HCTR][05:12:46.206][INFO][RK0][main]: Evaluation, AUC: 0.498656
[HCTR][05:12:46.206][INFO][RK0][main]: Eval Time for 160 iters: 0.156138s
[HCTR][05:12:46.206][INFO][RK0][main]: Using Local file system backend.
[HCTR][05:12:46.272][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][05:12:47.456][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][05:12:47.958][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][05:12:47.958][INFO][RK0][main]: Using Local file system backend.
[HCTR][05:12:56.286][INFO][RK0][main]: Done
[HCTR][05:12:56.840][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][05:12:56.840][INFO][RK0][main]: Using Local file system backend.
[HCTR][05:13:06.514][INFO][RK0][main]: Done
[HCTR][05:13:06.555][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][05:13:06.561][INFO][RK0][main]: Using Local file system backend.
[HCTR][05:13:06.693][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][05:13:06.694][INFO][RK0][main]: Using Local file system backend.
[HCTR][05:13:06.823][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][05:13:10.414][INFO][RK0][main]: Finish 1200 iterations with batchsize: 1024 in 42.63s.

Build the HPS-integrated TensorRT engine

The sparse saved model dlrm_hugectr0_sparse_1000.model is already in the format that HPS requires. In order to use HPS in the inference stage, we need to create JSON configuration file for HPS.

Then we convert the dense saved model dlrm_hugectr_dense_1000.model to ONNX using hugectr2onnx, and employ the ONNX GraphSurgoen tool to replace the input embedding vectors with with the placeholder of HPS TensorRT plugin layer.

After that, we can build the TensorRT engine, which is comprised of the HPS TensorRT plugin layer and the dense network.

Step1: Prepare JSON configuration file for HPS

Please note that the storage format in the dlrm_hugectr0_sparse_1000.model/key file is int64, while the HPS TensorRT plugin only supports int32 when loading the keys into memory. There is no overflow since the key value range is 0~260000.

%%writefile dlrm_hugectr.json
{
    "supportlonglong": false,
    "models": [{
        "model": "dlrm",
        "sparse_files": ["dlrm_hugectr0_sparse_1000.model"],
        "num_of_worker_buffer_in_pool": 3,
        "embedding_table_names":["sparse_embedding0"],
        "embedding_vecsize_per_table": [128],
        "maxnum_catfeature_query_per_table_per_sample": [26],
        "default_value_for_each_table": [1.0],
        "deployed_device_list": [0],
        "max_batch_size": 1024,
        "cache_refresh_percentage_per_iteration": 0.2,
        "hit_rate_threshold": 1.0,
        "gpucacheper": 1.0,
        "gpucache": true
        }
    ]
}
Writing dlrm_hugectr.json

Step2: Convert to ONNX and do ONNX graph surgery

# hugectr2onnx
import hugectr2onnx
hugectr2onnx.converter.convert(onnx_model_path = "dlrm_hugectr_dense.onnx",
                            graph_config = "dlrm_hugectr_graph.json",
                            dense_model = "dlrm_hugectr_dense_1000.model",
                            convert_embedding = False)
[HUGECTR2ONNX][INFO]: Converting Data layer to ONNX
Skip sparse embedding layers in converted ONNX model
[HUGECTR2ONNX][INFO]: Converting DistributedSlotSparseEmbeddingHash layer to ONNX
[HUGECTR2ONNX][INFO]: Converting MLP layer to ONNX
[HUGECTR2ONNX][INFO]: Converting Interaction layer to ONNX
[HUGECTR2ONNX][INFO]: Converting MLP layer to ONNX
[HUGECTR2ONNX][INFO]: Converting Sigmoid layer to ONNX
[HUGECTR2ONNX][INFO]: The model is checked!
[HUGECTR2ONNX][INFO]: The model is saved at dlrm_hugectr_dense.onnx
# ONNX graph surgery to insert HPS the TensorRT plugin placeholder
import onnx_graphsurgeon as gs
from onnx import  shape_inference
import numpy as np
import onnx

graph = gs.import_onnx(onnx.load("dlrm_hugectr_dense.onnx"))
saved = []

for i in graph.inputs:
    if i.name == "sparse_embedding1":
        categorical_features = gs.Variable(name="categorical_features", dtype=np.int32, shape=("unknown_1", 26))
        node = gs.Node(op="HPS_TRT", attrs={"ps_config_file": "dlrm_hugectr.json\0", "model_name": "dlrm\0", "table_id": 0, "emb_vec_size": 128}, inputs=[categorical_features], outputs=[i])
        graph.nodes.append(node)
        saved.append(categorical_features)
    elif i.name == "numerical_features":
        i.shape = ("unknown_2", 13)
        saved.append(i)

graph.inputs = saved

graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "dlrm_hugectr_with_hps.onnx")

Step3: Build the TensorRT engine

# build the TensorRT engine based on dlrm_with_hps.onnx
import tensorrt as trt
import ctypes

plugin_lib_name = "/usr/local/hps_trt/lib/libhps_plugin.so"
handle = ctypes.CDLL(plugin_lib_name, mode=ctypes.RTLD_GLOBAL)

TRT_LOGGER = trt.Logger(trt.Logger.INFO)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

def build_engine_from_onnx(onnx_model_path):
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as builder_config:        
        model = open(onnx_model_path, 'rb')
        parser.parse(model.read())

        profile = builder.create_optimization_profile()        
        profile.set_shape("categorical_features", (1, 26), (1024, 26), (1024, 26))    
        profile.set_shape("numerical_features", (1, 13), (1024, 13), (1024, 13))
        builder_config.add_optimization_profile(profile)
        engine = builder.build_serialized_network(network, builder_config)
        return engine

serialized_engine = build_engine_from_onnx("dlrm_hugectr_with_hps.onnx")
with open("dlrm_hugectr_with_hps.trt", "wb") as fout:
    fout.write(serialized_engine)
print("Succesfully build the TensorRT engine")
[12/14/2022-05:13:31] [TRT] [I] [MemUsageChange] Init CUDA: CPU +262, GPU +0, now: CPU 1014, GPU 886 (MiB)
[12/14/2022-05:13:33] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +170, GPU +46, now: CPU 1239, GPU 932 (MiB)
[12/14/2022-05:13:33] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
[12/14/2022-05:13:33] [TRT] [W] onnx2trt_utils.cpp:377: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[12/14/2022-05:13:33] [TRT] [I] No importer registered for op: HPS_TRT. Attempting to import as plugin.
[12/14/2022-05:13:33] [TRT] [I] Searching for plugin: HPS_TRT, plugin_version: 1, plugin_namespace: 
=====================================================HPS Parse====================================================
[HCTR][05:13:33.812][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][05:13:33.812][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][05:13:33.812][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][05:13:33.812][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][05:13:33.812][INFO][RK0][main]: refresh_interval is not specified using default: 0
[HCTR][05:13:33.812][INFO][RK0][main]: use_static_table is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][05:13:33.813][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][05:13:33.813][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][05:13:33.813][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][05:13:33.813][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][05:13:33.813][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][05:13:33.813][INFO][RK0][main]: Using Local file system backend.
[HCTR][05:13:36.189][INFO][RK0][main]: Table: hps_et.dlrm.sparse_embedding0; cached 239950 / 239950 embeddings in volatile database (HashMapBackend); load: 239950 / 18446744073709551615 (0.00%).
[HCTR][05:13:36.196][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][05:13:36.196][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][05:13:36.205][INFO][RK0][main]: Model name: dlrm
[HCTR][05:13:36.205][INFO][RK0][main]: Max batch size: 1024
[HCTR][05:13:36.205][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][05:13:36.205][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][05:13:36.205][INFO][RK0][main]: Use static table: False
[HCTR][05:13:36.205][INFO][RK0][main]: Use I64 input key: False
[HCTR][05:13:36.205][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][05:13:36.205][INFO][RK0][main]: The size of thread pool: 80
[HCTR][05:13:36.205][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][05:13:36.205][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][05:13:36.205][INFO][RK0][main]: The refresh percentage : 0.200000
[HCTR][05:13:36.270][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][05:13:36.270][INFO][RK0][main]: Using Local file system backend.
[HCTR][05:13:36.419][INFO][RK0][main]: EC initialization for model: "dlrm", num_tables: 1
[HCTR][05:13:36.419][INFO][RK0][main]: EC initialization on device: 0
[HCTR][05:13:36.440][INFO][RK0][main]: Creating lookup session for dlrm on device: 0
[12/14/2022-05:13:36] [TRT] [I] Successfully created plugin: HPS_TRT
[12/14/2022-05:13:37] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +335, GPU +146, now: CPU 5763, GPU 1314 (MiB)
[12/14/2022-05:13:37] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +116, GPU +54, now: CPU 5879, GPU 1368 (MiB)
[12/14/2022-05:13:37] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[12/14/2022-05:13:37] [TRT] [W] Using kFASTER_DYNAMIC_SHAPES_0805 preview feature.
[12/14/2022-05:13:52] [TRT] [I] Total Activation Memory: 34118830080
[12/14/2022-05:13:52] [TRT] [I] Detected 2 inputs and 1 output network tensors.
[12/14/2022-05:13:52] [TRT] [I] Total Host Persistent Memory: 20304
[12/14/2022-05:13:52] [TRT] [I] Total Device Persistent Memory: 10752
[12/14/2022-05:13:52] [TRT] [I] Total Scratch Memory: 32505856
[12/14/2022-05:13:52] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 16 MiB, GPU 4628 MiB
[12/14/2022-05:13:52] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 16 steps to complete.
[12/14/2022-05:13:52] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.09284ms to assign 4 blocks to 16 nodes requiring 48099840 bytes.
[12/14/2022-05:13:52] [TRT] [I] Total Activation Memory: 48099840
[12/14/2022-05:13:52] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 6321, GPU 1580 (MiB)
[12/14/2022-05:13:52] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +1, GPU +10, now: CPU 6322, GPU 1590 (MiB)
[12/14/2022-05:13:52] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +8, GPU +16, now: CPU 8, GPU 16 (MiB)
Succesfully build the TensorRT engine

Deploy HPS-integrated TensorRT engine on Triton

In order to deploy the TensorRT engine with the Triton TensorRT backend, we need to create the model repository and define the config.pbtxt first.

!mkdir -p model_repo/dlrm_hugectr_with_hps/1
!mv dlrm_hugectr_with_hps.trt model_repo/dlrm_hugectr_with_hps/1
%%writefile model_repo/dlrm_hugectr_with_hps/config.pbtxt

platform: "tensorrt_plan"
default_model_filename: "dlrm_hugectr_with_hps.trt"
backend: "tensorrt"
max_batch_size: 0
input [
  {
    name: "categorical_features"
    data_type: TYPE_INT32
    dims: [-1,26]
  },
  {
    name: "numerical_features"
    data_type: TYPE_FP32
    dims: [-1,13]
  }
]
output [
  {
      name: "label"
      data_type: TYPE_FP32
      dims: [-1,1]
  }
]
instance_group [
  {
    count: 1
    kind: KIND_GPU
    gpus:[0]

  }
]
Writing model_repo/dlrm_hugectr_with_hps/config.pbtxt
!tree model_repo/dlrm_hugectr_with_hps
model_repo/dlrm_hugectr_with_hps
├── 1
│   └── dlrm_hugectr_with_hps.trt
└── config.pbtxt

1 directory, 2 files

We can then launch the Triton inference server using the TensorRT backend. Please note that LD_PRELOAD is utilized to load the custom TensorRT plugin (i.e., HPS TensorRT plugin) into Triton.

Note: Since Background processes not supported by Jupyter, please launch the Triton Server according to the following command independently in the background.

LD_PRELOAD=/usr/local/hps_trt/lib/libhps_plugin.so tritonserver –model-repository=/hugectr/hps_trt/notebooks/model_repo/ –load-model=dlrm_hugectr_with_hps –model-control-mode=explicit

If you successfully started tritonserver, you should see a log similar to following:

+----------+--------------------------------+--------------------------------+
| Backend  | Path                           | Config                         |
+----------+--------------------------------+--------------------------------+
| tensorrt | /opt/tritonserver/backends/ten | {"cmdline":{"auto-complete-con |
|          | sorrt/libtriton_tensorrt.so    | fig":"true","min-compute-capab |
|          |                                | ility":"6.000000","backend-dir |
|          |                                | ectory":"/opt/tritonserver/bac |
|          |                                | kends","default-max-batch-size |
|          |                                | ":"4"}}                        |
|          |                                |                                |
+----------+--------------------------------+--------------------------------+

+-----------------------+---------+--------+
| Model                 | Version | Status |
+-----------------------+---------+--------+
| dlrm_hugectr_with_hps | 1       | READY  |
+-----------------------+---------+--------+

We can then send the requests to the Triton inference server using the HTTP client.

import os
import shutil
import numpy as np
import tritonclient.http as httpclient
from tritonclient.utils import *

BATCH_SIZE = 1024

categorical_feature = np.random.randint(0,260000,size=(BATCH_SIZE,26)).astype(np.int32)
numerical_feature = np.random.random((BATCH_SIZE, 13)).astype(np.float32)

inputs = [
    httpclient.InferInput("categorical_features", 
                          categorical_feature.shape,
                          np_to_triton_dtype(np.int32)),
    httpclient.InferInput("numerical_features", 
                          numerical_feature.shape,
                          np_to_triton_dtype(np.float32)),                          
]
inputs[0].set_data_from_numpy(categorical_feature)
inputs[1].set_data_from_numpy(numerical_feature)


outputs = [
    httpclient.InferRequestedOutput("label")
]

model_name = "dlrm_hugectr_with_hps"

with httpclient.InferenceServerClient("localhost:8000") as client:
    response = client.infer(model_name,
                            inputs,
                            outputs=outputs)
    result = response.get_response()
    
    print("Prediction result is \n{}".format(response.as_numpy("label")))
    print("Response details:\n{}".format(result))
Prediction result is 
[[1.        ]
 [0.49642828]
 [0.52846366]
 ...
 [0.99999994]
 [0.9999992 ]
 [0.9999905 ]]
Response details:
{'model_name': 'dlrm_hugectr_with_hps', 'model_version': '1', 'outputs': [{'name': 'label', 'datatype': 'FP32', 'shape': [1024, 1], 'parameters': {'binary_data_size': 4096}}]}