# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
HugeCTR Wide and Deep Model with Criteo
Overview
In this notebook, we provide a tutorial that shows how to train a wide and deep model using the high-level Python API from HugeCTR on the original Criteo dataset as training data. We show how to produce prediction results based on different types of local database.
Setup HugeCTR
To setup the environment, refer to HugeCTR Example Notebooks and follow the instructions there before running the following.
Dataset Preprocessing
Generate training and validation data folders
# define some data folder to store the original and preprocessed data
# Standard Libraries
import os
from time import time
import re
import shutil
import glob
import warnings
BASE_DIR = "/wdl_train"
train_path = os.path.join(BASE_DIR, "train")
val_path = os.path.join(BASE_DIR, "val")
CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
n_workers = len(CUDA_VISIBLE_DEVICES.split(","))
frac_size = 0.15
allow_multi_gpu = False
use_rmm_pool = False
max_day = None # (Optional) -- Limit the dataset to day 0-max_day for debugging
if os.path.isdir(train_path):
shutil.rmtree(train_path)
os.makedirs(train_path)
if os.path.isdir(val_path):
shutil.rmtree(val_path)
os.makedirs(val_path)
!ls -l $train_path
total 0
Download the original Criteo dataset
!apt-get install wget
!wget -P $train_path https://storage.googleapis.com/criteo-cail-datasets/day_0.gz
Split the dataset into training and validation.
#!gzip -d -c $train_path/day_0.gz > day_0
!head -n 45840617 day_0 > $train_path/train.txt
!tail -n 2000000 day_0 > $val_path/test.txt
Preprocessing with NVTabular
%%writefile /wdl_train/preprocess.py
import os
import sys
import argparse
import glob
import time
import numpy as np
import pandas as pd
import shutil
import dask_cudf
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import cudf
import rmm
import nvtabular as nvt
from nvtabular.io import Shuffle
from nvtabular.utils import device_mem_size
from nvtabular.ops import Categorify, Clip, FillMissing, LambdaOp, Normalize, Rename, Operator, get_embedding_sizes
#%load_ext memory_profiler
import logging
logging.basicConfig(format='%(asctime)s %(message)s')
logging.root.setLevel(logging.NOTSET)
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('asyncio').setLevel(logging.WARNING)
# define dataset schema
CATEGORICAL_COLUMNS=["C" + str(x) for x in range(1, 27)]
CONTINUOUS_COLUMNS=["I" + str(x) for x in range(1, 14)]
LABEL_COLUMNS = ['label']
COLUMNS = LABEL_COLUMNS + CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS
#/samples/criteo mode doesn't have dense features
criteo_COLUMN=LABEL_COLUMNS + CATEGORICAL_COLUMNS
#For new feature cross columns
CROSS_COLUMNS = []
NUM_INTEGER_COLUMNS = 13
NUM_CATEGORICAL_COLUMNS = 26
NUM_TOTAL_COLUMNS = 1 + NUM_INTEGER_COLUMNS + NUM_CATEGORICAL_COLUMNS
# Initialize RMM pool on ALL workers
def setup_rmm_pool(client, pool_size):
client.run(rmm.reinitialize, pool_allocator=True, initial_pool_size=pool_size)
return None
#compute the partition size with GB
def bytesto(bytes, to, bsize=1024):
a = {'k' : 1, 'm': 2, 'g' : 3, 't' : 4, 'p' : 5, 'e' : 6 }
r = float(bytes)
return bytes / (bsize ** a[to])
#process the data with NVTabular
def process_NVT(args):
if args.feature_cross_list:
feature_pairs = [pair.split("_") for pair in args.feature_cross_list.split(",")]
for pair in feature_pairs:
CROSS_COLUMNS.append(pair[0]+'_'+pair[1])
logging.info('NVTabular processing')
train_input = os.path.join(args.data_path, "train/train.txt")
val_input = os.path.join(args.data_path, "val/test.txt")
PREPROCESS_DIR_temp_train = os.path.join(args.out_path, 'train/temp-parquet-after-conversion')
PREPROCESS_DIR_temp_val = os.path.join(args.out_path, 'val/temp-parquet-after-conversion')
PREPROCESS_DIR_temp = [PREPROCESS_DIR_temp_train, PREPROCESS_DIR_temp_val]
train_output = os.path.join(args.out_path, "train")
val_output = os.path.join(args.out_path, "val")
# Make sure we have a clean parquet space for cudf conversion
for one_path in PREPROCESS_DIR_temp:
if os.path.exists(one_path):
shutil.rmtree(one_path)
os.mkdir(one_path)
## Get Dask Client
# Deploy a Single-Machine Multi-GPU Cluster
device_size = device_mem_size(kind="total")
cluster = None
if args.protocol == "ucx":
UCX_TLS = os.environ.get("UCX_TLS", "tcp,cuda_copy,cuda_ipc,sockcm")
os.environ["UCX_TLS"] = UCX_TLS
cluster = LocalCUDACluster(
protocol = args.protocol,
CUDA_VISIBLE_DEVICES = args.devices,
n_workers = len(args.devices.split(",")),
enable_nvlink=True,
device_memory_limit = int(device_size * args.device_limit_frac),
dashboard_address=":" + args.dashboard_port
)
else:
cluster = LocalCUDACluster(
protocol = args.protocol,
n_workers = len(args.devices.split(",")),
CUDA_VISIBLE_DEVICES = args.devices,
device_memory_limit = int(device_size * args.device_limit_frac),
dashboard_address=":" + args.dashboard_port
)
# Create the distributed client
client = Client(cluster)
if args.device_pool_frac > 0.01:
setup_rmm_pool(client, int(args.device_pool_frac*device_size))
#calculate the total processing time
runtime = time.time()
#test dataset without the label feature
if args.dataset_type == 'test':
global LABEL_COLUMNS
LABEL_COLUMNS = []
##-----------------------------------##
# Dask rapids converts txt to parquet
# Dask cudf dataframe = ddf
## train/valid txt to parquet
train_valid_paths = [(train_input,PREPROCESS_DIR_temp_train),(val_input,PREPROCESS_DIR_temp_val)]
for input, temp_output in train_valid_paths:
ddf = dask_cudf.read_csv(input,sep='\t',names=LABEL_COLUMNS + CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS)
## Convert label col to FP32
if args.parquet_format and args.dataset_type == 'train':
ddf["label"] = ddf['label'].astype('float32')
# Save it as parquet format for better memory usage
ddf.to_parquet(temp_output,header=True)
##-----------------------------------##
COLUMNS = LABEL_COLUMNS + CONTINUOUS_COLUMNS + CROSS_COLUMNS + CATEGORICAL_COLUMNS
train_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_train, "*.parquet"))
valid_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_val, "*.parquet"))
categorify_op = Categorify(freq_threshold=args.freq_limit)
cat_features = CATEGORICAL_COLUMNS >> categorify_op
cont_features = CONTINUOUS_COLUMNS >> FillMissing() >> Clip(min_value=0) >> Normalize()
cross_cat_op = Categorify(encode_type="combo", freq_threshold=args.freq_limit)
features = LABEL_COLUMNS
if args.criteo_mode == 0:
features += cont_features
if args.feature_cross_list:
feature_pairs = [pair.split("_") for pair in args.feature_cross_list.split(",")]
for pair in feature_pairs:
features += [pair] >> cross_cat_op
features += cat_features
workflow = nvt.Workflow(features, client=client)
logging.info("Preprocessing")
output_format = 'hugectr'
if args.parquet_format:
output_format = 'parquet'
# just for /samples/criteo model
train_ds_iterator = nvt.Dataset(train_paths, engine='parquet', part_size=int(args.part_mem_frac * device_size))
valid_ds_iterator = nvt.Dataset(valid_paths, engine='parquet', part_size=int(args.part_mem_frac * device_size))
shuffle = None
if args.shuffle == "PER_WORKER":
shuffle = nvt.io.Shuffle.PER_WORKER
elif args.shuffle == "PER_PARTITION":
shuffle = nvt.io.Shuffle.PER_PARTITION
logging.info('Train Datasets Preprocessing.....')
dict_dtypes = {}
for col in CATEGORICAL_COLUMNS:
dict_dtypes[col] = np.int64
if not args.criteo_mode:
for col in CONTINUOUS_COLUMNS:
dict_dtypes[col] = np.float32
for col in CROSS_COLUMNS:
dict_dtypes[col] = np.int64
for col in LABEL_COLUMNS:
dict_dtypes[col] = np.float32
conts = CONTINUOUS_COLUMNS if not args.criteo_mode else []
workflow.fit(train_ds_iterator)
if output_format == 'hugectr':
workflow.transform(train_ds_iterator).to_hugectr(
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
output_path=train_output,
shuffle=shuffle,
out_files_per_proc=args.out_files_per_proc,
num_threads=args.num_io_threads)
else:
workflow.transform(train_ds_iterator).to_parquet(
output_path=train_output,
dtypes=dict_dtypes,
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
shuffle=shuffle,
out_files_per_proc=args.out_files_per_proc,
num_threads=args.num_io_threads)
###Getting slot size###
#--------------------##
embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]
print(embeddings)
##--------------------##
logging.info('Valid Datasets Preprocessing.....')
if output_format == 'hugectr':
workflow.transform(valid_ds_iterator).to_hugectr(
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
output_path=val_output,
shuffle=shuffle,
out_files_per_proc=args.out_files_per_proc,
num_threads=args.num_io_threads)
else:
workflow.transform(valid_ds_iterator).to_parquet(
output_path=val_output,
dtypes=dict_dtypes,
cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
conts=conts,
labels=LABEL_COLUMNS,
shuffle=shuffle,
out_files_per_proc=args.out_files_per_proc,
num_threads=args.num_io_threads)
embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]
print(embeddings)
##--------------------##
## Shutdown clusters
client.close()
logging.info('NVTabular processing done')
runtime = time.time() - runtime
print("\nDask-NVTabular Criteo Preprocessing")
print("--------------------------------------")
print(f"data_path | {args.data_path}")
print(f"output_path | {args.out_path}")
print(f"partition size | {'%.2f GB'%bytesto(int(args.part_mem_frac * device_size),'g')}")
print(f"protocol | {args.protocol}")
print(f"device(s) | {args.devices}")
print(f"rmm-pool-frac | {(args.device_pool_frac)}")
print(f"out-files-per-proc | {args.out_files_per_proc}")
print(f"num_io_threads | {args.num_io_threads}")
print(f"shuffle | {args.shuffle}")
print("======================================")
print(f"Runtime[s] | {runtime}")
print("======================================\n")
def parse_args():
parser = argparse.ArgumentParser(description=("Multi-GPU Criteo Preprocessing"))
#
# System Options
#
parser.add_argument("--data_path", type=str, help="Input dataset path (Required)")
parser.add_argument("--out_path", type=str, help="Directory path to write output (Required)")
parser.add_argument(
"-d",
"--devices",
default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"),
type=str,
help='Comma-separated list of visible devices (e.g. "0,1,2,3"). '
)
parser.add_argument(
"-p",
"--protocol",
choices=["tcp", "ucx"],
default="tcp",
type=str,
help="Communication protocol to use (Default 'tcp')",
)
parser.add_argument(
"--device_limit_frac",
default=0.5,
type=float,
help="Worker device-memory limit as a fraction of GPU capacity (Default 0.8). "
)
parser.add_argument(
"--device_pool_frac",
default=0.9,
type=float,
help="RMM pool size for each worker as a fraction of GPU capacity (Default 0.9). "
"The RMM pool frac is the same for all GPUs, make sure each one has enough memory size",
)
parser.add_argument(
"--num_io_threads",
default=0,
type=int,
help="Number of threads to use when writing output data (Default 0). "
"If 0 is specified, multi-threading will not be used for IO.",
)
#
# Data-Decomposition Parameters
#
parser.add_argument(
"--part_mem_frac",
default=0.125,
type=float,
help="Maximum size desired for dataset partitions as a fraction "
"of GPU capacity (Default 0.125)",
)
parser.add_argument(
"--out_files_per_proc",
default=1,
type=int,
help="Number of output files to write on each worker (Default 1)",
)
#
# Preprocessing Options
#
parser.add_argument(
"-f",
"--freq_limit",
default=0,
type=int,
help="Frequency limit for categorical encoding (Default 0)",
)
parser.add_argument(
"-s",
"--shuffle",
choices=["PER_WORKER", "PER_PARTITION", "NONE"],
default="PER_PARTITION",
help="Shuffle algorithm to use when writing output data to disk (Default PER_PARTITION)",
)
parser.add_argument(
"--feature_cross_list", default=None, type=str, help="List of feature crossing cols (e.g. C1_C2, C3_C4)"
)
#
# Diagnostics Options
#
parser.add_argument(
"--profile",
metavar="PATH",
default=None,
type=str,
help="Specify a file path to export a Dask profile report (E.g. dask-report.html)."
"If this option is excluded from the command, not profile will be exported",
)
parser.add_argument(
"--dashboard_port",
default="8787",
type=str,
help="Specify the desired port of Dask's diagnostics-dashboard (Default `3787`). "
"The dashboard will be hosted at http://<IP>:<PORT>/status",
)
#
# Format
#
parser.add_argument('--criteo_mode', type=int, default=0)
parser.add_argument('--parquet_format', type=int, default=1)
parser.add_argument('--dataset_type', type=str, default='train')
args = parser.parse_args()
args.n_workers = len(args.devices.split(","))
return args
if __name__ == '__main__':
args = parse_args()
process_NVT(args)
Writing /wdl_train/preprocess.py
!python3 /wdl_train/preprocess.py --data_path /wdl_train/ \
--out_path /wdl_train/ --freq_limit 6 --feature_cross_list C1_C2,C3_C4 \
--device_pool_frac 0.5 --devices '0' --num_io_threads 2
2023-02-10 10:32:34,808 NVTabular processing
2023-02-10 10:32:36,590 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2023-02-10 10:32:36,590 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2023-02-10 10:32:36,604 Unable to start CUDA Context
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/pynvml/nvml.py", line 782, in _nvmlGetFunctionPointer
_nvmlGetFunctionPointer_cache[name] = getattr(nvmlLib, name)
File "/usr/lib/python3.8/ctypes/__init__.py", line 386, in __getattr__
func = self.__getitem__(name)
File "/usr/lib/python3.8/ctypes/__init__.py", line 391, in __getitem__
func = self._FuncPtr((name_or_ordinal, self))
AttributeError: /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1: undefined symbol: nvmlDeviceGetComputeRunningProcesses_v2
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/dask_cuda/initialize.py", line 41, in _create_cuda_context
ctx = has_cuda_context()
File "/usr/local/lib/python3.8/dist-packages/distributed/diagnostics/nvml.py", line 164, in has_cuda_context
running_processes = pynvml.nvmlDeviceGetComputeRunningProcesses_v2(handle)
File "/usr/local/lib/python3.8/dist-packages/pynvml/nvml.py", line 2191, in nvmlDeviceGetComputeRunningProcesses_v2
fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v2")
File "/usr/local/lib/python3.8/dist-packages/pynvml/nvml.py", line 785, in _nvmlGetFunctionPointer
raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND)
pynvml.nvml.NVMLError_FunctionNotFound: Function Not Found
/usr/local/lib/python3.8/dist-packages/merlin/core/utils.py:384: FutureWarning: The `client` argument is deprecated from DaskExecutor and will be removed in a future version of NVTabular. By default, a global client in the same python context will be detected automatically, and `merlin.utils.set_dask_client` (as well as `Distributed` and `Serial`) can be used for explicit control.
warnings.warn(
2023-02-10 10:32:54,260 Preprocessing
2023-02-10 10:32:54,521 Train Datasets Preprocessing.....
[249058, 19561, 14212, 6890, 18592, 4, 6356, 1254, 52, 226170, 80508, 72308, 11, 2169, 7597, 61, 4, 923, 15, 249619, 168974, 243480, 68212, 9169, 75, 34, 281564, 415262]
2023-02-10 10:34:09,155 Valid Datasets Preprocessing.....
[249058, 19561, 14212, 6890, 18592, 4, 6356, 1254, 52, 226170, 80508, 72308, 11, 2169, 7597, 61, 4, 923, 15, 249619, 168974, 243480, 68212, 9169, 75, 34, 281564, 415262]
2023-02-10 10:34:10,596 NVTabular processing done
Dask-NVTabular Criteo Preprocessing
--------------------------------------
data_path | /wdl_train/
output_path | /wdl_train/
partition size | 3.97 GB
protocol | tcp
device(s) | 0
rmm-pool-frac | 0.5
out-files-per-proc | 1
num_io_threads | 2
shuffle | PER_PARTITION
======================================
Runtime[s] | 92.90126419067383
======================================
Check the preprocessed training data
!ls -ll /wdl_train/train
total 14449264
-rw-r--r-- 1 root root 34 Feb 10 10:34 _file_list.txt
-rw-r--r-- 1 root root 450893 Feb 10 10:34 _metadata
-rw-r--r-- 1 root root 1510 Feb 10 10:34 _metadata.json
-rw-r--r-- 1 root root 3245838178 Feb 10 10:34 part_0.parquet
-rw-r--r-- 1 root root 27296 Feb 10 10:33 schema.pbtxt
drwxr-xr-x 2 root root 4096 Feb 10 10:32 temp-parquet-after-conversion
-rw-r--r-- 1 root root 11549710546 Feb 10 10:32 train.txt
WDL Model Training
%%writefile './model.py'
import hugectr
#from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 4000,
batchsize_eval = 2720,
batchsize = 2720,
lr = 0.001,
vvgpu = [[2]],
repeat_dataset = True,
i64_input_key = True)
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
source = ["/wdl_train/train/_file_list.txt"],
eval_source = "/wdl_train/val/_file_list.txt",
check_type = hugectr.Check_t.Non,
slot_size_array = [249058, 19561, 14212, 6890, 18592, 4, 6356, 1254, 52, 226170, 80508, 72308, 11, 2169, 7597, 61, 4, 923, 15, 249619, 168974, 243480, 68212, 9169, 75, 34, 278018, 415262])
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam,
update_type = hugectr.Update_t.Global,
beta1 = 0.9,
beta2 = 0.999,
epsilon = 0.0000001)
model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "label",
dense_dim = 13, dense_name = "dense",
data_reader_sparse_param_array =
[hugectr.DataReaderSparseParam("wide_data", 1, True, 2),
hugectr.DataReaderSparseParam("deep_data", 2, False, 26)]))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 24,
embedding_vec_size = 1,
combiner = "sum",
sparse_embedding_name = "sparse_embedding2",
bottom_name = "wide_data",
optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
workspace_size_per_gpu_in_mb = 405,
embedding_vec_size = 16,
combiner = "sum",
sparse_embedding_name = "sparse_embedding1",
bottom_name = "deep_data",
optimizer = optimizer))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding1"],
top_names = ["reshape1"],
leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
bottom_names = ["sparse_embedding2"],
top_names = ["reshape2"],
leading_dim=2))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReduceSum,
bottom_names = ["reshape2"],
top_names = ["wide_redn"],
axis = 1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
bottom_names = ["reshape1", "dense"],
top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["concat1"],
top_names = ["fc1"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc1"],
top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu1"],
top_names = ["dropout1"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout1"],
top_names = ["fc2"],
num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
bottom_names = ["fc2"],
top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
bottom_names = ["relu2"],
top_names = ["dropout2"],
dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
bottom_names = ["dropout2"],
top_names = ["fc3"],
num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
bottom_names = ["fc3", "wide_redn"],
top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
bottom_names = ["add1", "label"],
top_names = ["loss"]))
model.compile()
model.summary()
model.fit(max_iter = 21000, display = 1000, eval_interval = 4000, snapshot = 20000, snapshot_prefix = "wdl")
model.graph_to_json(graph_config_file = "wdl.json")
Overwriting ./model.py
!python ./model.py
HugeCTR Version: 4.2
====================================================Model Init=====================================================
[HCTR][10:35:38.637][WARNING][RK0][main]: The model name is not specified when creating the solver.
[HCTR][10:35:38.637][INFO][RK0][main]: Global seed is 1886280762
[HCTR][10:35:38.640][INFO][RK0][main]: Device to NUMA mapping:
GPU 2 -> node 0
[HCTR][10:35:40.821][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][10:35:40.821][DEBUG][RK0][main]: [device 2] allocating 0.0000 GB, available 30.8035
[HCTR][10:35:40.821][INFO][RK0][main]: Start all2all warmup
[HCTR][10:35:40.821][INFO][RK0][main]: End all2all warmup
[HCTR][10:35:40.822][INFO][RK0][main]: Using All-reduce algorithm: NCCL
[HCTR][10:35:40.823][INFO][RK0][main]: Device 2: Tesla V100-SXM2-32GB
[HCTR][10:35:40.824][INFO][RK0][main]: num of DataReader workers for train: 1
[HCTR][10:35:40.824][INFO][RK0][main]: num of DataReader workers for eval: 1
[HCTR][10:35:40.824][DEBUG][RK0][main]: [device 2] allocating 0.0054 GB, available 30.5476
[HCTR][10:35:40.825][DEBUG][RK0][main]: [device 2] allocating 0.0054 GB, available 30.5417
[HCTR][10:35:40.825][DEBUG][RK0][main]: [device 2] allocating 0.0000 GB, available 30.5417
[HCTR][10:35:40.826][DEBUG][RK0][main]: [device 2] allocating 0.0000 GB, available 30.5417
[HCTR][10:35:40.826][INFO][RK0][main]: Vocabulary size: 2138588
[HCTR][10:35:40.826][INFO][RK0][main]: max_vocabulary_size_per_gpu_=2097152
[HCTR][10:35:40.838][DEBUG][RK0][main]: [device 2] allocating 0.0241 GB, available 30.3914
[HCTR][10:35:40.845][INFO][RK0][main]: max_vocabulary_size_per_gpu_=2211840
[HCTR][10:35:40.851][DEBUG][RK0][main]: [device 2] allocating 0.4288 GB, available 29.9617
[HCTR][10:35:40.851][INFO][RK0][main]: Graph analysis to resolve tensor dependency
===================================================Model Compile===================================================
[HCTR][10:35:40.856][DEBUG][RK0][main]: [device 2] allocating 0.2162 GB, available 29.4792
[HCTR][10:35:40.856][DEBUG][RK0][main]: [device 2] allocating 0.0056 GB, available 29.4734
[HCTR][10:35:50.016][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][10:35:50.016][INFO][RK0][main]: gpu0 init embedding done
[HCTR][10:35:50.016][INFO][RK0][main]: gpu0 start to init embedding
[HCTR][10:35:50.017][INFO][RK0][main]: gpu0 init embedding done
[HCTR][10:35:50.017][DEBUG][RK0][main]: [device 2] allocating 0.0001 GB, available 29.4734
[HCTR][10:35:50.019][INFO][RK0][main]: Starting AUC NCCL warm-up
[HCTR][10:35:50.023][INFO][RK0][main]: Warm-up done
===================================================Model Summary===================================================
[HCTR][10:35:50.023][INFO][RK0][main]: Model structure on each GPU
Label Dense Sparse
label dense wide_data,deep_data
(2720,1) (2720,13)
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
Layer Type Input Name Output Name Output Shape
——————————————————————————————————————————————————————————————————————————————————————————————————————————————————
DistributedSlotSparseEmbeddingHash wide_data sparse_embedding2 (2720,2,1)
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash deep_data sparse_embedding1 (2720,26,16)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding1 reshape1 (2720,416)
------------------------------------------------------------------------------------------------------------------
Reshape sparse_embedding2 reshape2 (2720,2)
------------------------------------------------------------------------------------------------------------------
ReduceSum reshape2 wide_redn (2720,1)
------------------------------------------------------------------------------------------------------------------
Concat reshape1 concat1 (2720,429)
dense
------------------------------------------------------------------------------------------------------------------
InnerProduct concat1 fc1 (2720,1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc1 relu1 (2720,1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu1 dropout1 (2720,1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout1 fc2 (2720,1024)
------------------------------------------------------------------------------------------------------------------
ReLU fc2 relu2 (2720,1024)
------------------------------------------------------------------------------------------------------------------
Dropout relu2 dropout2 (2720,1024)
------------------------------------------------------------------------------------------------------------------
InnerProduct dropout2 fc3 (2720,1)
------------------------------------------------------------------------------------------------------------------
Add fc3 add1 (2720,1)
wide_redn
------------------------------------------------------------------------------------------------------------------
BinaryCrossEntropyLoss add1 loss
label
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[HCTR][10:35:50.024][INFO][RK0][main]: Use non-epoch mode with number of iterations: 21000
[HCTR][10:35:50.024][INFO][RK0][main]: Training batchsize: 2720, evaluation batchsize: 2720
[HCTR][10:35:50.024][INFO][RK0][main]: Evaluation interval: 4000, snapshot interval: 20000
[HCTR][10:35:50.024][INFO][RK0][main]: Dense network trainable: True
[HCTR][10:35:50.024][INFO][RK0][main]: Sparse embedding sparse_embedding1 trainable: True
[HCTR][10:35:50.024][INFO][RK0][main]: Sparse embedding sparse_embedding2 trainable: True
[HCTR][10:35:50.024][INFO][RK0][main]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True
[HCTR][10:35:50.024][INFO][RK0][main]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000
[HCTR][10:35:50.024][INFO][RK0][main]: decay_start: 0, decay_steps: 1, decay_power: 2.000000
[HCTR][10:35:50.024][INFO][RK0][main]: Training source file: /wdl_train/train/_file_list.txt
[HCTR][10:35:50.024][INFO][RK0][main]: Evaluation source file: /wdl_train/val/_file_list.txt
[HCTR][10:35:55.354][INFO][RK0][main]: Iter: 1000 Time(1000 iters): 5.32919s Loss: 0.117964 lr:0.001
[HCTR][10:36:00.665][INFO][RK0][main]: Iter: 2000 Time(1000 iters): 5.30907s Loss: 0.126084 lr:0.001
[HCTR][10:36:06.002][INFO][RK0][main]: Iter: 3000 Time(1000 iters): 5.33581s Loss: 0.138335 lr:0.001
[HCTR][10:36:11.349][INFO][RK0][main]: Iter: 4000 Time(1000 iters): 5.34525s Loss: 0.101962 lr:0.001
[HCTR][10:36:15.933][INFO][RK0][main]: Evaluation, AUC: 0.763947
[HCTR][10:36:15.933][INFO][RK0][main]: Eval Time for 4000 iters: 4.583s
[HCTR][10:36:21.258][INFO][RK0][main]: Iter: 5000 Time(1000 iters): 9.90761s Loss: 0.120185 lr:0.001
[HCTR][10:36:26.620][INFO][RK0][main]: Iter: 6000 Time(1000 iters): 5.35972s Loss: 0.128626 lr:0.001
[HCTR][10:36:31.947][INFO][RK0][main]: Iter: 7000 Time(1000 iters): 5.32628s Loss: 0.125264 lr:0.001
[HCTR][10:36:37.320][INFO][RK0][main]: Iter: 8000 Time(1000 iters): 5.37131s Loss: 0.121486 lr:0.001
[HCTR][10:36:41.803][INFO][RK0][main]: Evaluation, AUC: 0.767916
[HCTR][10:36:41.803][INFO][RK0][main]: Eval Time for 4000 iters: 4.48175s
[HCTR][10:36:47.154][INFO][RK0][main]: Iter: 9000 Time(1000 iters): 9.83223s Loss: 0.109454 lr:0.001
[HCTR][10:36:52.522][INFO][RK0][main]: Iter: 10000 Time(1000 iters): 5.36677s Loss: 0.149472 lr:0.001
[HCTR][10:36:57.896][INFO][RK0][main]: Iter: 11000 Time(1000 iters): 5.37183s Loss: 0.118341 lr:0.001
[HCTR][10:37:03.264][INFO][RK0][main]: Iter: 12000 Time(1000 iters): 5.36706s Loss: 0.128496 lr:0.001
[HCTR][10:37:07.728][INFO][RK0][main]: Evaluation, AUC: 0.769081
[HCTR][10:37:07.728][INFO][RK0][main]: Eval Time for 4000 iters: 4.46314s
[HCTR][10:37:13.098][INFO][RK0][main]: Iter: 13000 Time(1000 iters): 9.83154s Loss: 0.118482 lr:0.001
[HCTR][10:37:18.447][INFO][RK0][main]: Iter: 14000 Time(1000 iters): 5.34802s Loss: 0.122699 lr:0.001
[HCTR][10:37:23.812][INFO][RK0][main]: Iter: 15000 Time(1000 iters): 5.36294s Loss: 0.118947 lr:0.001
[HCTR][10:37:29.176][INFO][RK0][main]: Iter: 16000 Time(1000 iters): 5.36303s Loss: 0.112516 lr:0.001
[HCTR][10:37:33.646][INFO][RK0][main]: Evaluation, AUC: 0.772322
[HCTR][10:37:33.646][INFO][RK0][main]: Eval Time for 4000 iters: 4.46896s
[HCTR][10:37:39.146][INFO][RK0][main]: Iter: 17000 Time(1000 iters): 9.96806s Loss: 0.11619 lr:0.001
[HCTR][10:37:44.517][INFO][RK0][main]: Iter: 18000 Time(1000 iters): 5.37011s Loss: 0.113035 lr:0.001
[HCTR][10:37:49.891][INFO][RK0][main]: Iter: 19000 Time(1000 iters): 5.37157s Loss: 0.116589 lr:0.001
[HCTR][10:37:55.236][INFO][RK0][main]: Iter: 20000 Time(1000 iters): 5.34424s Loss: 0.127488 lr:0.001
[HCTR][10:37:59.698][INFO][RK0][main]: Evaluation, AUC: 0.768523
[HCTR][10:37:59.698][INFO][RK0][main]: Eval Time for 4000 iters: 4.46044s
[HCTR][10:37:59.698][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:37:59.771][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][10:37:59.810][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:37:59.867][INFO][RK0][main]: Rank0: Write hash table to file
[HCTR][10:38:00.091][INFO][RK0][main]: Dumping sparse weights to files, successful
[HCTR][10:38:00.092][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][10:38:00.092][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:38:00.115][INFO][RK0][main]: Done
[HCTR][10:38:00.116][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][10:38:00.116][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:38:00.140][INFO][RK0][main]: Done
[HCTR][10:38:00.245][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][10:38:00.245][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:38:00.610][INFO][RK0][main]: Done
[HCTR][10:38:00.695][INFO][RK0][main]: Rank0: Write optimzer state to file
[HCTR][10:38:00.695][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:38:01.057][INFO][RK0][main]: Done
[HCTR][10:38:01.063][INFO][RK0][main]: Dumping sparse optimzer states to files, successful
[HCTR][10:38:01.064][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:38:01.079][INFO][RK0][main]: Dumping dense weights to file, successful
[HCTR][10:38:01.081][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:38:01.116][INFO][RK0][main]: Dumping dense optimizer states to file, successful
[HCTR][10:38:06.487][INFO][RK0][main]: Finish 21000 iterations with batchsize: 2720 in 136.46s.
[HCTR][10:38:06.488][INFO][RK0][main]: Save the model graph to wdl.json successfully
[1676025486.656545] [dgx1v-loki-23:602 :0] cuda_copy_iface.c:468 UCX ERROR cuCtxGetCurrent(&cuda_context)() failed: p�
[1676025486.656590] [dgx1v-loki-23:602 :0] cuda_ipc_iface.c:531 UCX ERROR cuCtxGetCurrent(&cuda_context)() failed:
[HCTR][10:38:06.707][INFO][RK0][main]: MPI finalization done.
Prepare Inference Request
!ls -l /wdl_train/val
total 633080
-rw-r--r-- 1 root root 32 Feb 10 10:34 _file_list.txt
-rw-r--r-- 1 root root 21894 Feb 10 10:34 _metadata
-rw-r--r-- 1 root root 1509 Feb 10 10:34 _metadata.json
-rw-r--r-- 1 root root 138441430 Feb 10 10:34 part_0.parquet
-rw-r--r-- 1 root root 27296 Feb 10 10:34 schema.pbtxt
drwxr-xr-x 2 root root 50 Feb 10 10:32 temp-parquet-after-conversion
-rw-r--r-- 1 root root 509766965 Feb 10 10:32 test.txt
import pandas as pd
df = pd.read_parquet("/wdl_train/val/part_0.parquet")
df.head()
I1 | I2 | I3 | I4 | I5 | I6 | I7 | I8 | I9 | I10 | ... | C17 | C18 | C19 | C20 | C21 | C22 | C23 | C24 | C25 | C26 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -0.054112 | -0.267624 | 0.371471 | -0.076760 | -0.131771 | -0.206385 | -0.064249 | -0.208772 | 0.324461 | -0.470383 | ... | 1 | 49 | 1 | 3 | 3 | 3 | 9 | 402 | 1 | 4 |
1 | -0.048792 | -0.547466 | -0.594327 | -0.157301 | -0.224758 | -0.206385 | -0.064249 | 0.949396 | -0.760031 | -0.470383 | ... | 3 | 2 | 1 | 0 | 1645 | 0 | 1358 | 1232 | 1 | 1 |
2 | -0.059432 | -0.516221 | -0.594327 | -0.115539 | -0.209261 | -0.206385 | -0.064249 | -0.281810 | -0.687732 | -0.470383 | ... | 0 | 1 | 1 | 33190 | 32473 | 34242 | 0 | 954 | 3 | 3 |
3 | -0.029284 | -0.548824 | -0.057773 | -0.105099 | -0.224758 | -0.206385 | -0.064249 | -0.255725 | -0.760031 | -0.470383 | ... | 1 | 1 | 2 | 1 | 1 | 1 | 1 | 622 | 1 | 2 |
4 | -0.061206 | -0.548824 | -0.594327 | -0.145369 | -0.209261 | -0.206385 | 0.339399 | -0.281810 | 0.035263 | -0.470383 | ... | 0 | 1 | 1 | 90 | 143 | 101 | 0 | 21 | 1 | 3 |
5 rows × 42 columns
df.head(10).to_csv('/wdl_train/infer_test.csv', sep=',', index=False,header=True)
Create prediction scripts
%%writefile '/wdl_train/wdl_predict.py'
from hugectr.inference import InferenceParams, CreateInferenceSession
import hugectr
import pandas as pd
import numpy as np
import sys
from mpi4py import MPI
def wdl_inference(model_name, network_file, dense_file, embedding_file_list, data_file, enable_cache, use_rocksdb=False, rocksdb_path=None):
CATEGORICAL_COLUMNS=["C" + str(x) for x in range(1, 27)]+["C1_C2","C3_C4"]
CONTINUOUS_COLUMNS=["I" + str(x) for x in range(1, 14)]
LABEL_COLUMNS = ['label']
emb_size = [249058, 19561, 14212, 6890, 18592, 4, 6356, 1254, 52, 226170, 80508, 72308, 11, 2169, 7597, 61, 4, 923, 15, 249619, 168974, 243480, 68212, 9169, 75, 34, 278018, 415262]
shift = np.insert(np.cumsum(emb_size), 0, 0)[:-1]
test_df=pd.read_csv(data_file,sep=',')
config_file = network_file
row_ptrs = list(range(0,21))+list(range(0,261))
dense_features = list(test_df[CONTINUOUS_COLUMNS].values.flatten())
test_df[CATEGORICAL_COLUMNS].astype(np.int64)
embedding_columns = list((test_df[CATEGORICAL_COLUMNS]+shift).values.flatten())
persistent_db_params = hugectr.inference.PersistentDatabaseParams()
if use_rocksdb:
persistent_db_params = hugectr.inference.PersistentDatabaseParams(
backend = hugectr.DatabaseType_t.rocks_db,
path = rocksdb_path
)
# create parameter server, embedding cache and inference session
inference_params = InferenceParams(model_name = model_name,
max_batchsize = 64,
hit_rate_threshold = 0.5,
dense_model_file = dense_file,
sparse_model_files = embedding_file_list,
device_id = 0,
use_gpu_embedding_cache = enable_cache,
cache_size_percentage = 0.9,
persistent_db = persistent_db_params,
i64_input_key = True,
use_mixed_precision = False)
inference_session = CreateInferenceSession(config_file, inference_params)
output = inference_session.predict(dense_features, embedding_columns, row_ptrs)
print("WDL multi-embedding table inference result is {}".format(output))
if __name__ == "__main__":
model_name = sys.argv[1]
print("{} multi-embedding table prediction".format(model_name))
network_file = sys.argv[2]
print("{} multi-embedding table prediction network is {}".format(model_name,network_file))
dense_file = sys.argv[3]
print("{} multi-embedding table prediction dense file is {}".format(model_name,dense_file))
embedding_file_list = str(sys.argv[4]).split(',')
print("{} multi-embedding table prediction sparse files are {}".format(model_name,embedding_file_list))
data_file = sys.argv[5]
print("{} multi-embedding table prediction input data path is {}".format(model_name,data_file))
input_dbtype = sys.argv[6]
print("{} multi-embedding table prediction input dbtype path is {}".format(model_name,input_dbtype))
if input_dbtype=="disabled":
wdl_inference(model_name, network_file, dense_file, embedding_file_list, data_file, True)
if input_dbtype=="rocksdb":
rocksdb_path = sys.argv[7]
print("{} multi-embedding table prediction rocksdb_path path is {}".format(model_name,rocksdb_path))
wdl_inference(model_name, network_file, dense_file, embedding_file_list, data_file, True, True, rocksdb_path)
Overwriting /wdl_train/wdl_predict.py
Prediction
Use different types of databases as a local parameter server to get the wide and deep model prediction results.
Load model embedding tables into local memory as parameter server
!python /wdl_train/wdl_predict.py "wdl" "./wdl.json" "./wdl_dense_20000.model" "./wdl0_sparse_20000.model/,./wdl1_sparse_20000.model" "/wdl_train/infer_test.csv" "disabled"
wdl multi-embedding table prediction
wdl multi-embedding table prediction network is ./wdl.json
wdl multi-embedding table prediction dense file is ./wdl_dense_20000.model
wdl multi-embedding table prediction sparse files are ['./wdl0_sparse_20000.model/', './wdl1_sparse_20000.model']
wdl multi-embedding table prediction input data path is /wdl_train/infer_test.csv
wdl multi-embedding table prediction input dbtype path is disabled
[HCTR][10:53:44.990][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][10:53:44.990][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
[HCTR][10:53:44.990][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][10:53:44.990][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][10:53:44.991][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][10:53:44.991][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][10:53:44.991][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][10:53:44.991][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][10:53:44.991][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:53:45.479][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding2; cached 664320 / 664320 embeddings in volatile database (HashMapBackend); load: 664320 / 18446744073709551615 (0.00%).
[HCTR][10:53:45.479][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:53:45.829][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding1; cached 1030499 / 1030499 embeddings in volatile database (HashMapBackend); load: 1030499 / 18446744073709551615 (0.00%).
[HCTR][10:53:45.836][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][10:53:45.836][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][10:53:45.842][INFO][RK0][main]: Model name: wdl
[HCTR][10:53:45.842][INFO][RK0][main]: Max batch size: 64
[HCTR][10:53:45.842][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][10:53:45.842][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.900000
[HCTR][10:53:45.842][INFO][RK0][main]: Use static table: False
[HCTR][10:53:45.842][INFO][RK0][main]: Use I64 input key: True
[HCTR][10:53:45.842][INFO][RK0][main]: Configured cache hit rate threshold: 0.500000
[HCTR][10:53:45.842][INFO][RK0][main]: The size of thread pool: 80
[HCTR][10:53:45.842][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][10:53:45.842][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][10:53:45.842][INFO][RK0][main]: The refresh percentage : 0.000000
[HCTR][10:53:46.786][INFO][RK0][main]: Global seed is 1984285016
[HCTR][10:53:46.789][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][10:53:47.797][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][10:53:47.797][DEBUG][RK0][main]: [device 0] allocating 0.0000 GB, available 30.7156
[HCTR][10:53:47.797][INFO][RK0][main]: Start all2all warmup
[HCTR][10:53:47.797][INFO][RK0][main]: End all2all warmup
[HCTR][10:53:47.798][INFO][RK0][main]: Model name: wdl
[HCTR][10:53:47.798][INFO][RK0][main]: Use mixed precision: False
[HCTR][10:53:47.798][INFO][RK0][main]: Use cuda graph: True
[HCTR][10:53:47.798][INFO][RK0][main]: Max batchsize: 64
[HCTR][10:53:47.798][INFO][RK0][main]: Use I64 input key: True
[HCTR][10:53:47.798][INFO][RK0][main]: start create embedding for inference
[HCTR][10:53:47.798][INFO][RK0][main]: sparse_input name wide_data
[HCTR][10:53:47.798][INFO][RK0][main]: sparse_input name deep_data
[HCTR][10:53:47.798][INFO][RK0][main]: create embedding for inference success
[HCTR][10:53:47.798][DEBUG][RK0][main]: [device 0] allocating 0.0003 GB, available 30.4636
[HCTR][10:53:47.799][INFO][RK0][main]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][10:53:47.799][DEBUG][RK0][main]: [device 0] allocating 0.0128 GB, available 30.4421
WDL multi-embedding table inference result is [0.011136045679450035, 0.006747737061232328, 0.005509266164153814, 0.0118627417832613, 0.01798960007727146, 0.010030664503574371, 0.02108118124306202, 0.008684462867677212, 0.07753805071115494, 0.011398322880268097]
Load model embedding tables into local RocksDB as a parameter Server
Create a RocksDB directory with read and write permissions for storing model embedded tables.
!mkdir -p -m 700 /wdl_train/rocksdb
!python /wdl_train/wdl_predict.py "wdl" "./wdl.json" \
"./wdl_dense_20000.model" \
"./wdl0_sparse_20000.model/,./wdl1_sparse_20000.model" \
"/wdl_train/infer_test.csv" \
"rocksdb" "/wdl_train/rocksdb"
wdl multi-embedding table prediction
wdl multi-embedding table prediction network is ./wdl.json
wdl multi-embedding table prediction dense file is ./wdl_dense_20000.model
wdl multi-embedding table prediction sparse files are ['./wdl0_sparse_20000.model/', './wdl1_sparse_20000.model']
wdl multi-embedding table prediction input data path is /wdl_train/infer_test.csv
wdl multi-embedding table prediction input dbtype path is rocksdb
wdl multi-embedding table prediction rocksdb_path path is /wdl_train/rocksdb
[HCTR][10:56:41.546][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][10:56:41.546][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
[HCTR][10:56:41.546][INFO][RK0][main]: default_emb_vec_value is not specified using default: 0
====================================================HPS Create====================================================
[HCTR][10:56:41.546][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][10:56:41.547][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][10:56:41.547][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][10:56:41.547][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][10:56:41.547][INFO][RK0][main]: Creating RocksDB backend...
[HCTR][10:56:41.547][INFO][RK0][main]: Connecting to RocksDB database...
[HCTR][10:56:41.548][ERROR][RK0][main]: RocksDB /wdl_train/rocksdb: Listing column names failed!
[HCTR][10:56:41.548][INFO][RK0][main]: RocksDB /wdl_train/rocksdb, found column family "default".
[HCTR][10:56:41.583][INFO][RK0][main]: Connected to RocksDB database!
[HCTR][10:56:41.583][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][10:56:41.583][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:56:42.084][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding2; cached 664320 / 664320 embeddings in volatile database (HashMapBackend); load: 664320 / 18446744073709551615 (0.00%).
[HCTR][10:56:43.351][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding2; cached 664320 embeddings in persistent database (RocksDB).
[HCTR][10:56:43.351][INFO][RK0][main]: Using Local file system backend.
[HCTR][10:56:43.693][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding1; cached 1030499 / 1030499 embeddings in volatile database (HashMapBackend); load: 1030499 / 18446744073709551615 (0.00%).
[HCTR][10:56:45.979][INFO][RK0][main]: Table: hps_et.wdl.sparse_embedding1; cached 1030499 embeddings in persistent database (RocksDB).
[HCTR][10:56:45.985][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][10:56:45.985][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][10:56:45.991][INFO][RK0][main]: Model name: wdl
[HCTR][10:56:45.991][INFO][RK0][main]: Max batch size: 64
[HCTR][10:56:45.991][INFO][RK0][main]: Number of embedding tables: 2
[HCTR][10:56:45.991][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.900000
[HCTR][10:56:45.991][INFO][RK0][main]: Use static table: False
[HCTR][10:56:45.991][INFO][RK0][main]: Use I64 input key: True
[HCTR][10:56:45.991][INFO][RK0][main]: Configured cache hit rate threshold: 0.500000
[HCTR][10:56:45.991][INFO][RK0][main]: The size of thread pool: 80
[HCTR][10:56:45.991][INFO][RK0][main]: The size of worker memory pool: 2
[HCTR][10:56:45.991][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][10:56:45.991][INFO][RK0][main]: The refresh percentage : 0.000000
[HCTR][10:56:46.953][INFO][RK0][main]: Global seed is 3196997041
[HCTR][10:56:46.956][INFO][RK0][main]: Device to NUMA mapping:
GPU 0 -> node 0
[HCTR][10:56:48.005][WARNING][RK0][main]: Peer-to-peer access cannot be fully enabled.
[HCTR][10:56:48.005][DEBUG][RK0][main]: [device 0] allocating 0.0000 GB, available 30.7156
[HCTR][10:56:48.005][INFO][RK0][main]: Start all2all warmup
[HCTR][10:56:48.005][INFO][RK0][main]: End all2all warmup
[HCTR][10:56:48.006][INFO][RK0][main]: Model name: wdl
[HCTR][10:56:48.006][INFO][RK0][main]: Use mixed precision: False
[HCTR][10:56:48.006][INFO][RK0][main]: Use cuda graph: True
[HCTR][10:56:48.006][INFO][RK0][main]: Max batchsize: 64
[HCTR][10:56:48.006][INFO][RK0][main]: Use I64 input key: True
[HCTR][10:56:48.006][INFO][RK0][main]: start create embedding for inference
[HCTR][10:56:48.006][INFO][RK0][main]: sparse_input name wide_data
[HCTR][10:56:48.006][INFO][RK0][main]: sparse_input name deep_data
[HCTR][10:56:48.006][INFO][RK0][main]: create embedding for inference success
[HCTR][10:56:48.006][DEBUG][RK0][main]: [device 0] allocating 0.0003 GB, available 30.4636
[HCTR][10:56:48.007][INFO][RK0][main]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer
[HCTR][10:56:48.007][DEBUG][RK0][main]: [device 0] allocating 0.0128 GB, available 30.4421
WDL multi-embedding table inference result is [0.011136045679450035, 0.006747737061232328, 0.005509266164153814, 0.0118627417832613, 0.01798960007727146, 0.010030664503574371, 0.02108118124306202, 0.008684462867677212, 0.07753805071115494, 0.011398322880268097]
[HCTR][10:56:48.571][INFO][RK0][main]: Disconnecting from RocksDB database...
[HCTR][10:56:48.573][INFO][RK0][main]: Disconnected from RocksDB database!