http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png

Multi-GPU Offline Inference

Overview

In HugeCTR version 3.4.1, we provide Python APIs to perform multi-GPU offline inference. This work leverages the HugeCTR Hierarchical Parameter Server and enables concurrent execution on multiple devices. The Norm or Parquet dataset format is currently supported by multi-GPU offline inference.

This notebook explains how to perform multi-GPU offline inference with the HugeCTR Python APIs. For more details about the API, see the HugeCTR Python Interface documentation.

Installation

Get HugeCTR from NGC

The HugeCTR Python module is preinstalled in the 22.04 and later Merlin Training Container: nvcr.io/nvidia/merlin/merlin-training:22.04.

You can check the existence of required libraries by running the following Python code after launching this container.

$ python3 -c "import hugectr"

Note: This Python module contains both training APIs and offline inference APIs. For online inference with Triton Inference Server, refer to the HugeCTR Backend documentation.

If you prefer to build HugeCTR from the source code instead of using the NGC container, refer to the How to Start Your Development documentation.

Data Generation

HugeCTR provides a tool to generate synthetic datasets. The Data Generator class is capable of generating datasets in different formats and with different distributions. We will generate multi-hot Parquet datasets with a power-law distribution for this notebook:

import hugectr
from hugectr.tools import DataGeneratorParams, DataGenerator

data_generator_params = DataGeneratorParams(
  format = hugectr.DataReaderType_t.Parquet,
  label_dim = 2,
  dense_dim = 2,
  num_slot = 3,
  i64_input_key = True,
  nnz_array = [2, 1, 3],
  source = "./multi_hot_parquet/file_list.txt",
  eval_source = "./multi_hot_parquet/file_list_test.txt",
  slot_size_array = [10000, 10000, 10000],
  check_type = hugectr.Check_t.Non,
  dist_type = hugectr.Distribution_t.PowerLaw,
  power_law_type = hugectr.PowerLaw_t.Short,
  num_files = 16,
  eval_num_files = 4)
data_generator = DataGenerator(data_generator_params)
data_generator.generate()
[HCTR][15:01:03][INFO][RK0][main]: Generate Parquet dataset
[HCTR][15:01:03][INFO][RK0][main]: train data folder: ./multi_hot_parquet, eval data folder: ./multi_hot_parquet, slot_size_array: 10000, 10000, 10000, nnz array: 2, 1, 3, #files for train: 16, #files for eval: 4, #samples per file: 40960, Use power law distribution: 1, alpha of power law: 1.3
[HCTR][15:01:03][INFO][RK0][main]: ./multi_hot_parquet exist
[HCTR][15:01:03][INFO][RK0][main]: ./multi_hot_parquet/train/gen_0.parquet
[HCTR][15:01:05][INFO][RK0][main]: ./multi_hot_parquet/train/gen_1.parquet
[HCTR][15:01:05][INFO][RK0][main]: ./multi_hot_parquet/train/gen_2.parquet
[HCTR][15:01:05][INFO][RK0][main]: ./multi_hot_parquet/train/gen_3.parquet
[HCTR][15:01:05][INFO][RK0][main]: ./multi_hot_parquet/train/gen_4.parquet
[HCTR][15:01:05][INFO][RK0][main]: ./multi_hot_parquet/train/gen_5.parquet
[HCTR][15:01:05][INFO][RK0][main]: ./multi_hot_parquet/train/gen_6.parquet
[HCTR][15:01:06][INFO][RK0][main]: ./multi_hot_parquet/train/gen_7.parquet
[HCTR][15:01:06][INFO][RK0][main]: ./multi_hot_parquet/train/gen_8.parquet
[HCTR][15:01:06][INFO][RK0][main]: ./multi_hot_parquet/train/gen_9.parquet
[HCTR][15:01:06][INFO][RK0][main]: ./multi_hot_parquet/train/gen_10.parquet
[HCTR][15:01:06][INFO][RK0][main]: ./multi_hot_parquet/train/gen_11.parquet
[HCTR][15:01:06][INFO][RK0][main]: ./multi_hot_parquet/train/gen_12.parquet
[HCTR][15:01:07][INFO][RK0][main]: ./multi_hot_parquet/train/gen_13.parquet
[HCTR][15:01:07][INFO][RK0][main]: ./multi_hot_parquet/train/gen_14.parquet
[HCTR][15:01:07][INFO][RK0][main]: ./multi_hot_parquet/train/gen_15.parquet
[HCTR][15:01:07][INFO][RK0][main]: ./multi_hot_parquet/file_list.txt done!
[HCTR][15:01:07][INFO][RK0][main]: ./multi_hot_parquet/val/gen_0.parquet
[HCTR][15:01:07][INFO][RK0][main]: ./multi_hot_parquet/val/gen_1.parquet
[HCTR][15:01:08][INFO][RK0][main]: ./multi_hot_parquet/val/gen_2.parquet
[HCTR][15:01:08][INFO][RK0][main]: ./multi_hot_parquet/val/gen_3.parquet
[HCTR][15:01:08][INFO][RK0][main]: ./multi_hot_parquet/file_list_test.txt done!

Train from Scratch

We can train fom scratch by performing the following steps with Python APIs:

  1. Create the solver, reader and optimizer, then initialize the model.

  2. Construct the model graph by adding input, sparse embedding and dense layers in order.

  3. Compile the model and have an overview of the model graph.

  4. Dump the model graph to a JSON file.

  5. Fit the model, save the model weights and optimizer states implicitly.

  6. Dump one batch of evaluation results to files.

%%writefile multi_hot_train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(model_name = "multi_hot",
                              max_eval_batches = 1,
                              batchsize_eval = 16384,
                              batchsize = 16384,
                              lr = 0.001,
                              vvgpu = [[0]],
                              i64_input_key = True,
                              repeat_dataset = True,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
                                  source = ["./multi_hot_parquet/file_list.txt"],
                                  eval_source = "./multi_hot_parquet/file_list_test.txt",
                                  check_type = hugectr.Check_t.Non,
                                  slot_size_array = [10000, 10000, 10000])
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 2, label_name = "label",
                        dense_dim = 2, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("data1", [2, 1], False, 2),
                        hugectr.DataReaderSparseParam("data2", 3, False, 1),]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 4,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "data1",
                            optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 2,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding2",
                            bottom_name = "data2",
                            optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=32))                            
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=16))                            
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "reshape2", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["relu1"],
                            top_names = ["fc2"],
                            num_output=2))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.MultiCrossEntropyLoss,
                            bottom_names = ["fc2", "label"],
                            top_names = ["loss"],
                            target_weight_vec = [0.5, 0.5]))
model.compile()
model.summary()
model.graph_to_json("multi_hot.json")
model.fit(max_iter = 1100, display = 200, eval_interval = 1000, snapshot = 1000, snapshot_prefix = "multi_hot")
model.export_predictions("multi_hot_pred_" + str(1000), "multi_hot_label_" + str(1000))
Overwriting multi_hot_train.py
!python3 multi_hot_train.py
HugeCTR Version: 3.4
====================================================Model Init=====================================================
[HCTR][15:04:04][INFO][RK0][main]: Initialize model: multi_hot
[HCTR][15:04:04][INFO][RK0][main]: Global seed is 2258929170
[HCTR][15:04:04][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][15:04:05][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][15:04:05][INFO][RK0][main]: Start all2all warmup
[HCTR][15:04:05][INFO][RK0][main]: End all2all warmup
[HCTR][15:04:05][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][15:04:05][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][15:04:05][INFO][RK0][main]: num of DataReader workers: 1
[HCTR][15:04:05][INFO][RK0][main]: Vocabulary size: 30000
[HCTR][15:04:05][INFO][RK0][main]: max_vocabulary_size_per_gpu_=65536
[HCTR][15:04:05][INFO][RK0][main]: max_vocabulary_size_per_gpu_=32768
[HCTR][15:04:05][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][15:04:14][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][15:04:14][INFO][RK0][main]: gpu0 init embedding done
[HCTR][15:04:14][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][15:04:14][INFO][RK0][main]: gpu0 init embedding done
[HCTR][15:04:14][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][15:04:14][INFO][RK0][main]: Warm-up done
[HCTR][15:04:14][INFO][RK0][main]: ===================================================Model Summary===================================================
label                                   Dense                         Sparse                        
label                                   dense                          data1,data2                   
(None, 2)                               (None, 2)                               
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      data1                         sparse_embedding1             (None, 2, 16)                 
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      data2                         sparse_embedding2             (None, 1, 16)                 
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (None, 32)                    
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding2             reshape2                      (None, 16)                    
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (None, 50)                    
                                        reshape2                                                                                  
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1                       fc1                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            relu1                         fc2                           (None, 2)                     
------------------------------------------------------------------------------------------------------------------
MultiCrossEntropyLoss                   fc2                           loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
[HCTR][15:04:14][INFO][RK0][main]: Save the model graph to multi_hot.json successfully
=====================================================Model Fit=====================================================
[HCTR][15:04:14][INFO][RK0][main]: Use non-epoch mode with number of iterations: 1100
[HCTR][15:04:14][INFO][RK0][main]: Training batchsize: 16384, evaluation batchsize: 16384
[HCTR][15:04:14][INFO][RK0][main]: Evaluation interval: 1000, snapshot interval: 1000
[HCTR][15:04:14][INFO][RK0][main]: Dense network trainable: True
[HCTR][15:04:14][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][15:04:14][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][15:04:14][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][15:04:14][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][15:04:14][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][15:04:14][INFO][RK0][main]: Training source file: ./multi_hot_parquet/file_list.txt
[HCTR][15:04:14][INFO][RK0][main]: Evaluation source file: ./multi_hot_parquet/file_list_test.txt
[HCTR][15:04:17][INFO][RK0][main]: Iter: 200 Time(200 iters): 2.73086s Loss: 0.342286 lr:0.001
[HCTR][15:04:20][INFO][RK0][main]: Iter: 400 Time(200 iters): 2.57674s Loss: 0.339907 lr:0.001
[HCTR][15:04:22][INFO][RK0][main]: Iter: 600 Time(200 iters): 2.59306s Loss: 0.338068 lr:0.001
[HCTR][15:04:25][INFO][RK0][main]: Iter: 800 Time(200 iters): 2.56907s Loss: 0.334571 lr:0.001
[HCTR][15:04:27][INFO][RK0][main]: Iter: 1000 Time(200 iters): 2.57584s Loss: 0.331733 lr:0.001
[HCTR][15:04:27][INFO][RK0][main]: Evaluation, AUC: 0.500278
[HCTR][15:04:27][INFO][RK0][main]: Eval Time for 1 iters: 0.001344s
[HCTR][15:04:27][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][15:04:27][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][15:04:27][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][15:04:27][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][15:04:27][INFO][RK0][main]: Done
[HCTR][15:04:27][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][15:04:27][INFO][RK0][main]: Done
[HCTR][15:04:28][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][15:04:28][INFO][RK0][main]: Done
[HCTR][15:04:28][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][15:04:28][INFO][RK0][main]: Done
[HCTR][15:04:28][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][15:04:28][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][15:04:28][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][15:04:29][INFO][RK0][main]: Finish 1100 iterations with batchsize: 16384 in 14.54s.

Multi-GPU Offline Inference

We can demonstrate multi-GPU offline inference by performing the following steps with Python APIs:

  1. Configure the inference hyperparameters.

  2. Initialize the inference model. The model is a collection of inference sessions deployed on multiple devices.

  3. Make an inference from the evaluation dataset.

  4. Check the correctness of the inference by comparing it with the dumped evaluation results.

Note: The max_batchsize configured within InferenceParams is the global batch size. The value for max_batchsize should be divisible by the number of deployed devices. The numpy array returned by InferenceModel.predict is of the shape (max_batchsize * num_batches, label_dim).

import hugectr
from hugectr.inference import InferenceModel, InferenceParams
import numpy as np
from mpi4py import MPI

model_config = "multi_hot.json"
inference_params = InferenceParams(
    model_name = "multi_hot",
    max_batchsize = 1024,
    hit_rate_threshold = 1.0,
    dense_model_file = "multi_hot_dense_1000.model",
    sparse_model_files = ["multi_hot0_sparse_1000.model", "multi_hot1_sparse_1000.model"],
    deployed_devices = [0, 1, 2, 3],
    use_gpu_embedding_cache = True,
    cache_size_percentage = 0.5,
    i64_input_key = True
)
inference_model = InferenceModel(model_config, inference_params)
pred = inference_model.predict(
    16,
    "./multi_hot_parquet/file_list_test.txt",
    hugectr.DataReaderType_t.Parquet,
    hugectr.Check_t.Non,
    [10000, 10000, 10000]
)
grount_truth = np.loadtxt("multi_hot_pred_1000")
print("pred: ", pred)
print("grount_truth: ", grount_truth)
diff = pred.flatten()-grount_truth
mse = np.mean(diff*diff)
print("mse: ", mse)
[HCTR][15:04:58][INFO][RK0][main]: Global seed is 3101700364
[HCTR][15:04:58][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
  GPU 1 ->  node 0
  GPU 2 ->  node 0
  GPU 3 ->  node 0
[HCTR][15:05:01][INFO][RK0][main]: Start all2all warmup
[HCTR][15:05:02][INFO][RK0][main]: End all2all warmup
[HCTR][15:05:02][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
[HCTR][15:05:02][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
[HCTR][15:05:02][INFO][RK0][main]: Creating ParallelHashMap CPU database backend...
[HCTR][15:05:02][INFO][RK0][main]: Created parallel (16 partitions) blank database backend in local memory!
[HCTR][15:05:02][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][15:05:02][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][15:05:02][INFO][RK0][main]: Table: hctr_et.multi_hot.sparse_embedding1; cached 16597 / 16597 embeddings in volatile database (ParallelHashMap); load: 16597 / 18446744073709551615 (0.00%).
[HCTR][15:05:02][INFO][RK0][main]: Table: hctr_et.multi_hot.sparse_embedding2; cached 9253 / 9253 embeddings in volatile database (ParallelHashMap); load: 9253 / 18446744073709551615 (0.00%).
[HCTR][15:05:02][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][15:05:02][INFO][RK0][main]: Create embedding cache in device 0.
[HCTR][15:05:02][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][15:05:02][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][15:05:02][INFO][RK0][main]: Create embedding cache in device 1.
[HCTR][15:05:02][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][15:05:02][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][15:05:02][INFO][RK0][main]: Create embedding cache in device 2.
[HCTR][15:05:02][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][15:05:02][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][15:05:02][INFO][RK0][main]: Create embedding cache in device 3.
[HCTR][15:05:02][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][15:05:02][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][15:05:02][INFO][RK0][main]: Global seed is 1801008028
[HCTR][15:05:02][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][15:05:02][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][15:05:02][INFO][RK0][main]: Start all2all warmup
[HCTR][15:05:02][INFO][RK0][main]: End all2all warmup
[HCTR][15:05:02][INFO][RK0][main]: Create inference session on device: 0
[HCTR][15:05:02][INFO][RK0][main]: Model name: multi_hot
[HCTR][15:05:02][INFO][RK0][main]: Use mixed precision: False
[HCTR][15:05:02][INFO][RK0][main]: Use cuda graph: True
[HCTR][15:05:02][INFO][RK0][main]: Max batchsize: 256
[HCTR][15:05:02][INFO][RK0][main]: Use I64 input key: True
[HCTR][15:05:02][INFO][RK0][main]: start create embedding for inference
[HCTR][15:05:02][INFO][RK0][main]: sparse_input name data1
[HCTR][15:05:02][INFO][RK0][main]: sparse_input name data2
[HCTR][15:05:02][INFO][RK0][main]: create embedding for inference success
[HCTR][15:05:02][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][15:05:02][INFO][RK0][main]: Global seed is 1395008125
[HCTR][15:05:02][INFO][RK0][main]: Device to NUMA mapping:
  GPU 1 ->  node 0
[HCTR][15:05:02][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][15:05:02][INFO][RK0][main]: Start all2all warmup
[HCTR][15:05:02][INFO][RK0][main]: End all2all warmup
[HCTR][15:05:02][INFO][RK0][main]: Create inference session on device: 1
[HCTR][15:05:02][INFO][RK0][main]: Model name: multi_hot
[HCTR][15:05:02][INFO][RK0][main]: Use mixed precision: False
[HCTR][15:05:02][INFO][RK0][main]: Use cuda graph: True
[HCTR][15:05:02][INFO][RK0][main]: Max batchsize: 256
[HCTR][15:05:02][INFO][RK0][main]: Use I64 input key: True
[HCTR][15:05:02][INFO][RK0][main]: start create embedding for inference
[HCTR][15:05:02][INFO][RK0][main]: sparse_input name data1
[HCTR][15:05:02][INFO][RK0][main]: sparse_input name data2
[HCTR][15:05:02][INFO][RK0][main]: create embedding for inference success
[HCTR][15:05:02][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][15:05:02][INFO][RK0][main]: Global seed is 3124827580
[HCTR][15:05:02][INFO][RK0][main]: Device to NUMA mapping:
  GPU 2 ->  node 0
[HCTR][15:05:03][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][15:05:03][INFO][RK0][main]: Start all2all warmup
[HCTR][15:05:03][INFO][RK0][main]: End all2all warmup
[HCTR][15:05:03][INFO][RK0][main]: Create inference session on device: 2
[HCTR][15:05:03][INFO][RK0][main]: Model name: multi_hot
[HCTR][15:05:03][INFO][RK0][main]: Use mixed precision: False
[HCTR][15:05:03][INFO][RK0][main]: Use cuda graph: True
[HCTR][15:05:03][INFO][RK0][main]: Max batchsize: 256
[HCTR][15:05:03][INFO][RK0][main]: Use I64 input key: True
[HCTR][15:05:03][INFO][RK0][main]: start create embedding for inference
[HCTR][15:05:03][INFO][RK0][main]: sparse_input name data1
[HCTR][15:05:03][INFO][RK0][main]: sparse_input name data2
[HCTR][15:05:03][INFO][RK0][main]: create embedding for inference success
[HCTR][15:05:03][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][15:05:03][INFO][RK0][main]: Global seed is 355752151
[HCTR][15:05:03][INFO][RK0][main]: Device to NUMA mapping:
  GPU 3 ->  node 0
[HCTR][15:05:03][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][15:05:03][INFO][RK0][main]: Start all2all warmup
[HCTR][15:05:03][INFO][RK0][main]: End all2all warmup
[HCTR][15:05:03][INFO][RK0][main]: Create inference session on device: 3
[HCTR][15:05:03][INFO][RK0][main]: Model name: multi_hot
[HCTR][15:05:03][INFO][RK0][main]: Use mixed precision: False
[HCTR][15:05:03][INFO][RK0][main]: Use cuda graph: True
[HCTR][15:05:03][INFO][RK0][main]: Max batchsize: 256
[HCTR][15:05:03][INFO][RK0][main]: Use I64 input key: True
[HCTR][15:05:03][INFO][RK0][main]: start create embedding for inference
[HCTR][15:05:03][INFO][RK0][main]: sparse_input name data1
[HCTR][15:05:03][INFO][RK0][main]: sparse_input name data2
[HCTR][15:05:03][INFO][RK0][main]: create embedding for inference success
[HCTR][15:05:03][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][15:05:03][INFO][RK0][main]: Global seed is 3474526165
[HCTR][15:05:03][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][15:05:03][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][15:05:03][INFO][RK0][main]: Start all2all warmup
[HCTR][15:05:03][INFO][RK0][main]: End all2all warmup
[HCTR][15:05:03][INFO][RK0][main]: Vocabulary size: 30000

pred:  [[0.6733939  0.43605337]
 [0.5189075  0.4978796 ]
 [0.39680484 0.16554658]
 ...
 [0.3779142  0.669542  ]
 [0.46529922 0.44098482]
 [0.58435297 0.45384815]]
grount_truth:  [0.673394 0.436053 0.518908 ... 0.440985 0.584353 0.453848]
mse:  0.0012302037921078574