http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_training-with-hdfs/nvidia_logo.png

HugeCTR End-end Example with NVTabular

Overview

In this sample notebook, we are going to:

  1. Preprocess data using NVTabular

  2. Training model with HugeCTR

  3. Do offline inference using HugeCTR HPS

Setup

To setup the environment, refer to HugeCTR Example Notebooks and follow the instructions there before running the following.

Data Preparation

import os
import shutil
!mkdir -p /hugectr_e2e
!mkdir -p /hugectr_e2e/criteo/train
!mkdir -p /hugectr_e2e/criteo/val
!mkdir -p /hugectr_e2e/model
BASE_DIR = os.environ.get("BASE_DIR", "/hugectr_e2e")
DATA_DIR = os.environ.get("DATA_DIR", BASE_DIR + "/criteo")
TRAIN_DIR = os.environ.get("TRAIN_DIR", DATA_DIR +"/train")
VAL_DIR = os.environ.get("VAL_DIR", DATA_DIR +"/val")
MODEL_DIR = os.environ.get("MODEL_DIR", BASE_DIR + "/model")

Download the Criteo data for 1 day:

#!wget -P $DATA_DIR https://storage.googleapis.com/criteo-cail-datasets/day_0.gz  #decomment this line to download, otherwise soft link the data.
#!gzip -d -c $DATA_DIR/day_0.gz > $DATA_DIR/day_0
INPUT_DATA = os.environ.get("INPUT_DATA", DATA_DIR + "/day_0")
!ln -s $INPUT_DATA $DATA_DIR/day_0
--2022-11-11 09:06:02--  https://storage.googleapis.com/criteo-cail-datasets/day_0.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.191.48, 142.251.46.208, 172.217.164.112, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.191.48|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16309554343 (15G) [application/octet-stream]
Saving to: ‘/hugectr_e2e/criteo/day_0.gz’

day_0.gz            100%[===================>]  15.19G  94.4MB/s    in 4m 58s  

2022-11-11 09:11:00 (52.2 MB/s) - ‘/hugectr_e2e/criteo/day_0.gz’ saved [16309554343/16309554343]

Unzip and split data

!head -n 10000000 $DATA_DIR/day_0 > $DATA_DIR/train/train.txt
!tail -n 2000000 $DATA_DIR/day_0 > $DATA_DIR/val/test.txt 

Data Preprocessing using NVTabular

import sys
import argparse
import glob
import time
import numpy as np
import numba

import dask_cudf
import cudf
import nvtabular as nvt
from nvtabular.io import Shuffle
from nvtabular.ops import Categorify, Clip, FillMissing, Normalize, get_embedding_sizes
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from nvtabular.utils import pynvml_mem_size, device_mem_size
import warnings

import logging
logging.basicConfig(format='%(asctime)s %(message)s')
logging.root.setLevel(logging.NOTSET)

# define dataset schema
CATEGORICAL_COLUMNS=["C" + str(x) for x in range(1, 27)]
CONTINUOUS_COLUMNS=["I" + str(x) for x in range(1, 14)]
LABEL_COLUMNS = ['label']
COLUMNS =  LABEL_COLUMNS + CONTINUOUS_COLUMNS +  CATEGORICAL_COLUMNS
#/samples/criteo mode doesn't have dense features
criteo_COLUMN=LABEL_COLUMNS +  CATEGORICAL_COLUMNS
#For new feature cross columns
CROSS_COLUMNS = ["C1_C2", "C3_C4"]

NUM_INTEGER_COLUMNS = 13
NUM_CATEGORICAL_COLUMNS = 26
NUM_TOTAL_COLUMNS = 1 + NUM_INTEGER_COLUMNS + NUM_CATEGORICAL_COLUMNS
# Dask dashboard
dashboard_port = "8787"

# Deploy a Single-Machine Multi-GPU Cluster
protocol = "tcp"  # "tcp" or "ucx"
if numba.cuda.is_available():
    NUM_GPUS = list(range(len(numba.cuda.gpus)))
else:
    NUM_GPUS = []
visible_devices = ",".join([str(n) for n in NUM_GPUS])  # Delect devices to place workers
device_limit_frac = 0.7  # Spill GPU-Worker memory to host at this limit.
device_pool_frac = 0.8
part_mem_frac = 0.15

# Use total device size to calculate args.device_limit_frac
device_size = device_mem_size(kind="total")
device_limit = int(device_limit_frac * device_size)
device_pool_size = int(device_pool_frac * device_size)
part_size = int(part_mem_frac * device_size)

# Check if any device memory is already occupied
for dev in visible_devices.split(","):
    fmem = pynvml_mem_size(kind="free", index=int(dev))
    used = (device_size - fmem) / 1e9
    if used > 1.0:
        warnings.warn(f"BEWARE - {used} GB is already occupied on device {int(dev)}!")

cluster = None  # (Optional) Specify existing scheduler port
if cluster is None:
    cluster = LocalCUDACluster(
        protocol=protocol,
        n_workers=len(visible_devices.split(",")),
        CUDA_VISIBLE_DEVICES=visible_devices,
        device_memory_limit=device_limit,
        dashboard_address=":" + dashboard_port,
        rmm_pool_size=(device_pool_size // 256) * 256
    )

# Create the distributed client
client = Client(cluster)
client
2022-11-14 06:59:41,436 Using selector: EpollSelector
2022-11-14 06:59:43,469 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,494 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,499 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,502 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,515 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,525 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,575 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,576 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize

Client

Client-ebccc26f-63e9-11ed-957d-54ab3adac0a5

Connection method: Cluster object Cluster type: dask_cuda.LocalCUDACluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

train_output = os.path.join(DATA_DIR, "train")
print("Training output data: "+train_output)
val_output = os.path.join(DATA_DIR, "val")
print("Validation output data: "+val_output)
train_input = os.path.join(DATA_DIR, "train/train.txt")
print("Training dataset: "+train_input)
val_input = os.path.join(DATA_DIR, "val/test.txt")
PREPROCESS_DIR_temp_train = os.path.join(DATA_DIR, 'train/temp-parquet-after-conversion')  
PREPROCESS_DIR_temp_val = os.path.join(DATA_DIR, "val/temp-parquet-after-conversion")
if not os.path.exists(PREPROCESS_DIR_temp_train):
    os.makedirs(PREPROCESS_DIR_temp_train)

if not os.path.exists(PREPROCESS_DIR_temp_val):
    os.makedirs(PREPROCESS_DIR_temp_val)

PREPROCESS_DIR_temp = [PREPROCESS_DIR_temp_train, PREPROCESS_DIR_temp_val]

# Make sure we have a clean parquet space for cudf conversion
for one_path in PREPROCESS_DIR_temp:
    if os.path.exists(one_path):
        shutil.rmtree(one_path)
    os.mkdir(one_path)

#calculate the total processing time
runtime = time.time()

## train/valid txt to parquet
train_valid_paths = [(train_input,PREPROCESS_DIR_temp_train),(val_input,PREPROCESS_DIR_temp_val)]

for input, temp_output in train_valid_paths:

    ddf = dask_cudf.read_csv(input,sep='\t',names=LABEL_COLUMNS + CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS)

    ddf["label"] = ddf['label'].astype('float32')
    ddf[CONTINUOUS_COLUMNS] = ddf[CONTINUOUS_COLUMNS].astype('float32')

    # Save it as parquet format for better memory usage
    ddf.to_parquet(temp_output,header=True)
    ##-----------------------------------##

COLUMNS =  LABEL_COLUMNS + CONTINUOUS_COLUMNS + CROSS_COLUMNS + CATEGORICAL_COLUMNS
train_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_train, "*.parquet"))
valid_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_val, "*.parquet"))

categorify_op = Categorify()
cat_features = CATEGORICAL_COLUMNS >> categorify_op
cont_features = CONTINUOUS_COLUMNS >> FillMissing() >> Clip(min_value=0) >> Normalize()
cross_cat_op = Categorify(encode_type="combo")

features = LABEL_COLUMNS

features += cont_features
if CROSS_COLUMNS:
    feature_pairs = [pair.split("_") for pair in CROSS_COLUMNS]
    for pair in feature_pairs:
        features += [pair] >> cross_cat_op

features += cat_features

workflow = nvt.Workflow(features)

logging.info("Preprocessing")

output_format = 'parquet'

# just for /samples/criteo model
train_ds_iterator = nvt.Dataset(train_paths, engine='parquet')
valid_ds_iterator = nvt.Dataset(valid_paths, engine='parquet')

shuffle = nvt.io.Shuffle.PER_PARTITION

logging.info('Train Datasets Preprocessing.....')

dict_dtypes = {}
for col in CATEGORICAL_COLUMNS:
    dict_dtypes[col] = np.int64
for col in CONTINUOUS_COLUMNS:
    dict_dtypes[col] = np.float32
for col in CROSS_COLUMNS:
    dict_dtypes[col] = np.int64
for col in LABEL_COLUMNS:
    dict_dtypes[col] = np.float32

conts = CONTINUOUS_COLUMNS

workflow.fit(train_ds_iterator)

if output_format == 'hugectr':
    workflow.transform(train_ds_iterator).to_hugectr(
            cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
            conts=conts,
            labels=LABEL_COLUMNS,
            output_path=train_output,
            shuffle=shuffle)
else:
    workflow.transform(train_ds_iterator).to_parquet(
            output_path=train_output,
            dtypes=dict_dtypes,
            cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
            conts=conts,
            labels=LABEL_COLUMNS,
            shuffle=shuffle)

###Getting slot size###    
#--------------------##
embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]

print(embeddings)
##--------------------##

logging.info('Valid Datasets Preprocessing.....')

if output_format == 'hugectr':
    workflow.transform(valid_ds_iterator).to_hugectr(
            cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
            conts=conts,
            labels=LABEL_COLUMNS,
            output_path=val_output,
            shuffle=shuffle)
else:
    workflow.transform(valid_ds_iterator).to_parquet(
            output_path=val_output,
            dtypes=dict_dtypes,
            cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
            conts=conts,
            labels=LABEL_COLUMNS,
            shuffle=shuffle)

embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]

print(embeddings)
##--------------------##

## Shutdown clusters
client.close()

runtime = time.time() - runtime

print("\nDask-NVTabular Criteo Preprocessing Done!")
print(f"Runtime[s]         | {runtime}")
print("======================================\n")
Training output data: /hugectr_e2e/criteo/train
Validation output data: /hugectr_e2e/criteo/val
Training dataset: /hugectr_e2e/criteo/train/train.txt
2022-11-14 07:03:35,978 Preprocessing
2022-11-14 07:03:36,175 Train Datasets Preprocessing.....
2022-11-14 07:03:41,211 Valid Datasets Preprocessing.....
[1234907, 19683, 13780, 6867, 18490, 4, 6264, 1235, 50, 854680, 114026, 75736, 11, 2159, 7533, 61, 4, 919, 15, 1307783, 404742, 1105613, 87714, 9032, 77, 34, 1581605, 1093030]
[1234907, 19683, 13780, 6867, 18490, 4, 6264, 1235, 50, 854680, 114026, 75736, 11, 2159, 7533, 61, 4, 919, 15, 1307783, 404742, 1105613, 87714, 9032, 77, 34, 1581605, 1093030]

Dask-NVTabular Criteo Preprocessing Done!
Runtime[s]         | 8.17225694656372
======================================
### Record the slot size array
SLOT_SIZE_ARRAY = embeddings

Training a WDL model with HugeCTR

%%writefile './train.py'
import hugectr
import os
import argparse
from mpi4py import MPI
parser = argparse.ArgumentParser(description=("HugeCTR Training"))
parser.add_argument("--data_path", type=str, help="Input dataset path (Required)")
parser.add_argument("--model_path", type=str, help="Directory path to write output (Required)")
args = parser.parse_args()
SLOT_SIZE_ARRAY = [1234907, 19683, 13780, 6867, 18490, 4, 6264, 1235, 50, 854680, 114026, 75736, 11, 2159, 7533, 61, 4, 919, 15, 1307783, 404742, 1105613, 87714, 9032, 77, 34, 1581605, 1093030]

solver = hugectr.CreateSolver(max_eval_batches = 4000,
                              batchsize_eval = 2720,
                              batchsize = 2720,
                              lr = 0.001,
                              vvgpu = [[0]],
                              repeat_dataset = True,
                              i64_input_key = True)

reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
                                  source = [os.path.join(args.data_path, "train/_file_list.txt")],
                                  eval_source = os.path.join(args.data_path, "val/_file_list.txt"),
                                  check_type = hugectr.Check_t.Non,
                                  slot_size_array = SLOT_SIZE_ARRAY)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam,
                                    update_type = hugectr.Update_t.Global,
                                    beta1 = 0.9,
                                    beta2 = 0.999,
                                    epsilon = 0.0000001)
model = hugectr.Model(solver, reader, optimizer)

model.add(hugectr.Input(label_dim = 1, label_name = "label",
                        dense_dim = 13, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("wide_data", 1, True, 2),
                        hugectr.DataReaderSparseParam("deep_data", 2, False, 26)]))

model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 80,
                            embedding_vec_size = 1,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding2",
                            bottom_name = "wide_data",
                            optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 1350,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "deep_data",
                            optimizer = optimizer))

model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=2))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReduceSum,
                            bottom_names = ["reshape2"],
                            top_names = ["wide_redn"],
                            axis = 1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "dense"],
                            top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc2"],
                            top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu2"],
                            top_names = ["dropout2"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout2"],
                            top_names = ["fc3"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
                            bottom_names = ["fc3", "wide_redn"],
                            top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["add1", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.fit(max_iter = 21000, display = 1000, eval_interval = 4000, snapshot = 20000, snapshot_prefix = os.path.join(args.model_path, "wdl/"))
model.graph_to_json(graph_config_file = os.path.join(args.model_path, "wdl.json"))
Overwriting ./train.py
!python train.py --data_path $DATA_DIR --model_path $MODEL_DIR
HugeCTR Version: 4.0
====================================================Model Init=====================================================
[HCTR][07:27:59.400][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][07:27:59.400][INFO][RK0][main]: Global seed is 2258624009
[HCTR][07:27:59.403][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][07:28:01.312][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][07:28:01.312][INFO][RK0][main]: Start all2all warmup
[HCTR][07:28:01.312][INFO][RK0][main]: End all2all warmup
[HCTR][07:28:01.313][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][07:28:01.313][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][07:28:01.314][INFO][RK0][main]: num of DataReader workers for train: 1
[HCTR][07:28:01.314][INFO][RK0][main]: num of DataReader workers for eval: 1
[HCTR][07:28:01.318][INFO][RK0][main]: Vocabulary size: 7946054
[HCTR][07:28:01.318][INFO][RK0][main]: max_vocabulary_size_per_gpu_=6990506
[HCTR][07:28:01.334][INFO][RK0][main]: max_vocabulary_size_per_gpu_=7372800
[HCTR][07:28:01.341][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][07:28:09.781][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][07:28:09.782][INFO][RK0][main]: gpu0 init embedding done
[HCTR][07:28:09.782][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][07:28:09.785][INFO][RK0][main]: gpu0 init embedding done
[HCTR][07:28:09.787][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][07:28:09.792][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][07:28:09.792][INFO][RK0][main]: Model structure on each GPU
Label                                   Dense                         Sparse                        
label                                   dense                          wide_data,deep_data           
(2720,1)                                (2720,13)                               
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type                              Input Name                    Output Name                   Output Shape                  
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash      wide_data                     sparse_embedding2             (2720,2,1)                    
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      deep_data                     sparse_embedding1             (2720,26,16)                  
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding1             reshape1                      (2720,416)                    
------------------------------------------------------------------------------------------------------------------
Reshape                                 sparse_embedding2             reshape2                      (2720,2)                      
------------------------------------------------------------------------------------------------------------------
ReduceSum                               reshape2                      wide_redn                     (2720,1)                      
------------------------------------------------------------------------------------------------------------------
Concat                                  reshape1                      concat1                       (2720,429)                    
                                        dense                                                                                     
------------------------------------------------------------------------------------------------------------------
InnerProduct                            concat1                       fc1                           (2720,1024)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc1                           relu1                         (2720,1024)                   
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu1                         dropout1                      (2720,1024)                   
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout1                      fc2                           (2720,1024)                   
------------------------------------------------------------------------------------------------------------------
ReLU                                    fc2                           relu2                         (2720,1024)                   
------------------------------------------------------------------------------------------------------------------
Dropout                                 relu2                         dropout2                      (2720,1024)                   
------------------------------------------------------------------------------------------------------------------
InnerProduct                            dropout2                      fc3                           (2720,1)                      
------------------------------------------------------------------------------------------------------------------
Add                                     fc3                           add1                          (2720,1)                      
                                        wide_redn                                                                                 
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss                  add1                          loss                                                        
                                        label                                                                                     
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[HCTR][07:28:09.792][INFO][RK0][main]: Use non-epoch mode with number of iterations: 21000
[HCTR][07:28:09.792][INFO][RK0][main]: Training batchsize: 2720, evaluation batchsize: 2720
[HCTR][07:28:09.792][INFO][RK0][main]: Evaluation interval: 4000, snapshot interval: 20000
[HCTR][07:28:09.792][INFO][RK0][main]: Dense network trainable: True
[HCTR][07:28:09.792][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][07:28:09.792][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][07:28:09.792][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][07:28:09.792][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][07:28:09.792][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][07:28:09.792][INFO][RK0][main]: Training source file: /hugectr_e2e/criteo/train/_file_list.txt
[HCTR][07:28:09.792][INFO][RK0][main]: Evaluation source file: /hugectr_e2e/criteo/val/_file_list.txt
[HCTR][07:28:18.503][INFO][RK0][main]: Iter: 1000 Time(1000 iters): 8.70626s Loss: 0.131954 lr:0.001
[HCTR][07:28:27.274][INFO][RK0][main]: Iter: 2000 Time(1000 iters): 8.76682s Loss: 0.135973 lr:0.001
[HCTR][07:28:36.261][INFO][RK0][main]: Iter: 3000 Time(1000 iters): 8.98195s Loss: 0.116014 lr:0.001
[HCTR][07:28:45.385][INFO][RK0][main]: Iter: 4000 Time(1000 iters): 9.12015s Loss: 0.100682 lr:0.001
[HCTR][07:28:50.025][INFO][RK0][main]: Evaluation, AUC: 0.734929
[HCTR][07:28:50.025][INFO][RK0][main]: Eval Time for 4000 iters: 4.6372s
[HCTR][07:28:59.179][INFO][RK0][main]: Iter: 5000 Time(1000 iters): 13.7896s Loss: 0.111253 lr:0.001
[HCTR][07:29:08.355][INFO][RK0][main]: Iter: 6000 Time(1000 iters): 9.17119s Loss: 0.11407 lr:0.001
[HCTR][07:29:17.488][INFO][RK0][main]: Iter: 7000 Time(1000 iters): 9.12885s Loss: 0.102613 lr:0.001
[HCTR][07:29:26.636][INFO][RK0][main]: Iter: 8000 Time(1000 iters): 9.14357s Loss: 0.0954151 lr:0.001
[HCTR][07:29:31.243][INFO][RK0][main]: Evaluation, AUC: 0.709346
[HCTR][07:29:31.243][INFO][RK0][main]: Eval Time for 4000 iters: 4.60503s
[HCTR][07:29:40.356][INFO][RK0][main]: Iter: 9000 Time(1000 iters): 13.7146s Loss: 0.0999723 lr:0.001
[HCTR][07:29:49.485][INFO][RK0][main]: Iter: 10000 Time(1000 iters): 9.12492s Loss: 0.0854849 lr:0.001
[HCTR][07:29:58.601][INFO][RK0][main]: Iter: 11000 Time(1000 iters): 9.112s Loss: 0.086353 lr:0.001
[HCTR][07:30:07.736][INFO][RK0][main]: Iter: 12000 Time(1000 iters): 9.13015s Loss: 0.0903414 lr:0.001
[HCTR][07:30:12.334][INFO][RK0][main]: Evaluation, AUC: 0.694103
[HCTR][07:30:12.334][INFO][RK0][main]: Eval Time for 4000 iters: 4.59641s
[HCTR][07:30:21.455][INFO][RK0][main]: Iter: 13000 Time(1000 iters): 13.7151s Loss: 0.0813873 lr:0.001
[HCTR][07:30:30.579][INFO][RK0][main]: Iter: 14000 Time(1000 iters): 9.11932s Loss: 0.0972778 lr:0.001
[HCTR][07:30:39.711][INFO][RK0][main]: Iter: 15000 Time(1000 iters): 9.12719s Loss: 0.0762291 lr:0.001
[HCTR][07:30:48.820][INFO][RK0][main]: Iter: 16000 Time(1000 iters): 9.10478s Loss: 0.092993 lr:0.001
[HCTR][07:30:53.425][INFO][RK0][main]: Evaluation, AUC: 0.681651
[HCTR][07:30:53.425][INFO][RK0][main]: Eval Time for 4000 iters: 4.60313s
[HCTR][07:31:02.550][INFO][RK0][main]: Iter: 17000 Time(1000 iters): 13.7253s Loss: 0.0736029 lr:0.001
[HCTR][07:31:11.675][INFO][RK0][main]: Iter: 18000 Time(1000 iters): 9.12101s Loss: 0.0938892 lr:0.001
[HCTR][07:31:20.801][INFO][RK0][main]: Iter: 19000 Time(1000 iters): 9.12153s Loss: 0.0925995 lr:0.001
[HCTR][07:31:29.922][INFO][RK0][main]: Iter: 20000 Time(1000 iters): 9.11565s Loss: 0.0869264 lr:0.001
[HCTR][07:31:34.533][INFO][RK0][main]: Evaluation, AUC: 0.678614
[HCTR][07:31:34.533][INFO][RK0][main]: Eval Time for 4000 iters: 4.61036s
[HCTR][07:31:34.533][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:34.558][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][07:31:34.575][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:34.731][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][07:31:34.906][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][07:31:34.922][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][07:31:34.922][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:34.937][INFO][RK0][main]: Done
[HCTR][07:31:34.940][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][07:31:34.940][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:34.953][INFO][RK0][main]: Done
[HCTR][07:31:35.221][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][07:31:35.221][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:35.443][INFO][RK0][main]: Done
[HCTR][07:31:35.720][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][07:31:35.720][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:35.941][INFO][RK0][main]: Done
[HCTR][07:31:35.953][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][07:31:35.954][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:35.957][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][07:31:35.958][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:35.964][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][07:31:45.076][INFO][RK0][main]: Finish 21000 iterations with batchsize: 2720 in 215.28s.
[HCTR][07:31:45.076][INFO][RK0][main]: Save the model graph to /hugectr_e2e/model/wdl.json successfully

Load model to HPS and inference with HugeCTR

from hugectr.inference import InferenceParams, CreateInferenceSession
import pandas as pd
import numpy as np
    
CATEGORICAL_COLUMNS=["C1_C2","C3_C4"]+["C" + str(x) for x in range(1, 27)]
CONTINUOUS_COLUMNS=["I" + str(x) for x in range(1, 14)]
LABEL_COLUMNS = ['label']
shift = np.insert(np.cumsum(SLOT_SIZE_ARRAY), 0, 0)[:-1]
test_df=pd.read_parquet(os.path.join(DATA_DIR, "val/part_0.parquet"))[:10]
config_file = os.path.join(MODEL_DIR, "wdl.json")
row_ptrs = list(range(0,21))+list(range(0,261))
dense_features =  list(test_df[CONTINUOUS_COLUMNS].values.flatten())
test_df[CATEGORICAL_COLUMNS].astype(np.int64)
embedding_columns = list((test_df[CATEGORICAL_COLUMNS]+shift).values.flatten())

# create parameter server, embedding cache and inference session
inference_params = InferenceParams(model_name = "wdl",
                            max_batchsize = 64,
                            hit_rate_threshold = 0.9,
                            dense_model_file = os.path.join(MODEL_DIR, "wdl/_dense_20000.model"),
                            sparse_model_files = [os.path.join(MODEL_DIR, "wdl/0_sparse_20000.model"), os.path.join(MODEL_DIR, "wdl/1_sparse_20000.model")],
                            device_id = 0,
                            use_gpu_embedding_cache = True,
                            cache_size_percentage = 0.9,
                            i64_input_key = True,
                            use_mixed_precision = False
                            )
inference_session = CreateInferenceSession(config_file, inference_params)
output = inference_session.predict(dense_features, embedding_columns, row_ptrs)
print("WDL multi-embedding table inference result is {}".format(output))
[HCTR][08:01:43.577][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
WDL multi-embedding table inference result is [0.016221938654780388, 0.08543526381254196, 3.26810294382085e-07, 0.02832142263650894, 0.06627560406923294, 0.0002603427565190941, 0.00022551437723450363, 0.02671617455780506, 0.0031104201916605234, 0.0017484374111518264]
[HCTR][08:01:43.578][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
[HCTR][08:01:43.578][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][08:01:43.578][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][08:01:43.578][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][08:01:43.578][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][08:01:43.578][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][08:01:43.578][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][08:01:43.578][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:01:44.508][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding2; cached 2327936 / 2327936 embeddings in volatile database (HashMapBackend); load: 2327936 / 18446744073709551615 (0.00%).
[HCTR][08:01:44.508][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:01:45.076][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding1; cached 4169063 / 4169063 embeddings in volatile database (HashMapBackend); load: 4169063 / 18446744073709551615 (0.00%).
[HCTR][08:01:45.083][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][08:01:45.083][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][08:01:45.088][INFO][RK0][main]: Model name: wdl
[HCTR][08:01:45.088][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][08:01:45.088][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.900000
[HCTR][08:01:45.088][INFO][RK0][main]: Use I64 input key: True
[HCTR][08:01:45.088][INFO][RK0][main]: Configured cache hit rate threshold: 0.900000
[HCTR][08:01:45.088][INFO][RK0][main]: The size of thread pool: 80
[HCTR][08:01:45.088][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][08:01:45.088][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][08:01:45.088][INFO][RK0][main]: The refresh percentage : 0.000000
[HCTR][08:01:46.044][INFO][RK0][main]: Global seed is 1583686956
[HCTR][08:01:46.047][INFO][RK0][main]: Device to NUMA mapping:
  GPU 0 ->  node 0
[HCTR][08:01:46.985][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][08:01:46.985][INFO][RK0][main]: Start all2all warmup
[HCTR][08:01:46.985][INFO][RK0][main]: End all2all warmup
[HCTR][08:01:46.986][INFO][RK0][main]: Model name: wdl
[HCTR][08:01:46.986][INFO][RK0][main]: Use mixed precision: False
[HCTR][08:01:46.986][INFO][RK0][main]: Use cuda graph: True
[HCTR][08:01:46.986][INFO][RK0][main]: Max batchsize: 64
[HCTR][08:01:46.986][INFO][RK0][main]: Use I64 input key: True
[HCTR][08:01:46.986][INFO][RK0][main]: start create embedding for inference
[HCTR][08:01:46.986][INFO][RK0][main]: sparse_input name wide_data
[HCTR][08:01:46.986][INFO][RK0][main]: sparse_input name deep_data
[HCTR][08:01:46.986][INFO][RK0][main]: create embedding for inference success
[HCTR][08:01:46.987][INFO][RK0][main]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer