HugeCTR Continuous Training
Overview
The notebook introduces how to use the Embedding Training Cache (ETC) feature in HugeCTR for the continuous training. The ETC feature is designed to handle recommendation models with huge embedding table by the incremental training method, which allows you to train such a model that the model size is much larger than the available GPU memory size.
To learn more about the ETC, see the Embedding Training Cache documentation.
To learn how to use the APIs of ETC, see the HugeCTR Python Interface documentation.
Installation
Get HugeCTR from NGC
The continuous training module is preinstalled in the 22.07 and later Merlin Training Container: nvcr.io/nvidia/merlin/merlin-hugectr:22.07
.
You can check the existence of required libraries by running the following Python code after launching this container.
$ python3 -c "import hugectr"
If you prefer to build HugeCTR from the source code instead of using the NGC container, refer to the How to Start Your Development documentation.
Continuous Training
Data Preparation
Download the Criteo dataset using the following command:
$ cd ${project_root}/tools $ wget https://storage.googleapis.com/criteo-cail-datasets/day_1.gz
To preprocess the downloaded Kaggle Criteo dataset, we’ll make the following operations:
Reduce the amounts of data to speed up the preprocessing
Fill missing values
Remove the feature values whose occurrences are very rare, etc.
Preprocessing by Pandas using the following command:
$ bash preprocess.sh 1 wdl_data pandas 1 1 100
Meanings of the command line arguments:
The 1st argument represents the dataset postfix. It is
1
here sinceday_1
is used.The 2nd argument
wdl_data
is where the preprocessed data is stored.The 3rd argument
pandas
is the processing script going to use, here we choosepandas
.The 4th argument
1
embodies that the normalization is applied to dense features.The 5th argument
1
means that the feature crossing is applied.The 6th argument
100
means the number of data files in each file list.
For more details about the data preprocessing, please refer to the “Preprocess the Criteo Dataset” section of the README in the samples/criteo directory of the repository on GitHub.
Create a soft link of the dataset folder to the path of this notebook using the following command:
$ ln -s ${project_root}/tools/wdl_data ${project_root}/notebooks/wdl_data
Continuous Training with High-level API
This section gives the code sample of continuous training using a Keras-like high-level API. The high-level API encapsulates much of the complexity for users, making it easy to use and able to handle many of the scenarios in a production environment.
Meanwhile, in addition to a high-level API, HugeCTR also provides low-level APIs that enable you customize the training logic. A code sample using the low-level APIs is provided in the next section.
The code sample in this section trains a model from scratch using the embedding training cache, gets the incremental model, and saves the trained dense weights and sparse embedding weights. The following steps are required to achieve those logics:
Create the
solver
,reader
,optimizer
andetc
, then initialize the model.Construct the model graph by adding input, sparse embedding, and dense layers in order.
Compile the model and overview the model graph.
Dump the model graph to the JSON file.
Train the sparse and dense model.
Set the new training datasets and their corresponding keysets.
Train the sparse and dense model incrementally.
Get the incrementally trained embedding table.
Save the model weights and optimizer states explicitly.
Note: repeat_dataset
should be False
when using the embedding training cache, while the argument num_epochs
in Model::fit
specifies the number of training epochs in this mode.
%%writefile wdl_train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
batchsize_eval = 1024,
batchsize = 1024,
lr = 0.001,
vvgpu = [[0]],
i64_input_key = False,
use_mixed_precision = False,
repeat_dataset = False,
use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
source = ["wdl_data/file_list."+str(i)+".txt" for i in range(2)],
keyset = ["wdl_data/file_list."+str(i)+".keyset" for i in range(2)],
eval_source = "wdl_data/file_list.2.txt",
check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
hc_cnfg = hugectr.CreateHMemCache(num_blocks = 2, target_hit_rate = 0.5, max_num_evict = 0)
etc = hugectr.CreateETC(ps_types = [hugectr.TrainPSType_t.Staged, hugectr.TrainPSType_t.Cached],
sparse_models = ["./wdl_0_sparse_model", "./wdl_1_sparse_model"],
local_paths = ["./"], hmem_cache_configs = [hc_cnfg])
model = hugectr.Model(solver, reader, optimizer, etc)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
dense_dim = 13, dense_name = "dense",
data_reader_sparse_param_array =
[hugectr.DataReaderSparseParam("wide_data", 30, True, 1),
hugectr.DataReaderSparseParam("deep_data", 2, False, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 69,
embedding_vec_size = 1,
combiner = "sum",
sparse_embedding_name = "sparse_embedding2",
bottom_name = "wide_data",
optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 1074,
embedding_vec_size = 16,
combiner = "sum",
sparse_embedding_name = "sparse_embedding1",
bottom_name = "deep_data",
optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding1"],
top_names = ["reshape1"],
leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding2"],
top_names = ["reshape2"],
leading_dim=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
bottom_names = ["reshape1", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["concat1"],
top_names = ["fc1"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc1"],
top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu1"],
top_names = ["dropout1"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout1"],
top_names = ["fc2"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc2"],
top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu2"],
top_names = ["dropout2"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout2"],
top_names = ["fc3"],
num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
bottom_names = ["fc3", "reshape2"],
top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
bottom_names = ["add1", "label"],
top_names = ["loss"]))
model.compile()
model.summary()
model.graph_to_json(graph_config_file = "wdl.json")
model.fit(num_epochs = 1, display = 500, eval_interval = 1000)
# Get the updated embedding features in model.fit()
# updated_model = model.get_incremental_model()
model.set_source(source = ["wdl_data/file_list.3.txt", "wdl_data/file_list.4.txt"], keyset = ["wdl_data/file_list.3.keyset", "wdl_data/file_list.4.keyset"], eval_source = "wdl_data/file_list.5.txt")
model.fit(num_epochs = 1, display = 500, eval_interval = 1000)
# Get the updated embedding features in model.fit()
updated_model = model.get_incremental_model()
model.save_params_to_files("wdl_etc")
Writing wdl_train.py
!python3 wdl_train.py
[HUGECTR][12:36:58][INFO][RANK0]: Empty embedding, trained table will be stored in ./wdl_0_sparse_model
[HUGECTR][12:36:58][INFO][RANK0]: Empty embedding, trained table will be stored in ./wdl_1_sparse_model
HugeCTR Version: 3.2
====================================================Model Init=====================================================
[HUGECTR][12:36:58][INFO][RANK0]: Global seed is 3664540043
[HUGECTR][12:36:58][INFO][RANK0]: Device to NUMA mapping:
GPU 0 -> node 0
[HUGECTR][12:36:59][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.
[HUGECTR][12:36:59][INFO][RANK0]: Start all2all warmup
[HUGECTR][12:36:59][INFO][RANK0]: End all2all warmup
[HUGECTR][12:36:59][INFO][RANK0]: Using All-reduce algorithm: NCCL
[HUGECTR][12:36:59][INFO][RANK0]: Device 0: Tesla V100-SXM2-32GB
[HUGECTR][12:36:59][INFO][RANK0]: num of DataReader workers: 12
[HUGECTR][12:36:59][INFO][RANK0]: max_vocabulary_size_per_gpu_=6029312
[HUGECTR][12:36:59][INFO][RANK0]: max_vocabulary_size_per_gpu_=5865472
[HUGECTR][12:36:59][INFO][RANK0]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HUGECTR][12:37:03][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][12:37:03][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][12:37:03][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][12:37:03][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][12:37:03][INFO][RANK0]: Enable HMEM-Based Parameter Server
[HUGECTR][12:37:03][INFO][RANK0]: ./wdl_0_sparse_model not exist, create and train from scratch
[HUGECTR][12:37:03][INFO][RANK0]: Enable HMemCache-Based Parameter Server
[HUGECTR][12:37:03][INFO][RANK0]: ./wdl_1_sparse_model/key doesn't exist, created
[HUGECTR][12:37:03][INFO][RANK0]: ./wdl_1_sparse_model/emb_vector doesn't exist, created
[HUGECTR][12:37:03][INFO][RANK0]: ./wdl_1_sparse_model/Adam.m doesn't exist, created
[HUGECTR][12:37:03][INFO][RANK0]: ./wdl_1_sparse_model/Adam.v doesn't exist, created
[HUGECTR][12:37:04][INFO][RANK0]: Starting AUC NCCL warm-up
[HUGECTR][12:37:04][INFO][RANK0]: Warm-up done
===================================================Model Summary===================================================
label Dense Sparse
label dense wide_data,deep_data
(None, 1) (None, 13)
------------------------------------------------------------------------------------------------------------------
Layer Type Input Name Output Name Output Shape
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash wide_data sparse_embedding2 (None, 1, 1)
DistributedSlotSparseEmbeddingHash deep_data sparse_embedding1 (None, 26, 16)
Reshape sparse_embedding1 reshape1 (None, 416)
Reshape sparse_embedding2 reshape2 (None, 1)
Concat reshape1,dense concat1 (None, 429)
InnerProduct concat1 fc1 (None, 1024)
ReLU fc1 relu1 (None, 1024)
Dropout relu1 dropout1 (None, 1024)
InnerProduct dropout1 fc2 (None, 1024)
ReLU fc2 relu2 (None, 1024)
Dropout relu2 dropout2 (None, 1024)
InnerProduct dropout2 fc3 (None, 1)
Add fc3,reshape2 add1 (None, 1)
BinaryCrossEntropyLoss add1,label loss
------------------------------------------------------------------------------------------------------------------
[HUGECTR][12:37:04][INFO][RANK0]: Save the model graph to wdl.json successfully
=====================================================Model Fit=====================================================
[HUGECTR][12:37:04][INFO][RANK0]: Use embedding training cache mode with number of training sources: 2, number of epochs: 1
[HUGECTR][12:37:04][INFO][RANK0]: Training batchsize: 1024, evaluation batchsize: 1024
[HUGECTR][12:37:04][INFO][RANK0]: Evaluation interval: 1000, snapshot interval: 10000
[HUGECTR][12:37:04][INFO][RANK0]: Sparse embedding trainable: True, dense network trainable: True
[HUGECTR][12:37:04][INFO][RANK0]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HUGECTR][12:37:04][INFO][RANK0]: lr: 0.001000, warmup_steps: 1, decay_start: 0, decay_steps: 1, decay_power: 2.000000, end_lr: 0.000000
[HUGECTR][12:37:04][INFO][RANK0]: Evaluation source file: wdl_data/file_list.2.txt
[HUGECTR][12:37:04][INFO][RANK0]: --------------------Epoch 0, source file: wdl_data/file_list.0.txt--------------------
[HUGECTR][12:37:04][INFO][RANK0]: Preparing embedding table for next pass
[HUGECTR][12:37:05][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 0 %
[HUGECTR][12:37:07][INFO][RANK0]: Iter: 500 Time(500 iters): 2.959942s Loss: 0.140601 lr:0.001000
[HUGECTR][12:37:10][INFO][RANK0]: Iter: 1000 Time(500 iters): 2.687422s Loss: 0.127723 lr:0.001000
[HUGECTR][12:37:15][INFO][RANK0]: Evaluation, AUC: 0.738460
[HUGECTR][12:37:15][INFO][RANK0]: Eval Time for 5000 iters: 4.757926s
[HUGECTR][12:37:17][INFO][RANK0]: Iter: 1500 Time(500 iters): 7.310160s Loss: 0.152160 lr:0.001000
[HUGECTR][12:37:20][INFO][RANK0]: Iter: 2000 Time(500 iters): 2.613197s Loss: 0.124371 lr:0.001000
[HUGECTR][12:37:22][INFO][RANK0]: Evaluation, AUC: 0.745345
[HUGECTR][12:37:22][INFO][RANK0]: Eval Time for 5000 iters: 1.907179s
[HUGECTR][12:37:24][INFO][RANK0]: Iter: 2500 Time(500 iters): 4.343850s Loss: 0.134511 lr:0.001000
[HUGECTR][12:37:27][INFO][RANK0]: Iter: 3000 Time(500 iters): 2.505121s Loss: 0.119222 lr:0.001000
[HUGECTR][12:37:28][INFO][RANK0]: Evaluation, AUC: 0.751256
[HUGECTR][12:37:28][INFO][RANK0]: Eval Time for 5000 iters: 1.900262s
[HUGECTR][12:37:31][INFO][RANK0]: Iter: 3500 Time(500 iters): 4.459760s Loss: 0.145278 lr:0.001000
[HUGECTR][12:37:34][INFO][RANK0]: Iter: 4000 Time(500 iters): 2.544999s Loss: 0.134373 lr:0.001000
[HUGECTR][12:37:35][INFO][RANK0]: Evaluation, AUC: 0.753270
[HUGECTR][12:37:35][INFO][RANK0]: Eval Time for 5000 iters: 1.901368s
[HUGECTR][12:37:35][INFO][RANK0]: --------------------Epoch 0, source file: wdl_data/file_list.1.txt--------------------
[HUGECTR][12:37:35][INFO][RANK0]: Preparing embedding table for next pass
[HUGECTR][12:37:37][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 0 %
[HUGECTR][12:37:37][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 0 %
[HUGECTR][12:37:40][INFO][RANK0]: Iter: 4500 Time(500 iters): 6.212693s Loss: 0.131819 lr:0.001000
[HUGECTR][12:37:42][INFO][RANK0]: Iter: 5000 Time(500 iters): 2.660587s Loss: 0.117531 lr:0.001000
[HUGECTR][12:37:44][INFO][RANK0]: Evaluation, AUC: 0.754530
[HUGECTR][12:37:44][INFO][RANK0]: Eval Time for 5000 iters: 1.897969s
[HUGECTR][12:37:47][INFO][RANK0]: Iter: 5500 Time(500 iters): 4.340803s Loss: 0.118400 lr:0.001000
[HUGECTR][12:37:49][INFO][RANK0]: Iter: 6000 Time(500 iters): 2.497391s Loss: 0.143188 lr:0.001000
[HUGECTR][12:37:51][INFO][RANK0]: Evaluation, AUC: 0.755805
[HUGECTR][12:37:51][INFO][RANK0]: Eval Time for 5000 iters: 1.904572s
[HUGECTR][12:37:54][INFO][RANK0]: Iter: 6500 Time(500 iters): 4.332877s Loss: 0.159262 lr:0.001000
[HUGECTR][12:37:56][INFO][RANK0]: Iter: 7000 Time(500 iters): 2.426105s Loss: 0.119848 lr:0.001000
[HUGECTR][12:37:58][INFO][RANK0]: Evaluation, AUC: 0.757338
[HUGECTR][12:37:58][INFO][RANK0]: Eval Time for 5000 iters: 1.900609s
[HUGECTR][12:38:00][INFO][RANK0]: Iter: 7500 Time(500 iters): 4.348594s Loss: 0.139543 lr:0.001000
[HUGECTR][12:38:03][INFO][RANK0]: Iter: 8000 Time(500 iters): 2.424926s Loss: 0.109002 lr:0.001000
[HUGECTR][12:38:05][INFO][RANK0]: Evaluation, AUC: 0.758067
[HUGECTR][12:38:05][INFO][RANK0]: Eval Time for 5000 iters: 1.900712s
=====================================================Model Fit=====================================================
[HUGECTR][12:38:05][INFO][RANK0]: Use embedding training cache mode with number of training sources: 2, number of epochs: 1
[HUGECTR][12:38:05][INFO][RANK0]: Training batchsize: 1024, evaluation batchsize: 1024
[HUGECTR][12:38:05][INFO][RANK0]: Evaluation interval: 1000, snapshot interval: 10000
[HUGECTR][12:38:05][INFO][RANK0]: Sparse embedding trainable: True, dense network trainable: True
[HUGECTR][12:38:05][INFO][RANK0]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HUGECTR][12:38:05][INFO][RANK0]: lr: 0.001000, warmup_steps: 1, decay_start: 0, decay_steps: 1, decay_power: 2.000000, end_lr: 0.000000
[HUGECTR][12:38:05][INFO][RANK0]: Evaluation source file: wdl_data/file_list.5.txt
[HUGECTR][12:38:05][INFO][RANK0]: --------------------Epoch 0, source file: wdl_data/file_list.3.txt--------------------
[HUGECTR][12:38:05][INFO][RANK0]: Preparing embedding table for next pass
[HUGECTR][12:38:06][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 77.89 %
[HUGECTR][12:38:06][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 71.22 %
[HUGECTR][12:38:08][INFO][RANK0]: Iter: 500 Time(500 iters): 3.652928s Loss: 0.124229 lr:0.001000
[HUGECTR][12:38:11][INFO][RANK0]: Iter: 1000 Time(500 iters): 2.455519s Loss: 0.142507 lr:0.001000
[HUGECTR][12:38:13][INFO][RANK0]: Evaluation, AUC: 0.757185
[HUGECTR][12:38:13][INFO][RANK0]: Eval Time for 5000 iters: 1.909209s
[HUGECTR][12:38:15][INFO][RANK0]: Iter: 1500 Time(500 iters): 4.353392s Loss: 0.123939 lr:0.001000
[HUGECTR][12:38:18][INFO][RANK0]: Iter: 2000 Time(500 iters): 2.522630s Loss: 0.130625 lr:0.001000
[HUGECTR][12:38:21][INFO][RANK0]: Evaluation, AUC: 0.757897
[HUGECTR][12:38:21][INFO][RANK0]: Eval Time for 5000 iters: 3.763415s
[HUGECTR][12:38:24][INFO][RANK0]: Iter: 2500 Time(500 iters): 6.238394s Loss: 0.138125 lr:0.001000
[HUGECTR][12:38:26][INFO][RANK0]: Iter: 3000 Time(500 iters): 2.429449s Loss: 0.126391 lr:0.001000
[HUGECTR][12:38:28][INFO][RANK0]: Evaluation, AUC: 0.757794
[HUGECTR][12:38:28][INFO][RANK0]: Eval Time for 5000 iters: 1.902641s
[HUGECTR][12:38:31][INFO][RANK0]: Iter: 3500 Time(500 iters): 4.398343s Loss: 0.123047 lr:0.001000
[HUGECTR][12:38:33][INFO][RANK0]: Iter: 4000 Time(500 iters): 2.420357s Loss: 0.142649 lr:0.001000
[HUGECTR][12:38:35][INFO][RANK0]: Evaluation, AUC: 0.760467
[HUGECTR][12:38:35][INFO][RANK0]: Eval Time for 5000 iters: 1.899759s
[HUGECTR][12:38:35][INFO][RANK0]: --------------------Epoch 0, source file: wdl_data/file_list.4.txt--------------------
[HUGECTR][12:38:35][INFO][RANK0]: Preparing embedding table for next pass
[HUGECTR][12:38:37][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 64.88 %
[HUGECTR][12:38:37][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 67.35 %
[HUGECTR][12:38:40][INFO][RANK0]: Iter: 4500 Time(500 iters): 6.318908s Loss: 0.165160 lr:0.001000
[HUGECTR][12:38:42][INFO][RANK0]: Iter: 5000 Time(500 iters): 2.423369s Loss: 0.112445 lr:0.001000
[HUGECTR][12:38:44][INFO][RANK0]: Evaluation, AUC: 0.759795
[HUGECTR][12:38:44][INFO][RANK0]: Eval Time for 5000 iters: 1.902252s
[HUGECTR][12:38:46][INFO][RANK0]: Iter: 5500 Time(500 iters): 4.329618s Loss: 0.150855 lr:0.001000
[HUGECTR][12:38:49][INFO][RANK0]: Iter: 6000 Time(500 iters): 2.422831s Loss: 0.121576 lr:0.001000
[HUGECTR][12:38:51][INFO][RANK0]: Evaluation, AUC: 0.760036
[HUGECTR][12:38:51][INFO][RANK0]: Eval Time for 5000 iters: 1.896330s
[HUGECTR][12:38:53][INFO][RANK0]: Iter: 6500 Time(500 iters): 4.352440s Loss: 0.131191 lr:0.001000
[HUGECTR][12:38:56][INFO][RANK0]: Iter: 7000 Time(500 iters): 2.426486s Loss: 0.130866 lr:0.001000
[HUGECTR][12:38:57][INFO][RANK0]: Evaluation, AUC: 0.761125
[HUGECTR][12:38:57][INFO][RANK0]: Eval Time for 5000 iters: 1.910397s
[HUGECTR][12:39:00][INFO][RANK0]: Iter: 7500 Time(500 iters): 4.364026s Loss: 0.096611 lr:0.001000
[HUGECTR][12:39:03][INFO][RANK0]: Iter: 8000 Time(500 iters): 2.664058s Loss: 0.142381 lr:0.001000
[HUGECTR][12:39:05][INFO][RANK0]: Evaluation, AUC: 0.762636
[HUGECTR][12:39:05][INFO][RANK0]: Eval Time for 5000 iters: 1.975668s
[HUGECTR][12:39:06][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 64.82 %
[HUGECTR][12:39:07][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 57.86 %
[HUGECTR][12:39:07][INFO][RANK0]: Get updated portion of embedding table [DONE}
[HUGECTR][12:39:08][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 64.82 %
[HUGECTR][12:39:08][INFO][RANK0]: Updating sparse model in SSD [DONE]
[HUGECTR][12:39:10][INFO][RANK0]: Sync blocks from HMEM-Cache to SSD
████████████████████████████████████████▏ 100.0% [ 2/ 2 | 66.7 Hz | 0s<0s] m
[HUGECTR][12:39:10][INFO][RANK0]: Dumping dense weights to file, successful
[HUGECTR][12:39:10][INFO][RANK0]: Dumping dense optimizer states to file, successful
[HUGECTR][12:39:10][INFO][RANK0]: Dumping untrainable weights to file, successful
Continuous Training with the Low-level API
This section gives the code sample for continuous training using the low-level API. The program logic is the same as the preceding code sample.
Although the low-level APIs provide fine-grained control of the training logic, we encourage you to use the high-level API if it can satisfy your requirements because the naked data reader and embedding training cache logics are not straightforward and error prone.
For more about the low-level API, please refer to Low-level Training API and samples of Low-level Training.
%%writefile wdl_etc.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
batchsize_eval = 1024,
batchsize = 1024,
vvgpu = [[0]],
i64_input_key = False,
use_mixed_precision = False,
repeat_dataset = False,
use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
source = ["wdl_data/file_list."+str(i)+".txt" for i in range(2)],
keyset = ["wdl_data/file_list."+str(i)+".keyset" for i in range(2)],
eval_source = "wdl_data/file_list.2.txt",
check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
hc_cnfg = hugectr.CreateHMemCache(num_blocks = 2, target_hit_rate = 0.5, max_num_evict = 0)
etc = hugectr.CreateETC(ps_types = [hugectr.TrainPSType_t.Staged, hugectr.TrainPSType_t.Cached],
sparse_models = ["./wdl_0_sparse_model", "./wdl_1_sparse_model"],
local_paths = ["./"], hmem_cache_configs = [hc_cnfg])
model = hugectr.Model(solver, reader, optimizer, etc)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = True)
model.compile()
lr_sch = model.get_learning_rate_scheduler()
data_reader_train = model.get_data_reader_train()
data_reader_eval = model.get_data_reader_eval()
etc = model.get_embedding_training_cache()
dataset = [("wdl_data/file_list."+str(i)+".txt", "wdl_data/file_list."+str(i)+".keyset") for i in range(2)]
data_reader_eval.set_source("wdl_data/file_list.2.txt")
data_reader_eval_flag = True
iteration = 0
for file_list, keyset_file in dataset:
data_reader_train.set_source(file_list)
data_reader_train_flag = True
etc.update(keyset_file)
while True:
lr = lr_sch.get_next()
model.set_learning_rate(lr)
data_reader_train_flag = model.train()
if not data_reader_train_flag:
break
if iteration % 1000 == 0:
batches = 0
while data_reader_eval_flag:
if batches >= solver.max_eval_batches:
break
data_reader_eval_flag = model.eval()
batches += 1
if not data_reader_eval_flag:
data_reader_eval.set_source()
data_reader_eval_flag = True
metrics = model.get_eval_metrics()
print("[HUGECTR][INFO] iter: {}, metrics: {}".format(iteration, metrics))
iteration += 1
print("[HUGECTR][INFO] trained with data in {}".format(file_list))
dataset = [("wdl_data/file_list."+str(i)+".txt", "wdl_data/file_list."+str(i)+".keyset") for i in range(3, 5)]
for file_list, keyset_file in dataset:
data_reader_train.set_source(file_list)
data_reader_train_flag = True
etc.update(keyset_file)
while True:
lr = lr_sch.get_next()
model.set_learning_rate(lr)
data_reader_train_flag = model.train()
if not data_reader_train_flag:
break
if iteration % 1000 == 0:
batches = 0
while data_reader_eval_flag:
if batches >= solver.max_eval_batches:
break
data_reader_eval_flag = model.eval()
batches += 1
if not data_reader_eval_flag:
data_reader_eval.set_source()
data_reader_eval_flag = True
metrics = model.get_eval_metrics()
print("[HUGECTR][INFO] iter: {}, metrics: {}".format(iteration, metrics))
iteration += 1
print("[HUGECTR][INFO] trained with data in {}".format(file_list))
incremental_model = model.get_incremental_model()
model.save_params_to_files("wdl_etc")
Writing wdl_etc.py
!python3 wdl_etc.py
[HUGECTR][12:39:44][INFO][RANK0]: Empty embedding, trained table will be stored in ./wdl_0_sparse_model
[HUGECTR][12:39:44][INFO][RANK0]: Empty embedding, trained table will be stored in ./wdl_1_sparse_model
HugeCTR Version: 3.2
====================================================Model Init=====================================================
[HUGECTR][12:39:44][INFO][RANK0]: Global seed is 3498697826
[HUGECTR][12:39:44][INFO][RANK0]: Device to NUMA mapping:
GPU 0 -> node 0
[HUGECTR][12:39:45][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.
[HUGECTR][12:39:45][INFO][RANK0]: Start all2all warmup
[HUGECTR][12:39:45][INFO][RANK0]: End all2all warmup
[HUGECTR][12:39:45][INFO][RANK0]: Using All-reduce algorithm: NCCL
[HUGECTR][12:39:45][INFO][RANK0]: Device 0: Tesla V100-SXM2-32GB
[HUGECTR][12:39:45][INFO][RANK0]: num of DataReader workers: 12
[HUGECTR][12:39:45][INFO][RANK0]: max_num_frequent_categories is not specified using default: 1
[HUGECTR][12:39:45][INFO][RANK0]: max_num_infrequent_samples is not specified using default: -1
[HUGECTR][12:39:45][INFO][RANK0]: p_dup_max is not specified using default: 0.010000
[HUGECTR][12:39:45][INFO][RANK0]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[HUGECTR][12:39:45][INFO][RANK0]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[HUGECTR][12:39:45][INFO][RANK0]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[HUGECTR][12:39:45][INFO][RANK0]: communication_type is not specified using default: IB_NVLink
[HUGECTR][12:39:45][INFO][RANK0]: hybrid_embedding_type is not specified using default: Distributed
[HUGECTR][12:39:45][INFO][RANK0]: max_vocabulary_size_per_gpu_=6029312
[HUGECTR][12:39:45][INFO][RANK0]: max_num_frequent_categories is not specified using default: 1
[HUGECTR][12:39:45][INFO][RANK0]: max_num_infrequent_samples is not specified using default: -1
[HUGECTR][12:39:45][INFO][RANK0]: p_dup_max is not specified using default: 0.010000
[HUGECTR][12:39:45][INFO][RANK0]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[HUGECTR][12:39:45][INFO][RANK0]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[HUGECTR][12:39:45][INFO][RANK0]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[HUGECTR][12:39:45][INFO][RANK0]: communication_type is not specified using default: IB_NVLink
[HUGECTR][12:39:45][INFO][RANK0]: hybrid_embedding_type is not specified using default: Distributed
[HUGECTR][12:39:45][INFO][RANK0]: max_vocabulary_size_per_gpu_=5865472
[HUGECTR][12:39:45][INFO][RANK0]: Load the model graph from wdl.json successfully
[HUGECTR][12:39:45][INFO][RANK0]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HUGECTR][12:39:49][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][12:39:49][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][12:39:49][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][12:39:49][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][12:39:49][INFO][RANK0]: Enable HMEM-Based Parameter Server
[HUGECTR][12:39:49][INFO][RANK0]: ./wdl_0_sparse_model not exist, create and train from scratch
[HUGECTR][12:39:49][INFO][RANK0]: Enable HMemCache-Based Parameter Server
[HUGECTR][12:39:49][INFO][RANK0]: ./wdl_1_sparse_model/key doesn't exist, created
[HUGECTR][12:39:49][INFO][RANK0]: ./wdl_1_sparse_model/emb_vector doesn't exist, created
[HUGECTR][12:39:49][INFO][RANK0]: ./wdl_1_sparse_model/Adam.m doesn't exist, created
[HUGECTR][12:39:49][INFO][RANK0]: ./wdl_1_sparse_model/Adam.v doesn't exist, created
[HUGECTR][12:39:50][INFO][RANK0]: Starting AUC NCCL warm-up
[HUGECTR][12:39:50][INFO][RANK0]: Warm-up done
[HUGECTR][12:39:50][INFO][RANK0]: Preparing embedding table for next pass
[HUGECTR][12:39:50][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 0 %
[HUGECTR][INFO] iter: 0, metrics: [('AUC', 0.4865134358406067)]
[HUGECTR][INFO] iter: 1000, metrics: [('AUC', 0.7405899167060852)]
[HUGECTR][INFO] iter: 2000, metrics: [('AUC', 0.7468112707138062)]
[HUGECTR][INFO] iter: 3000, metrics: [('AUC', 0.7530832290649414)]
[HUGECTR][INFO] trained with data in wdl_data/file_list.0.txt
[HUGECTR][12:40:28][INFO][RANK0]: Preparing embedding table for next pass
[HUGECTR][12:40:30][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 0 %
[HUGECTR][12:40:30][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 0 %
[HUGECTR][INFO] iter: 4000, metrics: [('AUC', 0.7554274201393127)]
[HUGECTR][INFO] iter: 5000, metrics: [('AUC', 0.7563489079475403)]
[HUGECTR][INFO] iter: 6000, metrics: [('AUC', 0.7577884197235107)]
[HUGECTR][INFO] iter: 7000, metrics: [('AUC', 0.7599539160728455)]
[HUGECTR][INFO] trained with data in wdl_data/file_list.1.txt
[HUGECTR][12:41:08][INFO][RANK0]: Preparing embedding table for next pass
[HUGECTR][12:41:09][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 77.89 %
[HUGECTR][12:41:10][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 71.22 %
[HUGECTR][INFO] iter: 8000, metrics: [('AUC', 0.7602559328079224)]
[HUGECTR][INFO] iter: 9000, metrics: [('AUC', 0.7596363425254822)]
[HUGECTR][INFO] iter: 10000, metrics: [('AUC', 0.7619153261184692)]
[HUGECTR][INFO] iter: 11000, metrics: [('AUC', 0.7607191801071167)]
[HUGECTR][INFO] trained with data in wdl_data/file_list.3.txt
[HUGECTR][12:41:48][INFO][RANK0]: Preparing embedding table for next pass
[HUGECTR][12:41:50][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 64.88 %
[HUGECTR][12:41:50][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 67.35 %
[HUGECTR][INFO] iter: 12000, metrics: [('AUC', 0.763184666633606)]
[HUGECTR][INFO] iter: 13000, metrics: [('AUC', 0.7622747421264648)]
[HUGECTR][INFO] iter: 14000, metrics: [('AUC', 0.7623080015182495)]
[HUGECTR][INFO] iter: 15000, metrics: [('AUC', 0.7622851729393005)]
[HUGECTR][INFO] trained with data in wdl_data/file_list.4.txt
[HUGECTR][12:42:30][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 64.82 %
[HUGECTR][12:42:31][INFO][RANK0]: HMEM-Cache PS: Hit rate [load]: 63.85 %
[HUGECTR][12:42:31][INFO][RANK0]: Get updated portion of embedding table [DONE}
[HUGECTR][12:42:32][INFO][RANK0]: HMEM-Cache PS: Hit rate [dump]: 64.82 %
[HUGECTR][12:42:32][INFO][RANK0]: Updating sparse model in SSD [DONE]
[HUGECTR][12:42:34][INFO][RANK0]: Sync blocks from HMEM-Cache to SSD
████████████████████████████████████████▏ 100.0% [ 2/ 2 | 64.6 Hz | 0s<0s] m
[HUGECTR][12:42:34][INFO][RANK0]: Dumping dense weights to file, successful
[HUGECTR][12:42:34][INFO][RANK0]: Dumping dense optimizer states to file, successful
[HUGECTR][12:42:34][INFO][RANK0]: Dumping untrainable weights to file, successful