[1]:
# Copyright 2020 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

NVTabular / HugeCTR Criteo Example

Here we’ll show how to use NVTabular first as a preprocessing library to prepare the Criteo Display Advertising Challenge dataset, and then train a model using HugeCTR.

Data Prep

Before we get started, make sure you’ve run the optimize_criteo notebook, which will convert the tsv data published by Criteo into the parquet format that our accelerated readers prefer. It’s fair to mention at this point that that notebook will take ~30 minutes to run. While we’re hoping to release accelerated csv readers in the near future, we also believe that inefficiencies in existing data representations like csv are in no small part a consequence of inefficiencies in the existing hardware/software stack. Accelerating these pipelines on new hardware like GPUs may require us to make new choices about the representations we use to store that data, and parquet represents a strong alternative.

[2]:
# Standard Libraries
import os
from time import time
import re
import shutil
import glob
import warnings

# External Dependencies
import numpy as np
import cupy as cp
import cudf
import dask_cudf
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from dask.utils import parse_bytes
from dask.delayed import delayed
import rmm

# NVTabular
import nvtabular as nvt
import nvtabular.ops as ops
from nvtabular.io import Shuffle
from nvtabular.utils import _pynvml_mem_size, device_mem_size

# HugeCTR
import sys
sys.path.append("/usr/local/hugectr/lib")
from hugectr import Session, solver_parser_helper, get_learning_rate_scheduler

Dataset and Dataset Schema

Once our data is ready, we’ll define some high level parameters to describe where our data is and what it “looks like” at a high level.

[3]:
# define some information about where to get our data
BASE_DIR = "/raid/criteo/tests/"
input_path = os.path.join(BASE_DIR, "crit_int_pq")
dask_workdir = os.path.join(BASE_DIR, "test_dask/workdir")
output_path = os.path.join(BASE_DIR, "test_dask/output")
stats_path = os.path.join(BASE_DIR, "test_dask/stats")


#BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 800000))
#NUM_PARTS = int(os.environ.get('NUM_PARTS', 2))
NUM_TRAIN_DAYS = 23 # number of days worth of data to use for training, the rest will be used for validation
NUM_GPUS = [0,1,2,3,4,5,6,7]

# define our dataset schema
CONTINUOUS_COLUMNS = ['I' + str(x) for x in range(1,14)]
CATEGORICAL_COLUMNS =  ['C' + str(x) for x in range(1,27)]
LABEL_COLUMNS = ['label']
COLUMNS = CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS + LABEL_COLUMNS

# Make sure we have a clean worker space for Dask
if os.path.isdir(dask_workdir):
    shutil.rmtree(dask_workdir)
os.makedirs(dask_workdir)

# Make sure we have a clean stats space for Dask
if os.path.isdir(stats_path):
    shutil.rmtree(stats_path)
os.mkdir(stats_path)

# Make sure we have a clean output path
if os.path.isdir(output_path):
    shutil.rmtree(output_path)
os.mkdir(output_path)
[4]:
! ls $BASE_DIR
crit_int_pq  NVTabular  test_dask
[5]:
fname = 'day_{}.parquet'
num_days = len([i for i in os.listdir(input_path) if re.match(fname.format('[0-9]{1,2}'), i) is not None])
train_paths = [os.path.join(input_path, fname.format(day)) for day in range(NUM_TRAIN_DAYS)]
valid_paths = [os.path.join(input_path, fname.format(day)) for day in range(NUM_TRAIN_DAYS, num_days)]
print(train_paths)
print(valid_paths)
['/raid/criteo/tests/crit_int_pq/day_0.parquet', '/raid/criteo/tests/crit_int_pq/day_1.parquet', '/raid/criteo/tests/crit_int_pq/day_2.parquet', '/raid/criteo/tests/crit_int_pq/day_3.parquet', '/raid/criteo/tests/crit_int_pq/day_4.parquet', '/raid/criteo/tests/crit_int_pq/day_5.parquet', '/raid/criteo/tests/crit_int_pq/day_6.parquet', '/raid/criteo/tests/crit_int_pq/day_7.parquet', '/raid/criteo/tests/crit_int_pq/day_8.parquet', '/raid/criteo/tests/crit_int_pq/day_9.parquet', '/raid/criteo/tests/crit_int_pq/day_10.parquet', '/raid/criteo/tests/crit_int_pq/day_11.parquet', '/raid/criteo/tests/crit_int_pq/day_12.parquet', '/raid/criteo/tests/crit_int_pq/day_13.parquet', '/raid/criteo/tests/crit_int_pq/day_14.parquet', '/raid/criteo/tests/crit_int_pq/day_15.parquet', '/raid/criteo/tests/crit_int_pq/day_16.parquet', '/raid/criteo/tests/crit_int_pq/day_17.parquet', '/raid/criteo/tests/crit_int_pq/day_18.parquet', '/raid/criteo/tests/crit_int_pq/day_19.parquet', '/raid/criteo/tests/crit_int_pq/day_20.parquet', '/raid/criteo/tests/crit_int_pq/day_21.parquet', '/raid/criteo/tests/crit_int_pq/day_22.parquet']
['/raid/criteo/tests/crit_int_pq/day_23.parquet']

Deploy a Distributed-Dask Cluster

[6]:
# Dask dashboard
dashboard_port = "8787"

# Deploy a Single-Machine Multi-GPU Cluster
protocol = "tcp"             # "tcp" or "ucx"
visible_devices = ",".join([str(n) for n in NUM_GPUS])  # Delect devices to place workers
device_limit_frac = 0.7      # Spill GPU-Worker memory to host at this limit.
device_pool_frac = 0.8
part_mem_frac = 0.15

# Use total device size to calculate args.device_limit_frac
device_size = device_mem_size(kind="total")
device_limit = int(device_limit_frac * device_size)
device_pool_size = int(device_pool_frac * device_size)
part_size = int(part_mem_frac * device_size)

# Check if any device memory is already occupied
for dev in visible_devices.split(","):
    fmem = _pynvml_mem_size(kind="free", index=int(dev))
    used = (device_size - fmem) / 1e9
    if used > 1.0:
        warnings.warn(f"BEWARE - {used} GB is already occupied on device {int(dev)}!")

cluster = None               # (Optional) Specify existing scheduler port
if cluster is None:
    cluster = LocalCUDACluster(
        protocol = protocol,
        n_workers=len(visible_devices.split(",")),
        CUDA_VISIBLE_DEVICES = visible_devices,
        device_memory_limit = device_limit,
        local_directory=dask_workdir,
        dashboard_address=":" + dashboard_port,
    )

# Create the distributed client
client = Client(cluster)
client
[6]:

Client

Cluster

  • Workers: 8
  • Cores: 8
  • Memory: 2.16 TB

Initilize Memory Pools

[7]:
# Initialize RMM pool on ALL workers
def _rmm_pool():
    rmm.reinitialize(
        # RMM may require the pool size to be a multiple of 256.
        pool_allocator=True,
        initial_pool_size=(device_pool_size // 256) * 256, # Use default size
    )

client.run(_rmm_pool)
[7]:
{'tcp://127.0.0.1:34419': None,
 'tcp://127.0.0.1:34529': None,
 'tcp://127.0.0.1:34703': None,
 'tcp://127.0.0.1:34721': None,
 'tcp://127.0.0.1:37447': None,
 'tcp://127.0.0.1:40953': None,
 'tcp://127.0.0.1:42517': None,
 'tcp://127.0.0.1:44611': None}

Preprocessing

At this point, our data still isn’t in a form that’s ideal for consumption by neural networks. The most pressing issues are missing values and the fact that our categorical variables are still represented by random, discrete identifiers, and need to be transformed into contiguous indices that can be leveraged by a learned embedding. Less pressing, but still important for learning dynamics, are the distributions of our continuous variables, which are distributed across multiple orders of magnitude and are uncentered (i.e. E[x] != 0).

We can fix these issues in a conscise and GPU-accelerated manner with an NVTabular Workflow. We’ll instantiate one with our current dataset schema, then symbolically add operations on that schema. By setting all these Ops to use replace=True, the schema itself will remain unmodified, while the variables represented by each field in the schema will be transformed.

Frequency Thresholding

One interesting thing worth pointing out is that we’re using frequency thresholding in our Categorify op. This handy functionality will map all categories which occur in the dataset with some threshold level of infrequency (which we’ve set here to be 15 occurrences throughout the dataset) to the same index, keeping the model from overfitting to sparse signals.

[8]:
proc = nvt.Workflow(
    cat_names=CATEGORICAL_COLUMNS,
    cont_names=CONTINUOUS_COLUMNS,
    label_name=LABEL_COLUMNS,
    client = client)

# log -> normalize continuous features. Note that doing this in the opposite
# order wouldn't make sense! Note also that we're zero filling continuous
# values before the log: this is a good time to remember that LogOp
# performs log(1+x), not log(x)
proc.add_cont_feature([ops.FillMissing(), ops.Clip(min_value=0), ops.LogOp()])

# categorification w/ MOD 10M
num_buckets = 10000000
proc.add_cat_preprocess([ops.Categorify(out_path=stats_path), ops.LambdaOp(op_name="MOD10M", f=lambda col, gdf: col % num_buckets)])

Now instantiate dataset iterators to loop through our dataset (which we couldn’t fit into GPU memory). We need to enforce the required HugeCTR data types, so we set them in a dictionary and give as an argument when creating our dataset

[9]:
dict_dtypes={}

for col in CATEGORICAL_COLUMNS:
    dict_dtypes[col] = np.int64

for col in CONTINUOUS_COLUMNS:
    dict_dtypes[col] = np.float32

for col in LABEL_COLUMNS:
    dict_dtypes[col] = np.float32
[10]:
train_dataset = nvt.Dataset(train_paths, engine='parquet', part_size=part_size)
valid_dataset = nvt.Dataset(valid_paths, engine='parquet', part_size=part_size)

Now run them through our workflows to collect statistics on the train set, then transform and save to parquet files.

[11]:
output_train_dir = os.path.join(output_path, 'train/')
output_valid_dir = os.path.join(output_path, 'valid/')
! mkdir -p $output_train_dir
! mkdir -p $output_valid_dir

For reference, let’s time it to see how long it takes…

[12]:
%%time
proc.apply(train_dataset, shuffle=nvt.io.Shuffle.PER_PARTITION, output_format="parquet", output_path=output_train_dir, dtypes=dict_dtypes)
CPU times: user 18.6 s, sys: 1.84 s, total: 20.5 s
Wall time: 2min 49s
[13]:
%%time
proc.apply(valid_dataset, record_stats=False, shuffle=nvt.io.Shuffle.PER_PARTITION, output_format="parquet", output_path=output_valid_dir, dtypes=dict_dtypes)
CPU times: user 853 ms, sys: 260 ms, total: 1.11 s
Wall time: 11.9 s

Get the embeddings table size, to configurate slot_size_array in dlrm_fp32_64k.json

[14]:
embeddings = [c[0] for c in ops.get_embedding_sizes(proc).values()]
embeddings = np.clip(a=embeddings, a_min=None, a_max=num_buckets).tolist()
print(embeddings)
[10000000, 10000000, 3014529, 400781, 11, 2209, 11869, 148, 4, 977, 15, 38713, 10000000, 10000000, 10000000, 584616, 12883, 109, 37, 17177, 7425, 20266, 4, 7085, 1535, 64]

And just like that, we have training and validation sets ready to feed to a model!

HugeCTR

Training

We’ll run huge_ctr using the DLRM configuration file.

First, we’ll shutdown our Dask client from earlier to free up some memory so that we can share it with HugeCTR.

[15]:
client.shutdown()

Finally, we run HugeCTR. For reference, let’s time it to see how long it takes…

[16]:
%%time
# Set config file
json_file = "dlrm_fp32_64k.json"
# Set solver config
solver_config = solver_parser_helper(seed = 0,
                                     batchsize = 16384,
                                     batchsize_eval = 16384,
                                     model_file = "",
                                     embedding_files = [],
                                     vvgpu = [NUM_GPUS],
                                     use_mixed_precision = False,
                                     scaler = 1.0,
                                     i64_input_key = True,
                                     use_algorithm_search = True,
                                     use_cuda_graph = True,
                                     repeat_dataset = True
                                    )
# Set learning rate
lr_sch = get_learning_rate_scheduler(json_file)
# Train model
sess = Session(solver_config, json_file)
sess.start_data_reading()
for i in range(10000):
    lr = lr_sch.get_next()
    sess.set_learning_rate(lr)
    sess.train()
    if (i%100 == 0):
        loss = sess.get_current_loss()
        print("[HUGECTR][INFO] iter: {}; loss: {}".format(i, loss))
    if (i%3000 == 0 and i != 0):
        metrics = sess.evaluation()
        print("[HUGECTR][INFO] iter: {}, {}".format(i, metrics))
[HUGECTR][INFO] iter: 0; loss: 0.8532629609107971
[HUGECTR][INFO] iter: 100; loss: 0.13733375072479248
[HUGECTR][INFO] iter: 200; loss: 0.12044114619493484
[HUGECTR][INFO] iter: 300; loss: 0.1432936042547226
[HUGECTR][INFO] iter: 400; loss: 0.13736993074417114
[HUGECTR][INFO] iter: 500; loss: 0.13921479880809784
[HUGECTR][INFO] iter: 600; loss: 0.12881088256835938
[HUGECTR][INFO] iter: 700; loss: 0.12584912776947021
[HUGECTR][INFO] iter: 800; loss: 0.1269063502550125
[HUGECTR][INFO] iter: 900; loss: 0.1276361644268036
[HUGECTR][INFO] iter: 1000; loss: 0.12601923942565918
[HUGECTR][INFO] iter: 1100; loss: 0.12928898632526398
[HUGECTR][INFO] iter: 1200; loss: 0.12370038777589798
[HUGECTR][INFO] iter: 1300; loss: 0.13637907803058624
[HUGECTR][INFO] iter: 1400; loss: 0.12262138724327087
[HUGECTR][INFO] iter: 1500; loss: 0.12716645002365112
[HUGECTR][INFO] iter: 1600; loss: 0.12878908216953278
[HUGECTR][INFO] iter: 1700; loss: 0.13806402683258057
[HUGECTR][INFO] iter: 1800; loss: 0.12637314200401306
[HUGECTR][INFO] iter: 1900; loss: 0.1234893649816513
[HUGECTR][INFO] iter: 2000; loss: 0.12500672042369843
[HUGECTR][INFO] iter: 2100; loss: 0.1271117925643921
[HUGECTR][INFO] iter: 2200; loss: 0.12065654993057251
[HUGECTR][INFO] iter: 2300; loss: 0.12455953657627106
[HUGECTR][INFO] iter: 2400; loss: 0.13445869088172913
[HUGECTR][INFO] iter: 2500; loss: 0.12091702222824097
[HUGECTR][INFO] iter: 2600; loss: 0.1275034099817276
[HUGECTR][INFO] iter: 2700; loss: 0.12200944125652313
[HUGECTR][INFO] iter: 2800; loss: 0.12480510026216507
[HUGECTR][INFO] iter: 2900; loss: 0.12914004921913147
[HUGECTR][INFO] iter: 3000; loss: 0.12693384289741516
[HUGECTR][INFO] iter: 3000, [('AUC', 0.7665499448776245)]
[HUGECTR][INFO] iter: 3100; loss: 0.12380503118038177
[HUGECTR][INFO] iter: 3200; loss: 0.12198879569768906
[HUGECTR][INFO] iter: 3300; loss: 0.11890366673469543
[HUGECTR][INFO] iter: 3400; loss: 0.11795458942651749
[HUGECTR][INFO] iter: 3500; loss: 0.1266060322523117
[HUGECTR][INFO] iter: 3600; loss: 0.1308339685201645
[HUGECTR][INFO] iter: 3700; loss: 0.11925296485424042
[HUGECTR][INFO] iter: 3800; loss: 0.12146525084972382
[HUGECTR][INFO] iter: 3900; loss: 0.1292012482881546
[HUGECTR][INFO] iter: 4000; loss: 0.12852615118026733
[HUGECTR][INFO] iter: 4100; loss: 0.128790944814682
[HUGECTR][INFO] iter: 4200; loss: 0.13038936257362366
[HUGECTR][INFO] iter: 4300; loss: 0.13004642724990845
[HUGECTR][INFO] iter: 4400; loss: 0.12568017840385437
[HUGECTR][INFO] iter: 4500; loss: 0.12528616189956665
[HUGECTR][INFO] iter: 4600; loss: 0.12257300317287445
[HUGECTR][INFO] iter: 4700; loss: 0.12529920041561127
[HUGECTR][INFO] iter: 4800; loss: 0.12477346509695053
[HUGECTR][INFO] iter: 4900; loss: 0.12581917643547058
[HUGECTR][INFO] iter: 5000; loss: 0.12937895953655243
[HUGECTR][INFO] iter: 5100; loss: 0.12715725600719452
[HUGECTR][INFO] iter: 5200; loss: 0.1305316984653473
[HUGECTR][INFO] iter: 5300; loss: 0.12407000362873077
[HUGECTR][INFO] iter: 5400; loss: 0.11724398285150528
[HUGECTR][INFO] iter: 5500; loss: 0.1297476887702942
[HUGECTR][INFO] iter: 5600; loss: 0.1252257376909256
[HUGECTR][INFO] iter: 5700; loss: 0.13481514155864716
[HUGECTR][INFO] iter: 5800; loss: 0.11881910264492035
[HUGECTR][INFO] iter: 5900; loss: 0.1231309324502945
[HUGECTR][INFO] iter: 6000; loss: 0.11981914937496185
[HUGECTR][INFO] iter: 6000, [('AUC', 0.7478535175323486)]
[HUGECTR][INFO] iter: 6100; loss: 0.12740889191627502
[HUGECTR][INFO] iter: 6200; loss: 0.1184406653046608
[HUGECTR][INFO] iter: 6300; loss: 0.1215326264500618
[HUGECTR][INFO] iter: 6400; loss: 0.12018976360559464
[HUGECTR][INFO] iter: 6500; loss: 0.12207344174385071
[HUGECTR][INFO] iter: 6600; loss: 0.11936748027801514
[HUGECTR][INFO] iter: 6700; loss: 0.1344636082649231
[HUGECTR][INFO] iter: 6800; loss: 0.1235312819480896
[HUGECTR][INFO] iter: 6900; loss: 0.11865617334842682
[HUGECTR][INFO] iter: 7000; loss: 0.12486278265714645
[HUGECTR][INFO] iter: 7100; loss: 0.13070285320281982
[HUGECTR][INFO] iter: 7200; loss: 0.12883560359477997
[HUGECTR][INFO] iter: 7300; loss: 0.12401801347732544
[HUGECTR][INFO] iter: 7400; loss: 0.12302699685096741
[HUGECTR][INFO] iter: 7500; loss: 0.13381913304328918
[HUGECTR][INFO] iter: 7600; loss: 0.12709784507751465
[HUGECTR][INFO] iter: 7700; loss: 0.12482510507106781
[HUGECTR][INFO] iter: 7800; loss: 0.12176606804132462
[HUGECTR][INFO] iter: 7900; loss: 0.12543131411075592
[HUGECTR][INFO] iter: 8000; loss: 0.11884783953428268
[HUGECTR][INFO] iter: 8100; loss: 0.1285083293914795
[HUGECTR][INFO] iter: 8200; loss: 0.12941600382328033
[HUGECTR][INFO] iter: 8300; loss: 0.1245264783501625
[HUGECTR][INFO] iter: 8400; loss: 0.1230475902557373
[HUGECTR][INFO] iter: 8500; loss: 0.1257411241531372
[HUGECTR][INFO] iter: 8600; loss: 0.12116973102092743
[HUGECTR][INFO] iter: 8700; loss: 0.12535282969474792
[HUGECTR][INFO] iter: 8800; loss: 0.12397449463605881
[HUGECTR][INFO] iter: 8900; loss: 0.12262888252735138
[HUGECTR][INFO] iter: 9000; loss: 0.11161893606185913
[HUGECTR][INFO] iter: 9000, [('AUC', 0.7808812856674194)]
[HUGECTR][INFO] iter: 9100; loss: 0.13083060085773468
[HUGECTR][INFO] iter: 9200; loss: 0.120280422270298
[HUGECTR][INFO] iter: 9300; loss: 0.12189973145723343
[HUGECTR][INFO] iter: 9400; loss: 0.11685363948345184
[HUGECTR][INFO] iter: 9500; loss: 0.12826286256313324
[HUGECTR][INFO] iter: 9600; loss: 0.11898329108953476
[HUGECTR][INFO] iter: 9700; loss: 0.12399856001138687
[HUGECTR][INFO] iter: 9800; loss: 0.11943891644477844
[HUGECTR][INFO] iter: 9900; loss: 0.12630650401115417
CPU times: user 8min 55s, sys: 45.3 s, total: 9min 40s
Wall time: 1min 39s
[17]: