http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png

SOK to HPS DLRM Demo

Overview

This notebook demonstrates how to train a DLRM model with SparseOperationKit (SOK) and then make inference with HierarchicalParameterServer(HPS). It is recommended to run sparse_operation_kit_demo.ipynb and hierarchical_parameter_server_demo.ipynb before diving into this notebook.

For more details about SOK, please refer to SOK Documentation. For more details about HPS APIs, please refer to HPS APIs. For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

Installation

Get SOK from NGC

Both SOK and HPS Python modules are preinstalled in the 22.08 and later Merlin TensorFlow Container: nvcr.io/nvidia/merlin/merlin-tensorflow:22.08.

You can check the existence of the required libraries by running the following Python code after launching this container.

$ python3 -c "import sparse_operation_kit as sok"
$ python3 -c "import hierarchical_parameter_server as hps"

Configurations

First of all we specify the required configurations, e.g., the arguments needed for generating the dataset, the model parameters and the paths to save the model. We will use DLRM model which has one embedding table, bottom MLP layers, interaction layer and top MLP layers. Please note that the input to the embedding layer will be a sparse key tensor.

import sparse_operation_kit as sok
import sys
sys.path.append("/hugectr/sparse_operation_kit/unit_test/test_scripts/tf2/")
import utils

import os
import numpy as np
import tensorflow as tf
import struct

args = dict()

args["gpu_num"] = 1                               # the number of available GPUs
args["iter_num"] = 10                             # the number of training iteration
args["slot_num"] = 26                             # the number of feature fields in this embedding layer
args["embed_vec_size"] = 16                       # the dimension of embedding vectors
args["dense_dim"] = 13                            # the dimension of dense features
args["global_batch_size"] = 1024                  # the globally batchsize for all GPUs
args["max_vocabulary_size"] = 260000
args["vocabulary_range_per_slot"] = [[i*10000, (i+1)*10000] for i in range(26)] 
args["max_nnz"] = 10                # the max number of non-zeros for all slots
args["combiner"] = "mean"

args["ps_config_file"] = "dlrm.json"
args["dense_model_path"] = "dlrm_dense.model"
args["embedding_table_path"] = "dlrm_sparse.model"
args["saved_path"] = "dlrm_tf_saved_model"
args["np_key_type"] = np.int64
args["np_vector_type"] = np.float32
args["tf_key_type"] = tf.int64
args["tf_vector_type"] = tf.float32
args["optimizer"] = "plugin_adam"

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(args["gpu_num"])))
[INFO]: sparse_operation_kit is imported
def generate_random_samples(num_samples, vocabulary_range_per_slot, max_nnz, dense_dim):
    def generate_sparse_keys(num_samples, vocabulary_range_per_slot, max_nnz, key_dtype = args["np_key_type"]):
        slot_num = len(vocabulary_range_per_slot)
        indices = []
        values = []
        for i in range(num_samples):
            for j in range(slot_num):
                vocab_range = vocabulary_range_per_slot[j]
                nnz = np.random.randint(low=1, high=max_nnz+1)
                entries = sorted(np.random.choice(max_nnz, nnz, replace=False))
                for entry in entries:
                    indices.append([i, j, entry])
                values.extend(np.random.randint(low=vocab_range[0], high=vocab_range[1], size=(nnz, )))
        values = np.array(values, dtype=key_dtype)
        return tf.sparse.SparseTensor(indices = indices,
                                    values = values,
                                    dense_shape = (num_samples, slot_num, max_nnz))

    
    sparse_keys = generate_sparse_keys(num_samples, vocabulary_range_per_slot, max_nnz)
    dense_features = np.random.random((num_samples, dense_dim)).astype(np.float32)
    labels = np.random.randint(low=0, high=2, size=(num_samples, 1))
    return sparse_keys, dense_features, labels

def tf_dataset(sparse_keys, dense_features, labels, batchsize):
    dataset = tf.data.Dataset.from_tensor_slices((sparse_keys, dense_features, labels))
    dataset = dataset.batch(batchsize, drop_remainder=True)
    return dataset

Train with SOK embedding layers

We define the model graph for training with SOK embedding layers, i.e., sok.DistributedEmbedding. We can then train the model and save the trained weights of the embedding table into the formats required by HPS. As for the dense layers, they are saved as a separate model graph, which can be loaded directly during inference.

class MLP(tf.keras.layers.Layer):
    def __init__(self,
                arch,
                activation='relu',
                out_activation=None,
                **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.layers = []
        index = 0
        for units in arch[:-1]:
            self.layers.append(tf.keras.layers.Dense(units, activation=activation, name="{}_{}".format(kwargs['name'], index)))
            index+=1
        self.layers.append(tf.keras.layers.Dense(arch[-1], activation=out_activation, name="{}_{}".format(kwargs['name'], index)))

            
    def call(self, inputs, training=True):
        x = self.layers[0](inputs)
        for layer in self.layers[1:]:
            x = layer(x)
        return x

class SecondOrderFeatureInteraction(tf.keras.layers.Layer):
    def __init__(self, self_interaction=False):
        super(SecondOrderFeatureInteraction, self).__init__()
        self.self_interaction = self_interaction

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        num_feas = tf.shape(inputs)[1]

        dot_products = tf.matmul(inputs, inputs, transpose_b=True)

        ones = tf.ones_like(dot_products)
        mask = tf.linalg.band_part(ones, 0, -1)
        out_dim = num_feas * (num_feas + 1) // 2

        if not self.self_interaction:
            mask = mask - tf.linalg.band_part(ones, 0, 0)
            out_dim = num_feas * (num_feas - 1) // 2
        flat_interactions = tf.reshape(tf.boolean_mask(dot_products, mask), (batch_size, out_dim))
        return flat_interactions

class DLRM(tf.keras.models.Model):
    def __init__(self,
                 combiner,
                 max_vocabulary_size_per_gpu,
                 embed_vec_size,
                 slot_num,
                 max_nnz,
                 dense_dim,
                 arch_bot,
                 arch_top,
                 self_interaction,
                 **kwargs):
        super(DLRM, self).__init__(**kwargs)
        
        self.combiner = combiner
        self.max_vocabulary_size_per_gpu = max_vocabulary_size_per_gpu
        self.embed_vec_size = embed_vec_size
        self.slot_num = slot_num
        self.max_nnz = max_nnz
        self.dense_dim = dense_dim
        
        self.embedding_layer = sok.DistributedEmbedding(combiner=self.combiner,
                                                        max_vocabulary_size_per_gpu=self.max_vocabulary_size_per_gpu,
                                                        embedding_vec_size=self.embed_vec_size,
                                                        slot_num=self.slot_num,
                                                        max_nnz=self.max_nnz)
        self.bot_nn = MLP(arch_bot, name = "bottom", out_activation='relu')
        self.top_nn = MLP(arch_top, name = "top", out_activation='sigmoid')
        self.interaction_op = SecondOrderFeatureInteraction(self_interaction)
        if self_interaction:
            self.interaction_out_dim = (self.slot_num+1) * (self.slot_num+2) // 2
        else:
            self.interaction_out_dim = self.slot_num * (self.slot_num+1) // 2
        self.reshape_layer1 = tf.keras.layers.Reshape((1, arch_bot[-1]), name = "reshape1")
        self.concat1 = tf.keras.layers.Concatenate(axis=1, name = "concat1")
        self.concat2 = tf.keras.layers.Concatenate(axis=1, name = "concat2")
            
    def call(self, inputs, training=True):
        input_cat = inputs[0]
        input_dense = inputs[1]
        
        embedding_vector = self.embedding_layer(input_cat, training=training)
        dense_x = self.bot_nn(input_dense)
        concat_features = self.concat1([embedding_vector, self.reshape_layer1(dense_x)])
        
        Z = self.interaction_op(concat_features)
        z = self.concat2([dense_x, Z])
        logit = self.top_nn(z)
        return logit, embedding_vector

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.max_nnz, ), sparse=True, dtype=args["tf_key_type"]), 
                  tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32)]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def train(args):
    dlrm = DLRM(combiner = "mean", 
                max_vocabulary_size_per_gpu = args["max_vocabulary_size"] // args["gpu_num"],
                embed_vec_size = args["embed_vec_size"],
                slot_num = args["slot_num"],
                max_nnz = args["max_nnz"],
                dense_dim = args["dense_dim"],
                arch_bot = [256, 128, args["embed_vec_size"]],
                arch_top = [256, 128, 1],
                self_interaction = False)

    emb_opt = utils.get_embedding_optimizer(args["optimizer"])(learning_rate=0.1)
    dense_opt = utils.get_dense_optimizer(args["optimizer"])(learning_rate=0.1)

    init_tensors = np.ones(shape=[args["max_vocabulary_size"], args["embed_vec_size"]], dtype=args["np_vector_type"])
    embedding_saver = sok.Saver()
    embedding_saver.load_embedding_values(dlrm.embedding_layer.embedding_variable, init_tensors)

    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    @tf.function
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, embedding_vector = dlrm(inputs, training=True)
            loss = loss_fn(labels, logit)
        embedding_variables, other_variable = sok.split_embedding_variable_from_others(dlrm.trainable_variables)
        grads, emb_grads = tape.gradient(loss, [other_variable, embedding_variables])
        if 'plugin' not in args["optimizer"]:
            with sok.OptimizerScope(embedding_variables):
                emb_opt.apply_gradients(zip(emb_grads, embedding_variables),
                                        experimental_aggregate_gradients=False)
        else:
            emb_opt.apply_gradients(zip(emb_grads, embedding_variables),
                                    experimental_aggregate_gradients=False)
        dense_opt.apply_gradients(zip(grads, other_variable))
        return logit, embedding_vector, loss

    sparse_keys, dense_features, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot"], args["max_nnz"], args["dense_dim"])
    dataset = tf_dataset(sparse_keys, dense_features, labels, args["global_batch_size"])
    for i, (sparse_keys, dense_features, labels) in enumerate(dataset):
        sparse_keys = tf.sparse.reshape(sparse_keys, [-1, sparse_keys.shape[-1]])
        inputs = [sparse_keys, dense_features]
        logit, embedding_vector, loss = _train_step(inputs, labels)
        print("-"*20, "Step {}, loss: {}".format(i, loss),  "-"*20)
    return dlrm, embedding_saver
sok.Init(global_batch_size=args["global_batch_size"])
trained_model, embedding_saver = train(args)
trained_model.summary()
2022-07-29 07:16:16.793169: I tensorflow/core/platform/cpu_feature_guard.cc:152] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-07-29 07:16:17.323141: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2022-07-29 07:16:17.323214: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30997 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/resources/manager.cc:107] Mapping from local_replica_id to device_id:
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/resources/manager.cc:109] 0 -> 0
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/resources/manager.cc:84] Global seed is 4287744788
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/resources/manager.cc:85] Local GPU Count: 1
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/resources/manager.cc:86] Global GPU Count: 1
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/resources/manager.cc:127] Global Replica Id: 0; Local Replica Id: 0
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/parameters/raw_manager.cc:132] Created embedding variable whose name is EmbeddingVariable
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/parameters/raw_param.cc:120] Variable: EmbeddingVariable on global_replica_id: 0 start initialization
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/parameters/raw_param.cc:137] Variable: EmbeddingVariable on global_replica_id: 0 initialization done.
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/facade.cc:257] SparseOperationKit allocated internal memory.
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/parameters/raw_manager.cc:225] Loading embedding values to Variable: EmbeddingVariable...
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/parameters/raw_param.cc:378] Allocated temporary buffer for loading embedding values.
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc_impl/embedding/common/src/dumping_functions.cc:299] num_total_keys = 260000, while total_max_vocabulary_size = 260000
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc_impl/embedding/common/src/dumping_functions.cc:350] Worker 0: Start uploading parameters. Total loop_num = 260
2022-07-29 07:16:17.078977: I sparse_operation_kit/kit_cc/kit_cc_infra/src/parameters/raw_manager.cc:235] Loaded embedding values to Variable: EmbeddingVariable.
/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:1082: UserWarning: "`binary_crossentropy` received `from_logits=True`, but the `output` argument was produced by a sigmoid or softmax activation and thus does not represent logits. Was this intended?"
  return dispatch_target(*args, **kwargs)
-------------------- Step 0, loss: 0.9379717111587524 --------------------
-------------------- Step 1, loss: 12726.013671875 --------------------
-------------------- Step 2, loss: 73.78772735595703 --------------------
-------------------- Step 3, loss: 71.33247375488281 --------------------
-------------------- Step 4, loss: 33.48320770263672 --------------------
-------------------- Step 5, loss: 234.79978942871094 --------------------
-------------------- Step 6, loss: 1.6663873195648193 --------------------
-------------------- Step 7, loss: 30.426162719726562 --------------------
-------------------- Step 8, loss: 2.430748462677002 --------------------
-------------------- Step 9, loss: 4.768443584442139 --------------------
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_2 (InputLayer)           [(None, 13)]         0           []                               
                                                                                                  
 bottom (MLP)                   (None, 16)           38544       ['input_2[0][0]']                
                                                                                                  
 input_1 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 distributed_embedding (Distrib  (None, 26, 16)      4160000     ['input_1[0][0]']                
 utedEmbedding)                                                                                   
                                                                                                  
 reshape1 (Reshape)             (None, 1, 16)        0           ['bottom[0][0]']                 
                                                                                                  
 concat1 (Concatenate)          (None, 27, 16)       0           ['distributed_embedding[0][0]',  
                                                                  'reshape1[0][0]']               
                                                                                                  
 second_order_feature_interacti  (None, None)        0           ['concat1[0][0]']                
 on (SecondOrderFeatureInteract                                                                   
 ion)                                                                                             
                                                                                                  
 concat2 (Concatenate)          (None, None)         0           ['bottom[0][0]',                 
                                                                  'second_order_feature_interactio
                                                                 n[0][0]']                        
                                                                                                  
 top (MLP)                      (None, 1)            127233      ['concat2[0][0]']                
                                                                                                  
==================================================================================================
Total params: 4,325,777
Trainable params: 4,325,777
Non-trainable params: 0
__________________________________________________________________________________________________
dense_model = tf.keras.Model([trained_model.get_layer("distributed_embedding").output,
                             trained_model.get_layer("bottom").input],
                             trained_model.get_layer("top").output)
dense_model.summary()
dense_model.save(args["dense_model_path"])
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_2 (InputLayer)           [(None, 13)]         0           []                               
                                                                                                  
 bottom (MLP)                   (None, 16)           38544       ['input_2[0][0]']                
                                                                                                  
 input_3 (InputLayer)           [(None, 26, 16)]     0           []                               
                                                                                                  
 reshape1 (Reshape)             (None, 1, 16)        0           ['bottom[1][0]']                 
                                                                                                  
 concat1 (Concatenate)          (None, 27, 16)       0           ['input_3[0][0]',                
                                                                  'reshape1[1][0]']               
                                                                                                  
 second_order_feature_interacti  (None, None)        0           ['concat1[1][0]']                
 on (SecondOrderFeatureInteract                                                                   
 ion)                                                                                             
                                                                                                  
 concat2 (Concatenate)          (None, None)         0           ['bottom[1][0]',                 
                                                                  'second_order_feature_interactio
                                                                 n[1][0]']                        
                                                                                                  
 top (MLP)                      (None, 1)            127233      ['concat2[1][0]']                
                                                                                                  
==================================================================================================
Total params: 165,777
Trainable params: 165,777
Non-trainable params: 0
__________________________________________________________________________________________________
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
2022-07-29 07:16:56.089529: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
WARNING:absl:Found untraced functions such as bottom_0_layer_call_fn, bottom_0_layer_call_and_return_conditional_losses, bottom_1_layer_call_fn, bottom_1_layer_call_and_return_conditional_losses, bottom_2_layer_call_fn while saving (showing 5 of 12). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: dlrm_dense.model/assets
INFO:tensorflow:Assets written to: dlrm_dense.model/assets
!mkdir -p dlrm_sparse.model
embedding_saver.dump_to_file(trained_model.embedding_layer.embedding_variable, args["embedding_table_path"])
!mv dlrm_sparse.model/EmbeddingVariable_keys.file dlrm_sparse.model/key
!mv dlrm_sparse.model/EmbeddingVariable_values.file dlrm_sparse.model/emb_vector
!ls -l dlrm_sparse.model
2022-07-29 07:17:01.079021: I sparse_operation_kit/kit_cc/kit_cc_infra/src/parameters/raw_manager.cc:192] Saving EmbeddingVariable to dlrm_sparse.model..
2022-07-29 07:17:01.079021: I sparse_operation_kit/kit_cc_impl/embedding/common/src/dumping_functions.cc:60] Worker: 0, GPU: 0 key-index count = 260000
2022-07-29 07:17:01.079021: I sparse_operation_kit/kit_cc_impl/embedding/common/src/dumping_functions.cc:147] Worker: 0, GPU: 0: dumping parameters from hashtable..
2022-07-29 07:17:01.079021: I sparse_operation_kit/kit_cc/kit_cc_infra/src/parameters/raw_manager.cc:200] Saved EmbeddingVariable to dlrm_sparse.model.
total 18360
-rw-r--r-- 1 nobody nogroup 16640000 Jul 29 07:17 emb_vector
-rw-r--r-- 1 nobody nogroup  2080000 Jul 29 07:17 key

Create the inference graph with HPS SparseLookupLayer

In order to use HPS in the inference stage, we need to create a inference model graph which is almost the same as the train graph except that sok.DistributedEmbedding is replaced by hps.SparseLookupLayer. The trained dense model graph can be loaded directly, while the weights of the embedding table can be retrieved by HPS from the folder dlrm_sparse.model.

We can then save the inference model graph, which will be ready to be loaded for inference deployment.

import hierarchical_parameter_server as hps

class InferenceModel(tf.keras.models.Model):
    def __init__(self,
                 slot_num,
                 embed_vec_size,
                 max_nnz,
                 dense_dim,
                 dense_model_path,
                 **kwargs):
        super(InferenceModel, self).__init__(**kwargs)
        
        self.slot_num = slot_num
        self.embed_vec_size = embed_vec_size
        self.max_nnz = max_nnz
        self.dense_dim = dense_dim
        
        self.sparse_lookup_layer = hps.SparseLookupLayer(model_name = "dlrm", 
                                            table_id = 0,
                                            emb_vec_size = self.embed_vec_size,
                                            emb_vec_dtype = args["tf_vector_type"])
        self.dense_model = tf.keras.models.load_model(dense_model_path)
    
    def call(self, inputs):
        input_cat = inputs[0]
        input_dense = inputs[1]

        embeddings = tf.reshape(self.sparse_lookup_layer(sp_ids=input_cat, sp_weights = None, combiner="mean"),
                                shape=[-1, self.slot_num, self.embed_vec_size])
        logit = self.dense_model([embeddings, input_dense])
        return logit, embeddings

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.max_nnz, ), sparse=True, dtype=args["tf_key_type"]), 
                  tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32)]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
[INFO] hierarchical_parameter_server is imported
def create_and_save_inference_graph(args): 
    model = InferenceModel(args["slot_num"], args["embed_vec_size"], args["max_nnz"], args["dense_dim"], args["dense_model_path"])
    model.summary()
    inputs = [tf.keras.Input(shape=(args["max_nnz"], ), sparse=True, dtype=args["tf_key_type"]), 
              tf.keras.Input(shape=(args["dense_dim"], ), dtype=tf.float32)]
    _, _ = model(inputs)
    model.save(args["saved_path"])
create_and_save_inference_graph(args)
2022-07-29 07:24:43.911439: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-07-29 07:24:44.490542: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30989 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 sparse_lookup_layer (SparseLoo  (None, 16)          0           ['input_1[0][0]']                
 kupLayer)                                                                                        
                                                                                                  
 tf.reshape (TFOpLambda)        (None, 26, 16)       0           ['sparse_lookup_layer[0][0]']    
                                                                                                  
 input_2 (InputLayer)           [(None, 13)]         0           []                               
                                                                                                  
 model_1 (Functional)           (None, 1)            165777      ['tf.reshape[0][0]',             
                                                                  'input_2[0][0]']                
                                                                                                  
==================================================================================================
Total params: 165,777
Trainable params: 165,777
Non-trainable params: 0
__________________________________________________________________________________________________
2022-07-29 07:24:48.043599: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_3 in the SavedModel.
WARNING:absl:Found untraced functions such as bottom_0_layer_call_fn, bottom_0_layer_call_and_return_conditional_losses, bottom_1_layer_call_fn, bottom_1_layer_call_and_return_conditional_losses, bottom_2_layer_call_fn while saving (showing 5 of 12). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: dlrm_tf_saved_model/assets
INFO:tensorflow:Assets written to: dlrm_tf_saved_model/assets

Inference with saved model graph

In order to initialize the lookup service provided by HPS, we also need to create a JSON configuration file and specify the details of the embedding tables for the models to be deployed. We deploy the DLRM model that has one embedding table here, and it can support multiple models with multiple embedding tables actually. Please note how maxnum_catfeature_query_per_table_per_sample is specified for the embedding table: the max_nnz is 10 for all the slots and there are 26 slots, so this entry is configured as 260.

We first call hps.Init to do the necessary initialization work, and then load the saved model graph to make inference. We peek at the keys and the embedding vectors for each table for the last inference batch.

%%writefile dlrm.json
{
    "supportlonglong": true,
    "models": [{
        "model": "dlrm",
        "sparse_files": ["dlrm_sparse.model"],
        "num_of_worker_buffer_in_pool": 3,
        "embedding_table_names":["sparse_embedding0"],
        "embedding_vecsize_per_table": [16],
        "maxnum_catfeature_query_per_table_per_sample": [260],
        "default_value_for_each_table": [1.0],
        "deployed_device_list": [0],
        "max_batch_size": 1024,
        "cache_refresh_percentage_per_iteration": 0.2,
        "hit_rate_threshold": 1.0,
        "gpucacheper": 1.0,
        "gpucache": true
        }
    ]
}
Overwriting dlrm.json
def inference_with_saved_model(args):
    hps.Init(global_batch_size = args["global_batch_size"],
             ps_config_file = args["ps_config_file"])
    model = tf.keras.models.load_model(args["saved_path"])
    model.summary()
    def _infer_step(inputs, labels):
        logit, embeddings = model(inputs)
        return logit, embeddings
    
    embeddings_peek = list()
    inputs_peek = list()
    
    sparse_keys, dense_features, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot"], args["max_nnz"], args["dense_dim"])
    dataset = tf_dataset(sparse_keys, dense_features, labels, args["global_batch_size"])
    for i, (sparse_keys, dense_features, labels) in enumerate(dataset):
        sparse_keys = tf.sparse.reshape(sparse_keys, [-1, sparse_keys.shape[-1]])
        inputs = [sparse_keys, dense_features]
        logit, embeddings = _infer_step(inputs, labels)
        embeddings_peek.append(embeddings)
        inputs_peek.append(inputs)
        print("-"*20, "Step {}".format(i),  "-"*20)
    return embeddings_peek, inputs_peek
embeddings_peek, inputs_peek = inference_with_saved_model(args)

# embedding table, input keys are SparseTensor 
print(inputs_peek[-1][0].values)
print(embeddings_peek[-1])
=====================================================HPS Parse====================================================
[HCTR][07:24:53.183][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][07:24:53.183][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][07:24:53.183][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][07:24:53.183][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][07:24:53.183][INFO][RK0][main]: refresh_interval is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][07:24:53.184][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][07:24:53.184][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][07:24:53.184][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][07:24:53.682][INFO][RK0][main]: Table: hps_et.dlrm.sparse_embedding0; cached 260000 / 260000 embeddings in volatile database (PreallocatedHashMapBackend); load: 260000 / 18446744073709551615 (0.00%).
[HCTR][07:24:53.682][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][07:24:53.682][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][07:24:53.689][INFO][RK0][main]: Model name: dlrm
[HCTR][07:24:53.689][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][07:24:53.689][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][07:24:53.689][INFO][RK0][main]: Use I64 input key: True
[HCTR][07:24:53.689][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][07:24:53.689][INFO][RK0][main]: The size of thread pool: 80
[HCTR][07:24:53.689][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][07:24:53.689][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][07:24:53.736][INFO][RK0][main]: Creating lookup session for dlrm on device: 0
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "inference_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 sparse_lookup_layer (Sparse  multiple                 0         
 LookupLayer)                                                    
                                                                 
 model_1 (Functional)        (None, 1)                 165777    
                                                                 
=================================================================
Total params: 165,777
Trainable params: 165,777
Non-trainable params: 0
_________________________________________________________________
-------------------- Step 0 --------------------
-------------------- Step 1 --------------------
-------------------- Step 2 --------------------
-------------------- Step 3 --------------------
-------------------- Step 4 --------------------
-------------------- Step 5 --------------------
-------------------- Step 6 --------------------
-------------------- Step 7 --------------------
-------------------- Step 8 --------------------
-------------------- Step 9 --------------------
tf.Tensor([   888   4486   5745 ... 255671 252879 252045], shape=(145888,), dtype=int64)
tf.Tensor(
[[[0.6825647  0.6801282  0.68074    ... 0.68074226 0.6818684  0.6809397 ]
  [1.3980061  1.3981627  1.3980061  ... 1.3980992  1.3980061  1.3980061 ]
  [0.78289294 0.7833897  0.78293324 ... 0.78336245 0.78305507 0.78301686]
  ...
  [0.880705   0.88164043 0.88109225 ... 0.87982655 0.88028604 0.88119066]
  [0.8650326  0.86442304 0.86414057 ... 0.8642554  0.8640611  0.8645548 ]
  [0.783202   0.78315204 0.78240466 ... 0.7826805  0.78258413 0.7824805 ]]

 [[0.8573375  0.85796195 0.85979205 ... 0.8595341  0.85846806 0.85798156]
  [0.7563881  0.7563928  0.7564304  ... 0.7563316  0.7563634  0.7564283 ]
  [0.62020814 0.6213356  0.62018126 ... 0.62036    0.6201106  0.6201722 ]
  ...
  [0.85459447 0.85330284 0.854774   ... 0.854769   0.8547034  0.85447353]
  [0.64481944 0.6447684  0.6449137  ... 0.64472693 0.64465916 0.64503783]
  [0.7852191  0.78577    0.78521436 ... 0.7852911  0.78544927 0.7853453 ]]

 [[0.6184057  0.61849916 0.61735946 ... 0.61852926 0.61921203 0.6175788 ]
  [0.7092892  0.7092928  0.7092843  ... 0.70928746 0.70928514 0.70928574]
  [0.6360293  0.6360285  0.636029   ... 0.63602984 0.63602865 0.63602734]
  ...
  [0.69062346 0.69038725 0.690281   ... 0.6907744  0.6904431  0.6903974 ]
  [0.6840397  0.684031   0.68404853 ... 0.6840508  0.68404937 0.68404216]
  [0.7159784  0.71973306 0.7159706  ... 0.7161063  0.71603465 0.71592766]]

 ...

 [[0.67292804 0.67351913 0.67328465 ... 0.67328894 0.6733438  0.67301095]
  [0.68593156 0.6859398  0.68593466 ... 0.6859294  0.6859311  0.68593705]
  [0.72352993 0.7230278  0.72331727 ... 0.72321206 0.72359455 0.7233958 ]
  ...
  [0.60178    0.6017275  0.60140777 ... 0.60140765 0.60151523 0.6015818 ]
  [0.73245263 0.73322636 0.7328412  ... 0.73278296 0.7325789  0.7329973 ]
  [0.68950844 0.69225705 0.6898281  ... 0.6889306  0.68944615 0.69020116]]

 [[0.848309   0.84465414 0.84872234 ... 0.8486877  0.84938526 0.8492384 ]
  [0.701107   0.6997489  0.70110285 ... 0.700902   0.7011098  0.70111394]
  [0.5723409  0.5738345  0.5723305  ... 0.57233423 0.57233775 0.572342  ]
  ...
  [0.82768726 0.82793933 0.8282728  ... 0.8282294  0.82802093 0.8280283 ]
  [0.6491487  0.64926434 0.64963746 ... 0.64926565 0.64935625 0.64957225]
  [0.5615084  0.56340796 0.5635457  ... 0.5635438  0.5613529  0.56135494]]

 [[0.9477315  0.94783926 0.94776624 ... 0.9477597  0.9477446  0.9477345 ]
  [0.74906373 0.7491199  0.74906075 ... 0.7490612  0.7490609  0.7490617 ]
  [0.6141995  0.6144503  0.6139838  ... 0.6140719  0.6141932  0.61409426]
  ...
  [0.6773844  0.67902935 0.67736465 ... 0.6773715  0.6773739  0.67744035]
  [0.700472   0.70258003 0.69977176 ... 0.70001334 0.69977176 0.69977176]
  [0.75941193 0.7594471  0.75891864 ... 0.7593392  0.75900066 0.75923026]]], shape=(1024, 26, 16), dtype=float32)