# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-hps-table-fusion-demo/nvidia_logo.png

HPS Table Fusion Demo

Overview

This notebook demonstrates how to fuse embedding tables of the same embedding vector size with the HPS plugin for TensorFlow. It is recommended to run hierarchical_parameter_server_demo.ipynb before diving into this notebook.

For more details about HPS APIs, please refer to HPS APIs. For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

Installation

Get HPS from NGC

The HPS Python module is preinstalled in the 23.12 and later Merlin TensorFlow Container: nvcr.io/nvidia/merlin/merlin-tensorflow:23.12.

You can check the existence of the required libraries by running the following Python code after launching this container.

$ python3 -c "import hierarchical_parameter_server as hps"

Create TF SavedModel

First of all we specify the required configurations, e.g., the arguments needed for generating the embedding tables, the template HPS JSON configuration file. We will use a naive deep neural network (DNN) model which has 8 embedding tables of the same emebedding vector size and one fully connected layer in this notebook.

We define the model with hps.LookupLayer and some native TF layers, and then save it in the SavedModel format. Please note that the table fusion is turned off here by setting fuse_embedding_table as False.

%%writefile create_model_for_table_fusion.py

import hierarchical_parameter_server as hps
import tensorflow as tf
import os
import numpy as np
import struct
import json
import pytest
import time

NUM_GPUS = 1
VOCAB_SIZE = 10000
EMB_VEC_SIZE = 128
NUM_QUERY_KEY = 26
EMB_VEC_DTYPE = np.float32
TF_KEY_TYPE = tf.int32
MAX_BATCH_SIZE = 256
NUM_ITERS = 100
NUM_TABLES = 8
USE_CONTEXT_STREAM = True

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(NUM_GPUS)))

gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
tf.config.threading.set_inter_op_parallelism_threads(1)

hps_config = {
    "supportlonglong": False,
    "fuse_embedding_table": True,
    "models": [
        {
            "model": str(NUM_TABLES) + "_table",
            "sparse_files": [],
            "num_of_worker_buffer_in_pool": NUM_TABLES,
            "embedding_table_names": [],
            "embedding_vecsize_per_table": [],
            "maxnum_catfeature_query_per_table_per_sample": [],
            "default_value_for_each_table": [0.0],
            "deployed_device_list": [0],
            "max_batch_size": MAX_BATCH_SIZE,
            "cache_refresh_percentage_per_iteration": 1.0,
            "hit_rate_threshold": 1.0,
            "gpucacheper": 1.0,
            "gpucache": True,
            "embedding_cache_type": "dynamic",
            "use_context_stream": True,
        }
    ],
}

def generate_embedding_tables(hugectr_sparse_model, vocab_range, embedding_vec_size):
    os.system("mkdir -p {}".format(hugectr_sparse_model))
    with open("{}/key".format(hugectr_sparse_model), "wb") as key_file, open(
        "{}/emb_vector".format(hugectr_sparse_model), "wb"
    ) as vec_file:
        for key in range(vocab_range[0], vocab_range[1]):
            vec = 0.00025 * np.ones((embedding_vec_size,)).astype(np.float32)
            key_struct = struct.pack("q", key)
            vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
            key_file.write(key_struct)
            vec_file.write(vec_struct)


def set_up_model_files():
    for i in range(NUM_TABLES):
        table_name = "table" + str(i)
        model_file_name = "embeddings/" + table_name
        generate_embedding_tables(
            model_file_name, [i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE], EMB_VEC_SIZE
        )
        hps_config["models"][0]["sparse_files"].append(model_file_name)
        hps_config["models"][0]["embedding_table_names"].append(table_name)
        hps_config["models"][0]["embedding_vecsize_per_table"].append(EMB_VEC_SIZE)
        hps_config["models"][0]["maxnum_catfeature_query_per_table_per_sample"].append(
            NUM_QUERY_KEY
        )
    return hps_config

class InferenceModel(tf.keras.models.Model):
    def __init__(self, num_tables, **kwargs):
        super(InferenceModel, self).__init__(**kwargs)
        self.lookup_layers = []
        for i in range(num_tables):
            self.lookup_layers.append(
                hps.LookupLayer(
                    model_name=str(NUM_TABLES) + "_table",
                    table_id=i,
                    emb_vec_size=EMB_VEC_SIZE,
                    emb_vec_dtype=EMB_VEC_DTYPE,
                    ps_config_file=str(NUM_TABLES) + "_table.json",
                    global_batch_size=MAX_BATCH_SIZE,
                    name="embedding_lookup" + str(i),
                )
            )
        self.fc = tf.keras.layers.Dense(
            units=1,
            activation=None,
            kernel_initializer="ones",
            bias_initializer="zeros",
            name="fc",
        )

    def call(self, inputs):
        assert len(inputs) == len(self.lookup_layers)
        embeddings = []
        for i in range(len(inputs)):
            embeddings.append(
                tf.reshape(
                    self.lookup_layers[i](inputs[i]), shape=[-1, NUM_QUERY_KEY * EMB_VEC_SIZE]
                )
            )
        concat_embeddings = tf.concat(embeddings, axis=1)
        logit = self.fc(concat_embeddings)
        return logit

    def summary(self):
        inputs = []
        for _ in range(len(self.lookup_layers)):
            inputs.append(tf.keras.Input(shape=(NUM_QUERY_KEY,), dtype=TF_KEY_TYPE))
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()


def create_savedmodel(hps_config):
    # Overwrite JSON configuration file
    hps_config["fuse_embedding_table"] = False
    hps_config_json_object = json.dumps(hps_config, indent=4)
    with open(str(NUM_TABLES) + "_table.json", "w") as outfile:
        outfile.write(hps_config_json_object)

    model = InferenceModel(NUM_TABLES)
    model.summary()
    inputs = []
    for i in range(NUM_TABLES):
        inputs.append(
            np.random.randint(
                i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE, (MAX_BATCH_SIZE, NUM_QUERY_KEY)
            ).astype(np.int32)
        )
    model(inputs)
    model.save(str(NUM_TABLES) + "_table.savedmodel")

    # Overwrite JSON configuration file
    hps_config["fuse_embedding_table"] = True
    hps_config_json_object = json.dumps(hps_config, indent=4)
    with open(str(NUM_TABLES) + "_table.json", "w") as outfile:
        outfile.write(hps_config_json_object)

if __name__ == "__main__":
    hps_config = set_up_model_files()
    create_savedmodel(hps_config)
Writing create_model_for_table_fusion.py
import os
os.system("python3 create_model_for_table_fusion.py")
2023-03-29 07:24:28.206281: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-29 07:24:36.420084: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-29 07:24:36.926162: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1637] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30996 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
[INFO] hierarchical_parameter_server is imported
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 26)]         0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 26)]         0           []                               
                                                                                                  
 input_3 (InputLayer)           [(None, 26)]         0           []                               
                                                                                                  
 input_4 (InputLayer)           [(None, 26)]         0           []                               
                                                                                                  
 input_5 (InputLayer)           [(None, 26)]         0           []                               
                                                                                                  
 input_6 (InputLayer)           [(None, 26)]         0           []                               
                                                                                                  
 input_7 (InputLayer)           [(None, 26)]         0           []                               
                                                                                                  
 input_8 (InputLayer)           [(None, 26)]         0           []                               
                                                                                                  
 embedding_lookup0 (LookupLayer  (None, 26, 128)     0           ['input_1[0][0]']                
 )                                                                                                
                                                                                                  
 embedding_lookup1 (LookupLayer  (None, 26, 128)     0           ['input_2[0][0]']                
 )                                                                                                
                                                                                                  
 embedding_lookup2 (LookupLayer  (None, 26, 128)     0           ['input_3[0][0]']                
 )                                                                                                
                                                                                                  
 embedding_lookup3 (LookupLayer  (None, 26, 128)     0           ['input_4[0][0]']                
 )                                                                                                
                                                                                                  
 embedding_lookup4 (LookupLayer  (None, 26, 128)     0           ['input_5[0][0]']                
 )                                                                                                
                                                                                                  
 embedding_lookup5 (LookupLayer  (None, 26, 128)     0           ['input_6[0][0]']                
 )                                                                                                
                                                                                                  
 embedding_lookup6 (LookupLayer  (None, 26, 128)     0           ['input_7[0][0]']                
 )                                                                                                
                                                                                                  
 embedding_lookup7 (LookupLayer  (None, 26, 128)     0           ['input_8[0][0]']                
 )                                                                                                
                                                                                                  
 tf.reshape (TFOpLambda)        (None, 3328)         0           ['embedding_lookup0[0][0]']      
                                                                                                  
 tf.reshape_1 (TFOpLambda)      (None, 3328)         0           ['embedding_lookup1[0][0]']      
                                                                                                  
 tf.reshape_2 (TFOpLambda)      (None, 3328)         0           ['embedding_lookup2[0][0]']      
                                                                                                  
 tf.reshape_3 (TFOpLambda)      (None, 3328)         0           ['embedding_lookup3[0][0]']      
                                                                                                  
 tf.reshape_4 (TFOpLambda)      (None, 3328)         0           ['embedding_lookup4[0][0]']      
                                                                                                  
 tf.reshape_5 (TFOpLambda)      (None, 3328)         0           ['embedding_lookup5[0][0]']      
                                                                                                  
 tf.reshape_6 (TFOpLambda)      (None, 3328)         0           ['embedding_lookup6[0][0]']      
                                                                                                  
 tf.reshape_7 (TFOpLambda)      (None, 3328)         0           ['embedding_lookup7[0][0]']      
                                                                                                  
 tf.concat (TFOpLambda)         (None, 26624)        0           ['tf.reshape[0][0]',             
                                                                  'tf.reshape_1[0][0]',           
                                                                  'tf.reshape_2[0][0]',           
                                                                  'tf.reshape_3[0][0]',           
                                                                  'tf.reshape_4[0][0]',           
                                                                  'tf.reshape_5[0][0]',           
                                                                  'tf.reshape_6[0][0]',           
                                                                  'tf.reshape_7[0][0]']           
                                                                                                  
 fc (Dense)                     (None, 1)            26625       ['tf.concat[0][0]']              
                                                                                                  
==================================================================================================
Total params: 26,625
Trainable params: 26,625
Non-trainable params: 0
__________________________________________________________________________________________________
=====================================================HPS Parse====================================================
[HCTR][07:24:38.079][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][07:24:38.079][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][07:24:38.079][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][07:24:38.079][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][07:24:38.079][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][07:24:38.079][INFO][RK0][main]: refresh_interval is not specified using default: 0
[HCTR][07:24:38.079][INFO][RK0][main]: use_static_table is not specified using default: 0
[HCTR][07:24:38.079][INFO][RK0][main]: HPS plugin uses context stream for model 8_table: True
====================================================HPS Create====================================================
[HCTR][07:24:38.080][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][07:24:38.080][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][07:24:38.080][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][07:24:38.080][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][07:24:38.080][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:24:38.547][INFO][RK0][main]: Table: hps_et.8_table.table0; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:39.379][INFO][RK0][main]: Table: hps_et.8_table.table1; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:39.830][INFO][RK0][main]: Table: hps_et.8_table.table2; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:40.448][INFO][RK0][main]: Table: hps_et.8_table.table3; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:40.899][INFO][RK0][main]: Table: hps_et.8_table.table4; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:41.934][INFO][RK0][main]: Table: hps_et.8_table.table5; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:43.097][INFO][RK0][main]: Table: hps_et.8_table.table6; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:45.296][INFO][RK0][main]: Table: hps_et.8_table.table7; cached 10000 / 10000 embeddings in volatile database (HashMapBackend); load: 10000 / 18446744073709551615 (0.00%).
[HCTR][07:24:45.296][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][07:24:45.297][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][07:24:45.306][INFO][RK0][main]: Model name: 8_table
[HCTR][07:24:45.306][INFO][RK0][main]: Max batch size: 256
[HCTR][07:24:45.306][INFO][RK0][main]: Fuse embedding tables: False
[HCTR][07:24:45.306][INFO][RK0][main]: Number of embedding tables: 8
[HCTR][07:24:45.306][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][07:24:45.306][INFO][RK0][main]: Embedding cache type: dynamic
[HCTR][07:24:45.306][INFO][RK0][main]: Use I64 input key: False
[HCTR][07:24:45.306][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][07:24:45.306][INFO][RK0][main]: The size of thread pool: 80
[HCTR][07:24:45.306][INFO][RK0][main]: The size of worker memory pool: 8
[HCTR][07:24:45.306][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][07:24:45.306][INFO][RK0][main]: The refresh percentage : 1.000000
[HCTR][07:24:45.469][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table0
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table1
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table2
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table3
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table4
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table5
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table6
[HCTR][07:24:45.470][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.table7
[HCTR][07:24:45.475][INFO][RK0][main]: LookupSession i64_input_key: False
[HCTR][07:24:45.475][INFO][RK0][main]: Creating lookup session for 8_table on device: 0
0

Make inference with HPS table fusion

We load the TF SavedModel and make inference for several batches. The table fusion is enabled since fuse_embedding_table is True within the HPS JSON configuration file.

import hierarchical_parameter_server as hps
import tensorflow as tf
import os
import numpy as np
import time

NUM_GPUS = 1
VOCAB_SIZE = 10000
NUM_QUERY_KEY = 26
MAX_BATCH_SIZE = 256
NUM_ITERS = 100
NUM_TABLES = 8

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(NUM_GPUS)))

gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
tf.config.threading.set_inter_op_parallelism_threads(1)


model = tf.keras.models.load_model(str(NUM_TABLES) + "_table.savedmodel")
inputs_seq = []
for _ in range(NUM_ITERS + 1):
    inputs = []
    for i in range(NUM_TABLES):
        inputs.append(
            np.random.randint(
                i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE, (MAX_BATCH_SIZE, NUM_QUERY_KEY)
            ).astype(np.int32)
        )
    inputs_seq.append(inputs)
preds = model(inputs_seq[0])
start = time.time()
for i in range(NUM_ITERS):
    print("-" * 20, "Step {}".format(i), "-" * 20)
    preds = model(inputs_seq[i + 1])
end = time.time()
print(
    "[INFO] Elapsed time for "
    + str(NUM_ITERS)
    + " iterations: "
    + str(end - start)
    + " seconds"
)
[INFO] hierarchical_parameter_server is imported
2023-03-29 07:25:39.918038: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-29 07:25:42.325440: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-29 07:25:42.818316: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1637] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30996 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
=====================================================HPS Parse====================================================
[HCTR][07:25:43.756][INFO][RK0][main]: Table fusion is enabled for HPS. Please ensure that there is no key value overlap in different tables and the embedding lookup layer has no dependency in the model graph. For more information, see https://nvidia-merlin.github.io/HugeCTR/main/hierarchical_parameter_server/hps_database_backend.html#configuration
[HCTR][07:25:43.756][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][07:25:43.756][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][07:25:43.756][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][07:25:43.756][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][07:25:43.756][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][07:25:43.756][INFO][RK0][main]: refresh_interval is not specified using default: 0
[HCTR][07:25:43.756][INFO][RK0][main]: use_static_table is not specified using default: 0
[HCTR][07:25:43.756][INFO][RK0][main]: HPS plugin uses context stream for model 8_table: True
====================================================HPS Create====================================================
[HCTR][07:25:43.756][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][07:25:43.756][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][07:25:43.756][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][07:25:43.756][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][07:25:43.756][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:25:44.292][INFO][RK0][main]: Table: hps_et.8_table.fused_embedding0; cached 80000 / 80000 embeddings in volatile database (HashMapBackend); load: 80000 / 18446744073709551615 (0.00%).
[HCTR][07:25:44.292][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][07:25:44.292][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][07:25:44.299][INFO][RK0][main]: Model name: 8_table
[HCTR][07:25:44.299][INFO][RK0][main]: Max batch size: 256
[HCTR][07:25:44.299][INFO][RK0][main]: Fuse embedding tables: True
[HCTR][07:25:44.299][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][07:25:44.299][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][07:25:44.299][INFO][RK0][main]: Embedding cache type: dynamic
[HCTR][07:25:44.299][INFO][RK0][main]: Use I64 input key: False
[HCTR][07:25:44.299][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][07:25:44.299][INFO][RK0][main]: The size of thread pool: 80
[HCTR][07:25:44.299][INFO][RK0][main]: The size of worker memory pool: 8
[HCTR][07:25:44.299][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][07:25:44.299][INFO][RK0][main]: The refresh percentage : 1.000000
[HCTR][07:25:44.406][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][07:25:44.406][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.fused_embedding0
[HCTR][07:25:44.407][INFO][RK0][main]: LookupSession i64_input_key: False
[HCTR][07:25:44.407][INFO][RK0][main]: Creating lookup session for 8_table on device: 0
-------------------- Step 0 --------------------
-------------------- Step 1 --------------------
-------------------- Step 2 --------------------
-------------------- Step 3 --------------------
-------------------- Step 4 --------------------
-------------------- Step 5 --------------------
-------------------- Step 6 --------------------
-------------------- Step 7 --------------------
-------------------- Step 8 --------------------
-------------------- Step 9 --------------------
-------------------- Step 10 --------------------
-------------------- Step 11 --------------------
-------------------- Step 12 --------------------
-------------------- Step 13 --------------------
-------------------- Step 14 --------------------
-------------------- Step 15 --------------------
-------------------- Step 16 --------------------
-------------------- Step 17 --------------------
-------------------- Step 18 --------------------
-------------------- Step 19 --------------------
-------------------- Step 20 --------------------
-------------------- Step 21 --------------------
-------------------- Step 22 --------------------
-------------------- Step 23 --------------------
-------------------- Step 24 --------------------
-------------------- Step 25 --------------------
-------------------- Step 26 --------------------
-------------------- Step 27 --------------------
-------------------- Step 28 --------------------
-------------------- Step 29 --------------------
-------------------- Step 30 --------------------
-------------------- Step 31 --------------------
-------------------- Step 32 --------------------
-------------------- Step 33 --------------------
-------------------- Step 34 --------------------
-------------------- Step 35 --------------------
-------------------- Step 36 --------------------
-------------------- Step 37 --------------------
-------------------- Step 38 --------------------
-------------------- Step 39 --------------------
-------------------- Step 40 --------------------
-------------------- Step 41 --------------------
-------------------- Step 42 --------------------
-------------------- Step 43 --------------------
-------------------- Step 44 --------------------
-------------------- Step 45 --------------------
-------------------- Step 46 --------------------
-------------------- Step 47 --------------------
-------------------- Step 48 --------------------
-------------------- Step 49 --------------------
-------------------- Step 50 --------------------
-------------------- Step 51 --------------------
-------------------- Step 52 --------------------
-------------------- Step 53 --------------------
-------------------- Step 54 --------------------
-------------------- Step 55 --------------------
-------------------- Step 56 --------------------
-------------------- Step 57 --------------------
-------------------- Step 58 --------------------
-------------------- Step 59 --------------------
-------------------- Step 60 --------------------
-------------------- Step 61 --------------------
-------------------- Step 62 --------------------
-------------------- Step 63 --------------------
-------------------- Step 64 --------------------
-------------------- Step 65 --------------------
-------------------- Step 66 --------------------
-------------------- Step 67 --------------------
-------------------- Step 68 --------------------
-------------------- Step 69 --------------------
-------------------- Step 70 --------------------
-------------------- Step 71 --------------------
-------------------- Step 72 --------------------
-------------------- Step 73 --------------------
-------------------- Step 74 --------------------
-------------------- Step 75 --------------------
-------------------- Step 76 --------------------
-------------------- Step 77 --------------------
-------------------- Step 78 --------------------
-------------------- Step 79 --------------------
-------------------- Step 80 --------------------
-------------------- Step 81 --------------------
-------------------- Step 82 --------------------
-------------------- Step 83 --------------------
-------------------- Step 84 --------------------
-------------------- Step 85 --------------------
-------------------- Step 86 --------------------
-------------------- Step 87 --------------------
-------------------- Step 88 --------------------
-------------------- Step 89 --------------------
-------------------- Step 90 --------------------
-------------------- Step 91 --------------------
-------------------- Step 92 --------------------
-------------------- Step 93 --------------------
-------------------- Step 94 --------------------
-------------------- Step 95 --------------------
-------------------- Step 96 --------------------
-------------------- Step 97 --------------------
-------------------- Step 98 --------------------
-------------------- Step 99 --------------------
[INFO] Elapsed time for 100 iterations: 0.9442901611328125 seconds