[1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Scaling Criteo: Training with HugeCTR
Overview
HugeCTR is an open-source framework to accelerate the training of CTR estimation models on NVIDIA GPUs. It is written in CUDA C++ and highly exploits GPU-accelerated libraries such as cuBLAS, cuDNN, and NCCL. HugeCTR offers multiple advantages to train deep learning recommender systems: 1. Speed: HugeCTR is a highly efficient framework written C++. We experienced upto 10x speed up. HugeCTR on a NVIDIA DGX A100 system proved to be the fastest commercially available solution for training the architecture Deep Learning Recommender Model (DLRM) developed by Facebook. 2. Scale: HugeCTR supports model parallel scaling. It distributes the large embedding tables over multiple GPUs or multiple nodes. 3. Easy-to-use: Easy-to-use Python API similar to Keras. Examples for popular deep learning recommender systems architectures (Wide&Deep, DLRM, DCN, DeepFM) are available.
HugeCTR is able to train recommender system models with larger-than-memory embedding tables by leveraging a parameter server.
You can find more information about HugeCTR here.
Learning objectives
In this notebook, we learn how to to use HugeCTR for training recommender system models - Use HugeCTR to define a recommender system model - Train Facebook’s Deep Learning Recommendation Model with HugeCTR
Getting Started
As HugeCTR optimizes the training in CUDA++, we need to define the training pipeline and model architecture and execute it via the commandline. We will use the Python API, which is similar to Keras models.
If you are not familiar with HugeCTR’s Python API and parameters, you can read more in its GitHub repository: - HugeCTR User Guide - HugeCTR Python API - HugeCTR Configuration File - HugeCTR example architectures
We will write the code to a ./model.py
file and execute it. It will create snapshot, which we will use for inference in the next notebook.
[3]:
%%writefile './model.py'
import hugectr
from mpi4py import MPI # noqa
# HugeCTR
solver = hugectr.solver_parser_helper(
vvgpu=[[0]],
max_iter=10000,
max_eval_batches=100,
batchsize_eval=2720,
batchsize=2720,
display=1000,
eval_interval=3200,
snapshot=3200,
i64_input_key=True,
use_mixed_precision=False,
repeat_dataset=True,
)
optimizer = hugectr.optimizer.CreateOptimizer(
optimizer_type=hugectr.Optimizer_t.SGD, use_mixed_precision=False
)
model = hugectr.Model(solver, optimizer)
model.add(
hugectr.Input(
data_reader_type=hugectr.DataReaderType_t.Parquet,
source="/raid/data/criteo/test_dask/output/train/_file_list.txt",
eval_source="/raid/data/criteo/test_dask/output/valid/_file_list.txt",
check_type=hugectr.Check_t.Non,
label_dim=1,
label_name="label",
dense_dim=13,
dense_name="dense",
slot_size_array=[
10000000,
10000000,
3014529,
400781,
11,
2209,
11869,
148,
4,
977,
15,
38713,
10000000,
10000000,
10000000,
584616,
12883,
109,
37,
17177,
7425,
20266,
4,
7085,
1535,
64,
],
data_reader_sparse_param_array=[
hugectr.DataReaderSparseParam(hugectr.DataReaderSparse_t.Localized, 26, 1, 26)
],
sparse_names=["data1"],
)
)
model.add(
hugectr.SparseEmbedding(
embedding_type=hugectr.Embedding_t.LocalizedSlotSparseEmbeddingHash,
max_vocabulary_size_per_gpu=15500000,
embedding_vec_size=128,
combiner=0,
sparse_embedding_name="sparse_embedding1",
bottom_name="data1",
)
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.InnerProduct,
bottom_names=["dense"],
top_names=["fc1"],
num_output=512,
)
)
model.add(
hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc1"], top_names=["relu1"])
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.InnerProduct,
bottom_names=["relu1"],
top_names=["fc2"],
num_output=256,
)
)
model.add(
hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc2"], top_names=["relu2"])
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.InnerProduct,
bottom_names=["relu2"],
top_names=["fc3"],
num_output=128,
)
)
model.add(
hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc3"], top_names=["relu3"])
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.Interaction,
bottom_names=["relu3", "sparse_embedding1"],
top_names=["interaction1"],
)
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.InnerProduct,
bottom_names=["interaction1"],
top_names=["fc4"],
num_output=1024,
)
)
model.add(
hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc4"], top_names=["relu4"])
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.InnerProduct,
bottom_names=["relu4"],
top_names=["fc5"],
num_output=1024,
)
)
model.add(
hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc5"], top_names=["relu5"])
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.InnerProduct,
bottom_names=["relu5"],
top_names=["fc6"],
num_output=512,
)
)
model.add(
hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc6"], top_names=["relu6"])
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.InnerProduct,
bottom_names=["relu6"],
top_names=["fc7"],
num_output=256,
)
)
model.add(
hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc7"], top_names=["relu7"])
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.InnerProduct,
bottom_names=["relu7"],
top_names=["fc8"],
num_output=1,
)
)
model.add(
hugectr.DenseLayer(
layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss,
bottom_names=["fc8", "label"],
top_names=["loss"],
)
)
model.compile()
model.summary()
model.fit()
Overwriting ./model.py
[4]:
!python model.py
===================================Model Init====================================
[29d14h45m11s][HUGECTR][INFO]: Global seed is 2284029516
[29d14h45m13s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
Device 0: Tesla V100-SXM2-32GB
[29d14h45m13s][HUGECTR][INFO]: num of DataReader workers: 1
[29d14h45m13s][HUGECTR][INFO]: num_internal_buffers 1
[29d14h45m13s][HUGECTR][INFO]: num_internal_buffers 1
[29d14h45m13s][HUGECTR][INFO]: Vocabulary size: 54120457
[29d14h45m13s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=15500000
[29d14h45m13s][HUGECTR][INFO]: All2All Warmup Start
[29d14h45m13s][HUGECTR][INFO]: All2All Warmup End
[29d14h45m27s][HUGECTR][INFO]: gpu0 start to init embedding
[29d14h45m27s][HUGECTR][INFO]: gpu0 init embedding done
==================================Model Summary==================================
Label Name Dense Name Sparse Name
label dense data1
--------------------------------------------------------------------------------
Layer Type Input Name Output Name
--------------------------------------------------------------------------------
LocalizedHash data1 sparse_embedding1
InnerProduct dense fc1
ReLU fc1 relu1
InnerProduct relu1 fc2
ReLU fc2 relu2
InnerProduct relu2 fc3
ReLU fc3 relu3
Interaction relu3, sparse_embedding1 interaction1
InnerProduct interaction1 fc4
ReLU fc4 relu4
InnerProduct relu4 fc5
ReLU fc5 relu5
InnerProduct relu5 fc6
ReLU fc6 relu6
InnerProduct relu6 fc7
ReLU fc7 relu7
InnerProduct relu7 fc8
BinaryCrossEntropyLoss fc8, label loss
--------------------------------------------------------------------------------
=====================================Model Fit====================================
[29d14h45m27s][HUGECTR][INFO]: Use non-epoch mode with number of iterations: 10000
[29d14h45m27s][HUGECTR][INFO]: Training batchsize: 2720, evaluation batchsize: 2720
[29d14h45m27s][HUGECTR][INFO]: Evaluation interval: 3200, snapshot interval: 3200
[29d14h45m33s][HUGECTR][INFO]: Iter: 1000 Time(1000 iters): 5.641792s Loss: 0.164940 lr:0.001000
[29d14h45m39s][HUGECTR][INFO]: Iter: 2000 Time(1000 iters): 5.656664s Loss: 0.135722 lr:0.001000
[29d14h45m44s][HUGECTR][INFO]: Iter: 3000 Time(1000 iters): 5.651847s Loss: 0.160054 lr:0.001000
[29d14h45m46s][HUGECTR][INFO]: Evaluation, AUC: 0.552071
[29d14h45m46s][HUGECTR][INFO]: Eval Time for 100 iters: 0.285196s
[29d14h45m47s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[29d14h45m48s][HUGECTR][INFO]: Rank0: Write hash table <key,slot_id,value> pairs to file
[29d14h45m49s][HUGECTR][INFO]: Done
[29d14h45m56s][HUGECTR][INFO]: Iter: 4000 Time(1000 iters): 11.744939s Loss: 0.143607 lr:0.001000
[29d14h46m20s][HUGECTR][INFO]: Iter: 5000 Time(1000 iters): 5.678187s Loss: 0.129964 lr:0.001000
[29d14h46m70s][HUGECTR][INFO]: Iter: 6000 Time(1000 iters): 5.678660s Loss: 0.134518 lr:0.001000
[29d14h46m10s][HUGECTR][INFO]: Evaluation, AUC: 0.612846
[29d14h46m10s][HUGECTR][INFO]: Eval Time for 100 iters: 0.218399s
[29d14h46m12s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[29d14h46m14s][HUGECTR][INFO]: Rank0: Write hash table <key,slot_id,value> pairs to file
[29d14h46m15s][HUGECTR][INFO]: Done
[29d14h46m23s][HUGECTR][INFO]: Iter: 7000 Time(1000 iters): 15.466943s Loss: 0.131755 lr:0.001000
[29d14h46m29s][HUGECTR][INFO]: Iter: 8000 Time(1000 iters): 5.657035s Loss: 0.136876 lr:0.001000
[29d14h46m34s][HUGECTR][INFO]: Iter: 9000 Time(1000 iters): 5.659428s Loss: 0.141471 lr:0.001000
[29d14h46m38s][HUGECTR][INFO]: Evaluation, AUC: 0.635054
[29d14h46m38s][HUGECTR][INFO]: Eval Time for 100 iters: 0.192339s
[29d14h46m40s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[29d14h46m43s][HUGECTR][INFO]: Rank0: Write hash table <key,slot_id,value> pairs to file
[29d14h46m45s][HUGECTR][INFO]: Done
We trained the model and created snapshots.