# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

Scaling Criteo: Training with HugeCTR

Overview

HugeCTR is an open-source framework to accelerate the training of CTR estimation models on NVIDIA GPUs. It is written in CUDA C++ and highly exploits GPU-accelerated libraries such as cuBLAS, cuDNN, and NCCL.

HugeCTR offers multiple advantages to train deep learning recommender systems:

  1. Speed: HugeCTR is a highly efficient framework written C++. We experienced up to 10x speed up. HugeCTR on a NVIDIA DGX A100 system proved to be the fastest commercially available solution for training the architecture Deep Learning Recommender Model (DLRM) developed by Facebook.

  2. Scale: HugeCTR supports model parallel scaling. It distributes the large embedding tables over multiple GPUs or multiple nodes.

  3. Easy-to-use: Easy-to-use Python API similar to Keras. Examples for popular deep learning recommender systems architectures (Wide&Deep, DLRM, DCN, DeepFM) are available.

HugeCTR is able to train recommender system models with larger-than-memory embedding tables by leveraging a parameter server.

You can find more information about HugeCTR here.

Learning objectives

In this notebook, we learn how to to use HugeCTR for training recommender system models

Training with HugeCTR

As HugeCTR optimizes the training in CUDA++, we need to define the training pipeline and model architecture and execute it via the commandline. We will use the Python API, which is similar to Keras models.

If you are not familiar with HugeCTR’s Python API and parameters, you can read more in its GitHub repository:

We will write the code to a ./model.py file and execute it. It will create snapshot, which we will use for inference in the next notebook.

!ls /raid/data/criteo/test_dask/output/
test_dask  train  valid  workflow
import os

os.system("rm -rf ./criteo_hugectr/")
os.system("mkdir -p ./criteo_hugectr/1")
0
INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", '/tmp/model/data')
data_path = os.path.join(INPUT_DATA_DIR, "train", "_file_list.txt")

We use graph_to_json to convert the model to a JSON configuration, required for the inference.

# %%writefile './model.py'
file_to_write = f"""
import hugectr
from mpi4py import MPI  # noqa

# HugeCTR
solver = hugectr.CreateSolver(
    vvgpu=[[0]],
    max_eval_batches=100,
    batchsize_eval=2720,
    batchsize=2720,
    i64_input_key=True,
    use_mixed_precision=False,
    repeat_dataset=True,
)
optimizer = hugectr.CreateOptimizer(optimizer_type=hugectr.Optimizer_t.SGD)
reader = hugectr.DataReaderParams(
    data_reader_type=hugectr.DataReaderType_t.Parquet,
    source=["{data_path}"],
    eval_source="{data_path}",
    check_type=hugectr.Check_t.Non,
    slot_size_array=[
        10000000,
        10000000,
        3014529,
        400781,
        11,
        2209,
        11869,
        148,
        4,
        977,
        15,
        38713,
        10000000,
        10000000,
        10000000,
        584616,
        12883,
        109,
        37,
        17177,
        7425,
        20266,
        4,
        7085,
        1535,
        64,
    ],
)
model = hugectr.Model(solver, reader, optimizer)
model.add(
    hugectr.Input(
        label_dim=1,
        label_name="label",
        dense_dim=13,
        dense_name="dense",
        data_reader_sparse_param_array=[hugectr.DataReaderSparseParam("data1", 1, False, 26)],
    )
)
model.add(
    hugectr.SparseEmbedding(
        embedding_type=hugectr.Embedding_t.LocalizedSlotSparseEmbeddingHash,
        workspace_size_per_gpu_in_mb=6000,
        embedding_vec_size=128,
        combiner="sum",
        sparse_embedding_name="sparse_embedding1",
        bottom_name="data1",
        optimizer=optimizer,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["dense"],
        top_names=["fc1"],
        num_output=512,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc1"], top_names=["relu1"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu1"],
        top_names=["fc2"],
        num_output=256,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc2"], top_names=["relu2"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu2"],
        top_names=["fc3"],
        num_output=128,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc3"], top_names=["relu3"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.Interaction,
        bottom_names=["relu3", "sparse_embedding1"],
        top_names=["interaction1"],
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["interaction1"],
        top_names=["fc4"],
        num_output=1024,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc4"], top_names=["relu4"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu4"],
        top_names=["fc5"],
        num_output=1024,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc5"], top_names=["relu5"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu5"],
        top_names=["fc6"],
        num_output=512,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc6"], top_names=["relu6"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu6"],
        top_names=["fc7"],
        num_output=256,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc7"], top_names=["relu7"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu7"],
        top_names=["fc8"],
        num_output=1,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss,
        bottom_names=["fc8", "label"],
        top_names=["loss"],
    )
)

MAX_ITER = 10000
EVAL_INTERVAL = 3200
model.compile()
model.summary()
model.fit(max_iter=MAX_ITER, eval_interval=EVAL_INTERVAL, display=1000, snapshot=3200)
model.graph_to_json(graph_config_file="./criteo_hugectr/1/criteo.json")
"""
with open('./model.py', 'w', encoding='utf-8') as fi:
    fi.write(file_to_write)
import time

start = time.time()
!python model.py
end = time.time() - start
print(f"run_time: {end}")
HugeCTR Version: 3.7
====================================================Model Init=====================================================
[HCTR][02:44:38.212][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][02:44:38.212][WARNING][RK0][main]: MPI was already initialized somewhere elese. Lifetime service disabled.
[HCTR][02:44:38.212][INFO][RK0][main]: Global seed is 3391378239
[HCTR][02:44:38.255][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][02:44:40.126][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][02:44:40.127][INFO][RK0][main]: Start all2all warmup
[HCTR][02:44:40.127][INFO][RK0][main]: End all2all warmup
[HCTR][02:44:40.127][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][02:44:40.127][INFO][RK0][main]: Device 0: Quadro RTX 8000
[HCTR][02:44:40.128][INFO][RK0][main]: num of DataReader workers: 1
[HCTR][02:44:40.129][INFO][RK0][main]: Vocabulary size: 54120457
[HCTR][02:44:40.130][INFO][RK0][main]: max_vocabulary_size_per_gpu_=12288000
[HCTR][02:44:40.130][DEBUG][RK0][tid #139916176520960]: file_name_ /tmp/pytest-of-root/pytest-9/test_criteo_hugectr0/tests/crit_test/train/part_0.parquet file_total_rows_ 138449698
[HCTR][02:44:40.130][DEBUG][RK0][tid #139916168128256]: file_name_ /tmp/pytest-of-root/pytest-9/test_criteo_hugectr0/tests/crit_test/train/part_0.parquet file_total_rows_ 138449698
[HCTR][02:44:40.138][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][02:44:55.150][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][02:44:55.230][INFO][RK0][main]: gpu0 init embedding done
[HCTR][02:44:55.234][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][02:44:55.235][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][02:44:55.235][INFO][RK0][main]: label                                   Dense                         Sparse                        
label                                   dense                          data1                         
(None, 1)                               (None, 13)                              
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
LocalizedSlotSparseEmbeddingHash        data1                         sparse_embedding1             (None, 26, 128)               
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dense                         fc1                           (None, 512)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (None, 512)                   
------------------------------------------------------------------------------------------------------------------
InnerProduct                            relu1                         fc2                           (None, 256)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc2                           relu2                         (None, 256)                   
------------------------------------------------------------------------------------------------------------------
InnerProduct                            relu2                         fc3                           (None, 128)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc3                           relu3                         (None, 128)                   
------------------------------------------------------------------------------------------------------------------
Interaction                             relu3                         interaction1                  (None, 480)                   
                                        sparse_embedding1                                                                         
------------------------------------------------------------------------------------------------------------------
InnerProduct                            interaction1                  fc4                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc4                           relu4                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            relu4                         fc5                           (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc5                           relu5                         (None, 1024)                  
------------------------------------------------------------------------------------------------------------------
InnerProduct                            relu5                         fc6                           (None, 512)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc6                           relu6                         (None, 512)                   
------------------------------------------------------------------------------------------------------------------
InnerProduct                            relu6                         fc7                           (None, 256)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc7                           relu7                         (None, 256)                   
------------------------------------------------------------------------------------------------------------------
InnerProduct                            relu7                         fc8                           (None, 1)                     
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  fc8                           loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[HCTR][02:44:55.235][INFO][RK0][main]: Use non-epoch mode with number of iterations: 10000
[HCTR][02:44:55.235][INFO][RK0][main]: Training batchsize: 2720, evaluation batchsize: 2720
[HCTR][02:44:55.235][INFO][RK0][main]: Evaluation interval: 3200, snapshot interval: 3200
[HCTR][02:44:55.235][INFO][RK0][main]: Dense network trainable: True
[HCTR][02:44:55.235][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][02:44:55.235][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][02:44:55.235][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][02:44:55.235][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][02:44:55.235][INFO][RK0][main]: Training source file: /tmp/pytest-of-root/pytest-9/test_criteo_hugectr0/tests/crit_test/train/_file_list.txt
[HCTR][02:44:55.235][INFO][RK0][main]: Evaluation source file: /tmp/pytest-of-root/pytest-9/test_criteo_hugectr0/tests/crit_test/train/_file_list.txt
[HCTR][02:45:01.551][INFO][RK0][main]: Iter: 1000 Time(1000 iters): 6.31026s Loss: 0.170242 lr:0.001
[HCTR][02:45:08.116][INFO][RK0][main]: Iter: 2000 Time(1000 iters): 6.5595s Loss: 0.142086 lr:0.001
[HCTR][02:45:14.999][INFO][RK0][main]: Iter: 3000 Time(1000 iters): 6.87726s Loss: 0.144497 lr:0.001
[HCTR][02:45:16.619][INFO][RK0][main]: Evaluation, AUC: 0.522062
[HCTR][02:45:16.619][INFO][RK0][main]: Eval Time for 100 iters: 0.218802s
[HCTR][02:45:17.186][INFO][RK0][main]: Rank0: Dump hash table from GPU0
[HCTR][02:45:17.362][INFO][RK0][main]: Rank0: Write hash table <key,value> pairs to file
[HCTR][02:45:18.490][INFO][RK0][main]: Done
[HCTR][02:45:18.802][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][02:45:18.802][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][02:45:18.812][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][02:45:18.812][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][02:45:24.512][INFO][RK0][main]: Iter: 4000 Time(1000 iters): 9.50778s Loss: 0.142673 lr:0.001
[HCTR][02:45:31.873][INFO][RK0][main]: Iter: 5000 Time(1000 iters): 7.35528s Loss: 0.13817 lr:0.001
[HCTR][02:45:39.491][INFO][RK0][main]: Iter: 6000 Time(1000 iters): 7.61235s Loss: 0.145115 lr:0.001
[HCTR][02:45:42.840][INFO][RK0][main]: Evaluation, AUC: 0.57392
[HCTR][02:45:42.840][INFO][RK0][main]: Eval Time for 100 iters: 0.249069s
[HCTR][02:45:43.756][INFO][RK0][main]: Rank0: Dump hash table from GPU0
[HCTR][02:45:44.043][INFO][RK0][main]: Rank0: Write hash table <key,value> pairs to file
[HCTR][02:45:45.935][INFO][RK0][main]: Done
[HCTR][02:45:46.480][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][02:45:46.480][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][02:45:46.486][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][02:45:46.486][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][02:45:51.203][INFO][RK0][main]: Iter: 7000 Time(1000 iters): 11.7059s Loss: 0.138048 lr:0.001
[HCTR][02:45:59.222][INFO][RK0][main]: Iter: 8000 Time(1000 iters): 8.01361s Loss: 0.149459 lr:0.001
[HCTR][02:46:07.359][INFO][RK0][main]: Iter: 9000 Time(1000 iters): 8.1318s Loss: 0.152849 lr:0.001
[HCTR][02:46:12.572][INFO][RK0][main]: Evaluation, AUC: 0.624589
[HCTR][02:46:12.572][INFO][RK0][main]: Eval Time for 100 iters: 0.223472s
[HCTR][02:46:13.798][INFO][RK0][main]: Rank0: Dump hash table from GPU0
[HCTR][02:46:14.172][INFO][RK0][main]: Rank0: Write hash table <key,value> pairs to file
[HCTR][02:46:16.936][INFO][RK0][main]: Done
[HCTR][02:46:17.654][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][02:46:17.655][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][02:46:17.661][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][02:46:17.661][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][02:46:21.006][INFO][RK0][main]: Finish 10000 iterations with batchsize: 2720 in 85.77s.
[HCTR][02:46:21.006][INFO][RK0][main]: Save the model graph to ./criteo_hugectr/1/criteo.json successfully
run_time: 104.00127220153809

We trained the model and created snapshots.