HugeCTR End-end Example with NVTabular
Overview
In this sample notebook, we are going to:
Preprocess data using NVTabular
Training model with HugeCTR
Do offline inference using HugeCTR HPS
Setup
To setup the environment, refer to HugeCTR Example Notebooks and follow the instructions there before running the following.
Data Preparation
import os
import shutil
!mkdir -p /hugectr_e2e
!mkdir -p /hugectr_e2e/criteo/train
!mkdir -p /hugectr_e2e/criteo/val
!mkdir -p /hugectr_e2e/model
BASE_DIR = os.environ.get("BASE_DIR", "/hugectr_e2e")
DATA_DIR = os.environ.get("DATA_DIR", BASE_DIR + "/criteo")
TRAIN_DIR = os.environ.get("TRAIN_DIR", DATA_DIR +"/train")
VAL_DIR = os.environ.get("VAL_DIR", DATA_DIR +"/val")
MODEL_DIR = os.environ.get("MODEL_DIR", BASE_DIR + "/model")
Download the Criteo data for 1 day:
#!wget -P $DATA_DIR https://storage.googleapis.com/criteo-cail-datasets/day_0.gz #decomment this line to download, otherwise soft link the data.
#!gzip -d -c $DATA_DIR/day_0.gz > $DATA_DIR/day_0
INPUT_DATA = os.environ.get("INPUT_DATA", DATA_DIR + "/day_0")
!ln -s $INPUT_DATA $DATA_DIR/day_0
--2022-11-11 09:06:02-- https://storage.googleapis.com/criteo-cail-datasets/day_0.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.191.48, 142.251.46.208, 172.217.164.112, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.191.48|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16309554343 (15G) [application/octet-stream]
Saving to: ‘/hugectr_e2e/criteo/day_0.gz’
day_0.gz 100%[===================>] 15.19G 94.4MB/s in 4m 58s
2022-11-11 09:11:00 (52.2 MB/s) - ‘/hugectr_e2e/criteo/day_0.gz’ saved [16309554343/16309554343]
Unzip and split data
!head -n 10000000 $DATA_DIR/day_0 > $DATA_DIR/train/train.txt
!tail -n 2000000 $DATA_DIR/day_0 > $DATA_DIR/val/test.txt
Data Preprocessing using NVTabular
import sys
import argparse
import glob
import time
import numpy as np
import numba
import dask_cudf
import cudf
import nvtabular as nvt
from nvtabular.io import Shuffle
from nvtabular.ops import Categorify, Clip, FillMissing, Normalize, get_embedding_sizes
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from nvtabular.utils import pynvml_mem_size, device_mem_size
import warnings
import logging
logging.basicConfig(format='%(asctime)s %(message)s')
logging.root.setLevel(logging.NOTSET)
# define dataset schema
CATEGORICAL_COLUMNS=["C" + str(x) for x in range(1, 27)]
CONTINUOUS_COLUMNS=["I" + str(x) for x in range(1, 14)]
LABEL_COLUMNS = ['label']
COLUMNS = LABEL_COLUMNS + CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS
#/samples/criteo mode doesn't have dense features
criteo_COLUMN=LABEL_COLUMNS + CATEGORICAL_COLUMNS
#For new feature cross columns
CROSS_COLUMNS = ["C1_C2", "C3_C4"]
NUM_INTEGER_COLUMNS = 13
NUM_CATEGORICAL_COLUMNS = 26
NUM_TOTAL_COLUMNS = 1 + NUM_INTEGER_COLUMNS + NUM_CATEGORICAL_COLUMNS
# Dask dashboard
dashboard_port = "8787"
# Deploy a Single-Machine Multi-GPU Cluster
protocol = "tcp" # "tcp" or "ucx"
if numba.cuda.is_available():
NUM_GPUS = list(range(len(numba.cuda.gpus)))
else:
NUM_GPUS = []
visible_devices = ",".join([str(n) for n in NUM_GPUS]) # Delect devices to place workers
device_limit_frac = 0.7 # Spill GPU-Worker memory to host at this limit.
device_pool_frac = 0.8
part_mem_frac = 0.15
# Use total device size to calculate args.device_limit_frac
device_size = device_mem_size(kind="total")
device_limit = int(device_limit_frac * device_size)
device_pool_size = int(device_pool_frac * device_size)
part_size = int(part_mem_frac * device_size)
# Check if any device memory is already occupied
for dev in visible_devices.split(","):
fmem = pynvml_mem_size(kind="free", index=int(dev))
used = (device_size - fmem) / 1e9
if used > 1.0:
warnings.warn(f"BEWARE - {used} GB is already occupied on device {int(dev)}!")
cluster = None # (Optional) Specify existing scheduler port
if cluster is None:
cluster = LocalCUDACluster(
protocol=protocol,
n_workers=len(visible_devices.split(",")),
CUDA_VISIBLE_DEVICES=visible_devices,
device_memory_limit=device_limit,
dashboard_address=":" + dashboard_port,
rmm_pool_size=(device_pool_size // 256) * 256
)
# Create the distributed client
client = Client(cluster)
client
2022-11-14 06:59:41,436 Using selector: EpollSelector
2022-11-14 06:59:43,469 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,494 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,499 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,502 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,515 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,525 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,575 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-14 06:59:43,576 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
Client
Client-ebccc26f-63e9-11ed-957d-54ab3adac0a5
Connection method: Cluster object | Cluster type: dask_cuda.LocalCUDACluster |
Dashboard: http://127.0.0.1:8787/status |
Cluster Info
LocalCUDACluster
df3321b7
Dashboard: http://127.0.0.1:8787/status | Workers: 8 |
Total threads: 8 | Total memory: 503.79 GiB |
Status: running | Using processes: True |
Scheduler Info
Scheduler
Scheduler-46de49cf-ae86-402e-84bb-9fc2adfe2312
Comm: tcp://127.0.0.1:45161 | Workers: 8 |
Dashboard: http://127.0.0.1:8787/status | Total threads: 8 |
Started: Just now | Total memory: 503.79 GiB |
Workers
Worker: 0
Comm: tcp://127.0.0.1:34845 | Total threads: 1 |
Dashboard: http://127.0.0.1:38849/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:38877 | |
Local directory: /jershi/hugectr/notebooks/dask-worker-space/worker-wtrkwr6a | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 1
Comm: tcp://127.0.0.1:45015 | Total threads: 1 |
Dashboard: http://127.0.0.1:34717/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:34283 | |
Local directory: /jershi/hugectr/notebooks/dask-worker-space/worker-us31wdmj | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 2
Comm: tcp://127.0.0.1:41503 | Total threads: 1 |
Dashboard: http://127.0.0.1:46523/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:36911 | |
Local directory: /jershi/hugectr/notebooks/dask-worker-space/worker-cqwboboi | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 3
Comm: tcp://127.0.0.1:42209 | Total threads: 1 |
Dashboard: http://127.0.0.1:36447/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:42605 | |
Local directory: /jershi/hugectr/notebooks/dask-worker-space/worker-b3erpkz1 | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 4
Comm: tcp://127.0.0.1:44799 | Total threads: 1 |
Dashboard: http://127.0.0.1:44091/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:35839 | |
Local directory: /jershi/hugectr/notebooks/dask-worker-space/worker-xr67lhyx | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 5
Comm: tcp://127.0.0.1:35535 | Total threads: 1 |
Dashboard: http://127.0.0.1:38173/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:41271 | |
Local directory: /jershi/hugectr/notebooks/dask-worker-space/worker-q1rqghpw | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 6
Comm: tcp://127.0.0.1:37771 | Total threads: 1 |
Dashboard: http://127.0.0.1:32825/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:43015 | |
Local directory: /jershi/hugectr/notebooks/dask-worker-space/worker-j8b27jww | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 7
Comm: tcp://127.0.0.1:44523 | Total threads: 1 |
Dashboard: http://127.0.0.1:32877/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:43101 | |
Local directory: /jershi/hugectr/notebooks/dask-worker-space/worker-n0usu6qf | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
train_output = os.path.join(DATA_DIR, "train")
print("Training output data: "+train_output)
val_output = os.path.join(DATA_DIR, "val")
print("Validation output data: "+val_output)
train_input = os.path.join(DATA_DIR, "train/train.txt")
print("Training dataset: "+train_input)
val_input = os.path.join(DATA_DIR, "val/test.txt")
PREPROCESS_DIR_temp_train = os.path.join(DATA_DIR, 'train/temp-parquet-after-conversion')
PREPROCESS_DIR_temp_val = os.path.join(DATA_DIR, "val/temp-parquet-after-conversion")
if not os.path.exists(PREPROCESS_DIR_temp_train):
os.makedirs(PREPROCESS_DIR_temp_train)
if not os.path.exists(PREPROCESS_DIR_temp_val):
os.makedirs(PREPROCESS_DIR_temp_val)
PREPROCESS_DIR_temp = [PREPROCESS_DIR_temp_train, PREPROCESS_DIR_temp_val]
# Make sure we have a clean parquet space for cudf conversion
for one_path in PREPROCESS_DIR_temp:
if os.path.exists(one_path):
shutil.rmtree(one_path)
os.mkdir(one_path)
#calculate the total processing time
runtime = time.time()
## train/valid txt to parquet
train_valid_paths = [(train_input,PREPROCESS_DIR_temp_train),(val_input,PREPROCESS_DIR_temp_val)]
for input, temp_output in train_valid_paths:
ddf = dask_cudf.read_csv(input,sep='\t',names=LABEL_COLUMNS + CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS)
ddf["label"] = ddf['label'].astype('float32')
ddf[CONTINUOUS_COLUMNS] = ddf[CONTINUOUS_COLUMNS].astype('float32')
# Save it as parquet format for better memory usage
ddf.to_parquet(temp_output,header=True)
##-----------------------------------##
COLUMNS = LABEL_COLUMNS + CONTINUOUS_COLUMNS + CROSS_COLUMNS + CATEGORICAL_COLUMNS
train_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_train, "*.parquet"))
valid_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_val, "*.parquet"))
categorify_op = Categorify()
cat_features = CATEGORICAL_COLUMNS >> categorify_op
cont_features = CONTINUOUS_COLUMNS >> FillMissing() >> Clip(min_value=0) >> Normalize()
cross_cat_op = Categorify(encode_type="combo")
features = LABEL_COLUMNS
features += cont_features
if CROSS_COLUMNS:
feature_pairs = [pair.split("_") for pair in CROSS_COLUMNS]
for pair in feature_pairs:
features += [pair] >> cross_cat_op
features += cat_features
workflow = nvt.Workflow(features)
logging.info("Preprocessing")
output_format = 'parquet'
# just for /samples/criteo model
train_ds_iterator = nvt.Dataset(train_paths, engine='parquet')
valid_ds_iterator = nvt.Dataset(valid_paths, engine='parquet')
shuffle = nvt.io.Shuffle.PER_PARTITION
logging.info('Train Datasets Preprocessing.....')
dict_dtypes = {}
for col in CATEGORICAL_COLUMNS:
dict_dtypes[col] = np.int64
for col in CONTINUOUS_COLUMNS:
dict_dtypes[col] = np.float32
for col in CROSS_COLUMNS:
dict_dtypes[col] = np.int64
for col in LABEL_COLUMNS:
dict_dtypes[col] = np.float32
conts = CONTINUOUS_COLUMNS
workflow.fit(train_ds_iterator)
if output_format == 'hugectr':
workflow.transform(train_ds_iterator).to_hugectr(
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
output_path=train_output,
shuffle=shuffle)
else:
workflow.transform(train_ds_iterator).to_parquet(
output_path=train_output,
dtypes=dict_dtypes,
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
shuffle=shuffle)
###Getting slot size###
#--------------------##
embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]
print(embeddings)
##--------------------##
logging.info('Valid Datasets Preprocessing.....')
if output_format == 'hugectr':
workflow.transform(valid_ds_iterator).to_hugectr(
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
output_path=val_output,
shuffle=shuffle)
else:
workflow.transform(valid_ds_iterator).to_parquet(
output_path=val_output,
dtypes=dict_dtypes,
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
shuffle=shuffle)
embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]
print(embeddings)
##--------------------##
## Shutdown clusters
client.close()
runtime = time.time() - runtime
print("\nDask-NVTabular Criteo Preprocessing Done!")
print(f"Runtime[s] | {runtime}")
print("======================================\n")
Training output data: /hugectr_e2e/criteo/train
Validation output data: /hugectr_e2e/criteo/val
Training dataset: /hugectr_e2e/criteo/train/train.txt
2022-11-14 07:03:35,978 Preprocessing
2022-11-14 07:03:36,175 Train Datasets Preprocessing.....
2022-11-14 07:03:41,211 Valid Datasets Preprocessing.....
[1234907, 19683, 13780, 6867, 18490, 4, 6264, 1235, 50, 854680, 114026, 75736, 11, 2159, 7533, 61, 4, 919, 15, 1307783, 404742, 1105613, 87714, 9032, 77, 34, 1581605, 1093030]
[1234907, 19683, 13780, 6867, 18490, 4, 6264, 1235, 50, 854680, 114026, 75736, 11, 2159, 7533, 61, 4, 919, 15, 1307783, 404742, 1105613, 87714, 9032, 77, 34, 1581605, 1093030]
Dask-NVTabular Criteo Preprocessing Done!
Runtime[s] | 8.17225694656372
======================================
### Record the slot size array
SLOT_SIZE_ARRAY = embeddings
Training a WDL model with HugeCTR
%%writefile './train.py'
import hugectr
import os
import argparse
from mpi4py import MPI
parser = argparse.ArgumentParser(description=("HugeCTR Training"))
parser.add_argument("--data_path", type=str, help="Input dataset path (Required)")
parser.add_argument("--model_path", type=str, help="Directory path to write output (Required)")
args = parser.parse_args()
SLOT_SIZE_ARRAY = [1234907, 19683, 13780, 6867, 18490, 4, 6264, 1235, 50, 854680, 114026, 75736, 11, 2159, 7533, 61, 4, 919, 15, 1307783, 404742, 1105613, 87714, 9032, 77, 34, 1581605, 1093030]
solver = hugectr.CreateSolver(max_eval_batches = 4000,
batchsize_eval = 2720,
batchsize = 2720,
lr = 0.001,
vvgpu = [[0]],
repeat_dataset = True,
i64_input_key = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
source = [os.path.join(args.data_path, "train/_file_list.txt")],
eval_source = os.path.join(args.data_path, "val/_file_list.txt"),
check_type = hugectr.Check_t.Non,
slot_size_array = SLOT_SIZE_ARRAY)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam,
update_type = hugectr.Update_t.Global,
beta1 = 0.9,
beta2 = 0.999,
epsilon = 0.0000001)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
dense_dim = 13, dense_name = "dense",
data_reader_sparse_param_array =
[hugectr.DataReaderSparseParam("wide_data", 1, True, 2),
hugectr.DataReaderSparseParam("deep_data", 2, False, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 80,
embedding_vec_size = 1,
combiner = "sum",
sparse_embedding_name = "sparse_embedding2",
bottom_name = "wide_data",
optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 1350,
embedding_vec_size = 16,
combiner = "sum",
sparse_embedding_name = "sparse_embedding1",
bottom_name = "deep_data",
optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding1"],
top_names = ["reshape1"],
leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding2"],
top_names = ["reshape2"],
leading_dim=2))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReduceSum,
bottom_names = ["reshape2"],
top_names = ["wide_redn"],
axis = 1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
bottom_names = ["reshape1", "dense"],
top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["concat1"],
top_names = ["fc1"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc1"],
top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu1"],
top_names = ["dropout1"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout1"],
top_names = ["fc2"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc2"],
top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu2"],
top_names = ["dropout2"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout2"],
top_names = ["fc3"],
num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
bottom_names = ["fc3", "wide_redn"],
top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
bottom_names = ["add1", "label"],
top_names = ["loss"]))
model.compile()
model.summary()
model.fit(max_iter = 21000, display = 1000, eval_interval = 4000, snapshot = 20000, snapshot_prefix = os.path.join(args.model_path, "wdl/"))
model.graph_to_json(graph_config_file = os.path.join(args.model_path, "wdl.json"))
Overwriting ./train.py
!python train.py --data_path $DATA_DIR --model_path $MODEL_DIR
HugeCTR Version: 4.0
====================================================Model Init=====================================================
[HCTR][07:27:59.400][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][07:27:59.400][INFO][RK0][main]: Global seed is 2258624009
[HCTR][07:27:59.403][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][07:28:01.312][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][07:28:01.312][INFO][RK0][main]: Start all2all warmup
[HCTR][07:28:01.312][INFO][RK0][main]: End all2all warmup
[HCTR][07:28:01.313][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][07:28:01.313][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][07:28:01.314][INFO][RK0][main]: num of DataReader workers for train: 1
[HCTR][07:28:01.314][INFO][RK0][main]: num of DataReader workers for eval: 1
[HCTR][07:28:01.318][INFO][RK0][main]: Vocabulary size: 7946054
[HCTR][07:28:01.318][INFO][RK0][main]: max_vocabulary_size_per_gpu_=6990506
[HCTR][07:28:01.334][INFO][RK0][main]: max_vocabulary_size_per_gpu_=7372800
[HCTR][07:28:01.341][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][07:28:09.781][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][07:28:09.782][INFO][RK0][main]: gpu0 init embedding done
[HCTR][07:28:09.782][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][07:28:09.785][INFO][RK0][main]: gpu0 init embedding done
[HCTR][07:28:09.787][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][07:28:09.792][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][07:28:09.792][INFO][RK0][main]: Model structure on each GPU
Label Dense Sparse
label dense wide_data,deep_data
(2720,1) (2720,13)
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type Input Name Output Name Output Shape
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash wide_data sparse_embedding2 (2720,2,1)
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash deep_data sparse_embedding1 (2720,26,16)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding1 reshape1 (2720,416)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding2 reshape2 (2720,2)
------------------------------------------------------------------------------------------------------------------
ReduceSum reshape2 wide_redn (2720,1)
------------------------------------------------------------------------------------------------------------------
Concat reshape1 concat1 (2720,429)
dense
------------------------------------------------------------------------------------------------------------------
InnerProduct concat1 fc1 (2720,1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc1 relu1 (2720,1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu1 dropout1 (2720,1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout1 fc2 (2720,1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc2 relu2 (2720,1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu2 dropout2 (2720,1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout2 fc3 (2720,1)
------------------------------------------------------------------------------------------------------------------
Add fc3 add1 (2720,1)
wide_redn
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss add1 loss
label
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[HCTR][07:28:09.792][INFO][RK0][main]: Use non-epoch mode with number of iterations: 21000
[HCTR][07:28:09.792][INFO][RK0][main]: Training batchsize: 2720, evaluation batchsize: 2720
[HCTR][07:28:09.792][INFO][RK0][main]: Evaluation interval: 4000, snapshot interval: 20000
[HCTR][07:28:09.792][INFO][RK0][main]: Dense network trainable: True
[HCTR][07:28:09.792][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][07:28:09.792][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][07:28:09.792][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][07:28:09.792][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][07:28:09.792][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][07:28:09.792][INFO][RK0][main]: Training source file: /hugectr_e2e/criteo/train/_file_list.txt
[HCTR][07:28:09.792][INFO][RK0][main]: Evaluation source file: /hugectr_e2e/criteo/val/_file_list.txt
[HCTR][07:28:18.503][INFO][RK0][main]: Iter: 1000 Time(1000 iters): 8.70626s Loss: 0.131954 lr:0.001
[HCTR][07:28:27.274][INFO][RK0][main]: Iter: 2000 Time(1000 iters): 8.76682s Loss: 0.135973 lr:0.001
[HCTR][07:28:36.261][INFO][RK0][main]: Iter: 3000 Time(1000 iters): 8.98195s Loss: 0.116014 lr:0.001
[HCTR][07:28:45.385][INFO][RK0][main]: Iter: 4000 Time(1000 iters): 9.12015s Loss: 0.100682 lr:0.001
[HCTR][07:28:50.025][INFO][RK0][main]: Evaluation, AUC: 0.734929
[HCTR][07:28:50.025][INFO][RK0][main]: Eval Time for 4000 iters: 4.6372s
[HCTR][07:28:59.179][INFO][RK0][main]: Iter: 5000 Time(1000 iters): 13.7896s Loss: 0.111253 lr:0.001
[HCTR][07:29:08.355][INFO][RK0][main]: Iter: 6000 Time(1000 iters): 9.17119s Loss: 0.11407 lr:0.001
[HCTR][07:29:17.488][INFO][RK0][main]: Iter: 7000 Time(1000 iters): 9.12885s Loss: 0.102613 lr:0.001
[HCTR][07:29:26.636][INFO][RK0][main]: Iter: 8000 Time(1000 iters): 9.14357s Loss: 0.0954151 lr:0.001
[HCTR][07:29:31.243][INFO][RK0][main]: Evaluation, AUC: 0.709346
[HCTR][07:29:31.243][INFO][RK0][main]: Eval Time for 4000 iters: 4.60503s
[HCTR][07:29:40.356][INFO][RK0][main]: Iter: 9000 Time(1000 iters): 13.7146s Loss: 0.0999723 lr:0.001
[HCTR][07:29:49.485][INFO][RK0][main]: Iter: 10000 Time(1000 iters): 9.12492s Loss: 0.0854849 lr:0.001
[HCTR][07:29:58.601][INFO][RK0][main]: Iter: 11000 Time(1000 iters): 9.112s Loss: 0.086353 lr:0.001
[HCTR][07:30:07.736][INFO][RK0][main]: Iter: 12000 Time(1000 iters): 9.13015s Loss: 0.0903414 lr:0.001
[HCTR][07:30:12.334][INFO][RK0][main]: Evaluation, AUC: 0.694103
[HCTR][07:30:12.334][INFO][RK0][main]: Eval Time for 4000 iters: 4.59641s
[HCTR][07:30:21.455][INFO][RK0][main]: Iter: 13000 Time(1000 iters): 13.7151s Loss: 0.0813873 lr:0.001
[HCTR][07:30:30.579][INFO][RK0][main]: Iter: 14000 Time(1000 iters): 9.11932s Loss: 0.0972778 lr:0.001
[HCTR][07:30:39.711][INFO][RK0][main]: Iter: 15000 Time(1000 iters): 9.12719s Loss: 0.0762291 lr:0.001
[HCTR][07:30:48.820][INFO][RK0][main]: Iter: 16000 Time(1000 iters): 9.10478s Loss: 0.092993 lr:0.001
[HCTR][07:30:53.425][INFO][RK0][main]: Evaluation, AUC: 0.681651
[HCTR][07:30:53.425][INFO][RK0][main]: Eval Time for 4000 iters: 4.60313s
[HCTR][07:31:02.550][INFO][RK0][main]: Iter: 17000 Time(1000 iters): 13.7253s Loss: 0.0736029 lr:0.001
[HCTR][07:31:11.675][INFO][RK0][main]: Iter: 18000 Time(1000 iters): 9.12101s Loss: 0.0938892 lr:0.001
[HCTR][07:31:20.801][INFO][RK0][main]: Iter: 19000 Time(1000 iters): 9.12153s Loss: 0.0925995 lr:0.001
[HCTR][07:31:29.922][INFO][RK0][main]: Iter: 20000 Time(1000 iters): 9.11565s Loss: 0.0869264 lr:0.001
[HCTR][07:31:34.533][INFO][RK0][main]: Evaluation, AUC: 0.678614
[HCTR][07:31:34.533][INFO][RK0][main]: Eval Time for 4000 iters: 4.61036s
[HCTR][07:31:34.533][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:34.558][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][07:31:34.575][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:34.731][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][07:31:34.906][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][07:31:34.922][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][07:31:34.922][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:34.937][INFO][RK0][main]: Done
[HCTR][07:31:34.940][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][07:31:34.940][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:34.953][INFO][RK0][main]: Done
[HCTR][07:31:35.221][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][07:31:35.221][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:35.443][INFO][RK0][main]: Done
[HCTR][07:31:35.720][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][07:31:35.720][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:35.941][INFO][RK0][main]: Done
[HCTR][07:31:35.953][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][07:31:35.954][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:35.957][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][07:31:35.958][INFO][RK0][main]: Using Local file system backend.
[HCTR][07:31:35.964][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][07:31:45.076][INFO][RK0][main]: Finish 21000 iterations with batchsize: 2720 in 215.28s.
[HCTR][07:31:45.076][INFO][RK0][main]: Save the model graph to /hugectr_e2e/model/wdl.json successfully
Load model to HPS and inference with HugeCTR
from hugectr.inference import InferenceParams, CreateInferenceSession
import pandas as pd
import numpy as np
CATEGORICAL_COLUMNS=["C1_C2","C3_C4"]+["C" + str(x) for x in range(1, 27)]
CONTINUOUS_COLUMNS=["I" + str(x) for x in range(1, 14)]
LABEL_COLUMNS = ['label']
shift = np.insert(np.cumsum(SLOT_SIZE_ARRAY), 0, 0)[:-1]
test_df=pd.read_parquet(os.path.join(DATA_DIR, "val/part_0.parquet"))[:10]
config_file = os.path.join(MODEL_DIR, "wdl.json")
row_ptrs = list(range(0,21))+list(range(0,261))
dense_features = list(test_df[CONTINUOUS_COLUMNS].values.flatten())
test_df[CATEGORICAL_COLUMNS].astype(np.int64)
embedding_columns = list((test_df[CATEGORICAL_COLUMNS]+shift).values.flatten())
# create parameter server, embedding cache and inference session
inference_params = InferenceParams(model_name = "wdl",
max_batchsize = 64,
hit_rate_threshold = 0.9,
dense_model_file = os.path.join(MODEL_DIR, "wdl/_dense_20000.model"),
sparse_model_files = [os.path.join(MODEL_DIR, "wdl/0_sparse_20000.model"), os.path.join(MODEL_DIR, "wdl/1_sparse_20000.model")],
device_id = 0,
use_gpu_embedding_cache = True,
cache_size_percentage = 0.9,
i64_input_key = True,
use_mixed_precision = False
)
inference_session = CreateInferenceSession(config_file, inference_params)
output = inference_session.predict(dense_features, embedding_columns, row_ptrs)
print("WDL multi-embedding table inference result is {}".format(output))
[HCTR][08:01:43.577][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
WDL multi-embedding table inference result is [0.016221938654780388, 0.08543526381254196, 3.26810294382085e-07, 0.02832142263650894, 0.06627560406923294, 0.0002603427565190941, 0.00022551437723450363, 0.02671617455780506, 0.0031104201916605234, 0.0017484374111518264]
[HCTR][08:01:43.578][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
[HCTR][08:01:43.578][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][08:01:43.578][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][08:01:43.578][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][08:01:43.578][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][08:01:43.578][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][08:01:43.578][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][08:01:43.578][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:01:44.508][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding2; cached 2327936 / 2327936 embeddings in volatile database (HashMapBackend); load: 2327936 / 18446744073709551615 (0.00%).
[HCTR][08:01:44.508][INFO][RK0][main]: Using Local file system backend.
[HCTR][08:01:45.076][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding1; cached 4169063 / 4169063 embeddings in volatile database (HashMapBackend); load: 4169063 / 18446744073709551615 (0.00%).
[HCTR][08:01:45.083][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][08:01:45.083][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][08:01:45.088][INFO][RK0][main]: Model name: wdl
[HCTR][08:01:45.088][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][08:01:45.088][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.900000
[HCTR][08:01:45.088][INFO][RK0][main]: Use I64 input key: True
[HCTR][08:01:45.088][INFO][RK0][main]: Configured cache hit rate threshold: 0.900000
[HCTR][08:01:45.088][INFO][RK0][main]: The size of thread pool: 80
[HCTR][08:01:45.088][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][08:01:45.088][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][08:01:45.088][INFO][RK0][main]: The refresh percentage : 0.000000
[HCTR][08:01:46.044][INFO][RK0][main]: Global seed is 1583686956
[HCTR][08:01:46.047][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][08:01:46.985][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][08:01:46.985][INFO][RK0][main]: Start all2all warmup
[HCTR][08:01:46.985][INFO][RK0][main]: End all2all warmup
[HCTR][08:01:46.986][INFO][RK0][main]: Model name: wdl
[HCTR][08:01:46.986][INFO][RK0][main]: Use mixed precision: False
[HCTR][08:01:46.986][INFO][RK0][main]: Use cuda graph: True
[HCTR][08:01:46.986][INFO][RK0][main]: Max batchsize: 64
[HCTR][08:01:46.986][INFO][RK0][main]: Use I64 input key: True
[HCTR][08:01:46.986][INFO][RK0][main]: start create embedding for inference
[HCTR][08:01:46.986][INFO][RK0][main]: sparse_input name wide_data
[HCTR][08:01:46.986][INFO][RK0][main]: sparse_input name deep_data
[HCTR][08:01:46.986][INFO][RK0][main]: create embedding for inference success
[HCTR][08:01:46.987][INFO][RK0][main]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer