HugeCTR training with HDFS example
Overview
In version v3.4, we introduced the support for HDFS. Users can now move their data and model files from HDFS to local filesystem through our API to do HugeCTR training. And after training, users can choose to dump the trained parameters and optimizer states into HDFS. In this example notebook, we are going to demonstrate the end to end procedure of training with HDFS.
Get HugeCTR from NGC
The HugeCTR Python module is preinstalled in the 22.07 and later Merlin Training Container: nvcr.io/nvidia/merlin/merlin-hugectr:22.07
.
You can check the existence of required libraries by running the following Python code after launching the container.
$ python3 -c "import hugectr"
If you prefer to build HugeCTR from the source code instead of using the NGC container, refer to the How to Start Your Development documentation.
Hadoop Installation and Configuration
Download and Install Hadoop
Download a JDK:
wget https://download.java.net/java/GA/jdk16.0.2/d4a915d82b4c4fbb9bde534da945d746/7/GPL/openjdk-16.0.2_linux-x64_bin.tar.gz tar -zxvf openjdk-16.0.2_linux-x64_bin.tar.gz mv jdk-16.0.2 /usr/local
Set Java environmental variables:
export JAVA_HOME=/usr/local/jdk-16.0.2 export JRE_HOME=${JAVA_HOME}/jre export CLASSPATH=.:${JAVA_HOME}/lib:${JRE_HOME}/lib export PATH=.:${JAVA_HOME}/bin:$PATH
Download and install Hadoop:
wget https://dlcdn.apache.org/hadoop/common/hadoop-3.3.1/hadoop-3.3.1.tar.gz tar -zxvf hadoop-3.3.1.tar.gz mv hadoop-3.3.1 /usr/local
Hadoop configuration
Set Hadoop environment variables:
export HDFS_NAMENODE_USER=root
export HDFS_DATANODE_USER=root
export HDFS_SECONDARYNAMENODE_USER=root
export YARN_RESOURCEMANAGER_USER=root
export YARN_NODEMANAGER_USER=root
echo ‘export JAVA_HOME=/usr/local/jdk-16.0.2’ >> /usr/local/hadoop/etc/hadoop/hadoop-env.sh
core-site.xml
config:
<property>
<name>fs.defaultFS</name>
<value>hdfs://namenode:9000</value>
</property>
<property>
<name>hadoop.tmp.dir</name>
<value>/usr/local/hadoop/tmp</value>
</property>
hdfs-site.xml
for name node:
<property>
<name>dfs.replication</name>
<value>4</value>
</property>
<property>
<name>dfs.namenode.name.dir</name>
<value>file:/usr/local/hadoop/hdfs/name</value>
</property>
<property>
<name>dfs.client.block.write.replace-datanode-on-failure.enable</name>
<value>true</value>
</property>
<property>
<name>dfs.client.block.write.replace-datanode-on-failure.policy</name>
<value>NEVER</value>
</property>
hdfs-site.xml
for data node:
<property>
<name>dfs.replication</name>
<value>4</value>
</property>
<property>
<name>dfs.datanode.data.dir</name>
<value>file:/usr/local/hadoop/hdfs/data</value>
</property>
<property>
<name>dfs.client.block.write.replace-datanode-on-failure.enable</name>
<value>true</value>
</property>
<property>
<name>dfs.client.block.write.replace-datanode-on-failure.policy</name>
<value>NEVER</value>
</property>
workers
for all node:
worker1
worker2
worker3
worker4
Start HDFS
Enable ssh connection:
ssh-keygen -t rsa -P '' -f ~/.ssh/id_rsa cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys /etc/init.d/ssh start
Format the NameNode:
/usr/local/hadoop/bin/hdfs namenode -format
Format the DataNodes:
/usr/local/hadoop/bin/hdfs datanode -format
Start HDFS from the NameNode:
/usr/local/hadoop/sbin/start-dfs.sh
Wide and Deep Model
In the Docker container, nvcr.io/nvidia/merlin/merlin-hugectr:22.07
,
make sure that you installed Hadoop and set the proper environment variables as instructed in the preceding sections.
If you chose to compile HugeCTR, make sure you that you set DENABLE_HDFS
to ON
.
Run
export CLASSPATH=$(hadoop classpath --glob)
first to link the required JAR file.Make sure that we have the model files your Hadoop cluster and provide the correct links to the model files.
Now you can run the following sample.
%%writefile train_with_hdfs.py
import hugectr
from mpi4py import MPI
from hugectr.data import DataSource, DataSourceParams
data_source_params = DataSourceParams(
use_hdfs = True, #whether use HDFS to save model files
namenode = 'localhost', #HDFS namenode IP
port = 9000, #HDFS port
)
solver = hugectr.CreateSolver(max_eval_batches = 1280,
batchsize_eval = 1024,
batchsize = 1024,
lr = 0.001,
vvgpu = [[0]],
repeat_dataset = True,
data_source_params = data_source_params)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,
source = ['./wdl_norm/file_list.txt'],
eval_source = './wdl_norm/file_list_test.txt',
check_type = hugectr.Check_t.Sum)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam,
update_type = hugectr.Update_t.Global,
beta1 = 0.9,
beta2 = 0.999,
epsilon = 0.0000001)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
dense_dim = 13, dense_name = "dense",
data_reader_sparse_param_array =
# the total number of slots should be equal to data_generator_params.num_slot
[hugectr.DataReaderSparseParam("wide_data", 2, True, 1),
hugectr.DataReaderSparseParam("deep_data", 1, True, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 69,
embedding_vec_size = 1,
combiner = "sum",
sparse_embedding_name = "sparse_embedding2",
bottom_name = "wide_data",
optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 1074,
embedding_vec_size = 16,
combiner = "sum",
sparse_embedding_name = "sparse_embedding1",
bottom_name = "deep_data",
optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding1"],
top_names = ["reshape1"],
leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding2"],
top_names = ["reshape2"],
leading_dim=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
bottom_names = ["reshape1", "dense"],
top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["concat1"],
top_names = ["fc1"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc1"],
top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu1"],
top_names = ["dropout1"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout1"],
top_names = ["fc2"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc2"],
top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu2"],
top_names = ["dropout2"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout2"],
top_names = ["fc3"],
num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
bottom_names = ["fc3", "reshape2"],
top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
bottom_names = ["add1", "label"],
top_names = ["loss"]))
model.compile()
model.summary()
model.load_dense_weights('/model/wdl/_dense_1000.model')
model.load_dense_optimizer_states('/model/wdl/_opt_dense_1000.model')
model.load_sparse_weights(['/model/wdl/0_sparse_1000.model', '/model/wdl/1_sparse_1000.model'])
model.load_sparse_optimizer_states(['/model/wdl/0_opt_sparse_1000.model', '/model/wdl/1_opt_sparse_1000.model'])
model.fit(max_iter = 1020, display = 200, eval_interval = 500, snapshot = 1000, snapshot_prefix = "/model/wdl/")
Overwriting train_with_hdfs.py
!python train_with_hdfs.py
HugeCTR Version: 3.3
====================================================Model Init=====================================================
[HCTR][09:00:54][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][09:00:54][INFO][RK0][main]: Global seed is 1285686508
[HCTR][09:00:55][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][09:00:56][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][09:00:56][INFO][RK0][main]: Start all2all warmup
[HCTR][09:00:56][INFO][RK0][main]: End all2all warmup
[HCTR][09:00:56][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][09:00:56][INFO][RK0][main]: Device 0: Tesla V100-PCIE-32GB
[HCTR][09:00:56][INFO][RK0][main]: num of DataReader workers: 12
[HCTR][09:00:56][INFO][RK0][main]: max_vocabulary_size_per_gpu_=6029312
[HCTR][09:00:56][INFO][RK0][main]: max_vocabulary_size_per_gpu_=5865472
[HCTR][09:00:56][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][09:01:00][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][09:01:00][INFO][RK0][main]: gpu0 init embedding done
[HCTR][09:01:00][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][09:01:00][INFO][RK0][main]: gpu0 init embedding done
[HCTR][09:01:00][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][09:01:00][INFO][RK0][main]: Warm-up done
[HCTR][09:01:00][INFO][RK0][main]: ===================================================Model Summary===================================================
label Dense Sparse
label dense wide_data,deep_data
(None, 1) (None, 13)
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type Input Name Output Name Output Shape
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
LocalizedSlotSparseEmbeddingHash wide_data sparse_embedding2 (None, 1, 1)
------------------------------------------------------------------------------------------------------------------
LocalizedSlotSparseEmbeddingHash deep_data sparse_embedding1 (None, 26, 16)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding1 reshape1 (None, 416)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding2 reshape2 (None, 1)
------------------------------------------------------------------------------------------------------------------
Concat reshape1 concat1 (None, 429)
dense
------------------------------------------------------------------------------------------------------------------
InnerProduct concat1 fc1 (None, 1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc1 relu1 (None, 1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu1 dropout1 (None, 1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout1 fc2 (None, 1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc2 relu2 (None, 1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu2 dropout2 (None, 1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout2 fc3 (None, 1)
------------------------------------------------------------------------------------------------------------------
Add fc3 add1 (None, 1)
reshape2
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss add1 loss
label
------------------------------------------------------------------------------------------------------------------
2022-02-23 09:01:00,548 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
[HDFS][INFO]: Read file /model/wdl/_dense_1000.model successfully!
[HDFS][INFO]: Read file /model/wdl/_opt_dense_1000.model successfully!
[HCTR][09:01:01][INFO][RK0][main]: Loading dense opt states:
[HCTR][09:01:01][INFO][RK0][main]: Loading sparse model: /model/wdl/0_sparse_1000.model
[HDFS][INFO]: Read file /model/wdl/0_sparse_1000.model/key successfully!
[HDFS][INFO]: Read file /model/wdl/0_sparse_1000.model/slot_id successfully!
[HDFS][INFO]: Read file /model/wdl/0_sparse_1000.model/emb_vector successfully!
[HCTR][09:01:01][INFO][RK0][main]: Start to upload embedding table file to GPUs, total loop_num: 128
[HCTR][09:01:01][INFO][RK0][main]: Done
[HCTR][09:01:01][INFO][RK0][main]: Loading sparse model: /model/wdl/1_sparse_1000.model
[HDFS][INFO]: Read file /model/wdl/1_sparse_1000.model/key successfully!
[HDFS][INFO]: Read file /model/wdl/1_sparse_1000.model/slot_id successfully!
[HDFS][INFO]: Read file /model/wdl/1_sparse_1000.model/emb_vector successfully!
[HCTR][09:01:01][INFO][RK0][main]: Start to upload embedding table file to GPUs, total loop_num: 518
[HCTR][09:01:01][INFO][RK0][main]: Done
[HCTR][09:01:01][INFO][RK0][main]: Loading sparse optimizer states: /model/wdl/0_opt_sparse_1000.model
[HCTR][09:01:01][INFO][RK0][main]: Rank0: Read optimzer state from file
[HDFS][INFO]: Read file /model/wdl/0_opt_sparse_1000.model successfully!
[HCTR][09:01:01][INFO][RK0][main]: Done
[HCTR][09:01:01][INFO][RK0][main]: Rank0: Read optimzer state from file
[HDFS][INFO]: Read file /model/wdl/0_opt_sparse_1000.model successfully!
[HCTR][09:01:01][INFO][RK0][main]: Done
[HCTR][09:01:01][INFO][RK0][main]: Loading sparse optimizer states: /model/wdl/1_opt_sparse_1000.model
[HCTR][09:01:01][INFO][RK0][main]: Rank0: Read optimzer state from file
[HDFS][INFO]: Read file /model/wdl/1_opt_sparse_1000.model successfully!
[HCTR][09:01:02][INFO][RK0][main]: Done
[HCTR][09:01:02][INFO][RK0][main]: Rank0: Read optimzer state from file
[HDFS][INFO]: Read file /model/wdl/1_opt_sparse_1000.model successfully!
[HCTR][09:01:02][INFO][RK0][main]: Done
=====================================================Model Fit=====================================================
[HCTR][09:01:02][INFO][RK0][main]: Use non-epoch mode with number of iterations: 1020
[HCTR][09:01:02][INFO][RK0][main]: Training batchsize: 1024, evaluation batchsize: 1024
[HCTR][09:01:02][INFO][RK0][main]: Evaluation interval: 500, snapshot interval: 1000
[HCTR][09:01:02][INFO][RK0][main]: Dense network trainable: True
[HCTR][09:01:02][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][09:01:02][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][09:01:02][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][09:01:02][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][09:01:02][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][09:01:02][INFO][RK0][main]: Training source file: ./wdl_norm/file_list.txt
[HCTR][09:01:02][INFO][RK0][main]: Evaluation source file: ./wdl_norm/file_list_test.txt
[HCTR][09:01:04][INFO][RK0][main]: Iter: 200 Time(200 iters): 1.12465s Loss: 0.632464 lr:0.001
[HCTR][09:01:05][INFO][RK0][main]: Iter: 400 Time(200 iters): 1.03567s Loss: 0.612515 lr:0.001
[HCTR][09:01:06][INFO][RK0][main]: Evaluation, AUC: 0.499877
[HCTR][09:01:06][INFO][RK0][main]: Eval Time for 1280 iters: 0.647875s
[HCTR][09:01:06][INFO][RK0][main]: Iter: 600 Time(200 iters): 1.68717s Loss: 0.625102 lr:0.001
[HCTR][09:01:07][INFO][RK0][main]: Iter: 800 Time(200 iters): 1.03752s Loss: 0.608092 lr:0.001
[HCTR][09:01:08][INFO][RK0][main]: Iter: 1000 Time(200 iters): 1.03691s Loss: 0.688194 lr:0.001
[HCTR][09:01:09][INFO][RK0][main]: Evaluation, AUC: 0.500383
[HCTR][09:01:09][INFO][RK0][main]: Eval Time for 1280 iters: 0.650671s
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Dump hash table from GPU0
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Write hash table <key,value> pairs to file
[HDFS][INFO]: Write to HDFS /model/wdl/0_sparse_1000.model/key successfully!
[HDFS][INFO]: Write to HDFS /model/wdl/0_sparse_1000.model/slot_id successfully!
[HDFS][INFO]: Write to HDFS /model/wdl/0_sparse_1000.model/emb_vector successfully!
[HCTR][09:01:09][INFO][RK0][main]: Done
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Dump hash table from GPU0
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Write hash table <key,value> pairs to file
[HDFS][INFO]: Write to HDFS /model/wdl/1_sparse_1000.model/key successfully!
[HDFS][INFO]: Write to HDFS /model/wdl/1_sparse_1000.model/slot_id successfully!
[HDFS][INFO]: Write to HDFS /model/wdl/1_sparse_1000.model/emb_vector successfully!
[HCTR][09:01:09][INFO][RK0][main]: Done
[HCTR][09:01:09][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Write optimzer state to file
[HDFS][INFO]: Write to HDFS /model/wdl/0_opt_sparse_1000.model successfully!
[HCTR][09:01:09][INFO][RK0][main]: Done
[HCTR][09:01:09][INFO][RK0][main]: Rank0: Write optimzer state to file
[HDFS][INFO]: Write to HDFS /model/wdl/0_opt_sparse_1000.model successfully!
[HCTR][09:01:10][INFO][RK0][main]: Done
[HCTR][09:01:10][INFO][RK0][main]: Rank0: Write optimzer state to file
[HDFS][INFO]: Write to HDFS /model/wdl/1_opt_sparse_1000.model successfully!
[HCTR][09:01:11][INFO][RK0][main]: Done
[HCTR][09:01:11][INFO][RK0][main]: Rank0: Write optimzer state to file
[HDFS][INFO]: Write to HDFS /model/wdl/1_opt_sparse_1000.model successfully!
[HCTR][09:01:12][INFO][RK0][main]: Done
[HCTR][09:01:12][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HDFS][INFO]: Write to HDFS /model/wdl/_dense_1000.model successfully!
[HCTR][09:01:12][INFO][RK0][main]: Dumping dense weights to HDFS, successful
[HDFS][INFO]: Write to HDFS /model/wdl/_opt_dense_1000.model successfully!
[HCTR][09:01:12][INFO][RK0][main]: Dumping dense optimizer states to HDFS, successful
[HCTR][09:01:12][INFO][RK0][main]: Finish 1020 iterations with batchsize: 1024 in 9.82s.