http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-multi-table-sparse-input-demo/nvidia_logo.png

HPS for Multiple Tables and Sparse Inputs

Overview

This notebook demonstrates how to use HPS when there are multiple embedding tables and sparse input. It is recommended to run hierarchical_parameter_server_demo.ipynb before diving into this notebook.

For more details about HPS APIs, please refer to HPS APIs. For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

Installation

Get HPS from NGC

The HPS Python module is preinstalled in the 22.09 and later Merlin TensorFlow Container: nvcr.io/nvidia/merlin/merlin-tensorflow:22.09.

You can check the existence of the required libraries by running the following Python code after launching this container.

$ python3 -c "import hierarchical_parameter_server as hps"

Configurations

First of all we specify the required configurations, e.g., the arguments needed for generating the dataset, the paths to save the model and the model parameters. We will use a deep neural network (DNN) model which has two embedding table and several dense layers in this notebook. Please note that there are two inputs here, one is the sparse key tensor (multi-hot) while the other is the dense key tensor (one-hot).

import hierarchical_parameter_server as hps
import os
import numpy as np
import tensorflow as tf
import struct

args = dict()

args["gpu_num"] = 1                                         # the number of available GPUs
args["iter_num"] = 10                                       # the number of training iteration
args["global_batch_size"] = 1024                            # the globally batchsize for all GPUs

args["slot_num_per_table"] = [3, 2]                         # the number of feature fields for two embedding tables
args["embed_vec_size_per_table"] = [16, 32]                 # the dimension of embedding vectors for two embedding tables
args["max_vocabulary_size_per_table"] = [30000, 2000]       # the vocabulary size for two embedding tables
args["vocabulary_range_per_slot_per_table"] = [ [[0,10000],[10000,20000],[20000,30000]], [[0, 1000], [1000, 2000]] ]
args["max_nnz_per_slot_per_table"] = [[4, 2, 3], [1, 1]]    # the max number of non-zeros for each slot for two embedding tables

args["dense_model_path"] = "multi_table_sparse_input_dense.model"
args["ps_config_file"] = "multi_table_sparse_input.json"
args["embedding_table_path"] = ["multi_table_sparse_input_sparse_0.model", "multi_table_sparse_input_sparse_1.model"]
args["saved_path"] = "multi_table_sparse_input_tf_saved_model"
args["np_key_type"] = np.int64
args["np_vector_type"] = np.float32
args["tf_key_type"] = tf.int64
args["tf_vector_type"] = tf.float32


os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(args["gpu_num"])))
[INFO] hierarchical_parameter_server is imported
def generate_random_samples(num_samples, vocabulary_range_per_slot_per_table, max_nnz_per_slot_per_table):
    def generate_sparse_keys(num_samples, vocabulary_range_per_slot, max_nnz_per_slot, key_dtype = args["np_key_type"]):
        slot_num = len(max_nnz_per_slot)
        max_nnz_of_all_slots = max(max_nnz_per_slot)
        indices = []
        values = []
        for i in range(num_samples):
            for j in range(slot_num):
                vocab_range = vocabulary_range_per_slot[j]
                max_nnz = max_nnz_per_slot[j]
                nnz = np.random.randint(low=1, high=max_nnz+1)
                entries = sorted(np.random.choice(max_nnz, nnz, replace=False))
                for entry in entries:
                    indices.append([i, j, entry])
                values.extend(np.random.randint(low=vocab_range[0], high=vocab_range[1], size=(nnz, )))
        values = np.array(values, dtype=key_dtype)
        return tf.sparse.SparseTensor(indices = indices,
                                    values = values,
                                    dense_shape = (num_samples, slot_num, max_nnz_of_all_slots))

    def generate_dense_keys(num_samples, vocabulary_range_per_slot, key_dtype = args["np_key_type"]):
        dense_keys = list()
        for vocab_range in vocabulary_range_per_slot:
            keys_per_slot = np.random.randint(low=vocab_range[0], high=vocab_range[1], size=(num_samples, 1), dtype=key_dtype)
            dense_keys.append(keys_per_slot)
        dense_keys = np.concatenate(np.array(dense_keys), axis = 1)
        return dense_keys
    
    assert len(vocabulary_range_per_slot_per_table)==2, "there should be two embedding tables"
    assert max(max_nnz_per_slot_per_table[0])>1, "the first embedding table has sparse key input (multi-hot)"
    assert min(max_nnz_per_slot_per_table[1])==1, "the second embedding table has dense key input (one-hot)"
    
    sparse_keys = generate_sparse_keys(num_samples, vocabulary_range_per_slot_per_table[0], max_nnz_per_slot_per_table[0])
    dense_keys = generate_dense_keys(num_samples, vocabulary_range_per_slot_per_table[1])
    labels = np.random.randint(low=0, high=2, size=(num_samples, 1))
    return sparse_keys, dense_keys, labels

def tf_dataset(sparse_keys, dense_keys, labels, batchsize):
    dataset = tf.data.Dataset.from_tensor_slices((sparse_keys, dense_keys, labels))
    dataset = dataset.batch(batchsize, drop_remainder=True)
    return dataset

Train with native TF layers

We define the model graph for training with native TF layers, i.e., tf.nn.embedding_lookup_sparse, tf.nn.embedding_lookup and tf.keras.layers.Dense. We can then train the model and extract the trained weights of the two embedding tables. As for the dense layers, they are saved as a separate model graph, which can be loaded directly during inference.

class TrainModel(tf.keras.models.Model):
    def __init__(self,
                 init_tensors_per_table,
                 slot_num_per_table,
                 embed_vec_size_per_table,
                 max_nnz_per_slot_per_table,
                 **kwargs):
        super(TrainModel, self).__init__(**kwargs)
        
        self.slot_num_per_table = slot_num_per_table
        self.embed_vec_size_per_table = embed_vec_size_per_table
        self.max_nnz_per_slot_per_table = max_nnz_per_slot_per_table
        self.max_nnz_of_all_slots_per_table = [max(ele) for ele in self.max_nnz_per_slot_per_table]
        
        self.init_tensors_per_table = init_tensors_per_table
        self.params0 = tf.Variable(initial_value=tf.concat(self.init_tensors_per_table[0], axis=0))
        self.params1 = tf.Variable(initial_value=tf.concat(self.init_tensors_per_table[1], axis=0))
        
        self.reshape = tf.keras.layers.Reshape((self.max_nnz_of_all_slots_per_table[0],),
                                                input_shape=(self.slot_num_per_table[0], self.max_nnz_of_all_slots_per_table[0]))
        
        self.fc_1 = tf.keras.layers.Dense(units=256, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros",
                                                 name='fc_1')
        self.fc_2 = tf.keras.layers.Dense(units=256, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros",
                                                 name='fc_2')
        self.fc_3 = tf.keras.layers.Dense(units=1, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros",
                                                 name='fc_3')

    def call(self, inputs):
        # SparseTensor of keys, shape: (batch_size*slot_num, max_nnz)
        embeddings0 = tf.reshape(tf.nn.embedding_lookup_sparse(params=self.params0, sp_ids=inputs[0], sp_weights = None, combiner="mean"),
                                shape=[-1, self.slot_num_per_table[0] * self.embed_vec_size_per_table[0]])
        # Tensor of keys, shape: (batch_size, slot_num)
        embeddings1 = tf.reshape(tf.nn.embedding_lookup(params=self.params1, ids=inputs[1]), 
                                 shape=[-1, self.slot_num_per_table[1] * self.embed_vec_size_per_table[1]])
        
        logit = self.fc_3(tf.math.add(self.fc_1(embeddings0), self.fc_2(embeddings1)))
        return logit, embeddings0, embeddings1

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.max_nnz_of_all_slots_per_table[0], ), sparse=True, dtype=args["tf_key_type"]),
                  tf.keras.Input(shape=(self.slot_num_per_table[1], ), dtype=args["tf_key_type"])]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def train(args):
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, _, _ = model(inputs)
            loss = loss_fn(labels, logit)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return logit, loss

    init_tensors_per_table = [np.ones(shape=[args["max_vocabulary_size_per_table"][0], args["embed_vec_size_per_table"][0]], dtype=args["np_vector_type"]),
                              np.ones(shape=[args["max_vocabulary_size_per_table"][1], args["embed_vec_size_per_table"][1]], dtype=args["np_vector_type"])]

    model = TrainModel(init_tensors_per_table, args["slot_num_per_table"], args["embed_vec_size_per_table"], args["max_nnz_per_slot_per_table"])
    model.summary()
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    sparse_keys, dense_keys, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot_per_table"], args["max_nnz_per_slot_per_table"])
    dataset = tf_dataset(sparse_keys, dense_keys, labels, args["global_batch_size"])
    for i, (sparse_keys, dense_keys, labels) in enumerate(dataset):
        sparse_keys = tf.sparse.reshape(sparse_keys, [-1, sparse_keys.shape[-1]])
        inputs = [sparse_keys, dense_keys]
        _, loss = _train_step(inputs, labels)
        print("-"*20, "Step {}, loss: {}".format(i, loss),  "-"*20)
    return model
trained_model = train(args)
weights_list = trained_model.get_weights()
embedding_weights_per_table = weights_list[-2:]
dense_model = tf.keras.Model([trained_model.get_layer("fc_1").input, 
                              trained_model.get_layer("fc_2").input], 
                             trained_model.get_layer("fc_3").output)
dense_model.summary()
dense_model.save(args["dense_model_path"])
2022-07-12 07:51:09.676041: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-07-12 07:51:10.271131: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30989 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.compat.v1.nn.embedding_lookup_sparse), but are not present in its tracked objects:   <tf.Variable 'Variable:0' shape=(30000, 16) dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer.
WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.compat.v1.nn.embedding_lookup), but are not present in its tracked objects:   <tf.Variable 'Variable:0' shape=(2000, 32) dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer.
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 4)]          0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 2)]          0           []                               
                                                                                                  
 tf.compat.v1.nn.embedding_look  (None, 16)          0           ['input_1[0][0]']                
 up_sparse (TFOpLambda)                                                                           
                                                                                                  
 tf.compat.v1.nn.embedding_look  (None, 2, 32)       0           ['input_2[0][0]']                
 up (TFOpLambda)                                                                                  
                                                                                                  
 tf.reshape (TFOpLambda)        (None, 48)           0           ['tf.compat.v1.nn.embedding_looku
                                                                 p_sparse[0][0]']                 
                                                                                                  
 tf.reshape_1 (TFOpLambda)      (None, 64)           0           ['tf.compat.v1.nn.embedding_looku
                                                                 p[0][0]']                        
                                                                                                  
 fc_1 (Dense)                   (None, 256)          12544       ['tf.reshape[0][0]']             
                                                                                                  
 fc_2 (Dense)                   (None, 256)          16640       ['tf.reshape_1[0][0]']           
                                                                                                  
 tf.math.add (TFOpLambda)       (None, 256)          0           ['fc_1[0][0]',                   
                                                                  'fc_2[0][0]']                   
                                                                                                  
 fc_3 (Dense)                   (None, 1)            257         ['tf.math.add[0][0]']            
                                                                                                  
==================================================================================================
Total params: 29,441
Trainable params: 29,441
Non-trainable params: 0
__________________________________________________________________________________________________
-------------------- Step 0, loss: 14588.0 --------------------
-------------------- Step 1, loss: 11693.25 --------------------
-------------------- Step 2, loss: 8232.9658203125 --------------------
-------------------- Step 3, loss: 6276.9736328125 --------------------
-------------------- Step 4, loss: 4676.82861328125 --------------------
-------------------- Step 5, loss: 2921.1875 --------------------
-------------------- Step 6, loss: 1938.2447509765625 --------------------
-------------------- Step 7, loss: 1093.598388671875 --------------------
-------------------- Step 8, loss: 616.3092651367188 --------------------
-------------------- Step 9, loss: 257.61248779296875 --------------------
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_3 (InputLayer)           [(None, 48)]         0           []                               
                                                                                                  
 input_4 (InputLayer)           [(None, 64)]         0           []                               
                                                                                                  
 fc_1 (Dense)                   (None, 256)          12544       ['input_3[0][0]']                
                                                                                                  
 fc_2 (Dense)                   (None, 256)          16640       ['input_4[0][0]']                
                                                                                                  
 tf.math.add (TFOpLambda)       (None, 256)          0           ['fc_1[1][0]',                   
                                                                  'fc_2[1][0]']                   
                                                                                                  
 fc_3 (Dense)                   (None, 1)            257         ['tf.math.add[1][0]']            
                                                                                                  
==================================================================================================
Total params: 29,441
Trainable params: 29,441
Non-trainable params: 0
__________________________________________________________________________________________________
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
2022-07-12 07:51:13.335404: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: multi_table_sparse_input_dense.model/assets
INFO:tensorflow:Assets written to: multi_table_sparse_input_dense.model/assets

Create the inference graph with HPS SparseLookupLayer and LookupLayer

In order to use HPS in the inference stage, we need to create a inference model graph which is almost the same as the train graph except that tf.nn.embedding_lookup_sparse is replaced by hps.SparseLookupLayer and tf.nn.embedding_lookup is replaced by hps.LookupLayer. The trained dense model graph can be loaded directly, while the weights of two embedding tables should be converted to the formats required by HPS.

We can then save the inference model graph, which will be ready to be loaded for inference deployment.

class InferenceModel(tf.keras.models.Model):
    def __init__(self,
                 slot_num_per_table,
                 embed_vec_size_per_table,
                 max_nnz_per_slot_per_table,
                 dense_model_path,
                 **kwargs):
        super(InferenceModel, self).__init__(**kwargs)
        
        self.slot_num_per_table = slot_num_per_table
        self.embed_vec_size_per_table = embed_vec_size_per_table
        self.max_nnz_per_slot_per_table = max_nnz_per_slot_per_table
        self.max_nnz_of_all_slots_per_table = [max(ele) for ele in self.max_nnz_per_slot_per_table]
        
        self.sparse_lookup_layer = hps.SparseLookupLayer(model_name = "multi_table_sparse_input", 
                                            table_id = 0,
                                            emb_vec_size = self.embed_vec_size_per_table[0],
                                            emb_vec_dtype = args["tf_vector_type"])
        self.lookup_layer = hps.LookupLayer(model_name = "multi_table_sparse_input", 
                                            table_id = 1,
                                            emb_vec_size = self.embed_vec_size_per_table[1],
                                            emb_vec_dtype = args["tf_vector_type"])
        self.dense_model = tf.keras.models.load_model(dense_model_path)
    
    def call(self, inputs):
        # SparseTensor of keys, shape: (batch_size*slot_num, max_nnz)
        embeddings0 = tf.reshape(self.sparse_lookup_layer(sp_ids=inputs[0], sp_weights = None, combiner="mean"),
                                shape=[-1, self.slot_num_per_table[0] * self.embed_vec_size_per_table[0]])
        # Tensor of keys, shape: (batch_size, slot_num)
        embeddings1 = tf.reshape(self.lookup_layer(inputs[1]), 
                                 shape=[-1, self.slot_num_per_table[1] * self.embed_vec_size_per_table[1]])
        
        logit = self.dense_model([embeddings0, embeddings1])
        return logit, embeddings0, embeddings1

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.max_nnz_of_all_slots_per_table[0], ), sparse=True, dtype=args["tf_key_type"]),
                  tf.keras.Input(shape=(self.slot_num_per_table[1], ), dtype=args["tf_key_type"])]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def create_and_save_inference_graph(args): 
    model = InferenceModel(args["slot_num_per_table"], args["embed_vec_size_per_table"], args["max_nnz_per_slot_per_table"], args["dense_model_path"])
    model.summary()
    inputs = [tf.keras.Input(shape=(max(args["max_nnz_per_slot_per_table"][0]), ), sparse=True, dtype=args["tf_key_type"]),
             tf.keras.Input(shape=(args["slot_num_per_table"][1], ), dtype=args["tf_key_type"])]
    _, _, _= model(inputs)
    model.save(args["saved_path"])
def convert_to_sparse_model(embeddings_weights, embedding_table_path, embedding_vec_size):
    os.system("mkdir -p {}".format(embedding_table_path))
    with open("{}/key".format(embedding_table_path), 'wb') as key_file, \
        open("{}/emb_vector".format(embedding_table_path), 'wb') as vec_file:
      for key in range(embeddings_weights.shape[0]):
        vec = embeddings_weights[key]
        key_struct = struct.pack('q', key)
        vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
        key_file.write(key_struct)
        vec_file.write(vec_struct)
convert_to_sparse_model(embedding_weights_per_table[0], args["embedding_table_path"][0], args["embed_vec_size_per_table"][0])
convert_to_sparse_model(embedding_weights_per_table[1], args["embedding_table_path"][1], args["embed_vec_size_per_table"][1])
create_and_save_inference_graph(args)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_5 (InputLayer)           [(None, 4)]          0           []                               
                                                                                                  
 input_6 (InputLayer)           [(None, 2)]          0           []                               
                                                                                                  
 sparse_lookup_layer (SparseLoo  (None, 16)          0           ['input_5[0][0]']                
 kupLayer)                                                                                        
                                                                                                  
 lookup_layer (LookupLayer)     (None, 2, 32)        0           ['input_6[0][0]']                
                                                                                                  
 tf.reshape_2 (TFOpLambda)      (None, 48)           0           ['sparse_lookup_layer[0][0]']    
                                                                                                  
 tf.reshape_3 (TFOpLambda)      (None, 64)           0           ['lookup_layer[0][0]']           
                                                                                                  
 model_1 (Functional)           (None, 1)            29441       ['tf.reshape_2[0][0]',           
                                                                  'tf.reshape_3[0][0]']           
                                                                                                  
==================================================================================================
Total params: 29,441
Trainable params: 29,441
Non-trainable params: 0
__________________________________________________________________________________________________
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_3 in the SavedModel.
INFO:tensorflow:Assets written to: multi_table_sparse_input_tf_saved_model/assets
INFO:tensorflow:Assets written to: multi_table_sparse_input_tf_saved_model/assets

Inference with saved model graph

In order to initialize the lookup service provided by HPS, we also need to create a JSON configuration file and specify the details of the embedding tables for the models to be deployed. We deploy a model that has two embedding tables here, and it can support multiple models with multiple embedding tables actually. Please note how maxnum_catfeature_query_per_table_per_sample is specified for the two embedding tables: the max_nnz_per_slot of the first table is [4, 2, 3], which sums to 9, and for the second table it is [1, 1], which sums to 2.

We first call hps.Init to do the necessary initialization work, and then load the saved model graph to make inference. We peek at the keys and the embedding vectors for each table for the last inference batch.

%%writefile multi_table_sparse_input.json
{
    "supportlonglong": true,
    "models": [{
        "model": "multi_table_sparse_input",
        "sparse_files": ["multi_table_sparse_input_sparse_0.model", "multi_table_sparse_input_sparse_1.model"],
        "num_of_worker_buffer_in_pool": 3,
        "embedding_table_names":["sparse_embedding0", "sparse_embedding1"],
        "embedding_vecsize_per_table": [16, 32],
        "maxnum_catfeature_query_per_table_per_sample": [9, 2],
        "default_value_for_each_table": [1.0, 1.0],
        "deployed_device_list": [0],
        "max_batch_size": 1024,
        "cache_refresh_percentage_per_iteration": 0.2,
        "hit_rate_threshold": 1.0,
        "gpucacheper": 1.0,
        "gpucache": true
        }
    ]
}
Writing multi_table_sparse_input.json
def inference_with_saved_model(args):
    hps.Init(global_batch_size = args["global_batch_size"],
             ps_config_file = args["ps_config_file"])
    model = tf.keras.models.load_model(args["saved_path"])
    model.summary()
    def _infer_step(inputs, labels):
        logit, embeddings0, embeddings1 = model(inputs)
        return logit, embeddings0, embeddings1
    embeddings0_peek = list()
    embeddings1_peek = list()
    inputs_peek = list()
    sparse_keys, dense_keys, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot_per_table"], args["max_nnz_per_slot_per_table"])
    dataset = tf_dataset(sparse_keys, dense_keys, labels, args["global_batch_size"])
    for i, (sparse_keys, dense_keys, labels) in enumerate(dataset):
        sparse_keys = tf.sparse.reshape(sparse_keys, [-1, sparse_keys.shape[-1]])
        inputs = [sparse_keys, dense_keys]
        logit, embeddings0, embeddings1 = _infer_step(inputs, labels)
        embeddings0_peek.append(embeddings0)
        embeddings1_peek.append(embeddings1)
        inputs_peek.append(inputs)
        print("-"*20, "Step {}".format(i),  "-"*20)
    return embeddings0_peek, embeddings1_peek, inputs_peek
embeddings0_peek, embeddings1_peek, inputs_peek = inference_with_saved_model(args)

# 1st embedding table, input keys are SparseTensor 
print(inputs_peek[-1][0].values)
print(embeddings0_peek[-1])

# 2nd embedding table, input keys are Tensor
print(inputs_peek[-1][1])
print(embeddings1_peek[-1])
=====================================================HPS Parse====================================================
[HCTR][07:51:32.495][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][07:51:32.495][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][07:51:32.495][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][07:51:32.495][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][07:51:32.495][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][07:51:32.495][INFO][RK0][main]: refresh_interval is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][07:51:32.495][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][07:51:32.495][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][07:51:32.495][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][07:51:32.855][INFO][RK0][main]: Table: hps_et.multi_table_sparse_input.sparse_embedding0; cached 30000 / 30000 embeddings in volatile database (PreallocatedHashMapBackend); load: 30000 / 18446744073709551615 (0.00%).
[HCTR][07:51:33.195][INFO][RK0][main]: Table: hps_et.multi_table_sparse_input.sparse_embedding1; cached 2000 / 2000 embeddings in volatile database (PreallocatedHashMapBackend); load: 2000 / 18446744073709551615 (0.00%).
[HCTR][07:51:33.195][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][07:51:33.195][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][07:51:33.201][INFO][RK0][main]: Model name: multi_table_sparse_input
[HCTR][07:51:33.201][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][07:51:33.201][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][07:51:33.201][INFO][RK0][main]: Use I64 input key: True
[HCTR][07:51:33.201][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][07:51:33.201][INFO][RK0][main]: The size of thread pool: 80
[HCTR][07:51:33.201][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][07:51:33.201][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][07:51:33.212][INFO][RK0][main]: Creating lookup session for multi_table_sparse_input on device: 0
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "inference_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 sparse_lookup_layer (Sparse  multiple                 0         
 LookupLayer)                                                    
                                                                 
 lookup_layer (LookupLayer)  multiple                  0         
                                                                 
 model_1 (Functional)        (None, 1)                 29441     
                                                                 
=================================================================
Total params: 29,441
Trainable params: 29,441
Non-trainable params: 0
_________________________________________________________________
-------------------- Step 0 --------------------
-------------------- Step 1 --------------------
-------------------- Step 2 --------------------
-------------------- Step 3 --------------------
-------------------- Step 4 --------------------
-------------------- Step 5 --------------------
-------------------- Step 6 --------------------
-------------------- Step 7 --------------------
-------------------- Step 8 --------------------
-------------------- Step 9 --------------------
tf.Tensor([ 9905  1750  4223 ... 20477 22119 23797], shape=(6111,), dtype=int64)
tf.Tensor(
[[0.9122444  0.9122444  0.9122444  ... 1.         1.         1.        ]
 [0.76979905 0.76979905 0.76979905 ... 1.         1.         1.        ]
 [0.7415885  0.7415885  0.7415885  ... 1.         1.         1.        ]
 ...
 [0.66938084 0.66938084 0.66938084 ... 1.         1.         1.        ]
 [0.90488005 0.90488005 0.90488005 ... 1.         1.         1.        ]
 [0.7773342  0.7773342  0.7773342  ... 0.6368773  0.6368773  0.6368773 ]], shape=(1024, 48), dtype=float32)
tf.Tensor(
[[ 276 1610]
 [ 408 1884]
 [ 678 1762]
 ...
 [ 369 1794]
 [ 403 1216]
 [ 909 1427]], shape=(1024, 2), dtype=int64)
tf.Tensor(
[[0.28882617 0.28882617 0.28882617 ... 0.41947648 0.41947648 0.41947648]
 [0.597903   0.597903   0.597903   ... 0.37505823 0.37505823 0.37505823]
 [0.70420146 0.70420146 0.70420146 ... 0.38864705 0.38864705 0.38864705]
 ...
 [0.32224336 0.32224336 0.32224336 ... 0.31987724 0.31987724 0.31987724]
 [0.43596342 0.43596342 0.43596342 ... 0.5383081  0.5383081  0.5383081 ]
 [0.37384593 0.37384593 0.37384593 ... 0.6026224  0.6026224  0.6026224 ]], shape=(1024, 64), dtype=float32)