# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
http://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_hugectr_hps-hierarchical-parameter-server-demo/nvidia_logo.png

HPS Torch Demo

Overview

Hierarchical Parameter Server (HPS) is a distributed recommendation inference framework, which combines a high-performance GPU embedding cache with an hierarchical storage architecture, to realize low-latency retrieval of embeddings for inference tasks. It is provided as a PyTorch plugin and can be easily used in the Torch model.

This notebook demonstrates how to apply HPS to the Torch model and use it for inference. For more details about HPS APIs, please refer to HPS APIs. For more details about HPS, please refer to HugeCTR Hierarchical Parameter Server (HPS).

Installation

Get HPS from NGC

The HPS Python module is preinstalled in the 23.09 and later Merlin PyTorch Container: nvcr.io/nvidia/merlin/merlin-pytorch:23.09.

You can check the existence of the required libraries by running the following Python code after launching this container.

$ python3 -c "import hps_torch"

Data Generation

First of all we specify the required configurations for data generation. We generate 8 embedding tables, all with the same embedding vector size 128. The maximum batch size is 256 and each sample has 10 keys to lookup up for each table.

import torch
import hps_torch
from typing import List
import os
import numpy as np
import struct
import json
import pytest
import time

NUM_GPUS = 1
VOCAB_SIZE = 10000
EMB_VEC_SIZE = 128
NUM_QUERY_KEY = 10
MAX_BATCH_SIZE = 256
NUM_ITERS = 100
NUM_TABLES = 8
USE_CONTEXT_STREAM = True

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(NUM_GPUS)))
[INFO] hps_torch is imported
hps_config = {
    "supportlonglong": False,
    "fuse_embedding_table": True,
    "models": [
        {
            "model": str(NUM_TABLES) + "_table",
            "sparse_files": [],
            "num_of_worker_buffer_in_pool": NUM_TABLES,
            "embedding_table_names": [],
            "embedding_vecsize_per_table": [],
            "maxnum_catfeature_query_per_table_per_sample": [],
            "default_value_for_each_table": [0.0],
            "deployed_device_list": [0],
            "max_batch_size": MAX_BATCH_SIZE,
            "cache_refresh_percentage_per_iteration": 1.0,
            "hit_rate_threshold": 1.0,
            "gpucacheper": 1.0,
            "gpucache": True,
            "embedding_cache_type": "static",
            "use_context_stream": True,
        }
    ],
}

def generate_embedding_tables(
    hugectr_sparse_model, vocab_range, embedding_vec_size, embedding_table
):
    os.system("mkdir -p {}".format(hugectr_sparse_model))
    with open("{}/key".format(hugectr_sparse_model), "wb") as key_file, open(
        "{}/emb_vector".format(hugectr_sparse_model), "wb"
    ) as vec_file:
        for key in range(vocab_range[0], vocab_range[1]):
            vec = np.random.random((embedding_vec_size,)).astype(np.float32)
            key_struct = struct.pack("q", key)
            vec_struct = struct.pack(str(embedding_vec_size) + "f", *vec)
            key_file.write(key_struct)
            vec_file.write(vec_struct)
            embedding_table[key] = vec


def set_up_model_files():
    embedding_table = np.zeros((NUM_TABLES * VOCAB_SIZE, EMB_VEC_SIZE)).astype(np.float32)
    for i in range(NUM_TABLES):
        table_name = "table" + str(i)
        model_file_name = "embeddings/" + table_name
        generate_embedding_tables(
            model_file_name, [i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE], EMB_VEC_SIZE, embedding_table
        )
        hps_config["models"][0]["sparse_files"].append(model_file_name)
        hps_config["models"][0]["embedding_table_names"].append(table_name)
        hps_config["models"][0]["embedding_vecsize_per_table"].append(EMB_VEC_SIZE)
        hps_config["models"][0]["maxnum_catfeature_query_per_table_per_sample"].append(
            NUM_QUERY_KEY
        )
    hps_config_json_object = json.dumps(hps_config, indent=4)
    with open(str(NUM_TABLES) + "_table.json", "w") as outfile:
        outfile.write(hps_config_json_object)
    return embedding_table
embedding_table = set_up_model_files()
!du -lh embeddings
5.0M	embeddings/table0
5.0M	embeddings/table1
5.0M	embeddings/table2
5.0M	embeddings/table3
5.0M	embeddings/table4
5.0M	embeddings/table5
5.0M	embeddings/table6
5.0M	embeddings/table7
40M	embeddings

Lookup with Table Fusion

HPS supports fusing tables of the same embedding vector size via CPU multithreading. This can be achieved with torch.jit.fork and torch.jit.wait when the HPS plugin for Torch is employed. For more details, please refer to HPS Configuration.

We conduct embedding lookup with table fusion and compare the results with the ground truth.

class Model(torch.nn.Module):
    def __init__(self, ps_config_file: str, model_name: str, emb_vec_size: List[int]):
        super().__init__()
        self.layers = torch.nn.ModuleList(
            [
                hps_torch.LookupLayer(ps_config_file, model_name, table_id, emb_vec_size[table_id])
                for table_id in range(len(emb_vec_size))
            ]
        )

    def forward(self, keys_list: torch.Tensor):
        vectors = []
        futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], [])
        for i, layer in enumerate(self.layers):
            fut = torch.jit.fork(layer, keys_list[i])
            futures.append(fut)
        for i, _ in enumerate(self.layers):
            vectors.append(torch.jit.wait(futures[i]))
        return torch.cat(vectors)
model = torch.jit.script(
    Model(
        f"{NUM_TABLES}_table.json",
        f"{NUM_TABLES}_table",
        [EMB_VEC_SIZE for _ in range(NUM_TABLES)],
    )
)
inputs_seq = []
for _ in range(NUM_ITERS + 1):
    inputs = []
    for i in range(NUM_TABLES):
        inputs.append(
            torch.randint(
                i * VOCAB_SIZE,
                (i + 1) * VOCAB_SIZE,
                (MAX_BATCH_SIZE, NUM_QUERY_KEY),
                dtype=torch.int32,
            ).cuda()
        )
    inputs_seq.append(torch.stack(inputs))

preds = model(inputs_seq[0])
preds_seq = []
start = time.time()
for i in range(NUM_ITERS):
    preds_seq.append(model(inputs_seq[i + 1]))
end = time.time()
print(
    "[INFO] Elapsed time for "
    + str(NUM_ITERS)
    + " iterations: "
    + str(end - start)
    + " seconds"
)
preds_seq = torch.stack(preds_seq).cpu().numpy()

preds_seq_gt = []
for i in range(NUM_ITERS):
    preds_seq_gt.append(np.concatenate(embedding_table[inputs_seq[i + 1].cpu().numpy()]))
preds_seq_gt = np.array(preds_seq_gt)

diff = preds_seq - preds_seq_gt
mse = np.mean(diff * diff)
assert mse <= 1e-6
print(f"HPS Torch Plugin embedding lookup with table fusion, MSE: {mse} ")
=====================================================HPS Parse====================================================
[HCTR][05:25:11.836][INFO][RK0][main]: Table fusion is enabled for HPS. Please ensure that there is no key value overlap in different tables and the embedding lookup layer has no dependency in the model graph. For more information, see https://nvidia-merlin.github.io/HugeCTR/main/hierarchical_parameter_server/hps_database_backend.html#configuration
[HCTR][05:25:11.836][INFO][RK0][main]: fuse_embedding_table is not specified using default: 1
[HCTR][05:25:11.839][INFO][RK0][main]: dense_file is not specified using default: 
[HCTR][05:25:11.839][WARNING][RK0][main]: default_value_for_each_table.size() is not equal to the number of embedding tables
[HCTR][05:25:11.839][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][05:25:11.839][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][05:25:11.839][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][05:25:11.839][INFO][RK0][main]: refresh_interval is not specified using default: 0
[HCTR][05:25:11.839][INFO][RK0][main]: fuse_embedding_table is not specified using default: 1
[HCTR][05:25:11.839][INFO][RK0][main]: use_static_table is not specified using default: 0
[HCTR][05:25:11.839][INFO][RK0][main]: use_hctr_cache_implementation is not specified using default: 1
[HCTR][05:25:11.839][INFO][RK0][main]: thread_pool_size is not specified using default: 16
[HCTR][05:25:11.839][INFO][RK0][main]: init_ec is not specified using default: 1
[HCTR][05:25:11.839][INFO][RK0][main]: HPS plugin uses context stream for model 8_table: True
====================================================HPS Create====================================================
[HCTR][05:25:11.840][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][05:25:11.840][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][05:25:11.840][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][05:25:11.840][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][05:25:11.840][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][05:25:11.880][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][05:25:11.880][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][05:25:11.880][INFO][RK0][main]: Model name: 8_table
[HCTR][05:25:11.880][INFO][RK0][main]: Max batch size: 256
[HCTR][05:25:11.880][INFO][RK0][main]: Fuse embedding tables: True
[HCTR][05:25:11.880][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][05:25:11.880][INFO][RK0][main]: Embedding cache type: static
[HCTR][05:25:11.880][INFO][RK0][main]: Use I64 input key: False
[HCTR][05:25:11.880][INFO][RK0][main]: The size of worker memory pool: 8
[HCTR][05:25:11.880][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][05:25:11.880][INFO][RK0][main]: The refresh percentage : 1.000000
[HCTR][05:25:11.936][INFO][RK0][main]: Initialize the embedding cache by by inserting the same size model file with embedding cache from beginning
[HCTR][05:25:11.936][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][05:25:11.936][INFO][RK0][main]: EC initialization on device 0 for hps_et.8_table.fused_embedding0
[HCTR][05:25:11.936][INFO][RK0][main]: To achieve the best performance, when using static table, the pointers of keys and vectors in HPS lookup should preferably be aligned to at least 16 Bytes.
[HCTR][05:25:11.975][INFO][RK0][main]: Initialize the embedding table 0 for iteration 0 with number of 1000 keys.
[HCTR][05:25:12.018][INFO][RK0][main]: Initialize the embedding table 0 for iteration 1 with number of 1000 keys.
[HCTR][05:25:12.041][INFO][RK0][main]: Initialize the embedding table 0 for iteration 2 with number of 1000 keys.
[HCTR][05:25:12.059][INFO][RK0][main]: Initialize the embedding table 0 for iteration 3 with number of 1000 keys.
[HCTR][05:25:12.070][INFO][RK0][main]: Initialize the embedding table 0 for iteration 4 with number of 1000 keys.
[HCTR][05:25:12.088][INFO][RK0][main]: Initialize the embedding table 0 for iteration 5 with number of 1000 keys.
[HCTR][05:25:12.104][INFO][RK0][main]: Initialize the embedding table 0 for iteration 6 with number of 1000 keys.
[HCTR][05:25:12.113][INFO][RK0][main]: Initialize the embedding table 0 for iteration 7 with number of 1000 keys.
[HCTR][05:25:12.123][INFO][RK0][main]: Initialize the embedding table 0 for iteration 8 with number of 1000 keys.
[HCTR][05:25:12.137][INFO][RK0][main]: Initialize the embedding table 0 for iteration 9 with number of 1000 keys.
[HCTR][05:25:12.167][INFO][RK0][main]: Initialize the embedding table 0 for iteration 0 with number of 1000 keys.
[HCTR][05:25:12.196][INFO][RK0][main]: Initialize the embedding table 0 for iteration 1 with number of 1000 keys.
[HCTR][05:25:12.210][INFO][RK0][main]: Initialize the embedding table 0 for iteration 2 with number of 1000 keys.
[HCTR][05:25:12.223][INFO][RK0][main]: Initialize the embedding table 0 for iteration 3 with number of 1000 keys.
[HCTR][05:25:12.239][INFO][RK0][main]: Initialize the embedding table 0 for iteration 4 with number of 1000 keys.
[HCTR][05:25:12.252][INFO][RK0][main]: Initialize the embedding table 0 for iteration 5 with number of 1000 keys.
[HCTR][05:25:12.284][INFO][RK0][main]: Initialize the embedding table 0 for iteration 6 with number of 1000 keys.
[HCTR][05:25:12.296][INFO][RK0][main]: Initialize the embedding table 0 for iteration 7 with number of 1000 keys.
[HCTR][05:25:12.307][INFO][RK0][main]: Initialize the embedding table 0 for iteration 8 with number of 1000 keys.
[HCTR][05:25:12.319][INFO][RK0][main]: Initialize the embedding table 0 for iteration 9 with number of 1000 keys.
[HCTR][05:25:12.336][INFO][RK0][main]: Initialize the embedding table 0 for iteration 0 with number of 1000 keys.
[HCTR][05:25:12.360][INFO][RK0][main]: Initialize the embedding table 0 for iteration 1 with number of 1000 keys.
[HCTR][05:25:12.368][INFO][RK0][main]: Initialize the embedding table 0 for iteration 2 with number of 1000 keys.
[HCTR][05:25:12.380][INFO][RK0][main]: Initialize the embedding table 0 for iteration 3 with number of 1000 keys.
[HCTR][05:25:12.390][INFO][RK0][main]: Initialize the embedding table 0 for iteration 4 with number of 1000 keys.
[HCTR][05:25:12.409][INFO][RK0][main]: Initialize the embedding table 0 for iteration 5 with number of 1000 keys.
[HCTR][05:25:12.437][INFO][RK0][main]: Initialize the embedding table 0 for iteration 6 with number of 1000 keys.
[HCTR][05:25:12.446][INFO][RK0][main]: Initialize the embedding table 0 for iteration 7 with number of 1000 keys.
[HCTR][05:25:12.453][INFO][RK0][main]: Initialize the embedding table 0 for iteration 8 with number of 1000 keys.
[HCTR][05:25:12.475][INFO][RK0][main]: Initialize the embedding table 0 for iteration 9 with number of 1000 keys.
[HCTR][05:25:12.515][INFO][RK0][main]: Initialize the embedding table 0 for iteration 0 with number of 1000 keys.
[HCTR][05:25:12.535][INFO][RK0][main]: Initialize the embedding table 0 for iteration 1 with number of 1000 keys.
[HCTR][05:25:12.551][INFO][RK0][main]: Initialize the embedding table 0 for iteration 2 with number of 1000 keys.
[HCTR][05:25:12.560][INFO][RK0][main]: Initialize the embedding table 0 for iteration 3 with number of 1000 keys.
[HCTR][05:25:12.580][INFO][RK0][main]: Initialize the embedding table 0 for iteration 4 with number of 1000 keys.
[HCTR][05:25:12.597][INFO][RK0][main]: Initialize the embedding table 0 for iteration 5 with number of 1000 keys.
[HCTR][05:25:12.606][INFO][RK0][main]: Initialize the embedding table 0 for iteration 6 with number of 1000 keys.
[HCTR][05:25:12.615][INFO][RK0][main]: Initialize the embedding table 0 for iteration 7 with number of 1000 keys.
[HCTR][05:25:12.624][INFO][RK0][main]: Initialize the embedding table 0 for iteration 8 with number of 1000 keys.
[HCTR][05:25:12.632][INFO][RK0][main]: Initialize the embedding table 0 for iteration 9 with number of 1000 keys.
[HCTR][05:25:12.668][INFO][RK0][main]: Initialize the embedding table 0 for iteration 0 with number of 1000 keys.
[HCTR][05:25:12.678][INFO][RK0][main]: Initialize the embedding table 0 for iteration 1 with number of 1000 keys.
[HCTR][05:25:12.695][INFO][RK0][main]: Initialize the embedding table 0 for iteration 2 with number of 1000 keys.
[HCTR][05:25:12.712][INFO][RK0][main]: Initialize the embedding table 0 for iteration 3 with number of 1000 keys.
[HCTR][05:25:12.725][INFO][RK0][main]: Initialize the embedding table 0 for iteration 4 with number of 1000 keys.
[HCTR][05:25:12.740][INFO][RK0][main]: Initialize the embedding table 0 for iteration 5 with number of 1000 keys.
[HCTR][05:25:12.756][INFO][RK0][main]: Initialize the embedding table 0 for iteration 6 with number of 1000 keys.
[HCTR][05:25:12.768][INFO][RK0][main]: Initialize the embedding table 0 for iteration 7 with number of 1000 keys.
[HCTR][05:25:12.783][INFO][RK0][main]: Initialize the embedding table 0 for iteration 8 with number of 1000 keys.
[HCTR][05:25:12.794][INFO][RK0][main]: Initialize the embedding table 0 for iteration 9 with number of 1000 keys.
[HCTR][05:25:12.821][INFO][RK0][main]: Initialize the embedding table 0 for iteration 0 with number of 1000 keys.
[HCTR][05:25:12.844][INFO][RK0][main]: Initialize the embedding table 0 for iteration 1 with number of 1000 keys.
[HCTR][05:25:12.861][INFO][RK0][main]: Initialize the embedding table 0 for iteration 2 with number of 1000 keys.
[HCTR][05:25:12.880][INFO][RK0][main]: Initialize the embedding table 0 for iteration 3 with number of 1000 keys.
[HCTR][05:25:12.890][INFO][RK0][main]: Initialize the embedding table 0 for iteration 4 with number of 1000 keys.
[HCTR][05:25:12.900][INFO][RK0][main]: Initialize the embedding table 0 for iteration 5 with number of 1000 keys.
[HCTR][05:25:12.920][INFO][RK0][main]: Initialize the embedding table 0 for iteration 6 with number of 1000 keys.
[HCTR][05:25:12.929][INFO][RK0][main]: Initialize the embedding table 0 for iteration 7 with number of 1000 keys.
[HCTR][05:25:12.938][INFO][RK0][main]: Initialize the embedding table 0 for iteration 8 with number of 1000 keys.
[HCTR][05:25:12.957][INFO][RK0][main]: Initialize the embedding table 0 for iteration 9 with number of 1000 keys.
[HCTR][05:25:12.979][INFO][RK0][main]: Initialize the embedding table 0 for iteration 0 with number of 1000 keys.
[HCTR][05:25:13.006][INFO][RK0][main]: Initialize the embedding table 0 for iteration 1 with number of 1000 keys.
[HCTR][05:25:13.016][INFO][RK0][main]: Initialize the embedding table 0 for iteration 2 with number of 1000 keys.
[HCTR][05:25:13.027][INFO][RK0][main]: Initialize the embedding table 0 for iteration 3 with number of 1000 keys.
[HCTR][05:25:13.037][INFO][RK0][main]: Initialize the embedding table 0 for iteration 4 with number of 1000 keys.
[HCTR][05:25:13.046][INFO][RK0][main]: Initialize the embedding table 0 for iteration 5 with number of 1000 keys.
[HCTR][05:25:13.056][INFO][RK0][main]: Initialize the embedding table 0 for iteration 6 with number of 1000 keys.
[HCTR][05:25:13.064][INFO][RK0][main]: Initialize the embedding table 0 for iteration 7 with number of 1000 keys.
[HCTR][05:25:13.085][INFO][RK0][main]: Initialize the embedding table 0 for iteration 8 with number of 1000 keys.
[HCTR][05:25:13.095][INFO][RK0][main]: Initialize the embedding table 0 for iteration 9 with number of 1000 keys.
[HCTR][05:25:13.110][INFO][RK0][main]: Initialize the embedding table 0 for iteration 0 with number of 1000 keys.
[HCTR][05:25:13.125][INFO][RK0][main]: Initialize the embedding table 0 for iteration 1 with number of 1000 keys.
[HCTR][05:25:13.136][INFO][RK0][main]: Initialize the embedding table 0 for iteration 2 with number of 1000 keys.
[HCTR][05:25:13.163][INFO][RK0][main]: Initialize the embedding table 0 for iteration 3 with number of 1000 keys.
[HCTR][05:25:13.173][INFO][RK0][main]: Initialize the embedding table 0 for iteration 4 with number of 1000 keys.
[HCTR][05:25:13.183][INFO][RK0][main]: Initialize the embedding table 0 for iteration 5 with number of 1000 keys.
[HCTR][05:25:13.194][INFO][RK0][main]: Initialize the embedding table 0 for iteration 6 with number of 1000 keys.
[HCTR][05:25:13.212][INFO][RK0][main]: Initialize the embedding table 0 for iteration 7 with number of 1000 keys.
[HCTR][05:25:13.231][INFO][RK0][main]: Initialize the embedding table 0 for iteration 8 with number of 1000 keys.
[HCTR][05:25:13.249][INFO][RK0][main]: Initialize the embedding table 0 for iteration 9 with number of 1000 keys.
[HCTR][05:25:13.250][INFO][RK0][main]: LookupSession i64_input_key: False
[HCTR][05:25:13.250][INFO][RK0][main]: Creating lookup session for 8_table on device: 0
[INFO] Elapsed time for 100 iterations: 0.10996460914611816 seconds
HPS Torch Plugin embedding lookup with table fusion, MSE: 0.0