Scaling Criteo: Training with HugeCTR

# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
https://developer.download.nvidia.com//notebooks/dlsw-notebooks/merlin_merlin_scaling-criteo-03-training-with-hugectr/nvidia_logo.png

Scaling Criteo: Training with HugeCTR#

This notebook is created using the latest stable merlin-hugectr container.

Overview#

HugeCTR is an open-source framework to accelerate the training of CTR estimation models on NVIDIA GPUs. It is written in CUDA C++ and highly exploits GPU-accelerated libraries such as cuBLAS, cuDNN, and NCCL.

HugeCTR offers multiple advantages to train deep learning recommender systems:

  1. Speed: HugeCTR is a highly efficient framework written in C++. We experienced up to 10x speed up. HugeCTR on a NVIDIA DGX A100 system proved to be the fastest commercially available solution for training the architecture Deep Learning Recommender Model (DLRM) developed by Facebook.

  2. Scale: HugeCTR supports model parallel scaling. It distributes the large embedding tables over multiple GPUs or multiple nodes.

  3. Easy-to-use: Easy-to-use Python API similar to Keras. Examples for popular deep learning recommender systems architectures (Wide&Deep, DLRM, DCN, DeepFM) are available.

HugeCTR is able to train recommender system models with larger-than-memory embedding tables by leveraging a parameter server.

You can find more information about HugeCTR from the GitHub repository.

Learning objectives#

In this notebook, we learn how to to use HugeCTR for training recommender system models

Training with HugeCTR#

As HugeCTR optimizes the training in CUDA++, we need to define the training pipeline and model architecture and execute it via the commandline. We will use the Python API, which is similar to Keras models.

If you are not familiar with HugeCTR’s Python API and parameters, you can read more in its GitHub repository:

import os

BASE_DIR = os.environ.get("BASE_DIR", "/raid/data/criteo")
OUTPUT_DATA_DIR = os.environ.get("OUTPUT_DATA_DIR", BASE_DIR + "/test_dask/output")

First, we clean the output directory.

os.system("rm -rf " + os.path.join(OUTPUT_DATA_DIR, "criteo_hugectr/"))
os.system("mkdir -p " + os.path.join(OUTPUT_DATA_DIR, "criteo_hugectr/1/"))
0

We write the code to a ./model.py file and execute it. The code creates snapshots, which we will use for inference in the next notebook. We use graph_to_json to convert the model to a JSON configuration, required for the inference.

file_to_write = f"""
import hugectr
from mpi4py import MPI  # noqa

# HugeCTR
solver = hugectr.CreateSolver(
    vvgpu=[[0]],
    max_eval_batches=100,
    batchsize_eval=2720,
    batchsize=2720,
    i64_input_key=True,
    use_mixed_precision=False,
    repeat_dataset=True,
)
optimizer = hugectr.CreateOptimizer(optimizer_type=hugectr.Optimizer_t.SGD)
reader = hugectr.DataReaderParams(
    data_reader_type=hugectr.DataReaderType_t.Parquet,
    source=["{os.path.join(OUTPUT_DATA_DIR, "train/_file_list.txt")}"],
    eval_source="{os.path.join(OUTPUT_DATA_DIR, "valid/_file_list.txt")}",
    check_type=hugectr.Check_t.Non,
    slot_size_array=[
        10000000,
        10000000,
        3014529,
        400781,
        11,
        2209,
        11869,
        148,
        4,
        977,
        15,
        38713,
        10000000,
        10000000,
        10000000,
        584616,
        12883,
        109,
        37,
        17177,
        7425,
        20266,
        4,
        7085,
        1535,
        64,
    ],
)
model = hugectr.Model(solver, reader, optimizer)
model.add(
    hugectr.Input(
        label_dim=1,
        label_name="label",
        dense_dim=13,
        dense_name="dense",
        data_reader_sparse_param_array=[hugectr.DataReaderSparseParam("data1", 1, False, 26)],
    )
)
model.add(
    hugectr.SparseEmbedding(
        embedding_type=hugectr.Embedding_t.LocalizedSlotSparseEmbeddingHash,
        workspace_size_per_gpu_in_mb=6000,
        embedding_vec_size=128,
        combiner="sum",
        sparse_embedding_name="sparse_embedding1",
        bottom_name="data1",
        optimizer=optimizer,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["dense"],
        top_names=["fc1"],
        num_output=512,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc1"], top_names=["relu1"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu1"],
        top_names=["fc2"],
        num_output=256,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc2"], top_names=["relu2"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu2"],
        top_names=["fc3"],
        num_output=128,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc3"], top_names=["relu3"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.Interaction,
        bottom_names=["relu3", "sparse_embedding1"],
        top_names=["interaction1"],
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["interaction1"],
        top_names=["fc4"],
        num_output=1024,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc4"], top_names=["relu4"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu4"],
        top_names=["fc5"],
        num_output=1024,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc5"], top_names=["relu5"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu5"],
        top_names=["fc6"],
        num_output=512,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc6"], top_names=["relu6"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu6"],
        top_names=["fc7"],
        num_output=256,
    )
)
model.add(
    hugectr.DenseLayer(layer_type=hugectr.Layer_t.ReLU, bottom_names=["fc7"], top_names=["relu7"])
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu7"],
        top_names=["fc8"],
        num_output=1,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss,
        bottom_names=["fc8", "label"],
        top_names=["loss"],
    )
)

MAX_ITER = 10000
EVAL_INTERVAL = 3200
model.compile()
model.summary()
model.fit(max_iter=MAX_ITER, eval_interval=EVAL_INTERVAL, display=1000, snapshot=3200)
model.graph_to_json(graph_config_file="{os.path.join(OUTPUT_DATA_DIR, "criteo_hugectr/1/", "criteo.json")}")
"""
with open('./model.py', 'w', encoding='utf-8') as fi:
    fi.write(file_to_write)

We can execute the training process.

import time

start = time.time()
os.system('python model.py')
end = time.time() - start
print(f"run_time: {end}")
HugeCTR Version: 4.1
====================================================Model Init=====================================================
[HCTR][13:56:03.374][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][13:56:03.374][INFO][RK0][main]: Global seed is 2831956451
[HCTR][13:56:03.376][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][13:56:06.490][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][13:56:06.490][INFO][RK0][main]: Start all2all warmup
[HCTR][13:56:06.490][INFO][RK0][main]: End all2all warmup
[HCTR][13:56:06.491][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][13:56:06.493][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB-LS
[HCTR][13:56:06.493][INFO][RK0][main]: num of DataReader workers for train: 1
[HCTR][13:56:06.493][INFO][RK0][main]: num of DataReader workers for eval: 1
[HCTR][13:56:06.540][INFO][RK0][main]: Vocabulary size: 54120457
[HCTR][13:56:06.541][INFO][RK0][main]: max_vocabulary_size_per_gpu_=12288000
[HCTR][13:56:06.562][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][13:56:24.055][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][13:56:24.107][INFO][RK0][main]: gpu0 init embedding done
[HCTR][13:56:24.111][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][13:56:24.113][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
run_time: 125.9122965335846

We trained the model and created snapshots.