# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
 
Hierarchical Parameter Server Demo
Overview
Hierarchical Parameter Server (HPS) is a distributed recommendation inference framework, which combines a high-performance GPU embedding cache with an hierarchical storage architecture, to realize low-latency retrieval of embeddings for inference tasks. It is provided as a Python toolkit and can be easily integrated into the TensorFlow (TF) model graph.
This notebook demonstrates how to apply HPS to the trained model and then use it for inference in TensorFlow. For more details about HPS APIs, please refer to HPS APIs. For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).
Installation
Get HPS from NGC
The HPS Python module is preinstalled in the 23.08 and later Merlin TensorFlow Container: nvcr.io/nvidia/merlin/merlin-tensorflow:23.08.
You can check the existence of the required libraries by running the following Python code after launching this container.
$ python3 -c "import hierarchical_parameter_server as hps"
Configurations
First of all we specify the required configurations, e.g., the arguments needed for generating the dataset, the paths to save the model and the model parameters. We will use a naive deep neural network (DNN) model which has one embedding table and several dense layers in this notebook.
import hierarchical_parameter_server as hps
import os
import numpy as np
import tensorflow as tf
import struct
args = dict()
args["gpu_num"] = 1                               # the number of available GPUs
args["iter_num"] = 10                             # the number of training iteration
args["slot_num"] = 3                              # the number of feature fields in this embedding layer
args["embed_vec_size"] = 16                       # the dimension of embedding vectors
args["global_batch_size"] = 65536                 # the globally batchsize for all GPUs
args["max_vocabulary_size"] = 30000
args["vocabulary_range_per_slot"] = [[0,10000],[10000,20000],[20000,30000]]
args["ps_config_file"] = "naive_dnn.json"
args["dense_model_path"] = "naive_dnn_dense.model"
args["embedding_table_path"] = "naive_dnn_sparse.model"
args["saved_path"] = "naive_dnn_tf_saved_model"
args["np_key_type"] = np.int64
args["np_vector_type"] = np.float32
args["tf_key_type"] = tf.int64
args["tf_vector_type"] = tf.float32
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(args["gpu_num"])))
[INFO] hierarchical_parameter_server is imported
def generate_random_samples(num_samples, vocabulary_range_per_slot, key_dtype = args["np_key_type"]):
    keys = list()
    for vocab_range in vocabulary_range_per_slot:
        keys_per_slot = np.random.randint(low=vocab_range[0], high=vocab_range[1], size=(num_samples, 1), dtype=key_dtype)
        keys.append(keys_per_slot)
    keys = np.concatenate(np.array(keys), axis = 1)
    labels = np.random.randint(low=0, high=2, size=(num_samples, 1))
    return keys, labels
def tf_dataset(keys, labels, batchsize):
    dataset = tf.data.Dataset.from_tensor_slices((keys, labels))
    dataset = dataset.batch(batchsize, drop_remainder=True)
    return dataset
Train with native TF layers
We define the model graph for training with native TF layers, i.e., tf.nn.embedding_lookup and tf.keras.layers.Dense. Besides, the embedding weights are stored in tf.Variable. We can then train the model and extract the trained weights of the embedding table. As for the dense layers, they are saved as a separate model graph, which can be loaded directly during inference.
class TrainModel(tf.keras.models.Model):
    def __init__(self,
                 init_tensors,
                 slot_num,
                 embed_vec_size,
                 **kwargs):
        super(TrainModel, self).__init__(**kwargs)
        
        self.slot_num = slot_num
        self.embed_vec_size = embed_vec_size
        self.init_tensors = init_tensors
        self.params = tf.Variable(initial_value=tf.concat(self.init_tensors, axis=0))
        self.fc_1 = tf.keras.layers.Dense(units=256, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros",
                                                 name='fc_1')
        self.fc_2 = tf.keras.layers.Dense(units=1, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros",
                                                 name='fc_2')
    def call(self, inputs):
        embedding_vector = tf.nn.embedding_lookup(params=self.params, ids=inputs)
        embedding_vector = tf.reshape(embedding_vector, shape=[-1, self.slot_num * self.embed_vec_size])
        logit = self.fc_2(self.fc_1(embedding_vector))
        return logit, embedding_vector
    def summary(self):
        inputs = tf.keras.Input(shape=(self.slot_num,), dtype=args["tf_key_type"])
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()    
def train(args):
    init_tensors = np.ones(shape=[args["max_vocabulary_size"], args["embed_vec_size"]], dtype=args["np_vector_type"])
    
    model = TrainModel(init_tensors, args["slot_num"], args["embed_vec_size"])
    model.summary()
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
    
    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, embedding_vector = model(inputs)
            loss = loss_fn(labels, logit)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return logit, embedding_vector, loss
    keys, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot"],  args["np_key_type"])
    dataset = tf_dataset(keys, labels, args["global_batch_size"])
    for i, (id_tensors, labels) in enumerate(dataset):
        _, embedding_vector, loss = _train_step(id_tensors, labels)
        print("-"*20, "Step {}, loss: {}".format(i, loss),  "-"*20)
    return model
trained_model = train(args)
weights_list = trained_model.get_weights()
embedding_weights = weights_list[-1]
dense_model = tf.keras.models.Model(trained_model.get_layer("fc_1").input, trained_model.get_layer("fc_2").output)
dense_model.summary()
dense_model.save(args["dense_model_path"])
2022-07-12 07:49:56.742983: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.compat.v1.nn.embedding_lookup), but are not present in its tracked objects:   <tf.Variable 'Variable:0' shape=(30000, 16) dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer.
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 3)]               0         
                                                                 
 tf.compat.v1.nn.embedding_l  (None, 3, 16)            0         
 ookup (TFOpLambda)                                              
                                                                 
 tf.reshape (TFOpLambda)     (None, 48)                0         
                                                                 
 fc_1 (Dense)                (None, 256)               12544     
                                                                 
 fc_2 (Dense)                (None, 1)                 257       
                                                                 
=================================================================
Total params: 12,801
Trainable params: 12,801
Non-trainable params: 0
_________________________________________________________________
2022-07-12 07:49:57.326494: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30989 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
-------------------- Step 0, loss: 6136.6875 --------------------
-------------------- Step 1, loss: 4463.05712890625 --------------------
-------------------- Step 2, loss: 3192.029296875 --------------------
-------------------- Step 3, loss: 2180.40283203125 --------------------
-------------------- Step 4, loss: 1419.980712890625 --------------------
-------------------- Step 5, loss: 879.0396728515625 --------------------
-------------------- Step 6, loss: 513.3021240234375 --------------------
-------------------- Step 7, loss: 272.9712219238281 --------------------
-------------------- Step 8, loss: 129.147705078125 --------------------
-------------------- Step 9, loss: 48.21624755859375 --------------------
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 48)]              0         
                                                                 
 fc_1 (Dense)                (None, 256)               12544     
                                                                 
 fc_2 (Dense)                (None, 1)                 257       
                                                                 
=================================================================
Total params: 12,801
Trainable params: 12,801
Non-trainable params: 0
_________________________________________________________________
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
2022-07-12 07:49:59.645703: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: naive_dnn_dense.model/assets
Create the inference graph with HPS LookupLayer
In order to use HPS in the inference stage, we need to create a inference model graph which is almost the same as the train graph except that tf.nn.embedding_lookup is replaced by hps.LookupLayer. The trained dense model graph can be loaded directly, while the embedding weights should be converted to the formats required by HPS.
We can then save the inference model graph, which will be ready to be loaded for inference deployment.
class InferenceModel(tf.keras.models.Model):
    def __init__(self,
                 slot_num,
                 embed_vec_size,
                 dense_model_path,
                 **kwargs):
        super(InferenceModel, self).__init__(**kwargs)
        
        self.slot_num = slot_num
        self.embed_vec_size = embed_vec_size
        self.lookup_layer = hps.LookupLayer(model_name = "naive_dnn", 
                                            table_id = 0,
                                            emb_vec_size = self.embed_vec_size,
                                            emb_vec_dtype = args["tf_vector_type"],
                                            name = "lookup")
        self.dense_model = tf.keras.models.load_model(dense_model_path)
    def call(self, inputs):
        embedding_vector = self.lookup_layer(inputs)
        embedding_vector = tf.reshape(embedding_vector, shape=[-1, self.slot_num * self.embed_vec_size])
        logit = self.dense_model(embedding_vector)
        return logit, embedding_vector
    def summary(self):
        inputs = tf.keras.Input(shape=(self.slot_num,), dtype=args["tf_key_type"])
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def create_and_save_inference_graph(args): 
    model = InferenceModel(args["slot_num"], args["embed_vec_size"], args["dense_model_path"])
    model.summary()
    _, _ = model(tf.keras.Input(shape=(args["slot_num"],), dtype=args["tf_key_type"]))
    model.save(args["saved_path"])
def convert_to_sparse_model(embeddings_weights, embedding_table_path, embedding_vec_size):
    os.system("mkdir -p {}".format(embedding_table_path))
    with open("{}/key".format(embedding_table_path), 'wb') as key_file, \
        open("{}/emb_vector".format(embedding_table_path), 'wb') as vec_file:
      for key in range(embeddings_weights.shape[0]):
        vec = embeddings_weights[key]
        key_struct = struct.pack('q', key)
        vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
        key_file.write(key_struct)
        vec_file.write(vec_struct)
convert_to_sparse_model(embedding_weights, args["embedding_table_path"], args["embed_vec_size"])
create_and_save_inference_graph(args)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_3 (InputLayer)        [(None, 3)]               0         
                                                                 
 lookup (LookupLayer)        (None, 3, 16)             0         
                                                                 
 tf.reshape_1 (TFOpLambda)   (None, 48)                0         
                                                                 
 model_1 (Functional)        (None, 1)                 12801     
                                                                 
=================================================================
Total params: 12,801
Trainable params: 12,801
Non-trainable params: 0
_________________________________________________________________
INFO:tensorflow:Assets written to: naive_dnn_tf_saved_model/assets
Inference with saved model graph
In order to initialize the lookup service provided by HPS, we also need to create a JSON configuration file and specify the details of the embedding tables for the models to be deployed. We only show how to deploy a model that has one embedding table here, and it can support multiple models with multiple embedding tables actually.
We first call hps.Init to do the necessary initialization work, and then load the saved model graph to make inference. We peek at the keys and the embedding vectors (it has been reshaped from (None, 3, 16) to (None, 48)) for the last inference batch.
%%writefile naive_dnn.json
{
    "supportlonglong": true,
    "models": [{
        "model": "naive_dnn",
        "sparse_files": ["naive_dnn_sparse.model"],
        "num_of_worker_buffer_in_pool": 3,
        "embedding_table_names":["sparse_embedding1"],
        "embedding_vecsize_per_table": [16],
        "maxnum_catfeature_query_per_table_per_sample": [3],
        "default_value_for_each_table": [1.0],
        "deployed_device_list": [0],
        "max_batch_size": 65536,
        "cache_refresh_percentage_per_iteration": 0.2,
        "hit_rate_threshold": 1.0,
        "gpucacheper": 1.0,
        "gpucache": true
        }
    ]
}
Writing naive_dnn.json
def inference_with_saved_model(args):
    hps.Init(global_batch_size = args["global_batch_size"],
             ps_config_file = args["ps_config_file"])
    model = tf.keras.models.load_model(args["saved_path"])
    model.summary()
    def _infer_step(inputs, labels):
        logit, embedding_vector = model(inputs)
        return logit, embedding_vector
    embedding_vectors_peek = list()
    id_tensors_peek = list()
    keys, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot"],  args["np_key_type"])
    dataset = tf_dataset(keys, labels, args["global_batch_size"])
    for i, (id_tensors, labels) in enumerate(dataset):
        print("-"*20, "Step {}".format(i),  "-"*20)
        _, embedding_vector = _infer_step(id_tensors, labels)
        embedding_vectors_peek.append(embedding_vector)
        id_tensors_peek.append(id_tensors)
    return embedding_vectors_peek, id_tensors_peek
embedding_vectors_peek, id_tensors_peek = inference_with_saved_model(args)
print(embedding_vectors_peek[-1])
print(id_tensors_peek[-1])
=====================================================HPS Parse====================================================
[HCTR][07:50:25.009][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][07:50:25.009][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][07:50:25.009][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][07:50:25.009][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][07:50:25.009][INFO][RK0][main]: refresh_interval is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][07:50:25.009][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][07:50:25.010][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][07:50:25.010][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][07:50:25.357][INFO][RK0][main]: Table: hps_et.naive_dnn.sparse_embedding1; cached 30000 / 30000 embeddings in volatile database (PreallocatedHashMapBackend); load: 30000 / 18446744073709551615 (0.00%).
[HCTR][07:50:25.357][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][07:50:25.357][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][07:50:25.363][INFO][RK0][main]: Model name: naive_dnn
[HCTR][07:50:25.363][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][07:50:25.363][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][07:50:25.363][INFO][RK0][main]: Use I64 input key: True
[HCTR][07:50:25.363][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][07:50:25.363][INFO][RK0][main]: The size of thread pool: 80
[HCTR][07:50:25.363][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][07:50:25.363][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][07:50:25.405][INFO][RK0][main]: Creating lookup session for naive_dnn on device: 0
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "inference_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 lookup (LookupLayer)        multiple                  0         
                                                                 
 model_1 (Functional)        (None, 1)                 12801     
                                                                 
=================================================================
Total params: 12,801
Trainable params: 12,801
Non-trainable params: 0
_________________________________________________________________
-------------------- Step 0 --------------------
-------------------- Step 1 --------------------
-------------------- Step 2 --------------------
-------------------- Step 3 --------------------
-------------------- Step 4 --------------------
-------------------- Step 5 --------------------
-------------------- Step 6 --------------------
-------------------- Step 7 --------------------
-------------------- Step 8 --------------------
-------------------- Step 9 --------------------
tf.Tensor(
[[0.23265739 0.23265739 0.23265739 ... 0.11092357 0.11092357 0.11092357]
 [0.09594781 0.09594781 0.09594781 ... 0.16974597 0.16974597 0.16974597]
 [0.22555737 0.22555737 0.22555737 ... 0.20454781 0.20454781 0.20454781]
 ...
 [0.22397298 0.22397298 0.22397298 ... 0.1229516  0.1229516  0.1229516 ]
 [0.12451896 0.12451896 0.12451896 ... 0.21348731 0.21348731 0.21348731]
 [0.11943579 0.11943579 0.11943579 ... 0.2502464  0.2502464  0.2502464 ]], shape=(65536, 48), dtype=float32)
tf.Tensor(
[[ 5283 17773 26371]
 [ 5043 17928 22941]
 [ 5154 18816 28670]
 ...
 [ 9014 16185 22256]
 [ 9893 14515 25771]
 [ 5377 18265 28063]], shape=(65536, 3), dtype=int64)