[1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

Scaling Criteo: Training with HugeCTR

Overview

HugeCTR is an open-source framework to accelerate the training of CTR estimation models on NVIDIA GPUs. It is written in CUDA C++ and highly exploits GPU-accelerated libraries such as cuBLAS, cuDNN, and NCCL. HugeCTR offers multiple advantages to train deep learning recommender systems:

  1. Speed: HugeCTR is a highly efficient framework written C++. We experienced up to 10x speed up. HugeCTR on a NVIDIA DGX A100 system proved to be the fastest commercially available solution for training the architecture Deep Learning Recommender Model (DLRM) developed by Facebook.

  2. Scale: HugeCTR supports model parallel scaling. It distributes the large embedding tables over multiple GPUs or multiple nodes.

  3. Easy-to-use: Easy-to-use Python API similar to Keras. Examples for popular deep learning recommender systems architectures (Wide&Deep, DLRM, DCN, DeepFM) are available.

HugeCTR is able to train recommender system models with larger-than-memory embedding tables by leveraging a parameter server.

You can find more information about HugeCTR here.

Learning objectives

In this notebook, we learn how to to use HugeCTR for training recommender system models

Training with HugeCTR

As HugeCTR optimizes the training in CUDA++, we need to define the training pipeline and model architecture and execute it via the commandline. We will use the Python API, which is similar to Keras models.

If you are not familiar with HugeCTR’s Python API and parameters, you can read more in its GitHub repository: - HugeCTR User Guide - HugeCTR Python API - HugeCTR example architectures

We will write the code to a ./model.py file and execute it. It will create snapshot, which we will use for inference in the next notebook.

[2]:
!ls /raid/data/criteo/test_dask/output/
train  valid  workflow
[3]:
import os

os.system("rm -rf ./criteo_hugectr/")
os.system("mkdir -p ./criteo_hugectr/1")
[3]:
0

We use graph_to_json to convert the model to a JSON configuration, required for the inference.

[4]:
%%writefile './model.py'

import hugectr
from mpi4py import MPI  # noqa

# HugeCTR
solver = hugectr.CreateSolver(
    vvgpu=[[0]],
    max_eval_batches=100,
    batchsize_eval=2720,
    batchsize=2720,
    i64_input_key=True,
    use_mixed_precision=False,
    repeat_dataset=True,
)
optimizer = hugectr.CreateOptimizer(optimizer_type=hugectr.Optimizer_t.SGD)
reader = hugectr.DataReaderParams(
    data_reader_type=hugectr.DataReaderType_t.Parquet,
    source=["/raid/data/criteo/test_dask/output/train/_file_list.txt"],
    eval_source="/raid/data/criteo/test_dask/output/train/_file_list.txt",
    check_type=hugectr.Check_t.Non,
    slot_size_array=[
        10000000,
        10000000,
        3014529,
        400781,
        11,
        2209,
        11869,
        148,
        4,
        977,
        15,
        38713,
        10000000,
        10000000,
        10000000,
        584616,
        12883,
        109,
        37,
        17177,
        7425,
        20266,
        4,
        7085,
        1535,
        64,
    ],
)
model = hugectr.Model(solver, reader, optimizer)
model.add(
    hugectr.Input(
        label_dim=1,
        label_name="label",
        dense_dim=13,
        dense_name="dense",
        data_reader_sparse_param_array=[hugectr.DataReaderSparseParam("data1", 1, False, 26)],
    )
)
model.add(
    hugectr.SparseEmbedding(
        embedding_type=hugectr.Embedding_t.LocalizedSlotSparseEmbeddingHash,
        workspace_size_per_gpu_in_mb=6000,
        embedding_vec_size=128,
        combiner="sum",
        sparse_embedding_name="sparse_embedding1",
        bottom_name="data1",
        optimizer=optimizer,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["dense"],
        top_names=["fc1"],
        num_output=512,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc1"], top_names=["relu1"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu1"],
        top_names=["fc2"],
        num_output=256,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc2"], top_names=["relu2"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu2"],
        top_names=["fc3"],
        num_output=128,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc3"], top_names=["relu3"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.Interaction,
        bottom_names=["relu3", "sparse_embedding1"],
        top_names=["interaction1"],
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["interaction1"],
        top_names=["fc4"],
        num_output=1024,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc4"], top_names=["relu4"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu4"],
        top_names=["fc5"],
        num_output=1024,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc5"], top_names=["relu5"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu5"],
        top_names=["fc6"],
        num_output=512,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc6"], top_names=["relu6"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu6"],
        top_names=["fc7"],
        num_output=256,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc7"], top_names=["relu7"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu7"],
        top_names=["fc8"],
        num_output=1,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss,
        bottom_names=["fc8", "label"],
        top_names=["loss"],
    )
)
model.compile()
model.summary()
model.fit(max_iter=10000, eval_interval=3200, display=1000, snapshot=3200)
model.graph_to_json(graph_config_file="./criteo_hugectr/1/criteo.json")
Overwriting ./model.py
[5]:
!python model.py
====================================================Model Init=====================================================
[26d18h47m40s][HUGECTR][INFO]: Global seed is 488900738
[26d18h47m43s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
Device 0: Tesla V100-SXM2-32GB
[26d18h47m43s][HUGECTR][INFO]: num of DataReader workers: 1
[26d18h47m43s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=12288000
[26d18h47m43s][HUGECTR][INFO]: All2All Warmup Start
[26d18h47m43s][HUGECTR][INFO]: All2All Warmup End
===================================================Model Compile===================================================
[26d18h47m58s][HUGECTR][INFO]: gpu0 start to init embedding
[26d18h47m58s][HUGECTR][INFO]: gpu0 init embedding done
===================================================Model Summary===================================================
Label                                   Dense                         Sparse
label                                   dense                          data1
(None, 1)                               (None, 13)
------------------------------------------------------------------------------------------------------------------
Layer Type                              Input Name                    Output Name                   Output Shape
------------------------------------------------------------------------------------------------------------------
LocalizedSlotSparseEmbeddingHash        data1                         sparse_embedding1             (None, 26, 128)
InnerProduct                            dense                         fc1                           (None, 512)
ReLU                                    fc1                           relu1                         (None, 512)
InnerProduct                            relu1                         fc2                           (None, 256)
ReLU                                    fc2                           relu2                         (None, 256)
InnerProduct                            relu2                         fc3                           (None, 128)
ReLU                                    fc3                           relu3                         (None, 128)
Interaction                             relu3,sparse_embedding1       interaction1                  (None, 480)
InnerProduct                            interaction1                  fc4                           (None, 1024)
ReLU                                    fc4                           relu4                         (None, 1024)
InnerProduct                            relu4                         fc5                           (None, 1024)
ReLU                                    fc5                           relu5                         (None, 1024)
InnerProduct                            relu5                         fc6                           (None, 512)
ReLU                                    fc6                           relu6                         (None, 512)
InnerProduct                            relu6                         fc7                           (None, 256)
ReLU                                    fc7                           relu7                         (None, 256)
InnerProduct                            relu7                         fc8                           (None, 1)
BinaryCrossEntropyLoss                  fc8,label                     loss
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[26d18h47m58s][HUGECTR][INFO]: Use non-epoch mode with number of iterations: 10000
[26d18h47m58s][HUGECTR][INFO]: Training batchsize: 2720, evaluation batchsize: 2720
[26d18h47m58s][HUGECTR][INFO]: Evaluation interval: 3200, snapshot interval: 3200
[26d18h47m58s][HUGECTR][INFO]: Sparse embedding trainable: 1, dense network trainable: 1
[26d18h47m58s][HUGECTR][INFO]: Use mixed precision: 0, scaler: 1.000000, use cuda graph: 1
[26d18h47m58s][HUGECTR][INFO]: lr: 0.001000, warmup_steps: 1, decay_start: 0, decay_steps: 1, decay_power: 2.000000, end_lr: 0.000000
[26d18h47m58s][HUGECTR][INFO]: Training source file: /raid/data/criteo2/test_dask/output/train/_file_list.txt
[26d18h47m58s][HUGECTR][INFO]: Evaluation source file: /raid/data/criteo2/test_dask/output/train/_file_list.txt
[26d18h48m40s][HUGECTR][INFO]: Iter: 1000 Time(1000 iters): 5.344550s Loss: 0.157622 lr:0.001000
[26d18h48m90s][HUGECTR][INFO]: Iter: 2000 Time(1000 iters): 5.351840s Loss: 0.139202 lr:0.001000
[26d18h48m14s][HUGECTR][INFO]: Iter: 3000 Time(1000 iters): 5.370395s Loss: 0.150230 lr:0.001000
[26d18h48m16s][HUGECTR][INFO]: Evaluation, AUC: 0.547015
[26d18h48m16s][HUGECTR][INFO]: Eval Time for 100 iters: 0.284275s
[26d18h48m17s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[26d18h48m17s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file
[26d18h48m20s][HUGECTR][INFO]: Done
[26d18h48m20s][HUGECTR][INFO]: Dumping sparse weights to files, successful
[26d18h48m20s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful
[26d18h48m20s][HUGECTR][INFO]: Dumping dense weights to file, successful
[26d18h48m20s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful
[26d18h48m20s][HUGECTR][INFO]: Dumping untrainable weights to file, successful
[26d18h48m25s][HUGECTR][INFO]: Iter: 4000 Time(1000 iters): 10.234373s Loss: 0.138970 lr:0.001000
[26d18h48m30s][HUGECTR][INFO]: Iter: 5000 Time(1000 iters): 5.334527s Loss: 0.123318 lr:0.001000
[26d18h48m35s][HUGECTR][INFO]: Iter: 6000 Time(1000 iters): 5.342809s Loss: 0.135370 lr:0.001000
[26d18h48m38s][HUGECTR][INFO]: Evaluation, AUC: 0.612456
[26d18h48m38s][HUGECTR][INFO]: Eval Time for 100 iters: 0.341947s
[26d18h48m39s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[26d18h48m40s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file
[26d18h48m44s][HUGECTR][INFO]: Done
[26d18h48m45s][HUGECTR][INFO]: Dumping sparse weights to files, successful
[26d18h48m45s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful
[26d18h48m45s][HUGECTR][INFO]: Dumping dense weights to file, successful
[26d18h48m45s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful
[26d18h48m45s][HUGECTR][INFO]: Dumping untrainable weights to file, successful
[26d18h48m48s][HUGECTR][INFO]: Iter: 7000 Time(1000 iters): 12.887706s Loss: 0.164043 lr:0.001000
[26d18h48m54s][HUGECTR][INFO]: Iter: 8000 Time(1000 iters): 5.328255s Loss: 0.137928 lr:0.001000
[26d18h48m59s][HUGECTR][INFO]: Iter: 9000 Time(1000 iters): 5.346786s Loss: 0.149743 lr:0.001000
[26d18h49m20s][HUGECTR][INFO]: Evaluation, AUC: 0.630659
[26d18h49m20s][HUGECTR][INFO]: Eval Time for 100 iters: 0.344536s
[26d18h49m40s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[26d18h49m50s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file
[26d18h49m11s][HUGECTR][INFO]: Done
[26d18h49m12s][HUGECTR][INFO]: Dumping sparse weights to files, successful
[26d18h49m12s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful
[26d18h49m12s][HUGECTR][INFO]: Dumping dense weights to file, successful
[26d18h49m12s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful
[26d18h49m12s][HUGECTR][INFO]: Dumping untrainable weights to file, successful
[26d18h49m14s][HUGECTR][INFO]: Save the model graph to ./criteo_hugectr/1/criteo.json, successful

We trained the model and created snapshots.