# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hugectr-criteo/nvidia_logo.png

Introduction to the HugeCTR Python Interface

Overview

HugeCTR version 3.1 introduces an enhanced Python interface The interface supports continuous training and inference with high-level APIs. There are four main improvements.

  • First, the model graph can be constructed and dumped to a JSON file with Python code and it saves users from writing JSON configuration files.

  • Second, the API supports the feature of embedding training cache with high-level APIs and extends it further for online training cases. (For learn about continuous training, you can view the example notebook).

  • Third, the freezing method is provided for both sparse embedding and dense network. This method enables transfer learning and fine-tuning for CTR tasks.

  • Finally, the pre-trained embeddings in other formats can be converted to HugeCTR sparse models and then loaded to facilitate the training process. This is shown in the Load Pre-trained Embeddings section of this notebook.

This notebook explains how to access and use the enhanced HugeCTR Python interface. Although the low-level training APIs are still maintained for users who want to have precise control of each training iteration, migrating to the high-level training APIs is strongly recommended. For more details of the usage of the Python API, refer to the HugeCTR Python Interface documentation.

Setup

To setup the environment, refer to HugeCTR Example Notebooks and follow the instructions there before running the following.

DCN Model

Note: If you already have the data downloaded, then skip to the preprocessing step (2). If preprocessing is also done, skip to creating the softlink between the processed data to the notebooks/ directory (3).

Data Preparation

To download and prepare the dataset we will be doing the following steps. At the end of this cell, we provide the shell commands you can run on the terminal to get the data ready for this notebook.

Note: If you already have the data downloaded, then skip to the preprocessing step (2). If preprocessing is also done, skip to creating the softlink between the processed data to the notebooks/ directory (3).

  1. Download the Criteo dataset

To preprocess the downloaded Kaggle Criteo dataset, we’ll make the following operations:

  • Reduce the amounts of data to speed up the preprocessing

  • Fill missing values

  • Remove the feature values whose occurrences are very rare, etc.

  1. Preprocessing by Pandas:

    Meanings of the command line arguments:

    • The 1st argument represents the dataset postfix. It is 1 here since day_1 is used.

    • The 2nd argument wdl_data is where the preprocessed data is stored.

    • The 3rd argument pandas is the processing script going to use, here we choose pandas.

    • The 4th argument 1 embodies that the normalization is applied to dense features.

    • The 5th argument 1 means that the feature crossing is applied.

    • The 6th argument 100 means the number of data files in each file list.

    For more details about the data preprocessing, please refer to the “Preprocess the Criteo Dataset” section of the README in the samples/criteo directory of the repository on GitHub.

  2. Create a soft link of the dataset folder to the path of this notebook

Run the following commands on the terminal to prepare the data for this notebook

export project_root=/home/hugectr # set this to the directory where hugectr is downloaded
cd ${project_root}/tools
# Step 1
wget https://storage.googleapis.com/criteo-cail-datasets/day_0.gz
#Step 2
bash preprocess.sh 0 dcn_data pandas 1 0
#Step 3
ln -s ${project_root}/tools/dcn_data ${project_root}/notebooks/dcn_data

Note: It will take a while (dozens of minutes) to preprocess the dataset. Please make sure that it is finished successfully before moving forward to the next section.

Train from Scratch

We can train fom scratch, dump the model graph to a JSON file, and save the model weights and optimizer states by performing the following with Python APIs:

  1. Create the solver, reader and optimizer, then initialize the model.

  2. Construct the model graph by adding input, sparse embedding and dense layers in order.

  3. Compile the model and have an overview of the model graph.

  4. Dump the model graph to the JSON file.

  5. Fit the model, save the model weights and optimizer states implicitly.

Please note that the training mode is determined by repeat_dataset within hugectr.CreateSolver. If it is True, the non-epoch mode training is adopted and the maximum iterations should be specified by max_iter within hugectr.Model.fit. If it is False, the epoch-mode training is adopted and the number of epochs should be specified by num_epochs within hugectr.Model.fit.

The optimizer that is used to initialize the model applies to the weights of dense layers, while the optimizer for each sparse embedding layer can be specified independently within hugectr.SparseEmbedding.

%%writefile dcn_train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 1500,
                              batchsize_eval = 4096,
                              batchsize = 4096,
                              lr = 0.001,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = True,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                                  source = ["./dcn_data/file_list.txt"],
                                  eval_source = "./dcn_data/file_list_test.txt",
                                  check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
                        dense_dim = 13, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("data1", 2, False, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 264,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "data1",
                            optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.MultiCross,
                            bottom_names = ["concat1"],
                            top_names = ["multicross1"],
                            num_layers=6))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc2"],
                            top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu2"],
                            top_names = ["dropout2"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["dropout2", "multicross1"],
                            top_names = ["concat2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat2"],
                            top_names = ["fc3"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["fc3", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.graph_to_json(graph_config_file = "dcn.json")
model.fit(max_iter = 1200, display = 500, eval_interval = 100, snapshot = 1000, snapshot_prefix = "dcn")
Writing dcn_train.py
!python3 dcn_train.py
HugeCTR Version: 4.1
====================================================Model Init=====================================================
[HCTR][08:37:13.891][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][08:37:13.892][INFO][RK0][main]: Global seed is 3840413353
[HCTR][08:37:13.894][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][08:37:15.710][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][08:37:15.710][INFO][RK0][main]: Start all2all warmup
[HCTR][08:37:15.710][INFO][RK0][main]: End all2all warmup
[HCTR][08:37:15.710][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][08:37:15.711][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][08:37:15.712][INFO][RK0][main]: num of DataReader workers for train: 12
[HCTR][08:37:15.712][INFO][RK0][main]: num of DataReader workers for eval: 12
[HCTR][08:37:15.748][INFO][RK0][main]: max_vocabulary_size_per_gpu_=1441792
[HCTR][08:37:15.750][INFO][RK0][main]: Graph analysis to resolve tensor dependency
[HCTR][08:37:15.750][INFO][RK0][main]: Add Slice layer for tensor: concat1, creating 2 copies
===================================================Model Compile===================================================
[HCTR][08:37:29.637][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][08:37:29.638][INFO][RK0][main]: gpu0 init embedding done
[HCTR][08:37:29.640][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][08:37:29.643][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][08:37:29.643][INFO][RK0][main]: Model structure on each GPU
Label                                   Dense                         Sparse                        
label                                   dense                          data1                         
(4096,1)                                (4096,13)                               
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      data1                         sparse_embedding1             (4096,26,16)                  
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (4096,416)                    
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (4096,429)                    
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
Slice                                   concat1                       concat1_slice0                (4096,429)                    
                                                                      concat1_slice1                (4096,429)                    
------------------------------------------------------------------------------------------------------------------
MultiCross                              concat1_slice0                multicross1                   (4096,429)                    
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1_slice1                fc1                           (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu1                         dropout1                      (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout1                      fc2                           (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc2                           relu2                         (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu2                         dropout2                      (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
Concat                                  dropout2                      concat2                       (4096,1453)                   
                                        multicross1                                                                               
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat2                       fc3                           (4096,1)                      
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  fc3                           loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
[HCTR][08:37:29.645][INFO][RK0][main]: Save the model graph to dcn.json successfully
=====================================================Model Fit=====================================================
[HCTR][08:37:29.645][INFO][RK0][main]: Use non-epoch mode with number of iterations: 1200
[HCTR][08:37:29.645][INFO][RK0][main]: Training batchsize: 4096, evaluation batchsize: 4096
[HCTR][08:37:29.645][INFO][RK0][main]: Evaluation interval: 100, snapshot interval: 1000
[HCTR][08:37:29.645][INFO][RK0][main]: Dense network trainable: True
[HCTR][08:37:29.645][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][08:37:29.645][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][08:37:29.645][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][08:37:29.645][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][08:37:29.645][INFO][RK0][main]: Training source file: ./dcn_data/file_list.txt
[HCTR][08:37:29.645][INFO][RK0][main]: Evaluation source file: ./dcn_data/file_list_test.txt
[HCTR][08:37:33.797][INFO][RK0][main]: Evaluation, AUC: 0.722862
[HCTR][08:37:33.797][INFO][RK0][main]: Eval Time for 1500 iters: 3.29509s
[HCTR][08:37:37.855][INFO][RK0][main]: Evaluation, AUC: 0.738291
[HCTR][08:37:37.855][INFO][RK0][main]: Eval Time for 1500 iters: 3.29477s
[HCTR][08:37:41.915][INFO][RK0][main]: Evaluation, AUC: 0.748639
[HCTR][08:37:41.915][INFO][RK0][main]: Eval Time for 1500 iters: 3.2957s
[HCTR][08:37:45.987][INFO][RK0][main]: Evaluation, AUC: 0.753537
[HCTR][08:37:45.987][INFO][RK0][main]: Eval Time for 1500 iters: 3.30702s
[HCTR][08:37:46.752][INFO][RK0][main]: Iter: 500 Time(500 iters): 17.1058s Loss: 0.126047 lr:0.001
[HCTR][08:37:50.048][INFO][RK0][main]: Evaluation, AUC: 0.755874
[HCTR][08:37:50.048][INFO][RK0][main]: Eval Time for 1500 iters: 3.29613s
[HCTR][08:37:54.108][INFO][RK0][main]: Evaluation, AUC: 0.758
[HCTR][08:37:54.108][INFO][RK0][main]: Eval Time for 1500 iters: 3.29422s
[HCTR][08:37:58.166][INFO][RK0][main]: Evaluation, AUC: 0.760666
[HCTR][08:37:58.166][INFO][RK0][main]: Eval Time for 1500 iters: 3.29585s
[HCTR][08:38:02.225][INFO][RK0][main]: Evaluation, AUC: 0.763448
[HCTR][08:38:02.225][INFO][RK0][main]: Eval Time for 1500 iters: 3.29549s
[HCTR][08:38:06.286][INFO][RK0][main]: Evaluation, AUC: 0.764422
[HCTR][08:38:06.286][INFO][RK0][main]: Eval Time for 1500 iters: 3.29466s
[HCTR][08:38:07.051][INFO][RK0][main]: Iter: 1000 Time(500 iters): 20.2983s Loss: 0.113481 lr:0.001
[HCTR][08:38:10.347][INFO][RK0][main]: Evaluation, AUC: 0.755063
[HCTR][08:38:10.347][INFO][RK0][main]: Eval Time for 1500 iters: 3.2954s
[HCTR][08:38:10.347][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:10.369][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][08:38:10.429][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][08:38:10.485][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][08:38:10.485][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:10.729][INFO][RK0][main]: Done
[HCTR][08:38:10.790][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][08:38:10.790][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:11.037][INFO][RK0][main]: Done
[HCTR][08:38:11.042][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][08:38:11.045][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:11.062][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][08:38:11.067][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:11.098][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][08:38:15.161][INFO][RK0][main]: Evaluation, AUC: 0.741793
[HCTR][08:38:15.162][INFO][RK0][main]: Eval Time for 1500 iters: 3.29657s
[HCTR][08:38:15.922][INFO][RK0][main]: Finish 1200 iterations with batchsize: 4096 in 46.28s.

Continue Training

We can continue our training based on the saved model graph, model weights, and optimizer states by performing the following with Python APIs:

  1. Create the solver, reader and optimizer, then initialize the model.

  2. Construct the model graph from the saved JSON file, see Python API details here.

  3. Compile the model and have an overview of the model graph.

  4. Load the model weights and optimizer states, see Python API details here.

  5. Fit the model, save the model weights and optimizer states implicitly.

%%writefile dcn_continue.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 1500,
                              batchsize_eval = 4096,
                              batchsize = 4096,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = True,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                                  source = ["./dcn_data/file_list.txt"],
                                  eval_source = "./dcn_data/file_list_test.txt",
                                  check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "dcn.json", include_dense_network = True)
model.compile()
model.load_dense_weights("dcn_dense_1000.model")
model.load_sparse_weights(["dcn0_sparse_1000.model"])
model.load_dense_optimizer_states("dcn_opt_dense_1000.model")
model.load_sparse_optimizer_states(["dcn0_opt_sparse_1000.model"])
model.summary()
model.fit(max_iter = 500, display = 50, eval_interval = 100, snapshot = 10000, snapshot_prefix = "dcn")
Writing dcn_continue.py
!python3 dcn_continue.py
HugeCTR Version: 4.1
====================================================Model Init=====================================================
[HCTR][08:38:41.347][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][08:38:41.347][INFO][RK0][main]: Global seed is 2833570033
[HCTR][08:38:41.349][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][08:38:43.163][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][08:38:43.163][INFO][RK0][main]: Start all2all warmup
[HCTR][08:38:43.164][INFO][RK0][main]: End all2all warmup
[HCTR][08:38:43.164][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][08:38:43.165][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][08:38:43.166][INFO][RK0][main]: num of DataReader workers for train: 12
[HCTR][08:38:43.166][INFO][RK0][main]: num of DataReader workers for eval: 12
[HCTR][08:38:43.203][WARNING][RK0][main]: Embedding vector size(16) is not a multiple of 32, which may affect the GPU resource utilization.
[HCTR][08:38:43.203][INFO][RK0][main]: max_num_frequent_categories is not specified using default: 1
[HCTR][08:38:43.203][INFO][RK0][main]: max_num_infrequent_samples is not specified using default: -1
[HCTR][08:38:43.203][INFO][RK0][main]: p_dup_max is not specified using default: 0.01
[HCTR][08:38:43.203][INFO][RK0][main]: max_all_reduce_bandwidth is not specified using default: 1.3e+11
[HCTR][08:38:43.203][INFO][RK0][main]: max_all_to_all_bandwidth is not specified using default: 1.9e+11
[HCTR][08:38:43.203][INFO][RK0][main]: efficiency_bandwidth_ratio is not specified using default: 1
[HCTR][08:38:43.203][INFO][RK0][main]: communication_type is not specified using default: IB_NVLink
[HCTR][08:38:43.203][INFO][RK0][main]: hybrid_embedding_type is not specified using default: Distributed
[HCTR][08:38:43.203][INFO][RK0][main]: max_vocabulary_size_per_gpu_=1441792
[HCTR][08:38:43.205][INFO][RK0][main]: Load the model graph from dcn.json successfully
[HCTR][08:38:43.205][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][08:38:57.092][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][08:38:57.093][INFO][RK0][main]: gpu0 init embedding done
[HCTR][08:38:57.095][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][08:38:57.099][INFO][RK0][main]: Warm-up done
[HCTR][08:38:57.101][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:57.103][INFO][RK0][main]: Loading sparse model: dcn0_sparse_1000.model
[HCTR][08:38:57.103][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:57.183][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:57.198][INFO][RK0][main]: Loading dense opt states: dcn_opt_dense_1000.model
[HCTR][08:38:57.199][INFO][RK0][main]: Loading sparse optimizer states: dcn0_opt_sparse_1000.model
[HCTR][08:38:57.200][INFO][RK0][main]: Rank0: Read optimzer state from file
[HCTR][08:38:57.200][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:57.291][INFO][RK0][main]: Done
[HCTR][08:38:57.291][INFO][RK0][main]: Rank0: Read optimzer state from file
[HCTR][08:38:57.291][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:38:57.376][INFO][RK0][main]: Done
===================================================Model Summary===================================================
[HCTR][08:38:57.376][INFO][RK0][main]: Model structure on each GPU
Label                                   Dense                         Sparse                        
label                                   dense                          data1                         
(4096,1)                                (4096,13)                               
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      data1                         sparse_embedding1             (4096,26,16)                  
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (4096,416)                    
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (4096,429)                    
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
Slice                                   concat1                       concat1_slice0                (4096,429)                    
                                                                      concat1_slice1                (4096,429)                    
------------------------------------------------------------------------------------------------------------------
MultiCross                              concat1_slice0                multicross1                   (4096,429)                    
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1_slice1                fc1                           (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu1                         dropout1                      (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout1                      fc2                           (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc2                           relu2                         (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu2                         dropout2                      (4096,1024)                   
------------------------------------------------------------------------------------------------------------------
Concat                                  dropout2                      concat2                       (4096,1453)                   
                                        multicross1                                                                               
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat2                       fc3                           (4096,1)                      
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  fc3                           loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[HCTR][08:38:57.376][INFO][RK0][main]: Use non-epoch mode with number of iterations: 500
[HCTR][08:38:57.376][INFO][RK0][main]: Training batchsize: 4096, evaluation batchsize: 4096
[HCTR][08:38:57.376][INFO][RK0][main]: Evaluation interval: 100, snapshot interval: 10000
[HCTR][08:38:57.376][INFO][RK0][main]: Dense network trainable: True
[HCTR][08:38:57.376][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][08:38:57.376][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][08:38:57.376][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][08:38:57.376][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][08:38:57.376][INFO][RK0][main]: Training source file: ./dcn_data/file_list.txt
[HCTR][08:38:57.376][INFO][RK0][main]: Evaluation source file: ./dcn_data/file_list_test.txt
[HCTR][08:38:57.853][INFO][RK0][main]: Iter: 50 Time(50 iters): 0.475771s Loss: 0.113989 lr:0.001
[HCTR][08:38:58.237][INFO][RK0][main]: Iter: 100 Time(50 iters): 0.382273s Loss: 0.105428 lr:0.001
[HCTR][08:39:01.531][INFO][RK0][main]: Evaluation, AUC: 0.746004
[HCTR][08:39:01.531][INFO][RK0][main]: Eval Time for 1500 iters: 3.29415s
[HCTR][08:39:01.915][INFO][RK0][main]: Iter: 150 Time(50 iters): 3.67713s Loss: 0.112908 lr:0.001
[HCTR][08:39:02.299][INFO][RK0][main]: Iter: 200 Time(50 iters): 0.382364s Loss: 0.110116 lr:0.001
[HCTR][08:39:05.592][INFO][RK0][main]: Evaluation, AUC: 0.743096
[HCTR][08:39:05.592][INFO][RK0][main]: Eval Time for 1500 iters: 3.29324s
[HCTR][08:39:05.976][INFO][RK0][main]: Iter: 250 Time(50 iters): 3.67566s Loss: 0.113728 lr:0.001
[HCTR][08:39:06.359][INFO][RK0][main]: Iter: 300 Time(50 iters): 0.381715s Loss: 0.114037 lr:0.001
[HCTR][08:39:09.651][INFO][RK0][main]: Evaluation, AUC: 0.744914
[HCTR][08:39:09.651][INFO][RK0][main]: Eval Time for 1500 iters: 3.29269s
[HCTR][08:39:10.036][INFO][RK0][main]: Iter: 350 Time(50 iters): 3.67574s Loss: 0.100788 lr:0.001
[HCTR][08:39:10.418][INFO][RK0][main]: Iter: 400 Time(50 iters): 0.381021s Loss: 0.119661 lr:0.001
[HCTR][08:39:13.713][INFO][RK0][main]: Evaluation, AUC: 0.739262
[HCTR][08:39:13.713][INFO][RK0][main]: Eval Time for 1500 iters: 3.29468s
[HCTR][08:39:14.094][INFO][RK0][main]: Iter: 450 Time(50 iters): 3.67472s Loss: 0.122573 lr:0.001
[HCTR][08:39:14.466][INFO][RK0][main]: Finish 500 iterations with batchsize: 4096 in 17.09s.

Inference

The HugeCTR inference is enabled by hugectr.inference.InferenceSession.predict method of InferenceSession. This method requires dense features, embedding columns, and row pointers of slots as the input and gives the prediction result as the output. We need to convert the Criteo data to inference format first.

!python3 ../tools/criteo_predict/criteo2predict.py --src_csv_path=dcn_data/val/test.txt --src_config=../tools/criteo_predict/dcn_data.json --dst_path=./dcn_csr.txt --batch_size=1024

We can then make inferences based on the saved model graph and model weights by performing the following with Python APIs:

  1. Configure the inference related parameters.

  2. Create the inference session.

  3. Make inference with the hugectr.inference.InferenceSession.predict method.

%%writefile dcn_inference.py
from hugectr.inference import InferenceParams, CreateInferenceSession
from mpi4py import MPI

def calculate_accuracy(labels, output):
    num_samples = len(labels)
    flags = [1 if ((labels[i] == 0 and output[i] <= 0.5) or (labels[i] == 1 and output[i] > 0.5)) else 0 for i in range(num_samples)]
    correct_samples = sum(flags)
    return float(correct_samples)/(float(num_samples)+1e-16)

data_file = open("dcn_csr.txt")
config_file = "dcn.json"
labels = [int(item) for item in data_file.readline().split(' ')]
dense_features = [float(item) for item in data_file.readline().split(' ') if item!="\n"]
embedding_columns = [int(item) for item in data_file.readline().split(' ')]
row_ptrs = [int(item) for item in data_file.readline().split(' ')]

# create parameter server, embedding cache and inference session
inference_params = InferenceParams(model_name = "dcn",
                                max_batchsize = 1024,
                                hit_rate_threshold = 0.6,
                                dense_model_file = "./dcn_dense_1000.model",
                                sparse_model_files = ["./dcn0_sparse_1000.model"],
                                device_id = 0,
                                use_gpu_embedding_cache = True,
                                cache_size_percentage = 0.9,
                                i64_input_key = False,
                                use_mixed_precision = False)
inference_session = CreateInferenceSession(config_file, inference_params)
output = inference_session.predict(dense_features, embedding_columns, row_ptrs)
accuracy = calculate_accuracy(labels, output)
print("[HUGECTR][INFO] number samples: {}, accuracy: {}".format(len(labels), accuracy))
Writing dcn_inference.py
!python3 dcn_inference.py
[HCTR][08:41:25.641][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][08:41:25.641][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][08:41:25.642][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][08:41:25.642][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][08:41:25.642][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][08:41:25.642][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][08:41:25.642][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:41:25.891][INFO][RK0][main]: Table: hps_et.dcn.sparse_embedding1; cached 282873 / 282873 embeddings in volatile database (HashMapBackend); load: 282873 / 18446744073709551615 (0.00%).
[HCTR][08:41:25.892][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][08:41:25.892][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][08:41:25.898][INFO][RK0][main]: Model name: dcn
[HCTR][08:41:25.898][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][08:41:25.898][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.900000
[HCTR][08:41:25.898][INFO][RK0][main]: Use I64 input key: False
[HCTR][08:41:25.898][INFO][RK0][main]: Configured cache hit rate threshold: 0.600000
[HCTR][08:41:25.898][INFO][RK0][main]: The size of thread pool: 80
[HCTR][08:41:25.898][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][08:41:25.898][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][08:41:25.898][INFO][RK0][main]: The refresh percentage : 0.000000
[HCTR][08:41:26.851][INFO][RK0][main]: Global seed is 1715681389
[HCTR][08:41:26.854][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][08:41:27.788][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][08:41:27.788][INFO][RK0][main]: Start all2all warmup
[HCTR][08:41:27.788][INFO][RK0][main]: End all2all warmup
[HCTR][08:41:27.788][INFO][RK0][main]: Model name: dcn
[HCTR][08:41:27.788][INFO][RK0][main]: Use mixed precision: False
[HCTR][08:41:27.788][INFO][RK0][main]: Use cuda graph: True
[HCTR][08:41:27.788][INFO][RK0][main]: Max batchsize: 1024
[HCTR][08:41:27.788][INFO][RK0][main]: Use I64 input key: False
[HCTR][08:41:27.788][INFO][RK0][main]: start create embedding for inference
[HCTR][08:41:27.788][INFO][RK0][main]: sparse_input name data1
[HCTR][08:41:27.788][INFO][RK0][main]: create embedding for inference success
[HCTR][08:41:27.789][INFO][RK0][main]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer
[HUGECTR][INFO] number samples: 1024, accuracy: 0.970703125

Wide and Deep Model

Download and Preprocess Data

  1. Download the Criteo dataset using the following command:

    In preprocessing, we will further reduce the amounts of data to speedup the preprocessing, fill missing values, remove the feature values whose occurrences are very rare, etc. Here we choose pandas preprocessing method to make the dataset ready for HugeCTR training.

  2. Preprocessing by Pandas using the following command:

    The first argument represents the dataset postfix. It is 1 here since day_1 is used. The second argument wdl_data is where the preprocessed data is stored. The fourth argument (one after pandas) 1 embodies that the normalization is applied to dense features. The fifth argument 1 means that the feature crossing is applied. The last argument 100 means the number of data files in each file list.

  3. Create a soft link to the dataset folder using the following command:

Run the following commands on the terminal to prepare the data for this notebook

export project_root=/home/hugectr # set this to the directory where hugectr is downloaded
cd ${project_root}/tools
# Step 1
wget https://storage.googleapis.com/criteo-cail-datasets/day_0.gz
#Step 2
bash preprocess.sh 1 wdl_data pandas 1 1 100
#Step 3
ln -s ${project_root}/tools/wdl_data ${project_root}/notebooks/wdl_data

Note: It will take a while (dozens of minutes) to preprocess the dataset. Please make sure that it is finished successfully before moving forward to the next section.

Train from Scratch

We can train fom scratch, dump the model graph to a JSON file, and save the model weights and optimizer states by performing the same steps that we followed with the DCN Model.

%%writefile wdl_train.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              lr = 0.001,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = False,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.0.txt"],
                          eval_source = "wdl_data/file_list.1.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
                        dense_dim = 13, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("wide_data", 30, True, 1),
                        hugectr.DataReaderSparseParam("deep_data", 2, False, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 69,
                            embedding_vec_size = 1,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding2",
                            bottom_name = "wide_data",
                            optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 1074,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "deep_data",
                            optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc2"],
                            top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu2"],
                            top_names = ["dropout2"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout2"],
                            top_names = ["fc3"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
                            bottom_names = ["fc3", "reshape2"],
                            top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["add1", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.graph_to_json(graph_config_file = "wdl.json")
model.fit(num_epochs = 1, display = 500, eval_interval = 500, snapshot = 4000, snapshot_prefix = "wdl")
!python3 wdl_train.py

Fine-tuning

We can only load the sparse embedding layers, their corresponding weights, and then construct a new dense network. The dense weights will be trained first and the sparse weights will be fine-tuned later. We can achieve this by performing the following with Python APIs:

  1. Create the solver, reader and optimizer, then initialize the model.

  2. Load the sparse embedding layers from the saved JSON file.

  3. Add the dense layers on top of the loaded model graph.

  4. Compile the model and have an overview of the model graph.

  5. Load the sparse weights and freeze the sparse embedding layers.

  6. Train the dense weights.

  7. Unfreeze the sparse embedding layers and freeze the dense layers, reset the learning rate scheduler with a small rate.

  8. Fine-tune the sparse weights.

%%writefile wdl_fine_tune.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = False,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.2.txt"],
                          eval_source = "wdl_data/file_list.3.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = False)
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "reshape2", "dense"], top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["fc2", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.load_sparse_weights(["wdl0_sparse_4000.model", "wdl1_sparse_4000.model"])
model.freeze_embedding()
model.fit(num_epochs = 1, display = 500, eval_interval = 1000, snapshot = 100000, snapshot_prefix = "wdl")
model.unfreeze_embedding()
model.freeze_dense()
model.reset_learning_rate_scheduler(base_lr = 0.0001)
model.fit(num_epochs = 2, display = 500, eval_interval = 1000, snapshot = 100000, snapshot_prefix = "wdl")
!python3 wdl_fine_tune.py

Load Pre-trained Embeddings

If you have the pre-trained embeddings in other formats, you can convert them to the HugeCTR sparse models and then load them to facilitate the training process. For the sake of simplicity and generality, we represent the pretrained embeddings with the dictionary of randomly initialized numpy arrays, of which the keys indicate the embedding keys and the array values embody the embedding values. It is worth mentioning that there are two embedding tables for the Wide&Deep model, and here we only load the pre-trained embeddings for one table and freeze the corresponding embedding layer.

%%writefile wdl_load_pretrained.py
import hugectr
from mpi4py import MPI
import numpy as np
import os
import struct
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = False,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.0.txt"],
                          eval_source = "wdl_data/file_list.1.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = True)
model.compile()
model.summary()

def convert_pretrained_embeddings_to_sparse_model(pre_trained_sparse_embeddings, hugectr_sparse_model, embedding_vec_size):
    os.system("mkdir -p {}".format(hugectr_sparse_model))
    with open("{}/key".format(hugectr_sparse_model), 'wb') as key_file, \
        open("{}/emb_vector".format(hugectr_sparse_model), 'wb') as vec_file:
      for key in pre_trained_sparse_embeddings:
        vec = pre_trained_sparse_embeddings[key]
        key_struct = struct.pack('q', key)
        vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
        key_file.write(key_struct)
        vec_file.write(vec_struct)

# Convert the pretrained embeddings
pretrained_embeddings = dict()
hugectr_sparse_model = "wdl1_pretrained.model"
embedding_vec_size = 16
key_range = (0, 100000)
for key in range(key_range[0], key_range[1]):
    pretrained_embeddings[key] = np.random.randn(embedding_vec_size).astype(np.float32)
convert_pretrained_embeddings_to_sparse_model(pretrained_embeddings, hugectr_sparse_model, embedding_vec_size)
print("Successfully convert pretrained embeddings to {}".format(hugectr_sparse_model))

# Load the pretrained sparse models
model.load_sparse_weights({"sparse_embedding1": hugectr_sparse_model})
model.freeze_embedding("sparse_embedding1")
model.fit(num_epochs = 1, display = 500, eval_interval = 1000, snapshot = 100000, snapshot_prefix = "wdl")
!python3 wdl_load_pretrained.py

Low-level Training

The low-level training APIs are maintained in the enhanced HugeCTR Python interface. If you want to have precise control of each training iteration and each evaluation step, you may find it helpful to use these APIs. Since the data reader behavior is different in epoch mode and non-epoch mode, we should pay attention to how to tweak the data reader when using low-level training. We will demonstrate how to write the low-level training scripts for non-epoch mode, epoch mode, and embedding training cache mode.

%%writefile wdl_non_epoch.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = True,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.0.txt"],
                          eval_source = "wdl_data/file_list.1.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = True)
model.compile()
model.start_data_reading()
lr_sch = model.get_learning_rate_scheduler()
max_iter = 2000
for i in range(max_iter):
    lr = lr_sch.get_next()
    model.set_learning_rate(lr)
    model.train()
    if (i%100 == 0):
        loss = model.get_current_loss()
        print("[HUGECTR][INFO] iter: {}; loss: {}".format(i, loss))
    if (i%1000 == 0 and i != 0):
        for _ in range(solver.max_eval_batches):
            model.eval()
        metrics = model.get_eval_metrics()
        print("[HUGECTR][INFO] iter: {}, {}".format(i, metrics))
model.save_params_to_files("./", max_iter)
!python3 wdl_non_epoch.py
%%writefile wdl_epoch.py
import hugectr
from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 5000,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              vvgpu = [[0]],
                              i64_input_key = False,
                              use_mixed_precision = False,
                              repeat_dataset = False,
                              use_cuda_graph = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                          source = ["wdl_data/file_list.0.txt"],
                          eval_source = "wdl_data/file_list.1.txt",
                          check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam)
model = hugectr.Model(solver, reader, optimizer)
model.construct_from_json(graph_config_file = "wdl.json", include_dense_network = True)
model.compile()
lr_sch = model.get_learning_rate_scheduler()
data_reader_train = model.get_data_reader_train()
data_reader_eval = model.get_data_reader_eval()
data_reader_eval.set_source()
data_reader_eval_flag = True
iteration = 0
for epoch in range(2):
  print("[HUGECTR][INFO] epoch: ", epoch)
  data_reader_train.set_source()
  data_reader_train_flag = True
  while True:
    lr = lr_sch.get_next()
    model.set_learning_rate(lr)
    data_reader_train_flag = model.train()
    if not data_reader_train_flag:
      break
    if iteration % 1000 == 0:
      batches = 0
      while data_reader_eval_flag:
        if batches >= solver.max_eval_batches:
          break
        data_reader_eval_flag = model.eval()
        batches += 1
      if not data_reader_eval_flag:
        data_reader_eval.set_source()
        data_reader_eval_flag = True
      metrics = model.get_eval_metrics()
      print("[HUGECTR][INFO] iter: {}, metrics: {}".format(iteration, metrics))
    iteration += 1
model.save_params_to_files("./", iteration)
!python3 wdl_epoch.py