http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png

HugeCTR training with HDFS example

Overview

In version v3.4, we introduced the support for HDFS. Users can now move their data and model files from HDFS to local filesystem through our API to do HugeCTR training. And after training, users can choose to dump the trained parameters and optimizer states into HDFS. In this example notebook, we are going to demonstrate the end to end procedure of training with HDFS.

Get HugeCTR from NGC

The HugeCTR Python module is preinstalled in the 22.04 and later Merlin Training Container: nvcr.io/nvidia/merlin/merlin-training:22.04.

You can check the existence of required libraries by running the following Python code after launching the container.

$ python3 -c "import hugectr"

If you prefer to build HugeCTR from the source code instead of using the NGC container, refer to the How to Start Your Development documentation.

Hadoop Installation and Configuration

Download and Install Hadoop

  1. Download a JDK:

    wget https://download.java.net/java/GA/jdk16.0.2/d4a915d82b4c4fbb9bde534da945d746/7/GPL/openjdk-16.0.2_linux-x64_bin.tar.gz
    tar -zxvf openjdk-16.0.2_linux-x64_bin.tar.gz
    mv jdk-16.0.2 /usr/local
    
  2. Set Java environmental variables:

    export JAVA_HOME=/usr/local/jdk-16.0.2
    export JRE_HOME=${JAVA_HOME}/jre
    export CLASSPATH=.:${JAVA_HOME}/lib:${JRE_HOME}/lib
    export PATH=.:${JAVA_HOME}/bin:$PATH
    
  3. Download and install Hadoop:

    wget https://dlcdn.apache.org/hadoop/common/hadoop-3.3.1/hadoop-3.3.1.tar.gz
    tar -zxvf hadoop-3.3.1.tar.gz
    mv hadoop-3.3.1 /usr/local
    

Hadoop configuration

Set Hadoop environment variables:

export HDFS_NAMENODE_USER=root
export HDFS_DATANODE_USER=root
export HDFS_SECONDARYNAMENODE_USER=root
export YARN_RESOURCEMANAGER_USER=root
export YARN_NODEMANAGER_USER=root
echo ‘export JAVA_HOME=/usr/local/jdk-16.0.2’ >> /usr/local/hadoop/etc/hadoop/hadoop-env.sh

core-site.xml config:

<property>
    <name>fs.defaultFS</name>
    <value>hdfs://namenode:9000</value>
</property>
<property>
    <name>hadoop.tmp.dir</name>
    <value>/usr/local/hadoop/tmp</value>
</property>

hdfs-site.xml for name node:

<property>
     <name>dfs.replication</name>
     <value>4</value>
</property>
<property>
     <name>dfs.namenode.name.dir</name>
     <value>file:/usr/local/hadoop/hdfs/name</value>
</property>
<property>
     <name>dfs.client.block.write.replace-datanode-on-failure.enable</name>         
     <value>true</value>
</property>
<property>
    <name>dfs.client.block.write.replace-datanode-on-failure.policy</name>
    <value>NEVER</value>
</property>

hdfs-site.xml for data node:

<property>
     <name>dfs.replication</name>
     <value>4</value>
</property>
<property>
     <name>dfs.datanode.data.dir</name>
     <value>file:/usr/local/hadoop/hdfs/data</value>
</property>
<property>
     <name>dfs.client.block.write.replace-datanode-on-failure.enable</name>         
     <value>true</value>
</property>
<property>
    <name>dfs.client.block.write.replace-datanode-on-failure.policy</name>
    <value>NEVER</value>
</property>

workers for all node:

worker1
worker2
worker3
worker4

Start HDFS

  1. Enable ssh connection:

    ssh-keygen -t rsa -P '' -f ~/.ssh/id_rsa
    cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys
    /etc/init.d/ssh start
    
  2. Format the NameNode:

    /usr/local/hadoop/bin/hdfs namenode -format
    
  3. Format the DataNodes:

    /usr/local/hadoop/bin/hdfs datanode -format
    
  4. Start HDFS from the NameNode:

    /usr/local/hadoop/sbin/start-dfs.sh
    

Wide and Deep Model

In the Docker container, nvcr.io/nvidia/merlin/merlin-training:22.04, make sure that you installed Hadoop and set the proper environment variables as instructed in the preceding sections.

If you chose to compile HugeCTR, make sure you that you set DENABLE_HDFS to ON.

  • Run export CLASSPATH=$(hadoop classpath --glob) first to link the required JAR file.

  • Make sure that we have the model files your Hadoop cluster and provide the correct links to the model files.

Now you can run the following sample.

%%writefile train_with_hdfs.py
import hugectr
from mpi4py import MPI
from hugectr.data import DataSource, DataSourceParams

data_source_params = DataSourceParams(
    use_hdfs = True, #whether use HDFS to save model files
    namenode = 'localhost', #HDFS namenode IP
    port = 9000, #HDFS port
)

solver = hugectr.CreateSolver(max_eval_batches = 1280,
                              batchsize_eval = 1024,
                              batchsize = 1024,
                              lr = 0.001,
                              vvgpu = [[0]],
                              repeat_dataset = True,
                              data_source_params = data_source_params)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
                                  source = ['./wdl_norm/file_list.txt'],
                                  eval_source = './wdl_norm/file_list_test.txt',
                                  check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam,
                                    update_type = hugectr.Update_t.Global,
                                    beta1 = 0.9,
                                    beta2 = 0.999,
                                    epsilon = 0.0000001)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
                        dense_dim = 13, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        # the total number of slots should be equal to data_generator_params.num_slot
                        [hugectr.DataReaderSparseParam("wide_data", 2, True, 1),
                        hugectr.DataReaderSparseParam("deep_data", 1, True, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 69,
                            embedding_vec_size = 1,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding2",
                            bottom_name = "wide_data",
                            optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 1074,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "deep_data",
                            optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "dense"],
                            top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc2"],
                            top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu2"],
                            top_names = ["dropout2"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout2"],
                            top_names = ["fc3"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
                            bottom_names = ["fc3", "reshape2"],
                            top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["add1", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()

model.load_dense_weights('/model/wdl/_dense_1000.model')
model.load_dense_optimizer_states('/model/wdl/_opt_dense_1000.model')
model.load_sparse_weights(['/model/wdl/0_sparse_1000.model', '/model/wdl/1_sparse_1000.model'])
model.load_sparse_optimizer_states(['/model/wdl/0_opt_sparse_1000.model', '/model/wdl/1_opt_sparse_1000.model'])

model.fit(max_iter = 1020, display = 200, eval_interval = 500, snapshot = 1000, snapshot_prefix = "/model/wdl/")
Overwriting train_with_hdfs.py
!python train_with_hdfs.py
HugeCTR Version: 3.3
====================================================Model Init=====================================================
[HCTR][09:00:54][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][09:00:54][INFO][RK0][main]: Global seed is 1285686508
[HCTR][09:00:55][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][09:00:56][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][09:00:56][INFO][RK0][main]: Start all2all warmup
[HCTR][09:00:56][INFO][RK0][main]: End all2all warmup
[HCTR][09:00:56][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][09:00:56][INFO][RK0][main]: Device 0: Tesla V100-PCIE-32GB
[HCTR][09:00:56][INFO][RK0][main]: num of DataReader workers: 12
[HCTR][09:00:56][INFO][RK0][main]: max_vocabulary_size_per_gpu_=6029312
[HCTR][09:00:56][INFO][RK0][main]: max_vocabulary_size_per_gpu_=5865472
[HCTR][09:00:56][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][09:01:00][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][09:01:00][INFO][RK0][main]: gpu0 init embedding done
[HCTR][09:01:00][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][09:01:00][INFO][RK0][main]: gpu0 init embedding done
[HCTR][09:01:00][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][09:01:00][INFO][RK0][main]: Warm-up done
[HCTR][09:01:00][INFO][RK0][main]: ===================================================Model Summary===================================================
label                                   Dense                         Sparse                        
label                                   dense                          wide_data,deep_data           
(None, 1)                               (None, 13)                              
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
LocalizedSlotSparseEmbeddingHash        wide_data                     sparse_embedding2             (None, 1, 1)                  
------------------------------------------------------------------------------------------------------------------
LocalizedSlotSparseEmbeddingHash        deep_data                     sparse_embedding1             (None, 26, 16)                
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (None, 416)                   
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding2             reshape2                      (None, 1)                     
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (None, 429)                   
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1                       fc1                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu1                         dropout1                      (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout1                      fc2                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc2                           relu2                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu2                         dropout2                      (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout2                      fc3                           (None, 1)                     
------------------------------------------------------------------------------------------------------------------
Add                                     fc3                           add1                          (None, 1)                     
                                        reshape2                                                                                  
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  add1                          loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
2022-02-23 09:01:00,548 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
[HDFS][INFO]: Read file /model/wdl/_dense_1000.model successfully!
[HDFS][INFO]: Read file /model/wdl/_opt_dense_1000.model successfully!
[HCTR][09:01:01][INFO][RK0][main]: Loading dense opt states: 
[HCTR][09:01:01][INFO][RK0][main]: Loading sparse model: /model/wdl/0_sparse_1000.model
[HDFS][INFO]: Read file /model/wdl/0_sparse_1000.model/key successfully!
[HDFS][INFO]: Read file /model/wdl/0_sparse_1000.model/slot_id successfully!
[HDFS][INFO]: Read file /model/wdl/0_sparse_1000.model/emb_vector successfully!
[HCTR][09:01:01][INFO][RK0][main]: Start to upload embedding table file to GPUs, total loop_num: 128
[HCTR][09:01:01][INFO][RK0][main]: Done
[HCTR][09:01:01][INFO][RK0][main]: Loading sparse model: /model/wdl/1_sparse_1000.model
[HDFS][INFO]: Read file /model/wdl/1_sparse_1000.model/key successfully!
[HDFS][INFO]: Read file /model/wdl/1_sparse_1000.model/slot_id successfully!
[HDFS][INFO]: Read file /model/wdl/1_sparse_1000.model/emb_vector successfully!
[HCTR][09:01:01][INFO][RK0][main]: Start to upload embedding table file to GPUs, total loop_num: 518
[HCTR][09:01:01][INFO][RK0][main]: Done
[HCTR][09:01:01][INFO][RK0][main]: Loading sparse optimizer states: /model/wdl/0_opt_sparse_1000.model
[HCTR][09:01:01][INFO][RK0][main]: Rank0: Read optimzer state from file
[HDFS][INFO]: Read file /model/wdl/0_opt_sparse_1000.model successfully!
[HCTR][09:01:01][INFO][RK0][main]: Done
[HCTR][09:01:01][INFO][RK0][main]: Rank0: Read optimzer state from file
[HDFS][INFO]: Read file /model/wdl/0_opt_sparse_1000.model successfully!
[HCTR][09:01:01][INFO][RK0][main]: Done
[HCTR][09:01:01][INFO][RK0][main]: Loading sparse optimizer states: /model/wdl/1_opt_sparse_1000.model
[HCTR][09:01:01][INFO][RK0][main]: Rank0: Read optimzer state from file
[HDFS][INFO]: Read file /model/wdl/1_opt_sparse_1000.model successfully!
[HCTR][09:01:02][INFO][RK0][main]: Done
[HCTR][09:01:02][INFO][RK0][main]: Rank0: Read optimzer state from file
[HDFS][INFO]: Read file /model/wdl/1_opt_sparse_1000.model successfully!
[HCTR][09:01:02][INFO][RK0][main]: Done
=====================================================Model Fit=====================================================
[HCTR][09:01:02][INFO][RK0][main]: Use non-epoch mode with number of iterations: 1020
[HCTR][09:01:02][INFO][RK0][main]: Training batchsize: 1024, evaluation batchsize: 1024
[HCTR][09:01:02][INFO][RK0][main]: Evaluation interval: 500, snapshot interval: 1000
[HCTR][09:01:02][INFO][RK0][main]: Dense network trainable: True
[HCTR][09:01:02][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][09:01:02][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][09:01:02][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][09:01:02][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][09:01:02][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][09:01:02][INFO][RK0][main]: Training source file: ./wdl_norm/file_list.txt
[HCTR][09:01:02][INFO][RK0][main]: Evaluation source file: ./wdl_norm/file_list_test.txt
[HCTR][09:01:04][INFO][RK0][main]: Iter: 200 Time(200 iters): 1.12465s Loss: 0.632464 lr:0.001
[HCTR][09:01:05][INFO][RK0][main]: Iter: 400 Time(200 iters): 1.03567s Loss: 0.612515 lr:0.001
[HCTR][09:01:06][INFO][RK0][main]: Evaluation, AUC: 0.499877
[HCTR][09:01:06][INFO][RK0][main]: Eval Time for 1280 iters: 0.647875s
[HCTR][09:01:06][INFO][RK0][main]: Iter: 600 Time(200 iters): 1.68717s Loss: 0.625102 lr:0.001
[HCTR][09:01:07][INFO][RK0][main]: Iter: 800 Time(200 iters): 1.03752s Loss: 0.608092 lr:0.001
[HCTR][09:01:08][INFO][RK0][main]: Iter: 1000 Time(200 iters): 1.03691s Loss: 0.688194 lr:0.001
[HCTR][09:01:09][INFO][RK0][main]: Evaluation, AUC: 0.500383
[HCTR][09:01:09][INFO][RK0][main]: Eval Time for 1280 iters: 0.650671s
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Dump hash table from GPU0
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Write hash table <key,value> pairs to file
[HDFS][INFO]: Write to HDFS /model/wdl/0_sparse_1000.model/key successfully!
[HDFS][INFO]: Write to HDFS /model/wdl/0_sparse_1000.model/slot_id successfully!
[HDFS][INFO]: Write to HDFS /model/wdl/0_sparse_1000.model/emb_vector successfully!
[HCTR][09:01:09][INFO][RK0][main]: Done
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Dump hash table from GPU0
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Write hash table <key,value> pairs to file
[HDFS][INFO]: Write to HDFS /model/wdl/1_sparse_1000.model/key successfully!
[HDFS][INFO]: Write to HDFS /model/wdl/1_sparse_1000.model/slot_id successfully!
[HDFS][INFO]: Write to HDFS /model/wdl/1_sparse_1000.model/emb_vector successfully!
[HCTR][09:01:09][INFO][RK0][main]: Done
[HCTR][09:01:09][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Write optimzer state to file
[HDFS][INFO]: Write to HDFS /model/wdl/0_opt_sparse_1000.model successfully!
[HCTR][09:01:09][INFO][RK0][main]: Done
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Write optimzer state to file
[HDFS][INFO]: Write to HDFS /model/wdl/0_opt_sparse_1000.model successfully!
[HCTR][09:01:10][INFO][RK0][main]: Done
[HCTR][09:01:10][INFO][RK0][main]: Rank0: Write optimzer state to file
[HDFS][INFO]: Write to HDFS /model/wdl/1_opt_sparse_1000.model successfully!
[HCTR][09:01:11][INFO][RK0][main]: Done
[HCTR][09:01:11][INFO][RK0][main]: Rank0: Write optimzer state to file
[HDFS][INFO]: Write to HDFS /model/wdl/1_opt_sparse_1000.model successfully!
[HCTR][09:01:12][INFO][RK0][main]: Done
[HCTR][09:01:12][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HDFS][INFO]: Write to HDFS /model/wdl/_dense_1000.model successfully!
[HCTR][09:01:12][INFO][RK0][main]: Dumping dense weights to HDFS, successful
[HDFS][INFO]: Write to HDFS /model/wdl/_opt_dense_1000.model successfully!
[HCTR][09:01:12][INFO][RK0][main]: Dumping dense optimizer states to HDFS, successful
[HCTR][09:01:12][INFO][RK0][main]: Finish 1020 iterations with batchsize: 1024 in 9.82s.