HPS Table Fusion Demo
This notebook demonstrates how to fuse embedding tables of the same embedding vector size with the HPS plugin for TensorFlow. It is recommended to run hierarchical_parameter_server_demo.ipynb before diving into this notebook.
For more details about HPS APIs, please refer to HPS APIs. For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).
Get HPS from NGC
The HPS Python module is preinstalled in the 24.04 and later Merlin HugeCTR Container: nvcr.io/nvidia/merlin/merlin-hugectr:24.04
You can check the existence of the required libraries by running the following Python code after launching this container.
$ python3 -c "import hierarchical_parameter_server as hps"
Create TF SavedModel
First of all we specify the required configurations, e.g., the arguments needed for generating the embedding tables, the template HPS JSON configuration file. We will use a naive deep neural network (DNN) model which has 8 embedding tables of the same emebedding vector size and one fully connected layer in this notebook.
We define the model with hps.LookupLayer
and some native TF layers, and then save it in the SavedModel format. Please note that the table fusion is turned off here by setting fuse_embedding_table
as False
%%writefile create_model_for_table_fusion.py
import hierarchical_parameter_server as hps
import tensorflow as tf
import os
import numpy as np
import struct
import json
import pytest
import time
VOCAB_SIZE = 10000
EMB_VEC_DTYPE = np.float32
TF_KEY_TYPE = tf.int32
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(NUM_GPUS)))
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
hps_config = {
"supportlonglong": False,
"fuse_embedding_table": True,
"models": [
"model": str(NUM_TABLES) + "_table",
"sparse_files": [],
"num_of_worker_buffer_in_pool": NUM_TABLES,
"embedding_table_names": [],
"embedding_vecsize_per_table": [],
"maxnum_catfeature_query_per_table_per_sample": [],
"default_value_for_each_table": [0.0],
"deployed_device_list": [0],
"max_batch_size": MAX_BATCH_SIZE,
"cache_refresh_percentage_per_iteration": 1.0,
"hit_rate_threshold": 1.0,
"gpucacheper": 1.0,
"gpucache": True,
"embedding_cache_type": "dynamic",
"use_context_stream": True,
def generate_embedding_tables(hugectr_sparse_model, vocab_range, embedding_vec_size):
os.system("mkdir -p {}".format(hugectr_sparse_model))
with open("{}/key".format(hugectr_sparse_model), "wb") as key_file, open(
"{}/emb_vector".format(hugectr_sparse_model), "wb"
) as vec_file:
for key in range(vocab_range[0], vocab_range[1]):
vec = 0.00025 * np.ones((embedding_vec_size,)).astype(np.float32)
key_struct = struct.pack("q", key)
vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
def set_up_model_files():
for i in range(NUM_TABLES):
table_name = "table" + str(i)
model_file_name = "embeddings/" + table_name
model_file_name, [i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE], EMB_VEC_SIZE
return hps_config
class InferenceModel(tf.keras.models.Model):
def __init__(self, num_tables, **kwargs):
super(InferenceModel, self).__init__(**kwargs)
self.lookup_layers = []
for i in range(num_tables):
model_name=str(NUM_TABLES) + "_table",
ps_config_file=str(NUM_TABLES) + "_table.json",
name="embedding_lookup" + str(i),
self.fc = tf.keras.layers.Dense(
def call(self, inputs):
assert len(inputs) == len(self.lookup_layers)
embeddings = []
for i in range(len(inputs)):
self.lookup_layers[i](inputs[i]), shape=[-1, NUM_QUERY_KEY * EMB_VEC_SIZE]
concat_embeddings = tf.concat(embeddings, axis=1)
logit = self.fc(concat_embeddings)
return logit
def summary(self):
inputs = []
for _ in range(len(self.lookup_layers)):
inputs.append(tf.keras.Input(shape=(NUM_QUERY_KEY,), dtype=TF_KEY_TYPE))
model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
return model.summary()
def create_savedmodel(hps_config):
# Overwrite JSON configuration file
hps_config["fuse_embedding_table"] = False
hps_config_json_object = json.dumps(hps_config, indent=4)
with open(str(NUM_TABLES) + "_table.json", "w") as outfile:
model = InferenceModel(NUM_TABLES)
inputs = []
for i in range(NUM_TABLES):
model.save(str(NUM_TABLES) + "_table.savedmodel")
# Overwrite JSON configuration file
hps_config["fuse_embedding_table"] = True
hps_config_json_object = json.dumps(hps_config, indent=4)
with open(str(NUM_TABLES) + "_table.json", "w") as outfile:
if __name__ == "__main__":
hps_config = set_up_model_files()
Writing create_model_for_table_fusion.py
import os
os.system("python3 create_model_for_table_fusion.py")
2023-03-29 07:24:28.206281: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-29 07:24:36.420084: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-29 07:24:36.926162: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1637] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30996 MB memory: -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
[INFO] hierarchical_parameter_server is imported
Model: "model"
Layer (type) Output Shape Param # Connected to
input_1 (InputLayer) [(None, 26)] 0 []
input_2 (InputLayer) [(None, 26)] 0 []
input_3 (InputLayer) [(None, 26)] 0 []
input_4 (InputLayer) [(None, 26)] 0 []
input_5 (InputLayer) [(None, 26)] 0 []
input_6 (InputLayer) [(None, 26)] 0 []
input_7 (InputLayer) [(None, 26)] 0 []
input_8 (InputLayer) [(None, 26)] 0 []
embedding_lookup0 (LookupLayer (None, 26, 128) 0 ['input_1[0][0]']
embedding_lookup1 (LookupLayer (None, 26, 128) 0 ['input_2[0][0]']
embedding_lookup2 (LookupLayer (None, 26, 128) 0 ['input_3[0][0]']
embedding_lookup3 (LookupLayer (None, 26, 128) 0 ['input_4[0][0]']
embedding_lookup4 (LookupLayer (None, 26, 128) 0 ['input_5[0][0]']
embedding_lookup5 (LookupLayer (None, 26, 128) 0 ['input_6[0][0]']
embedding_lookup6 (LookupLayer (None, 26, 128) 0 ['input_7[0][0]']
embedding_lookup7 (LookupLayer (None, 26, 128) 0 ['input_8[0][0]']
tf.reshape (TFOpLambda) (None, 3328) 0 ['embedding_lookup0[0][0]']
tf.reshape_1 (TFOpLambda) (None, 3328) 0 ['embedding_lookup1[0][0]']
tf.reshape_2 (TFOpLambda) (None, 3328) 0 ['embedding_lookup2[0][0]']
tf.reshape_3 (TFOpLambda) (None, 3328) 0 ['embedding_lookup3[0][0]']
tf.reshape_4 (TFOpLambda) (None, 3328) 0 ['embedding_lookup4[0][0]']
tf.reshape_5 (TFOpLambda) (None, 3328) 0 ['embedding_lookup5[0][0]']
tf.reshape_6 (TFOpLambda) (None, 3328) 0 ['embedding_lookup6[0][0]']
tf.reshape_7 (TFOpLambda) (None, 3328) 0 ['embedding_lookup7[0][0]']
tf.concat (TFOpLambda) (None, 26624) 0 ['tf.reshape[0][0]',
fc (Dense) (None, 1) 26625 ['tf.concat[0][0]']
Total params: 26,625
Trainable params: 26,625
Non-trainable params: 0
=====================================================HPS Parse====================================================
[HCTR][07:24:38.079][INFO][RK0][main]: dense_file is not specified using default:
[HCTR][07:24:38.079][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][07:24:38.079][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][07:24:38.079][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][07:24:38.079][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][07:24:38.079][INFO][RK0][main]: refresh_interval is not specified using default: 0
[HCTR][07:24:38.079][INFO][RK0][main]: use_static_table is not specified using default: 0
[HCTR][07:24:38.079][INFO][RK0][main]: HPS plugin uses context stream for model 8_table: True
====================================================HPS Create====================================================
[HCTR][07:24:38.080][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][07:24:38.080][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][07:24:38.080][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][07:24:38.080][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][07:24:38.080][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:24:38.547][INFO][RK0][main]: Table: hps_et.8_table.table0; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:39.379][INFO][RK0][main]: Table: hps_et.8_table.table1; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:39.830][INFO][RK0][main]: Table: hps_et.8_table.table2; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:40.448][INFO][RK0][main]: Table: hps_et.8_table.table3; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:40.899][INFO][RK0][main]: Table: hps_et.8_table.table4; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:41.934][INFO][RK0][main]: Table: hps_et.8_table.table5; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:43.097][INFO][RK0][main]: Table: hps_et.8_table.table6; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:45.296][INFO][RK0][main]: Table: hps_et.8_table.table7; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:45.296][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][07:24:45.297][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][07:24:45.306][INFO][RK0][main]: Model name: 8_table
[HCTR][07:24:45.306][INFO][RK0][main]: Max batch size: 256
[HCTR][07:24:45.306][INFO][RK0][main]: Fuse embedding tables: False
[HCTR][07:24:45.306][INFO][RK0][main]: Number of embedding tables: 8
[HCTR][07:24:45.306][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][07:24:45.306][INFO][RK0][main]: Embedding cache type: dynamic
[HCTR][07:24:45.306][INFO][RK0][main]: Use I64 input key: False
[HCTR][07:24:45.306][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][07:24:45.306][INFO][RK0][main]: The size of thread pool: 80
[HCTR][07:24:45.306][INFO][RK0][main]: The size of worker memory pool: 8
[HCTR][07:24:45.306][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][07:24:45.306][INFO][RK0][main]: The refresh percentage : 1.000000
[HCTR][07:24:45.469][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table0
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table1
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table2
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table3
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table4
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table5
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table6
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table7
[HCTR][07:24:45.475][INFO][RK0][main]: LookupSession i64_input_key: False
[HCTR][07:24:45.475][INFO][RK0][main]: Creating lookup session for 8_table on device: 0
Make inference with HPS table fusion
We load the TF SavedModel and make inference for several batches. The table fusion is enabled since fuse_embedding_table
is True
within the HPS JSON configuration file.
import hierarchical_parameter_server as hps
import tensorflow as tf
import os
import numpy as np
import time
VOCAB_SIZE = 10000
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(NUM_GPUS)))
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
model = tf.keras.models.load_model(str(NUM_TABLES) + "_table.savedmodel")
inputs_seq = []
for _ in range(NUM_ITERS + 1):
inputs = []
for i in range(NUM_TABLES):
preds = model(inputs_seq[0])
start = time.time()
for i in range(NUM_ITERS):
print("-" * 20, "Step {}".format(i), "-" * 20)
preds = model(inputs_seq[i + 1])
end = time.time()
"[INFO] Elapsed time for "
+ str(NUM_ITERS)
+ " iterations: "
+ str(end - start)
+ " seconds"
[INFO] hierarchical_parameter_server is imported
2023-03-29 07:25:39.918038: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-29 07:25:42.325440: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-29 07:25:42.818316: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1637] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30996 MB memory: -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
=====================================================HPS Parse====================================================
[HCTR][07:25:43.756][INFO][RK0][main]: Table fusion is enabled for HPS. Please ensure that there is no key value overlap in different tables and the embedding lookup layer has no dependency in the model graph. For more information, see https://nvidia-merlin.github.io/HugeCTR/main/hierarchical_parameter_server/hps_database_backend.html#configuration
[HCTR][07:25:43.756][INFO][RK0][main]: dense_file is not specified using default:
[HCTR][07:25:43.756][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][07:25:43.756][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][07:25:43.756][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][07:25:43.756][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][07:25:43.756][INFO][RK0][main]: refresh_interval is not specified using default: 0
[HCTR][07:25:43.756][INFO][RK0][main]: use_static_table is not specified using default: 0
[HCTR][07:25:43.756][INFO][RK0][main]: HPS plugin uses context stream for model 8_table: True
====================================================HPS Create====================================================
[HCTR][07:25:43.756][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][07:25:43.756][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][07:25:43.756][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][07:25:43.756][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][07:25:43.756][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:25:44.292][INFO][RK0][main]: Table: hps_et.8_table.fused_embedding0; cached 80000 / 80000 embeddings in volatile database (HashMapBackend); load: 80000 / 18446744073709551615 (0.00%).
[HCTR][07:25:44.292][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][07:25:44.292][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][07:25:44.299][INFO][RK0][main]: Model name: 8_table
[HCTR][07:25:44.299][INFO][RK0][main]: Max batch size: 256
[HCTR][07:25:44.299][INFO][RK0][main]: Fuse embedding tables: True
[HCTR][07:25:44.299][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][07:25:44.299][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][07:25:44.299][INFO][RK0][main]: Embedding cache type: dynamic
[HCTR][07:25:44.299][INFO][RK0][main]: Use I64 input key: False
[HCTR][07:25:44.299][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][07:25:44.299][INFO][RK0][main]: The size of thread pool: 80
[HCTR][07:25:44.299][INFO][RK0][main]: The size of worker memory pool: 8
[HCTR][07:25:44.299][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][07:25:44.299][INFO][RK0][main]: The refresh percentage : 1.000000
[HCTR][07:25:44.406][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:25:44.406][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.fused_embedding0
[HCTR][07:25:44.407][INFO][RK0][main]: LookupSession i64_input_key: False
[HCTR][07:25:44.407][INFO][RK0][main]: Creating lookup session for 8_table on device: 0
[INFO] Elapsed time for 100 iterations: 0.9442901611328125 seconds