# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
Multi-GPU Offline Inference
Deprecation Warning: this Notebook is based on the offline inference API
InferenceModel
, which will be deprecated in a future release. Please check out the Hierarchical Parameter Server for alternatives based on TensorFlow and TensorRT.
Overview
In HugeCTR version 3.4.1, we provide Python APIs to perform multi-GPU offline inference.
This work leverages the HugeCTR Hierarchical Parameter Server and enables concurrent execution on multiple devices.
The Norm
or Parquet
dataset format is currently supported by multi-GPU offline inference.
This notebook explains how to perform multi-GPU offline inference with the HugeCTR Python APIs. For more details about the API, see the HugeCTR Python Interface documentation.
Setup
To setup the environment, refer to HugeCTR Example Notebooks and follow the instructions there before running the following.
Data Generation
HugeCTR provides a tool to generate synthetic datasets. The Data Generator class is capable of generating datasets in different formats and with different distributions. We will generate multi-hot Parquet datasets with a power-law distribution for this notebook:
import hugectr
from hugectr.tools import DataGeneratorParams, DataGenerator
data_generator_params = DataGeneratorParams(
format = hugectr.DataReaderType_t.Parquet,
label_dim = 2,
dense_dim = 2,
num_slot = 3,
i64_input_key = True,
nnz_array = [2, 1, 3],
source = "./multi_hot_parquet/file_list.txt",
eval_source = "./multi_hot_parquet/file_list_test.txt",
slot_size_array = [10000, 10000, 10000],
check_type = hugectr.Check_t.Non,
dist_type = hugectr.Distribution_t.PowerLaw,
power_law_type = hugectr.PowerLaw_t.Short,
num_files = 32,
eval_num_files = 8)
data_generator = DataGenerator(data_generator_params)
data_generator.generate()
[HCTR][08:59:54.134][INFO][RK0][main]: Generate Parquet dataset
[HCTR][08:59:54.134][INFO][RK0][main]: train data folder: ./multi_hot_parquet, eval data folder: ./multi_hot_parquet, slot_size_array: 10000, 10000, 10000, nnz array: 2, 1, 3, #files for train: 32, #files for eval: 8, #samples per file: 40960, Use power law distribution: 1, alpha of power law: 1.3
[HCTR][08:59:54.136][INFO][RK0][main]: ./multi_hot_parquet exist
[HCTR][08:59:54.140][INFO][RK0][main]: ./multi_hot_parquet/train/gen_0.parquet
[HCTR][08:59:55.615][INFO][RK0][main]: ./multi_hot_parquet/train/gen_1.parquet
[HCTR][08:59:55.850][INFO][RK0][main]: ./multi_hot_parquet/train/gen_2.parquet
[HCTR][08:59:56.078][INFO][RK0][main]: ./multi_hot_parquet/train/gen_3.parquet
[HCTR][08:59:56.311][INFO][RK0][main]: ./multi_hot_parquet/train/gen_4.parquet
[HCTR][08:59:56.534][INFO][RK0][main]: ./multi_hot_parquet/train/gen_5.parquet
[HCTR][08:59:56.770][INFO][RK0][main]: ./multi_hot_parquet/train/gen_6.parquet
[HCTR][08:59:56.959][INFO][RK0][main]: ./multi_hot_parquet/train/gen_7.parquet
[HCTR][08:59:57.152][INFO][RK0][main]: ./multi_hot_parquet/train/gen_8.parquet
[HCTR][08:59:57.309][INFO][RK0][main]: ./multi_hot_parquet/train/gen_9.parquet
[HCTR][08:59:57.496][INFO][RK0][main]: ./multi_hot_parquet/train/gen_10.parquet
[HCTR][08:59:57.671][INFO][RK0][main]: ./multi_hot_parquet/train/gen_11.parquet
[HCTR][08:59:57.879][INFO][RK0][main]: ./multi_hot_parquet/train/gen_12.parquet
[HCTR][08:59:58.069][INFO][RK0][main]: ./multi_hot_parquet/train/gen_13.parquet
[HCTR][08:59:58.240][INFO][RK0][main]: ./multi_hot_parquet/train/gen_14.parquet
[HCTR][08:59:58.423][INFO][RK0][main]: ./multi_hot_parquet/train/gen_15.parquet
[HCTR][08:59:58.619][INFO][RK0][main]: ./multi_hot_parquet/train/gen_16.parquet
[HCTR][08:59:58.833][INFO][RK0][main]: ./multi_hot_parquet/train/gen_17.parquet
[HCTR][08:59:59.017][INFO][RK0][main]: ./multi_hot_parquet/train/gen_18.parquet
[HCTR][08:59:59.176][INFO][RK0][main]: ./multi_hot_parquet/train/gen_19.parquet
[HCTR][08:59:59.358][INFO][RK0][main]: ./multi_hot_parquet/train/gen_20.parquet
[HCTR][08:59:59.527][INFO][RK0][main]: ./multi_hot_parquet/train/gen_21.parquet
[HCTR][08:59:59.722][INFO][RK0][main]: ./multi_hot_parquet/train/gen_22.parquet
[HCTR][08:59:59.939][INFO][RK0][main]: ./multi_hot_parquet/train/gen_23.parquet
[HCTR][09:00:00.107][INFO][RK0][main]: ./multi_hot_parquet/train/gen_24.parquet
[HCTR][09:00:00.294][INFO][RK0][main]: ./multi_hot_parquet/train/gen_25.parquet
[HCTR][09:00:00.509][INFO][RK0][main]: ./multi_hot_parquet/train/gen_26.parquet
[HCTR][09:00:00.695][INFO][RK0][main]: ./multi_hot_parquet/train/gen_27.parquet
[HCTR][09:00:00.955][INFO][RK0][main]: ./multi_hot_parquet/train/gen_28.parquet
[HCTR][09:00:01.190][INFO][RK0][main]: ./multi_hot_parquet/train/gen_29.parquet
[HCTR][09:00:01.365][INFO][RK0][main]: ./multi_hot_parquet/train/gen_30.parquet
[HCTR][09:00:01.509][INFO][RK0][main]: ./multi_hot_parquet/train/gen_31.parquet
[HCTR][09:00:01.698][INFO][RK0][main]: ./multi_hot_parquet/file_list.txt done!
[HCTR][09:00:01.708][INFO][RK0][main]: ./multi_hot_parquet/val/gen_0.parquet
[HCTR][09:00:01.895][INFO][RK0][main]: ./multi_hot_parquet/val/gen_1.parquet
[HCTR][09:00:02.062][INFO][RK0][main]: ./multi_hot_parquet/val/gen_2.parquet
[HCTR][09:00:02.255][INFO][RK0][main]: ./multi_hot_parquet/val/gen_3.parquet
[HCTR][09:00:02.472][INFO][RK0][main]: ./multi_hot_parquet/val/gen_4.parquet
[HCTR][09:00:02.665][INFO][RK0][main]: ./multi_hot_parquet/val/gen_5.parquet
[HCTR][09:00:02.888][INFO][RK0][main]: ./multi_hot_parquet/val/gen_6.parquet
[HCTR][09:00:03.110][INFO][RK0][main]: ./multi_hot_parquet/val/gen_7.parquet
[HCTR][09:00:03.303][INFO][RK0][main]: ./multi_hot_parquet/file_list_test.txt done!
Train from Scratch
We can train from scratch by performing the following steps with Python APIs:
Create the solver, reader and optimizer, then initialize the model.
Construct the model graph by adding input, sparse embedding and dense layers in order.
Compile the model and have an overview of the model graph.
Dump the model graph to a JSON file.
Fit the model, save the model weights and optimizer states implicitly.
Dump one batch of evaluation results to files.
%%writefile multi_hot_train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(model_name = "multi_hot",
max_eval_batches = 1,
batchsize_eval = 131072,
batchsize = 16384,
lr = 0.001,
vvgpu = [[0]],
i64_input_key = True,
repeat_dataset = True,
use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
source = ["./multi_hot_parquet/file_list.txt"],
eval_source = "./multi_hot_parquet/file_list_test.txt",
check_type = hugectr.Check_t.Non,
slot_size_array = [10000, 10000, 10000])
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 2, label_name = "label",
dense_dim = 2, dense_name = "dense",
data_reader_sparse_param_array =
[hugectr.DataReaderSparseParam("data1", [2, 1], False, 2),
hugectr.DataReaderSparseParam("data2", 3, False, 1),]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 4,
embedding_vec_size = 16,
combiner = "sum",
sparse_embedding_name = "sparse_embedding1",
bottom_name = "data1",
optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 2,
embedding_vec_size = 16,
combiner = "sum",
sparse_embedding_name = "sparse_embedding2",
bottom_name = "data2",
optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding1"],
top_names = ["reshape1"],
leading_dim=32))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding2"],
top_names = ["reshape2"],
leading_dim=16))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
bottom_names = ["reshape1", "reshape2", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["concat1"],
top_names = ["fc1"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc1"],
top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["relu1"],
top_names = ["fc2"],
num_output=2))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.MultiCrossEntropyLoss,
bottom_names = ["fc2", "label"],
top_names = ["loss"],
target_weight_vec = [0.5, 0.5]))
model.compile()
model.summary()
model.graph_to_json("multi_hot.json")
model.fit(max_iter = 1100, display = 200, eval_interval = 1000, snapshot = 1000, snapshot_prefix = "multi_hot")
model.export_predictions("multi_hot_pred_" + str(1000), "multi_hot_label_" + str(1000))
Overwriting multi_hot_train.py
!python3 multi_hot_train.py
HugeCTR Version: 3.7
====================================================Model Init=====================================================
[HCTR][09:00:10.032][INFO][RK0][main]: Initialize model: multi_hot
[HCTR][09:00:10.032][INFO][RK0][main]: Global seed is 69819197
[HCTR][09:00:10.135][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][09:00:11.978][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][09:00:11.978][INFO][RK0][main]: Start all2all warmup
[HCTR][09:00:11.978][INFO][RK0][main]: End all2all warmup
[HCTR][09:00:11.979][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][09:00:11.980][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][09:00:11.985][INFO][RK0][main]: num of DataReader workers for train: 1
[HCTR][09:00:11.985][INFO][RK0][main]: num of DataReader workers for eval: 1
[HCTR][09:00:12.176][INFO][RK0][main]: Vocabulary size: 30000
[HCTR][09:00:12.177][INFO][RK0][main]: max_vocabulary_size_per_gpu_=21845
[HCTR][09:00:12.179][INFO][RK0][main]: max_vocabulary_size_per_gpu_=10922
[HCTR][09:00:12.181][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][09:00:43.965][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][09:00:43.965][INFO][RK0][main]: gpu0 init embedding done
[HCTR][09:00:43.965][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][09:00:43.965][INFO][RK0][main]: gpu0 init embedding done
[HCTR][09:00:43.969][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][09:00:43.972][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][09:00:43.972][INFO][RK0][main]: label Dense Sparse
label dense data1,data2
(None, 2) (None, 2)
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type Input Name Output Name Output Shape
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash data1 sparse_embedding1 (None, 2, 16)
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash data2 sparse_embedding2 (None, 1, 16)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding1 reshape1 (None, 32)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding2 reshape2 (None, 16)
------------------------------------------------------------------------------------------------------------------
Concat reshape1 concat1 (None, 50)
reshape2
dense
------------------------------------------------------------------------------------------------------------------
InnerProduct concat1 fc1 (None, 1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc1 relu1 (None, 1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct relu1 fc2 (None, 2)
------------------------------------------------------------------------------------------------------------------
MultiCrossEntropyLoss fc2 loss
label
------------------------------------------------------------------------------------------------------------------
[HCTR][09:00:43.977][INFO][RK0][main]: Save the model graph to multi_hot.json successfully
=====================================================Model Fit=====================================================
[HCTR][09:00:43.977][INFO][RK0][main]: Use non-epoch mode with number of iterations: 1100
[HCTR][09:00:43.977][INFO][RK0][main]: Training batchsize: 16384, evaluation batchsize: 131072
[HCTR][09:00:43.977][INFO][RK0][main]: Evaluation interval: 1000, snapshot interval: 1000
[HCTR][09:00:43.977][INFO][RK0][main]: Dense network trainable: True
[HCTR][09:00:43.977][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][09:00:43.977][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][09:00:43.977][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][09:00:43.977][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][09:00:43.977][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][09:00:43.977][INFO][RK0][main]: Training source file: ./multi_hot_parquet/file_list.txt
[HCTR][09:00:43.977][INFO][RK0][main]: Evaluation source file: ./multi_hot_parquet/file_list_test.txt
[HCTR][09:00:46.346][INFO][RK0][main]: Iter: 200 Time(200 iters): 2.36888s Loss: 0.346413 lr:0.001
[HCTR][09:00:48.421][INFO][RK0][main]: Iter: 400 Time(200 iters): 2.07362s Loss: 0.345891 lr:0.001
[HCTR][09:00:50.519][INFO][RK0][main]: Iter: 600 Time(200 iters): 2.09809s Loss: 0.345239 lr:0.001
[HCTR][09:00:52.586][INFO][RK0][main]: Iter: 800 Time(200 iters): 2.06616s Loss: 0.344346 lr:0.001
[HCTR][09:00:54.656][INFO][RK0][main]: Iter: 1000 Time(200 iters): 2.0697s Loss: 0.343731 lr:0.001
[HCTR][09:00:54.686][INFO][RK0][main]: Evaluation, AUC: 0.499013
[HCTR][09:00:54.686][INFO][RK0][main]: Eval Time for 1 iters: 0.006811s
[HCTR][09:00:54.692][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][09:00:54.830][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][09:00:54.848][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][09:00:54.851][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][09:00:54.852][INFO][RK0][main]: Done
[HCTR][09:00:54.852][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][09:00:54.853][INFO][RK0][main]: Done
[HCTR][09:00:54.886][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][09:00:54.887][INFO][RK0][main]: Done
[HCTR][09:00:54.887][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][09:00:54.888][INFO][RK0][main]: Done
[HCTR][09:00:54.904][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][09:00:54.906][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][09:00:54.909][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][09:00:55.915][INFO][RK0][main]: Finish 1100 iterations with batchsize: 16384 in 11.94s.
Multi-GPU Offline Inference
We can demonstrate multi-GPU offline inference by performing the following steps with Python APIs:
Configure the inference hyperparameters.
Initialize the inference model. The model is a collection of inference sessions deployed on multiple devices.
Make an inference from the evaluation dataset.
Check the correctness of the inference by comparing it with the dumped evaluation results.
Note: The max_batchsize
configured within InferenceParams
is the global batch size.
The value for max_batchsize
should be divisible by the number of deployed devices.
The numpy array returned by InferenceModel.predict
is of the shape (max_batchsize * num_batches, label_dim)
.
import hugectr
from hugectr.inference import InferenceModel, InferenceParams
import numpy as np
from mpi4py import MPI
model_config = "multi_hot.json"
inference_params = InferenceParams(
model_name = "multi_hot",
max_batchsize = 16384,
hit_rate_threshold = 1.0,
dense_model_file = "multi_hot_dense_1000.model",
sparse_model_files = ["multi_hot0_sparse_1000.model", "multi_hot1_sparse_1000.model"],
deployed_devices = [0, 1, 2, 3, 4, 5, 6, 7],
use_gpu_embedding_cache = True,
cache_size_percentage = 0.5,
i64_input_key = True
)
inference_model = InferenceModel(model_config, inference_params)
pred = inference_model.predict(
8,
"./multi_hot_parquet/file_list_test.txt",
hugectr.DataReaderType_t.Parquet,
hugectr.Check_t.Non,
[10000, 10000, 10000]
)
grount_truth = np.loadtxt("multi_hot_pred_1000")
print("pred: ", pred)
print("grount_truth: ", grount_truth)
diff = pred.flatten()-grount_truth
mse = np.mean(diff*diff)
print("mse: ", mse)
[HCTR][09:01:06.069][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][09:01:06.072][INFO][RK0][main]: Global seed is 3072588155
[HCTR][09:01:06.222][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
GPU 1 -> node 0
GPU 2 -> node 0
GPU 3 -> node 0
GPU 4 -> node 1
GPU 5 -> node 1
GPU 6 -> node 1
GPU 7 -> node 1
[HCTR][09:01:23.761][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][09:01:23.763][INFO][RK0][main]: Start all2all warmup
[HCTR][09:01:23.996][INFO][RK0][main]: End all2all warmup
[HCTR][09:01:24.013][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
[HCTR][09:01:24.013][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
[HCTR][09:01:24.013][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][09:01:24.013][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][09:01:24.013][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][09:01:24.347][INFO][RK0][main]: Table: hps_et.multi_hot.sparse_embedding1; cached 19849 / 19849 embeddings in volatile database (PreallocatedHashMapBackend); load: 19849 / 18446744073709551615 (0.00%).
[HCTR][09:01:24.622][INFO][RK0][main]: Table: hps_et.multi_hot.sparse_embedding2; cached 9996 / 9996 embeddings in volatile database (PreallocatedHashMapBackend); load: 9996 / 18446744073709551615 (0.00%).
[HCTR][09:01:24.622][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][09:01:24.622][INFO][RK0][main]: Create embedding cache in device 0.
[HCTR][09:01:24.628][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][09:01:24.628][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][09:01:24.628][INFO][RK0][main]: The size of thread pool: 80
[HCTR][09:01:24.628][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][09:01:24.628][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][09:01:24.641][INFO][RK0][main]: Create embedding cache in device 1.
[HCTR][09:01:24.646][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][09:01:24.646][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][09:01:24.646][INFO][RK0][main]: The size of thread pool: 80
[HCTR][09:01:24.646][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][09:01:24.646][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][09:01:24.647][INFO][RK0][main]: Create embedding cache in device 2.
[HCTR][09:01:24.652][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][09:01:24.652][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][09:01:24.652][INFO][RK0][main]: The size of thread pool: 80
[HCTR][09:01:24.652][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][09:01:24.652][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][09:01:24.654][INFO][RK0][main]: Create embedding cache in device 3.
[HCTR][09:01:24.659][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][09:01:24.659][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][09:01:24.659][INFO][RK0][main]: The size of thread pool: 80
[HCTR][09:01:24.659][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][09:01:24.659][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][09:01:24.662][INFO][RK0][main]: Create embedding cache in device 4.
[HCTR][09:01:24.667][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][09:01:24.667][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][09:01:24.667][INFO][RK0][main]: The size of thread pool: 80
[HCTR][09:01:24.667][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][09:01:24.667][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][09:01:24.669][INFO][RK0][main]: Create embedding cache in device 5.
[HCTR][09:01:24.675][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][09:01:24.675][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][09:01:24.675][INFO][RK0][main]: The size of thread pool: 80
[HCTR][09:01:24.675][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][09:01:24.675][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][09:01:24.679][INFO][RK0][main]: Create embedding cache in device 6.
[HCTR][09:01:24.683][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][09:01:24.683][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][09:01:24.683][INFO][RK0][main]: The size of thread pool: 80
[HCTR][09:01:24.683][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][09:01:24.683][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][09:01:24.685][INFO][RK0][main]: Create embedding cache in device 7.
[HCTR][09:01:24.688][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.500000
[HCTR][09:01:24.688][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][09:01:24.688][INFO][RK0][main]: The size of thread pool: 80
[HCTR][09:01:24.688][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][09:01:24.688][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][09:01:24.768][INFO][RK0][main]: Create inference session on device: 0
[HCTR][09:01:24.768][INFO][RK0][main]: Model name: multi_hot
[HCTR][09:01:24.768][INFO][RK0][main]: Use mixed precision: False
[HCTR][09:01:24.768][INFO][RK0][main]: Use cuda graph: True
[HCTR][09:01:24.768][INFO][RK0][main]: Max batchsize: 2048
[HCTR][09:01:24.768][INFO][RK0][main]: Use I64 input key: True
[HCTR][09:01:24.768][INFO][RK0][main]: start create embedding for inference
[HCTR][09:01:24.768][INFO][RK0][main]: sparse_input name data1
[HCTR][09:01:24.768][INFO][RK0][main]: sparse_input name data2
[HCTR][09:01:24.768][INFO][RK0][main]: create embedding for inference success
[HCTR][09:01:24.768][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][09:01:25.520][INFO][RK0][main]: Create inference session on device: 1
[HCTR][09:01:25.520][INFO][RK0][main]: Model name: multi_hot
[HCTR][09:01:25.520][INFO][RK0][main]: Use mixed precision: False
[HCTR][09:01:25.520][INFO][RK0][main]: Use cuda graph: True
[HCTR][09:01:25.520][INFO][RK0][main]: Max batchsize: 2048
[HCTR][09:01:25.520][INFO][RK0][main]: Use I64 input key: True
[HCTR][09:01:25.520][INFO][RK0][main]: start create embedding for inference
[HCTR][09:01:25.520][INFO][RK0][main]: sparse_input name data1
[HCTR][09:01:25.520][INFO][RK0][main]: sparse_input name data2
[HCTR][09:01:25.520][INFO][RK0][main]: create embedding for inference success
[HCTR][09:01:25.520][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][09:01:26.275][INFO][RK0][main]: Create inference session on device: 2
[HCTR][09:01:26.275][INFO][RK0][main]: Model name: multi_hot
[HCTR][09:01:26.275][INFO][RK0][main]: Use mixed precision: False
[HCTR][09:01:26.275][INFO][RK0][main]: Use cuda graph: True
[HCTR][09:01:26.275][INFO][RK0][main]: Max batchsize: 2048
[HCTR][09:01:26.275][INFO][RK0][main]: Use I64 input key: True
[HCTR][09:01:26.275][INFO][RK0][main]: start create embedding for inference
[HCTR][09:01:26.275][INFO][RK0][main]: sparse_input name data1
[HCTR][09:01:26.275][INFO][RK0][main]: sparse_input name data2
[HCTR][09:01:26.275][INFO][RK0][main]: create embedding for inference success
[HCTR][09:01:26.275][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][09:01:27.035][INFO][RK0][main]: Create inference session on device: 3
[HCTR][09:01:27.035][INFO][RK0][main]: Model name: multi_hot
[HCTR][09:01:27.035][INFO][RK0][main]: Use mixed precision: False
[HCTR][09:01:27.035][INFO][RK0][main]: Use cuda graph: True
[HCTR][09:01:27.035][INFO][RK0][main]: Max batchsize: 2048
[HCTR][09:01:27.035][INFO][RK0][main]: Use I64 input key: True
[HCTR][09:01:27.035][INFO][RK0][main]: start create embedding for inference
[HCTR][09:01:27.035][INFO][RK0][main]: sparse_input name data1
[HCTR][09:01:27.035][INFO][RK0][main]: sparse_input name data2
[HCTR][09:01:27.035][INFO][RK0][main]: create embedding for inference success
[HCTR][09:01:27.035][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][09:01:27.781][INFO][RK0][main]: Create inference session on device: 4
[HCTR][09:01:27.781][INFO][RK0][main]: Model name: multi_hot
[HCTR][09:01:27.781][INFO][RK0][main]: Use mixed precision: False
[HCTR][09:01:27.781][INFO][RK0][main]: Use cuda graph: True
[HCTR][09:01:27.781][INFO][RK0][main]: Max batchsize: 2048
[HCTR][09:01:27.781][INFO][RK0][main]: Use I64 input key: True
[HCTR][09:01:27.781][INFO][RK0][main]: start create embedding for inference
[HCTR][09:01:27.781][INFO][RK0][main]: sparse_input name data1
[HCTR][09:01:27.781][INFO][RK0][main]: sparse_input name data2
[HCTR][09:01:27.781][INFO][RK0][main]: create embedding for inference success
[HCTR][09:01:27.781][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][09:01:28.534][INFO][RK0][main]: Create inference session on device: 5
[HCTR][09:01:28.534][INFO][RK0][main]: Model name: multi_hot
[HCTR][09:01:28.534][INFO][RK0][main]: Use mixed precision: False
[HCTR][09:01:28.534][INFO][RK0][main]: Use cuda graph: True
[HCTR][09:01:28.534][INFO][RK0][main]: Max batchsize: 2048
[HCTR][09:01:28.534][INFO][RK0][main]: Use I64 input key: True
[HCTR][09:01:28.534][INFO][RK0][main]: start create embedding for inference
[HCTR][09:01:28.534][INFO][RK0][main]: sparse_input name data1
[HCTR][09:01:28.534][INFO][RK0][main]: sparse_input name data2
[HCTR][09:01:28.534][INFO][RK0][main]: create embedding for inference success
[HCTR][09:01:28.534][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][09:01:29.291][INFO][RK0][main]: Create inference session on device: 6
[HCTR][09:01:29.291][INFO][RK0][main]: Model name: multi_hot
[HCTR][09:01:29.291][INFO][RK0][main]: Use mixed precision: False
[HCTR][09:01:29.291][INFO][RK0][main]: Use cuda graph: True
[HCTR][09:01:29.291][INFO][RK0][main]: Max batchsize: 2048
[HCTR][09:01:29.291][INFO][RK0][main]: Use I64 input key: True
[HCTR][09:01:29.291][INFO][RK0][main]: start create embedding for inference
[HCTR][09:01:29.291][INFO][RK0][main]: sparse_input name data1
[HCTR][09:01:29.291][INFO][RK0][main]: sparse_input name data2
[HCTR][09:01:29.291][INFO][RK0][main]: create embedding for inference success
[HCTR][09:01:29.291][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][09:01:30.037][INFO][RK0][main]: Create inference session on device: 7
[HCTR][09:01:30.037][INFO][RK0][main]: Model name: multi_hot
[HCTR][09:01:30.037][INFO][RK0][main]: Use mixed precision: False
[HCTR][09:01:30.037][INFO][RK0][main]: Use cuda graph: True
[HCTR][09:01:30.037][INFO][RK0][main]: Max batchsize: 2048
[HCTR][09:01:30.037][INFO][RK0][main]: Use I64 input key: True
[HCTR][09:01:30.038][INFO][RK0][main]: start create embedding for inference
[HCTR][09:01:30.038][INFO][RK0][main]: sparse_input name data1
[HCTR][09:01:30.038][INFO][RK0][main]: sparse_input name data2
[HCTR][09:01:30.038][INFO][RK0][main]: create embedding for inference success
[HCTR][09:01:30.038][INFO][RK0][main]: Inference stage skip MultiCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][09:01:30.807][INFO][RK0][main]: Create inference data reader on 8 GPU(s)
[HCTR][09:01:30.807][INFO][RK0][main]: num of DataReader workers: 8
[HCTR][09:01:30.915][INFO][RK0][main]: Vocabulary size: 30000
[INFO] Inference time for 8 batches: 0.182527
pred: [[0.51329887 0.4888402 ]
[0.55268604 0.62567735]
[0.48302165 0.5015869 ]
...
[0.52275413 0.46319592]
[0.46984023 0.5436093 ]
[0.48216432 0.48920953]]
grount_truth: [0.513299 0.48884 0.552686 ... 0.543609 0.482164 0.48921 ]
mse: 8.482603947165404e-14