 
Introduction to the HugeCTR Python Interface
Overview
HugeCTR version 3.1 introduces an enhanced Python interface The interface supports continuous training and inference with high-level APIs. There are four main improvements.
- First, the model graph can be constructed and dumped to a JSON file with Python code and it saves users from writing JSON configuration files. 
- Second, the API supports the feature of embedding training cache with high-level APIs and extends it further for online training cases. (For learn about continuous training, you can view the example notebook). 
- Third, the freezing method is provided for both sparse embedding and dense network. This method enables transfer learning and fine-tuning for CTR tasks. 
- Finally, the pre-trained embeddings in other formats can be converted to HugeCTR sparse models and then loaded to facilitate the training process. This is shown in the Load Pre-trained Embeddings section of this notebook. 
This notebook explains how to access and use the enhanced HugeCTR Python interface. Although the low-level training APIs are still maintained for users who want to have precise control of each training iteration, migrating to the high-level training APIs is strongly recommended. For more details of the usage of the Python API, refer to the HugeCTR Python Interface documentation.
Installation
Get HugeCTR from NGC
The HugeCTR Python module is preinstalled in the 22.09 and later Merlin Training Container: nvcr.io/nvidia/merlin/merlin-hugectr:22.09.
You can check the existence of the required libraries by running the following Python code after launching this container.
$ python3 -c "import hugectr"
If you prefer to build HugeCTR from the source code instead of using the NGC container, refer to the How to Start Your Development documentation.
DCN Model
Download and Preprocess Data
- Download the Criteo dataset using the following command: - $ cd ${project-root}/tools $ wget https://storage.googleapis.com/criteo-cail-datasets/day_1.gz - In preprocessing, we will further reduce the amounts of data to speedup the preprocessing, fill missing values, remove the feature values whose occurrences are very rare, etc. Here we choose pandas preprocessing method to make the dataset ready for HugeCTR training. 
- Preprocessing by Pandas using the following command: - $ bash preprocess.sh 1 dcn_data pandas 1 0 - The first argument represents the dataset postfix. It is 1 here since day_1 is used. The second argument dcn_data is where the preprocessed data is stored. The fourth argument (one after pandas) 1 embodies that the normalization is applied to dense features. The last argument 0 means that the feature crossing is not applied. 
- Create a soft link to the dataset folder using the following command: - $ ln ${project-root}/tools/dcn_data ${project_root}/notebooks/dcn_data 
Note: It will take a while (dozens of minutes) to preprocess the dataset. Please make sure that it is finished successfully before moving forward to the next section.
Train from Scratch
We can train fom scratch, dump the model graph to a JSON file, and save the model weights and optimizer states by performing the following with Python APIs:
- Create the solver, reader and optimizer, then initialize the model. 
- Construct the model graph by adding input, sparse embedding and dense layers in order. 
- Compile the model and have an overview of the model graph. 
- Dump the model graph to the JSON file. 
- Fit the model, save the model weights and optimizer states implicitly. 
Please note that the training mode is determined by repeat_dataset within hugectr.CreateSolver.
If it is True, the non-epoch mode training is adopted and the maximum iterations should be specified by max_iter within hugectr.Model.fit.
If it is False, the epoch-mode training is adopted and the number of epochs should be specified by num_epochs within hugectr.Model.fit.
The optimizer that is used to initialize the model applies to the weights of dense layers, while the optimizer for each sparse embedding layer can be specified independently within hugectr.SparseEmbedding.
import hugectr
%%writefile dcn_train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 1500,
                              batchsize_eval = 4096,
                              batchsize = 4096,
                              lr = 0.001,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = True,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                                  source = ["./dcn_data/file_list.txt"],
                                  eval_source = "./dcn_data/file_list_test.txt",
                                  check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
                        dense_dim = 13, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("data1", 2, False, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 264,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "data1",
                            optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.MultiCross,
                            bottom_names = ["concat1"],
                            top_names = ["multicross1"],
                            num_layers=6))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc2"],
                            top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu2"],
                            top_names = ["dropout2"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["dropout2", "multicross1"],
                            top_names = ["concat2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat2"],
                            top_names = ["fc3"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["fc3", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.graph_to_json(graph_config_file = "dcn.json")
model.fit(max_iter = 1200, display = 500, eval_interval = 100, snapshot = 1000, snapshot_prefix = "dcn")
Writing dcn_train.py
!python3 dcn_train.py
HugeCTR Version: 3.2
====================================================Model Init=====================================================
[HUGECTR][03:31:21][INFO][RANK0]: Global seed is 1645340130
[HUGECTR][03:31:22][INFO][RANK0]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HUGECTR][03:31:23][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.
[HUGECTR][03:31:23][INFO][RANK0]: Start all2all warmup
[HUGECTR][03:31:23][INFO][RANK0]: End all2all warmup
[HUGECTR][03:31:23][INFO][RANK0]: Using All-reduce algorithm: NCCL
[HUGECTR][03:31:23][INFO][RANK0]: Device 0: Tesla V100-SXM2-32GB
[HUGECTR][03:31:23][INFO][RANK0]: num of DataReader workers: 12
[HUGECTR][03:31:23][INFO][RANK0]: max_vocabulary_size_per_gpu_=1441792
[HUGECTR][03:31:23][INFO][RANK0]: Graph analysis to resolve tensor dependency
[HUGECTR][03:31:23][INFO][RANK0]: Add Slice layer for tensor: concat1, creating 2 copies
===================================================Model Compile===================================================
[HUGECTR][03:31:35][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][03:31:35][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][03:31:35][INFO][RANK0]: Starting AUC NCCL warm-up
[HUGECTR][03:31:35][INFO][RANK0]: Warm-up done
===================================================Model Summary===================================================
label                                   Dense                         Sparse                        
label                                   dense                          data1                         
(None, 1)                               (None, 13)                              
------------------------------------------------------------------------------------------------------------------
Layer Type                              Input Name                    Output Name                   Output Shape                  
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      data1                         sparse_embedding1             (None, 26, 16)                
Reshape                                 sparse_embedding1             reshape1                      (None, 416)                   
Concat                                  reshape1,dense                concat1                       (None, 429)                   
Slice                                   concat1                       concat1_slice0,concat1_slice1                               
MultiCross                              concat1_slice0                multicross1                   (None, 429)                   
InnerProduct                            concat1_slice1                fc1                           (None, 1024)                  
ReLU                                    fc1                           relu1                         (None, 1024)                  
Dropout                                 relu1                         dropout1                      (None, 1024)                  
InnerProduct                            dropout1                      fc2                           (None, 1024)                  
ReLU                                    fc2                           relu2                         (None, 1024)                  
Dropout                                 relu2                         dropout2                      (None, 1024)                  
Concat                                  dropout2,multicross1          concat2                       (None, 1453)                  
InnerProduct                            concat2                       fc3                           (None, 1)                     
BinaryCrossEntropyLoss                  fc3,label                     loss                                                        
------------------------------------------------------------------------------------------------------------------
[HUGECTR][03:31:35][INFO][RANK0]: Save the model graph to dcn.json successfully
=====================================================Model Fit=====================================================
[HUGECTR][03:31:35][INFO][RANK0]: Use non-epoch mode with number of iterations: 1200
[HUGECTR][03:31:35][INFO][RANK0]: Training batchsize: 4096, evaluation batchsize: 4096
[HUGECTR][03:31:35][INFO][RANK0]: Evaluation interval: 100, snapshot interval: 1000
[HUGECTR][03:31:35][INFO][RANK0]: Sparse embedding trainable: 1, dense network trainable: 1
[HUGECTR][03:31:35][INFO][RANK0]: Use mixed precision: 0, scaler: 1, use cuda graph: -510996182
[HUGECTR][03:31:35][INFO][RANK0]: lr: 0.001000, warmup_steps: 1, decay_start: 0, decay_steps: 1, decay_power: 2.000000, end_lr: 0.000000
[HUGECTR][03:31:35][INFO][RANK0]: Training source file: ./dcn_data/file_list.txt
[HUGECTR][03:31:35][INFO][RANK0]: Evaluation source file: ./dcn_data/file_list_test.txt
[HUGECTR][03:31:39][INFO][RANK0]: Evaluation, AUC: 0.717558
[HUGECTR][03:31:39][INFO][RANK0]: Eval Time for 1500 iters: 2.957773s
[HUGECTR][03:31:43][INFO][RANK0]: Evaluation, AUC: 0.735452
[HUGECTR][03:31:43][INFO][RANK0]: Eval Time for 1500 iters: 2.963541s
[HUGECTR][03:31:47][INFO][RANK0]: Evaluation, AUC: 0.741079
[HUGECTR][03:31:47][INFO][RANK0]: Eval Time for 1500 iters: 2.959102s
[HUGECTR][03:31:50][INFO][RANK0]: Evaluation, AUC: 0.745329
[HUGECTR][03:31:50][INFO][RANK0]: Eval Time for 1500 iters: 2.964232s
[HUGECTR][03:31:51][INFO][RANK0]: Iter: 500 Time(500 iters): 15.479323s Loss: 0.117504 lr:0.001000
[HUGECTR][03:31:54][INFO][RANK0]: Evaluation, AUC: 0.749935
[HUGECTR][03:31:54][INFO][RANK0]: Eval Time for 1500 iters: 2.961690s
[HUGECTR][03:31:58][INFO][RANK0]: Evaluation, AUC: 0.750517
[HUGECTR][03:31:58][INFO][RANK0]: Eval Time for 1500 iters: 2.963790s
[HUGECTR][03:32:01][INFO][RANK0]: Evaluation, AUC: 0.754112
[HUGECTR][03:32:01][INFO][RANK0]: Eval Time for 1500 iters: 2.965818s
[HUGECTR][03:32:05][INFO][RANK0]: Evaluation, AUC: 0.755083
[HUGECTR][03:32:05][INFO][RANK0]: Eval Time for 1500 iters: 2.962515s
[HUGECTR][03:32:09][INFO][RANK0]: Evaluation, AUC: 0.755834
[HUGECTR][03:32:09][INFO][RANK0]: Eval Time for 1500 iters: 2.967796s
[HUGECTR][03:32:09][INFO][RANK0]: Iter: 1000 Time(500 iters): 18.362356s Loss: 0.154462 lr:0.001000
[HUGECTR][03:32:12][INFO][RANK0]: Evaluation, AUC: 0.758410
[HUGECTR][03:32:12][INFO][RANK0]: Eval Time for 1500 iters: 2.969008s
[HUGECTR][03:32:12][INFO][RANK0]: Rank0: Write hash table to file
[HUGECTR][03:32:13][INFO][RANK0]: Dumping sparse weights to files, successful
[HUGECTR][03:32:13][INFO][RANK0]: Rank0: Write optimzer state to file
[HUGECTR][03:32:13][INFO][RANK0]: Done
[HUGECTR][03:32:13][INFO][RANK0]: Rank0: Write optimzer state to file
[HUGECTR][03:32:13][INFO][RANK0]: Done
[HUGECTR][03:32:15][INFO][RANK0]: Dumping sparse optimzer states to files, successful
[HUGECTR][03:32:15][INFO][RANK0]: Dumping dense weights to file, successful
[HUGECTR][03:32:15][INFO][RANK0]: Dumping dense optimizer states to file, successful
[HUGECTR][03:32:15][INFO][RANK0]: Dumping untrainable weights to file, successful
[HUGECTR][03:32:19][INFO][RANK0]: Evaluation, AUC: 0.758818
[HUGECTR][03:32:19][INFO][RANK0]: Eval Time for 1500 iters: 2.966687s
[HUGECTR][03:32:20][INFO][RANK0]: Finish 1200 iterations with batchsize: 4096 in 44.08s.
Continue Training
We can continue our training based on the saved model graph, model weights, and optimizer states by performing the following with Python APIs:
- Create the solver, reader and optimizer, then initialize the model. 
- Construct the model graph from the saved JSON file. 
- Compile the model and have an overview of the model graph. 
- Load the model weights and optimizer states. 
- Fit the model, save the model weights and optimizer states implicitly. 
!ls *.model
dcn0_opt_sparse_1000.model  dcn_dense_1000.model  dcn_opt_dense_1000.model
dcn0_sparse_1000.model:
emb_vector  key
%%writefile dcn_continue.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 1500,
                              batchsize_eval = 4096,
                              batchsize = 4096,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = True,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                                  source = ["./dcn_data/file_list.txt"],
                                  eval_source = "./dcn_data/file_list_test.txt",
                                  check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "dcn.json", include_dense_network = True)
model.compile()
model.load_dense_weights("dcn_dense_1000.model")
model.load_sparse_weights(["dcn0_sparse_1000.model"])
model.load_dense_optimizer_states("dcn_opt_dense_1000.model")
model.load_sparse_optimizer_states(["dcn0_opt_sparse_1000.model"])
model.summary()
model.fit(max_iter = 500, display = 50, eval_interval = 100, snapshot = 10000, snapshot_prefix = "dcn")
Writing dcn_continue.py
!python3 dcn_continue.py
HugeCTR Version: 3.2
====================================================Model Init=====================================================
[HUGECTR][03:32:48][INFO][RANK0]: Global seed is 4147354758
[HUGECTR][03:32:49][INFO][RANK0]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HUGECTR][03:32:50][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.
[HUGECTR][03:32:50][INFO][RANK0]: Start all2all warmup
[HUGECTR][03:32:50][INFO][RANK0]: End all2all warmup
[HUGECTR][03:32:50][INFO][RANK0]: Using All-reduce algorithm: NCCL
[HUGECTR][03:32:50][INFO][RANK0]: Device 0: Tesla V100-SXM2-32GB
[HUGECTR][03:32:50][INFO][RANK0]: num of DataReader workers: 12
[HUGECTR][03:32:50][INFO][RANK0]: max_num_frequent_categories is not specified using default: 1
[HUGECTR][03:32:50][INFO][RANK0]: max_num_infrequent_samples is not specified using default: -1
[HUGECTR][03:32:50][INFO][RANK0]: p_dup_max is not specified using default: 0.010000
[HUGECTR][03:32:50][INFO][RANK0]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[HUGECTR][03:32:50][INFO][RANK0]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[HUGECTR][03:32:50][INFO][RANK0]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[HUGECTR][03:32:50][INFO][RANK0]: communication_type is not specified using default: IB_NVLink
[HUGECTR][03:32:50][INFO][RANK0]: hybrid_embedding_type is not specified using default: Distributed
[HUGECTR][03:32:50][INFO][RANK0]: max_vocabulary_size_per_gpu_=1441792
[HUGECTR][03:32:50][INFO][RANK0]: Load the model graph from dcn.json successfully
[HUGECTR][03:32:50][INFO][RANK0]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HUGECTR][03:33:02][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][03:33:02][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][03:33:02][INFO][RANK0]: Starting AUC NCCL warm-up
[HUGECTR][03:33:02][INFO][RANK0]: Warm-up done
[HUGECTR][03:33:02][INFO][RANK0]: Loading dense model: dcn_dense_1000.model
[HUGECTR][03:33:02][INFO][RANK0]: Loading sparse model: dcn0_sparse_1000.model
[HUGECTR][03:33:03][INFO][RANK0]: Loading dense opt states: dcn_opt_dense_1000.model
[HUGECTR][03:33:03][INFO][RANK0]: Loading sparse optimizer states: dcn0_opt_sparse_1000.model
[HUGECTR][03:33:03][INFO][RANK0]: Rank0: Read optimzer state from file
[HUGECTR][03:33:04][INFO][RANK0]: Done
[HUGECTR][03:33:04][INFO][RANK0]: Rank0: Read optimzer state from file
[HUGECTR][03:33:05][INFO][RANK0]: Done
===================================================Model Summary===================================================
label                                   Dense                         Sparse                        
label                                   dense                          data1                         
(None, 1)                               (None, 13)                              
------------------------------------------------------------------------------------------------------------------
Layer Type                              Input Name                    Output Name                   Output Shape                  
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      data1                         sparse_embedding1             (None, 26, 16)                
Reshape                                 sparse_embedding1             reshape1                      (None, 416)                   
Concat                                  reshape1,dense                concat1                       (None, 429)                   
Slice                                   concat1                       concat1_slice0,concat1_slice1                               
MultiCross                              concat1_slice0                multicross1                   (None, 429)                   
InnerProduct                            concat1_slice1                fc1                           (None, 1024)                  
ReLU                                    fc1                           relu1                         (None, 1024)                  
Dropout                                 relu1                         dropout1                      (None, 1024)                  
InnerProduct                            dropout1                      fc2                           (None, 1024)                  
ReLU                                    fc2                           relu2                         (None, 1024)                  
Dropout                                 relu2                         dropout2                      (None, 1024)                  
Concat                                  dropout2,multicross1          concat2                       (None, 1453)                  
InnerProduct                            concat2                       fc3                           (None, 1)                     
BinaryCrossEntropyLoss                  fc3,label                     loss                                                        
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[HUGECTR][03:33:05][INFO][RANK0]: Use non-epoch mode with number of iterations: 500
[HUGECTR][03:33:05][INFO][RANK0]: Training batchsize: 4096, evaluation batchsize: 4096
[HUGECTR][03:33:05][INFO][RANK0]: Evaluation interval: 100, snapshot interval: 10000
[HUGECTR][03:33:05][INFO][RANK0]: Sparse embedding trainable: 1, dense network trainable: 1
[HUGECTR][03:33:05][INFO][RANK0]: Use mixed precision: 0, scaler: 1, use cuda graph: 1946517802
[HUGECTR][03:33:05][INFO][RANK0]: lr: 0.001000, warmup_steps: 1, decay_start: 0, decay_steps: 1, decay_power: 2.000000, end_lr: 0.000000
[HUGECTR][03:33:05][INFO][RANK0]: Training source file: ./dcn_data/file_list.txt
[HUGECTR][03:33:05][INFO][RANK0]: Evaluation source file: ./dcn_data/file_list_test.txt
[HUGECTR][03:33:06][INFO][RANK0]: Iter: 50 Time(50 iters): 0.451251s Loss: 0.106090 lr:0.001000
[HUGECTR][03:33:06][INFO][RANK0]: Iter: 100 Time(50 iters): 0.351385s Loss: 0.128124 lr:0.001000
[HUGECTR][03:33:09][INFO][RANK0]: Evaluation, AUC: 0.741880
[HUGECTR][03:33:09][INFO][RANK0]: Eval Time for 1500 iters: 2.972658s
[HUGECTR][03:33:10][INFO][RANK0]: Iter: 150 Time(50 iters): 3.329078s Loss: 0.128845 lr:0.001000
[HUGECTR][03:33:10][INFO][RANK0]: Iter: 200 Time(50 iters): 0.351025s Loss: 0.128085 lr:0.001000
[HUGECTR][03:33:13][INFO][RANK0]: Evaluation, AUC: 0.730338
[HUGECTR][03:33:13][INFO][RANK0]: Eval Time for 1500 iters: 2.973015s
[HUGECTR][03:33:13][INFO][RANK0]: Iter: 250 Time(50 iters): 3.329451s Loss: 0.114888 lr:0.001000
[HUGECTR][03:33:14][INFO][RANK0]: Iter: 300 Time(50 iters): 0.350937s Loss: 0.106827 lr:0.001000
[HUGECTR][03:33:17][INFO][RANK0]: Evaluation, AUC: 0.728633
[HUGECTR][03:33:17][INFO][RANK0]: Eval Time for 1500 iters: 2.972206s
[HUGECTR][03:33:17][INFO][RANK0]: Iter: 350 Time(50 iters): 3.328661s Loss: 0.116533 lr:0.001000
[HUGECTR][03:33:17][INFO][RANK0]: Iter: 400 Time(50 iters): 0.351332s Loss: 0.110059 lr:0.001000
[HUGECTR][03:33:20][INFO][RANK0]: Evaluation, AUC: 0.726629
[HUGECTR][03:33:20][INFO][RANK0]: Eval Time for 1500 iters: 2.970395s
[HUGECTR][03:33:21][INFO][RANK0]: Iter: 450 Time(50 iters): 3.327333s Loss: 0.117730 lr:0.001000
[HUGECTR][03:33:21][INFO][RANK0]: Finish 500 iterations with batchsize: 4096 in 15.53s.
Inference
The HugeCTR inference is enabled by hugectr.inference.InferenceSession.predict method of InferenceSession.
This method requires dense features, embedding columns, and row pointers of slots as the input and gives the prediction result as the output.
We need to convert the Criteo data to inference format first.
!python3 ../tools/criteo_predict/criteo2predict.py --src_csv_path=dcn_data/val/test.txt --src_config=../tools/criteo_predict/dcn_data.json --dst_path=./dcn_csr.txt --batch_size=1024
We can then make inferences based on the saved model graph and model weights by performing the following with Python APIs:
- Configure the inference related parameters. 
- Create the inference session. 
- Make inference with the - InferenceSession.predictmethod.
%%writefile dcn_inference.py
from hugectr.inference import InferenceParams, CreateInferenceSession
from mpi4py import MPI
def calculate_accuracy(labels, output):
    num_samples = len(labels)
    flags = [1 if ((labels[i] == 0 and output[i] <= 0.5) or (labels[i] == 1 and output[i] > 0.5)) else 0 for i in range(num_samples)]
    correct_samples = sum(flags)
    return float(correct_samples)/(float(num_samples)+1e-16)
data_file = open("dcn_csr.txt")
config_file = "dcn.json"
labels = [int(item) for item in data_file.readline().split(' ')]
dense_features = [float(item) for item in data_file.readline().split(' ') if item!="\n"]
embedding_columns = [int(item) for item in data_file.readline().split(' ')]
row_ptrs = [int(item) for item in data_file.readline().split(' ')]
# create parameter server, embedding cache and inference session
inference_params = InferenceParams(model_name = "dcn",
                                max_batchsize = 1024,
                                hit_rate_threshold = 0.6,
                                dense_model_file = "./dcn_dense_1000.model",
                                sparse_model_files = ["./dcn0_sparse_1000.model"],
                                device_id = 0,
                                use_gpu_embedding_cache = True,
                                cache_size_percentage = 0.9,
                                i64_input_key = False,
                                use_mixed_precision = False)
inference_session = CreateInferenceSession(config_file, inference_params)
output = inference_session.predict(dense_features, embedding_columns, row_ptrs)
accuracy = calculate_accuracy(labels, output)
print("[HUGECTR][INFO] number samples: {}, accuracy: {}".format(len(labels), accuracy))
Writing dcn_inference.py
!python3 dcn_inference.py
[06d04h52m24s][HUGECTR][INFO]: default_emb_vec_value is not specified using default: 0.000000
[06d04h52m26s][HUGECTR][INFO]: Global seed is 3956797427
[06d04h52m28s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
[06d04h52m28s][HUGECTR][INFO]: Start all2all warmup
[06d04h52m28s][HUGECTR][INFO]: End all2all warmup
[06d04h52m28s][HUGECTR][INFO]: Use mixed precision: 0
[06d04h52m28s][HUGECTR][INFO]: start create embedding for inference
[06d04h52m28s][HUGECTR][INFO]: sparse_input name data1
[06d04h52m28s][HUGECTR][INFO]: create embedding for inference success
[06d04h52m28s][HUGECTR][INFO]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer
[HUGECTR][INFO] number samples: 1024, accuracy: 0.96875
Wide and Deep Model
Download and Preprocess Data
- Download the Criteo dataset using the following command: - $ cd ${project_root}/tools $ wget https://storage.googleapis.com/criteo-cail-datasets/day_1.gz - In preprocessing, we will further reduce the amounts of data to speedup the preprocessing, fill missing values, remove the feature values whose occurrences are very rare, etc. Here we choose pandas preprocessing method to make the dataset ready for HugeCTR training. 
- Preprocessing by Pandas using the following command: - $ bash preprocess.sh 1 wdl_data pandas 1 1 100 - The first argument represents the dataset postfix. It is 1 here since day_1 is used. The second argument wdl_data is where the preprocessed data is stored. The fourth argument (one after pandas) 1 embodies that the normalization is applied to dense features. The fifth argument 1 means that the feature crossing is applied. The last argument 100 means the number of data files in each file list. 
- Create a soft link to the dataset folder using the following command: - $ ln ${project_root}/tools/wdl_data ${project_root}/notebooks/wdl_data 
Note: It will take a while (dozens of minutes) to preprocess the dataset. Please make sure that it is finished successfully before moving forward to the next section.
Train from Scratch
We can train fom scratch, dump the model graph to a JSON file, and save the model weights and optimizer states by performing the same steps that we followed with the DCN Model.
%%writefile wdl_train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              lr = 0.001,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = False,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.0.txt"],
                          eval_source = "wdl_data/file_list.1.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
                        dense_dim = 13, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("wide_data", 30, True, 1),
                        hugectr.DataReaderSparseParam("deep_data", 2, False, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 69,
                            embedding_vec_size = 1,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding2",
                            bottom_name = "wide_data",
                            optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 1074,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "deep_data",
                            optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc2"],
                            top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu2"],
                            top_names = ["dropout2"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout2"],
                            top_names = ["fc3"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
                            bottom_names = ["fc3", "reshape2"],
                            top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["add1", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.graph_to_json(graph_config_file = "wdl.json")
model.fit(num_epochs = 1, display = 500, eval_interval = 500, snapshot = 4000, snapshot_prefix = "wdl")
Writing wdl_train.py
!python3 wdl_train.py
HugeCTR Version: 3.2
====================================================Model Init=====================================================
[HUGECTR][07:13:04][INFO][RANK0]: Global seed is 1910256490
[HUGECTR][07:13:04][INFO][RANK0]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HUGECTR][07:13:06][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.
[HUGECTR][07:13:06][INFO][RANK0]: Start all2all warmup
[HUGECTR][07:13:06][INFO][RANK0]: End all2all warmup
[HUGECTR][07:13:06][INFO][RANK0]: Using All-reduce algorithm: NCCL
[HUGECTR][07:13:06][INFO][RANK0]: Device 0: Tesla V100-SXM2-16GB
[HUGECTR][07:13:06][INFO][RANK0]: num of DataReader workers: 12
[HUGECTR][07:13:06][INFO][RANK0]: max_vocabulary_size_per_gpu_=6029312
[HUGECTR][07:13:06][INFO][RANK0]: max_vocabulary_size_per_gpu_=5865472
[HUGECTR][07:13:06][INFO][RANK0]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HUGECTR][07:13:09][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][07:13:09][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][07:13:09][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][07:13:09][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][07:13:09][INFO][RANK0]: Starting AUC NCCL warm-up
[HUGECTR][07:13:09][INFO][RANK0]: Warm-up done
===================================================Model Summary===================================================
label                                   Dense                         Sparse                        
label                                   dense                          wide_data,deep_data           
(None, 1)                               (None, 13)                              
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      wide_data                     sparse_embedding2             (None, 1, 1)                  
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      deep_data                     sparse_embedding1             (None, 26, 16)                
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (None, 416)                   
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding2             reshape2                      (None, 1)                     
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (None, 429)                   
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1                       fc1                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu1                         dropout1                      (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout1                      fc2                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc2                           relu2                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu2                         dropout2                      (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout2                      fc3                           (None, 1)                     
------------------------------------------------------------------------------------------------------------------
Add                                     fc3                           add1                          (None, 1)                     
                                        reshape2                                                                                  
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  add1                          loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
[HUGECTR][07:13:09][INFO][RANK0]: Save the model graph to wdl.json successfully
=====================================================Model Fit=====================================================
[HUGECTR][07:13:09][INFO][RANK0]: Use epoch mode with number of epochs: 1
[HUGECTR][07:13:09][INFO][RANK0]: Training batchsize: 1024, evaluation batchsize: 1024
[HUGECTR][07:13:09][INFO][RANK0]: Evaluation interval: 500, snapshot interval: 4000
[HUGECTR][07:13:09][INFO][RANK0]: Dense network trainable: True
[HUGECTR][07:13:09][INFO][RANK0]: Sparse embedding sparse_embedding1 trainable: True
[HUGECTR][07:13:09][INFO][RANK0]: Sparse embedding sparse_embedding2 trainable: True
[HUGECTR][07:13:09][INFO][RANK0]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HUGECTR][07:13:09][INFO][RANK0]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HUGECTR][07:13:09][INFO][RANK0]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HUGECTR][07:13:09][INFO][RANK0]: Training source file: wdl_data/file_list.0.txt
[HUGECTR][07:13:09][INFO][RANK0]: Evaluation source file: wdl_data/file_list.1.txt
[HUGECTR][07:13:09][INFO][RANK0]: -----------------------------------Epoch 0-----------------------------------
[HUGECTR][07:13:12][INFO][RANK0]: Iter: 500 Time(500 iters): 2.633535s Loss: 0.146090 lr:0.001000
[HUGECTR][07:13:24][INFO][RANK0]: Evaluation, AUC: 0.729127
[HUGECTR][07:13:24][INFO][RANK0]: Eval Time for 5000 iters: 12.329920s
[HUGECTR][07:13:27][INFO][RANK0]: Iter: 1000 Time(500 iters): 14.758588s Loss: 0.133614 lr:0.001000
[HUGECTR][07:13:29][INFO][RANK0]: Evaluation, AUC: 0.739283
[HUGECTR][07:13:29][INFO][RANK0]: Eval Time for 5000 iters: 1.892544s
[HUGECTR][07:13:31][INFO][RANK0]: Iter: 1500 Time(500 iters): 4.309439s Loss: 0.145478 lr:0.001000
[HUGECTR][07:13:33][INFO][RANK0]: Evaluation, AUC: 0.744546
[HUGECTR][07:13:33][INFO][RANK0]: Eval Time for 5000 iters: 1.888380s
[HUGECTR][07:13:35][INFO][RANK0]: Iter: 2000 Time(500 iters): 4.322345s Loss: 0.142099 lr:0.001000
[HUGECTR][07:13:37][INFO][RANK0]: Evaluation, AUC: 0.748392
[HUGECTR][07:13:37][INFO][RANK0]: Eval Time for 5000 iters: 1.894575s
[HUGECTR][07:13:40][INFO][RANK0]: Iter: 2500 Time(500 iters): 4.354580s Loss: 0.167694 lr:0.001000
[HUGECTR][07:13:42][INFO][RANK0]: Evaluation, AUC: 0.748089
[HUGECTR][07:13:42][INFO][RANK0]: Eval Time for 5000 iters: 1.853501s
[HUGECTR][07:13:44][INFO][RANK0]: Iter: 3000 Time(500 iters): 4.269730s Loss: 0.124279 lr:0.001000
[HUGECTR][07:13:46][INFO][RANK0]: Evaluation, AUC: 0.753290
[HUGECTR][07:13:46][INFO][RANK0]: Eval Time for 5000 iters: 1.906998s
[HUGECTR][07:13:48][INFO][RANK0]: Iter: 3500 Time(500 iters): 4.328614s Loss: 0.114806 lr:0.001000
[HUGECTR][07:13:50][INFO][RANK0]: Evaluation, AUC: 0.755007
[HUGECTR][07:13:50][INFO][RANK0]: Eval Time for 5000 iters: 1.897527s
[HUGECTR][07:13:53][INFO][RANK0]: Iter: 4000 Time(500 iters): 4.308839s Loss: 0.128652 lr:0.001000
[HUGECTR][07:13:55][INFO][RANK0]: Evaluation, AUC: 0.756323
[HUGECTR][07:13:55][INFO][RANK0]: Eval Time for 5000 iters: 1.849973s
[HUGECTR][07:13:55][INFO][RANK0]: Rank0: Write hash table to file
[HUGECTR][07:13:55][INFO][RANK0]: Rank0: Write hash table to file
[HUGECTR][07:13:55][INFO][RANK0]: Dumping sparse weights to files, successful
[HUGECTR][07:13:56][INFO][RANK0]: Rank0: Write optimzer state to file
[HUGECTR][07:13:56][INFO][RANK0]: Done
[HUGECTR][07:13:56][INFO][RANK0]: Rank0: Write optimzer state to file
[HUGECTR][07:13:56][INFO][RANK0]: Done
[HUGECTR][07:13:56][INFO][RANK0]: Rank0: Write optimzer state to file
[HUGECTR][07:13:56][INFO][RANK0]: Done
[HUGECTR][07:13:57][INFO][RANK0]: Rank0: Write optimzer state to file
[HUGECTR][07:13:57][INFO][RANK0]: Done
[HUGECTR][07:14:04][INFO][RANK0]: Dumping sparse optimzer states to files, successful
[HUGECTR][07:14:04][INFO][RANK0]: Dumping dense weights to file, successful
[HUGECTR][07:14:04][INFO][RANK0]: Dumping dense optimizer states to file, successful
[HUGECTR][07:14:04][INFO][RANK0]: Dumping untrainable weights to file, successful
[HUGECTR][07:14:04][INFO][RANK0]: Finish 1 epochs 4001 global iterations with batchsize 1024 in 54.53s.
Fine-tuning
We can only load the sparse embedding layers, their corresponding weights, and then construct a new dense network. The dense weights will be trained first and the sparse weights will be fine-tuned later. We can achieve this by performing the following with Python APIs:
- Create the solver, reader and optimizer, then initialize the model. 
- Load the sparse embedding layers from the saved JSON file. 
- Add the dense layers on top of the loaded model graph. 
- Compile the model and have an overview of the model graph. 
- Load the sparse weights and freeze the sparse embedding layers. 
- Train the dense weights. 
- Unfreeze the sparse embedding layers and freeze the dense layers, reset the learning rate scheduler with a small rate. 
- Fine-tune the sparse weights. 
%%writefile wdl_fine_tune.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = False,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.2.txt"],
                          eval_source = "wdl_data/file_list.3.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = False)
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "reshape2", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["fc2", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.load_sparse_weights(["wdl0_sparse_4000.model", "wdl1_sparse_4000.model"])
model.freeze_embedding()
model.fit(num_epochs = 1, display = 500, eval_interval = 1000, snapshot = 100000, snapshot_prefix = "wdl")
model.unfreeze_embedding()
model.freeze_dense()
model.reset_learning_rate_scheduler(base_lr = 0.0001)
model.fit(num_epochs = 2, display = 500, eval_interval = 1000, snapshot = 100000, snapshot_prefix = "wdl")
Overwriting wdl_fine_tune.py
!python3 wdl_fine_tune.py
HugeCTR Version: 3.2
====================================================Model Init=====================================================
[HUGECTR][07:29:56][INFO][RANK0]: Global seed is 2136095432
[HUGECTR][07:29:56][INFO][RANK0]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HUGECTR][07:29:58][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.
[HUGECTR][07:29:58][INFO][RANK0]: Start all2all warmup
[HUGECTR][07:29:58][INFO][RANK0]: End all2all warmup
[HUGECTR][07:29:58][INFO][RANK0]: Using All-reduce algorithm: NCCL
[HUGECTR][07:29:58][INFO][RANK0]: Device 0: Tesla V100-SXM2-16GB
[HUGECTR][07:29:58][INFO][RANK0]: num of DataReader workers: 12
[HUGECTR][07:29:58][INFO][RANK0]: max_num_frequent_categories is not specified using default: 1
[HUGECTR][07:29:58][INFO][RANK0]: max_num_infrequent_samples is not specified using default: -1
[HUGECTR][07:29:58][INFO][RANK0]: p_dup_max is not specified using default: 0.010000
[HUGECTR][07:29:58][INFO][RANK0]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[HUGECTR][07:29:58][INFO][RANK0]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[HUGECTR][07:29:58][INFO][RANK0]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[HUGECTR][07:29:58][INFO][RANK0]: communication_type is not specified using default: IB_NVLink
[HUGECTR][07:29:58][INFO][RANK0]: hybrid_embedding_type is not specified using default: Distributed
[HUGECTR][07:29:58][INFO][RANK0]: max_vocabulary_size_per_gpu_=6029312
[HUGECTR][07:29:58][INFO][RANK0]: max_num_frequent_categories is not specified using default: 1
[HUGECTR][07:29:58][INFO][RANK0]: max_num_infrequent_samples is not specified using default: -1
[HUGECTR][07:29:58][INFO][RANK0]: p_dup_max is not specified using default: 0.010000
[HUGECTR][07:29:58][INFO][RANK0]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[HUGECTR][07:29:58][INFO][RANK0]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[HUGECTR][07:29:58][INFO][RANK0]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[HUGECTR][07:29:58][INFO][RANK0]: communication_type is not specified using default: IB_NVLink
[HUGECTR][07:29:58][INFO][RANK0]: hybrid_embedding_type is not specified using default: Distributed
[HUGECTR][07:29:58][INFO][RANK0]: max_vocabulary_size_per_gpu_=5865472
[HUGECTR][07:29:58][INFO][RANK0]: Load the model graph from wdl.json successfully
[HUGECTR][07:29:58][INFO][RANK0]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HUGECTR][07:30:00][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][07:30:00][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][07:30:00][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][07:30:00][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][07:30:00][INFO][RANK0]: Starting AUC NCCL warm-up
[HUGECTR][07:30:00][INFO][RANK0]: Warm-up done
===================================================Model Summary===================================================
label                                   Dense                         Sparse                        
label                                   dense                          wide_data,deep_data           
(None, 1)                               (None, 13)                              
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      wide_data                     sparse_embedding2             (None, 1, 1)                  
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      deep_data                     sparse_embedding1             (None, 26, 16)                
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (None, 416)                   
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding2             reshape2                      (None, 1)                     
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (None, 430)                   
                                        reshape2                                                                                  
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1                       fc1                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu1                         dropout1                      (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout1                      fc2                           (None, 1)                     
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  fc2                           loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
[HUGECTR][07:30:00][INFO][RANK0]: Loading sparse model: wdl0_sparse_4000.model
[HUGECTR][07:30:00][INFO][RANK0]: Loading sparse model: wdl1_sparse_4000.model
=====================================================Model Fit=====================================================
[HUGECTR][07:30:00][INFO][RANK0]: Use epoch mode with number of epochs: 1
[HUGECTR][07:30:00][INFO][RANK0]: Training batchsize: 1024, evaluation batchsize: 1024
[HUGECTR][07:30:00][INFO][RANK0]: Evaluation interval: 1000, snapshot interval: 100000
[HUGECTR][07:30:00][INFO][RANK0]: Dense network trainable: True
[HUGECTR][07:30:00][INFO][RANK0]: Sparse embedding sparse_embedding1 trainable: False
[HUGECTR][07:30:00][INFO][RANK0]: Sparse embedding sparse_embedding2 trainable: False
[HUGECTR][07:30:00][INFO][RANK0]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HUGECTR][07:30:00][INFO][RANK0]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HUGECTR][07:30:00][INFO][RANK0]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HUGECTR][07:30:00][INFO][RANK0]: Training source file: wdl_data/file_list.2.txt
[HUGECTR][07:30:00][INFO][RANK0]: Evaluation source file: wdl_data/file_list.3.txt
[HUGECTR][07:30:00][INFO][RANK0]: -----------------------------------Epoch 0-----------------------------------
[HUGECTR][07:30:02][INFO][RANK0]: Iter: 500 Time(500 iters): 2.288018s Loss: 0.115263 lr:0.001000
[HUGECTR][07:30:04][INFO][RANK0]: Iter: 1000 Time(500 iters): 2.084800s Loss: 0.130941 lr:0.001000
[HUGECTR][07:30:06][INFO][RANK0]: Evaluation, AUC: 0.753592
[HUGECTR][07:30:06][INFO][RANK0]: Eval Time for 5000 iters: 1.233550s
[HUGECTR][07:30:08][INFO][RANK0]: Iter: 1500 Time(500 iters): 3.320545s Loss: 0.160203 lr:0.001000
[HUGECTR][07:30:10][INFO][RANK0]: Iter: 2000 Time(500 iters): 2.083907s Loss: 0.133159 lr:0.001000
[HUGECTR][07:30:11][INFO][RANK0]: Evaluation, AUC: 0.757654
[HUGECTR][07:30:11][INFO][RANK0]: Eval Time for 5000 iters: 1.257166s
[HUGECTR][07:30:13][INFO][RANK0]: Iter: 2500 Time(500 iters): 3.344821s Loss: 0.114668 lr:0.001000
[HUGECTR][07:30:15][INFO][RANK0]: Iter: 3000 Time(500 iters): 2.085232s Loss: 0.131622 lr:0.001000
[HUGECTR][07:30:17][INFO][RANK0]: Evaluation, AUC: 0.759316
[HUGECTR][07:30:17][INFO][RANK0]: Eval Time for 5000 iters: 1.307634s
[HUGECTR][07:30:19][INFO][RANK0]: Iter: 3500 Time(500 iters): 3.395008s Loss: 0.140864 lr:0.001000
[HUGECTR][07:30:21][INFO][RANK0]: Iter: 4000 Time(500 iters): 2.080470s Loss: 0.132377 lr:0.001000
[HUGECTR][07:30:22][INFO][RANK0]: Evaluation, AUC: 0.759804
[HUGECTR][07:30:22][INFO][RANK0]: Eval Time for 5000 iters: 1.176526s
[HUGECTR][07:30:22][INFO][RANK0]: Finish 1 epochs 4001 global iterations with batchsize 1024 in 21.89s.
=====================================================Model Fit=====================================================
[HUGECTR][07:30:22][INFO][RANK0]: Use epoch mode with number of epochs: 2
[HUGECTR][07:30:22][INFO][RANK0]: Training batchsize: 1024, evaluation batchsize: 1024
[HUGECTR][07:30:22][INFO][RANK0]: Evaluation interval: 1000, snapshot interval: 100000
[HUGECTR][07:30:22][INFO][RANK0]: Dense network trainable: False
[HUGECTR][07:30:22][INFO][RANK0]: Sparse embedding sparse_embedding1 trainable: True
[HUGECTR][07:30:22][INFO][RANK0]: Sparse embedding sparse_embedding2 trainable: True
[HUGECTR][07:30:22][INFO][RANK0]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HUGECTR][07:30:22][INFO][RANK0]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HUGECTR][07:30:22][INFO][RANK0]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HUGECTR][07:30:22][INFO][RANK0]: Training source file: wdl_data/file_list.2.txt
[HUGECTR][07:30:22][INFO][RANK0]: Evaluation source file: wdl_data/file_list.3.txt
[HUGECTR][07:30:22][INFO][RANK0]: -----------------------------------Epoch 0-----------------------------------
[HUGECTR][07:30:24][INFO][RANK0]: Iter: 500 Time(500 iters): 2.143974s Loss: 0.113414 lr:0.000100
[HUGECTR][07:30:26][INFO][RANK0]: Iter: 1000 Time(500 iters): 2.082424s Loss: 0.128542 lr:0.000100
[HUGECTR][07:30:27][INFO][RANK0]: Evaluation, AUC: 0.761524
[HUGECTR][07:30:27][INFO][RANK0]: Eval Time for 5000 iters: 1.205501s
[HUGECTR][07:30:30][INFO][RANK0]: Iter: 1500 Time(500 iters): 3.291612s Loss: 0.161557 lr:0.000100
[HUGECTR][07:30:32][INFO][RANK0]: Iter: 2000 Time(500 iters): 2.083802s Loss: 0.131485 lr:0.000100
[HUGECTR][07:30:33][INFO][RANK0]: Evaluation, AUC: 0.762616
[HUGECTR][07:30:33][INFO][RANK0]: Eval Time for 5000 iters: 1.170735s
[HUGECTR][07:30:35][INFO][RANK0]: Iter: 2500 Time(500 iters): 3.260273s Loss: 0.111285 lr:0.000100
[HUGECTR][07:30:37][INFO][RANK0]: Iter: 3000 Time(500 iters): 2.086440s Loss: 0.128462 lr:0.000100
[HUGECTR][07:30:38][INFO][RANK0]: Evaluation, AUC: 0.763377
[HUGECTR][07:30:38][INFO][RANK0]: Eval Time for 5000 iters: 1.226101s
[HUGECTR][07:30:40][INFO][RANK0]: Iter: 3500 Time(500 iters): 3.316375s Loss: 0.140509 lr:0.000100
[HUGECTR][07:30:42][INFO][RANK0]: Iter: 4000 Time(500 iters): 2.082592s Loss: 0.127911 lr:0.000100
[HUGECTR][07:30:44][INFO][RANK0]: Evaluation, AUC: 0.763914
[HUGECTR][07:30:44][INFO][RANK0]: Eval Time for 5000 iters: 1.270215s
[HUGECTR][07:30:44][INFO][RANK0]: -----------------------------------Epoch 1-----------------------------------
[HUGECTR][07:30:46][INFO][RANK0]: Iter: 4500 Time(500 iters): 3.370336s Loss: 0.146994 lr:0.000100
[HUGECTR][07:30:48][INFO][RANK0]: Iter: 5000 Time(500 iters): 2.087686s Loss: 0.110219 lr:0.000100
[HUGECTR][07:30:49][INFO][RANK0]: Evaluation, AUC: 0.764316
[HUGECTR][07:30:49][INFO][RANK0]: Eval Time for 5000 iters: 1.142174s
[HUGECTR][07:30:51][INFO][RANK0]: Iter: 5500 Time(500 iters): 3.232587s Loss: 0.144252 lr:0.000100
[HUGECTR][07:30:53][INFO][RANK0]: Iter: 6000 Time(500 iters): 2.087027s Loss: 0.122446 lr:0.000100
[HUGECTR][07:30:54][INFO][RANK0]: Evaluation, AUC: 0.764680
[HUGECTR][07:30:54][INFO][RANK0]: Eval Time for 5000 iters: 1.234112s
[HUGECTR][07:30:56][INFO][RANK0]: Iter: 6500 Time(500 iters): 3.325564s Loss: 0.098065 lr:0.000100
[HUGECTR][07:30:59][INFO][RANK0]: Iter: 7000 Time(500 iters): 2.087908s Loss: 0.132715 lr:0.000100
[HUGECTR][07:31:00][INFO][RANK0]: Evaluation, AUC: 0.764872
[HUGECTR][07:31:00][INFO][RANK0]: Eval Time for 5000 iters: 1.179473s
[HUGECTR][07:31:02][INFO][RANK0]: Iter: 7500 Time(500 iters): 3.268473s Loss: 0.132111 lr:0.000100
[HUGECTR][07:31:04][INFO][RANK0]: Iter: 8000 Time(500 iters): 2.083339s Loss: 0.126090 lr:0.000100
[HUGECTR][07:31:05][INFO][RANK0]: Evaluation, AUC: 0.764933
[HUGECTR][07:31:05][INFO][RANK0]: Eval Time for 5000 iters: 1.272813s
[HUGECTR][07:31:05][INFO][RANK0]: Finish 2 epochs 8002 global iterations with batchsize 1024 in 43.22s.
Load Pre-trained Embeddings
If you have the pre-trained embeddings in other formats, you can convert them to the HugeCTR sparse models and then load them to facilitate the training process. For the sake of simplicity and generality, we represent the pretrained embeddings with the dictionary of randomly initialized numpy arrays, of which the keys indicate the embedding keys and the array values embody the embedding values. It is worth mentioning that there are two embedding tables for the Wide&Deep model, and here we only load the pre-trained embeddings for one table and freeze the corresponding embedding layer.
%%writefile wdl_load_pretrained.py
import hugectr
from mpi4py import MPI
import numpy as np
import os
import struct
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = False,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.0.txt"],
                          eval_source = "wdl_data/file_list.1.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = True)
model.compile()
model.summary()
def convert_pretrained_embeddings_to_sparse_model(pre_trained_sparse_embeddings, hugectr_sparse_model, embedding_vec_size):
    os.system("mkdir -p {}".format(hugectr_sparse_model))
    with open("{}/key".format(hugectr_sparse_model), 'wb') as key_file, \
        open("{}/emb_vector".format(hugectr_sparse_model), 'wb') as vec_file:
      for key in pre_trained_sparse_embeddings:
        vec = pre_trained_sparse_embeddings[key]
        key_struct = struct.pack('q', key)
        vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
        key_file.write(key_struct)
        vec_file.write(vec_struct)
# Convert the pretrained embeddings
pretrained_embeddings = dict()
hugectr_sparse_model = "wdl1_pretrained.model"
embedding_vec_size = 16
key_range = (0, 100000)
for key in range(key_range[0], key_range[1]):
    pretrained_embeddings[key] = np.random.randn(embedding_vec_size).astype(np.float32)
convert_pretrained_embeddings_to_sparse_model(pretrained_embeddings, hugectr_sparse_model, embedding_vec_size)
print("Successfully convert pretrained embeddings to {}".format(hugectr_sparse_model))
# Load the pretrained sparse models
model.load_sparse_weights({"sparse_embedding1": hugectr_sparse_model})
model.freeze_embedding("sparse_embedding1")
model.fit(num_epochs = 1, display = 500, eval_interval = 1000, snapshot = 100000, snapshot_prefix = "wdl")
Overwriting wdl_load_pretrained.py
!python3 wdl_load_pretrained.py
HugeCTR Version: 3.2
====================================================Model Init=====================================================
[HUGECTR][07:31:36][INFO][RANK0]: Global seed is 3369591795
[HUGECTR][07:31:36][INFO][RANK0]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HUGECTR][07:31:38][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.
[HUGECTR][07:31:38][INFO][RANK0]: Start all2all warmup
[HUGECTR][07:31:38][INFO][RANK0]: End all2all warmup
[HUGECTR][07:31:38][INFO][RANK0]: Using All-reduce algorithm: NCCL
[HUGECTR][07:31:38][INFO][RANK0]: Device 0: Tesla V100-SXM2-16GB
[HUGECTR][07:31:38][INFO][RANK0]: num of DataReader workers: 12
[HUGECTR][07:31:38][INFO][RANK0]: max_num_frequent_categories is not specified using default: 1
[HUGECTR][07:31:38][INFO][RANK0]: max_num_infrequent_samples is not specified using default: -1
[HUGECTR][07:31:38][INFO][RANK0]: p_dup_max is not specified using default: 0.010000
[HUGECTR][07:31:38][INFO][RANK0]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[HUGECTR][07:31:38][INFO][RANK0]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[HUGECTR][07:31:38][INFO][RANK0]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[HUGECTR][07:31:38][INFO][RANK0]: communication_type is not specified using default: IB_NVLink
[HUGECTR][07:31:38][INFO][RANK0]: hybrid_embedding_type is not specified using default: Distributed
[HUGECTR][07:31:38][INFO][RANK0]: max_vocabulary_size_per_gpu_=6029312
[HUGECTR][07:31:38][INFO][RANK0]: max_num_frequent_categories is not specified using default: 1
[HUGECTR][07:31:38][INFO][RANK0]: max_num_infrequent_samples is not specified using default: -1
[HUGECTR][07:31:38][INFO][RANK0]: p_dup_max is not specified using default: 0.010000
[HUGECTR][07:31:38][INFO][RANK0]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[HUGECTR][07:31:38][INFO][RANK0]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[HUGECTR][07:31:38][INFO][RANK0]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[HUGECTR][07:31:38][INFO][RANK0]: communication_type is not specified using default: IB_NVLink
[HUGECTR][07:31:38][INFO][RANK0]: hybrid_embedding_type is not specified using default: Distributed
[HUGECTR][07:31:38][INFO][RANK0]: max_vocabulary_size_per_gpu_=5865472
[HUGECTR][07:31:38][INFO][RANK0]: Load the model graph from wdl.json successfully
[HUGECTR][07:31:38][INFO][RANK0]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HUGECTR][07:31:42][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][07:31:42][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][07:31:42][INFO][RANK0]: gpu0 start to init embedding
[HUGECTR][07:31:42][INFO][RANK0]: gpu0 init embedding done
[HUGECTR][07:31:42][INFO][RANK0]: Starting AUC NCCL warm-up
[HUGECTR][07:31:42][INFO][RANK0]: Warm-up done
===================================================Model Summary===================================================
label                                   Dense                         Sparse                        
label                                   dense                          wide_data,deep_data           
(None, 1)                               (None, 13)                              
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      wide_data                     sparse_embedding2             (None, 1, 1)                  
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      deep_data                     sparse_embedding1             (None, 26, 16)                
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (None, 416)                   
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding2             reshape2                      (None, 1)                     
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (None, 429)                   
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1                       fc1                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu1                         dropout1                      (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout1                      fc2                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc2                           relu2                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu2                         dropout2                      (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout2                      fc3                           (None, 1)                     
------------------------------------------------------------------------------------------------------------------
Add                                     fc3                           add1                          (None, 1)                     
                                        reshape2                                                                                  
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  add1                          loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
Successfully convert pretrained embeddings to wdl1_pretrained.model
[HUGECTR][07:31:42][INFO][RANK0]: Loading sparse model: wdl1_pretrained.model
=====================================================Model Fit=====================================================
[HUGECTR][07:31:42][INFO][RANK0]: Use epoch mode with number of epochs: 1
[HUGECTR][07:31:42][INFO][RANK0]: Training batchsize: 1024, evaluation batchsize: 1024
[HUGECTR][07:31:42][INFO][RANK0]: Evaluation interval: 1000, snapshot interval: 100000
[HUGECTR][07:31:42][INFO][RANK0]: Dense network trainable: True
[HUGECTR][07:31:42][INFO][RANK0]: Sparse embedding sparse_embedding1 trainable: False
[HUGECTR][07:31:42][INFO][RANK0]: Sparse embedding sparse_embedding2 trainable: True
[HUGECTR][07:31:42][INFO][RANK0]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HUGECTR][07:31:42][INFO][RANK0]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HUGECTR][07:31:42][INFO][RANK0]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HUGECTR][07:31:42][INFO][RANK0]: Training source file: wdl_data/file_list.0.txt
[HUGECTR][07:31:42][INFO][RANK0]: Evaluation source file: wdl_data/file_list.1.txt
[HUGECTR][07:31:42][INFO][RANK0]: -----------------------------------Epoch 0-----------------------------------
[HUGECTR][07:31:45][INFO][RANK0]: Iter: 500 Time(500 iters): 2.578954s Loss: 0.144762 lr:0.001000
[HUGECTR][07:31:47][INFO][RANK0]: Iter: 1000 Time(500 iters): 2.417656s Loss: 0.136326 lr:0.001000
[HUGECTR][07:31:49][INFO][RANK0]: Evaluation, AUC: 0.713149
[HUGECTR][07:31:49][INFO][RANK0]: Eval Time for 5000 iters: 1.900026s
[HUGECTR][07:31:52][INFO][RANK0]: Iter: 1500 Time(500 iters): 4.323596s Loss: 0.148682 lr:0.001000
[HUGECTR][07:31:54][INFO][RANK0]: Iter: 2000 Time(500 iters): 2.416977s Loss: 0.145738 lr:0.001000
[HUGECTR][07:31:56][INFO][RANK0]: Evaluation, AUC: 0.725260
[HUGECTR][07:31:56][INFO][RANK0]: Eval Time for 5000 iters: 1.876535s
[HUGECTR][07:31:58][INFO][RANK0]: Iter: 2500 Time(500 iters): 4.297001s Loss: 0.168649 lr:0.001000
[HUGECTR][07:32:01][INFO][RANK0]: Iter: 3000 Time(500 iters): 2.418015s Loss: 0.134682 lr:0.001000
[HUGECTR][07:32:03][INFO][RANK0]: Evaluation, AUC: 0.732183
[HUGECTR][07:32:03][INFO][RANK0]: Eval Time for 5000 iters: 1.877291s
[HUGECTR][07:32:05][INFO][RANK0]: Iter: 3500 Time(500 iters): 4.296748s Loss: 0.117909 lr:0.001000
[HUGECTR][07:32:08][INFO][RANK0]: Iter: 4000 Time(500 iters): 2.411790s Loss: 0.133109 lr:0.001000
[HUGECTR][07:32:09][INFO][RANK0]: Evaluation, AUC: 0.736392
[HUGECTR][07:32:09][INFO][RANK0]: Eval Time for 5000 iters: 1.901939s
[HUGECTR][07:32:09][INFO][RANK0]: Finish 1 epochs 4001 global iterations with batchsize 1024 in 27.09s.
Low-level Training
The low-level training APIs are maintained in the enhanced HugeCTR Python interface. If you want to have precise control of each training iteration and each evaluation step, you may find it helpful to use these APIs. Since the data reader behavior is different in epoch mode and non-epoch mode, we should pay attention to how to tweak the data reader when using low-level training. We will demonstrate how to write the low-level training scripts for non-epoch mode, epoch mode, and embedding training cache mode.
%%writefile wdl_non_epoch.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = True,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.0.txt"],
                          eval_source = "wdl_data/file_list.1.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = True)
model.compile()
model.start_data_reading()
lr_sch = model.get_learning_rate_scheduler()
max_iter = 2000
for i in range(max_iter):
    lr = lr_sch.get_next()
    model.set_learning_rate(lr)
    model.train()
    if (i%100 == 0):
        loss = model.get_current_loss()
        print("[HUGECTR][INFO] iter: {}; loss: {}".format(i, loss))
    if (i%1000 == 0 and i != 0):
        for _ in range(solver.max_eval_batches):
            model.eval()
        metrics = model.get_eval_metrics()
        print("[HUGECTR][INFO] iter: {}, {}".format(i, metrics))
model.save_params_to_files("./", max_iter)
Writing wdl_non_epoch.py
!python3 wdl_non_epoch.py
HugeCTR Version: 3.2.0
====================================================Model Init=====================================================
[28d09h36m32s][HUGECTR][INFO]: Global seed is 3898093135
[28d09h36m33s][HUGECTR][INFO]: Device to NUMA mapping:
  GPU 0 ->  node 0
[28d09h36m35s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
[28d09h36m35s][HUGECTR][INFO]: Start all2all warmup
[28d09h36m35s][HUGECTR][INFO]: End all2all warmup
[28d09h36m35s][HUGECTR][INFO]: Using All-reduce algorithm NCCL
Device 0: Tesla V100-SXM2-16GB
[28d09h36m35s][HUGECTR][INFO]: num of DataReader workers: 12
[28d09h36m35s][HUGECTR][INFO]: max_num_frequent_categories is not specified using default: 1
[28d09h36m35s][HUGECTR][INFO]: max_num_infrequent_samples is not specified using default: -1
[28d09h36m35s][HUGECTR][INFO]: p_dup_max is not specified using default: 0.010000
[28d09h36m35s][HUGECTR][INFO]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[28d09h36m35s][HUGECTR][INFO]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[28d09h36m35s][HUGECTR][INFO]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[28d09h36m35s][HUGECTR][INFO]: communication_type is not specified using default: IB_NVLink
[28d09h36m35s][HUGECTR][INFO]: hybrid_embedding_type is not specified using default: Distributed
[28d09h36m35s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=6029312
[28d09h36m35s][HUGECTR][INFO]: max_num_frequent_categories is not specified using default: 1
[28d09h36m35s][HUGECTR][INFO]: max_num_infrequent_samples is not specified using default: -1
[28d09h36m35s][HUGECTR][INFO]: p_dup_max is not specified using default: 0.010000
[28d09h36m35s][HUGECTR][INFO]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[28d09h36m35s][HUGECTR][INFO]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[28d09h36m35s][HUGECTR][INFO]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[28d09h36m35s][HUGECTR][INFO]: communication_type is not specified using default: IB_NVLink
[28d09h36m35s][HUGECTR][INFO]: hybrid_embedding_type is not specified using default: Distributed
[28d09h36m35s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=5865472
[28d09h36m35s][HUGECTR][INFO]: Load the model graph from wdl.json, successful
===================================================Model Compile===================================================
[28d09h36m38s][HUGECTR][INFO]: gpu0 start to init embedding
[28d09h36m38s][HUGECTR][INFO]: gpu0 init embedding done
[28d09h36m38s][HUGECTR][INFO]: gpu0 start to init embedding
[28d09h36m38s][HUGECTR][INFO]: gpu0 init embedding done
[28d09h36m38s][HUGECTR][INFO]: Starting AUC NCCL warm-up
[28d09h36m38s][HUGECTR][INFO]: Warm-up done
[HUGECTR][INFO] iter: 0; loss: 1.0029206275939941
[HUGECTR][INFO] iter: 100; loss: 0.12538853287696838
[HUGECTR][INFO] iter: 200; loss: 0.10476257652044296
[HUGECTR][INFO] iter: 300; loss: 0.1463421732187271
[HUGECTR][INFO] iter: 400; loss: 0.1541304737329483
[HUGECTR][INFO] iter: 500; loss: 0.14912495017051697
[HUGECTR][INFO] iter: 600; loss: 0.12571805715560913
[HUGECTR][INFO] iter: 700; loss: 0.13279415667057037
[HUGECTR][INFO] iter: 800; loss: 0.13649113476276398
[HUGECTR][INFO] iter: 900; loss: 0.1288434863090515
[HUGECTR][INFO] iter: 1000; loss: 0.13555476069450378
[HUGECTR][INFO] iter: 1000, [('AUC', 0.7378068566322327)]
[HUGECTR][INFO] iter: 1100; loss: 0.15259310603141785
[HUGECTR][INFO] iter: 1200; loss: 0.15795981884002686
[HUGECTR][INFO] iter: 1300; loss: 0.13971731066703796
[HUGECTR][INFO] iter: 1400; loss: 0.14082138240337372
[HUGECTR][INFO] iter: 1500; loss: 0.14722011983394623
[HUGECTR][INFO] iter: 1600; loss: 0.13573814928531647
[HUGECTR][INFO] iter: 1700; loss: 0.12339376658201218
[HUGECTR][INFO] iter: 1800; loss: 0.15557655692100525
[HUGECTR][INFO] iter: 1900; loss: 0.1399267613887787
[28d09h36m55s][HUGECTR][INFO]: Rank0: Write hash table to file
[28d09h36m56s][HUGECTR][INFO]: Rank0: Write hash table to file
[28d09h36m56s][HUGECTR][INFO]: Dumping sparse weights to files, successful
[28d09h36m56s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[28d09h36m56s][HUGECTR][INFO]: Done
[28d09h36m56s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[28d09h36m56s][HUGECTR][INFO]: Done
[28d09h36m57s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[28d09h36m57s][HUGECTR][INFO]: Done
[28d09h36m58s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[28d09h36m58s][HUGECTR][INFO]: Done
[28d09h37m05s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful
[28d09h37m05s][HUGECTR][INFO]: Dumping dense weights to file, successful
[28d09h37m05s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful
[28d09h37m05s][HUGECTR][INFO]: Dumping untrainable weights to file, successful
%%writefile wdl_epoch.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = False,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.0.txt"],
                          eval_source = "wdl_data/file_list.1.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = True)
model.compile()
lr_sch = model.get_learning_rate_scheduler()
data_reader_train = model.get_data_reader_train()
data_reader_eval = model.get_data_reader_eval()
data_reader_eval.set_source()
data_reader_eval_flag = True
iteration = 0
for epoch in range(2):
  print("[HUGECTR][INFO] epoch: ", epoch)
  data_reader_train.set_source()
  data_reader_train_flag = True
  while True:
    lr = lr_sch.get_next()
    model.set_learning_rate(lr)
    data_reader_train_flag = model.train()
    if not data_reader_train_flag:
      break
    if iteration % 1000 == 0:
      batches = 0
      while data_reader_eval_flag:
        if batches >= solver.max_eval_batches:
          break
        data_reader_eval_flag = model.eval()
        batches += 1
      if not data_reader_eval_flag:
        data_reader_eval.set_source()
        data_reader_eval_flag = True
      metrics = model.get_eval_metrics()
      print("[HUGECTR][INFO] iter: {}, metrics: {}".format(iteration, metrics))
    iteration += 1
model.save_params_to_files("./", iteration)
Writing wdl_epoch.py
!python3 wdl_epoch.py
HugeCTR Version: 3.2.0
====================================================Model Init=====================================================
[28d09h37m13s][HUGECTR][INFO]: Global seed is 2582215374
[28d09h37m14s][HUGECTR][INFO]: Device to NUMA mapping:
  GPU 0 ->  node 0
[28d09h37m16s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
[28d09h37m16s][HUGECTR][INFO]: Start all2all warmup
[28d09h37m16s][HUGECTR][INFO]: End all2all warmup
[28d09h37m16s][HUGECTR][INFO]: Using All-reduce algorithm NCCL
Device 0: Tesla V100-SXM2-16GB
[28d09h37m16s][HUGECTR][INFO]: num of DataReader workers: 12
[28d09h37m16s][HUGECTR][INFO]: max_num_frequent_categories is not specified using default: 1
[28d09h37m16s][HUGECTR][INFO]: max_num_infrequent_samples is not specified using default: -1
[28d09h37m16s][HUGECTR][INFO]: p_dup_max is not specified using default: 0.010000
[28d09h37m16s][HUGECTR][INFO]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[28d09h37m16s][HUGECTR][INFO]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[28d09h37m16s][HUGECTR][INFO]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[28d09h37m16s][HUGECTR][INFO]: communication_type is not specified using default: IB_NVLink
[28d09h37m16s][HUGECTR][INFO]: hybrid_embedding_type is not specified using default: Distributed
[28d09h37m16s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=6029312
[28d09h37m16s][HUGECTR][INFO]: max_num_frequent_categories is not specified using default: 1
[28d09h37m16s][HUGECTR][INFO]: max_num_infrequent_samples is not specified using default: -1
[28d09h37m16s][HUGECTR][INFO]: p_dup_max is not specified using default: 0.010000
[28d09h37m16s][HUGECTR][INFO]: max_all_reduce_bandwidth is not specified using default: 130000000000.000000
[28d09h37m16s][HUGECTR][INFO]: max_all_to_all_bandwidth is not specified using default: 190000000000.000000
[28d09h37m16s][HUGECTR][INFO]: efficiency_bandwidth_ratio is not specified using default: 1.000000
[28d09h37m16s][HUGECTR][INFO]: communication_type is not specified using default: IB_NVLink
[28d09h37m16s][HUGECTR][INFO]: hybrid_embedding_type is not specified using default: Distributed
[28d09h37m16s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=5865472
[28d09h37m16s][HUGECTR][INFO]: Load the model graph from wdl.json, successful
===================================================Model Compile===================================================
[28d09h37m19s][HUGECTR][INFO]: gpu0 start to init embedding
[28d09h37m19s][HUGECTR][INFO]: gpu0 init embedding done
[28d09h37m19s][HUGECTR][INFO]: gpu0 start to init embedding
[28d09h37m19s][HUGECTR][INFO]: gpu0 init embedding done
[28d09h37m19s][HUGECTR][INFO]: Starting AUC NCCL warm-up
[28d09h37m19s][HUGECTR][INFO]: Warm-up done
[HUGECTR][INFO] epoch:  0
[HUGECTR][INFO] iter: 0, metrics: [('AUC', 0.5298923850059509)]
[HUGECTR][INFO] iter: 1000, metrics: [('AUC', 0.7407246828079224)]
[HUGECTR][INFO] iter: 2000, metrics: [('AUC', 0.7498546242713928)]
[HUGECTR][INFO] iter: 3000, metrics: [('AUC', 0.7546358704566956)]
[HUGECTR][INFO] epoch:  1
[HUGECTR][INFO] iter: 4000, metrics: [('AUC', 0.7573446035385132)]
[HUGECTR][INFO] iter: 5000, metrics: [('AUC', 0.7208017706871033)]
[HUGECTR][INFO] iter: 6000, metrics: [('AUC', 0.7227433323860168)]
[HUGECTR][INFO] iter: 7000, metrics: [('AUC', 0.7216600775718689)]
[28d09h38m42s][HUGECTR][INFO]: Rank0: Write hash table to file
[28d09h38m42s][HUGECTR][INFO]: Rank0: Write hash table to file
[28d09h38m43s][HUGECTR][INFO]: Dumping sparse weights to files, successful
[28d09h38m43s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[28d09h38m43s][HUGECTR][INFO]: Done
[28d09h38m43s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[28d09h38m43s][HUGECTR][INFO]: Done
[28d09h38m44s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[28d09h38m44s][HUGECTR][INFO]: Done
[28d09h38m44s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[28d09h38m45s][HUGECTR][INFO]: Done
[28d09h38m52s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful
[28d09h38m52s][HUGECTR][INFO]: Dumping dense weights to file, successful
[28d09h38m52s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful
[28d09h38m52s][HUGECTR][INFO]: Dumping untrainable weights to file, successful