# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
HugeCTR End-end Example with NVTabular
Overview
In this sample notebook, we are going to:
Preprocess data using NVTabular
Training model with HugeCTR
Do offline inference using HugeCTR HPS
Setup
To setup the environment, refer to HugeCTR Example Notebooks and follow the instructions there before running the following.
Data Preparation
import os
import shutil
!mkdir -p /hugectr_e2e
!mkdir -p /hugectr_e2e/criteo/train
!mkdir -p /hugectr_e2e/criteo/val
!mkdir -p /hugectr_e2e/model
BASE_DIR = os.environ.get("BASE_DIR", "/hugectr_e2e")
DATA_DIR = os.environ.get("DATA_DIR", BASE_DIR + "/criteo")
TRAIN_DIR = os.environ.get("TRAIN_DIR", DATA_DIR +"/train")
VAL_DIR = os.environ.get("VAL_DIR", DATA_DIR +"/val")
MODEL_DIR = os.environ.get("MODEL_DIR", BASE_DIR + "/model")
Download the Criteo data for 1 day:
#!wget -P $DATA_DIR https://storage.googleapis.com/criteo-cail-datasets/day_0.gz #decomment this line to download, otherwise soft link the data.
#!gzip -d -c $DATA_DIR/day_0.gz > $DATA_DIR/day_0
INPUT_DATA = os.environ.get("INPUT_DATA", DATA_DIR + "/day_0")
!ln -s $INPUT_DATA $DATA_DIR/day_0
ln: failed to create symbolic link '/hugectr_e2e/criteo/day_0': File exists
Unzip and split data
!head -n 10000000 $DATA_DIR/day_0 > $DATA_DIR/train/train.txt
!tail -n 2000000 $DATA_DIR/day_0 > $DATA_DIR/val/test.txt
Data Preprocessing using NVTabular
import sys
import argparse
import glob
import time
import numpy as np
import numba
import dask_cudf
import cudf
import nvtabular as nvt
from nvtabular.io import Shuffle
from nvtabular.ops import Categorify, Clip, FillMissing, Normalize, get_embedding_sizes
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from nvtabular.utils import pynvml_mem_size, device_mem_size
import warnings
import logging
logging.basicConfig(format='%(asctime)s %(message)s')
logging.root.setLevel(logging.NOTSET)
# define dataset schema
CATEGORICAL_COLUMNS=["C" + str(x) for x in range(1, 27)]
CONTINUOUS_COLUMNS=["I" + str(x) for x in range(1, 14)]
LABEL_COLUMNS = ['label']
COLUMNS = LABEL_COLUMNS + CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS
#/samples/criteo mode doesn't have dense features
criteo_COLUMN=LABEL_COLUMNS + CATEGORICAL_COLUMNS
#For new feature cross columns
CROSS_COLUMNS = ["C1_C2", "C3_C4"]
NUM_INTEGER_COLUMNS = 13
NUM_CATEGORICAL_COLUMNS = 26
NUM_TOTAL_COLUMNS = 1 + NUM_INTEGER_COLUMNS + NUM_CATEGORICAL_COLUMNS
# Dask dashboard
dashboard_port = "8787"
# Deploy a Single-Machine Multi-GPU Cluster
protocol = "tcp" # "tcp" or "ucx"
if numba.cuda.is_available():
NUM_GPUS = list(range(len(numba.cuda.gpus)))
else:
NUM_GPUS = []
visible_devices = ",".join([str(n) for n in NUM_GPUS]) # Delect devices to place workers
device_limit_frac = 0.7 # Spill GPU-Worker memory to host at this limit.
device_pool_frac = 0.8
part_mem_frac = 0.15
# Use total device size to calculate args.device_limit_frac
device_size = device_mem_size(kind="total")
device_limit = int(device_limit_frac * device_size)
device_pool_size = int(device_pool_frac * device_size)
part_size = int(part_mem_frac * device_size)
# Check if any device memory is already occupied
for dev in visible_devices.split(","):
fmem = pynvml_mem_size(kind="free", index=int(dev))
used = (device_size - fmem) / 1e9
if used > 1.0:
warnings.warn(f"BEWARE - {used} GB is already occupied on device {int(dev)}!")
cluster = None # (Optional) Specify existing scheduler port
if cluster is None:
cluster = LocalCUDACluster(
protocol=protocol,
n_workers=len(visible_devices.split(",")),
CUDA_VISIBLE_DEVICES=visible_devices,
device_memory_limit=device_limit,
dashboard_address=":" + dashboard_port,
rmm_pool_size=(device_pool_size // 256) * 256
)
# Create the distributed client
client = Client(cluster)
client
2023-01-06 04:03:09,380 Using selector: EpollSelector
2023-01-06 04:03:11,334 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2023-01-06 04:03:11,334 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2023-01-06 04:03:11,343 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2023-01-06 04:03:11,344 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2023-01-06 04:03:11,362 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2023-01-06 04:03:11,362 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2023-01-06 04:03:11,381 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2023-01-06 04:03:11,381 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2023-01-06 04:03:11,402 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2023-01-06 04:03:11,402 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2023-01-06 04:03:11,411 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2023-01-06 04:03:11,411 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2023-01-06 04:03:11,413 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2023-01-06 04:03:11,413 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2023-01-06 04:03:11,419 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2023-01-06 04:03:11,419 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
Client
Client-0a63fe32-8d77-11ed-88dd-54ab3adac0a5
Connection method: Cluster object | Cluster type: dask_cuda.LocalCUDACluster |
Dashboard: http://127.0.0.1:8787/status |
Cluster Info
LocalCUDACluster
46ea38f1
Dashboard: http://127.0.0.1:8787/status | Workers: 8 |
Total threads: 8 | Total memory: 503.79 GiB |
Status: running | Using processes: True |
Scheduler Info
Scheduler
Scheduler-5770dfa1-869c-4143-a327-dd59936a928a
Comm: tcp://127.0.0.1:44759 | Workers: 8 |
Dashboard: http://127.0.0.1:8787/status | Total threads: 8 |
Started: Just now | Total memory: 503.79 GiB |
Workers
Worker: 0
Comm: tcp://127.0.0.1:41549 | Total threads: 1 |
Dashboard: http://127.0.0.1:44715/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:35427 | |
Local directory: /tmp/dask-worker-space/worker-j9o6tjnq | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 1
Comm: tcp://127.0.0.1:33573 | Total threads: 1 |
Dashboard: http://127.0.0.1:36839/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:43633 | |
Local directory: /tmp/dask-worker-space/worker-g048l4_5 | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 2
Comm: tcp://127.0.0.1:40435 | Total threads: 1 |
Dashboard: http://127.0.0.1:37093/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:41905 | |
Local directory: /tmp/dask-worker-space/worker-nb0yv_rz | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 3
Comm: tcp://127.0.0.1:41707 | Total threads: 1 |
Dashboard: http://127.0.0.1:37285/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:43925 | |
Local directory: /tmp/dask-worker-space/worker-8vibnk55 | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 4
Comm: tcp://127.0.0.1:40165 | Total threads: 1 |
Dashboard: http://127.0.0.1:46549/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:40305 | |
Local directory: /tmp/dask-worker-space/worker-p9qwcklg | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 5
Comm: tcp://127.0.0.1:36597 | Total threads: 1 |
Dashboard: http://127.0.0.1:42895/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:41439 | |
Local directory: /tmp/dask-worker-space/worker-y00valwq | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 6
Comm: tcp://127.0.0.1:45495 | Total threads: 1 |
Dashboard: http://127.0.0.1:44953/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:37091 | |
Local directory: /tmp/dask-worker-space/worker-4lj3i2cp | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
Worker: 7
Comm: tcp://127.0.0.1:46123 | Total threads: 1 |
Dashboard: http://127.0.0.1:40675/status | Memory: 62.97 GiB |
Nanny: tcp://127.0.0.1:36179 | |
Local directory: /tmp/dask-worker-space/worker-dvzzo2d1 | |
GPU: Tesla V100-SXM2-32GB | GPU memory: 31.75 GiB |
train_output = os.path.join(DATA_DIR, "train")
print("Training output data: "+train_output)
val_output = os.path.join(DATA_DIR, "val")
print("Validation output data: "+val_output)
train_input = os.path.join(DATA_DIR, "train/train.txt")
print("Training dataset: "+train_input)
val_input = os.path.join(DATA_DIR, "val/test.txt")
PREPROCESS_DIR_temp_train = os.path.join(DATA_DIR, 'train/temp-parquet-after-conversion')
PREPROCESS_DIR_temp_val = os.path.join(DATA_DIR, "val/temp-parquet-after-conversion")
if not os.path.exists(PREPROCESS_DIR_temp_train):
os.makedirs(PREPROCESS_DIR_temp_train)
if not os.path.exists(PREPROCESS_DIR_temp_val):
os.makedirs(PREPROCESS_DIR_temp_val)
PREPROCESS_DIR_temp = [PREPROCESS_DIR_temp_train, PREPROCESS_DIR_temp_val]
# Make sure we have a clean parquet space for cudf conversion
for one_path in PREPROCESS_DIR_temp:
if os.path.exists(one_path):
shutil.rmtree(one_path)
os.mkdir(one_path)
#calculate the total processing time
runtime = time.time()
## train/valid txt to parquet
train_valid_paths = [(train_input,PREPROCESS_DIR_temp_train),(val_input,PREPROCESS_DIR_temp_val)]
for input, temp_output in train_valid_paths:
ddf = dask_cudf.read_csv(input,sep='\t',names=LABEL_COLUMNS + CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS)
ddf["label"] = ddf['label'].astype('float32')
ddf[CONTINUOUS_COLUMNS] = ddf[CONTINUOUS_COLUMNS].astype('float32')
# Save it as parquet format for better memory usage
ddf.to_parquet(temp_output,header=True)
##-----------------------------------##
COLUMNS = LABEL_COLUMNS + CONTINUOUS_COLUMNS + CROSS_COLUMNS + CATEGORICAL_COLUMNS
train_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_train, "*.parquet"))
valid_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_val, "*.parquet"))
categorify_op = Categorify()
cat_features = CATEGORICAL_COLUMNS >> categorify_op
cont_features = CONTINUOUS_COLUMNS >> FillMissing() >> Clip(min_value=0) >> Normalize()
cross_cat_op = Categorify(encode_type="combo")
features = LABEL_COLUMNS
features += cont_features
if CROSS_COLUMNS:
feature_pairs = [pair.split("_") for pair in CROSS_COLUMNS]
for pair in feature_pairs:
features += [pair] >> cross_cat_op
features += cat_features
workflow = nvt.Workflow(features)
logging.info("Preprocessing")
output_format = 'parquet'
# just for /samples/criteo model
train_ds_iterator = nvt.Dataset(train_paths, engine='parquet')
valid_ds_iterator = nvt.Dataset(valid_paths, engine='parquet')
shuffle = nvt.io.Shuffle.PER_PARTITION
logging.info('Train Datasets Preprocessing.....')
dict_dtypes = {}
for col in CATEGORICAL_COLUMNS:
dict_dtypes[col] = np.int64
for col in CONTINUOUS_COLUMNS:
dict_dtypes[col] = np.float32
for col in CROSS_COLUMNS:
dict_dtypes[col] = np.int64
for col in LABEL_COLUMNS:
dict_dtypes[col] = np.float32
conts = CONTINUOUS_COLUMNS
workflow.fit(train_ds_iterator)
if output_format == 'hugectr':
workflow.transform(train_ds_iterator).to_hugectr(
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
output_path=train_output,
shuffle=shuffle)
else:
workflow.transform(train_ds_iterator).to_parquet(
output_path=train_output,
dtypes=dict_dtypes,
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
shuffle=shuffle)
###Getting slot size###
#--------------------##
embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]
print(embeddings)
##--------------------##
logging.info('Valid Datasets Preprocessing.....')
if output_format == 'hugectr':
workflow.transform(valid_ds_iterator).to_hugectr(
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
output_path=val_output,
shuffle=shuffle)
else:
workflow.transform(valid_ds_iterator).to_parquet(
output_path=val_output,
dtypes=dict_dtypes,
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
shuffle=shuffle)
embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]
print(embeddings)
##--------------------##
## Shutdown clusters
client.close()
runtime = time.time() - runtime
print("\nDask-NVTabular Criteo Preprocessing Done!")
print(f"Runtime[s] | {runtime}")
print("======================================\n")
Training output data: /hugectr_e2e/criteo/train
Validation output data: /hugectr_e2e/criteo/val
Training dataset: /hugectr_e2e/criteo/train/train.txt
2023-01-06 04:03:19,407 Preprocessing
2023-01-06 04:03:19,749 Train Datasets Preprocessing.....
/usr/local/lib/python3.8/dist-packages/merlin/io/dataset.py:867: UserWarning: Only creating 5 files. Did not have enough partitions to create 8 files.
warnings.warn(
2023-01-06 04:03:25,189 Valid Datasets Preprocessing.....
[1234907, 19683, 13780, 6867, 18490, 4, 6264, 1235, 50, 854680, 114026, 75736, 11, 2159, 7533, 61, 4, 919, 15, 1307783, 404742, 1105613, 87714, 9032, 77, 34, 1581605, 1093030]
[1234907, 19683, 13780, 6867, 18490, 4, 6264, 1235, 50, 854680, 114026, 75736, 11, 2159, 7533, 61, 4, 919, 15, 1307783, 404742, 1105613, 87714, 9032, 77, 34, 1581605, 1093030]
Dask-NVTabular Criteo Preprocessing Done!
Runtime[s] | 11.454381227493286
======================================
client.shutdown()
cluster.close()
### Record the slot size array
SLOT_SIZE_ARRAY = embeddings
Training a WDL model with HugeCTR
%%writefile './train.py'
import hugectr
import os
import argparse
from mpi4py import MPI
parser = argparse.ArgumentParser(description=("HugeCTR Training"))
parser.add_argument("--data_path", type=str, help="Input dataset path (Required)")
parser.add_argument("--model_path", type=str, help="Directory path to write output (Required)")
args = parser.parse_args()
SLOT_SIZE_ARRAY = [1234907, 19683, 13780, 6867, 18490, 4, 6264, 1235, 50, 854680, 114026, 75736, 11, 2159, 7533, 61, 4, 919, 15, 1307783, 404742, 1105613, 87714, 9032, 77, 34, 1581605, 1093030]
solver = hugectr.CreateSolver(max_eval_batches = 4000,
batchsize_eval = 2720,
batchsize = 2720,
lr = 0.001,
vvgpu = [[0]],
repeat_dataset = True,
i64_input_key = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
source = [os.path.join(args.data_path, "train/_file_list.txt")],
eval_source = os.path.join(args.data_path, "val/_file_list.txt"),
check_type = hugectr.Check_t.Non,
slot_size_array = SLOT_SIZE_ARRAY)
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam,
update_type = hugectr.Update_t.Global,
beta1 = 0.9,
beta2 = 0.999,
epsilon = 0.0000001)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
dense_dim = 13, dense_name = "dense",
data_reader_sparse_param_array =
[hugectr.DataReaderSparseParam("wide_data", 1, True, 2),
hugectr.DataReaderSparseParam("deep_data", 2, False, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 80,
embedding_vec_size = 1,
combiner = "sum",
sparse_embedding_name = "sparse_embedding2",
bottom_name = "wide_data",
optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 1350,
embedding_vec_size = 16,
combiner = "sum",
sparse_embedding_name = "sparse_embedding1",
bottom_name = "deep_data",
optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding1"],
top_names = ["reshape1"],
leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding2"],
top_names = ["reshape2"],
leading_dim=2))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReduceSum,
bottom_names = ["reshape2"],
top_names = ["wide_redn"],
axis = 1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
bottom_names = ["reshape1", "dense"],
top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["concat1"],
top_names = ["fc1"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc1"],
top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu1"],
top_names = ["dropout1"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout1"],
top_names = ["fc2"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc2"],
top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu2"],
top_names = ["dropout2"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout2"],
top_names = ["fc3"],
num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
bottom_names = ["fc3", "wide_redn"],
top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
bottom_names = ["add1", "label"],
top_names = ["loss"]))
model.compile()
model.summary()
model.fit(max_iter = 21000, display = 1000, eval_interval = 4000, snapshot = 20000, snapshot_prefix = os.path.join(args.model_path, "wdl/"))
model.graph_to_json(graph_config_file = os.path.join(args.model_path, "wdl.json"))
Overwriting ./train.py
!python train.py --data_path $DATA_DIR --model_path $MODEL_DIR
HugeCTR Version: 4.2
====================================================Model Init=====================================================
[HCTR][04:03:34.493][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][04:03:34.493][INFO][RK0][main]: Global seed is 1268002110
[HCTR][04:03:34.495][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][04:03:36.281][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][04:03:36.281][DEBUG][RK0][main]: [device 0] allocating 0.0000 GB, available 30.1804
[HCTR][04:03:36.281][INFO][RK0][main]: Start all2all warmup
[HCTR][04:03:36.281][INFO][RK0][main]: End all2all warmup
[HCTR][04:03:36.282][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][04:03:36.282][INFO][RK0][main]: Device 0: Tesla V100-SXM2-32GB
[HCTR][04:03:36.283][INFO][RK0][main]: num of DataReader workers for train: 1
[HCTR][04:03:36.283][INFO][RK0][main]: num of DataReader workers for eval: 1
[HCTR][04:03:36.283][DEBUG][RK0][main]: [device 0] allocating 0.0054 GB, available 29.9246
[HCTR][04:03:36.283][DEBUG][RK0][main]: [device 0] allocating 0.0054 GB, available 29.9187
[HCTR][04:03:36.284][DEBUG][RK0][main]: [device 0] allocating 0.0000 GB, available 29.9187
[HCTR][04:03:36.284][DEBUG][RK0][main]: [device 0] allocating 0.0000 GB, available 29.9187
[HCTR][04:03:36.284][INFO][RK0][main]: Vocabulary size: 7946054
[HCTR][04:03:36.285][INFO][RK0][main]: max_vocabulary_size_per_gpu_=6990506
[HCTR][04:03:36.298][DEBUG][RK0][main]: [device 0] allocating 0.0788 GB, available 29.5886
[HCTR][04:03:36.298][INFO][RK0][main]: max_vocabulary_size_per_gpu_=7372800
[HCTR][04:03:36.305][DEBUG][RK0][main]: [device 0] allocating 1.3516 GB, available 28.1101
[HCTR][04:03:36.306][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][04:03:36.308][DEBUG][RK0][main]: [device 0] allocating 0.2162 GB, available 27.8777
[HCTR][04:03:36.309][DEBUG][RK0][main]: [device 0] allocating 0.0056 GB, available 27.8718
[HCTR][04:03:44.735][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][04:03:44.735][INFO][RK0][main]: gpu0 init embedding done
[HCTR][04:03:44.735][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][04:03:44.738][INFO][RK0][main]: gpu0 init embedding done
[HCTR][04:03:44.738][DEBUG][RK0][main]: [device 0] allocating 0.0001 GB, available 27.8718
[HCTR][04:03:44.740][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][04:03:44.744][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][04:03:44.745][INFO][RK0][main]: Model structure on each GPU
Label Dense Sparse
label dense wide_data,deep_data
(2720,1) (2720,13)
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type Input Name Output Name Output Shape
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash wide_data sparse_embedding2 (2720,2,1)
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash deep_data sparse_embedding1 (2720,26,16)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding1 reshape1 (2720,416)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding2 reshape2 (2720,2)
------------------------------------------------------------------------------------------------------------------
ReduceSum reshape2 wide_redn (2720,1)
------------------------------------------------------------------------------------------------------------------
Concat reshape1 concat1 (2720,429)
dense
------------------------------------------------------------------------------------------------------------------
InnerProduct concat1 fc1 (2720,1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc1 relu1 (2720,1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu1 dropout1 (2720,1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout1 fc2 (2720,1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc2 relu2 (2720,1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu2 dropout2 (2720,1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout2 fc3 (2720,1)
------------------------------------------------------------------------------------------------------------------
Add fc3 add1 (2720,1)
wide_redn
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss add1 loss
label
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[HCTR][04:03:44.745][INFO][RK0][main]: Use non-epoch mode with number of iterations: 21000
[HCTR][04:03:44.745][INFO][RK0][main]: Training batchsize: 2720, evaluation batchsize: 2720
[HCTR][04:03:44.745][INFO][RK0][main]: Evaluation interval: 4000, snapshot interval: 20000
[HCTR][04:03:44.745][INFO][RK0][main]: Dense network trainable: True
[HCTR][04:03:44.745][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][04:03:44.745][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][04:03:44.745][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][04:03:44.745][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][04:03:44.745][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][04:03:44.745][INFO][RK0][main]: Training source file: /hugectr_e2e/criteo/train/_file_list.txt
[HCTR][04:03:44.745][INFO][RK0][main]: Evaluation source file: /hugectr_e2e/criteo/val/_file_list.txt
[HCTR][04:03:53.398][INFO][RK0][main]: Iter: 1000 Time(1000 iters): 8.64913s Loss: 0.121202 lr:0.001
[HCTR][04:04:02.135][INFO][RK0][main]: Iter: 2000 Time(1000 iters): 8.73203s Loss: 0.119842 lr:0.001
[HCTR][04:04:11.063][INFO][RK0][main]: Iter: 3000 Time(1000 iters): 8.92374s Loss: 0.140827 lr:0.001
[HCTR][04:04:20.136][INFO][RK0][main]: Iter: 4000 Time(1000 iters): 9.06817s Loss: 0.112631 lr:0.001
[HCTR][04:04:24.719][INFO][RK0][main]: Evaluation, AUC: 0.740936
[HCTR][04:04:24.719][INFO][RK0][main]: Eval Time for 4000 iters: 4.58235s
[HCTR][04:04:33.827][INFO][RK0][main]: Iter: 5000 Time(1000 iters): 13.687s Loss: 0.0898167 lr:0.001
[HCTR][04:04:42.918][INFO][RK0][main]: Iter: 6000 Time(1000 iters): 9.0867s Loss: 0.0845359 lr:0.001
[HCTR][04:04:52.021][INFO][RK0][main]: Iter: 7000 Time(1000 iters): 9.09826s Loss: 0.103367 lr:0.001
[HCTR][04:05:01.114][INFO][RK0][main]: Iter: 8000 Time(1000 iters): 9.0885s Loss: 0.0986205 lr:0.001
[HCTR][04:05:05.659][INFO][RK0][main]: Evaluation, AUC: 0.697036
[HCTR][04:05:05.659][INFO][RK0][main]: Eval Time for 4000 iters: 4.5436s
[HCTR][04:05:14.755][INFO][RK0][main]: Iter: 9000 Time(1000 iters): 13.6362s Loss: 0.0968494 lr:0.001
[HCTR][04:05:23.829][INFO][RK0][main]: Iter: 10000 Time(1000 iters): 9.06979s Loss: 0.107726 lr:0.001
[HCTR][04:05:32.906][INFO][RK0][main]: Iter: 11000 Time(1000 iters): 9.07225s Loss: 0.0775421 lr:0.001
[HCTR][04:05:42.011][INFO][RK0][main]: Iter: 12000 Time(1000 iters): 9.10076s Loss: 0.0998987 lr:0.001
[HCTR][04:05:46.553][INFO][RK0][main]: Evaluation, AUC: 0.695771
[HCTR][04:05:46.554][INFO][RK0][main]: Eval Time for 4000 iters: 4.54159s
[HCTR][04:05:55.651][INFO][RK0][main]: Iter: 13000 Time(1000 iters): 13.636s Loss: 0.0899532 lr:0.001
[HCTR][04:06:04.757][INFO][RK0][main]: Iter: 14000 Time(1000 iters): 9.1015s Loss: 0.0869383 lr:0.001
[HCTR][04:06:13.862][INFO][RK0][main]: Iter: 15000 Time(1000 iters): 9.10059s Loss: 0.0718445 lr:0.001
[HCTR][04:06:22.967][INFO][RK0][main]: Iter: 16000 Time(1000 iters): 9.10076s Loss: 0.0784668 lr:0.001
[HCTR][04:06:27.514][INFO][RK0][main]: Evaluation, AUC: 0.68094
[HCTR][04:06:27.514][INFO][RK0][main]: Eval Time for 4000 iters: 4.54556s
[HCTR][04:06:36.619][INFO][RK0][main]: Iter: 17000 Time(1000 iters): 13.6469s Loss: 0.0711678 lr:0.001
[HCTR][04:06:45.710][INFO][RK0][main]: Iter: 18000 Time(1000 iters): 9.08659s Loss: 0.0941869 lr:0.001
[HCTR][04:06:54.800][INFO][RK0][main]: Iter: 19000 Time(1000 iters): 9.08513s Loss: 0.0780168 lr:0.001
[HCTR][04:07:03.884][INFO][RK0][main]: Iter: 20000 Time(1000 iters): 9.07991s Loss: 0.0978516 lr:0.001
[HCTR][04:07:08.432][INFO][RK0][main]: Evaluation, AUC: 0.67476
[HCTR][04:07:08.432][INFO][RK0][main]: Eval Time for 4000 iters: 4.54741s
[HCTR][04:07:08.433][INFO][RK0][main]: Using Local file system backend.
[HCTR][04:07:08.456][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][04:07:08.473][INFO][RK0][main]: Using Local file system backend.
[HCTR][04:07:08.627][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][04:07:08.799][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][04:07:08.814][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][04:07:08.814][INFO][RK0][main]: Using Local file system backend.
[HCTR][04:07:08.829][INFO][RK0][main]: Done
[HCTR][04:07:08.832][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][04:07:08.832][INFO][RK0][main]: Using Local file system backend.
[HCTR][04:07:08.845][INFO][RK0][main]: Done
[HCTR][04:07:09.109][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][04:07:09.109][INFO][RK0][main]: Using Local file system backend.
[HCTR][04:07:09.318][INFO][RK0][main]: Done
[HCTR][04:07:09.597][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][04:07:09.597][INFO][RK0][main]: Using Local file system backend.
[HCTR][04:07:09.808][INFO][RK0][main]: Done
[HCTR][04:07:09.825][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][04:07:09.826][INFO][RK0][main]: Using Local file system backend.
[HCTR][04:07:09.829][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][04:07:09.831][INFO][RK0][main]: Using Local file system backend.
[HCTR][04:07:09.837][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][04:07:18.897][INFO][RK0][main]: Finish 21000 iterations with batchsize: 2720 in 214.15s.
[HCTR][04:07:18.897][INFO][RK0][main]: Save the model graph to /hugectr_e2e/model/wdl.json successfully
Load model to HPS and inference with HugeCTR
from hugectr.inference import InferenceModel, InferenceParams
import hugectr
import os
model_config = os.path.join(MODEL_DIR, "wdl.json")
inference_params = InferenceParams(
model_name = "wdl",
max_batchsize = 1024,
hit_rate_threshold = 1.0,
dense_model_file = os.path.join(MODEL_DIR, "wdl/_dense_20000.model"),
sparse_model_files = [os.path.join(MODEL_DIR, "wdl/0_sparse_20000.model"), os.path.join(MODEL_DIR, "wdl/1_sparse_20000.model")],
deployed_devices = [0],
use_gpu_embedding_cache = True,
cache_size_percentage = 1.0,
i64_input_key = True
)
inference_model = InferenceModel(model_config, inference_params)
pred = inference_model.predict(
10,
"/hugectr_e2e/criteo/val/_file_list.txt",
hugectr.DataReaderType_t.Parquet,
hugectr.Check_t.Non,
SLOT_SIZE_ARRAY
)
print(pred.shape)
print(pred)
[HCTR][05:22:53.279][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][05:22:53.279][INFO][RK0][main]: Global seed is 2968606722
[HCTR][05:22:53.279][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][05:22:53.317][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][05:22:53.317][DEBUG][RK0][main]: [device 0] allocating 0.0000 GB, available 29.8757
[HCTR][05:22:53.317][INFO][RK0][main]: Start all2all warmup
[HCTR][05:22:53.317][INFO][RK0][main]: End all2all warmup
[HCTR][05:22:53.318][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
[HCTR][05:22:53.318][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][05:22:53.318][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][05:22:53.318][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][05:22:53.318][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][05:22:53.318][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][05:22:53.318][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][05:22:53.318][INFO][RK0][main]: Using Local file system backend.
[HCTR][05:22:53.659][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding2; cached 2327936 / 2327936 embeddings in volatile database (HashMapBackend); load: 2327936 / 18446744073709551615 (0.00%).
[HCTR][05:22:53.659][INFO][RK0][main]: Using Local file system backend.
[HCTR][05:22:54.219][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding1; cached 4169063 / 4169063 embeddings in volatile database (HashMapBackend); load: 4169063 / 18446744073709551615 (0.00%).
[HCTR][05:22:54.235][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][05:22:54.235][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][05:22:54.239][INFO][RK0][main]: Model name: wdl
[HCTR][05:22:54.239][INFO][RK0][main]: Max batch size: 1024
[HCTR][05:22:54.239][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][05:22:54.239][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][05:22:54.239][INFO][RK0][main]: Use static table: False
[HCTR][05:22:54.239][INFO][RK0][main]: Use I64 input key: True
[HCTR][05:22:54.239][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][05:22:54.239][INFO][RK0][main]: The size of thread pool: 80
[HCTR][05:22:54.239][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][05:22:54.239][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][05:22:54.239][INFO][RK0][main]: The refresh percentage : 0.000000
[HCTR][05:22:54.246][INFO][RK0][main]: Model name: wdl
[HCTR][05:22:54.246][INFO][RK0][main]: Use mixed precision: False
[HCTR][05:22:54.246][INFO][RK0][main]: Use cuda graph: True
[HCTR][05:22:54.247][INFO][RK0][main]: Max batchsize: 1024
[HCTR][05:22:54.247][INFO][RK0][main]: Use I64 input key: True
[HCTR][05:22:54.247][INFO][RK0][main]: start create embedding for inference
[HCTR][05:22:54.247][INFO][RK0][main]: sparse_input name wide_data
[HCTR][05:22:54.247][INFO][RK0][main]: sparse_input name deep_data
[HCTR][05:22:54.247][INFO][RK0][main]: create embedding for inference success
[HCTR][05:22:54.247][DEBUG][RK0][main]: [device 0] allocating 0.0049 GB, available 29.2468
[HCTR][05:22:54.248][INFO][RK0][main]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][05:22:54.248][DEBUG][RK0][main]: [device 0] allocating 0.0388 GB, available 29.2000
[HCTR][05:22:55.955][DEBUG][RK0][main]: [device 0] allocating 0.0001 GB, available 29.1960
[HCTR][05:22:56.325][INFO][RK0][main]: Create inference data reader on 1 GPU(s)
[HCTR][05:22:56.325][INFO][RK0][main]: num of DataReader workers: 1
[HCTR][05:22:56.325][DEBUG][RK0][main]: [device 0] allocating 0.0020 GB, available 29.9617
[HCTR][05:22:56.325][DEBUG][RK0][main]: [device 0] allocating 0.0000 GB, available 29.9617
[HCTR][05:22:56.326][INFO][RK0][main]: Vocabulary size: 79460(10240, 1)
[[2.0730977e-07]
[2.4348024e-05]
[4.8165547e-04]
...
[1.3334073e-10]
[8.7153958e-03]
[2.7467359e-02]]
54
[HCTR][05:22:56.352][INFO][RK0][main]: Inference time for 10 batches: 0.02287