[1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

0dca7c0856f84f82a9c818db2e683012

Overview

In this notebook, we want to provide an overview what HugeCTR framework is, its features and benefits. We will use HugeCTR to train a basic neural network architecture.

Learning Objectives: * Adopt NVTabular workflow to provide input files to HugeCTR * Define HugeCTR neural network architecture * Train a deep learning model with HugeCTR

Why using HugeCTR?

HugeCTR is a GPU-accelerated recommender framework designed to distribute training across multiple GPUs and nodes and estimate Click-Through Rates (CTRs).

HugeCTR offers multiple advantages to train deep learning recommender systems: 1. Speed: HugeCTR is a highly efficient framework written C++. We experienced up to 10x speed up. HugeCTR on a NVIDIA DGX A100 system proved to be the fastest commercially available solution for training the architecture Deep Learning Recommender Model (DLRM) developed by Facebook. 2. Scale: HugeCTR supports model parallel scaling. It distributes the large embedding tables over multiple GPUs or multiple nodes. 3. Easy-to-use: Easy-to-use Python API similar to Keras. Examples for popular deep learning recommender systems architectures (Wide&Deep, DLRM, DCN, DeepFM) are available.

Other Features of HugeCTR

HugeCTR is designed to scale deep learning models for recommender systems. It provides a list of other important features: * Proficiency in oversubscribing models to train embedding tables with single nodes that don’t fit within the GPU or CPU memory (only required embeddings are prefetched from a parameter server per batch) * Asynchronous and multithreaded data pipelines * A highly optimized data loader. * Supported data formats such as parquet and binary * Integration with Triton Inference Server for deployment to production

Getting Started

In this example, we will train a neural network with HugeCTR. We will use preprocessed datasets generated via NVTabular in 02-ETL-with-NVTabular notebook.

[2]:
# External dependencies
import os
import nvtabular as nvt

We define our base directory, containing the data.

[3]:
# path to preprocessed data
INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", os.path.expanduser("~/nvt-examples/movielens/data/")
)

# path to save the models
MODEL_BASE_DIR = os.environ.get("MODEL_BASE_DIR", os.path.expanduser("~/nvt-examples/"))

Let’s load our saved workflow from the 02-ETL-with-NVTabular notebook.

[4]:
workflow = nvt.Workflow.load(os.path.join(INPUT_DATA_DIR, "workflow"))
[5]:
workflow.output_dtypes
[5]:
{'genres': ListDtype(int64),
 'movieId': dtype('int64'),
 'userId': dtype('int64'),
 'rating': dtype('int8')}

Note: We do not have numerical output columns

Let’s clear existing directory and create the output folders.

[6]:
MODEL_DIR = os.path.join(INPUT_DATA_DIR, "model/movielens_hugectr/")
!rm -r MODEL_DIR
!mkdir MODEL_DIR + "1"

Scaling Accelerated training with HugeCTR

HugeCTR is a deep learning framework dedicated to recommendation systems. It is written in CUDA C++. As HugeCTR optimizes the training in CUDA++, we need to define the training pipeline and model architecture and execute it via the commandline. We will use the Python API, which is similar to Keras models.

HugeCTR has three main components: * Solver: Specifies various details such as active GPU list, batchsize, and model_file * Optimizer: Specifies the type of optimizer and its hyperparameters * DataReader: Specifies the training/evaludation data * Model: Specifies embeddings, and dense layers. Note that embeddings must precede the dense layers

Solver

Let’s take a look on the parameter for the Solver. We should be familiar from other frameworks for the hyperparameter.

solver = hugectr.CreateSolver(
- vvgpu: GPU indices used in the training process, which has two levels. For example: [[0,1],[1,2]] indicates that two nodes are used in the first node. GPUs 0 and 1 are used while GPUs 1 and 2 are used for the second node. It is also possible to specify non-continuous GPU indices such as [0, 2, 4, 7]
- batchsize: Minibatch size used in training
- max_eval_batches: Maximum number of batches used in evaluation. It is recommended that the number is equal to or bigger than the actual number of bathces in the evaluation dataset.
If max_iter is used, the evaluation happens for max_eval_batches by repeating the evaluation dataset infinitely.
On the other hand, with num_epochs, HugeCTR stops the evaluation if all the evaluation data is consumed
- batchsize_eval: Maximum number of batches used in evaluation. It is recommended that the number is equal to or
  bigger than the actual number of bathces in the evaluation dataset
- mixed_precision: Enables mixed precision training with the scaler specified here. Only 128,256, 512, and 1024 scalers are supported
)

Optimizer

The optimizer is the algorithm to update the model parameters. HugeCTR supports the common algorithms.

optimizer = CreateOptimizer(
- optimizer_type: Optimizer algorithm - Adam, MomentumSGD, Nesterov, and SGD
- learning_rate: Learning Rate for optimizer
)

DataReader

The data reader defines the training and evaluation dataset.

reader = hugectr.DataReaderParams(
- data_reader_type: Data format to read
- source: The training dataset file list. IMPORTANT: This should be a list
- eval_source: The evaluation dataset file list.
- check_type: The data error detection mechanism (Sum: Checksum, None: no detection).
- slot_size_array: The list of categorical feature cardinalities
)

Model

We initialize the model with the solver, optimizer and data reader:

model = hugectr.Model(solver, reader, optimizer)

We can add multiple layers to the model with model.add function. We will focus on: - Input defines the input data - SparseEmbedding defines the embedding layer - DenseLayer defines dense layers, such as fully connected, ReLU, BatchNorm, etc.

HugeCTR organizes the layers by names. For each layer, we define the input and output names.

Input layer:

This layer is required to define the input data.

hugectr.Input(
    label_dim: Number of label columns
    label_name: Name of label columns in network architecture
    dense_dim: Number of continuous columns
    dense_name: Name of contiunous columns in network architecture
    data_reader_sparse_param_array: Configuration how to read sparse data and its names
)

SparseEmbedding:

This layer defines embedding table

hugectr.SparseEmbedding(
    embedding_type: Different embedding options to distribute embedding tables
    workspace_size_per_gpu_in_mb: Maximum embedding table size in MB
    embedding_vec_size: Embedding vector size
    combiner: Intra-slot reduction op
    sparse_embedding_name: Layer name
    bottom_name: Input layer names
    optimizer: Optimizer to use
)

DenseLayer:

This layer is copied to each GPU and is normally used for the MLP tower.

hugectr.DenseLayer(
    layer_type: Layer type, such as FullyConnected, Reshape, Concat, Loss, BatchNorm, etc.
    bottom_names: Input layer names
    top_names: Layer name
    ...: Depending on the layer type additional parameter can be defined
)

This is only a short introduction in the API. You can read more in the official docs: Python Interface and Layer Book

Let’s define our model

We walked through the documentation, but it is useful to understand the API. Finally, we can define our model. We will write the model to ./model.py and execute it afterwards.

We need the cardinalities of each categorical feature to assign as slot_size_array in the model below.

[7]:
from nvtabular.ops import get_embedding_sizes

embeddings = get_embedding_sizes(workflow)
print(embeddings)
({'movieId': (56586, 512), 'userId': (162542, 512)}, {'genres': (21, 16)})

We use graph_to_json to convert the model to a JSON configuration, required for the inference.

[8]:
%%writefile './model.py'

import hugectr
from mpi4py import MPI  # noqa

solver = hugectr.CreateSolver(
    vvgpu=[[0]],
    batchsize=2048,
    batchsize_eval=2048,
    max_eval_batches=160,
    i64_input_key=True,
    use_mixed_precision=False,
    repeat_dataset=True,
)
optimizer = hugectr.CreateOptimizer(optimizer_type=hugectr.Optimizer_t.Adam)
reader = hugectr.DataReaderParams(
    data_reader_type=hugectr.DataReaderType_t.Parquet,
    source=[INPUT_DATA_DIR + "train/_file_list.txt"],
    eval_source=INPUT_DATA_DIR + "valid/_file_list.txt",
    check_type=hugectr.Check_t.Non,
    slot_size_array=[162542, 56586, 21],
)


model = hugectr.Model(solver, reader, optimizer)

model.add(
    hugectr.Input(
        label_dim=1,
        label_name="label",
        dense_dim=0,
        dense_name="dense",
        data_reader_sparse_param_array=[
            hugectr.DataReaderSparseParam("data1", nnz_per_slot=[1, 1, 2], is_fixed_length=False, slot_num=3)
        ],
    )
)
model.add(
    hugectr.SparseEmbedding(
        embedding_type=hugectr.Embedding_t.LocalizedSlotSparseEmbeddingHash,
        workspace_size_per_gpu_in_mb=200,
        embedding_vec_size=16,
        combiner="sum",
        sparse_embedding_name="sparse_embedding1",
        bottom_name="data1",
        optimizer=optimizer,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.Reshape,
        bottom_names=["sparse_embedding1"],
        top_names=["reshape1"],
        leading_dim=48,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["reshape1"],
        top_names=["fc1"],
        num_output=128,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.ReLU,
        bottom_names=["fc1"],
        top_names=["relu1"],
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu1"],
        top_names=["fc2"],
        num_output=128,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.ReLU,
        bottom_names=["fc2"],
        top_names=["relu2"],
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.InnerProduct,
        bottom_names=["relu2"],
        top_names=["fc3"],
        num_output=1,
    )
)
model.add(
    hugectr.DenseLayer(
        layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss,
        bottom_names=["fc3", "label"],
        top_names=["loss"],
    )
)
model.compile()
model.summary()
model.fit(max_iter=2000, display=100, eval_interval=200, snapshot=1900)
model.graph_to_json(graph_config_file=MODEL_DIR + "1/movielens.json")
Overwriting ./model.py

We train our model.

[9]:
!python model.py
====================================================Model Init=====================================================
[30d01h11m05s][HUGECTR][INFO]: Global seed is 2919760786
[30d01h11m05s][HUGECTR][INFO]: Device to NUMA mapping:
  GPU 0 ->  node 0

[30d01h11m06s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
[30d01h11m06s][HUGECTR][INFO]: Start all2all warmup
[30d01h11m06s][HUGECTR][INFO]: End all2all warmup
[30d01h11m06s][HUGECTR][INFO]: Using All-reduce algorithm OneShot
Device 0: Tesla V100-DGXS-16GB
[30d01h11m06s][HUGECTR][INFO]: num of DataReader workers: 1
[30d01h11m06s][HUGECTR][INFO]: Vocabulary size: 219149
[30d01h11m06s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=3276800
[30d01h11m06s][HUGECTR][INFO]: All2All Warmup Start
[30d01h11m06s][HUGECTR][INFO]: All2All Warmup End
===================================================Model Compile===================================================
[30d01h11m08s][HUGECTR][INFO]: gpu0 start to init embedding
[30d01h11m08s][HUGECTR][INFO]: gpu0 init embedding done
[30d01h11m08s][HUGECTR][INFO]: Starting AUC NCCL warm-up
[30d01h11m08s][HUGECTR][INFO]: Warm-up done
===================================================Model Summary===================================================
Label                                   Dense                         Sparse
label                                   dense                          data1
(None, 1)                               (None, 0)
------------------------------------------------------------------------------------------------------------------
Layer Type                              Input Name                    Output Name                   Output Shape
------------------------------------------------------------------------------------------------------------------
LocalizedSlotSparseEmbeddingHash        data1                         sparse_embedding1             (None, 3, 16)
Reshape                                 sparse_embedding1             reshape1                      (None, 48)
InnerProduct                            reshape1                      fc1                           (None, 128)
ReLU                                    fc1                           relu1                         (None, 128)
InnerProduct                            relu1                         fc2                           (None, 128)
ReLU                                    fc2                           relu2                         (None, 128)
InnerProduct                            relu2                         fc3                           (None, 1)
BinaryCrossEntropyLoss                  fc3,label                     loss
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[30d10h11m80s][HUGECTR][INFO]: Use non-epoch mode with number of iterations: 2000
[30d10h11m80s][HUGECTR][INFO]: Training batchsize: 2048, evaluation batchsize: 2048
[30d10h11m80s][HUGECTR][INFO]: Evaluation interval: 200, snapshot interval: 1900
[30d10h11m80s][HUGECTR][INFO]: Sparse embedding trainable: 1, dense network trainable: 1
[30d10h11m80s][HUGECTR][INFO]: Use mixed precision: 0, scaler: 1.000000, use cuda graph: 1
[30d10h11m80s][HUGECTR][INFO]: lr: 0.001000, warmup_steps: 1, decay_start: 0, decay_steps: 1, decay_power: 2.000000, end_lr: 0.000000
[30d10h11m80s][HUGECTR][INFO]: Training source file: /root/nvt-examples/movielens/data/train/_file_list.txt
[30d10h11m80s][HUGECTR][INFO]: Evaluation source file: /root/nvt-examples/movielens/data/valid/_file_list.txt
[30d10h11m80s][HUGECTR][INFO]: Iter: 100 Time(100 iters): 0.221617s Loss: 0.601177 lr:0.001000
[30d10h11m80s][HUGECTR][INFO]: Iter: 200 Time(100 iters): 0.218594s Loss: 0.562358 lr:0.001000
[30d10h11m80s][HUGECTR][INFO]: Evaluation, AUC: 0.747171
[30d10h11m80s][HUGECTR][INFO]: Eval Time for 160 iters: 0.035586s
[30d10h11m90s][HUGECTR][INFO]: Iter: 300 Time(100 iters): 0.255790s Loss: 0.558368 lr:0.001000
[30d10h11m90s][HUGECTR][INFO]: Iter: 400 Time(100 iters): 0.219147s Loss: 0.542835 lr:0.001000
[30d10h11m90s][HUGECTR][INFO]: Evaluation, AUC: 0.766279
[30d10h11m90s][HUGECTR][INFO]: Eval Time for 160 iters: 0.035058s
[30d10h11m90s][HUGECTR][INFO]: Iter: 500 Time(100 iters): 0.281789s Loss: 0.535660 lr:0.001000
[30d10h11m90s][HUGECTR][INFO]: Iter: 600 Time(100 iters): 0.219292s Loss: 0.536590 lr:0.001000
[30d10h11m90s][HUGECTR][INFO]: Evaluation, AUC: 0.775542
[30d10h11m90s][HUGECTR][INFO]: Eval Time for 160 iters: 0.034913s
[30d10h11m10s][HUGECTR][INFO]: Iter: 700 Time(100 iters): 0.255189s Loss: 0.541284 lr:0.001000
[30d10h11m10s][HUGECTR][INFO]: Iter: 800 Time(100 iters): 0.219360s Loss: 0.530490 lr:0.001000
[30d10h11m10s][HUGECTR][INFO]: Evaluation, AUC: 0.782824
[30d10h11m10s][HUGECTR][INFO]: Eval Time for 160 iters: 0.062139s
[30d10h11m10s][HUGECTR][INFO]: Iter: 900 Time(100 iters): 0.282552s Loss: 0.529440 lr:0.001000
[30d10h11m10s][HUGECTR][INFO]: Iter: 1000 Time(100 iters): 0.244506s Loss: 0.523832 lr:0.001000
[30d10h11m10s][HUGECTR][INFO]: Evaluation, AUC: 0.787961
[30d10h11m10s][HUGECTR][INFO]: Eval Time for 160 iters: 0.035205s
[30d10h11m11s][HUGECTR][INFO]: Iter: 1100 Time(100 iters): 0.255605s Loss: 0.531827 lr:0.001000
[30d10h11m11s][HUGECTR][INFO]: Iter: 1200 Time(100 iters): 0.219146s Loss: 0.540005 lr:0.001000
[30d10h11m11s][HUGECTR][INFO]: Evaluation, AUC: 0.789803
[30d10h11m11s][HUGECTR][INFO]: Eval Time for 160 iters: 0.034100s
[30d10h11m11s][HUGECTR][INFO]: Iter: 1300 Time(100 iters): 0.254645s Loss: 0.533909 lr:0.001000
[30d10h11m11s][HUGECTR][INFO]: Iter: 1400 Time(100 iters): 0.219378s Loss: 0.511975 lr:0.001000
[30d10h11m11s][HUGECTR][INFO]: Evaluation, AUC: 0.793778
[30d10h11m11s][HUGECTR][INFO]: Eval Time for 160 iters: 0.060744s
[30d10h11m12s][HUGECTR][INFO]: Iter: 1500 Time(100 iters): 0.306601s Loss: 0.517142 lr:0.001000
[30d10h11m12s][HUGECTR][INFO]: Iter: 1600 Time(100 iters): 0.219319s Loss: 0.519984 lr:0.001000
[30d10h11m12s][HUGECTR][INFO]: Evaluation, AUC: 0.795280
[30d10h11m12s][HUGECTR][INFO]: Eval Time for 160 iters: 0.034830s
[30d10h11m12s][HUGECTR][INFO]: Iter: 1700 Time(100 iters): 0.255265s Loss: 0.500303 lr:0.001000
[30d10h11m12s][HUGECTR][INFO]: Iter: 1800 Time(100 iters): 0.219257s Loss: 0.537229 lr:0.001000
[30d10h11m12s][HUGECTR][INFO]: Evaluation, AUC: 0.796465
[30d10h11m12s][HUGECTR][INFO]: Eval Time for 160 iters: 0.034618s
[30d10h11m13s][HUGECTR][INFO]: Iter: 1900 Time(100 iters): 0.254921s Loss: 0.516836 lr:0.001000
[30d10h11m13s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0
[30d10h11m13s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file
[30d10h11m13s][HUGECTR][INFO]: Done
[30d10h11m13s][HUGECTR][INFO]: Dumping sparse weights to files, successful
[30d10h11m13s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[30d10h11m13s][HUGECTR][INFO]: Done
[30d10h11m13s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[30d10h11m13s][HUGECTR][INFO]: Done
[30d10h11m13s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful
[30d10h11m13s][HUGECTR][INFO]: Dumping dense weights to file, successful
[30d10h11m13s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful
[30d10h11m13s][HUGECTR][INFO]: Dumping untrainable weights to file, successful
Finish 2000 iterations with batchsize: 2048 in 5.65s
[30d10h11m14s][HUGECTR][INFO]: Save the model graph to /root/nvt-examples/model/movielens_hugectr/1/movielens.json, successful

After training terminates, we can see that multiple .model files and folders are generated. We need to move them inside 1 folder under the movielens_hugectr folder.

[10]:
!mv *.model MODEL_DIR