# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
HugeCTR Continuous Training
Overview
The notebook introduces how to use the Embedding Training Cache (ETC) feature in HugeCTR for the continuous training. The ETC feature is designed to handle recommendation models with huge embedding table by the incremental training method, which allows you to train such a model that the model size is much larger than the available GPU memory size.
To learn more about the ETC, see the Embedding Training Cache documentation.
To learn how to use the APIs of ETC, see the HugeCTR Python Interface documentation.
Setup
To setup the environment, refer to HugeCTR Example Notebooks and follow the instructions there before running the following.
Continuous Training
To download and prepare the dataset we will be doing the following steps. At the end of this cell, we provide the shell commands you can run on the terminal to get the data ready for this notebook.
Note: If you already have the data downloaded, then skip to the preprocessing step (2). If preprocessing is also done, skip to creating the softlink between the processed data to the notebooks/
directory (3).
Data Preparation
Download the Criteo dataset
To preprocess the downloaded Kaggle Criteo dataset, we’ll make the following operations:
Reduce the amounts of data to speed up the preprocessing
Fill missing values
Remove the feature values whose occurrences are very rare, etc.
Preprocessing by Pandas:
Meanings of the command line arguments:
The 1st argument represents the dataset postfix. It is
1
here sinceday_1
is used.The 2nd argument
wdl_data
is where the preprocessed data is stored.The 3rd argument
pandas
is the processing script going to use, here we choosepandas
.The 4th argument
1
embodies that the normalization is applied to dense features.The 5th argument
1
means that the feature crossing is applied.The 6th argument
100
means the number of data files in each file list.
For more details about the data preprocessing, please refer to the “Preprocess the Criteo Dataset” section of the README in the samples/criteo directory of the repository on GitHub.
Create a soft link of the dataset folder to the path of this notebook
Run the following commands on the terminal to prepare the data for this notebook
export project_root=/home/hugectr # set this to the directory where hugectr is downloaded
cd ${project_root}/tools
# Step 1
wget https://storage.googleapis.com/criteo-cail-datasets/day_0.gz
#Step 2
bash preprocess.sh 0 wdl_data pandas 1 1 10
#Step 3
ln -s ${project_root}/tools/wdl_data ${project_root}/notebooks/wdl_data
Continuous Training with High-level API
This section gives the code sample of continuous training using a Keras-like high-level API. The high-level API encapsulates much of the complexity for users, making it easy to use and able to handle many of the scenarios in a production environment.
Meanwhile, in addition to a high-level API, HugeCTR also provides low-level APIs that enable you customize the training logic. A code sample using the low-level APIs is provided in the next section.
The code sample in this section trains a model from scratch using the embedding training cache, gets the incremental model, and saves the trained dense weights and sparse embedding weights. The following steps are required to achieve those logics:
Create the
solver
,reader
,optimizer
andetc
, then initialize the model.Construct the model graph by adding input, sparse embedding, and dense layers in order.
Compile the model and overview the model graph.
Dump the model graph to the JSON file.
Train the sparse and dense model.
Set the new training datasets and their corresponding keysets.
Train the sparse and dense model incrementally.
Get the incrementally trained embedding table.
Save the model weights and optimizer states explicitly.
Note: repeat_dataset
should be False
when using the embedding training cache, while the argument num_epochs
in Model::fit
specifies the number of training epochs in this mode.
%%writefile wdl_train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
batchsize_eval = 1024,
batchsize = 1024,
lr = 0.001,
vvgpu = [[0]],
i64_input_key = False,
use_mixed_precision = False,
repeat_dataset = False,
use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
source = ["wdl_data/file_list."+str(i)+".txt" for i in range(2)],
keyset = ["wdl_data/file_list."+str(i)+".keyset" for i in range(2)],
eval_source = "wdl_data/file_list.2.txt",
num_workers=8,
check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
hc_cnfg = hugectr.CreateHMemCache(num_blocks = 2, target_hit_rate = 0.5, max_num_evict = 0)
etc = hugectr.CreateETC(ps_types = [hugectr.TrainPSType_t.Staged, hugectr.TrainPSType_t.Cached],
sparse_models = ["./wdl_0_sparse_model", "./wdl_1_sparse_model"],
local_paths = ["./"], hmem_cache_configs = [hc_cnfg])
model = hugectr.Model(solver, reader, optimizer, etc)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
dense_dim = 13, dense_name = "dense",
data_reader_sparse_param_array =
[hugectr.DataReaderSparseParam("wide_data", 30, True, 1),
hugectr.DataReaderSparseParam("deep_data", 2, False, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 69,
embedding_vec_size = 1,
combiner = "sum",
sparse_embedding_name = "sparse_embedding2",
bottom_name = "wide_data",
optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 1074,
embedding_vec_size = 16,
combiner = "sum",
sparse_embedding_name = "sparse_embedding1",
bottom_name = "deep_data",
optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding1"],
top_names = ["reshape1"],
leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding2"],
top_names = ["reshape2"],
leading_dim=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
bottom_names = ["reshape1", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["concat1"],
top_names = ["fc1"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc1"],
top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu1"],
top_names = ["dropout1"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout1"],
top_names = ["fc2"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc2"],
top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu2"],
top_names = ["dropout2"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout2"],
top_names = ["fc3"],
num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
bottom_names = ["fc3", "reshape2"],
top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
bottom_names = ["add1", "label"],
top_names = ["loss"]))
model.compile()
model.summary()
model.graph_to_json(graph_config_file = "wdl.json")
model.fit(num_epochs = 1, display = 500, eval_interval = 1000)
# Get the updated embedding features in model.fit()
# updated_model = model.get_incremental_model()
model.set_source(source = ["wdl_data/file_list.3.txt", "wdl_data/file_list.4.txt"], keyset = ["wdl_data/file_list.3.keyset", "wdl_data/file_list.4.keyset"], eval_source = "wdl_data/file_list.5.txt")
model.fit(num_epochs = 1, display = 500, eval_interval = 1000)
# Get the updated embedding features in model.fit()
updated_model = model.get_incremental_model()
model.save_params_to_files("wdl_etc")
Writing wdl_train.py
!python3 wdl_train.py
[HCTR][08:03:26.675][INFO][RK0][main]: Empty embedding, trained table will be stored in ./wdl_0_sparse_model
[HCTR][08:03:26.675][INFO][RK0][main]: Empty embedding, trained table will be stored in ./wdl_1_sparse_model
HugeCTR Version: 4.1
====================================================Model Init=====================================================
[HCTR][08:03:26.676][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][08:03:26.676][INFO][RK0][main]: Global seed is 1441282772
[HCTR][08:03:26.678][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][08:03:28.494][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][08:03:28.494][INFO][RK0][main]: Start all2all warmup
[HCTR][08:03:28.495][INFO][RK0][main]: End all2all warmup
[HCTR][08:03:28.495][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][08:03:28.496][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][08:03:28.496][INFO][RK0][main]: num of DataReader workers for train: 8
[HCTR][08:03:28.496][INFO][RK0][main]: num of DataReader workers for eval: 8
[HCTR][08:03:28.505][INFO][RK0][main]: max_vocabulary_size_per_gpu_=6029312
[HCTR][08:03:28.507][INFO][RK0][main]: max_vocabulary_size_per_gpu_=5865472
[HCTR][08:03:28.511][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][08:03:32.482][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][08:03:32.482][INFO][RK0][main]: gpu0 init embedding done
[HCTR][08:03:32.482][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][08:03:32.484][INFO][RK0][main]: gpu0 init embedding done
[HCTR][08:03:32.484][INFO][RK0][main]: Enable HMEM-Based Parameter Server
[HCTR][08:03:32.484][INFO][RK0][main]: ./wdl_0_sparse_model not exist, create and train from scratch
[HCTR][08:03:32.492][INFO][RK0][main]: Enable HMemCache-Based Parameter Server
[HCTR][08:03:32.492][INFO][RK0][main]: ./wdl_1_sparse_model/key doesn't exist, created
[HCTR][08:03:32.495][INFO][RK0][main]: ./wdl_1_sparse_model/emb_vector doesn't exist, created
[HCTR][08:03:32.498][INFO][RK0][main]: ./wdl_1_sparse_model/Adam.m doesn't exist, created
[HCTR][08:03:32.502][INFO][RK0][main]: ./wdl_1_sparse_model/Adam.v doesn't exist, created
[HCTR][08:03:33.843][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][08:03:33.847][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][08:03:33.847][INFO][RK0][main]: Model structure on each GPU
Label Dense Sparse
label dense wide_data,deep_data
(1024,1) (1024,13)
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type Input Name Output Name Output Shape
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash wide_data sparse_embedding2 (1024,1,1)
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash deep_data sparse_embedding1 (1024,26,16)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding1 reshape1 (1024,416)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding2 reshape2 (1024,1)
------------------------------------------------------------------------------------------------------------------
Concat reshape1 concat1 (1024,429)
dense
------------------------------------------------------------------------------------------------------------------
InnerProduct concat1 fc1 (1024,1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc1 relu1 (1024,1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu1 dropout1 (1024,1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout1 fc2 (1024,1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc2 relu2 (1024,1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu2 dropout2 (1024,1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout2 fc3 (1024,1)
------------------------------------------------------------------------------------------------------------------
Add fc3 add1 (1024,1)
reshape2
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss add1 loss
label
------------------------------------------------------------------------------------------------------------------
[HCTR][08:03:33.857][INFO][RK0][main]: Save the model graph to wdl.json successfully
=====================================================Model Fit=====================================================
[HCTR][08:03:33.857][INFO][RK0][main]: Use embedding training cache mode with number of training sources: 2, number of epochs: 1
[HCTR][08:03:33.857][INFO][RK0][main]: Training batchsize: 1024, evaluation batchsize: 1024
[HCTR][08:03:33.857][INFO][RK0][main]: Evaluation interval: 1000, snapshot interval: 10000
[HCTR][08:03:33.857][INFO][RK0][main]: Dense network trainable: True
[HCTR][08:03:33.857][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][08:03:33.857][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][08:03:33.857][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][08:03:33.857][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][08:03:33.857][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][08:03:33.858][INFO][RK0][main]: Evaluation source file: wdl_data/file_list.2.txt
[HCTR][08:03:33.858][INFO][RK0][main]: --------------------Epoch 0, source file: wdl_data/file_list.0.txt--------------------
[HCTR][08:03:33.860][INFO][RK0][main]: Preparing embedding table for next pass
[HCTR][08:03:33.952][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 0 %
[HCTR][08:03:36.359][INFO][RK0][main]: --------------------Epoch 0, source file: wdl_data/file_list.1.txt--------------------
[HCTR][08:03:36.360][INFO][RK0][main]: Preparing embedding table for next pass
[HCTR][08:03:37.255][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 0 %
[HCTR][08:03:37.355][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 0 %
[HCTR][08:03:37.964][INFO][RK0][main]: Iter: 500 Time(500 iters): 4.10376s Loss: 0.108516 lr:0.001
=====================================================Model Fit=====================================================
[HCTR][08:03:39.695][INFO][RK0][main]: Use embedding training cache mode with number of training sources: 2, number of epochs: 1
[HCTR][08:03:39.695][INFO][RK0][main]: Training batchsize: 1024, evaluation batchsize: 1024
[HCTR][08:03:39.695][INFO][RK0][main]: Evaluation interval: 1000, snapshot interval: 10000
[HCTR][08:03:39.695][INFO][RK0][main]: Dense network trainable: True
[HCTR][08:03:39.695][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][08:03:39.695][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][08:03:39.695][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][08:03:39.695][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][08:03:39.695][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][08:03:39.695][INFO][RK0][main]: Evaluation source file: wdl_data/file_list.5.txt
[HCTR][08:03:39.695][INFO][RK0][main]: --------------------Epoch 0, source file: wdl_data/file_list.3.txt--------------------
[HCTR][08:03:39.696][INFO][RK0][main]: Preparing embedding table for next pass
[HCTR][08:03:40.424][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 78.9 %
[HCTR][08:03:40.501][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 72.53 %
[HCTR][08:03:42.840][INFO][RK0][main]: --------------------Epoch 0, source file: wdl_data/file_list.4.txt--------------------
[HCTR][08:03:42.841][INFO][RK0][main]: Preparing embedding table for next pass
[HCTR][08:03:43.696][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 66.2 %
[HCTR][08:03:43.767][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 68.85 %
[HCTR][08:03:44.382][INFO][RK0][main]: Iter: 500 Time(500 iters): 4.68363s Loss: 0.103496 lr:0.001
[HCTR][08:03:46.952][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 66.14 %
[HCTR][08:03:47.217][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 58.85 %
[HCTR][08:03:47.222][INFO][RK0][main]: Get updated portion of embedding table [DONE}
[HCTR][08:03:48.314][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 66.14 %
[HCTR][08:03:48.318][INFO][RK0][main]: Updating sparse model in SSD
[HCTR][08:03:48.544][INFO][RK0][main]: Done!
[HCTR][08:03:48.544][INFO][RK0][main]: Sync blocks from HMEM-Cache to SSD
████████████████████████████████████████▏ 100.0% [ 2/ 2 | 13.3 Hz | 0s<0s] m
[HCTR][08:03:48.695][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:03:48.710][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][08:03:48.715][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:03:48.744][INFO][RK0][main]: Dumping dense optimizer states to file, successful
Continuous Training with the Low-level API
This section gives the code sample for continuous training using the low-level API. The program logic is the same as the preceding code sample.
Although the low-level APIs provide fine-grained control of the training logic, we encourage you to use the high-level API if it can satisfy your requirements because the naked data reader and embedding training cache logics are not straightforward and error prone.
For more about the low-level API, please refer to Low-level Training API and samples of Low-level Training.
%%writefile wdl_etc.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
batchsize_eval = 1024,
batchsize = 1024,
vvgpu = [[0]],
i64_input_key = False,
use_mixed_precision = False,
repeat_dataset = False,
use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
source = ["wdl_data/file_list."+str(i)+".txt" for i in range(2)],
keyset = ["wdl_data/file_list."+str(i)+".keyset" for i in range(2)],
eval_source = "wdl_data/file_list.2.txt",
num_workers = 10,
check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
hc_cnfg = hugectr.CreateHMemCache(num_blocks = 2, target_hit_rate = 0.5, max_num_evict = 0)
etc = hugectr.CreateETC(ps_types = [hugectr.TrainPSType_t.Staged, hugectr.TrainPSType_t.Cached],
sparse_models = ["./wdl_0_sparse_model", "./wdl_1_sparse_model"],
local_paths = ["./"], hmem_cache_configs = [hc_cnfg])
model = hugectr.Model(solver, reader, optimizer, etc)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = True)
model.compile()
lr_sch = model.get_learning_rate_scheduler()
data_reader_train = model.get_data_reader_train()
data_reader_eval = model.get_data_reader_eval()
etc = model.get_embedding_training_cache()
dataset = [("wdl_data/file_list."+str(i)+".txt", "wdl_data/file_list."+str(i)+".keyset") for i in range(2)]
data_reader_eval.set_source("wdl_data/file_list.2.txt")
data_reader_eval_flag = True
iteration = 0
for file_list, keyset_file in dataset:
data_reader_train.set_source(file_list)
data_reader_train_flag = True
etc.update(keyset_file)
while True:
lr = lr_sch.get_next()
model.set_learning_rate(lr)
data_reader_train_flag = model.train()
if not data_reader_train_flag:
break
if iteration % 1000 == 0:
batches = 0
while data_reader_eval_flag:
if batches >= solver.max_eval_batches:
break
data_reader_eval_flag = model.eval()
batches += 1
if not data_reader_eval_flag:
data_reader_eval.set_source()
data_reader_eval_flag = True
metrics = model.get_eval_metrics()
print("[HUGECTR][INFO] iter: {}, metrics: {}".format(iteration, metrics))
iteration += 1
print("[HUGECTR][INFO] trained with data in {}".format(file_list))
dataset = [("wdl_data/file_list."+str(i)+".txt", "wdl_data/file_list."+str(i)+".keyset") for i in range(3, 5)]
for file_list, keyset_file in dataset:
data_reader_train.set_source(file_list)
data_reader_train_flag = True
etc.update(keyset_file)
while True:
lr = lr_sch.get_next()
model.set_learning_rate(lr)
data_reader_train_flag = model.train()
if not data_reader_train_flag:
break
if iteration % 1000 == 0:
batches = 0
while data_reader_eval_flag:
if batches >= solver.max_eval_batches:
break
data_reader_eval_flag = model.eval()
batches += 1
if not data_reader_eval_flag:
data_reader_eval.set_source()
data_reader_eval_flag = True
metrics = model.get_eval_metrics()
print("[HUGECTR][INFO] iter: {}, metrics: {}".format(iteration, metrics))
iteration += 1
print("[HUGECTR][INFO] trained with data in {}".format(file_list))
incremental_model = model.get_incremental_model()
model.save_params_to_files("wdl_etc")
Writing wdl_etc.py
!python3 wdl_etc.py
[HCTR][08:03:56.005][INFO][RK0][main]: Use existing embedding: ./wdl_0_sparse_model
[HCTR][08:03:56.005][INFO][RK0][main]: Use existing embedding: ./wdl_1_sparse_model
HugeCTR Version: 4.1
====================================================Model Init=====================================================
[HCTR][08:03:56.005][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][08:03:56.005][INFO][RK0][main]: Global seed is 3575646790
[HCTR][08:03:56.008][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][08:03:57.823][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][08:03:57.823][INFO][RK0][main]: Start all2all warmup
[HCTR][08:03:57.823][INFO][RK0][main]: End all2all warmup
[HCTR][08:03:57.824][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][08:03:57.825][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][08:03:57.825][INFO][RK0][main]: num of DataReader workers for train: 10
[HCTR][08:03:57.825][INFO][RK0][main]: num of DataReader workers for eval: 10
[HCTR][08:03:57.836][WARNING][RK0][main]: Embedding vector size(1) is not a multiple of 32, which may affect the GPU resource utilization.
[HCTR][08:03:57.836][INFO][RK0][main]: max_num_frequent_categories is not specified using default: 1
[HCTR][08:03:57.836][INFO][RK0][main]: max_num_infrequent_samples is not specified using default: -1
[HCTR][08:03:57.836][INFO][RK0][main]: p_dup_max is not specified using default: 0.01
[HCTR][08:03:57.836][INFO][RK0][main]: max_all_reduce_bandwidth is not specified using default: 1.3e+11
[HCTR][08:03:57.836][INFO][RK0][main]: max_all_to_all_bandwidth is not specified using default: 1.9e+11
[HCTR][08:03:57.836][INFO][RK0][main]: efficiency_bandwidth_ratio is not specified using default: 1
[HCTR][08:03:57.836][INFO][RK0][main]: communication_type is not specified using default: IB_NVLink
[HCTR][08:03:57.836][INFO][RK0][main]: hybrid_embedding_type is not specified using default: Distributed
[HCTR][08:03:57.836][INFO][RK0][main]: max_vocabulary_size_per_gpu_=6029312
[HCTR][08:03:57.838][WARNING][RK0][main]: Embedding vector size(16) is not a multiple of 32, which may affect the GPU resource utilization.
[HCTR][08:03:57.838][INFO][RK0][main]: max_num_frequent_categories is not specified using default: 1
[HCTR][08:03:57.838][INFO][RK0][main]: max_num_infrequent_samples is not specified using default: -1
[HCTR][08:03:57.838][INFO][RK0][main]: p_dup_max is not specified using default: 0.01
[HCTR][08:03:57.838][INFO][RK0][main]: max_all_reduce_bandwidth is not specified using default: 1.3e+11
[HCTR][08:03:57.838][INFO][RK0][main]: max_all_to_all_bandwidth is not specified using default: 1.9e+11
[HCTR][08:03:57.838][INFO][RK0][main]: efficiency_bandwidth_ratio is not specified using default: 1
[HCTR][08:03:57.838][INFO][RK0][main]: communication_type is not specified using default: IB_NVLink
[HCTR][08:03:57.838][INFO][RK0][main]: hybrid_embedding_type is not specified using default: Distributed
[HCTR][08:03:57.838][INFO][RK0][main]: max_vocabulary_size_per_gpu_=5865472
[HCTR][08:03:57.842][INFO][RK0][main]: Load the model graph from wdl.json successfully
[HCTR][08:03:57.842][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][08:04:01.810][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][08:04:01.811][INFO][RK0][main]: gpu0 init embedding done
[HCTR][08:04:01.811][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][08:04:01.813][INFO][RK0][main]: gpu0 init embedding done
[HCTR][08:04:01.813][INFO][RK0][main]: Enable HMEM-Based Parameter Server
[HCTR][08:04:01.948][INFO][RK0][main]: Enable HMemCache-Based Parameter Server
[HCTR][08:04:03.308][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][08:04:03.312][INFO][RK0][main]: Warm-up done
[HCTR][08:04:03.315][INFO][RK0][main]: Preparing embedding table for next pass
[HCTR][08:04:03.454][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 0 %
[HUGECTR][INFO] iter: 0, metrics: [('AUC', 0.6509891152381897)]
[HUGECTR][INFO] trained with data in wdl_data/file_list.0.txt
[HCTR][08:04:06.386][INFO][RK0][main]: Preparing embedding table for next pass
[HCTR][08:04:06.865][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 100 %
[HCTR][08:04:06.966][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 78.9 %
[HUGECTR][INFO] trained with data in wdl_data/file_list.1.txt
[HCTR][08:04:09.307][INFO][RK0][main]: Preparing embedding table for next pass
[HCTR][08:04:10.058][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 100 %
[HCTR][08:04:10.133][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 91.27 %
[HUGECTR][INFO] iter: 1000, metrics: [('AUC', 0.716963529586792)]
[HUGECTR][INFO] trained with data in wdl_data/file_list.3.txt
[HCTR][08:04:12.976][INFO][RK0][main]: Preparing embedding table for next pass
[HCTR][08:04:13.647][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 91.27 %
[HCTR][08:04:13.724][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 91.25 %
[HUGECTR][INFO] trained with data in wdl_data/file_list.4.txt
[HCTR][08:04:16.497][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 91.25 %
[HCTR][08:04:16.674][INFO][RK0][main]: HMEM-Cache PS: Hit rate [load]: 100 %
[HCTR][08:04:16.680][INFO][RK0][main]: Get updated portion of embedding table [DONE}
[HCTR][08:04:17.705][INFO][RK0][main]: HMEM-Cache PS: Hit rate [dump]: 91.25 %
[HCTR][08:04:17.705][INFO][RK0][main]: Updating sparse model in SSD
[HCTR][08:04:17.732][INFO][RK0][main]: Done!
[HCTR][08:04:17.732][INFO][RK0][main]: Sync blocks from HMEM-Cache to SSD
████████████████████████████████████████▏ 100.0% [ 2/ 2 | 11.9 Hz | 0s<0s] m
[HCTR][08:04:17.902][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:04:17.917][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][08:04:17.919][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:04:17.954][INFO][RK0][main]: Dumping dense optimizer states to file, successful