[1]:
# Copyright 2020 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

NVTabular / HugeCTR Criteo Example

Here we’ll show how to use NVTabular first as a preprocessing library to prepare the Criteo Display Advertising Challenge dataset, and then train a model using HugeCTR.

Data Prep

Before we get started, make sure you’ve run the optimize_criteo notebook, which will convert the tsv data published by Criteo into the parquet format that our accelerated readers prefer. It’s fair to mention at this point that that notebook will take ~30 minutes to run. While we’re hoping to release accelerated csv readers in the near future, we also believe that inefficiencies in existing data representations like csv are in no small part a consequence of inefficiencies in the existing hardware/software stack. Accelerating these pipelines on new hardware like GPUs may require us to make new choices about the representations we use to store that data, and parquet represents a strong alternative.

[2]:
import os
from time import time
import re
import glob
import warnings


# tools for data preproc/loading
import torch
import rmm
import nvtabular as nvt
from nvtabular.ops import Normalize, Categorify, LogOp, FillMissing, Clip, get_embedding_sizes
from nvtabular.utils import device_mem_size

/opt/conda/envs/rapids/lib/python3.7/site-packages/numba/cuda/envvars.py:17: NumbaWarning: 
Environment variables with the 'NUMBAPRO' prefix are deprecated and consequently ignored, found use of NUMBAPRO_NVVM=/usr/local/cuda/nvvm/lib64/libnvvm.so.

For more information about alternatives visit: ('http://numba.pydata.org/numba-doc/latest/cuda/overview.html', '#cudatoolkit-lookup')
  warnings.warn(errors.NumbaWarning(msg))
/opt/conda/envs/rapids/lib/python3.7/site-packages/numba/cuda/envvars.py:17: NumbaWarning: 
Environment variables with the 'NUMBAPRO' prefix are deprecated and consequently ignored, found use of NUMBAPRO_LIBDEVICE=/usr/local/cuda/nvvm/libdevice/.

For more information about alternatives visit: ('http://numba.pydata.org/numba-doc/latest/cuda/overview.html', '#cudatoolkit-lookup')
  warnings.warn(errors.NumbaWarning(msg))

Initializing the Memory Pool

For applications like the one that follows where RAPIDS will be the only workhorse user of GPU memory and resource, a best practice is to use the RAPIDS Memory Manager library rmm to allocate a dedicated pool of GPU memory that allows for fast, asynchronous memory management. Here, we’ll dedicate 80% of free GPU memory to this pool to make sure we get the most utilization possible.

[3]:
rmm.reinitialize(pool_allocator=True, initial_pool_size=0.8 * device_mem_size(kind='free'))
/nvtabular/nvtabular/io.py:113: UserWarning: get_memory_info is not supported. Using total device memory from NVML.
  warnings.warn("get_memory_info is not supported. Using total device memory from NVML.")

Dataset and Dataset Schema

Once our data is ready, we’ll define some high level parameters to describe where our data is and what it “looks like” at a high level.

[4]:
# define some information about where to get our data
INPUT_DATA_DIR = os.environ.get('INPUT_DATA_DIR', '/raid/criteo/tests/crit_int_pq')
OUTPUT_DATA_DIR = os.environ.get('OUTPUT_DATA_DIR', '/raid/criteo/tests/test_dask') # where we'll save our procesed data to

BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 800000))
NUM_PARTS = int(os.environ.get('NUM_PARTS', 2))
NUM_TRAIN_DAYS = 23 # number of days worth of data to use for training, the rest will be used for validation

# define our dataset schema
CONTINUOUS_COLUMNS = ['I' + str(x) for x in range(1,14)]
CATEGORICAL_COLUMNS =  ['C' + str(x) for x in range(1,27)]
LABEL_COLUMNS = ['label']
COLUMNS = CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS + LABEL_COLUMNS
[5]:
! ls $INPUT_DATA_DIR
_metadata       day_12.parquet  day_17.parquet  day_21.parquet  day_5.parquet
day_0.parquet   day_13.parquet  day_18.parquet  day_22.parquet  day_6.parquet
day_1.parquet   day_14.parquet  day_19.parquet  day_23.parquet  day_7.parquet
day_10.parquet  day_15.parquet  day_2.parquet   day_3.parquet   day_8.parquet
day_11.parquet  day_16.parquet  day_20.parquet  day_4.parquet   day_9.parquet
[6]:
fname = 'day_{}.parquet'
num_days = len([i for i in os.listdir(INPUT_DATA_DIR) if re.match(fname.format('[0-9]{1,2}'), i) is not None])
train_paths = [os.path.join(INPUT_DATA_DIR, fname.format(day)) for day in range(NUM_TRAIN_DAYS)]
valid_paths = [os.path.join(INPUT_DATA_DIR, fname.format(day)) for day in range(NUM_TRAIN_DAYS, num_days)]

Preprocessing

At this point, our data still isn’t in a form that’s ideal for consumption by neural networks. The most pressing issues are missing values and the fact that our categorical variables are still represented by random, discrete identifiers, and need to be transformed into contiguous indices that can be leveraged by a learned embedding. Less pressing, but still important for learning dynamics, are the distributions of our continuous variables, which are distributed across multiple orders of magnitude and are uncentered (i.e. E[x] != 0).

We can fix these issues in a conscise and GPU-accelerated manner with an NVTabular Workflow. We’ll instantiate one with our current dataset schema, then symbolically add operations on that schema. By setting all these Ops to use replace=True, the schema itself will remain unmodified, while the variables represented by each field in the schema will be transformed.

Frequency Thresholding

One interesting thing worth pointing out is that we’re using frequency thresholding in our Categorify op. This handy functionality will map all categories which occur in the dataset with some threshold level of infrequency (which we’ve set here to be 15 occurrences throughout the dataset) to the same index, keeping the model from overfitting to sparse signals.

[7]:
proc = nvt.Workflow(
    cat_names=CATEGORICAL_COLUMNS,
    cont_names=CONTINUOUS_COLUMNS,
    label_name=LABEL_COLUMNS)

# log -> normalize continuous features. Note that doing this in the opposite
# order wouldn't make sense! Note also that we're zero filling continuous
# values before the log: this is a good time to remember that LogOp
# performs log(1+x), not log(x)
proc.add_cont_feature([FillMissing(), Clip(min_value=0), LogOp()])
proc.add_cont_preprocess(Normalize())

# categorification with frequency thresholding
proc.add_cat_preprocess(Categorify(freq_threshold=15, out_path=OUTPUT_DATA_DIR))

Now instantiate dataset iterators to loop through our dataset (which we couldn’t fit into GPU memory). We need to enforce the required HugeCTR data types, so we set them in a dictionary and give as an argument when creating our dataset

[8]:
import numpy as np

dict_dtypes={}

for col in CONTINUOUS_COLUMNS:
    dict_dtypes[col] = np.float32

for col in CATEGORICAL_COLUMNS:
    dict_dtypes[col] = np.int64

for col in LABEL_COLUMNS:
    dict_dtypes[col] = np.float32
[9]:
train_dataset = nvt.Dataset(train_paths, engine='parquet', part_mem_fraction=0.15, dtypes=dict_dtypes)
valid_dataset = nvt.Dataset(valid_paths, engine='parquet', part_mem_fraction=0.15, dtypes=dict_dtypes)

Now run them through our workflows to collect statistics on the train set, then transform and save to parquet files.

[10]:
output_train_dir = os.path.join(OUTPUT_DATA_DIR, 'train/')
output_valid_dir = os.path.join(OUTPUT_DATA_DIR, 'valid/')
! mkdir -p $output_train_dir
! mkdir -p $output_valid_dir

For reference, let’s time it to see how long it takes…

[11]:
%%time
proc.apply(train_dataset, shuffle=nvt.io.Shuffle.PER_PARTITION, output_format="parquet", output_path=output_train_dir, out_files_per_proc=15)
CPU times: user 20min 37s, sys: 13min 14s, total: 33min 52s
Wall time: 1h 14min 31s
[12]:
%%time
proc.apply(valid_dataset ,record_stats=False, shuffle=nvt.io.Shuffle.PER_PARTITION, output_format="parquet", output_path=output_valid_dir, out_files_per_proc=15)
CPU times: user 29 s, sys: 26.1 s, total: 55.1 s
Wall time: 2min 30s
[13]:
embeddings = get_embedding_sizes(proc)
print(embeddings.values())
dict_values([(7599500, 16), (5345303, 16), (561810, 16), (242827, 16), (11, 6), (2209, 16), (10616, 16), (100, 16), (4, 3), (968, 16), (15, 7), (33521, 16), (7838519, 16), (2580502, 16), (6878028, 16), (298771, 16), (11951, 16), (97, 16), (35, 12), (17022, 16), (7339, 16), (20046, 16), (4, 3), (7068, 16), (1377, 16), (63, 16)])

And just like that, we have training and validation sets ready to feed to a model!

HugeCTR

Training

We’ll run huge_ctr using the DLRM configuration file.

First, we’ll reinitialize our memory pool from earlier to free up some memory so that we can share it with HugeCTR.

[14]:
rmm.reinitialize(pool_allocator=False)

Finally, we run HugeCTR. For reference, let’s time it to see how long it takes…

[27]:
%%time
! /usr/local/hugectr/bin/huge_ctr --train dlrm_fp32_16k.json
[0, init_start, ]
HugeCTR Version: 2.2.1
Config file: dlrm_fp32_16k.json
[28d23h07m10s][HUGECTR][INFO]: algorithm_search is not specified using default: 1
[28d23h07m10s][HUGECTR][INFO]: Algorithm search: ON
Device 0: Tesla P100-DGXS-16GB
Device 1: Tesla P100-DGXS-16GB
[28d23h07m15s][HUGECTR][INFO]: Initial seed is 104555387
[28d23h07m15s][HUGECTR][INFO]: cache_eval_data is not specified using default: 0
[28d23h07m15s][HUGECTR][INFO]: num_internal_buffers 1
[28d23h07m15s][HUGECTR][INFO]: num_internal_buffers 1
[28d23h07m15s][HUGECTR][INFO]: Vocabulary size: 31457706
[28d23h07m15s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=15500000
[28d23h07m15s][HUGECTR][INFO]: All2All Warmup Start
[28d23h07m15s][HUGECTR][INFO]: All2All Warmup End
[28d23h07m36s][HUGECTR][INFO]: gpu0 start to init embedding
[28d23h07m36s][HUGECTR][INFO]: gpu1 start to init embedding
[28d23h07m37s][HUGECTR][INFO]: gpu0 init embedding done
[28d23h07m37s][HUGECTR][INFO]: gpu1 init embedding done
[28418.6, init_end, ]
[28418.7, run_start, ]
HugeCTR training start:
[28418.7, train_epoch_start, 0, ]
[28d23h08m13s][HUGECTR][INFO]: Iter: 1000 Time(1000 iters): 35.098988s Loss: 0.130554 lr:1.001000
[28d23h08m48s][HUGECTR][INFO]: Iter: 2000 Time(1000 iters): 34.997547s Loss: 0.133759 lr:2.001000
[28d23h09m23s][HUGECTR][INFO]: Iter: 3000 Time(1000 iters): 34.888029s Loss: 0.136231 lr:3.001000
[140359, eval_start, 0.032, ]
[28d23h10m53s][HUGECTR][INFO]: Evaluation, AUC: 0.729536
[222983, eval_accuracy, 0.729536, 0.032, 3200, ]
[28d23h10m53s][HUGECTR][INFO]: Eval Time for 5440 iters: 82.624101s
[222983, eval_stop, 0.032, ]
[28d23h11m24s][HUGECTR][INFO]: Iter: 4000 Time(1000 iters): 120.358955s Loss: 0.132684 lr:4.001000
[28d23h12m02s][HUGECTR][INFO]: Iter: 5000 Time(1000 iters): 38.544305s Loss: 0.119664 lr:5.001000
[28d23h12m41s][HUGECTR][INFO]: Iter: 6000 Time(1000 iters): 38.445513s Loss: 0.121679 lr:6.001000
[346374, eval_start, 0.064, ]
[28d23h14m18s][HUGECTR][INFO]: Evaluation, AUC: 0.754327
[427758, eval_accuracy, 0.754327, 0.064, 6400, ]
[28d23h14m18s][HUGECTR][INFO]: Eval Time for 5440 iters: 81.383839s
[427758, eval_stop, 0.064, ]
[28d23h14m41s][HUGECTR][INFO]: Iter: 7000 Time(1000 iters): 119.988648s Loss: 0.128801 lr:7.001000
[28d23h15m19s][HUGECTR][INFO]: Iter: 8000 Time(1000 iters): 38.559991s Loss: 0.125091 lr:8.000000
[28d23h15m58s][HUGECTR][INFO]: Iter: 9000 Time(1000 iters): 38.992278s Loss: 0.139365 lr:8.000000
[551288, eval_start, 0.096, ]
[28d23h17m42s][HUGECTR][INFO]: Evaluation, AUC: 673643.375000
[632540, eval_accuracy, 673643, 0.096, 9600, ]
[632540, train_samples, 157302784, ]
Hit target accuracy AUC 0.772500 at epoch 0.096000 with batchsize: 16384 in 604.12 s. Average speed 260355.76 records/s.
[632539.75, eval_stop, 0.096000, ]
[632539.75, train_epoch_end, 1, ]
[632539.76, run_stop, ]
CPU times: user 11.7 s, sys: 3.62 s, total: 15.3 s
Wall time: 10min 35s
[ ]: