http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-hps-tensorflow-triton-deployment/nvidia_logo.png

HPS TensorRT Plugin Benchmark for TensorFlow Large Model

Overview

This notebook demonstrates how to benchmark the HPS-integrated TensorRT engine for the TensorFlow large model.

For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

  1. Create TF Create the TF model.

  2. Build the HPS-integrated TensorRT engine

  3. Benchmark HPS-integrated TensorRT engine on Triton

  4. Benchmark HPS-integrated TensorRT engine on Grace and Hooper

Installation

Use NGC

The HPS TensorRT plugin is preinstalled in the 23.05 and later Merlin TensorFlow Container: nvcr.io/nvidia/merlin/merlin-tensorflow:23.05.

You can check the existence of the required libraries by running the following Python code after launching this container.

import ctypes
plugin_lib_name = "/usr/local/hps_trt/lib/libhps_plugin.so"
plugin_handle = ctypes.CDLL(plugin_lib_name, mode=ctypes.RTLD_GLOBAL)

1. Create the TF model

We define the model graph with native TF layers, i.e., tf.nn.embedding_lookup, tf.keras.layers.Dense and so on. The embedding lookup layer is a placeholder here, which will be replaced by HPS plugin later to support looking up 147GB embedding table efficiently.

import numpy as np
import tensorflow as tf
2023-06-08 02:22:18.552734: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
class LiteModel(tf.keras.models.Model):
    def __init__(self,
                 init_tensors,
                 embed_vec_size,
                 slot_num,
                 dense_dim,
                 **kwargs):
        super(LiteModel, self).__init__(**kwargs)
        
        self.init_tensors = init_tensors
        self.params = tf.Variable(initial_value=tf.concat(self.init_tensors, axis=0))
        
        self.embed_vec_size = embed_vec_size
        self.slot_num = slot_num
        self.dense_dim = dense_dim

        self.concat1 = tf.keras.layers.Concatenate(axis=1, name = "concat1")
        self.fc1 = tf.keras.layers.Dense(1024, activation=None, name="fc1")
        self.fc2 = tf.keras.layers.Dense(256, activation=None, name="fc2")
        self.fc3 = tf.keras.layers.Dense(1, activation=None, name="fc3")            
    
    def call(self, inputs, training=True):
        categorical_features = inputs["categorical_features"]
        numerical_features = inputs["numerical_features"]
        
        embedding_vector = tf.nn.embedding_lookup(params=self.params, ids=categorical_features)
        reduced_embedding = tf.math.reduce_mean(embedding_vector, axis=1, keepdims=False)
        concat_features = self.concat1([reduced_embedding, numerical_features])
        
        logit = self.fc3(self.fc2(self.fc1(concat_features)))
        return logit

    def summary(self):
        inputs = {"categorical_features": tf.keras.Input(shape=(self.slot_num, ), dtype=tf.int32, name="categorical_features"), 
                  "numerical_features": tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32, name="numrical_features")}
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
# This is the placeholder embedding table. The real embedding table is of 147GB
init_tensors = np.ones(shape=[10000, 128], dtype=np.float32)
model = LiteModel(init_tensors, 128, 26, 13, name = "dlrm")
model.summary()
categorical_features = np.random.randint(0,100, (4096,26))
numerical_features = np.random.random((4096, 13))
inputs = {"categorical_features": categorical_features, "numerical_features": numerical_features}
model(inputs)
model.save("3fc_light.savedmodel")

# Release the occupied GPU memory by TensorFlow and Keras
from numba import cuda
cuda.select_device(0)
cuda.close()
WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.compat.v1.nn.embedding_lookup_1), but are not present in its tracked objects:   <tf.Variable 'Variable:0' shape=(10000, 128) dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer.
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 categorical_features (InputLay  [(None, 26)]        0           []                               
 er)                                                                                              
                                                                                                  
 tf.compat.v1.nn.embedding_look  (None, 26, 128)     0           ['categorical_features[0][0]']   
 up_1 (TFOpLambda)                                                                                
                                                                                                  
 tf.math.reduce_mean_1 (TFOpLam  (None, 128)         0           ['tf.compat.v1.nn.embedding_looku
 bda)                                                            p_1[0][0]']                      
                                                                                                  
 numrical_features (InputLayer)  [(None, 13)]        0           []                               
                                                                                                  
 concat1 (Concatenate)          (None, 141)          0           ['tf.math.reduce_mean_1[0][0]',  
                                                                  'numrical_features[0][0]']      
                                                                                                  
 fc1 (Dense)                    (None, 1024)         145408      ['concat1[0][0]']                
                                                                                                  
 fc2 (Dense)                    (None, 256)          262400      ['fc1[0][0]']                    
                                                                                                  
 fc3 (Dense)                    (None, 1)            257         ['fc2[0][0]']                    
                                                                                                  
==================================================================================================
Total params: 408,065
Trainable params: 408,065
Non-trainable params: 0
__________________________________________________________________________________________________
INFO:tensorflow:Assets written to: 3fc_light.savedmodel/assets

2. Build the HPS-integrated TensorRT engine

In order to use HPS in the inference stage, we create JSON configuration file and leverage the 147GB embedding table files in the HPS format.

Then we convert the TF saved model to ONNX, and employ the ONNX GraphSurgoen tool to replace the native TF embedding lookup layer with the placeholder of HPS TensorRT plugin layer.

After that, we can build the TensorRT engine, which is comprised of the HPS TensorRT plugin layer and the dense network.

Step1: Prepare the 147GB embedding table

The current instructions below assume you need to train a 147GB DLRM model from scratch based DeepLearningExamples using the 1TB criteo dataset.

1.1 Train a 147GB model from scratch

Please refer to the Quick Start Guide for DLRM, which contains the following important steps:

  1. Built the training docker container.

  2. Preprocessing the training datasets.
    Here you have two options, you can use the 1TB Criteo dataset to train a 147GB model (need to use Spark for data preprocessing), or you can use a synthetic dataset (avoid the preprocessing process) for fast verification.
    2.1. Preprocess the Criteo 1TB dataset
    A. Please change the frequency limit to 2 in the step 6 .Preprocess the data to generate a training dataset for the 147GB embedding model file.
    B. Please ensure that the final_output_dir path is consistent with the input data path in the Part 3 for step 3 Prepare the benchmark input data.
    Note: The frequency limit is used to filter out the categorical values which appear less than n times in the whole dataset, and make them be 0. Change this variable to 1 to enable it. The default frequency limit is 15 in the script. You also can change the number as you want by changing the line of OPTS=”–frequency_limit 8”.

    2.2 Generate synthetic data in the same format as Criteo
    Downloading and preprocessing the Criteo 1TB dataset requires a lot of time and disk space. Because of this we provide a synthetic dataset generator that roughly matches Criteo 1TB characteristics. This will enable you to benchmark quickly.

  3. Run the training and saved a model checkpoint. If you haven’t completed those steps.
    Note:If the model is successfully saved, you will find that each category feature of the Criteo data will export a feauture_*.npy file, which is the embedding table for each feature, and you will merge these npy files into a complete binary embedding table for HPS in the next step.

1.2 Get the embedding model file in hps format

After completing the model training and getting the 147GB embedding model file, you need to convert the embedding table file to the HPS-format embedding file. In addition: you only need to complete the first three steps in the Quick Start Guide to obtain the HPS-format embedding file.

  1. Build the Merlin HPS docker container.

  2. Run the training docker container built during the training stage.

  3. Convert the model checkpoint into a Triton model repository.

Then you will find a folder named sparse under your deploy_path(The paths provided in steps 2 and 3) for format conversion, from which you can find the two embedding table files(emb_vector and key) in HPS format.

  1. Copy the above two files(emb_vector and key) to the deployment path (model_repo/hps_model/1/dlrm_sparse.model)

!tree -n model_repo/hps_model/1/dlrm_sparse.model
model_repo/hps_model/1/dlrm_sparse.model
├── memb_vector
└── mkey

0 directories, 2 files

Step2: Prepare JSON configuration file for HPS

Please note that the storage format in the dlrm_sparse.model/key file is int64, while the HPS TensorRT plugin currently only support int32 when loading the keys into memory.
Note:In order to facilitate the benchmark test from the minimum batch to the maximum batch, we set the test range to a maximum of 65536. If you want to get better performance, please set each batch independently. For instance, if you set batch=32, it is only used to the case with 32 samples for one batch.

%%writefile light.json
{
    "supportlonglong": false,
    "models": [{
        "model": "light",
        "sparse_files": ["/hugectr/hps_trt/notebooks/model_repo/hps_model/1/dlrm_sparse.model"],
        "num_of_worker_buffer_in_pool": 1,
        "num_of_refresher_buffer_in_pool": 0,
        "embedding_table_names":["sparse_embedding1"],
        "embedding_vecsize_per_table": [128],
        "maxnum_catfeature_query_per_table_per_sample": [26],
        "default_value_for_each_table": [1.0],
        "deployed_device_list": [0],
        "max_batch_size": 65536,
        "cache_refresh_percentage_per_iteration": 0.0,
        "hit_rate_threshold": 1.0,
        "gpucacheper": 0.00,
        "gpucache": true,
        "init_ec": false,
        "embedding_cache_type": "static",
        "enable_pagelock" = true
        "use_context_stream": true
        }
    ]
}
Writing light.json

Step3: Convert to ONNX and do ONNX graph surgery

# convert TF SavedModel to ONNX
!python -m tf2onnx.convert --saved-model 3fc_light.savedmodel --output 3fc_light.onnx
2023-06-08 02:28:29.577492: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
/usr/lib/python3.8/runpy.py:127: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
  warn(RuntimeWarning(msg))
2023-06-08 02:28:32,462 - WARNING - ***IMPORTANT*** Installed protobuf is not cpp accelerated. Conversion will be extremely slow. See https://github.com/onnx/tensorflow-onnx/issues/1557
2023-06-08 02:28:36,132 - WARNING - '--tag' not specified for saved_model. Using --tag serve
2023-06-08 02:28:36,928 - INFO - Signatures found in model: [serving_default].
2023-06-08 02:28:36,928 - WARNING - '--signature_def' not specified, using first signature: serving_default
2023-06-08 02:28:36,928 - INFO - Output names: ['output_1']
2023-06-08 02:28:37,440 - INFO - Using tensorflow=2.11.0, onnx=1.14.0, tf2onnx=1.14.0/8f8d49
2023-06-08 02:28:37,440 - INFO - Using opset <onnx, 15>
2023-06-08 02:28:37,459 - INFO - Computed 0 values for constant folding
2023-06-08 02:28:37,482 - INFO - Optimizing ONNX model
2023-06-08 02:28:37,541 - INFO - After optimization: Const -3 (7->4), Identity -2 (2->0)
2023-06-08 02:28:37,781 - INFO - 
2023-06-08 02:28:37,781 - INFO - Successfully converted TensorFlow model 3fc_light.savedmodel to ONNX
2023-06-08 02:28:37,781 - INFO - Model inputs: ['categorical_features', 'numerical_features']
2023-06-08 02:28:37,781 - INFO - Model outputs: ['output_1']
2023-06-08 02:28:37,781 - INFO - ONNX model is saved at 3fc_light.onnx
 # ONNX graph surgery to insert HPS the TensorRT plugin placeholder
import onnx_graphsurgeon as gs
from onnx import  shape_inference
import numpy as np
import onnx

graph = gs.import_onnx(onnx.load("3fc_light.onnx"))
saved = []

for node in graph.nodes:
    if node.name == "StatefulPartitionedCall/dlrm/embedding_lookup":
        categorical_features = gs.Variable(name="categorical_features", dtype=np.int32, shape=("unknown", 26))
        hps_node = gs.Node(op="HPS_TRT", attrs={"ps_config_file": "light.json\0", "model_name": "light\0", "table_id": 0, "emb_vec_size": 128}, 
                           inputs=[categorical_features], outputs=[node.outputs[0]])
        graph.nodes.append(hps_node)
        saved.append(categorical_features)
        node.outputs.clear()
for i in graph.inputs:
    if i.name == "numerical_features":
        saved.append(i)
graph.inputs = saved

graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "3fc_light_with_hps.onnx")
[W] colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored
[W] Found distinct tensors that share the same name:
[id: 139822124016208] Variable (categorical_features): (shape=('unknown', 26), dtype=<class 'numpy.int32'>)
[id: 139821990953120] Variable (categorical_features): (shape=['unk__6', 26], dtype=int64)
Note: Producer node(s) of first tensor:
[]
Producer node(s) of second tensor:
[]
[W] colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored
[W] Found distinct tensors that share the same name:
[id: 139821990953120] Variable (categorical_features): (shape=['unk__6', 26], dtype=int64)
[id: 139822124016208] Variable (categorical_features): (shape=('unknown', 26), dtype=<class 'numpy.int32'>)
Note: Producer node(s) of first tensor:
[]
Producer node(s) of second tensor:
[]

Step4: Build the TensorRT engine

import tensorrt as trt
import ctypes

TRT_LOGGER = trt.Logger(trt.Logger.INFO)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

def create_hps_plugin_creator():
    trt_version = [int(n) for n in trt.__version__.split('.')]

    plugin_lib_name = "/usr/local/hps_trt/lib/libhps_plugin.so"
    handle = ctypes.CDLL(plugin_lib_name, mode=ctypes.RTLD_GLOBAL)

    trt.init_libnvinfer_plugins(TRT_LOGGER, "")
    plg_registry = trt.get_plugin_registry()

    for plugin_creator in plg_registry.plugin_creator_list:
        if plugin_creator.name[0] == "H":
            print(plugin_creator.name)

    hps_plugin_creator = plg_registry.get_plugin_creator("HPS_TRT", "1", "")
    return hps_plugin_creator

def build_engine_from_onnx(onnx_model_path):
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as builder_config:        
        model = open(onnx_model_path, 'rb')
        parser.parse(model.read())
        print(network.num_layers)
        
        builder_config.set_flag(trt.BuilderFlag.FP16)
        profile = builder.create_optimization_profile()        
        profile.set_shape("categorical_features", (1, 26), (1024, 26), (65536, 26))    
        profile.set_shape("numerical_features", (1, 13), (1024, 13), (65536, 13))
        builder_config.add_optimization_profile(profile)

        engine = builder.build_serialized_network(network, builder_config)
 
        return engine

create_hps_plugin_creator()
serialized_engine = build_engine_from_onnx("3fc_light_with_hps.onnx")
with open("dynamic_3fc_light.trt", "wb") as fout:
    fout.write(serialized_engine)
HPS_TRT
[06/08/2023-02:37:03] [TRT] [I] [MemUsageChange] Init CUDA: CPU +974, GPU +0, now: CPU 2531, GPU 661 (MiB)
[06/08/2023-02:37:09] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +336, GPU +74, now: CPU 2943, GPU 735 (MiB)
[06/08/2023-02:37:09] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[06/08/2023-02:37:09] [TRT] [I] No importer registered for op: HPS_TRT. Attempting to import as plugin.
[06/08/2023-02:37:09] [TRT] [I] Searching for plugin: HPS_TRT, plugin_version: 1, plugin_namespace: 
=====================================================HPS Parse====================================================
[HCTR][02:37:09.116][INFO][RK0][main]: fuse_embedding_table is not specified using default: 0
[HCTR][02:37:09.116][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][02:37:09.117][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][02:37:09.117][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][02:37:09.117][INFO][RK0][main]: refresh_interval is not specified using default: 0
[HCTR][02:37:09.117][INFO][RK0][main]: use_static_table is not specified using default: 0
[HCTR][02:37:09.117][INFO][RK0][main]: use_hctr_cache_implementation is not specified using default: 1
[HCTR][02:37:09.117][INFO][RK0][main]: HPS plugin uses context stream for model light: True
====================================================HPS Create====================================================
[HCTR][02:37:09.117][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][02:37:09.117][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][02:37:09.117][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][02:37:09.117][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][02:37:09.117][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][02:37:09.177][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][02:37:09.177][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][02:37:09.177][INFO][RK0][main]: Model name: light
[HCTR][02:37:09.177][INFO][RK0][main]: Max batch size: 65536
[HCTR][02:37:09.177][INFO][RK0][main]: Fuse embedding tables: False
[HCTR][02:37:09.177][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][02:37:09.177][INFO][RK0][main]: Use static table: False
[HCTR][02:37:09.177][INFO][RK0][main]: Use I64 input key: False
[HCTR][02:37:09.177][INFO][RK0][main]: The size of worker memory pool: 1
[HCTR][02:37:09.177][INFO][RK0][main]: The size of refresh memory pool: 0
[HCTR][02:37:09.177][INFO][RK0][main]: The refresh percentage : 0.000000
[HCTR][02:38:09.140][INFO][RK0][main]: Initialize the embedding cache by by inserting the same size model file with embedding cache from beginning
[HCTR][02:38:09.140][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][02:38:09.141][INFO][RK0][main]: EC initialization on device 0 for hps_et.light.sparse_embedding1
[HCTR][03:41:57.227][INFO][RK0][main]: LookupSession i64_input_key: False
[HCTR][03:41:57.227][INFO][RK0][main]: Creating lookup session for light on device: 0
[06/08/2023-03:41:57] [TRT] [I] Successfully created plugin: HPS_TRT
9
[06/08/2023-03:41:57] [TRT] [I] BuilderFlag::kTF32 is set but hardware does not support TF32. Disabling TF32.
[06/08/2023-03:41:57] [TRT] [I] Graph optimization time: 0.0088975 seconds.
[06/08/2023-03:41:57] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +1, GPU +8, now: CPU 2952, GPU 12129 (MiB)
[06/08/2023-03:41:57] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +349, GPU +190, now: CPU 3301, GPU 12319 (MiB)
[06/08/2023-03:41:57] [TRT] [W] TensorRT was linked against cuDNN 8.9.0 but loaded cuDNN 8.7.0
[06/08/2023-03:41:57] [TRT] [I] BuilderFlag::kTF32 is set but hardware does not support TF32. Disabling TF32.
[06/08/2023-03:41:57] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[06/08/2023-03:42:03] [TRT] [I] Detected 2 inputs and 1 output network tensors.
[06/08/2023-03:42:03] [TRT] [I] Total Host Persistent Memory: 16672
[06/08/2023-03:42:03] [TRT] [I] Total Device Persistent Memory: 0
[06/08/2023-03:42:03] [TRT] [I] Total Scratch Memory: 0
[06/08/2023-03:42:03] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 3 MiB, GPU 1248 MiB
[06/08/2023-03:42:03] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 10 steps to complete.
[06/08/2023-03:42:03] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.040758ms to assign 3 blocks to 10 nodes requiring 905970176 bytes.
[06/08/2023-03:42:03] [TRT] [I] Total Activation Memory: 905969664
[06/08/2023-03:42:03] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 3302, GPU 12397 (MiB)
[06/08/2023-03:42:03] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +10, now: CPU 3302, GPU 12407 (MiB)
[06/08/2023-03:42:03] [TRT] [W] TensorRT was linked against cuDNN 8.9.0 but loaded cuDNN 8.7.0
[06/08/2023-03:42:03] [TRT] [W] TensorRT encountered issues when converting weights between types and that could affect accuracy.
[06/08/2023-03:42:03] [TRT] [W] If this is not the desired behavior, please modify the weights or retrain with regularization to adjust the magnitude of the weights.
[06/08/2023-03:42:03] [TRT] [W] Check verbose logs for the list of affected weights.
[06/08/2023-03:42:03] [TRT] [W] - 2 weights are affected by this issue: Detected subnormal FP16 values.
[06/08/2023-03:42:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +4, now: CPU 0, GPU 4 (MiB)

3. Benchmark HPS-integrated TensorRT engine on Triton

Step1: Create the model repository

In order to benchmark the TensorRT engine with the Triton TensorRT backend, we need to create the model repository and define the config.pbtxt first.

!mkdir -p model_repo/dynamic_3fc_lite_hps_trt/1
!mv dynamic_3fc_light.trt model_repo/dynamic_3fc_lite_hps_trt/1
%%writefile model_repo/dynamic_3fc_lite_hps_trt/config.pbtxt

platform: "tensorrt_plan"
default_model_filename: "dynamic_3fc_light.trt"
backend: "tensorrt"
max_batch_size: 0
input [
  {
    name: "categorical_features"
    data_type: TYPE_INT32
    dims: [-1,26]
  },
  {
    name: "numerical_features"
    data_type: TYPE_FP32
    dims: [-1,13]
  }
]
output [
  {
      name: "output_1"
      data_type: TYPE_FP32
      dims: [-1,1]
  }
]
instance_group [
  {
    count: 1
    kind: KIND_GPU
    gpus:[0]

  }
]
Overwriting model_repo/dynamic_3fc_lite_hps_trt/config.pbtxt
!tree -n model_repo/dynamic_3fc_lite_hps_trt
mmodel_repo/dynamic_3fc_lite_hps_trt
├── 1
│   └── dynamic_3fc_light.trt
└── config.pbtxt

1 directory, 2 files

Step2: Prepare the benchmark input data

To benchmark with Triton Performance Analyzer, we need to prepare the input data of the required format based on Criteo dataset.

In part 2 section 1.1, we have created the binary dataset in final_output_dir. We provide you with a script to convert this binary data into JSON format that can be fed into the Triton Performance Analyzer. To use the script, you will need the DeepLearningExamples again. Make sure you have add it to your $PYTHONPATH.

If not, please run export PYTHONPATH=/DeepLearningExamples/TensorFlow2/Recommendation/DLRM_and_DCNv2. Remember to replace the path with the correct path in your workspace.

#Create a dir to store the JSON format data
!mkdir -p ./perf_data

Please note that the following script will takes several minutes to finish.

#Run the Python script to convert binary data to JSON format
!python spark2json.py --result-path ./perf_data --dataset_path /path/to/your/binary_split_converted_data --num-benchmark-samples 2000000

Remember to replace the --dataset_path with the correct path in your workspace, and specify the --num-benchmark-samples you want to use.

!tree -n perf_data
perf_data
├── 1024.json
├── 16384.json
├── 2048.json
├── 256.json
├── 32768.json
├── 4096.json
├── 512.json
├── 65536.json
└── 8192.json

0 directories, 9 files

Step3: Launch the Triton inference server

We can then launch the Triton inference server using the TensorRT backend. Please note that LD_PRELOAD is utilized to load the custom TensorRT plugin (i.e., HPS TensorRT plugin) into Triton.

Note: Since Background processes not supported by Jupyter, please launch the Triton Server according to the following command independently in the background.

LD_PRELOAD=/usr/local/hps_trt/lib/libhps_plugin.so tritonserver –model-repository=/hugectr/hps_trt/notebooks/model_repo/ –load-model=dynamic_3fc_lite_hps_trt –model-control-mode=explicit

If you successfully started tritonserver, you should see a log similar to following:

+--------------------------+---------+--------+
| Model                    | Version | Status |
+--------------------------+---------+--------+
| dynamic_3fc_lite_hps_trt | 1       | READY  |
+--------------------------+---------+--------+

Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

We can then benchmark the online inference performance of this HPS-integrated engine with 147GB embedding table.

Step4: Run the benchmark

The reported compute infer number at the server side is the end-to-end inference latency of the engine, which covers HPS embedding lookup and TensorRT forward of dense layers.

%%writefile benchmark.sh

batch_size=(256 512 1024 2048 4096 8192 16384 32768 65536)

model_name=("dynamic_3fc_lite_hps_trt")

for b in ${batch_size[*]};
do
  for m in ${model_name[*]};
  do
    echo $b $m
    perf_analyzer -m ${m} -u localhost:8000 --input-data perf_data/${b}.json --shape categorical_features:${b},26 --shape numerical_features:${b},13
  done
done
Overwriting benchmark.sh
!bash benchmark.sh > result.log
256 dynamic_3fc_lite_hps_trt
 Successfully read data for 1 stream/streams with 25600 step/steps.
*** Measurement Settings ***
  Batch size: 1
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Using synchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client: 
    Request count: 20941
    Throughput: 1163.14 infer/sec
    Avg latency: 851 usec (standard deviation 1184 usec)
    p50 latency: 799 usec
    p90 latency: 922 usec
    p95 latency: 977 usec
    p99 latency: 1190 usec
    Avg HTTP time: 846 usec (send/recv 85 usec + response wait 761 usec)
  Server: 
    Inference count: 20941
    Execution count: 20941
    Successful request count: 20941
    Avg request latency: 551 usec (overhead 21 usec + queue 40 usec + compute input 38 usec + compute infer 343 usec + compute output 108 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 1163.14 infer/sec, latency 851 usec
512 dynamic_3fc_lite_hps_trt
 Successfully read data for 1 stream/streams with 12800 step/steps.
*** Measurement Settings ***
  Batch size: 1
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Using synchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client: 
    Request count: 14135
    Throughput: 785.143 infer/sec
    Avg latency: 1264 usec (standard deviation 286 usec)
    p50 latency: 1236 usec
    p90 latency: 1340 usec
    p95 latency: 1374 usec
    p99 latency: 1476 usec
    Avg HTTP time: 1258 usec (send/recv 92 usec + response wait 1166 usec)
  Server: 
    Inference count: 14135
    Execution count: 14135
    Successful request count: 14135
    Avg request latency: 889 usec (overhead 27 usec + queue 44 usec + compute input 42 usec + compute infer 619 usec + compute output 156 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 785.143 infer/sec, latency 1264 usec
1024 dynamic_3fc_lite_hps_trt
 Successfully read data for 1 stream/streams with 6400 step/steps.
*** Measurement Settings ***
  Batch size: 1
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Using synchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client: 
    Request count: 8116
    Throughput: 450.826 infer/sec
    Avg latency: 2206 usec (standard deviation 391 usec)
    p50 latency: 2183 usec
    p90 latency: 2321 usec
    p95 latency: 2368 usec
    p99 latency: 2486 usec
    Avg HTTP time: 2199 usec (send/recv 118 usec + response wait 2081 usec)
  Server: 
    Inference count: 8116
    Execution count: 8116
    Successful request count: 8116
    Avg request latency: 1632 usec (overhead 45 usec + queue 73 usec + compute input 106 usec + compute infer 1173 usec + compute output 234 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 450.826 infer/sec, latency 2206 usec
2048 dynamic_3fc_lite_hps_trt
 Successfully read data for 1 stream/streams with 3200 step/steps.
*** Measurement Settings ***
  Batch size: 1
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Using synchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client: 
    Request count: 5311
    Throughput: 295.01 infer/sec
    Avg latency: 3377 usec (standard deviation 459 usec)
    p50 latency: 3349 usec
    p90 latency: 3486 usec
    p95 latency: 3530 usec
    p99 latency: 3820 usec
    Avg HTTP time: 3370 usec (send/recv 155 usec + response wait 3215 usec)
  Server: 
    Inference count: 5311
    Execution count: 5311
    Successful request count: 5311
    Avg request latency: 2591 usec (overhead 50 usec + queue 76 usec + compute input 162 usec + compute infer 2068 usec + compute output 234 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 295.01 infer/sec, latency 3377 usec
4096 dynamic_3fc_lite_hps_trt
 Successfully read data for 1 stream/streams with 1600 step/steps.
*** Measurement Settings ***
  Batch size: 1
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Using synchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client: 
    Request count: 3518
    Throughput: 195.42 infer/sec
    Avg latency: 5109 usec (standard deviation 380 usec)
    p50 latency: 5068 usec
    p90 latency: 5242 usec
    p95 latency: 5316 usec
    p99 latency: 5741 usec
    Avg HTTP time: 5104 usec (send/recv 171 usec + response wait 4933 usec)
  Server: 
    Inference count: 3518
    Execution count: 3518
    Successful request count: 3518
    Avg request latency: 4134 usec (overhead 38 usec + queue 48 usec + compute input 138 usec + compute infer 3742 usec + compute output 167 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 195.42 infer/sec, latency 5109 usec
8192 dynamic_3fc_lite_hps_trt
 Successfully read data for 1 stream/streams with 800 step/steps.
*** Measurement Settings ***
  Batch size: 1
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Using synchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client: 
    Request count: 1910
    Throughput: 106.097 infer/sec
    Avg latency: 9412 usec (standard deviation 553 usec)
    p50 latency: 9384 usec
    p90 latency: 9529 usec
    p95 latency: 9581 usec
    p99 latency: 10106 usec
    Avg HTTP time: 9406 usec (send/recv 294 usec + response wait 9112 usec)
  Server: 
    Inference count: 1910
    Execution count: 1910
    Successful request count: 1910
    Avg request latency: 7674 usec (overhead 42 usec + queue 54 usec + compute input 267 usec + compute infer 7179 usec + compute output 130 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 106.097 infer/sec, latency 9412 usec
16384 dynamic_3fc_lite_hps_trt
 Successfully read data for 1 stream/streams with 400 step/steps.
*** Measurement Settings ***
  Batch size: 1
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Using synchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client: 
    Request count: 992
    Throughput: 55.1033 infer/sec
    Avg latency: 18132 usec (standard deviation 726 usec)
    p50 latency: 18051 usec
    p90 latency: 18257 usec
    p95 latency: 18330 usec
    p99 latency: 23069 usec
    Avg HTTP time: 18125 usec (send/recv 1278 usec + response wait 16847 usec)
  Server: 
    Inference count: 992
    Execution count: 992
    Successful request count: 992
    Avg request latency: 14999 usec (overhead 29 usec + queue 56 usec + compute input 476 usec + compute infer 14234 usec + compute output 203 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 55.1033 infer/sec, latency 18132 usec
32768 dynamic_3fc_lite_hps_trt
 Successfully read data for 1 stream/streams with 200 step/steps.
*** Measurement Settings ***
  Batch size: 1
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Using synchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client: 
    Request count: 515
    Throughput: 28.6081 infer/sec
    Avg latency: 34878 usec (standard deviation 927 usec)
    p50 latency: 34734 usec
    p90 latency: 35143 usec
    p95 latency: 35288 usec
    p99 latency: 40804 usec
    Avg HTTP time: 34872 usec (send/recv 2584 usec + response wait 32288 usec)
  Server: 
    Inference count: 516
    Execution count: 516
    Successful request count: 516
    Avg request latency: 29340 usec (overhead 33 usec + queue 55 usec + compute input 870 usec + compute infer 28111 usec + compute output 270 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 28.6081 infer/sec, latency 34878 usec
65536 dynamic_3fc_lite_hps_trt
 Successfully read data for 1 stream/streams with 100 step/steps.
*** Measurement Settings ***
  Batch size: 1
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Using synchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client: 
    Request count: 253
    Throughput: 14.053 infer/sec
    Avg latency: 71063 usec (standard deviation 1570 usec)
    p50 latency: 70749 usec
    p90 latency: 71666 usec
    p95 latency: 73226 usec
    p99 latency: 77979 usec
    Avg HTTP time: 71058 usec (send/recv 5092 usec + response wait 65966 usec)
  Server: 
    Inference count: 253
    Execution count: 253
    Successful request count: 253
    Avg request latency: 60716 usec (overhead 38 usec + queue 58 usec + compute input 1804 usec + compute infer 58482 usec + compute output 333 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 14.053 infer/sec, latency 71063 usec
%%writefile ./summary.py
import os
from argparse import ArgumentParser
import json
import re
import glob
from collections import defaultdict
import math

log_pattern = {
    "inference_benchmark": {
        "cmd_log": r"compute infer",
        "result_log": r"compute infer (\d+\.?\d*) usec",
    },

}



def extract_result_from_log(log_path):
    job_log_pattern = log_pattern["inference_benchmark"]
    results = []
    with open(log_path, "r", errors="ignore") as f:
        lines = "".join(f.readlines())
        job_logs = lines.split("+ ")
        for each_job_log in job_logs:
            if re.search(job_log_pattern["cmd_log"], each_job_log):
                for line in each_job_log.split("\n"):
                    match = re.search(job_log_pattern["result_log"], line)
                    if match is None:
                        continue
                    result = float(match.group(1))
                    results.append(result)
    return results



if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--log_path", required=True)
    args = parser.parse_args()
    batch_sizes = ["256", "512", "1024", "2048", "4096", "8192", "16384", "32768", "65536"]


    perf_result = extract_result_from_log(args.log_path)
    idx = 0
    batch_sizes = ["256", "512", "1024", "2048", "4096", "8192", "16384", "32768", "65536"]
    print("Inference Latency (usec)")
    print("-----------------------------------------------------------------------------------------")
    print("batch_size\tresult \t")
    print("-----------------------------------------------------------------------------------------")
    for i in range(len(perf_result)):
        print("{}\t\t{}\t\t".format(
                        batch_sizes[i],
                        perf_result[i]
                    )
                )
        print(
                "-----------------------------------------------------------------------------------------"
            )
Overwriting ./summary.py

Summarize the result

!python ./summary.py --log_path result.log
Inference Latency (usec)
-----------------------------------------------------------------------------------------
batch_size	result 	
-----------------------------------------------------------------------------------------
256		343.0		
-----------------------------------------------------------------------------------------
512		619.0		
-----------------------------------------------------------------------------------------
1024		1173.0		
-----------------------------------------------------------------------------------------
2048		2068.0		
-----------------------------------------------------------------------------------------
4096		3742.0		
-----------------------------------------------------------------------------------------
8192		7179.0		
-----------------------------------------------------------------------------------------
16384		14234.0		
-----------------------------------------------------------------------------------------
32768		28111.0		
-----------------------------------------------------------------------------------------
65536		58482.0		
-----------------------------------------------------------------------------------------

4. Benchmark for ARM64 or Grace + Hooper systems

Our prebuilt Grace-optimized ARM64 images are currently undergoing testing, and are therefore not yet available via NGC. This will change soon. If you want to benchmark on ARM and in particular a system equipped with a NVIDIA Grace CPU, you can build a compatible docker image yourself by following these steps.

In some steps we provide 2 mutually exclusive alternatives.

  • Option A (portable ARM64 HugeCTR): If you follow option A instructions, you will build the standard version of HugeCTR for ARM64 platforms. This approach produces binaries that are more portable, but may not allow you get the most out your Grace+Hopper hardware setup.

  • Option B (G+H optimized HugeCTR): In contrast, if you follow option B instructions, you will build and run a HugeCTR variant that maximizes DLRM throughput on Grace+Hopper systems. However, please be advised that slight alterations to the system setup are necessary to achieve this. To apply these alterations you must have root access.es configuration

Step 1: Build the NVIDIA Merlin docker images

Use the following instructions on your ARM system to download and build merlin-base and merlin-tensorflow docker images required for the benchmark.

  • Option A (portable ARM64 HugeCTR):

    git clone https://github.com/NVIDIA-Merlin/Merlin.git
    cd Merlin/docker
    docker build -t nvcr.io/nvstaging/merlin/merlin-base:24.04 -f dockerfile.merlin.ctr .
    docker build -t nvcr.io/nvstaging/merlin/merlin-hugectr:24.04 -f dockerfile.ctr .
    cd ../..
    
  • Option B (G+H optimized HugeCTR):

    git clone https://github.com/NVIDIA-Merlin/Merlin.git
    cd Merlin/docker
    sed -i -e 's/" -DENABLE_INFERENCE=ON/" -DUSE_HUGE_PAGES=ON -DENABLE_INFERENCE=ON/g' dockerfile.merlin
    docker build -t nvcr.io/nvstaging/merlin/merlin-base:24.04 -f dockerfile.merlin.ctr .
    docker build -t nvcr.io/nvstaging/merlin/merlin-hugectr:24.04 -f dockerfile.ctr .
    cd ../..
    

Step 2: Prepare host system for running the docker container

  • Option A (portable ARM64 HugeCTR): No action required.

  • Option B (G+H optimized HugeCTR): Adjust your Grace+Hopper system configuration to increase the number of large memory pages for the benchmark.

    sudo echo '180000' > /sys/devices/system/node/node0/hugepages/hugepages-2048kB/nr_hugepages
    

    This make take a while.

    In addition, you can reuse the light.json configuration file in Prepare JSON configuration file for HPS.

Step 3: Create the model

Follow to Create TF Create the TF model to create the model. There are many ways to accomplish this. We suggest simply running this Jupyter notebook using the docker image that you just created in your ARM64 / Grace+Hopper node, and forward the web-server port to the host system.

Your filesystem or system environment might impose constraints. The following command just serves as an example. It assumes HugeCTR was downloaded from GitHub into the current working directory (git clone https://github.com/NVIDIA-Merlin/HugeCTR.git). To allow writing files, we first give root user (inside the docker image you are root) to access to the notebook folder (this folder), and then startup a suitable Jupyter server.

export HCTR_SRC="${PWD}/HugeCTR" && chmod -R 777 "${HCTR_SRC}/hps_trt/notebooks" && docker run -it --rm --gpus all --network=host -v ${HCTR_SRC}:/hugectr nvcr.io/nvstaging/merlin/merlin-hugectr:24.04 jupyter-lab --allow-root --ip 0.0.0.0 --port 8888 --no-browser --notebook-dir=/hugectr/hps_trt/notebooks

Step 4: Prepare data

Next, follow the instructions in Build the HPS-integrated TensorRT engine to create the dataset and the predconditions for benchmarking.

Step 5: Run benchmark

Follow the steps outlined in Benchmark HPS-integrated TensorRT engine on Triton to execute the benchmark itself.