http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-demo/nvidia_logo.png

Hierarchical Parameter Server Demo

Overview

In HugeCTR version 3.5, we provide Python APIs for embedding table lookup with HugeCTR Hierarchical Parameter Server (HPS) HPS supports different database backends and GPU embedding caches.

This notebook demonstrates how to use HPS with HugeCTR Python APIs. Without loss of generality, the HPS APIs are utilized together with the ONNX Runtime APIs to create an ensemble inference model, where HPS is responsible for embedding table lookup while the ONNX model takes charge of feed forward of dense neural networks.

Installation

Get HugeCTR from NGC

The HugeCTR Python module is preinstalled in the 22.10 and later Merlin Training Container: nvcr.io/nvidia/merlin/merlin-hugectr:22.10.

You can check the existence of required libraries by running the following Python code after launching this container.

$ python3 -c "import hugectr"

Note: This Python module contains both training APIs and offline inference APIs. For online inference with Triton, please refer to HugeCTR Backend.

If you prefer to build HugeCTR from the source code instead of using the NGC container, please refer to the How to Start Your Development documentation.

Data Generation

HugeCTR provides a tool to generate synthetic datasets. The Data Generator is capable of generating datasets of different file formats and different distributions. We will generate one-hot Parquet datasets with power-law distribution for this notebook:

import hugectr
from hugectr.tools import DataGeneratorParams, DataGenerator

data_generator_params = DataGeneratorParams(
  format = hugectr.DataReaderType_t.Parquet,
  label_dim = 1,
  dense_dim = 10,
  num_slot = 4,
  i64_input_key = True,
  nnz_array = [1, 1, 1, 1],
  source = "./data_parquet/file_list.txt",
  eval_source = "./data_parquet/file_list_test.txt",
  slot_size_array = [10000, 10000, 10000, 10000],
  check_type = hugectr.Check_t.Non,
  dist_type = hugectr.Distribution_t.PowerLaw,
  power_law_type = hugectr.PowerLaw_t.Short,
  num_files = 16,
  eval_num_files = 4,
  num_samples_per_file = 40960)
data_generator = DataGenerator(data_generator_params)
data_generator.generate()
[HCTR][11:15:15][INFO][RK0][main]: Generate Parquet dataset
[HCTR][11:15:15][INFO][RK0][main]: train data folder: ./data_parquet, eval data folder: ./data_parquet, slot_size_array: 10000, 10000, 10000, 10000, nnz array: 1, 1, 1, 1, #files for train: 16, #files for eval: 4, #samples per file: 40960, Use power law distribution: 1, alpha of power law: 1.3
[HCTR][11:15:15][INFO][RK0][main]: ./data_parquet exist
[HCTR][11:15:15][INFO][RK0][main]: ./data_parquet exist
[HCTR][11:15:15][INFO][RK0][main]: ./data_parquet/train exist
[HCTR][11:15:15][INFO][RK0][main]: ./data_parquet/train/gen_0.parquet
[HCTR][11:15:17][INFO][RK0][main]: ./data_parquet/train/gen_1.parquet
[HCTR][11:15:17][INFO][RK0][main]: ./data_parquet/train/gen_2.parquet
[HCTR][11:15:17][INFO][RK0][main]: ./data_parquet/train/gen_3.parquet
[HCTR][11:15:17][INFO][RK0][main]: ./data_parquet/train/gen_4.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_5.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_6.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_7.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_8.parquet
[HCTR][11:15:18][INFO][RK0][main]: ./data_parquet/train/gen_9.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_10.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_11.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_12.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_13.parquet
[HCTR][11:15:19][INFO][RK0][main]: ./data_parquet/train/gen_14.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/train/gen_15.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/file_list.txt done!
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val exist
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val/gen_0.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val/gen_1.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val/gen_2.parquet
[HCTR][11:15:20][INFO][RK0][main]: ./data_parquet/val/gen_3.parquet
[HCTR][11:15:21][INFO][RK0][main]: ./data_parquet/file_list_test.txt done!

Train from Scratch

We can train fom scratch by performing the following steps with Python APIs:

  1. Create the solver, reader and optimizer, then initialize the model.

  2. Construct the model graph by adding input, sparse embedding and dense layers in order.

  3. Compile the model and have an overview of the model graph.

  4. Dump the model graph to the JSON file.

  5. Fit the model, save the model weights and optimizer states implicitly.

  6. Dump one batch of evaluation results to files.

%%writefile train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(model_name = "hps_demo",
                              max_eval_batches = 1,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              lr = 0.001,
                              vvgpu = [[0]],
                              i64_input_key = True,
                              repeat_dataset = True,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
                                  source = ["./data_parquet/file_list.txt"],
                                  eval_source = "./data_parquet/file_list_test.txt",
                                  check_type = hugectr.Check_t.Non,
                                  slot_size_array = [10000, 10000, 10000, 10000])
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
                        dense_dim = 10, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("data1", [1, 1], True, 2),
                        hugectr.DataReaderSparseParam("data2", [1, 1], True, 2)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 4,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "data1",
                            optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 8,
                            embedding_vec_size = 32,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding2",
                            bottom_name = "data2",
                            optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=32))                            
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=64))                            
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "reshape2", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["relu1"],
                            top_names = ["fc2"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["fc2", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.graph_to_json("hps_demo.json")
model.fit(max_iter = 1100, display = 200, eval_interval = 1000, snapshot = 1000, snapshot_prefix = "hps_demo")
model.export_predictions("hps_demo_pred_" + str(1000), "hps_demo_label_" + str(1000))
Overwriting train.py
!python3 train.py
HugeCTR Version: 3.4
====================================================Model Init=====================================================
[HCTR][11:15:26][INFO][RK0][main]: Initialize model: hps_demo
[HCTR][11:15:26][INFO][RK0][main]: Global seed is 156170895
[HCTR][11:15:26][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][11:15:27][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][11:15:27][INFO][RK0][main]: Start all2all warmup
[HCTR][11:15:27][INFO][RK0][main]: End all2all warmup
[HCTR][11:15:27][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][11:15:27][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][11:15:27][INFO][RK0][main]: num of DataReader workers: 1
[HCTR][11:15:27][INFO][RK0][main]: Vocabulary size: 40000
[HCTR][11:15:27][INFO][RK0][main]: max_vocabulary_size_per_gpu_=21845
[HCTR][11:15:27][INFO][RK0][main]: max_vocabulary_size_per_gpu_=21845
[HCTR][11:15:27][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][11:15:29][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][11:15:29][INFO][RK0][main]: gpu0 init embedding done
[HCTR][11:15:29][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][11:15:29][INFO][RK0][main]: gpu0 init embedding done
[HCTR][11:15:29][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][11:15:29][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][11:15:29][INFO][RK0][main]: label                                   Dense                         Sparse                        
label                                   dense                          data1,data2                   
(None, 1)                               (None, 10)                              
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      data1                         sparse_embedding1             (None, 2, 16)                 
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      data2                         sparse_embedding2             (None, 2, 32)                 
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (None, 32)                    
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding2             reshape2                      (None, 64)                    
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (None, 106)                   
                                        reshape2                                                                                  
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1                       fc1                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            relu1                         fc2                           (None, 1)                     
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  fc2                           loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
[HCTR][11:15:29][INFO][RK0][main]: Save the model graph to hps_demo.json successfully
=====================================================Model Fit=====================================================
[HCTR][11:15:29][INFO][RK0][main]: Use non-epoch mode with number of iterations: 1100
[HCTR][11:15:29][INFO][RK0][main]: Training batchsize: 1024, evaluation batchsize: 1024
[HCTR][11:15:29][INFO][RK0][main]: Evaluation interval: 1000, snapshot interval: 1000
[HCTR][11:15:29][INFO][RK0][main]: Dense network trainable: True
[HCTR][11:15:29][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][11:15:29][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][11:15:29][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][11:15:29][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][11:15:29][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][11:15:29][INFO][RK0][main]: Training source file: ./data_parquet/file_list.txt
[HCTR][11:15:29][INFO][RK0][main]: Evaluation source file: ./data_parquet/file_list_test.txt
[HCTR][11:15:29][INFO][RK0][main]: Iter: 200 Time(200 iters): 0.211451s Loss: 0.694128 lr:0.001
[HCTR][11:15:29][INFO][RK0][main]: Iter: 400 Time(200 iters): 0.267199s Loss: 0.689953 lr:0.001
[HCTR][11:15:29][INFO][RK0][main]: Iter: 600 Time(200 iters): 0.216242s Loss: 0.689657 lr:0.001
[HCTR][11:15:29][INFO][RK0][main]: Iter: 800 Time(200 iters): 0.215779s Loss: 0.677149 lr:0.001
[HCTR][11:15:30][INFO][RK0][main]: Iter: 1000 Time(200 iters): 0.219875s Loss: 0.681208 lr:0.001
[HCTR][11:15:30][INFO][RK0][main]: Evaluation, AUC: 0.49589
[HCTR][11:15:30][INFO][RK0][main]: Eval Time for 1 iters: 0.000359s
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][11:15:30][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][11:15:30][INFO][RK0][main]: Done
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][11:15:30][INFO][RK0][main]: Done
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][11:15:30][INFO][RK0][main]: Done
[HCTR][11:15:30][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][11:15:30][INFO][RK0][main]: Done
[HCTR][11:15:30][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][11:15:30][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][11:15:30][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][11:15:30][INFO][RK0][main]: Finish 1100 iterations with batchsize: 1024 in 1.53s.

Convert HugeCTR to ONNX

We will convert the saved HugeCTR models to ONNX using the HugeCTR to ONNX Converter. For more information about the converter, refer to the README in the onnx_converter directory of the repository.

For the sake of double checking the correctness, we will investigate both cases of conversion depending on whether or not to convert the sparse embedding models.

import hugectr2onnx
hugectr2onnx.converter.convert(onnx_model_path = "hps_demo_with_embedding.onnx",
                            graph_config = "hps_demo.json",
                            dense_model = "hps_demo_dense_1000.model",
                            convert_embedding = True,
                            sparse_models = ["hps_demo0_sparse_1000.model", "hps_demo1_sparse_1000.model"])

hugectr2onnx.converter.convert(onnx_model_path = "hps_demo_without_embedding.onnx",
                            graph_config = "hps_demo.json",
                            dense_model = "hps_demo_dense_1000.model",
                            convert_embedding = False)
The model is checked!
The model is saved at hps_demo_with_embedding.onnx
Skip sparse embedding layers in converted ONNX model
Skip sparse embedding layers in converted ONNX model
The model is checked!
The model is saved at hps_demo_without_embedding.onnx

Inference with HPS & ONNX

We will make inference by performing the following steps with Python APIs:

  1. Configure the HPS hyperparameters.

  2. Initialize the HPS object, which is responsible for embedding table lookup.

  3. Loading the Parquet data.

  4. Make inference with the HPS object and the ONNX inference session of hps_demo_without_embedding.onnx.

  5. Check the correctness by comparing with dumped evaluation results.

  6. Make inference with the ONNX inference session of hps_demo_with_embedding.onnx (double check).

from hugectr.inference import HPS, ParameterServerConfig, InferenceParams

import pandas as pd
import numpy as np

import onnxruntime as ort

slot_size_array = [10000, 10000, 10000, 10000]
key_offset = np.insert(np.cumsum(slot_size_array), 0, 0)[:-1]
batch_size = 1024

# 1. Configure the HPS hyperparameters
ps_config = ParameterServerConfig(
           emb_table_name = {"hps_demo": ["sparse_embedding1", "sparse_embedding2"]},
           embedding_vec_size = {"hps_demo": [16, 32]},
           max_feature_num_per_sample_per_emb_table = {"hps_demo": [2, 2]},
           inference_params_array = [
              InferenceParams(
                model_name = "hps_demo",
                max_batchsize = batch_size,
                hit_rate_threshold = 1.0,
                dense_model_file = "",
                sparse_model_files = ["hps_demo0_sparse_1000.model", "hps_demo1_sparse_1000.model"],
                deployed_devices = [0],
                use_gpu_embedding_cache = True,
                cache_size_percentage = 0.5,
                i64_input_key = True)
           ])

# 2. Initialize the HPS object
hps = HPS(ps_config)

# 3. Loading the Parquet data.
df = pd.read_parquet("data_parquet/val/gen_0.parquet")
dense_input_columns = df.columns[1:11]
cat_input1_columns = df.columns[11:13]
cat_input2_columns = df.columns[13:15]
dense_input = df[dense_input_columns].loc[0:batch_size-1].to_numpy(dtype=np.float32)
cat_input1 = (df[cat_input1_columns].loc[0:batch_size-1].to_numpy(dtype=np.int64) + key_offset[0:2]).reshape((batch_size, 2, 1))
cat_input2 = (df[cat_input2_columns].loc[0:batch_size-1].to_numpy(dtype=np.int64) + key_offset[2:4]).reshape((batch_size, 2, 1))

# 4. Make inference from the HPS object and the ONNX inference session of `hps_demo_without_embedding.onnx`.
embedding1 = hps.lookup(cat_input1.flatten(), "hps_demo", 0).reshape(batch_size, 2, 16)
embedding2 = hps.lookup(cat_input2.flatten(), "hps_demo", 1).reshape(batch_size, 2, 32)
sess = ort.InferenceSession("hps_demo_without_embedding.onnx")
res = sess.run(output_names=[sess.get_outputs()[0].name],
               input_feed={sess.get_inputs()[0].name: dense_input,
               sess.get_inputs()[1].name: embedding1,
               sess.get_inputs()[2].name: embedding2})
pred = res[0]

# 5. Check the correctness by comparing with dumped evaluation results.
ground_truth = np.loadtxt("hps_demo_pred_1000")
print("ground_truth: ", ground_truth)
diff = pred.flatten()-ground_truth
mse = np.mean(diff*diff)
print("pred: ", pred)
print("mse between pred and ground_truth: ", mse)

# 6. Make inference with the ONNX inference session of `hps_demo_with_embedding.onnx` (double check).
sess_ref = ort.InferenceSession("hps_demo_with_embedding.onnx")
res_ref = sess_ref.run(output_names=[sess_ref.get_outputs()[0].name],
                   input_feed={sess_ref.get_inputs()[0].name: dense_input,
                   sess_ref.get_inputs()[1].name: cat_input1,
                   sess_ref.get_inputs()[2].name: cat_input2})
pred_ref = res_ref[0]
diff_ref = pred_ref.flatten()-ground_truth
mse_ref = np.mean(diff_ref*diff_ref)
print("pred_ref: ", pred_ref)
print("mse between pred_ref and ground_truth: ", mse_ref)
[HCTR][11:17:13][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][11:17:13][INFO][RK0][main]: Creating ParallelHashMap CPU database backend...
[HCTR][11:17:13][INFO][RK0][main]: Created parallel (16 partitions) blank database backend in local memory!
[HCTR][11:17:13][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][11:17:13][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][11:17:13][INFO][RK0][main]: Table: hps_et.hps_demo.sparse_embedding1; cached 15749 / 15749 embeddings in volatile database (ParallelHashMap); load: 15749 / 18446744073709551615 (0.00%).
[HCTR][11:17:13][INFO][RK0][main]: Table: hps_et.hps_demo.sparse_embedding2; cached 15781 / 15781 embeddings in volatile database (ParallelHashMap); load: 15781 / 18446744073709551615 (0.00%).
[HCTR][11:17:13][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][11:17:13][INFO][RK0][main]: Create embedding cache in device 0.
[HCTR][11:17:13][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][11:17:13][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][11:17:13][INFO][RK0][main]: Create inference session on device: 0
[HCTR][11:17:13][INFO][RK0][main]: Model name: hps_demo
[HCTR][11:17:13][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][11:17:13][INFO][RK0][main]: Use I64 input key: True
ground_truth:  [0.456111 0.417843 0.428037 ... 0.336745 0.53599  0.508711]
pred:  [[0.45611122]
 [0.4178428 ]
 [0.42803708]
 ...
 [0.3367453 ]
 [0.53599   ]
 [0.5087108 ]]
mse between pred and ground_truth:  8.241691052249094e-14
pred_ref:  [[0.45611122]
 [0.4178428 ]
 [0.42803708]
 ...
 [0.3367453 ]
 [0.53599   ]
 [0.5087108 ]]
mse between pred_ref and ground_truth:  7.573986338301264e-05
2022-03-31 11:17:13.779336470 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'key_to_indice_hash_all_tables'. It is not used by any node and should be removed from the model.

Lookup the Embedding Vector from DLPack

We also provide a lookup_fromdlpack interface that could query embedding keys on the CPU and return the embedding vectors on the GPU/CPU.

  1. Suppose you have created a Pytorch/Tensorflow tensor that stores the embedded keys.

  2. Convert the embedding key tensor to DLPack capsule through the corresponding platform’s to_dlpack function.

  3. Creates an empty tensor as a buffer to store embedding vectors.

  4. Convert a buffer tensor to DLPack capsule.

  5. Lookup the embedding vector of the corresponding embedding key directly through lookup_fromdlpack interface, and output it to the embedding vector buffer tensor

  6. If the output capsule is allocated on the GPU, then a device_id needs to be specified in lookup_fromdlpack interface for corresponding embedding cache. If not specified, the default value is device 0

Note: Please make sure that tensorflow or pytorch have been installed correctly in the container

embedding1 = hps.lookup(cat_input1.flatten(), "hps_demo", 0).reshape(batch_size, 2, 16)
embedding2 = hps.lookup(cat_input2.flatten(), "hps_demo", 1).reshape(batch_size, 2, 32)

# 1. Look up from dlpack for Pytorch tensor on CPU
print(" Look up from dlpack for Pytorch tensor")
import torch.utils.dlpack
import os
print("************Look up from pytorch dlpack on CPU")
device = torch.device("cpu")
key = torch.tensor(cat_input1.flatten(),dtype=torch.int64, device=device)
out = torch.empty((1,cat_input1.flatten().shape[0]*16), dtype=torch.float32, device=device)
key_capsule = torch.utils.dlpack.to_dlpack(key)
print("The device type of embedding keys that lookup dlpack from hps interface for embedding table 0 of hps_demo: {}, the keys: {}".format(key.device, key))
out_capsule = torch.utils.dlpack.to_dlpack(out)
# Lookup the embedding vectors from dlpack
hps.lookup_fromdlpack(key_capsule, out_capsule,"hps_demo", 0)
out_put = torch.utils.dlpack.from_dlpack(out_capsule)
print("[The device type of embedding vectors that lookup dlpack from hps interface for embedding table 0 of hps_demo: {}, the vectors: {}\n".format(out_put.device, out_put))
diff = out_put-embedding1.reshape(1,cat_input1.flatten().shape[0]*16)
if diff.mean() > 1e-4:
    raise RuntimeError("Too large mse between pytorch dlpack on cpu and native HPS lookup api: {}".format(diff.mean()))
    sys.exit(1)
else:
    print("Pytorch dlpack on cpu  results are consistent with native HPS lookup api, mse: {}".format(diff.mean()))
    

# 2. Look up from dlpack for Pytorch tensor on GPU
print("************Look up from pytorch dlpack on GPU")
cuda_device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
key = torch.tensor(cat_input1.flatten(),dtype=torch.int64, device=device)
key_capsule = torch.utils.dlpack.to_dlpack(key)
out = torch.empty((cat_input1.flatten().shape[0]*16), dtype=torch.float32, device=cuda_device)
out_capsule = torch.utils.dlpack.to_dlpack(out)
hps.lookup_fromdlpack(key_capsule, out_capsule,"hps_demo", 0)
out_put = torch.utils.dlpack.from_dlpack(out_capsule)
print("The device type of embedding vectors that lookup dlpack from hps interface for embedding table 0 of hps_demo: {}, the vectors: {}\n\n".format(out_put.device, out_put))
diff = out_put.cpu()-embedding1.reshape(1,cat_input1.flatten().shape[0]*16)
if diff.mean() > 1e-3:
    raise RuntimeError("Too large mse between pytorch dlpack on cpu and native HPS lookup api: {}".format(diff.mean()))
    sys.exit(1)
else:
    print("Pytorch dlpack on GPU results are consistent with native HPS lookup api, mse: {}".format(diff.mean()))
 Look up from dlpack for Pytorch tensor
************Look up from pytorch dlpack on CPU
The device type of embedding keys that lookup dlpack from hps interface for embedding table 0 of hps_demo: cpu, the keys: tensor([    4, 10000,    17,  ..., 10208,     5, 10012])
[The device type of embedding vectors that lookup dlpack from hps interface for embedding table 0 of hps_demo: cpu, the vectors: tensor([[ 0.0201,  0.0179,  0.0029,  ...,  0.0168, -0.0059,  0.0017]])

Pytorch dlpack on cpu  results are consistent with native HPS lookup api, mse: 0.0
************Look up from pytorch dlpack on GPU
The device type of embedding vectors that lookup dlpack from hps interface for embedding table 0 of hps_demo: cuda:0, the vectors: tensor([ 0.0201,  0.0179,  0.0029,  ...,  0.0168, -0.0059,  0.0017],
       device='cuda:0')


Pytorch dlpack on GPU results are consistent with native HPS lookup api, mse: 0.0
# 3. Look up from dlpack for tensorflow tensor on CPU
print("Look up from dlpack for Tensorflow tensor")
from tensorflow.python.dlpack import dlpack  
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
print("***************Look up from tensorflow dlpack on CPU**********")
with tf.device('/CPU:0'):
    key_tensor = tf.constant(cat_input2.flatten(),dtype=tf.int64)
    out_tensor = tf.zeros([1, cat_input2.flatten().shape[0]*32],dtype=tf.float32)
    print("The device type of embedding keys that lookup dlpack from hps interface for embedding table 1 of hps_demo: {}, the keys: {}".format(key_tensor.device, key_tensor))
    key_capsule = tf.experimental.dlpack.to_dlpack(key_tensor)
    out_dlcapsule = tf.experimental.dlpack.to_dlpack(out_tensor)
hps.lookup_fromdlpack(key_capsule,out_dlcapsule, "hps_demo", 1)
out= tf.experimental.dlpack.from_dlpack(out_dlcapsule)
print("The device type of embedding vectors that lookup dlpack from hps interface for embedding table 1 of hps_demo: {}, the vectors: {}\n".format(out.device, out))
diff = out-embedding2.reshape(1,cat_input2.flatten().shape[0]*32)
mse = tf.reduce_mean(diff)
if mse> 1e-3:
    raise RuntimeError("Too large mse between tensorflow dlpack on cpu and native HPS lookup api: {}".format(mse))
    sys.exit(1)
else:
    print("tensorflow dlpack on CPU results are consistent with native HPS lookup api, mse: {}".format(mse))
    
# 4. Look up from dlpack for tensorflow tensor on GPU
print("***************Look up from tensorflow dlpack on GPU**********")
with tf.device('/GPU:0'):
    key_tensor = tf.constant(cat_input2.flatten(),dtype=tf.int64)
    out_tensor = tf.zeros([cat_input2.flatten().shape[0]*32],dtype=tf.float32)
    key_capsule = tf.experimental.dlpack.to_dlpack(key_tensor)
    out_dlcapsule = tf.experimental.dlpack.to_dlpack(out_tensor)
hps.lookup_fromdlpack(key_capsule,out_dlcapsule, "hps_demo", 1)
out= tf.experimental.dlpack.from_dlpack(out_dlcapsule)
print("[HUGECTR][INFO] The device type of embedding vectors that lookup dlpack from hps interface for embedding table 1 of wdl: {}, the vectors: {}\n".format(out.device, out))
diff = out-embedding2.reshape(1,cat_input2.flatten().shape[0]*32)
mse = tf.reduce_mean(diff)
if mse> 1e-3:
    raise RuntimeError("Too large mse between tensorflow dlpack on cpu and native HPS lookup api: {}".format(mse))
    sys.exit(1)
else:
    print("tensorflow dlpack on GPU results are consistent with native HPS lookup api, mse: {}".format(mse))
Look up from dlpack for Tensorflow tensor
***************Look up from tensorflow dlpack on CPU**********
The device type of embedding keys that lookup dlpack from hps interface for embedding table 1 of hps_demo: /job:localhost/replica:0/task:0/device:CPU:0, the keys: [20005 30347 20001 ... 30174 20000 30013]
The device type of embedding vectors that lookup dlpack from hps interface for embedding table 1 of hps_demo: /job:localhost/replica:0/task:0/device:CPU:0, the vectors: [[ 0.02120136  0.03807243 -0.04021286 ... -0.00556568  0.00462132
   0.01774719]]
tensorflow dlpack on CPU results are consistent with native HPS lookup api, mse: 0.0
***************Look up from tensorflow dlpack on GPU**********
[HUGECTR][INFO] The device type of embedding vectors that lookup dlpack from hps interface for embedding table 1 of wdl: /job:localhost/replica:0/task:0/device:GPU:0, the vectors: [ 0.02120136  0.03807243 -0.04021286 ... -0.00556568  0.00462132
  0.01774719]

tensorflow dlpack on GPU results are consistent with native HPS lookup api, mse: 0.0

Multi-process inference

It is possible to share the a hashmap database between multiple processes. The followng example launches 3 processes which achieve this using the operating system’s shared memory, which is located at /dev/shm in most unix systems. In this example, we separate processes into a primary and multiple secondary processes, and only the primary process initializes the shared memory database. The secondary processes wait until the shared memory has been fully initialized. However, note that inter-process database access is guaranteed to be thread-safe. Therefore, it is also possible to implement more complicated initialization/refresh mechanisms for your use-case.

%%writefile multi_process_hps.py
import os
import time
import multiprocessing as mp
import pandas as pd
import numpy as np
import onnxruntime as ort
from hugectr import DatabaseType_t
from hugectr.inference import HPS, ParameterServerConfig, InferenceParams, VolatileDatabaseParams

slot_size_array = [10000, 10000, 10000, 10000]
key_offset = np.insert(np.cumsum(slot_size_array), 0, 0)[:-1]
batch_size = 2048

def create_hps(name, initialized):
    print(f'subprocess:{name}{os.getpid()})launch...')
    
    # 1. Let secondary processes wait until shared memory is initialized.
    while name != 'primary' and initialized.value == 0:
        print(f'Subprocess {name} awaiting initialization...')
        time.sleep(1)

    # 2. Configure the HPS hyperparameters
    ps_config = ParameterServerConfig(
           emb_table_name = {"hps_demo": ["sparse_embedding1", "sparse_embedding2"]},
           embedding_vec_size = {"hps_demo": [16, 32]},
           max_feature_num_per_sample_per_emb_table = {"hps_demo": [2, 2]},
           inference_params_array = [
              InferenceParams(
                model_name = "hps_demo",
                max_batchsize = batch_size,
                hit_rate_threshold = 1.0,
                dense_model_file = "",
                sparse_model_files = ["hps_demo0_sparse_1000.model", "hps_demo1_sparse_1000.model"],
                deployed_devices = [0],
                use_gpu_embedding_cache = True,
                cache_size_percentage = 0.5,
                i64_input_key = True)
           ],
           volatile_db = VolatileDatabaseParams(
                DatabaseType_t.multi_process_hash_map,  # Use /dev/shm instead of normal memory for storage.
                # Skips initializing modl. If we run HPS in multiple processes, only one needs to initialize.
                initialize_after_startup = name == 'primary',
               
           ))

    # 3. Initialize the HPS object
    hps = HPS(ps_config)
    initialized.value += 1
    print(f'Subprocess {name} initialized')

    # 4. Load query data.
    df = pd.read_parquet("data_parquet/val/gen_0.parquet")
    dense_input_columns = df.columns[1:11]
    cat_input1_columns = df.columns[11:13]
    cat_input2_columns = df.columns[13:15]
    dense_input = df[dense_input_columns].loc[0:batch_size-1].to_numpy(dtype=np.float32)
    cat_input1 = (df[cat_input1_columns].loc[0:batch_size-1].to_numpy(dtype=np.int64) + key_offset[0:2]).reshape((batch_size, 2, 1))
    cat_input2 = (df[cat_input2_columns].loc[0:batch_size-1].to_numpy(dtype=np.int64) + key_offset[2:4]).reshape((batch_size, 2, 1))

    # 5. Make inference from the HPS object and the ONNX inference session of `hps_demo_without_embedding.onnx`.
    embedding1 = hps.lookup(cat_input1.flatten(), "hps_demo", 0).reshape(batch_size, 2, 16)
    embedding2 = hps.lookup(cat_input2.flatten(), "hps_demo", 1).reshape(batch_size, 2, 32)
    sess = ort.InferenceSession("hps_demo_without_embedding.onnx")
    res = sess.run(output_names=[sess.get_outputs()[0].name],
                   input_feed={sess.get_inputs()[0].name: dense_input,
                   sess.get_inputs()[1].name: embedding1,
                   sess.get_inputs()[2].name: embedding2})
    pred = res[0]
            
    # 6. Check the correctness by comparing with dumped evaluation results.
    ground_truth = np.loadtxt("hps_demo_pred_1000")
    print(f'Subprocess {name}; ground_truth: {ground_truth}')
    diff = pred.flatten()-ground_truth
    mse = np.mean(diff*diff)
    print(f'Subprocess {name}; pred: {pred}')
    print(f'Subprocess {name}; mse between pred and ground_truth: {mse}')
    
    # 7. Make inference with the ONNX inference session of `hps_demo_with_embedding.onnx` (double check).
    sess_ref = ort.InferenceSession("hps_demo_with_embedding.onnx")
    res_ref = sess_ref.run(output_names=[sess_ref.get_outputs()[0].name],
                   input_feed={sess_ref.get_inputs()[0].name: dense_input,
                   sess_ref.get_inputs()[1].name: cat_input1,
                   sess_ref.get_inputs()[2].name: cat_input2})
    pred_ref = res_ref[0]
    diff_ref = pred_ref.flatten()-ground_truth
    mse_ref = np.mean(diff_ref*diff_ref)
    print(f'Subprocess {name}; pred_ref: {pred_ref}')
    print(f'Subprocess {name}; mse between pred_ref and ground_truth: {mse_ref}')
    
    print(f'Subprocess {name} exiting...')

if __name__ == '__main__':
    # Destroy shared memory.
    try:
        os.remove('/dev/shm/hctr_mp_hash_map_database')
    except:
        pass
    
    initialized = mp.Value('i', 0)

    # Create sub processes.
    processes = [
        mp.Process(target=create_hps, args=('primary', initialized)),
        mp.Process(target=create_hps, args=('secondary', initialized)),
        mp.Process(target=create_hps, args=('secondary', initialized)),
    ]
    for p in processes:
        p.start()

    # Go to sleep until subprocesses are initialized.
    while initialized.value < len(processes):
        print(f'Main process; awaiting subprocess initializatiopn... So far {initialized.value} initialized...')
        time.sleep(1)
        
    # Wait for subprocesses to exit.
    for i, p in enumerate(processes):
        print(f'Main process; awaiting subprocess {i} to exit...')
        p.join()
    print(f'Main process; exiting...')
Overwriting multi_process_hps.py
!python3 multi_process_hps.py
subprocess:primary(68230)launch...
[HCTR][12:08:05.939][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
Main process; awaiting subprocess initializatiopn... So far 0 initialized...
====================================================HPS Create====================================================
[HCTR][12:08:05.940][INFO][RK0][main]: Creating Multi-Process HashMap CPU database backend...
subprocess:secondary(68231)launch...
Subprocess secondary awaiting initialization...
[HCTR][12:08:05.940][INFO][RK0][main]: Connecting to shared memory 'hctr_mp_hash_map_database'...
subprocess:secondary(68232)launch...
Subprocess secondary awaiting initialization...
[HCTR][12:08:05.940][INFO][RK0][main]: Connected to shared memory 'hctr_mp_hash_map_database'; OS total = 274877906944 bytes, OS available = 265611984896 bytes, HCTR allocated = 17179869184 bytes, HCTR free = 17179868640 bytes; other processes connected = 0
[HCTR][12:08:06.440][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][12:08:06.440][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][12:08:06.440][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][12:08:06.440][INFO][RK0][main]: Using Local file system backend.
Main process; awaiting subprocess initializatiopn... So far 0 initialized...
Subprocess secondary awaiting initialization...
Subprocess secondary awaiting initialization...
[HCTR][12:08:07.200][INFO][RK0][main]: Table: hps_et.hps_demo.sparse_embedding1; cached 18451 / 18451 embeddings in volatile database (MultiProcessHashMapBackend); load: 18451 / 18446744073709551615 (0.00%).
[HCTR][12:08:07.200][INFO][RK0][main]: Using Local file system backend.
[HCTR][12:08:07.875][INFO][RK0][main]: Table: hps_et.hps_demo.sparse_embedding2; cached 18345 / 18345 embeddings in volatile database (MultiProcessHashMapBackend); load: 18345 / 18446744073709551615 (0.00%).
[HCTR][12:08:07.875][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][12:08:07.876][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][12:08:07.893][INFO][RK0][main]: Model name: hps_demo
[HCTR][12:08:07.893][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][12:08:07.893][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][12:08:07.893][INFO][RK0][main]: Use I64 input key: True
[HCTR][12:08:07.893][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][12:08:07.893][INFO][RK0][main]: The size of thread pool: 256
[HCTR][12:08:07.893][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][12:08:07.893][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][12:08:07.893][INFO][RK0][main]: The refresh percentage : 0.000000
Main process; awaiting subprocess initializatiopn... So far 0 initialized...
Subprocess secondary awaiting initialization...
Subprocess secondary awaiting initialization...
[HCTR][12:08:08.781][INFO][RK0][main]: Creating lookup session for hps_demo on device: 0
Subprocess primary initialized
2022-10-25 12:08:08.862765243 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'key_to_indice_hash_all_tables'. It is not used by any node and should be removed from the model.
Subprocess primary; ground_truth: [0.45834  0.483733 0.521425 ... 0.492849 0.385835 0.533917]
Subprocess primary; pred: [[0.47144616]
 [0.48876405]
 [0.5209218 ]
 ...
 [0.5083375 ]
 [0.46721274]
 [0.44642204]]
Subprocess primary; mse between pred and ground_truth: 0.0018104517608496385
Subprocess primary; pred_ref: [[0.47144616]
 [0.48876405]
 [0.5209218 ]
 ...
 [0.5083375 ]
 [0.46721274]
 [0.44642204]]
Subprocess primary; mse between pred_ref and ground_truth: 0.0018104517608496385
Subprocess primary exiting...
[HCTR][12:08:08.934][INFO][RK0][main]: Disconnecting from shared memory 'hctr_mp_hash_map_database'.
Main process; awaiting subprocess initializatiopn... So far 1 initialized...
[HCTR][12:08:08.943][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][12:08:08.943][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
====================================================HPS Create====================================================
====================================================HPS Create====================================================
[HCTR][12:08:08.944][INFO][RK0][main]: Creating Multi-Process HashMap CPU database backend...
[HCTR][12:08:08.944][INFO][RK0][main]: Creating Multi-Process HashMap CPU database backend...
[HCTR][12:08:08.944][INFO][RK0][main]: Connecting to shared memory 'hctr_mp_hash_map_database'...
[HCTR][12:08:08.944][INFO][RK0][main]: Connecting to shared memory 'hctr_mp_hash_map_database'...
[HCTR][12:08:09.443][INFO][RK0][main]: Detached last process from shared memory 'hctr_mp_hash_map_database'. Auto remove in progress...
[HCTR][12:08:09.443][INFO][RK0][main]: Connected to shared memory 'hctr_mp_hash_map_database'; OS total = 274877906944 bytes, OS available = 256346140672 bytes, HCTR allocated = 17179869184 bytes, HCTR free = 7914048128 bytes; other processes connected = 0
[HCTR][12:08:09.943][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][12:08:09.943][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][12:08:09.943][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][12:08:09.943][INFO][RK0][main]: Using Local file system backend.
Main process; awaiting subprocess initializatiopn... So far 1 initialized...
[HCTR][12:08:09.948][INFO][RK0][main]: Using Local file system backend.
[HCTR][12:08:09.953][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][12:08:09.953][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][12:08:09.970][INFO][RK0][main]: Model name: hps_demo
[HCTR][12:08:09.970][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][12:08:09.970][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][12:08:09.970][INFO][RK0][main]: Use I64 input key: True
[HCTR][12:08:09.970][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][12:08:09.970][INFO][RK0][main]: The size of thread pool: 256
[HCTR][12:08:09.970][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][12:08:09.970][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][12:08:09.970][INFO][RK0][main]: The refresh percentage : 0.000000
[HCTR][12:08:09.943][INFO][RK0][main]: Connected to shared memory 'hctr_mp_hash_map_database'; OS total = 274877906944 bytes, OS available = 256346140672 bytes, HCTR allocated = 17179869184 bytes, HCTR free = 7914048128 bytes; other processes connected = 1
[HCTR][12:08:10.444][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][12:08:10.444][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][12:08:10.444][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][12:08:10.444][INFO][RK0][main]: Using Local file system backend.
[HCTR][12:08:10.448][INFO][RK0][main]: Using Local file system backend.
[HCTR][12:08:10.453][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][12:08:10.453][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][12:08:10.470][INFO][RK0][main]: Model name: hps_demo
[HCTR][12:08:10.470][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][12:08:10.470][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][12:08:10.470][INFO][RK0][main]: Use I64 input key: True
[HCTR][12:08:10.470][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][12:08:10.470][INFO][RK0][main]: The size of thread pool: 256
[HCTR][12:08:10.470][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][12:08:10.470][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][12:08:10.470][INFO][RK0][main]: The refresh percentage : 0.000000
Main process; awaiting subprocess initializatiopn... So far 1 initialized...
[HCTR][12:08:10.971][INFO][RK0][main]: Creating lookup session for hps_demo on device: 0
Subprocess secondary initialized
2022-10-25 12:08:11.064851917 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'key_to_indice_hash_all_tables'. It is not used by any node and should be removed from the model.
Subprocess secondary; ground_truth: [0.45834  0.483733 0.521425 ... 0.492849 0.385835 0.533917]
Subprocess secondary; pred: [[0.47144616]
 [0.48876405]
 [0.5209218 ]
 ...
 [0.5083375 ]
 [0.46721274]
 [0.44642204]]
Subprocess secondary; mse between pred and ground_truth: 0.0018104517608496385
Subprocess secondary; pred_ref: [[0.47144616]
 [0.48876405]
 [0.5209218 ]
 ...
 [0.5083375 ]
 [0.46721274]
 [0.44642204]]
Subprocess secondary; mse between pred_ref and ground_truth: 0.0018104517608496385
Subprocess secondary exiting...
[HCTR][12:08:11.159][INFO][RK0][main]: Disconnecting from shared memory 'hctr_mp_hash_map_database'.
[HCTR][12:08:11.359][INFO][RK0][main]: Creating lookup session for hps_demo on device: 0
Subprocess secondary initialized
2022-10-25 12:08:11.801969051 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'key_to_indice_hash_all_tables'. It is not used by any node and should be removed from the model.
Subprocess secondary; ground_truth: [0.45834  0.483733 0.521425 ... 0.492849 0.385835 0.533917]
Subprocess secondary; pred: [[0.47144616]
 [0.48876405]
 [0.5209218 ]
 ...
 [0.5083375 ]
 [0.46721274]
 [0.44642204]]
Subprocess secondary; mse between pred and ground_truth: 0.0018104517608496385
Subprocess secondary; pred_ref: [[0.47144616]
 [0.48876405]
 [0.5209218 ]
 ...
 [0.5083375 ]
 [0.46721274]
 [0.44642204]]
Subprocess secondary; mse between pred_ref and ground_truth: 0.0018104517608496385
Subprocess secondary exiting...
[HCTR][12:08:11.896][INFO][RK0][main]: Disconnecting from shared memory 'hctr_mp_hash_map_database'.
Main process; awaiting subprocess 0 to exit...
Main process; awaiting subprocess 1 to exit...
[HCTR][12:08:12.445][INFO][RK0][main]: Detached last process from shared memory 'hctr_mp_hash_map_database'. Auto remove in progress...
Main process; awaiting subprocess 2 to exit...
Main process; exiting...