# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

HugeCTR Wide and Deep Model with Criteo

Overview

In this notebook, we provide a tutorial that shows how to train a wide and deep model using the high-level Python API from HugeCTR on the original Criteo dataset as training data. We show how to produce prediction results based on different types of local database.

Dataset Preprocessing

Generate training and validation data folders

# define some data folder to store the original and preprocessed data
# Standard Libraries
import os
from time import time
import re
import shutil
import glob
import warnings
BASE_DIR = "/wdl_train"
train_path  = os.path.join(BASE_DIR, "train")
val_path = os.path.join(BASE_DIR, "val")
CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
n_workers = len(CUDA_VISIBLE_DEVICES.split(","))
frac_size = 0.15
allow_multi_gpu = False
use_rmm_pool = False
max_day = None  # (Optional) -- Limit the dataset to day 0-max_day for debugging

if os.path.isdir(train_path):
    shutil.rmtree(train_path)
os.makedirs(train_path)

if os.path.isdir(val_path):
    shutil.rmtree(val_path)
os.makedirs(val_path)
!ls -l $train_path
total 14537948
-rw-r--r-- 1 root root  3336516608 Jul  5 05:44 0.8870d61b8a1f4deca0f911acfb072999.parquet
-rw-r--r-- 1 root root          61 Jul  5 05:44 _file_list.txt
-rw-r--r-- 1 root root      602767 Jul  5 05:44 _metadata
-rw-r--r-- 1 root root        1538 Jul  5 05:44 _metadata.json
drwxr-xr-x 2 root root        4096 Jul  5 05:41 temp-parquet-after-conversion/
-rwxrwxr-x 1 1025 1025 11549710546 Jul  5 05:39 train.txt*

Download the original Criteo dataset

!apt-get install wget
Reading package lists... Done
Building dependency tree       
Reading state information... Done
wget is already the newest version (1.20.3-1ubuntu1).
0 upgraded, 0 newly installed, 0 to remove and 4 not upgraded.
!wget -P $train_path https://storage.googleapis.com/criteo-cail-datasets/day_0.gz
--2022-06-16 03:27:50--  https://storage.googleapis.com/criteo-cail-datasets/day_0.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.191.80, 172.217.5.112, 142.250.189.208, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.191.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16309554343 (15G) [application/octet-stream]
Saving to: ‘wdl_train/train/day_0.gz’

day_0.gz               100%[===================================================>]   15G  --.-KB/s    in  6m 12s 
2021-07-05 03:34:04 (79.2 MB/s) - 'day_0.gz' saved [16309554343/16309554343]

Split the dataset into training and validation.

!gzip -d -c $train_path/day_0.gz > day_0
!head -n 45840617 day_0 > $train_path/train.txt
!tail -n 2000000 day_0 > $val_path/test.txt 

Preprocessing with NVTabular

%%writefile /wdl_train/preprocess.py
import os
import sys
import argparse
import glob
import time
from cudf.io.parquet import ParquetWriter
import numpy as np
import pandas as pd
import concurrent.futures as cf
from concurrent.futures import as_completed
import shutil

import dask_cudf
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from dask.utils import parse_bytes
from dask.delayed import delayed

import cudf
import rmm
import nvtabular as nvt
from nvtabular.io import Shuffle
from nvtabular.utils import device_mem_size
from nvtabular.ops import Categorify, Clip, FillMissing, HashBucket, LambdaOp, Normalize, Rename, Operator, get_embedding_sizes
#%load_ext memory_profiler

import logging
logging.basicConfig(format='%(asctime)s %(message)s')
logging.root.setLevel(logging.NOTSET)
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('asyncio').setLevel(logging.WARNING)

# define dataset schema
CATEGORICAL_COLUMNS=["C" + str(x) for x in range(1, 27)]
CONTINUOUS_COLUMNS=["I" + str(x) for x in range(1, 14)]
LABEL_COLUMNS = ['label']
COLUMNS =  LABEL_COLUMNS + CONTINUOUS_COLUMNS +  CATEGORICAL_COLUMNS
#/samples/criteo mode doesn't have dense features
criteo_COLUMN=LABEL_COLUMNS +  CATEGORICAL_COLUMNS
#For new feature cross columns
CROSS_COLUMNS = []


NUM_INTEGER_COLUMNS = 13
NUM_CATEGORICAL_COLUMNS = 26
NUM_TOTAL_COLUMNS = 1 + NUM_INTEGER_COLUMNS + NUM_CATEGORICAL_COLUMNS


# Initialize RMM pool on ALL workers
def setup_rmm_pool(client, pool_size):
    client.run(rmm.reinitialize, pool_allocator=True, initial_pool_size=pool_size)
    return None

#compute the partition size with GB
def bytesto(bytes, to, bsize=1024):
    a = {'k' : 1, 'm': 2, 'g' : 3, 't' : 4, 'p' : 5, 'e' : 6 }
    r = float(bytes)
    return bytes / (bsize ** a[to])

class FeatureCross(Operator):
    def __init__(self, dependency):
        self.dependency = dependency

    def transform(self, columns, gdf):
        new_df = type(gdf)()
        for col in columns.names:
            new_df[col] = gdf[col] + gdf[self.dependency]
        return new_df

    def dependencies(self):
        return [self.dependency]

#process the data with NVTabular
def process_NVT(args):

    if args.feature_cross_list:
        feature_pairs = [pair.split("_") for pair in args.feature_cross_list.split(",")]
        for pair in feature_pairs:
            CROSS_COLUMNS.append(pair[0]+'_'+pair[1])


    logging.info('NVTabular processing')
    train_input = os.path.join(args.data_path, "train/train.txt")
    val_input = os.path.join(args.data_path, "val/test.txt")
    PREPROCESS_DIR_temp_train = os.path.join(args.out_path, 'train/temp-parquet-after-conversion')
    PREPROCESS_DIR_temp_val = os.path.join(args.out_path, 'val/temp-parquet-after-conversion')
    PREPROCESS_DIR_temp = [PREPROCESS_DIR_temp_train, PREPROCESS_DIR_temp_val]
    train_output = os.path.join(args.out_path, "train")
    val_output = os.path.join(args.out_path, "val")

    # Make sure we have a clean parquet space for cudf conversion
    for one_path in PREPROCESS_DIR_temp:
        if os.path.exists(one_path):
            shutil.rmtree(one_path)
        os.mkdir(one_path)


    ## Get Dask Client

    # Deploy a Single-Machine Multi-GPU Cluster
    device_size = device_mem_size(kind="total")
    cluster = None
    if args.protocol == "ucx":
        UCX_TLS = os.environ.get("UCX_TLS", "tcp,cuda_copy,cuda_ipc,sockcm")
        os.environ["UCX_TLS"] = UCX_TLS
        cluster = LocalCUDACluster(
            protocol = args.protocol,
            CUDA_VISIBLE_DEVICES = args.devices,
            n_workers = len(args.devices.split(",")),
            enable_nvlink=True,
            device_memory_limit = int(device_size * args.device_limit_frac),
            dashboard_address=":" + args.dashboard_port
        )
    else:
        cluster = LocalCUDACluster(
            protocol = args.protocol,
            n_workers = len(args.devices.split(",")),
            CUDA_VISIBLE_DEVICES = args.devices,
            device_memory_limit = int(device_size * args.device_limit_frac),
            dashboard_address=":" + args.dashboard_port
        )



    # Create the distributed client
    client = Client(cluster)
    if args.device_pool_frac > 0.01:
        setup_rmm_pool(client, int(args.device_pool_frac*device_size))


    #calculate the total processing time
    runtime = time.time()

    #test dataset without the label feature
    if args.dataset_type == 'test':
        global LABEL_COLUMNS
        LABEL_COLUMNS = []

    ##-----------------------------------##
    # Dask rapids converts txt to parquet
    # Dask cudf dataframe = ddf

    ## train/valid txt to parquet
    train_valid_paths = [(train_input,PREPROCESS_DIR_temp_train),(val_input,PREPROCESS_DIR_temp_val)]

    for input, temp_output in train_valid_paths:

        ddf = dask_cudf.read_csv(input,sep='\t',names=LABEL_COLUMNS + CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS)

        ## Convert label col to FP32
        if args.parquet_format and args.dataset_type == 'train':
            ddf["label"] = ddf['label'].astype('float32')

        # Save it as parquet format for better memory usage
        ddf.to_parquet(temp_output,header=True)
        ##-----------------------------------##

    COLUMNS =  LABEL_COLUMNS + CONTINUOUS_COLUMNS + CROSS_COLUMNS + CATEGORICAL_COLUMNS
    train_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_train, "*.parquet"))
    valid_paths = glob.glob(os.path.join(PREPROCESS_DIR_temp_val, "*.parquet"))

    categorify_op = Categorify(freq_threshold=args.freq_limit)
    cat_features = CATEGORICAL_COLUMNS >> categorify_op
    cont_features = CONTINUOUS_COLUMNS >> FillMissing() >> Clip(min_value=0) >> Normalize()
    cross_cat_op = Categorify(freq_threshold=args.freq_limit)

    features = LABEL_COLUMNS
    
    if args.criteo_mode == 0:
        features += cont_features
        if args.feature_cross_list:
            feature_pairs = [pair.split("_") for pair in args.feature_cross_list.split(",")]
            for pair in feature_pairs:
                col0 = pair[0]
                col1 = pair[1]
                features += col0 >> FeatureCross(col1)  >> Rename(postfix="_"+col1) >> cross_cat_op
            
    features += cat_features

    workflow = nvt.Workflow(features, client=client)

    logging.info("Preprocessing")

    output_format = 'hugectr'
    if args.parquet_format:
        output_format = 'parquet'

    # just for /samples/criteo model
    train_ds_iterator = nvt.Dataset(train_paths, engine='parquet', part_size=int(args.part_mem_frac * device_size))
    valid_ds_iterator = nvt.Dataset(valid_paths, engine='parquet', part_size=int(args.part_mem_frac * device_size))

    shuffle = None
    if args.shuffle == "PER_WORKER":
        shuffle = nvt.io.Shuffle.PER_WORKER
    elif args.shuffle == "PER_PARTITION":
        shuffle = nvt.io.Shuffle.PER_PARTITION

    logging.info('Train Datasets Preprocessing.....')

    dict_dtypes = {}
    for col in CATEGORICAL_COLUMNS:
        dict_dtypes[col] = np.int64
    if not args.criteo_mode:
        for col in CONTINUOUS_COLUMNS:
            dict_dtypes[col] = np.float32
    for col in CROSS_COLUMNS:
        dict_dtypes[col] = np.int64
    for col in LABEL_COLUMNS:
        dict_dtypes[col] = np.float32
    
    conts = CONTINUOUS_COLUMNS if not args.criteo_mode else []
    
    workflow.fit(train_ds_iterator)
    
    if output_format == 'hugectr':
        workflow.transform(train_ds_iterator).to_hugectr(
                cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
                conts=conts,
                labels=LABEL_COLUMNS,
                output_path=train_output,
                shuffle=shuffle,
                out_files_per_proc=args.out_files_per_proc,
                num_threads=args.num_io_threads)
    else:
        workflow.transform(train_ds_iterator).to_parquet(
                output_path=train_output,
                dtypes=dict_dtypes,
                cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
                conts=conts,
                labels=LABEL_COLUMNS,
                shuffle=shuffle,
                out_files_per_proc=args.out_files_per_proc,
                num_threads=args.num_io_threads)
        
        
        
    ###Getting slot size###    
    #--------------------##
    embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
    embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
    embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]
    
    print(embeddings)
    ##--------------------##

    logging.info('Valid Datasets Preprocessing.....')

    if output_format == 'hugectr':
        workflow.transform(valid_ds_iterator).to_hugectr(
                cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
                conts=conts,
                labels=LABEL_COLUMNS,
                output_path=val_output,
                shuffle=shuffle,
                out_files_per_proc=args.out_files_per_proc,
                num_threads=args.num_io_threads)
    else:
        workflow.transform(valid_ds_iterator).to_parquet(
                output_path=val_output,
                dtypes=dict_dtypes,
                cats=CATEGORICAL_COLUMNS + CROSS_COLUMNS,
                conts=conts,
                labels=LABEL_COLUMNS,
                shuffle=shuffle,
                out_files_per_proc=args.out_files_per_proc,
                num_threads=args.num_io_threads)

    embeddings_dict_cat = categorify_op.get_embedding_sizes(CATEGORICAL_COLUMNS)
    embeddings_dict_cross = cross_cat_op.get_embedding_sizes(CROSS_COLUMNS)
    embeddings = [embeddings_dict_cat[c][0] for c in CATEGORICAL_COLUMNS] + [embeddings_dict_cross[c][0] for c in CROSS_COLUMNS]
    
    print(embeddings)
    ##--------------------##

    ## Shutdown clusters
    client.close()
    logging.info('NVTabular processing done')

    runtime = time.time() - runtime

    print("\nDask-NVTabular Criteo Preprocessing")
    print("--------------------------------------")
    print(f"data_path          | {args.data_path}")
    print(f"output_path        | {args.out_path}")
    print(f"partition size     | {'%.2f GB'%bytesto(int(args.part_mem_frac * device_size),'g')}")
    print(f"protocol           | {args.protocol}")
    print(f"device(s)          | {args.devices}")
    print(f"rmm-pool-frac      | {(args.device_pool_frac)}")
    print(f"out-files-per-proc | {args.out_files_per_proc}")
    print(f"num_io_threads     | {args.num_io_threads}")
    print(f"shuffle            | {args.shuffle}")
    print("======================================")
    print(f"Runtime[s]         | {runtime}")
    print("======================================\n")


def parse_args():
    parser = argparse.ArgumentParser(description=("Multi-GPU Criteo Preprocessing"))

    #
    # System Options
    #

    parser.add_argument("--data_path", type=str, help="Input dataset path (Required)")
    parser.add_argument("--out_path", type=str, help="Directory path to write output (Required)")
    parser.add_argument(
        "-d",
        "--devices",
        default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"),
        type=str,
        help='Comma-separated list of visible devices (e.g. "0,1,2,3"). '
    )
    parser.add_argument(
        "-p",
        "--protocol",
        choices=["tcp", "ucx"],
        default="tcp",
        type=str,
        help="Communication protocol to use (Default 'tcp')",
    )
    parser.add_argument(
        "--device_limit_frac",
        default=0.5,
        type=float,
        help="Worker device-memory limit as a fraction of GPU capacity (Default 0.8). "
    )
    parser.add_argument(
        "--device_pool_frac",
        default=0.9,
        type=float,
        help="RMM pool size for each worker  as a fraction of GPU capacity (Default 0.9). "
        "The RMM pool frac is the same for all GPUs, make sure each one has enough memory size",
    )
    parser.add_argument(
        "--num_io_threads",
        default=0,
        type=int,
        help="Number of threads to use when writing output data (Default 0). "
        "If 0 is specified, multi-threading will not be used for IO.",
    )

    #
    # Data-Decomposition Parameters
    #

    parser.add_argument(
        "--part_mem_frac",
        default=0.125,
        type=float,
        help="Maximum size desired for dataset partitions as a fraction "
        "of GPU capacity (Default 0.125)",
    )
    parser.add_argument(
        "--out_files_per_proc",
        default=1,
        type=int,
        help="Number of output files to write on each worker (Default 1)",
    )

    #
    # Preprocessing Options
    #

    parser.add_argument(
        "-f",
        "--freq_limit",
        default=0,
        type=int,
        help="Frequency limit for categorical encoding (Default 0)",
    )
    parser.add_argument(
        "-s",
        "--shuffle",
        choices=["PER_WORKER", "PER_PARTITION", "NONE"],
        default="PER_PARTITION",
        help="Shuffle algorithm to use when writing output data to disk (Default PER_PARTITION)",
    )

    parser.add_argument(
        "--feature_cross_list", default=None, type=str, help="List of feature crossing cols (e.g. C1_C2, C3_C4)"
    )

    #
    # Diagnostics Options
    #

    parser.add_argument(
        "--profile",
        metavar="PATH",
        default=None,
        type=str,
        help="Specify a file path to export a Dask profile report (E.g. dask-report.html)."
        "If this option is excluded from the command, not profile will be exported",
    )
    parser.add_argument(
        "--dashboard_port",
        default="8787",
        type=str,
        help="Specify the desired port of Dask's diagnostics-dashboard (Default `3787`). "
        "The dashboard will be hosted at http://<IP>:<PORT>/status",
    )

    #
    # Format
    #

    parser.add_argument('--criteo_mode', type=int, default=0)
    parser.add_argument('--parquet_format', type=int, default=1)
    parser.add_argument('--dataset_type', type=str, default='train')

    args = parser.parse_args()
    args.n_workers = len(args.devices.split(","))
    return args
if __name__ == '__main__':

    args = parse_args()

    process_NVT(args)
Writing /wdl_train/preprocess.py
import pandas as pd
!python3 /wdl_train/preprocess.py --data_path wdl_train/ \
--out_path wdl_train/ --freq_limit 6 --feature_cross_list C1_C2,C3_C4 \
--device_pool_frac 0.5  --devices '0' --num_io_threads 2
2021-07-05 05:41:34,199 NVTabular processing
2021-07-05 05:42:00,112 Preprocessing
2021-07-05 05:42:00,469 Train Datasets Preprocessing.....
[249058, 19561, 14212, 6890, 18592, 4, 6356, 1254, 52, 226170, 80508, 72308, 11, 2169, 7597, 61, 4, 923, 15, 249619, 168974, 243480, 68212, 9169, 75, 34, 278018, 415262]
2021-07-05 05:44:17,349 Valid Datasets Preprocessing.....
[249058, 19561, 14212, 6890, 18592, 4, 6356, 1254, 52, 226170, 80508, 72308, 11, 2169, 7597, 61, 4, 923, 15, 249619, 168974, 243480, 68212, 9169, 75, 34, 278018, 415262]
2021-07-05 05:44:19,138 NVTabular processing done

Dask-NVTabular Criteo Preprocessing
--------------------------------------
data_path          | wdl_train/
output_path        | wdl_train/
partition size     | 2.77 GB
protocol           | tcp
device(s)          | 0
rmm-pool-frac      | 0.5
out-files-per-proc | 1
num_io_threads     | 2
shuffle            | PER_PARTITION
======================================
Runtime[s]         | 159.50506210327148
======================================

Check the preprocessed training data

!ls -ll /wdl_train/train
total 14537948
-rw-r--r-- 1 root root  3336516608 Jul  5 05:44 0.8870d61b8a1f4deca0f911acfb072999.parquet
-rw-r--r-- 1 root root          61 Jul  5 05:44 _file_list.txt
-rw-r--r-- 1 root root      602767 Jul  5 05:44 _metadata
-rw-r--r-- 1 root root        1538 Jul  5 05:44 _metadata.json
drwxr-xr-x 2 root root        4096 Jul  5 05:41 temp-parquet-after-conversion
-rwxrwxr-x 1 1025 1025 11549710546 Jul  5 05:39 train.txt
import pandas as pd
df = pd.read_parquet("/wdl_train/train/0.8870d61b8a1f4deca0f911acfb072999.parquet")
df.head(2)
I1 I2 I3 I4 I5 I6 I7 I8 I9 I10 ... C17 C18 C19 C20 C21 C22 C23 C24 C25 C26
0 -0.048792 -0.368150 0.478781 -0.133437 -0.069780 0.068484 0.743047 -0.266159 1.481252 1.386036 ... 3 356 10 183947 140830 28449 64057 6432 10 22
1 -0.061206 0.840877 -0.594327 0.148456 -0.209261 -0.206385 -0.064249 -0.281810 -0.760031 -0.470383 ... 1 781 10 207893 27876 112273 65971 3414 10 5

2 rows × 42 columns

WDL Model Training

%%writefile './model.py'
import hugectr
#from mpi4py import MPI
solver = hugectr.CreateSolver(max_eval_batches = 4000,
                              batchsize_eval = 2720,
                              batchsize = 2720,
                              lr = 0.001,
                              vvgpu = [[2]],
                              repeat_dataset = True,
                              i64_input_key = True)

reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
                                  source = ["./train/_file_list.txt"],
                                  eval_source = "./val/_file_list.txt",
                                  check_type = hugectr.Check_t.Non,
                                  slot_size_array = [249058, 19561, 14212, 6890, 18592, 4, 6356, 1254, 52, 226170, 80508, 72308, 11, 2169, 7597, 61, 4, 923, 15, 249619, 168974, 243480, 68212, 9169, 75, 34, 278018, 415262])
optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam,
                                    update_type = hugectr.Update_t.Global,
                                    beta1 = 0.9,
                                    beta2 = 0.999,
                                    epsilon = 0.0000001)
model = hugectr.Model(solver, reader, optimizer)

model.add(hugectr.Input(label_dim = 1, label_name = "label",
                        dense_dim = 13, dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("wide_data", 1, True, 2),
                        hugectr.DataReaderSparseParam("deep_data", 2, False, 26)]))

model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 24,
                            embedding_vec_size = 1,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding2",
                            bottom_name = "wide_data",
                            optimizer = optimizer))
model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, 
                            workspace_size_per_gpu_in_mb = 405,
                            embedding_vec_size = 16,
                            combiner = "sum",
                            sparse_embedding_name = "sparse_embedding1",
                            bottom_name = "deep_data",
                            optimizer = optimizer))

model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape1"],
                            leading_dim=416))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding2"],
                            top_names = ["reshape2"],
                            leading_dim=2))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReduceSum,
                            bottom_names = ["reshape2"],
                            top_names = ["wide_redn"],
                            axis = 1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,
                            bottom_names = ["reshape1", "dense"],
                            top_names = ["concat1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["concat1"],
                            top_names = ["fc1"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1024))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc2"],
                            top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu2"],
                            top_names = ["dropout2"],
                            dropout_rate=0.5))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout2"],
                            top_names = ["fc3"],
                            num_output=1))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
                            bottom_names = ["fc3", "wide_redn"],
                            top_names = ["add1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["add1", "label"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.fit(max_iter = 21000, display = 1000, eval_interval = 4000, snapshot = 20000, snapshot_prefix = "wdl")
model.graph_to_json(graph_config_file = "wdl.json")
Overwriting ./model.py
!python ./model.py
====================================================Model Init=====================================================
[13d04h56m26s][HUGECTR][INFO]: Global seed is 3689394843
[13d04h56m31s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
Device 2: A30
[13d04h56m31s][HUGECTR][INFO]: num of DataReader workers: 1
[13d04h56m31s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=2097152
[13d04h56m31s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=2211840
===================================================Model Compile===================================================
[13d04h56m46s][HUGECTR][INFO]: gpu0 start to init embedding
[13d04h56m46s][HUGECTR][INFO]: gpu0 init embedding done
[13d04h56m46s][HUGECTR][INFO]: gpu0 start to init embedding
[13d04h56m46s][HUGECTR][INFO]: gpu0 init embedding done
===================================================Model Summary===================================================
Label                                   Dense                         Sparse                        
label                                   dense                          wide_data,deep_data           
(None, 1)                               (None, 13)                              
------------------------------------------------------------------------------------------------------------------
Layer Type                              Input Name                    Output Name                   Output Shape                  
------------------------------------------------------------------------------------------------------------------
DistributedSlotSparseEmbeddingHash      wide_data                     sparse_embedding2             (None, 2, 1)                  
DistributedSlotSparseEmbeddingHash      deep_data                     sparse_embedding1             (None, 26, 16)                
Reshape                                 sparse_embedding1             reshape1                      (None, 416)                   
Reshape                                 sparse_embedding2             reshape2                      (None, 2)                     
ReduceSum                               reshape2                      wide_redn                     (None, 1)                     
Concat                                  reshape1,dense                concat1                       (None, 429)                   
InnerProduct                            concat1                       fc1                           (None, 1024)                  
ReLU                                    fc1                           relu1                         (None, 1024)                  
Dropout                                 relu1                         dropout1                      (None, 1024)                  
InnerProduct                            dropout1                      fc2                           (None, 1024)                  
ReLU                                    fc2                           relu2                         (None, 1024)                  
Dropout                                 relu2                         dropout2                      (None, 1024)                  
InnerProduct                            dropout2                      fc3                           (None, 1)                     
Add                                     fc3,wide_redn                 add1                          (None, 1)                     
BinaryCrossEntropyLoss                  add1,label                    loss                                                        
------------------------------------------------------------------------------------------------------------------
=====================================================Model Fit=====================================================
[13d40h56m46s][HUGECTR][INFO]: Use non-epoch mode with number of iterations: 21000
[13d40h56m46s][HUGECTR][INFO]: Training batchsize: 2720, evaluation batchsize: 2720
[13d40h56m46s][HUGECTR][INFO]: Evaluation interval: 4000, snapshot interval: 20000
[13d40h56m46s][HUGECTR][INFO]: Sparse embedding trainable: 1, dense network trainable: 1
[13d40h56m46s][HUGECTR][INFO]: Use mixed precision: 0, scaler: 1.000000, use cuda graph: 1
[13d40h56m46s][HUGECTR][INFO]: lr: 0.001000, warmup_steps: 1, decay_start: 0, decay_steps: 1, decay_power: 2.000000, end_lr: 0.000000
[13d40h56m46s][HUGECTR][INFO]: Training source file: ./train/_file_list.txt
[13d40h56m46s][HUGECTR][INFO]: Evaluation source file: ./val/_file_list.txt
[13d40h56m53s][HUGECTR][INFO]: Iter: 1000 Time(1000 iters): 6.920620s Loss: 0.083037 lr:0.001000
[13d40h57m00s][HUGECTR][INFO]: Iter: 2000 Time(1000 iters): 6.706585s Loss: 0.120559 lr:0.001000
[13d40h57m60s][HUGECTR][INFO]: Iter: 3000 Time(1000 iters): 6.699129s Loss: 0.117169 lr:0.001000
[13d40h57m13s][HUGECTR][INFO]: Iter: 4000 Time(1000 iters): 6.758591s Loss: 0.083112 lr:0.001000
[13d40h57m23s][HUGECTR][INFO]: Evaluation, AUC: 0.824140
[13d40h57m23s][HUGECTR][INFO]: Eval Time for 4000 iters: 9.483117s
[13d40h57m29s][HUGECTR][INFO]: Iter: 5000 Time(1000 iters): 16.187022s Loss: 0.131896 lr:0.001000
[13d40h57m36s][HUGECTR][INFO]: Iter: 6000 Time(1000 iters): 6.748882s Loss: 0.082966 lr:0.001000
[13d40h57m43s][HUGECTR][INFO]: Iter: 7000 Time(1000 iters): 6.761953s Loss: 0.091929 lr:0.001000
[13d40h57m50s][HUGECTR][INFO]: Iter: 8000 Time(1000 iters): 6.874048s Loss: 0.080763 lr:0.001000
[13d40h57m59s][HUGECTR][INFO]: Evaluation, AUC: 0.826269
[13d40h57m59s][HUGECTR][INFO]: Eval Time for 4000 iters: 9.275068s
[13d40h58m60s][HUGECTR][INFO]: Iter: 9000 Time(1000 iters): 15.969286s Loss: 0.088093 lr:0.001000
[13d40h58m12s][HUGECTR][INFO]: Iter: 10000 Time(1000 iters): 6.652935s Loss: 0.137476 lr:0.001000
[13d40h58m19s][HUGECTR][INFO]: Iter: 11000 Time(1000 iters): 6.751184s Loss: 0.116295 lr:0.001000
[13d40h58m26s][HUGECTR][INFO]: Iter: 12000 Time(1000 iters): 6.659960s Loss: 0.151319 lr:0.001000
[13d40h58m35s][HUGECTR][INFO]: Evaluation, AUC: 0.827362
[13d40h58m35s][HUGECTR][INFO]: Eval Time for 4000 iters: 9.378966s
[13d40h58m42s][HUGECTR][INFO]: Iter: 13000 Time(1000 iters): 16.001544s Loss: 0.094625 lr:0.001000
[13d40h58m48s][HUGECTR][INFO]: Iter: 14000 Time(1000 iters): 6.678430s Loss: 0.121618 lr:0.001000
[13d40h58m55s][HUGECTR][INFO]: Iter: 15000 Time(1000 iters): 6.840206s Loss: 0.083302 lr:0.001000
[13d40h59m20s][HUGECTR][INFO]: Iter: 16000 Time(1000 iters): 6.489092s Loss: 0.102394 lr:0.001000
[13d40h59m11s][HUGECTR][INFO]: Evaluation, AUC: 0.829899
[13d40h59m11s][HUGECTR][INFO]: Eval Time for 4000 iters: 9.338721s
[13d40h59m18s][HUGECTR][INFO]: Iter: 17000 Time(1000 iters): 15.868251s Loss: 0.108997 lr:0.001000
[13d40h59m24s][HUGECTR][INFO]: Iter: 18000 Time(1000 iters): 5.960831s Loss: 0.098293 lr:0.001000
[13d40h59m29s][HUGECTR][INFO]: Iter: 19000 Time(1000 iters): 5.980448s Loss: 0.071080 lr:0.001000
[13d40h59m35s][HUGECTR][INFO]: Iter: 20000 Time(1000 iters): 5.984280s Loss: 0.115342 lr:0.001000
[13d40h59m45s][HUGECTR][INFO]: Evaluation, AUC: 0.828875
[13d40h59m45s][HUGECTR][INFO]: Eval Time for 4000 iters: 9.337684s
[13d40h59m45s][HUGECTR][INFO]: Rank0: Write hash table to file
[13d40h59m45s][HUGECTR][INFO]: Rank0: Write hash table to file
[13d40h59m45s][HUGECTR][INFO]: Dumping sparse weights to files, successful
[13d40h59m45s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[13d40h59m45s][HUGECTR][INFO]: Done
[13d40h59m45s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[13d40h59m45s][HUGECTR][INFO]: Done
[13d40h59m45s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[13d40h59m45s][HUGECTR][INFO]: Done
[13d40h59m45s][HUGECTR][INFO]: Rank0: Write optimzer state to file
[13d40h59m45s][HUGECTR][INFO]: Done
[13d40h59m45s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful
[13d40h59m45s][HUGECTR][INFO]: Dumping dense weights to file, successful
[13d40h59m45s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful
[13d40h59m45s][HUGECTR][INFO]: Dumping untrainable weights to file, successful
[13d40h59m51s][HUGECTR][INFO]: Save the model graph to wdl.json, successful
!ls -ll
total 301620
-rw-rw-r-- 1 1025 1025     49824 Jul  6 07:00 HugeCTR_WDL_Training.ipynb
drwxr-xr-x 2 root root      4096 Jul  5 05:43 categories
drwxr-xr-x 3 root root      4096 Jul  5 05:44 dask-worker-space
-rw-r--r-- 1 root root      5539 Jul  6 07:00 model.py
-rw-r--r-- 1 root root     14265 Jul  6 06:59 preprocess.py
drwxr-xr-x 3 root root      4096 Jul  5 23:34 train
drwxr-xr-x 3 root root      4096 Jul  5 05:44 val
-rw-r--r-- 1 root root  17108704 Jul  6 03:28 wdl0_opt_sparse_20000.model
drwxr-xr-x 2 root root      4096 Jul  5 06:32 wdl0_sparse_20000.model
-rw-r--r-- 1 root root 273739264 Jul  6 03:28 wdl1_opt_sparse_20000.model
drwxr-xr-x 2 root root      4096 Jul  5 06:32 wdl1_sparse_20000.model
-rw-r--r-- 1 root root   5963780 Jul  6 03:28 wdl_dense_20000.model
-rw-r--r-- 1 root root      3158 Jul  6 03:28 wdl_infer.json
-rw-r--r-- 1 root root  11927560 Jul  6 03:28 wdl_opt_dense_20000.model
!python /wdl_infer/wdl_python_infer.py "wdl" "/wdl_infer/model/wdl/1/wdl.json" \
"/wdl_infer/model/wdl/1/wdl_dense_20000.model" \
"/wdl_infer/model/wdl/1/wdl0_sparse_20000.model/,/wdl_infer/model/wdl/1/wdl1_sparse_20000.model" \
"/wdl_infer/first_ten.csv"
['/wdl_infer/model/wdl/1/wdl0_sparse_20000.model/', '/wdl_infer/model/wdl/1/wdl1_sparse_20000.model']
[14d04h23m54s][HUGECTR][INFO]: default_emb_vec_value is not specified using default: 0.000000
[14d04h23m54s][HUGECTR][INFO]: default_emb_vec_value is not specified using default: 0.000000
Local RocksDB is initializing the embedding table: wdl0
Last Iteration insert successfully
Local RocksDB is initializing the embedding table: wdl1
Last Iteration insert successfully
[14d04h24m08s][HUGECTR][INFO]: Global seed is 2483322206
[14d04h24m08s][HUGECTR][INFO]: Device to NUMA mapping:
  GPU 2 ->  node 1

[14d04h24m13s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
[14d04h24m13s][HUGECTR][INFO]: Start all2all warmup
[14d04h24m13s][HUGECTR][INFO]: End all2all warmup
[14d04h24m13s][HUGECTR][INFO]: Use mixed precision: 0
[14d04h24m13s][HUGECTR][INFO]: start create embedding for inference
[14d04h24m13s][HUGECTR][INFO]: sparse_input name wide_data
[14d04h24m13s][HUGECTR][INFO]: sparse_input name deep_data
[14d04h24m13s][HUGECTR][INFO]: create embedding for inference success
[14d04h24m13s][HUGECTR][INFO]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer
Rocksdb gets missing keys from model: wdl and table: 0
Rocksdb gets missing keys from model: wdl and table: 1
WDL multi-embedding table inference result is [0.2726621925830841, 0.16786302626132965, 0.06844793260097504, 0.21687281131744385, 0.28839486837387085, 0.09961184859275818, 0.1451544463634491, 0.1859627217054367, 0.1754387617111206, 0.14994166791439056]
[HUGECTR][INFO] WDL multi-embedding table inference using GPU cache, prediction error is less  than threshold:0.0001, error is 1.1102230246251565e-16

Prepare Inference Request

!ls -l /wdl_train/val
total 637376
-rw-r--r-- 1 root root 142856977 Jul  5 05:44 0.110d099942694a5cbf1b71eb73e10f27.parquet
-rw-r--r-- 1 root root        51 Jul  6 07:02 _file_list.txt
-rw-r--r-- 1 root root     27701 Jul  5 05:44 _metadata
-rw-r--r-- 1 root root      1537 Jul  5 05:44 _metadata.json
drwxr-xr-x 2 root root      4096 Jul  5 05:42 temp-parquet-after-conversion
-rw-r--r-- 1 1025 1025 509766965 Jul  5 04:45 test.txt
import pandas as pd
df = pd.read_parquet("/wdl_train/val/0.110d099942694a5cbf1b71eb73e10f27.parquet")

df.head()
I1 I2 I3 I4 I5 I6 I7 I8 I9 I10 ... C17 C18 C19 C20 C21 C22 C23 C24 C25 C26
0 0.061161 0.974006 -0.594327 -0.157301 -0.224758 0.618222 -0.064249 -0.281810 -0.760031 1.386036 ... 2 666 1 33722 24373 91481 62242 7673 44 28
1 -0.061206 -0.437431 0.156849 -0.146861 -0.193763 0.893091 -0.064249 0.286841 -0.109336 3.242455 ... 1 666 10 0 97438 0 21446 4472 56 19
2 0.043427 -0.464600 -0.379705 -0.120014 0.054203 -0.206385 -0.064249 -0.093999 -0.543133 -0.470383 ... 1 575 10 0 46601 0 12090 540 10 17
3 -0.059432 -0.273058 -0.487016 -0.143878 -0.193763 -0.206385 -0.064249 -0.279201 -0.109336 -0.470383 ... 0 351 10 125237 4329 238309 0 8488 56 22
4 -0.048792 -0.418412 0.693403 0.300589 -0.193763 -0.206385 -0.064249 -0.281810 0.902856 -0.470383 ... 0 575 7 69747 76381 207280 0 444 73 22

5 rows × 42 columns

df.head(10).to_csv('/wdl_train/infer_test.csv', sep=',', index=False,header=True)

Create prediction scripts

%%writefile '/wdl_train/wdl_predict.py'
from hugectr.inference import InferenceParams, CreateInferenceSession
import hugectr
import pandas as pd
import numpy as np
import sys
from mpi4py import MPI
def wdl_inference(model_name, network_file, dense_file, embedding_file_list, data_file,enable_cache,dbtype=hugectr.Database_t.Local,rocksdb_path=""):
    CATEGORICAL_COLUMNS=["C" + str(x) for x in range(1, 27)]+["C1_C2","C3_C4"]
    CONTINUOUS_COLUMNS=["I" + str(x) for x in range(1, 14)]
    LABEL_COLUMNS = ['label']
    emb_size = [249058, 19561, 14212, 6890, 18592, 4, 6356, 1254, 52, 226170, 80508, 72308, 11, 2169, 7597, 61, 4, 923, 15, 249619, 168974, 243480, 68212, 9169, 75, 34, 278018, 415262]
    shift = np.insert(np.cumsum(emb_size), 0, 0)[:-1]
    test_df=pd.read_csv(data_file,sep=',')
    config_file = network_file
    row_ptrs = list(range(0,21))+list(range(0,261))
    dense_features =  list(test_df[CONTINUOUS_COLUMNS].values.flatten())
    test_df[CATEGORICAL_COLUMNS].astype(np.int64)
    embedding_columns = list((test_df[CATEGORICAL_COLUMNS]+shift).values.flatten())
    

    # create parameter server, embedding cache and inference session
    inference_params = InferenceParams(model_name = model_name,
                                max_batchsize = 64,
                                hit_rate_threshold = 0.5,
                                dense_model_file = dense_file,
                                sparse_model_files = embedding_file_list,
                                device_id = 2,
                                use_gpu_embedding_cache = enable_cache,
                                cache_size_percentage = 0.9,
                                i64_input_key = True,
                                use_mixed_precision = False,
                                db_type = dbtype,
                                rocksdb_path=rocksdb_path,
                                cache_size_percentage_redis=0.5)
    inference_session = CreateInferenceSession(config_file, inference_params)
    output = inference_session.predict(dense_features, embedding_columns, row_ptrs)
    print("WDL multi-embedding table inference result is {}".format(output))

if __name__ == "__main__":
    model_name = sys.argv[1]
    print("{} multi-embedding table prediction".format(model_name))
    network_file = sys.argv[2]
    print("{} multi-embedding table prediction network is {}".format(model_name,network_file))
    dense_file = sys.argv[3]
    print("{} multi-embedding table prediction dense file is {}".format(model_name,dense_file))
    embedding_file_list = str(sys.argv[4]).split(',')
    print("{} multi-embedding table prediction sparse files are {}".format(model_name,embedding_file_list))
    data_file = sys.argv[5]
    print("{} multi-embedding table prediction input data path is {}".format(model_name,data_file))
    input_dbtype = sys.argv[6]
    print("{} multi-embedding table prediction input dbtype path is {}".format(model_name,input_dbtype))
    if input_dbtype=="local":
        wdl_inference(model_name, network_file, dense_file, embedding_file_list, data_file, True, hugectr.Database_t.Local)
    if input_dbtype=="rocksdb":
        rocksdb_path = sys.argv[7]
        print("{} multi-embedding table prediction rocksdb_path path is {}".format(model_name,rocksdb_path))
        wdl_inference(model_name, network_file, dense_file, embedding_file_list, data_file, True, hugectr.Database_t.RocksDB,rocksdb_path)
Overwriting /wdl_train/wdl_predict.py

Prediction

Use different types of databases as a local parameter server to get the wide and deep model prediction results.

Load model embedding tables into local memory as parameter server

!python /wdl_train/wdl_predict.py "wdl" "/wdl_infer/model/wdl/1/wdl.json" "/wdl_infer/model/wdl/1/wdl_dense_20000.model" "/wdl_infer/model/wdl/1/wdl0_sparse_20000.model/,/wdl_infer/model/wdl/1/wdl1_sparse_20000.model" "/wdl_train/infer_test.csv" "local"
wdl multi-embedding table prediction
wdl multi-embedding table prediction network is /wdl_infer/model/wdl/1/wdl.json
wdl multi-embedding table prediction dense file is /wdl_infer/model/wdl/1/wdl_dense_20000.model
wdl multi-embedding table prediction sparse files are ['/wdl_infer/model/wdl/1/wdl0_sparse_20000.model/', '/wdl_infer/model/wdl/1/wdl1_sparse_20000.model']
wdl multi-embedding table prediction input data path is /wdl_train/infer_test.csv
wdl multi-embedding table prediction input dbtype path is local
[15d09h17m26s][HUGECTR][INFO]: default_emb_vec_value is not specified using default: 0.000000
[15d09h17m26s][HUGECTR][INFO]: default_emb_vec_value is not specified using default: 0.000000
[15d09h17m32s][HUGECTR][INFO]: Global seed is 1361897547
[15d09h17m32s][HUGECTR][INFO]: Device to NUMA mapping:
  GPU 2 ->  node 1

[15d09h17m48s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
[15d09h17m48s][HUGECTR][INFO]: Start all2all warmup
[15d09h17m48s][HUGECTR][INFO]: End all2all warmup
[15d09h17m48s][HUGECTR][INFO]: Use mixed precision: 0
[15d09h17m48s][HUGECTR][INFO]: start create embedding for inference
[15d09h17m48s][HUGECTR][INFO]: sparse_input name wide_data
[15d09h17m48s][HUGECTR][INFO]: sparse_input name deep_data
[15d09h17m48s][HUGECTR][INFO]: create embedding for inference success
[15d09h17m48s][HUGECTR][INFO]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer
WDL multi-embedding table inference result is [0.019959857687354088, 0.025274723768234253, 0.017903145402669907, 0.006932722870260477, 0.02339070290327072, 0.022747302427887917, 0.05989734083414078, 0.015981541946530342, 0.005822415463626385, 0.01423134095966816]

Load model embedding tables into local RocksDB as a parameter Server

Create a RocksDB directory with read and write permissions for storing model embedded tables.

!mkdir -p -m 700 /wdl_train/rocksdb
!python /wdl_train/wdl_predict.py "wdl" "/wdl_infer/model/wdl/1/wdl.json" \
"/wdl_infer/model/wdl/1/wdl_dense_20000.model" \
"/wdl_infer/model/wdl/1/wdl0_sparse_20000.model/,/wdl_infer/model/wdl/1/wdl1_sparse_20000.model" \
"/wdl_train/infer_test.csv" \
"rocksdb"  "/wdl_train/rocksdb"
wdl multi-embedding table prediction
wdl multi-embedding table prediction network is /wdl_infer/model/wdl/1/wdl.json
wdl multi-embedding table prediction dense file is /wdl_infer/model/wdl/1/wdl_dense_20000.model
wdl multi-embedding table prediction sparse files are ['/wdl_infer/model/wdl/1/wdl0_sparse_20000.model/', '/wdl_infer/model/wdl/1/wdl1_sparse_20000.model']
wdl multi-embedding table prediction input data path is /wdl_train/infer_test.csv
wdl multi-embedding table prediction input dbtype path is rocksdb
wdl multi-embedding table prediction rocksdb_path path is /wdl_train/rocksdb
[15d12h32m00s][HUGECTR][INFO]: default_emb_vec_value is not specified using default: 0.000000
[15d12h32m00s][HUGECTR][INFO]: default_emb_vec_value is not specified using default: 0.000000
Local RocksDB is initializing the embedding table: wdl0
Last Iteration insert successfully
Local RocksDB is initializing the embedding table: wdl1
Last Iteration insert successfully
[15d12h32m07s][HUGECTR][INFO]: Global seed is 1156574989
[15d12h32m08s][HUGECTR][INFO]: Device to NUMA mapping:
  GPU 2 ->  node 1

[15d12h32m21s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.
[15d12h32m21s][HUGECTR][INFO]: Start all2all warmup
[15d12h32m21s][HUGECTR][INFO]: End all2all warmup
[15d12h32m21s][HUGECTR][INFO]: Use mixed precision: 0
[15d12h32m21s][HUGECTR][INFO]: start create embedding for inference
[15d12h32m21s][HUGECTR][INFO]: sparse_input name wide_data
[15d12h32m21s][HUGECTR][INFO]: sparse_input name deep_data
[15d12h32m21s][HUGECTR][INFO]: create embedding for inference success
[15d12h32m21s][HUGECTR][INFO]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer
Rocksdb gets missing keys from model: wdl and table: 0
Rocksdb gets missing keys from model: wdl and table: 1
WDL multi-embedding table inference result is [0.019959857687354088, 0.025274723768234253, 0.017903145402669907, 0.006932722870260477, 0.02339070290327072, 0.022747302427887917, 0.05989734083414078, 0.015981541946530342, 0.005822415463626385, 0.01423134095966816]