# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-hps-pretrained-model-training-demo/nvidia_logo.png

HPS Pretrained Model Training Demo

Overview

This notebook demonstrates how to use HPS to load pre-trained embedding tables. It is recommended to run hierarchical_parameter_server_demo.ipynb before diving into this notebook.

For more details about HPS APIs, please refer to HPS APIs. For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

Installation

Get HPS from NGC

The HPS Python module is preinstalled in the 23.04 and later Merlin TensorFlow Container: nvcr.io/nvidia/merlin/merlin-tensorflow:23.04.

You can check the existence of the required libraries by running the following Python code after launching this container.

$ python3 -c "import hierarchical_parameter_server as hps"

Configurations

First of all we specify the required configurations, e.g., the arguments needed for generating the dataset, the model parameters and the paths to save the model. We will use a deep neural network (DNN) model which has one embedding table and several dense layers. Please note that the input to the embedding layer will be a sparse key tensor.

import hierarchical_parameter_server as hps
import os
import numpy as np
import tensorflow as tf
import struct

args = dict()

args["gpu_num"] = 4                               # the number of available GPUs
args["iter_num"] = 10                             # the number of training iteration
args["slot_num"] = 10                             # the number of feature fields in this embedding layer
args["embed_vec_size"] = 16                       # the dimension of embedding vectors
args["dense_dim"] = 10                            # the dimension of dense features
args["global_batch_size"] = 1024                  # the globally batchsize for all GPUs
args["max_vocabulary_size"] = 100000
args["vocabulary_range_per_slot"] = [[i*10000, (i+1)*10000] for i in range(10)] 
args["max_nnz"] = 5                # the max number of non-zeros for all slots
args["combiner"] = "mean"

args["ps_config_file"] = "dnn.json"
args["dense_model_path"] = "dnn_dense.model"
args["embedding_table_path"] = "dnn_sparse.model"
args["saved_path"] = "dnn_tf_saved_model"
args["np_key_type"] = np.int64
args["np_vector_type"] = np.float32
args["tf_key_type"] = tf.int64
args["tf_vector_type"] = tf.float32

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(args["gpu_num"])))
[INFO] hierarchical_parameter_server is imported
def generate_random_samples(num_samples, vocabulary_range_per_slot, max_nnz, dense_dim):
    def generate_sparse_keys(num_samples, vocabulary_range_per_slot, max_nnz, key_dtype = args["np_key_type"]):
        slot_num = len(vocabulary_range_per_slot)
        indices = []
        values = []
        for i in range(num_samples):
            for j in range(slot_num):
                vocab_range = vocabulary_range_per_slot[j]
                nnz = np.random.randint(low=1, high=max_nnz+1)
                entries = sorted(np.random.choice(max_nnz, nnz, replace=False))
                for entry in entries:
                    indices.append([i, j, entry])
                values.extend(np.random.randint(low=vocab_range[0], high=vocab_range[1], size=(nnz, )))
        values = np.array(values, dtype=key_dtype)
        return tf.sparse.SparseTensor(indices = indices,
                                    values = values,
                                    dense_shape = (num_samples, slot_num, max_nnz))

    
    sparse_keys = generate_sparse_keys(num_samples, vocabulary_range_per_slot, max_nnz)
    dense_features = np.random.random((num_samples, dense_dim)).astype(np.float32)
    labels = np.random.randint(low=0, high=2, size=(num_samples, 1))
    return sparse_keys, dense_features, labels

def tf_dataset(sparse_keys, dense_features, labels, batchsize):
    dataset = tf.data.Dataset.from_tensor_slices((sparse_keys, dense_features, labels))
    dataset = dataset.batch(batchsize, drop_remainder=True)
    return dataset

Train with native TF layers

We define the model graph for training with native TF layers, i.e., tf.nn.embedding_lookup_sparse and tf.keras.layers.Dense. Besides, the embedding weights are stored in tf.Variable. We can then train the model and extract the trained weights of the embedding table.

class DNN(tf.keras.models.Model):
    def __init__(self,
                 init_tensors,
                 combiner,
                 embed_vec_size,
                 slot_num,
                 max_nnz,
                 dense_dim,
                 **kwargs):
        super(DNN, self).__init__(**kwargs)
        
        self.combiner = combiner
        self.embed_vec_size = embed_vec_size
        self.slot_num = slot_num
        self.max_nnz = max_nnz
        self.dense_dim = dense_dim
        self.params = tf.Variable(initial_value=tf.concat(init_tensors, axis=0))
        self.fc1 = tf.keras.layers.Dense(units=1024, activation="relu", name="fc1")
        self.fc2 = tf.keras.layers.Dense(units=256, activation="relu", name="fc2")
        self.fc3 = tf.keras.layers.Dense(units=1, activation="sigmoid", name="fc3")

    def call(self, inputs, training=True):
        input_cat = inputs[0]
        input_dense = inputs[1]
        
        # SparseTensor of keys, shape: (batch_size*slot_num, max_nnz)
        embeddings = tf.reshape(tf.nn.embedding_lookup_sparse(params=self.params, sp_ids=input_cat, sp_weights = None, combiner=self.combiner),
                                shape=[-1, self.slot_num * self.embed_vec_size])
        concat_feas = tf.concat([embeddings, input_dense], axis=1)
        logit = self.fc3(self.fc2(self.fc1(concat_feas)))
        return logit, embeddings

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.max_nnz, ), sparse=True, dtype=args["tf_key_type"]), 
                  tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32)]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def train(args):
    init_tensors = np.ones(shape=[args["max_vocabulary_size"], args["embed_vec_size"]], dtype=args["np_vector_type"])
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        model = DNN(init_tensors, args["combiner"], args["embed_vec_size"], args["slot_num"], args["max_nnz"], args["dense_dim"])
        model.summary()
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)    

    loss_fn = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        return tf.nn.compute_average_loss(loss, global_batch_size=args["global_batch_size"])
    
    def _reshape_input(sparse_keys):
        sparse_keys = tf.sparse.reshape(sparse_keys, [-1, sparse_keys.shape[-1]])
        return sparse_keys
    
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, _ = model(inputs)
            loss = _replica_loss(labels, logit)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return logit, loss

    def _dataset_fn(input_context):
        replica_batch_size = input_context.get_per_replica_batch_size(args["global_batch_size"])
        sparse_keys, dense_features, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot"], args["max_nnz"], args["dense_dim"])
        dataset = tf_dataset(sparse_keys, dense_features, labels, replica_batch_size)
        dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
        return dataset

    dataset = strategy.distribute_datasets_from_function(_dataset_fn)
    for i, (sparse_keys, dense_features, labels) in enumerate(dataset):
        sparse_keys = strategy.run(_reshape_input, args=(sparse_keys,))
        inputs = [sparse_keys, dense_features]  
        _, loss = strategy.run(_train_step, args=(inputs, labels))
        print("-"*20, "Step {}, loss: {}".format(i, loss),  "-"*20)
    return model
trained_model = train(args)
weights_list = trained_model.get_weights()
embedding_weights = weights_list[-1]
2022-07-29 06:41:55.554588: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-07-29 06:41:57.606412: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30989 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
2022-07-29 06:41:57.608128: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 30989 MB memory:  -> device: 1, name: Tesla V100-SXM2-32GB, pci bus id: 0000:07:00.0, compute capability: 7.0
2022-07-29 06:41:57.609468: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 30989 MB memory:  -> device: 2, name: Tesla V100-SXM2-32GB, pci bus id: 0000:0a:00.0, compute capability: 7.0
2022-07-29 06:41:57.610818: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 30989 MB memory:  -> device: 3, name: Tesla V100-SXM2-32GB, pci bus id: 0000:0b:00.0, compute capability: 7.0
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.compat.v1.nn.embedding_lookup_sparse), but are not present in its tracked objects:   <tf.Variable 'Variable:0' shape=(100000, 16) dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer.
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 5)]          0           []                               
                                                                                                  
 tf.compat.v1.nn.embedding_look  (None, 16)          0           ['input_1[0][0]']                
 up_sparse (TFOpLambda)                                                                           
                                                                                                  
 tf.reshape (TFOpLambda)        (None, 160)          0           ['tf.compat.v1.nn.embedding_looku
                                                                 p_sparse[0][0]']                 
                                                                                                  
 input_2 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 tf.concat (TFOpLambda)         (None, 170)          0           ['tf.reshape[0][0]',             
                                                                  'input_2[0][0]']                
                                                                                                  
 fc1 (Dense)                    (None, 1024)         175104      ['tf.concat[0][0]']              
                                                                                                  
 fc2 (Dense)                    (None, 256)          262400      ['fc1[0][0]']                    
                                                                                                  
 fc3 (Dense)                    (None, 1)            257         ['fc2[0][0]']                    
                                                                                                  
==================================================================================================
Total params: 437,761
Trainable params: 437,761
Non-trainable params: 0
__________________________________________________________________________________________________
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:1082: UserWarning: "`binary_crossentropy` received `from_logits=True`, but the `output` argument was produced by a sigmoid or softmax activation and thus does not represent logits. Was this intended?"
  return dispatch_target(*args, **kwargs)
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
WARNING:tensorflow:Efficient allreduce is not supported for 1 IndexedSlices
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').
-------------------- Step 0, loss: PerReplica:{
  0: tf.Tensor(0.1950232, shape=(), dtype=float32),
  1: tf.Tensor(0.20766959, shape=(), dtype=float32),
  2: tf.Tensor(0.2006835, shape=(), dtype=float32),
  3: tf.Tensor(0.21188965, shape=(), dtype=float32)
} --------------------
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
WARNING:tensorflow:Efficient allreduce is not supported for 1 IndexedSlices
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').
-------------------- Step 1, loss: PerReplica:{
  0: tf.Tensor(681.73474, shape=(), dtype=float32),
  1: tf.Tensor(691.33826, shape=(), dtype=float32),
  2: tf.Tensor(588.15265, shape=(), dtype=float32),
  3: tf.Tensor(622.72485, shape=(), dtype=float32)
} --------------------
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
WARNING:tensorflow:Efficient allreduce is not supported for 1 IndexedSlices
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').
-------------------- Step 2, loss: PerReplica:{
  0: tf.Tensor(6.9260483, shape=(), dtype=float32),
  1: tf.Tensor(8.509967, shape=(), dtype=float32),
  2: tf.Tensor(7.0374002, shape=(), dtype=float32),
  3: tf.Tensor(7.1059036, shape=(), dtype=float32)
} --------------------
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
-------------------- Step 3, loss: PerReplica:{
  0: tf.Tensor(3.002458, shape=(), dtype=float32),
  1: tf.Tensor(3.7079678, shape=(), dtype=float32),
  2: tf.Tensor(3.333396, shape=(), dtype=float32),
  3: tf.Tensor(3.6451607, shape=(), dtype=float32)
} --------------------
WARNING:tensorflow:5 out of the last 5 calls to <function _apply_all_reduce.<locals>._all_reduce at 0x7fba4c2dc1f0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
-------------------- Step 4, loss: PerReplica:{
  0: tf.Tensor(0.8326673, shape=(), dtype=float32),
  1: tf.Tensor(0.79405844, shape=(), dtype=float32),
  2: tf.Tensor(0.85364443, shape=(), dtype=float32),
  3: tf.Tensor(0.92679256, shape=(), dtype=float32)
} --------------------
WARNING:tensorflow:6 out of the last 6 calls to <function _apply_all_reduce.<locals>._all_reduce at 0x7fba4c2dcdc0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
-------------------- Step 5, loss: PerReplica:{
  0: tf.Tensor(0.5796976, shape=(), dtype=float32),
  1: tf.Tensor(0.54752666, shape=(), dtype=float32),
  2: tf.Tensor(0.57471323, shape=(), dtype=float32),
  3: tf.Tensor(0.54845804, shape=(), dtype=float32)
} --------------------
-------------------- Step 6, loss: PerReplica:{
  0: tf.Tensor(0.61678064, shape=(), dtype=float32),
  1: tf.Tensor(0.647662, shape=(), dtype=float32),
  2: tf.Tensor(0.6421599, shape=(), dtype=float32),
  3: tf.Tensor(0.6278339, shape=(), dtype=float32)
} --------------------
-------------------- Step 7, loss: PerReplica:{
  0: tf.Tensor(0.28049487, shape=(), dtype=float32),
  1: tf.Tensor(0.2768654, shape=(), dtype=float32),
  2: tf.Tensor(0.2943622, shape=(), dtype=float32),
  3: tf.Tensor(0.2805586, shape=(), dtype=float32)
} --------------------
-------------------- Step 8, loss: PerReplica:{
  0: tf.Tensor(1.2102679, shape=(), dtype=float32),
  1: tf.Tensor(1.368755, shape=(), dtype=float32),
  2: tf.Tensor(1.4997649, shape=(), dtype=float32),
  3: tf.Tensor(1.5143406, shape=(), dtype=float32)
} --------------------
-------------------- Step 9, loss: PerReplica:{
  0: tf.Tensor(0.413176, shape=(), dtype=float32),
  1: tf.Tensor(0.42411563, shape=(), dtype=float32),
  2: tf.Tensor(0.38453132, shape=(), dtype=float32),
  3: tf.Tensor(0.4314984, shape=(), dtype=float32)
} --------------------

Load the pre-trained embeddings via HPS

In order to use HPS to load the pre-trained embeddings, they should be converted to the formats required by HPS. After that, we can train a new model which leverages the pre-trained embeddings and only updates the weights of dense layers. Please note that hps.SparseLookupLayer and hps.LookupLayer are not trainable.

In order to initialize the lookup service provided by HPS, we also need to create a JSON configuration file and specify the details of the embedding tables for the models to be deployed. We deploy a model that has one embedding table here, and it can support multiple models with multiple embedding tables actually. Please note how maxnum_catfeature_query_per_table_per_sample is specified for the embedding table: the max_nnz is 5 for all the slots and there are 10 slots, so this entry is configured as 50.

def convert_to_sparse_model(embeddings_weights, embedding_table_path, embedding_vec_size):
    os.system("mkdir -p {}".format(embedding_table_path))
    with open("{}/key".format(embedding_table_path), 'wb') as key_file, \
        open("{}/emb_vector".format(embedding_table_path), 'wb') as vec_file:
      for key in range(embeddings_weights.shape[0]):
        vec = embeddings_weights[key]
        key_struct = struct.pack('q', key)
        vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
        key_file.write(key_struct)
        vec_file.write(vec_struct)
%%writefile dnn.json
{
    "supportlonglong": true,
    "models": [{
        "model": "dnn",
        "sparse_files": ["dnn_sparse.model"],
        "num_of_worker_buffer_in_pool": 3,
        "embedding_table_names":["sparse_embedding0"],
        "embedding_vecsize_per_table": [16],
        "maxnum_catfeature_query_per_table_per_sample": [50],
        "default_value_for_each_table": [1.0],
        "deployed_device_list": [0,1,2,3],
        "max_batch_size": 1024,
        "cache_refresh_percentage_per_iteration": 0.2,
        "hit_rate_threshold": 1.0,
        "gpucacheper": 1.0,
        "gpucache": true
        }
    ]
}
Overwriting dnn.json
class PreTrainedEmbedding(tf.keras.models.Model):
    def __init__(self,
                 combiner,
                 embed_vec_size,
                 slot_num,
                 max_nnz,
                 dense_dim,
                 **kwargs):
        super(PreTrainedEmbedding, self).__init__(**kwargs)
        
        self.combiner = combiner
        self.embed_vec_size = embed_vec_size
        self.slot_num = slot_num
        self.max_nnz = max_nnz
        self.dense_dim = dense_dim
        
        self.sparse_lookup_layer = hps.SparseLookupLayer(model_name = "dnn", 
                                                         table_id = 0,
                                                         emb_vec_size = self.embed_vec_size,
                                                         emb_vec_dtype = args["tf_vector_type"])
        # Only use one FC layer when leveraging pre-trained embeddings
        self.new_fc = tf.keras.layers.Dense(units=1, activation="sigmoid", name="new_fc")

    def call(self, inputs, training=True):
        input_cat = inputs[0]
        input_dense = inputs[1]
        
        # SparseTensor of keys, shape: (batch_size*slot_num, max_nnz)
        embeddings = tf.reshape(self.sparse_lookup_layer(sp_ids=input_cat, sp_weights = None, combiner=self.combiner),
                                shape=[-1, self.slot_num * self.embed_vec_size])
        concat_feas = tf.concat([embeddings, input_dense], axis=1)
        logit = self.new_fc(concat_feas)
        return logit, embeddings

    def summary(self):
        inputs = [tf.keras.Input(shape=(self.max_nnz, ), sparse=True, dtype=args["tf_key_type"]), 
                  tf.keras.Input(shape=(self.dense_dim, ), dtype=tf.float32)]
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
def train_with_pretrained_embeddings(args):
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        hps.Init(global_batch_size = args["global_batch_size"], ps_config_file = args["ps_config_file"])
        model = PreTrainedEmbedding(args["combiner"], args["embed_vec_size"], args["slot_num"], args["max_nnz"], args["dense_dim"])
        model.summary()
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
        
    loss_fn = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        return tf.nn.compute_average_loss(loss, global_batch_size=args["global_batch_size"])
    
    def _reshape_input(sparse_keys):
        sparse_keys = tf.sparse.reshape(sparse_keys, [-1, sparse_keys.shape[-1]])
        return sparse_keys
    
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, _ = model(inputs)
            loss = _replica_loss(labels, logit)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return logit, loss
    
    def _dataset_fn(input_context):
        replica_batch_size = input_context.get_per_replica_batch_size(args["global_batch_size"])
        sparse_keys, dense_features, labels = generate_random_samples(args["global_batch_size"]  * args["iter_num"], args["vocabulary_range_per_slot"], args["max_nnz"], args["dense_dim"])
        dataset = tf_dataset(sparse_keys, dense_features, labels, replica_batch_size)
        dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
        return dataset

    dataset = strategy.distribute_datasets_from_function(_dataset_fn)
    for i, (sparse_keys, dense_features, labels) in enumerate(dataset):
        sparse_keys = strategy.run(_reshape_input, args=(sparse_keys,))
        inputs = [sparse_keys, dense_features]
        _, loss = strategy.run(_train_step, args=(inputs, labels))
        print("-"*20, "Step {}, loss: {}".format(i, loss),  "-"*20)
    return model
convert_to_sparse_model(embedding_weights, args["embedding_table_path"], args["embed_vec_size"])
model = train_with_pretrained_embeddings(args)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
=====================================================HPS Parse====================================================
You are using the plugin with MirroredStrategy.
[HCTR][06:42:16.707][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][06:42:16.707][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][06:42:16.707][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][06:42:16.707][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][06:42:16.707][INFO][RK0][main]: refresh_interval is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][06:42:16.707][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][06:42:16.707][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][06:42:16.707][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][06:42:17.153][INFO][RK0][main]: Table: hps_et.dnn.sparse_embedding0; cached 100000 / 100000 embeddings in volatile database (PreallocatedHashMapBackend); load: 100000 / 18446744073709551615 (0.00%).
[HCTR][06:42:17.153][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][06:42:17.153][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][06:42:17.160][INFO][RK0][main]: Model name: dnn
[HCTR][06:42:17.160][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][06:42:17.160][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][06:42:17.160][INFO][RK0][main]: Use I64 input key: True
[HCTR][06:42:17.160][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][06:42:17.160][INFO][RK0][main]: The size of thread pool: 80
[HCTR][06:42:17.160][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][06:42:17.160][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][06:42:17.170][INFO][RK0][main]: Creating embedding cache in device 1.
[HCTR][06:42:17.177][INFO][RK0][main]: Model name: dnn
[HCTR][06:42:17.177][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][06:42:17.177][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][06:42:17.177][INFO][RK0][main]: Use I64 input key: True
[HCTR][06:42:17.177][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][06:42:17.177][INFO][RK0][main]: The size of thread pool: 80
[HCTR][06:42:17.177][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][06:42:17.177][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][06:42:17.180][INFO][RK0][main]: Creating embedding cache in device 2.
[HCTR][06:42:17.188][INFO][RK0][main]: Model name: dnn
[HCTR][06:42:17.188][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][06:42:17.188][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][06:42:17.188][INFO][RK0][main]: Use I64 input key: True
[HCTR][06:42:17.188][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][06:42:17.188][INFO][RK0][main]: The size of thread pool: 80
[HCTR][06:42:17.188][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][06:42:17.188][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][06:42:17.191][INFO][RK0][main]: Creating embedding cache in device 3.
[HCTR][06:42:17.197][INFO][RK0][main]: Model name: dnn
[HCTR][06:42:17.197][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][06:42:17.197][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][06:42:17.197][INFO][RK0][main]: Use I64 input key: True
[HCTR][06:42:17.197][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][06:42:17.197][INFO][RK0][main]: The size of thread pool: 80
[HCTR][06:42:17.197][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][06:42:17.197][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][06:42:17.300][INFO][RK0][main]: Creating lookup session for dnn on device: 0
[HCTR][06:42:17.300][INFO][RK0][main]: Creating lookup session for dnn on device: 1
[HCTR][06:42:17.300][INFO][RK0][main]: Creating lookup session for dnn on device: 2
[HCTR][06:42:17.300][INFO][RK0][main]: Creating lookup session for dnn on device: 3
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_3 (InputLayer)           [(None, 5)]          0           []                               
                                                                                                  
 sparse_lookup_layer (SparseLoo  (None, 16)          0           ['input_3[0][0]']                
 kupLayer)                                                                                        
                                                                                                  
 tf.reshape_1 (TFOpLambda)      (None, 160)          0           ['sparse_lookup_layer[0][0]']    
                                                                                                  
 input_4 (InputLayer)           [(None, 10)]         0           []                               
                                                                                                  
 tf.concat_1 (TFOpLambda)       (None, 170)          0           ['tf.reshape_1[0][0]',           
                                                                  'input_4[0][0]']                
                                                                                                  
 new_fc (Dense)                 (None, 1)            171         ['tf.concat_1[0][0]']            
                                                                                                  
==================================================================================================
Total params: 171
Trainable params: 171
Non-trainable params: 0
__________________________________________________________________________________________________
-------------------- Step 0, loss: PerReplica:{
  0: tf.Tensor(0.17934436, shape=(), dtype=float32),
  1: tf.Tensor(0.17969523, shape=(), dtype=float32),
  2: tf.Tensor(0.18917403, shape=(), dtype=float32),
  3: tf.Tensor(0.18102707, shape=(), dtype=float32)
} --------------------
-------------------- Step 1, loss: PerReplica:{
  0: tf.Tensor(1.7858478, shape=(), dtype=float32),
  1: tf.Tensor(1.68311, shape=(), dtype=float32),
  2: tf.Tensor(1.66279, shape=(), dtype=float32),
  3: tf.Tensor(1.5826445, shape=(), dtype=float32)
} --------------------
-------------------- Step 2, loss: PerReplica:{
  0: tf.Tensor(0.7325904, shape=(), dtype=float32),
  1: tf.Tensor(0.7331751, shape=(), dtype=float32),
  2: tf.Tensor(0.7210605, shape=(), dtype=float32),
  3: tf.Tensor(0.7671325, shape=(), dtype=float32)
} --------------------
-------------------- Step 3, loss: PerReplica:{
  0: tf.Tensor(0.62144834, shape=(), dtype=float32),
  1: tf.Tensor(0.5696643, shape=(), dtype=float32),
  2: tf.Tensor(0.5946336, shape=(), dtype=float32),
  3: tf.Tensor(0.64713424, shape=(), dtype=float32)
} --------------------
-------------------- Step 4, loss: PerReplica:{
  0: tf.Tensor(0.88115656, shape=(), dtype=float32),
  1: tf.Tensor(0.9079187, shape=(), dtype=float32),
  2: tf.Tensor(0.98161024, shape=(), dtype=float32),
  3: tf.Tensor(0.97925556, shape=(), dtype=float32)
} --------------------
-------------------- Step 5, loss: PerReplica:{
  0: tf.Tensor(0.6572284, shape=(), dtype=float32),
  1: tf.Tensor(0.6304919, shape=(), dtype=float32),
  2: tf.Tensor(0.66552734, shape=(), dtype=float32),
  3: tf.Tensor(0.6695935, shape=(), dtype=float32)
} --------------------
-------------------- Step 6, loss: PerReplica:{
  0: tf.Tensor(0.2002374, shape=(), dtype=float32),
  1: tf.Tensor(0.19162768, shape=(), dtype=float32),
  2: tf.Tensor(0.1874283, shape=(), dtype=float32),
  3: tf.Tensor(0.19209734, shape=(), dtype=float32)
} --------------------
-------------------- Step 7, loss: PerReplica:{
  0: tf.Tensor(0.5284709, shape=(), dtype=float32),
  1: tf.Tensor(0.6028371, shape=(), dtype=float32),
  2: tf.Tensor(0.5635803, shape=(), dtype=float32),
  3: tf.Tensor(0.5773235, shape=(), dtype=float32)
} --------------------
-------------------- Step 8, loss: PerReplica:{
  0: tf.Tensor(0.74001855, shape=(), dtype=float32),
  1: tf.Tensor(0.71915305, shape=(), dtype=float32),
  2: tf.Tensor(0.619328, shape=(), dtype=float32),
  3: tf.Tensor(0.7890761, shape=(), dtype=float32)
} --------------------
-------------------- Step 9, loss: PerReplica:{
  0: tf.Tensor(0.55197906, shape=(), dtype=float32),
  1: tf.Tensor(0.5565746, shape=(), dtype=float32),
  2: tf.Tensor(0.52792, shape=(), dtype=float32),
  3: tf.Tensor(0.6230979, shape=(), dtype=float32)
} --------------------