http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-hps-tensorflow-triton-deployment/nvidia_logo.png

Deploy SavedModel using HPS with Triton TensorFlow Backend

Overview

This notebook demonstrates how to deploy the SavedModel that leverages HPS with Triton TensorFlow backend. It is recommended to run hierarchical_parameter_server_demo.ipynb before diving into this notebook.

For more details about HPS APIs, please refer to HPS APIs. For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

Installation

Get HPS from NGC

The HPS Python module is preinstalled in the 22.09 and later Merlin TensorFlow Container: nvcr.io/nvidia/merlin/merlin-tensorflow:22.09.

You can check the existence of the required libraries by running the following Python code after launching this container.

$ python3 -c "import hierarchical_parameter_server as hps"

The Triton TensorFlow backend is also available in this container.

Configurations

First of all we specify the required configurations, e.g., the arguments needed for generating the dataset, the paths to save the model and the model parameters. We will use a deep neural network (DNN) model which has one embedding table and several dense layers in this notebook. Please note that there are two inputs here, one is the key tensor (one-hot) while the other is the dense feature tensor.

import hierarchical_parameter_server as hps
import os
import numpy as np
import tensorflow as tf
import struct

args = dict()

args["gpu_num"] = 1                               # the number of available GPUs
args["iter_num"] = 10                             # the number of training iteration
args["slot_num"] = 5                              # the number of feature fields in this embedding layer
args["embed_vec_size"] = 16                       # the dimension of embedding vectors
args["global_batch_size"] = 1024                  # the globally batchsize for all GPUs
args["max_vocabulary_size"] = 50000
args["vocabulary_range_per_slot"] = [[0,10000],[10000,20000],[20000,30000],[30000,40000],[40000,50000]]
args["dense_dim"] = 10

args["dense_model_path"] = "hps_tf_triton_dense.model"
args["ps_config_file"] = "hps_tf_triton.json"
args["embedding_table_path"] = "hps_tf_triton_sparse_0.model"
args["saved_path"] = "hps_tf_triton_tf_saved_model"
args["np_key_type"] = np.int64
args["np_vector_type"] = np.float32
args["tf_key_type"] = tf.int64
args["tf_vector_type"] = tf.float32


os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(args["gpu_num"])))
[INFO] hierarchical_parameter_server is imported
def generate_random_samples(num_samples, vocabulary_range_per_slot, dense_dim, key_dtype = args["np_key_type"]):
    keys = list()
    for vocab_range in vocabulary_range_per_slot:
        keys_per_slot = np.random.randint(low=vocab_range[0], high=vocab_range[1], size=(num_samples, 1), dtype=key_dtype)
        keys.append(keys_per_slot)
    keys = np.concatenate(np.array(keys), axis = 1)
    dense_features = np.random.random((num_samples, dense_dim)).astype(np.float32)
    labels = np.random.randint(low=0, high=2, size=(num_samples, 1))
    return keys, dense_features, labels

def tf_dataset(keys, dense_features, labels, batchsize):
    dataset = tf.data.Dataset.from_tensor_slices((keys, dense_features, labels))
    dataset = dataset.batch(batchsize, drop_remainder=True)
    return dataset

Train with native TF layers

We define the model graph for training with native TF layers, i.e., tf.nn.embedding_lookup and tf.keras.layers.Dense. Besides, the embedding weights are stored in tf.Variable. We can then train the model and extract the trained weights of the embedding table. As for the dense layers, they are saved as a separate model graph, which can be loaded directly during inference.

class TrainModel(tf.keras.models.Model):
    def __init__(self,
                 init_tensors,
                 slot_num,
                 embed_vec_size,
                 dense_dim,
                 **kwargs):
        super(TrainModel, self).__init__(**kwargs)
        
        self.slot_num = slot_num
        self.embed_vec_size = embed_vec_size
        self.dense_dim = dense_dim
        self.init_tensors = init_tensors
        self.params = tf.Variable(initial_value=tf.concat(self.init_tensors, axis=0))
        self.concat = tf.keras.layers.Concatenate(axis=1, name="concatenate")
        self.fc_1 = tf.keras.layers.Dense(units=256, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros",
                                                 name='fc_1')
        self.fc_2 = tf.keras.layers.Dense(units=1, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros",
                                                 name='fc_2')

    def call(self, inputs):
        keys, dense_features = inputs[0], inputs[1]
        embedding_vector = tf.nn.embedding_lookup(params=self.params, ids=keys)
        embedding_vector = tf.reshape(embedding_vector, shape=[-1, self.slot_num * self.embed_vec_size])
        concated_features = self.concat([embedding_vector, dense_features])
        logit = self.fc_2(self.fc_1(concated_features))
        return logit

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.slot_num, ), dtype=args["tf_key_type"]),
                  tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32)]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def train(args):
    init_tensors = np.ones(shape=[args["max_vocabulary_size"], args["embed_vec_size"]], dtype=args["np_vector_type"])
    
    model = TrainModel(init_tensors, args["slot_num"], args["embed_vec_size"], args["dense_dim"])
    model.summary()
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
    
    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit = model(inputs)
            loss = loss_fn(labels, logit)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return logit, loss

    keys, dense_features, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot"], args["dense_dim"])
    dataset = tf_dataset(keys, dense_features, labels, args["global_batch_size"])
    for i, (keys, dense_features, labels) in enumerate(dataset):
        inputs = [keys, dense_features]
        _, loss = _train_step(inputs, labels)
        print("-"*20, "Step {}, loss: {}".format(i, loss),  "-"*20)

    return model
trained_model = train(args)
weights_list = trained_model.get_weights()
embedding_weights = weights_list[-1]
dense_model = tf.keras.Model(trained_model.get_layer("concatenate").input,
                             trained_model.get_layer("fc_2").output)
dense_model.summary()
dense_model.save(args["dense_model_path"])
2022-08-31 02:41:00.863222: I tensorflow/core/platform/cpu_feature_guard.cc:152] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-08-31 02:41:01.391912: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30999 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.compat.v1.nn.embedding_lookup), but are not present in its tracked objects:   <tf.Variable 'Variable:0' shape=(50000, 16) dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer.
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 5)]          0           []                               
                                                                                                  
 tf.compat.v1.nn.embedding_look  (None, 5, 16)       0           ['input_1[0][0]']                
 up (TFOpLambda)                                                                                  
                                                                                                  
 tf.reshape (TFOpLambda)        (None, 80)           0           ['tf.compat.v1.nn.embedding_looku
                                                                 p[0][0]']                        
                                                                                                  
 input_2 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 concatenate (Concatenate)      (None, 90)           0           ['tf.reshape[0][0]',             
                                                                  'input_2[0][0]']                
                                                                                                  
 fc_1 (Dense)                   (None, 256)          23296       ['concatenate[0][0]']            
                                                                                                  
 fc_2 (Dense)                   (None, 1)            257         ['fc_1[0][0]']                   
                                                                                                  
==================================================================================================
Total params: 23,553
Trainable params: 23,553
Non-trainable params: 0
__________________________________________________________________________________________________
-------------------- Step 0, loss: 11092.1826171875 --------------------
-------------------- Step 1, loss: 8587.974609375 --------------------
-------------------- Step 2, loss: 6780.404296875 --------------------
-------------------- Step 3, loss: 5393.8896484375 --------------------
-------------------- Step 4, loss: 4023.296142578125 --------------------
-------------------- Step 5, loss: 2579.7099609375 --------------------
-------------------- Step 6, loss: 1797.363037109375 --------------------
-------------------- Step 7, loss: 1062.6259765625 --------------------
-------------------- Step 8, loss: 566.6324462890625 --------------------
-------------------- Step 9, loss: 221.4973602294922 --------------------
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_3 (InputLayer)           [(None, 80)]         0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 concatenate (Concatenate)      (None, 90)           0           ['input_3[0][0]',                
                                                                  'input_2[0][0]']                
                                                                                                  
 fc_1 (Dense)                   (None, 256)          23296       ['concatenate[1][0]']            
                                                                                                  
 fc_2 (Dense)                   (None, 1)            257         ['fc_1[1][0]']                   
                                                                                                  
==================================================================================================
Total params: 23,553
Trainable params: 23,553
Non-trainable params: 0
__________________________________________________________________________________________________
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
2022-08-31 02:41:02.894935: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: hps_tf_triton_dense.model/assets
INFO:tensorflow:Assets written to: hps_tf_triton_dense.model/assets

Create the inference graph with HPS LookupLayer

In order to use HPS in the inference stage, we need to create a inference model graph which is almost the same as the train graph except that tf.nn.embedding_lookup is replaced by hps.LookupLayer. The trained dense model graph can be loaded directly, while the embedding weights should be converted to the formats required by HPS.

We can then save the inference model graph, which will be ready to be loaded for inference deployment. Please note that the inference SavedModel that leverages HPS will be deployed with the Triton TensorFlow backend, thus implicit initialization of HPS should be enabled by specifying ps_config_file and global_batch_size in the constructor of hps.LookupLayer. For more information, please refer to HPS Initialize.

To this end, we need to create a JSON configuration file and specify the details of the embedding tables for the models to be deployed. We only show how to deploy a model that has one embedding table here, and it can support multiple models with multiple embedding tables actually.

%%writefile hps_tf_triton.json
{
    "supportlonglong": true,
    "models": [{
        "model": "hps_tf_triton",
        "sparse_files": ["/hugectr/hierarchical_parameter_server/notebooks/model_repo/hps_tf_triton/1/hps_tf_triton_sparse_0.model"],
        "num_of_worker_buffer_in_pool": 3,
        "embedding_table_names":["sparse_embedding1"],
        "embedding_vecsize_per_table": [16],
        "maxnum_catfeature_query_per_table_per_sample": [5],
        "default_value_for_each_table": [1.0],
        "deployed_device_list": [0],
        "max_batch_size": 1024,
        "cache_refresh_percentage_per_iteration": 0.2,
        "hit_rate_threshold": 1.0,
        "gpucacheper": 1.0,
        "gpucache": true
        }
    ]
}
Writing hps_tf_triton.json
triton_model_repo = "/hugectr/hierarchical_parameter_server/notebooks/model_repo/hps_tf_triton/"

class InferenceModel(tf.keras.models.Model):
    def __init__(self,
                 slot_num,
                 embed_vec_size,
                 dense_dim,
                 dense_model_path,
                 **kwargs):
        super(InferenceModel, self).__init__(**kwargs)
        
        self.slot_num = slot_num
        self.embed_vec_size = embed_vec_size
        self.dense_dim = dense_dim
        self.lookup_layer = hps.LookupLayer(model_name = "hps_tf_triton", 
                                            table_id = 0,
                                            emb_vec_size = self.embed_vec_size,
                                            emb_vec_dtype = args["tf_vector_type"],
                                            ps_config_file = triton_model_repo + args["ps_config_file"],
                                            global_batch_size = args["global_batch_size"],
                                            name = "lookup")
        self.dense_model = tf.keras.models.load_model(dense_model_path)

    def call(self, inputs):
        keys, dense_features = inputs[0], inputs[1]
        embedding_vector = self.lookup_layer(keys)
        embedding_vector = tf.reshape(embedding_vector, shape=[-1, self.slot_num * self.embed_vec_size])
        logit = self.dense_model([embedding_vector, dense_features])
        return logit

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.slot_num, ), dtype=args["tf_key_type"]),
                  tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32)]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def create_and_save_inference_graph(args): 
    model = InferenceModel(args["slot_num"], args["embed_vec_size"], args["dense_dim"], args["dense_model_path"])
    model.summary()
    _ = model([tf.keras.Input(shape=(args["slot_num"], ), dtype=args["tf_key_type"]),
               tf.keras.Input(shape=(args["dense_dim"], ), dtype=tf.float32)])
    model.save(args["saved_path"])
def convert_to_sparse_model(embeddings_weights, embedding_table_path, embedding_vec_size):
    os.system("mkdir -p {}".format(embedding_table_path))
    with open("{}/key".format(embedding_table_path), 'wb') as key_file, \
        open("{}/emb_vector".format(embedding_table_path), 'wb') as vec_file:
      for key in range(embeddings_weights.shape[0]):
        vec = embeddings_weights[key]
        key_struct = struct.pack('q', key)
        vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
        key_file.write(key_struct)
        vec_file.write(vec_struct)
convert_to_sparse_model(embedding_weights, args["embedding_table_path"], args["embed_vec_size"])
create_and_save_inference_graph(args)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_4 (InputLayer)           [(None, 5)]          0           []                               
                                                                                                  
 lookup (LookupLayer)           (None, 5, 16)        0           ['input_4[0][0]']                
                                                                                                  
 tf.reshape_1 (TFOpLambda)      (None, 80)           0           ['lookup[0][0]']                 
                                                                                                  
 input_5 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 model_1 (Functional)           (None, 1)            23553       ['tf.reshape_1[0][0]',           
                                                                  'input_5[0][0]']                
                                                                                                  
==================================================================================================
Total params: 23,553
Trainable params: 23,553
Non-trainable params: 0
__________________________________________________________________________________________________
INFO:tensorflow:Assets written to: hps_tf_triton_tf_saved_model/assets
INFO:tensorflow:Assets written to: hps_tf_triton_tf_saved_model/assets

Deploy SavedModel using HPS with Triton TensorFlow Backend

In order to deploy the inference SavedModel with the Triton TensorFlow backend, we need to create the model repository and define the config.pbtxt first.

!mkdir -p model_repo/hps_tf_triton/1
!mv hps_tf_triton_tf_saved_model model_repo/hps_tf_triton/1/model.savedmodel
!mv hps_tf_triton_sparse_0.model model_repo/hps_tf_triton/1
!mv hps_tf_triton.json model_repo/hps_tf_triton
%%writefile model_repo/hps_tf_triton/config.pbtxt
name: "hps_tf_triton"
platform: "tensorflow_savedmodel"
max_batch_size:1024
input [
  {
    name: "input_1"
    data_type: TYPE_INT64
    dims: [5]
  },
  {
    name: "input_2"
    data_type: TYPE_FP32
    dims: [10]
  }
]
output [
  {
    name: "output_1"
    data_type: TYPE_FP32
    dims: [1]
  }
]
version_policy: {
        specific:{versions: 1}
},
instance_group [
  {
    count: 1
    kind : KIND_GPU
    gpus: [0]
  }
]
Writing model_repo/hps_tf_triton/config.pbtxt
!tree model_repo/hps_tf_triton
model_repo/hps_tf_triton
├── 1
│   ├── hps_tf_triton_sparse_0.model
│   │   ├── emb_vector
│   │   └── key
│   └── model.savedmodel
│       ├── assets
│       ├── keras_metadata.pb
│       ├── saved_model.pb
│       └── variables
│           ├── variables.data-00000-of-00001
│           └── variables.index
├── config.pbtxt
└── hps_tf_triton.json

5 directories, 8 files

We can then launch the Triton inference server using the TensorFlow backend. Please note that LD_PRELOAD is utilized to load the custom TensorFlow operations (i.e., HPS related operations) into Triton. For more information, please refer to TensorFlow Custom Operations in Triton.

Note: Since Background processes not supported by Jupyter, please launch the Triton Server according to the following command independently in the background.

LD_PRELOAD=/usr/local/lib/python3.8/dist-packages/merlin_hps-1.0.0-py3.8-linux-x86_64.egg/hierarchical_parameter_server/lib/libhierarchical_parameter_server.so tritonserver –model-repository=/hugectr/hierarchical_parameter_server/notebooks/model_repo –backend-config=tensorflow,version=2 –load-model=hps_tf_triton –model-control-mode=explicit

We can then send the requests to the Triton inference server using the HTTP client. Please note that HPS will be initialized implicitly when the first request is processed at the server side, and the latency can be higher than that of later requests.

import os
num_gpu = 1
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(num_gpu)))

import numpy as np
import tritonclient.http as httpclient
from tritonclient.utils import *

triton_client = httpclient.InferenceServerClient(url="localhost:8000", verbose=True)
triton_client.is_server_live()
triton_client.get_model_repository_index()

num_requests = 5
num_samples = 16

for i in range(num_requests):
    print("--------------------------Request {}--------------------------".format(i))
    key_tensor, dense_tensor, _ = generate_random_samples(num_samples, args["vocabulary_range_per_slot"], args["dense_dim"])

    inputs = [
        httpclient.InferInput("input_1", 
                              key_tensor.shape,
                              np_to_triton_dtype(np.int64)),
        httpclient.InferInput("input_2", 
                              dense_tensor.shape,
                              np_to_triton_dtype(np.float32)),
    ]

    inputs[0].set_data_from_numpy(key_tensor)
    inputs[1].set_data_from_numpy(dense_tensor)
    outputs = [
        httpclient.InferRequestedOutput("output_1")
    ]

    # print("Input key tensor is \n{}".format(key_tensor))
    # print("Input dense tensor is \n{}".format(dense_tensor))
    model_name = "hps_tf_triton"
    with httpclient.InferenceServerClient("localhost:8000") as client:
        response = client.infer(model_name,
                                inputs,
                                outputs=outputs)
        result = response.get_response()

        print("Prediction result:\n{}".format(response.as_numpy("output_1")))
        print("Response details:\n{}".format(result))
GET /v2/health/live, headers None
<HTTPSocketPoolResponse status=200 headers={'content-length': '0', 'content-type': 'text/plain'}>
POST /v2/repository/index, headers None

<HTTPSocketPoolResponse status=200 headers={'content-type': 'application/json', 'content-length': '56'}>
bytearray(b'[{"name":"hps_tf_triton","version":"1","state":"READY"}]')
--------------------------Request 0--------------------------
Prediction result:
[[102.5896  ]
 [109.57728 ]
 [105.14982 ]
 [112.536064]
 [115.500206]
 [122.54994 ]
 [118.42626 ]
 [116.46372 ]
 [109.31565 ]
 [123.4528  ]
 [106.171906]
 [108.97251 ]
 [ 92.7147  ]
 [ 92.135   ]
 [103.82919 ]
 [119.42267 ]]
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [16, 1], 'parameters': {'binary_data_size': 64}}]}
--------------------------Request 1--------------------------
Prediction result:
[[108.39017 ]
 [111.06535 ]
 [112.19471 ]
 [123.16409 ]
 [ 84.799545]
 [115.29873 ]
 [104.93053 ]
 [ 92.51399 ]
 [111.03866 ]
 [118.73721 ]
 [119.996704]
 [111.64917 ]
 [111.96098 ]
 [106.95992 ]
 [118.8165  ]
 [102.89214 ]]
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [16, 1], 'parameters': {'binary_data_size': 64}}]}
--------------------------Request 2--------------------------
Prediction result:
[[115.43975 ]
 [104.05401 ]
 [ 95.1138  ]
 [109.50248 ]
 [117.69166 ]
 [111.92008 ]
 [ 99.65907 ]
 [ 91.395035]
 [103.35495 ]
 [115.99719 ]
 [114.05845 ]
 [ 90.95559 ]
 [110.51797 ]
 [105.39578 ]
 [104.69898 ]
 [ 97.37328 ]]
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [16, 1], 'parameters': {'binary_data_size': 64}}]}
--------------------------Request 3--------------------------
Prediction result:
[[119.3963  ]
 [125.18664 ]
 [123.5703  ]
 [112.66611 ]
 [ 99.078514]
 [105.94452 ]
 [102.65439 ]
 [111.734314]
 [ 91.28878 ]
 [104.32374 ]
 [117.849236]
 [102.520256]
 [115.76198 ]
 [102.74941 ]
 [101.43743 ]
 [115.935295]]
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [16, 1], 'parameters': {'binary_data_size': 64}}]}
--------------------------Request 4--------------------------
Prediction result:
[[111.10583 ]
 [107.11977 ]
 [ 93.52705 ]
 [119.05332 ]
 [116.221054]
 [101.29974 ]
 [111.83873 ]
 [106.383804]
 [ 95.079666]
 [118.63491 ]
 [101.41594 ]
 [107.80403 ]
 [113.61658 ]
 [104.66516 ]
 [107.94727 ]
 [106.81342 ]]
Response details:
{'model_name': 'hps_tf_triton', 'model_version': '1', 'outputs': [{'name': 'output_1', 'datatype': 'FP32', 'shape': [16, 1], 'parameters': {'binary_data_size': 64}}]}