http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png

Hierarchical Parameter Server Demo

Overview

In HugeCTR version 3.5, we provide Python APIs for embedding table lookup with HugeCTR Hierarchical Parameter Server (HPS) HPS supports different database backends and GPU embedding caches.

This notebook demonstrates how to use HPS with HugeCTR Python APIs. Without loss of generality, the HPS APIs are utilized together with the ONNX Runtime APIs to create an ensemble inference model, where HPS is responsible for embedding table lookup while the ONNX model takes charge of feed forward of dense neural networks.

Installation

Get HugeCTR from NGC

The HugeCTR Python module is preinstalled in the 22.07 and later Merlin Training Container: nvcr.io/nvidia/merlin/merlin-hugectr:22.07.

You can check the existence of required libraries by running the following Python code after launching this container.

$ python3 -c "import hugectr"

Note: This Python module contains both training APIs and offline inference APIs. For online inference with Triton, please refer to HugeCTR Backend.

If you prefer to build HugeCTR from the source code instead of using the NGC container, please refer to the How to Start Your Development documentation.

Data Generation

HugeCTR provides a tool to generate synthetic datasets. The Data Generator is capable of generating datasets of different file formats and different distributions. We will generate one-hot Parquet datasets with power-law distribution for this notebook:

import hugectr
from hugectr.tools import DataGeneratorParams, DataGenerator

data_generator_params = DataGeneratorParams(
  format = hugectr.DataReaderType_t.Parquet,
  label_dim = 1,
  dense_dim = 10,
  num_slot = 4,
  i64_input_key = True,
  nnz_array = [1, 1, 1, 1],
  source = "./data_parquet/file_list.txt",
  eval_source = "./data_parquet/file_list_test.txt",
  slot_size_array = [10000, 10000, 10000, 10000],
  check_type = hugectr.Check_t.Non,
  dist_type = hugectr.Distribution_t.PowerLaw,
  power_law_type = hugectr.PowerLaw_t.Short,
  num_files = 16,
  eval_num_files = 4,
  num_samples_per_file = 40960)
data_generator = DataGenerator(data_generator_params)
data_generator.generate()
[HCTR][11:15:15][INFO][RK0][main]: Generate Parquet dataset
[HCTR][11:15:15][INFO][RK0][main]: train data folder: ./data_parquet, eval data folder: ./data_parquet, slot_size_array: 10000, 10000, 10000, 10000, nnz array: 1, 1, 1, 1, #files for train: 16, #files for eval: 4, #samples per file: 40960, Use power law distribution: 1, alpha of power law: 1.3
[HCTR][11:15:15][INFO][RK0][main]: ./data_parquet exist
[HCTR][11:15:15][INFO][RK0][main]: ./data_parquet exist
[HCTR][11:15:15][INFO][RK0][main]: ./data_parquet/train exist
[HCTR][11:15:15][INFO][RK0][main]: ./data_parquet/train/gen_0.parquet
[HCTR][11:15:17][INFO][RK0][main]: ./data_parquet/train/gen_1.parquet
[HCTR][11:15:17][INFO][RK0][main]: ./data_parquet/train/gen_2.parquet
[HCTR][11:15:17][INFO][RK0][main]: ./data_parquet/train/gen_3.parquet
[HCTR][11:15:17][INFO][RK0][main]: ./data_parquet/train/gen_4.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_5.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_6.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_7.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_8.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_9.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_10.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_11.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_12.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_13.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_14.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/train/gen_15.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/file_list.txt done!
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val exist
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val/gen_0.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val/gen_1.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val/gen_2.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val/gen_3.parquet
[HCTR][11:15:21][INFO][RK0][main]: ./data_parquet/file_list_test.txt done!

Train from Scratch

We can train fom scratch by performing the following steps with Python APIs:

  1. Create the solver, reader and optimizer, then initialize the model.

  2. Construct the model graph by adding input, sparse embedding and dense layers in order.

  3. Compile the model and have an overview of the model graph.

  4. Dump the model graph to the JSON file.

  5. Fit the model, save the model weights and optimizer states implicitly.

  6. Dump one batch of evaluation results to files.

%%writefile train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(model_name = "hps_demo",
                              max_eval_batches = 1,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              lr = 0.001,
                              vvgpu = [[0]],
                              i64_input_key = True,
                              repeat_dataset = True,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
                                  source = ["./data_parquet/file_list.txt"],
                                  eval_source = "./data_parquet/file_list_test.txt",
                                  check_type = hugectr.Check_t.Non,
                                  slot_size_array = [10000, 10000, 10000, 10000])
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
                        dense_dim = 10, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("data1", [1, 1], True, 2),
                        hugectr.DataReaderSparseParam("data2", [1, 1], True, 2)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 4,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "data1",
                            optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 8,
                            embedding_vec_size = 32,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding2",
                            bottom_name = "data2",
                            optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=32))                            
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=64))                            
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "reshape2", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["relu1"],
                            top_names = ["fc2"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["fc2", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.graph_to_json("hps_demo.json")
model.fit(max_iter = 1100, display = 200, eval_interval = 1000, snapshot = 1000, snapshot_prefix = "hps_demo")
model.export_predictions("hps_demo_pred_" + str(1000), "hps_demo_label_" + str(1000))
Overwriting train.py
!python3 train.py
HugeCTR Version: 3.4
====================================================Model Init=====================================================
[HCTR][11:15:26][INFO][RK0][main]: Initialize model: hps_demo
[HCTR][11:15:26][INFO][RK0][main]: Global seed is 156170895
[HCTR][11:15:26][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][11:15:27][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][11:15:27][INFO][RK0][main]: Start all2all warmup
[HCTR][11:15:27][INFO][RK0][main]: End all2all warmup
[HCTR][11:15:27][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][11:15:27][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][11:15:27][INFO][RK0][main]: num of DataReader workers: 1
[HCTR][11:15:27][INFO][RK0][main]: Vocabulary size: 40000
[HCTR][11:15:27][INFO][RK0][main]: max_vocabulary_size_per_gpu_=21845
[HCTR][11:15:27][INFO][RK0][main]: max_vocabulary_size_per_gpu_=21845
[HCTR][11:15:27][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][11:15:29][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][11:15:29][INFO][RK0][main]: gpu0 init embedding done
[HCTR][11:15:29][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][11:15:29][INFO][RK0][main]: gpu0 init embedding done
[HCTR][11:15:29][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][11:15:29][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][11:15:29][INFO][RK0][main]: label                                   Dense                         Sparse                        
label                                   dense                          data1,data2                   
(None, 1)                               (None, 10)                              
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      data1                         sparse_embedding1             (None, 2, 16)                 
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      data2                         sparse_embedding2             (None, 2, 32)                 
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (None, 32)                    
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding2             reshape2                      (None, 64)                    
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (None, 106)                   
                                        reshape2                                                                                  
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1                       fc1                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            relu1                         fc2                           (None, 1)                     
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  fc2                           loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
[HCTR][11:15:29][INFO][RK0][main]: Save the model graph to hps_demo.json successfully
=====================================================Model Fit=====================================================
[HCTR][11:15:29][INFO][RK0][main]: Use non-epoch mode with number of iterations: 1100
[HCTR][11:15:29][INFO][RK0][main]: Training batchsize: 1024, evaluation batchsize: 1024
[HCTR][11:15:29][INFO][RK0][main]: Evaluation interval: 1000, snapshot interval: 1000
[HCTR][11:15:29][INFO][RK0][main]: Dense network trainable: True
[HCTR][11:15:29][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][11:15:29][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][11:15:29][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][11:15:29][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][11:15:29][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][11:15:29][INFO][RK0][main]: Training source file: ./data_parquet/file_list.txt
[HCTR][11:15:29][INFO][RK0][main]: Evaluation source file: ./data_parquet/file_list_test.txt
[HCTR][11:15:29][INFO][RK0][main]: Iter: 200 Time(200 iters): 0.211451s Loss: 0.694128 lr:0.001
[HCTR][11:15:29][INFO][RK0][main]: Iter: 400 Time(200 iters): 0.267199s Loss: 0.689953 lr:0.001
[HCTR][11:15:29][INFO][RK0][main]: Iter: 600 Time(200 iters): 0.216242s Loss: 0.689657 lr:0.001
[HCTR][11:15:29][INFO][RK0][main]: Iter: 800 Time(200 iters): 0.215779s Loss: 0.677149 lr:0.001
[HCTR][11:15:30][INFO][RK0][main]: Iter: 1000 Time(200 iters): 0.219875s Loss: 0.681208 lr:0.001
[HCTR][11:15:30][INFO][RK0][main]: Evaluation, AUC: 0.49589
[HCTR][11:15:30][INFO][RK0][main]: Eval Time for 1 iters: 0.000359s
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][11:15:30][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][11:15:30][INFO][RK0][main]: Done
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][11:15:30][INFO][RK0][main]: Done
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][11:15:30][INFO][RK0][main]: Done
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][11:15:30][INFO][RK0][main]: Done
[HCTR][11:15:30][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][11:15:30][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][11:15:30][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][11:15:30][INFO][RK0][main]: Finish 1100 iterations with batchsize: 1024 in 1.53s.

Convert HugeCTR to ONNX

We will convert the saved HugeCTR models to ONNX using the HugeCTR to ONNX Converter. For more information about the converter, refer to the README in the onnx_converter directory of the repository.

For the sake of double checking the correctness, we will investigate both cases of conversion depending on whether or not to convert the sparse embedding models.

import hugectr2onnx
hugectr2onnx.converter.convert(onnx_model_path = "hps_demo_with_embedding.onnx",
                            graph_config = "hps_demo.json",
                            dense_model = "hps_demo_dense_1000.model",
                            convert_embedding = True,
                            sparse_models = ["hps_demo0_sparse_1000.model", "hps_demo1_sparse_1000.model"])

hugectr2onnx.converter.convert(onnx_model_path = "hps_demo_without_embedding.onnx",
                            graph_config = "hps_demo.json",
                            dense_model = "hps_demo_dense_1000.model",
                            convert_embedding = False)
The model is checked!
The model is saved at hps_demo_with_embedding.onnx
Skip sparse embedding layers in converted ONNX model
Skip sparse embedding layers in converted ONNX model
The model is checked!
The model is saved at hps_demo_without_embedding.onnx

Inference with HPS & ONNX

We will make inference by performing the following steps with Python APIs:

  1. Configure the HPS hyperparameters.

  2. Initialize the HPS object, which is responsible for embedding table lookup.

  3. Loading the Parquet data.

  4. Make inference with the HPS object and the ONNX inference session of hps_demo_without_embedding.onnx.

  5. Check the correctness by comparing with dumped evaluation results.

  6. Make inference with the ONNX inference session of hps_demo_with_embedding.onnx (double check).

from hugectr.inference import HPS, ParameterServerConfig, InferenceParams

import pandas as pd
import numpy as np

import onnxruntime as ort

slot_size_array = [10000, 10000, 10000, 10000]
key_offset = np.insert(np.cumsum(slot_size_array), 0, 0)[:-1]
batch_size = 1024

# 1. Configure the HPS hyperparameters
ps_config = ParameterServerConfig(
           emb_table_name = {"hps_demo": ["sparse_embedding1", "sparse_embedding2"]},
           embedding_vec_size = {"hps_demo": [16, 32]},
           max_feature_num_per_sample_per_emb_table = {"hps_demo": [2, 2]},
           inference_params_array = [
              InferenceParams(
                model_name = "hps_demo",
                max_batchsize = batch_size,
                hit_rate_threshold = 1.0,
                dense_model_file = "",
                sparse_model_files = ["hps_demo0_sparse_1000.model", "hps_demo1_sparse_1000.model"],
                deployed_devices = [0],
                use_gpu_embedding_cache = True,
                cache_size_percentage = 0.5,
                i64_input_key = True)
           ])

# 2. Initialize the HPS object
hps = HPS(ps_config)

# 3. Loading the Parquet data.
df = pd.read_parquet("data_parquet/val/gen_0.parquet")
dense_input_columns = df.columns[1:11]
cat_input1_columns = df.columns[11:13]
cat_input2_columns = df.columns[13:15]
dense_input = df[dense_input_columns].loc[0:batch_size-1].to_numpy(dtype=np.float32)
cat_input1 = (df[cat_input1_columns].loc[0:batch_size-1].to_numpy(dtype=np.int64) + key_offset[0:2]).reshape((batch_size, 2, 1))
cat_input2 = (df[cat_input2_columns].loc[0:batch_size-1].to_numpy(dtype=np.int64) + key_offset[2:4]).reshape((batch_size, 2, 1))

# 4. Make inference from the HPS object and the ONNX inference session of `hps_demo_without_embedding.onnx`.
embedding1 = hps.lookup(cat_input1.flatten(), "hps_demo", 0).reshape(batch_size, 2, 16)
embedding2 = hps.lookup(cat_input2.flatten(), "hps_demo", 1).reshape(batch_size, 2, 32)
sess = ort.InferenceSession("hps_demo_without_embedding.onnx")
res = sess.run(output_names=[sess.get_outputs()[0].name],
               input_feed={sess.get_inputs()[0].name: dense_input,
               sess.get_inputs()[1].name: embedding1,
               sess.get_inputs()[2].name: embedding2})
pred = res[0]

# 5. Check the correctness by comparing with dumped evaluation results.
ground_truth = np.loadtxt("hps_demo_pred_1000")
print("ground_truth: ", ground_truth)
diff = pred.flatten()-ground_truth
mse = np.mean(diff*diff)
print("pred: ", pred)
print("mse between pred and ground_truth: ", mse)

# 6. Make inference with the ONNX inference session of `hps_demo_with_embedding.onnx` (double check).
sess_ref = ort.InferenceSession("hps_demo_with_embedding.onnx")
res_ref = sess_ref.run(output_names=[sess_ref.get_outputs()[0].name],
                   input_feed={sess_ref.get_inputs()[0].name: dense_input,
                   sess_ref.get_inputs()[1].name: cat_input1,
                   sess_ref.get_inputs()[2].name: cat_input2})
pred_ref = res_ref[0]
diff_ref = pred_ref.flatten()-ground_truth
mse_ref = np.mean(diff_ref*diff_ref)
print("pred_ref: ", pred_ref)
print("mse between pred_ref and ground_truth: ", mse_ref)
[HCTR][11:17:13][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][11:17:13][INFO][RK0][main]: Creating ParallelHashMap CPU database backend...
[HCTR][11:17:13][INFO][RK0][main]: Created parallel (16 partitions) blank database backend in local memory!
[HCTR][11:17:13][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][11:17:13][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][11:17:13][INFO][RK0][main]: Table: hps_et.hps_demo.sparse_embedding1; cached 15749 / 15749 embeddings in volatile database (ParallelHashMap); load: 15749 / 18446744073709551615 (0.00%).
[HCTR][11:17:13][INFO][RK0][main]: Table: hps_et.hps_demo.sparse_embedding2; cached 15781 / 15781 embeddings in volatile database (ParallelHashMap); load: 15781 / 18446744073709551615 (0.00%).
[HCTR][11:17:13][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][11:17:13][INFO][RK0][main]: Create embedding cache in device 0.
[HCTR][11:17:13][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][11:17:13][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][11:17:13][INFO][RK0][main]: Create inference session on device: 0
[HCTR][11:17:13][INFO][RK0][main]: Model name: hps_demo
[HCTR][11:17:13][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][11:17:13][INFO][RK0][main]: Use I64 input key: True
ground_truth:  [0.456111 0.417843 0.428037 ... 0.336745 0.53599  0.508711]
pred:  [[0.45611122]
 [0.4178428 ]
 [0.42803708]
 ...
 [0.3367453 ]
 [0.53599   ]
 [0.5087108 ]]
mse between pred and ground_truth:  8.241691052249094e-14
pred_ref:  [[0.45611122]
 [0.4178428 ]
 [0.42803708]
 ...
 [0.3367453 ]
 [0.53599   ]
 [0.5087108 ]]
mse between pred_ref and ground_truth:  7.573986338301264e-05
2022-03-31 11:17:13.779336470 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'key_to_indice_hash_all_tables'. It is not used by any node and should be removed from the model.

Lookup the Embedding Vector from DLPack

We also provide a lookup_fromdlpack interface that could query embedding keys on the CPU and return the embedding vectors on the GPU/CPU.

  1. Suppose you have created a Pytorch/Tensorflow tensor that stores the embedded keys.

  2. Convert the embedding key tensor to DLPack capsule through the corresponding platform’s to_dlpack function.

  3. Creates an empty tensor as a buffer to store embedding vectors.

  4. Convert a buffer tensor to DLPack capsule.

  5. Lookup the embedding vector of the corresponding embedding key directly through lookup_fromdlpack interface, and output it to the embedding vector buffer tensor

Note: Please make sure that tensorflow or pytorch have been installed correctly in the container

embedding1 = hps.lookup(cat_input1.flatten(), "hps_demo", 0).reshape(batch_size, 2, 16)
embedding2 = hps.lookup(cat_input2.flatten(), "hps_demo", 1).reshape(batch_size, 2, 32)

# 1. Look up from dlpack for Pytorch tensor on CPU
print(" Look up from dlpack for Pytorch tensor")
import torch.utils.dlpack
import os
print("************Look up from pytorch dlpack on CPU")
device = torch.device("cpu")
key = torch.tensor(cat_input1.flatten(),dtype=torch.int64, device=device)
out = torch.empty((1,cat_input1.flatten().shape[0]*16), dtype=torch.float32, device=device)
key_capsule = torch.utils.dlpack.to_dlpack(key)
print("The device type of embedding keys that lookup dlpack from hps interface for embedding table 0 of hps_demo: {}, the keys: {}".format(key.device, key))
out_capsule = torch.utils.dlpack.to_dlpack(out)
# Lookup the embedding vectors from dlpack
hps.lookup_fromdlpack(key_capsule, out_capsule,"hps_demo", 0)
out_put = torch.utils.dlpack.from_dlpack(out_capsule)
print("[The device type of embedding vectors that lookup dlpack from hps interface for embedding table 0 of hps_demo: {}, the vectors: {}\n".format(out_put.device, out_put))
diff = out_put-embedding1.reshape(1,cat_input1.flatten().shape[0]*16)
if diff.mean() > 1e-4:
    raise RuntimeError("Too large mse between pytorch dlpack on cpu and native HPS lookup api: {}".format(diff.mean()))
    sys.exit(1)
else:
    print("Pytorch dlpack on cpu  results are consistent with native HPS lookup api, mse: {}".format(diff.mean()))
    

# 2. Look up from dlpack for Pytorch tensor on GPU
print("************Look up from pytorch dlpack on GPU")
cuda_device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
out = torch.empty((1,cat_input1.flatten().shape[0]*16), dtype=torch.float32, device=cuda_device)
out_capsule = torch.utils.dlpack.to_dlpack(out)
hps.lookup_fromdlpack(key_capsule, out_capsule,"hps_demo", 0)
out_put = torch.utils.dlpack.from_dlpack(out_capsule)
print("The device type of embedding vectors that lookup dlpack from hps interface for embedding table 0 of hps_demo: {}, the vectors: {}\n\n".format(out_put.device, out_put))
diff = out_put.cpu()-embedding1.reshape(1,cat_input1.flatten().shape[0]*16)
if diff.mean() > 1e-3:
    raise RuntimeError("Too large mse between pytorch dlpack on cpu and native HPS lookup api: {}".format(diff.mean()))
    sys.exit(1)
else:
    print("Pytorch dlpack on GPU results are consistent with native HPS lookup api, mse: {}".format(diff.mean()))
 Look up from dlpack for Pytorch tensor
************Look up from pytorch dlpack on CPU
The device type of embedding keys that lookup dlpack from hps interface for embedding table 0 of hps_demo: cpu, the keys: tensor([    0, 10000,     0,  ..., 10037,    57, 10057])
[The device type of embedding vectors that lookup dlpack from hps interface for embedding table 0 of hps_demo: cpu, the vectors: tensor([[-0.0843,  0.0634,  0.0409,  ..., -0.0584,  0.0030, -0.0187]])

Pytorch dlpack on cpu  results are consistent with native HPS lookup api, mse: 0.0
************Look up from pytorch dlpack on GPU
The device type of embedding vectors that lookup dlpack from hps interface for embedding table 0 of hps_demo: cuda:0, the vectors: tensor([[-0.0843,  0.0634,  0.0409,  ..., -0.0584,  0.0030, -0.0187]],
       device='cuda:0')


Pytorch dlpack on GPU results are consistent with native HPS lookup api, mse: 0.0
# 3. Look up from dlpack for tensorflow tensor on CPU
print("Look up from dlpack for Tensorflow tensor")
from tensorflow.python.dlpack import dlpack  
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
print("***************Look up from tensorflow dlpack on CPU**********")
with tf.device('/CPU:0'):
    key_tensor = tf.constant(cat_input2.flatten(),dtype=tf.int64)
    out_tensor = tf.zeros([1, cat_input2.flatten().shape[0]*32],dtype=tf.float32)
    print("The device type of embedding keys that lookup dlpack from hps interface for embedding table 1 of hps_demo: {}, the keys: {}".format(key_tensor.device, key_tensor))
    key_capsule = tf.experimental.dlpack.to_dlpack(key_tensor)
    out_dlcapsule = tf.experimental.dlpack.to_dlpack(out_tensor)
hps.lookup_fromdlpack(key_capsule,out_dlcapsule, "hps_demo", 1)
out= tf.experimental.dlpack.from_dlpack(out_dlcapsule)
print("The device type of embedding vectors that lookup dlpack from hps interface for embedding table 1 of hps_demo: {}, the vectors: {}\n".format(out.device, out))
diff = out-embedding2.reshape(1,cat_input2.flatten().shape[0]*32)
mse = tf.reduce_mean(diff)
if mse> 1e-3:
    raise RuntimeError("Too large mse between tensorflow dlpack on cpu and native HPS lookup api: {}".format(mse))
    sys.exit(1)
else:
    print("tensorflow dlpack on CPU results are consistent with native HPS lookup api, mse: {}".format(mse))
    
# 4. Look up from dlpack for tensorflow tensor on GPU
print("***************Look up from tensorflow dlpack on GPU**********")
with tf.device('/GPU:0'):
    out_tensor = tf.zeros([1, cat_input2.flatten().shape[0]*32],dtype=tf.float32)
    key_capsule = tf.experimental.dlpack.to_dlpack(key_tensor)
    out_dlcapsule = tf.experimental.dlpack.to_dlpack(out_tensor)
hps.lookup_fromdlpack(key_capsule,out_dlcapsule, "hps_demo", 1)
out= tf.experimental.dlpack.from_dlpack(out_dlcapsule)
print("[HUGECTR][INFO] The device type of embedding vectors that lookup dlpack from hps interface for embedding table 1 of wdl: {}, the vectors: {}\n".format(out.device, out))
diff = out-embedding2.reshape(1,cat_input2.flatten().shape[0]*32)
mse = tf.reduce_mean(diff)
if mse> 1e-3:
    raise RuntimeError("Too large mse between tensorflow dlpack on cpu and native HPS lookup api: {}".format(mse))
    sys.exit(1)
else:
    print("tensorflow dlpack on GPU results are consistent with native HPS lookup api, mse: {}".format(mse))
Look up from dlpack for Tensorflow tensor
***************Look up from tensorflow dlpack on CPU**********
The device type of embedding keys that lookup dlpack from hps interface for embedding table 1 of hps_demo: /job:localhost/replica:0/task:0/device:CPU:0, the keys: [20000 30000 20000 ... 30037 20057 30057]
The device type of embedding vectors that lookup dlpack from hps interface for embedding table 1 of hps_demo: /job:localhost/replica:0/task:0/device:CPU:0, the vectors: [[ 0.04648086  0.06154778 -0.04931969 ...  0.00693844  0.04137739
  -0.06696524]]

tensorflow dlpack on CPU results are consistent with native HPS lookup api, mse: 0.0
***************Look up from tensorflow dlpack on GPU**********
[HUGECTR][INFO] The device type of embedding vectors that lookup dlpack from hps interface for embedding table 1 of wdl: /job:localhost/replica:0/task:0/device:GPU:0, the vectors: [[ 0.04648086  0.06154778 -0.04931969 ...  0.00693844  0.04137739
  -0.06696524]]

tensorflow dlpack on GPU results are consistent with native HPS lookup api, mse: 0.0