[1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

Scaling Criteo: Training with HugeCTR

Overview

HugeCTR is an open-source framework to accelerate the training of CTR estimation models on NVIDIA GPUs. It is written in CUDA C++ and highly exploits GPU-accelerated libraries such as cuBLAS, cuDNN, and NCCL. HugeCTR offers multiple advantages to train deep learning recommender systems: 1. Speed: HugeCTR is a highly efficient framework written C++. We experienced upto 10x speed up. HugeCTR on a NVIDIA DGX A100 system proved to be the fastest commercially available solution for training the architecture Deep Learning Recommender Model (DLRM) developed by Facebook. 2. Scale: HugeCTR supports model parallel scaling. It distributes the large embedding tables over multiple GPUs or multiple nodes. 3. Easy-to-use: Easy-to-use Python API similar to Keras. Examples for popular deep learning recommender systems architectures (Wide&Deep, DLRM, DCN, DeepFM) are available.

HugeCTR is able to train recommender system models with larger-than-memory embedding tables by leveraging a parameter server.

You can find more information about HugeCTR here.

Learning objectives

In this notebook, we learn how to to use HugeCTR for training recommender system models - Use HugeCTR to define a recommender system model - Train Facebook’s Deep Learning Recommendation Model with HugeCTR

Getting Started

As HugeCTR optimizes the training in CUDA++, we need to define the training pipeline and model architecture and execute it via the commandline. We will use the Python API, which is similar to Keras models.

If you are not familiar with HugeCTR’s Python API and parameters, you can read more in its GitHub repository: - HugeCTR User Guide - HugeCTR Python API - HugeCTR Configuration File - HugeCTR example architectures

We will write the code to a ./model.py file and execute it. It will create snapshot, which we will use for inference in the next notebook.

[3]:
%%writefile './model.py'

import hugectr
from mpi4py import MPI  # noqa

# HugeCTR
solver = hugectr.solver_parser_helper(
    vvgpu=[[0]],
    max_iter=10000,
    max_eval_batches=100,
    batchsize_eval=2720,
    batchsize=2720,
    display=1000,
    eval_interval=3200,
    snapshot=3200,
    i64_input_key=True,
    use_mixed_precision=False,
    repeat_dataset=True,
)
optimizer = hugectr.optimizer.CreateOptimizer(
    optimizer_type=hugectr.Optimizer_t.SGD, use_mixed_precision=False
)
model = hugectr.Model(solver, optimizer)
model.add(
    hugectr.Input(
        data_reader_type=hugectr.DataReaderType_t.Parquet,
        source="/raid/data/criteo/test_dask/output/train/_file_list.txt",
        eval_source="/raid/data/criteo/test_dask/output/valid/_file_list.txt",
        check_type=hugectr.Check_t.Non,
        label_dim=1,
        label_name="label",
        dense_dim=13,
        dense_name="dense",
        slot_size_array=[
            10000000,
            10000000,
            3014529,
            400781,
            11,
            2209,
            11869,
            148,
            4,
            977,
            15,
            38713,
            10000000,
            10000000,
            10000000,
            584616,
            12883,
            109,
            37,
            17177,
            7425,
            20266,
            4,
            7085,
            1535,
            64,
        ],
        data_reader_sparse_param_array=[
            hugectr.DataReaderSparseParam(hugectr.DataReaderSparse_t.Localized, 26, 1, 26)
        ],
        sparse_names=["data1"],
    )
)
model.add(
    hugectr.SparseEmbedding(
        embedding_type=hugectr.Embedding_t.LocalizedSlotSparseEmbeddingHash,
        max_vocabulary_size_per_gpu=15500000,
        embedding_vec_size=128,
        combiner=0,
        sparse_embedding_name="sparse_embedding1",
        bottom_name="data1",
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["dense"],
        top_names=["fc1"],
        num_output=512,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc1"], top_names=["relu1"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu1"],
        top_names=["fc2"],
        num_output=256,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc2"], top_names=["relu2"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu2"],
        top_names=["fc3"],
        num_output=128,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc3"], top_names=["relu3"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.Interaction,
        bottom_names=["relu3", "sparse_embedding1"],
        top_names=["interaction1"],
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["interaction1"],
        top_names=["fc4"],
        num_output=1024,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc4"], top_names=["relu4"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu4"],
        top_names=["fc5"],
        num_output=1024,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc5"], top_names=["relu5"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu5"],
        top_names=["fc6"],
        num_output=512,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc6"], top_names=["relu6"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu6"],
        top_names=["fc7"],
        num_output=256,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc7"], top_names=["relu7"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu7"],
        top_names=["fc8"],
        num_output=1,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss,
        bottom_names=["fc8", "label"],
        top_names=["loss"],
    )
)
model.compile()
model.summary()
model.fit()
Overwriting ./model.py
[4]:
!python model.py
===================================Model Init====================================
[29d14h45m11s][HUGECTR][INFO]: Global seed is 2284029516
[29d14h45m13s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
Device 0: Tesla V100-SXM2-32GB
[29d14h45m13s][HUGECTR][INFO]: num of DataReader workers: 1
[29d14h45m13s][HUGECTR][INFO]: num_internal_buffers 1
[29d14h45m13s][HUGECTR][INFO]: num_internal_buffers 1
[29d14h45m13s][HUGECTR][INFO]: Vocabulary size: 54120457
[29d14h45m13s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=15500000
[29d14h45m13s][HUGECTR][INFO]: All2All Warmup Start
[29d14h45m13s][HUGECTR][INFO]: All2All Warmup End
[29d14h45m27s][HUGECTR][INFO]: gpu0 start to init embedding
[29d14h45m27s][HUGECTR][INFO]: gpu0 init embedding done
==================================Model Summary==================================
Label Name                    Dense Name                    Sparse Name
label                         dense                         data1
--------------------------------------------------------------------------------
Layer Type                    Input Name                    Output Name
--------------------------------------------------------------------------------
LocalizedHash                 data1                         sparse_embedding1
InnerProduct                  dense                         fc1
ReLU                          fc1                           relu1
InnerProduct                  relu1                         fc2
ReLU                          fc2                           relu2
InnerProduct                  relu2                         fc3
ReLU                          fc3                           relu3
Interaction                   relu3, sparse_embedding1      interaction1
InnerProduct                  interaction1                  fc4
ReLU                          fc4                           relu4
InnerProduct                  relu4                         fc5
ReLU                          fc5                           relu5
InnerProduct                  relu5                         fc6
ReLU                          fc6                           relu6
InnerProduct                  relu6                         fc7
ReLU                          fc7                           relu7
InnerProduct                  relu7                         fc8
BinaryCrossEntropyLoss        fc8, label                    loss
--------------------------------------------------------------------------------
=====================================Model Fit====================================
[29d14h45m27s][HUGECTR][INFO]: Use non-epoch mode with number of iterations: 10000
[29d14h45m27s][HUGECTR][INFO]: Training batchsize: 2720, evaluation batchsize: 2720
[29d14h45m27s][HUGECTR][INFO]: Evaluation interval: 3200, snapshot interval: 3200
[29d14h45m33s][HUGECTR][INFO]: Iter: 1000 Time(1000 iters): 5.641792s Loss: 0.164940 lr:0.001000
[29d14h45m39s][HUGECTR][INFO]: Iter: 2000 Time(1000 iters): 5.656664s Loss: 0.135722 lr:0.001000
[29d14h45m44s][HUGECTR][INFO]: Iter: 3000 Time(1000 iters): 5.651847s Loss: 0.160054 lr:0.001000
[29d14h45m46s][HUGECTR][INFO]: Evaluation, AUC: 0.552071
[29d14h45m46s][HUGECTR][INFO]: Eval Time for 100 iters: 0.285196s
[29d14h45m47s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[29d14h45m48s][HUGECTR][INFO]: Rank0: Write hash table <key,slot_id,value> pairs to file
[29d14h45m49s][HUGECTR][INFO]: Done
[29d14h45m56s][HUGECTR][INFO]: Iter: 4000 Time(1000 iters): 11.744939s Loss: 0.143607 lr:0.001000
[29d14h46m20s][HUGECTR][INFO]: Iter: 5000 Time(1000 iters): 5.678187s Loss: 0.129964 lr:0.001000
[29d14h46m70s][HUGECTR][INFO]: Iter: 6000 Time(1000 iters): 5.678660s Loss: 0.134518 lr:0.001000
[29d14h46m10s][HUGECTR][INFO]: Evaluation, AUC: 0.612846
[29d14h46m10s][HUGECTR][INFO]: Eval Time for 100 iters: 0.218399s
[29d14h46m12s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[29d14h46m14s][HUGECTR][INFO]: Rank0: Write hash table <key,slot_id,value> pairs to file
[29d14h46m15s][HUGECTR][INFO]: Done
[29d14h46m23s][HUGECTR][INFO]: Iter: 7000 Time(1000 iters): 15.466943s Loss: 0.131755 lr:0.001000
[29d14h46m29s][HUGECTR][INFO]: Iter: 8000 Time(1000 iters): 5.657035s Loss: 0.136876 lr:0.001000
[29d14h46m34s][HUGECTR][INFO]: Iter: 9000 Time(1000 iters): 5.659428s Loss: 0.141471 lr:0.001000
[29d14h46m38s][HUGECTR][INFO]: Evaluation, AUC: 0.635054
[29d14h46m38s][HUGECTR][INFO]: Eval Time for 100 iters: 0.192339s
[29d14h46m40s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[29d14h46m43s][HUGECTR][INFO]: Rank0: Write hash table <key,slot_id,value> pairs to file
[29d14h46m45s][HUGECTR][INFO]: Done

We trained the model and created snapshots.