Program Listing for File merlin_hashtable.cuh

Return to documentation for file (merlin_hashtable.cuh)

/*
 * Copyright (c) 2022, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include <atomic>
#include <cstdint>
#include <limits>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <type_traits>
#include "merlin/allocator.cuh"
#include "merlin/array_kernels.cuh"
#include "merlin/core_kernels.cuh"
#include "merlin/flexible_buffer.cuh"
#include "merlin/group_lock.cuh"
#include "merlin/memory_pool.cuh"
#include "merlin/types.cuh"
#include "merlin/utils.cuh"

namespace nv {
namespace merlin {

struct EvictStrategy {
  enum EvictStrategyEnum {
    kLru = 0,
    kLfu = 1,
    kEpochLru = 2,
    kEpochLfu = 3,
    kCustomized = 4,
  };
};

struct HashTableOptions {
  size_t init_capacity = 0;
  size_t max_capacity = 0;
  size_t max_hbm_for_vectors = 0;
  size_t max_bucket_size = 128;
  size_t dim = 64;
  float max_load_factor = 0.5f;
  int block_size = 128;
  int io_block_size = 1024;
  int device_id = -1;
  bool io_by_cpu = false;
  bool use_constant_memory = false;
  /*
   * reserved_key_start_bit = 0, is the default behavior, HKV reserves
   * `0xFFFFFFFFFFFFFFFD`, `0xFFFFFFFFFFFFFFFE`, and `0xFFFFFFFFFFFFFFFF`  for
   * internal using. if the default one conflicted with your keys, change the
   * reserved_key_start_bit value to a numbers between 1 and 62,
   * reserved_key_start_bit = 1 means using the insignificant bits index 1 and 2
   * as the keys as the reserved keys and the index 0 bit is 0 and all the other
   * bits are 1, the new reserved keys are `FFFFFFFFFFFFFFFE`,
   * `0xFFFFFFFFFFFFFFFC`, `0xFFFFFFFFFFFFFFF8`, and `0xFFFFFFFFFFFFFFFA` the
   * console log prints the reserved keys during the table initialization.
   */
  int reserved_key_start_bit = 0;
  size_t num_of_buckets_per_alloc = 1;
  MemoryPoolOptions
      device_memory_pool;
  MemoryPoolOptions
      host_memory_pool;
};

template <class K, class S>
using EraseIfPredict = bool (*)(
    const K& key,
    S& score,
    const K& pattern,
    const S& threshold
);

#if THRUST_VERSION >= 101600
static constexpr auto& thrust_par = thrust::cuda::par_nosync;
#else
static constexpr auto& thrust_par = thrust::cuda::par;
#endif

template <typename K, typename V, typename S = uint64_t>
class HashTableBase {
 public:
  using size_type = size_t;
  using key_type = K;
  using value_type = V;
  using score_type = S;
  using allocator_type = BaseAllocator;

 public:
  virtual ~HashTableBase() {}

  virtual void init(const HashTableOptions& options,
                    allocator_type* allocator = nullptr) = 0;

  virtual void insert_or_assign(const size_type n,
                                const key_type* keys,                // (n)
                                const value_type* values,            // (n, DIM)
                                const score_type* scores = nullptr,  // (n)
                                cudaStream_t stream = 0, bool unique_key = true,
                                bool ignore_evict_strategy = false) = 0;

  virtual void insert_and_evict(const size_type n,
                                const key_type* keys,          // (n)
                                const value_type* values,      // (n, DIM)
                                const score_type* scores,      // (n)
                                key_type* evicted_keys,        // (n)
                                value_type* evicted_values,    // (n, DIM)
                                score_type* evicted_scores,    // (n)
                                size_type* d_evicted_counter,  // (1)
                                cudaStream_t stream = 0, bool unique_key = true,
                                bool ignore_evict_strategy = false) = 0;

  virtual size_type insert_and_evict(const size_type n,
                                     const key_type* keys,        // (n)
                                     const value_type* values,    // (n, DIM)
                                     const score_type* scores,    // (n)
                                     key_type* evicted_keys,      // (n)
                                     value_type* evicted_values,  // (n, DIM)
                                     score_type* evicted_scores,  // (n)
                                     cudaStream_t stream = 0,
                                     bool unique_key = true,
                                     bool ignore_evict_strategy = false) = 0;

  virtual void accum_or_assign(const size_type n,
                               const key_type* keys,                // (n)
                               const value_type* value_or_deltas,   // (n, DIM)
                               const bool* accum_or_assigns,        // (n)
                               const score_type* scores = nullptr,  // (n)
                               cudaStream_t stream = 0,
                               bool ignore_evict_strategy = false) = 0;

  virtual void find_or_insert(const size_type n, const key_type* keys,  // (n)
                              value_type* values,            // (n * DIM)
                              score_type* scores = nullptr,  // (n)
                              cudaStream_t stream = 0, bool unique_key = true,
                              bool ignore_evict_strategy = false) = 0;

  virtual void find_or_insert(const size_type n, const key_type* keys,  // (n)
                              value_type** values,                      // (n)
                              bool* founds,                             // (n)
                              score_type* scores = nullptr,             // (n)
                              cudaStream_t stream = 0, bool unique_key = true,
                              bool ignore_evict_strategy = false) = 0;

  virtual void assign(const size_type n,
                      const key_type* keys,                // (n)
                      const value_type* values,            // (n, DIM)
                      const score_type* scores = nullptr,  // (n)
                      cudaStream_t stream = 0, bool unique_key = true) = 0;

  virtual void assign_scores(const size_type n,
                             const key_type* keys,                // (n)
                             const score_type* scores = nullptr,  // (n)
                             cudaStream_t stream = 0,
                             bool unique_key = true) = 0;

  virtual void assign(const size_type n,
                      const key_type* keys,                // (n)
                      const score_type* scores = nullptr,  // (n)
                      cudaStream_t stream = 0, bool unique_key = true) = 0;

  virtual void assign_values(const size_type n,
                             const key_type* keys,      // (n)
                             const value_type* values,  // (n, DIM)
                             cudaStream_t stream = 0,
                             bool unique_key = true) = 0;
  virtual void find(const size_type n, const key_type* keys,  // (n)
                    value_type* values,                       // (n, DIM)
                    bool* founds,                             // (n)
                    score_type* scores = nullptr,             // (n)
                    cudaStream_t stream = 0) const = 0;

  virtual void find(const size_type n, const key_type* keys,  // (n)
                    value_type* values,                       // (n, DIM)
                    key_type* missed_keys,                    // (n)
                    int* missed_indices,                      // (n)
                    int* missed_size,                         // scalar
                    score_type* scores = nullptr,             // (n)
                    cudaStream_t stream = 0) const = 0;

  virtual void find(const size_type n, const key_type* keys,  // (n)
                    value_type** values,                      // (n)
                    bool* founds,                             // (n)
                    score_type* scores = nullptr,             // (n)
                    cudaStream_t stream = 0, bool unique_key = true) const = 0;

  virtual void contains(const size_type n, const key_type* keys,  // (n)
                        bool* founds,                             // (n)
                        cudaStream_t stream = 0) const = 0;

  virtual void erase(const size_type n, const key_type* keys,
                     cudaStream_t stream = 0) = 0;

  virtual void clear(cudaStream_t stream = 0) = 0;

  virtual void export_batch(size_type n, const size_type offset,
                            size_type* d_counter,          // (1)
                            key_type* keys,                // (n)
                            value_type* values,            // (n, DIM)
                            score_type* scores = nullptr,  // (n)
                            cudaStream_t stream = 0) const = 0;

  virtual size_type export_batch(const size_type n, const size_type offset,
                                 key_type* keys,                // (n)
                                 value_type* values,            // (n, DIM)
                                 score_type* scores = nullptr,  // (n)
                                 cudaStream_t stream = 0) const = 0;

  virtual bool empty(cudaStream_t stream = 0) const = 0;

  virtual size_type size(cudaStream_t stream = 0) const = 0;

  virtual size_type capacity() const = 0;

  virtual void reserve(const size_type new_capacity,
                       cudaStream_t stream = 0) = 0;

  virtual float load_factor(cudaStream_t stream = 0) const = 0;

  virtual void set_max_capacity(size_type new_max_capacity) = 0;

  virtual size_type dim() const noexcept = 0;

  virtual size_type max_bucket_size() const noexcept = 0;

  virtual size_type bucket_count() const noexcept = 0;

  virtual size_type save(BaseKVFile<K, V, S>* file,
                         const size_t max_workspace_size = 1L * 1024 * 1024,
                         cudaStream_t stream = 0) const = 0;

  virtual size_type load(BaseKVFile<K, V, S>* file,
                         const size_t max_workspace_size = 1L * 1024 * 1024,
                         cudaStream_t stream = 0) = 0;

  virtual void set_global_epoch(const uint64_t epoch) = 0;
};

template <typename K, typename V, typename S = uint64_t,
          int Strategy = EvictStrategy::kLru, typename ArchTag = Sm80>
class HashTable : public HashTableBase<K, V, S> {
 public:
  using size_type = size_t;
  using key_type = K;
  using value_type = V;
  using score_type = S;
  static constexpr int evict_strategy = Strategy;

  using Pred = EraseIfPredict<key_type, score_type>;
  using allocator_type = BaseAllocator;

 private:
  using TableCore = nv::merlin::Table<key_type, value_type, score_type>;
  static constexpr unsigned int TILE_SIZE = 4;

  using DeviceMemoryPool = MemoryPool<DeviceAllocator<char>>;
  using HostMemoryPool = MemoryPool<HostAllocator<char>>;

 public:
  HashTable() {
    static_assert((std::is_same<key_type, int64_t>::value ||
                   std::is_same<key_type, uint64_t>::value),
                  "The key_type must be int64_t or uint64_t.");

    static_assert(std::is_same<score_type, uint64_t>::value,
                  "The key_type must be uint64_t.");
  };

  ~HashTable() {
    if (initialized_) {
      CUDA_CHECK(cudaDeviceSynchronize());

      initialized_ = false;
      destroy_table<key_type, value_type, score_type>(&table_, allocator_);
      allocator_->free(MemoryType::Device, d_table_);
      dev_mem_pool_.reset();
      host_mem_pool_.reset();

      CUDA_CHECK(cudaDeviceSynchronize());
      if (default_allocator_ && allocator_ != nullptr) {
        delete allocator_;
      }
    }
  }

 private:
  HashTable(const HashTable&) = delete;
  HashTable& operator=(const HashTable&) = delete;
  HashTable(HashTable&&) = delete;
  HashTable& operator=(HashTable&&) = delete;

 public:
  void init(const HashTableOptions& options,
            allocator_type* allocator = nullptr) {
    if (initialized_) {
      return;
    }
    options_ = options;
    MERLIN_CHECK(options.reserved_key_start_bit >= 0 &&
                     options.reserved_key_start_bit <= MAX_RESERVED_KEY_BIT,
                 "options.reserved_key_start_bit should >= 0 and <= 62.");
    CUDA_CHECK(init_reserved_keys(options.reserved_key_start_bit));

    default_allocator_ = (allocator == nullptr);
    allocator_ = (allocator == nullptr) ? (new DefaultAllocator()) : allocator;

    thrust_allocator_.set_allocator(allocator_);

    if (options_.device_id >= 0) {
      CUDA_CHECK(cudaSetDevice(options_.device_id));
    } else {
      CUDA_CHECK(cudaGetDevice(&(options_.device_id)));
    }

    MERLIN_CHECK(ispow2(static_cast<uint32_t>(options_.max_bucket_size)),
                 "Bucket size should be the pow of 2");
    MERLIN_CHECK(
        ispow2(static_cast<uint32_t>(options_.num_of_buckets_per_alloc)),
        "Then `num_of_buckets_per_alloc` should be the pow of 2");
    MERLIN_CHECK(options_.init_capacity >= options_.num_of_buckets_per_alloc *
                                               options_.max_bucket_size,
                 "Then `num_of_buckets_per_alloc` must be equal or less than "
                 "initial required buckets number");

    options_.block_size = SAFE_GET_BLOCK_SIZE(options_.block_size);

    MERLIN_CHECK(
        (((options_.max_bucket_size * (sizeof(key_type) + sizeof(score_type))) %
          128) == 0),
        "Storage size of keys and scores in one bucket should be the mutiple "
        "of cache line size");

    // Construct table.
    cudaDeviceProp deviceProp;
    CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, options_.device_id));
    shared_mem_size_ = deviceProp.sharedMemPerBlock;
    create_table<key_type, value_type, score_type>(
        &table_, allocator_, options_.dim, options_.init_capacity,
        options_.max_capacity, options_.max_hbm_for_vectors,
        options_.max_bucket_size, options_.num_of_buckets_per_alloc);
    options_.block_size = SAFE_GET_BLOCK_SIZE(options_.block_size);
    reach_max_capacity_ = (options_.init_capacity * 2 > options_.max_capacity);
    MERLIN_CHECK((!(options_.io_by_cpu && options_.max_hbm_for_vectors != 0)),
                 "[HierarchicalKV] `io_by_cpu` should not be true when "
                 "`max_hbm_for_vectors` is not 0!");
    allocator_->alloc(MemoryType::Device, (void**)&(d_table_),
                      sizeof(TableCore));

    sync_table_configuration();

    // Create memory pools.
    dev_mem_pool_ = std::make_unique<MemoryPool<DeviceAllocator<char>>>(
        options_.device_memory_pool, allocator_);
    host_mem_pool_ = std::make_unique<MemoryPool<HostAllocator<char>>>(
        options_.host_memory_pool, allocator_);

    CUDA_CHECK(cudaDeviceSynchronize());

    initialized_ = true;
    CudaCheckError();
  }

  void insert_or_assign(const size_type n,
                        const key_type* keys,                // (n)
                        const value_type* values,            // (n, DIM)
                        const score_type* scores = nullptr,  // (n)
                        cudaStream_t stream = 0, bool unique_key = true,
                        bool ignore_evict_strategy = false) {
    if (n == 0) {
      return;
    }

    while (!reach_max_capacity_ &&
           fast_load_factor(n, stream) > options_.max_load_factor) {
      reserve(capacity() * 2, stream);
    }

    if (!ignore_evict_strategy) {
      check_evict_strategy(scores);
    }

    insert_unique_lock lock(mutex_, stream);

    if (is_fast_mode()) {
      static thread_local int step_counter = 0;
      static thread_local float load_factor = 0.0;

      if (((step_counter++) % kernel_select_interval_) == 0) {
        load_factor = fast_load_factor(0, stream, false);
      }

      using Selector = KernelSelector_Upsert<key_type, value_type, score_type,
                                             evict_strategy, ArchTag>;
      if (Selector::callable(unique_key,
                             static_cast<uint32_t>(options_.max_bucket_size),
                             static_cast<uint32_t>(options_.dim))) {
        typename Selector::Params kernelParams(
            load_factor, table_->buckets, table_->buckets_size,
            table_->buckets_num,
            static_cast<uint32_t>(options_.max_bucket_size),
            static_cast<uint32_t>(options_.dim), keys, values, scores, n,
            global_epoch_);
        Selector::select_kernel(kernelParams, stream);
      } else {
        using Selector = SelectUpsertKernelWithIO<key_type, value_type,
                                                  score_type, evict_strategy>;
        Selector::execute_kernel(
            load_factor, options_.block_size, options_.max_bucket_size,
            table_->buckets_num, options_.dim, stream, n, d_table_,
            table_->buckets, keys, reinterpret_cast<const value_type*>(values),
            scores, global_epoch_);
      }
    } else {
      const size_type dev_ws_size{
          n * (sizeof(value_type*) + sizeof(int) + sizeof(key_type*))};
      auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
      auto d_dst{dev_ws.get<value_type**>(0)};
      auto keys_ptr{reinterpret_cast<key_type**>(d_dst + n)};
      auto d_src_offset{reinterpret_cast<int*>(keys_ptr + n)};

      CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream));

      constexpr uint32_t MinBucketCapacityFilter =
          sizeof(VecD_Load) / sizeof(D);

      bool filter_condition =
          unique_key && options_.max_bucket_size >= MinBucketCapacityFilter &&
          !options_.io_by_cpu;

      if (filter_condition) {
        constexpr uint32_t BLOCK_SIZE = 128;

        upsert_kernel_lock_key_hybrid<key_type, value_type, score_type,
                                      BLOCK_SIZE, evict_strategy>
            <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
                table_->buckets, table_->buckets_size, table_->buckets_num,
                options_.max_bucket_size, options_.dim, keys, d_dst, scores,
                keys_ptr, d_src_offset, n, global_epoch_);

      } else {
        const size_t block_size = options_.block_size;
        const size_t N = n * TILE_SIZE;
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        upsert_kernel<key_type, value_type, score_type, evict_strategy,
                      TILE_SIZE><<<grid_size, block_size, 0, stream>>>(
            d_table_, table_->buckets, options_.max_bucket_size,
            table_->buckets_num, options_.dim, keys, d_dst, scores,
            d_src_offset, global_epoch_, N);
      }

      {
        thrust::device_ptr<uintptr_t> d_dst_ptr(
            reinterpret_cast<uintptr_t*>(d_dst));
        thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);

        thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), d_dst_ptr,
                            d_dst_ptr + n, d_src_offset_ptr,
                            thrust::less<uintptr_t>());
      }

      if (filter_condition) {
        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        write_kernel_unlock_key<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(values, d_dst, d_src_offset,
                                                   dim(), keys, keys_ptr, N);

      } else if (options_.io_by_cpu) {
        const size_type host_ws_size{dev_ws_size +
                                     n * sizeof(value_type) * dim()};
        auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
        auto h_dst{host_ws.get<value_type**>(0)};
        auto h_src_offset{reinterpret_cast<int*>(h_dst + n)};
        auto h_values{reinterpret_cast<value_type*>(h_src_offset + n)};

        CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_size,
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaMemcpyAsync(h_values, values, host_ws_size - dev_ws_size,
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaStreamSynchronize(stream));

        write_by_cpu<value_type>(h_dst, h_values, h_src_offset, dim(), n);
      } else {
        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        write_kernel<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(values, d_dst, d_src_offset,
                                                   dim(), N);
      }
    }

    CudaCheckError();
  }

  void insert_and_evict(const size_type n,
                        const key_type* keys,          // (n)
                        const value_type* values,      // (n, DIM)
                        const score_type* scores,      // (n)
                        key_type* evicted_keys,        // (n)
                        value_type* evicted_values,    // (n, DIM)
                        score_type* evicted_scores,    // (n)
                        size_type* d_evicted_counter,  // (1)
                        cudaStream_t stream = 0, bool unique_key = true,
                        bool ignore_evict_strategy = false) {
    if (n == 0) {
      return;
    }

    while (!reach_max_capacity_ &&
           fast_load_factor(n, stream) > options_.max_load_factor) {
      reserve(capacity() * 2, stream);
    }

    if (!ignore_evict_strategy) {
      check_evict_strategy(scores);
    }

    insert_unique_lock lock(mutex_, stream);

    // TODO: Currently only need eviction when using HashTable as HBM cache.
    if (!is_fast_mode()) {
      throw std::runtime_error("Only allow insert_and_evict in pure HBM mode.");
    }

    static thread_local int step_counter = 0;
    static thread_local float load_factor = 0.0;

    if (((step_counter++) % kernel_select_interval_) == 0) {
      load_factor = fast_load_factor(0, stream, false);
    }

    using Selector =
        KernelSelector_UpsertAndEvict<key_type, value_type, score_type,
                                      evict_strategy, ArchTag>;
    if (Selector::callable(unique_key,
                           static_cast<uint32_t>(options_.max_bucket_size),
                           static_cast<uint32_t>(options_.dim))) {
      typename Selector::Params kernelParams(
          load_factor, table_->buckets, table_->buckets_size,
          table_->buckets_num, static_cast<uint32_t>(options_.max_bucket_size),
          static_cast<uint32_t>(options_.dim), keys, values, scores,
          evicted_keys, evicted_values, evicted_scores, n, d_evicted_counter,
          global_epoch_);
      Selector::select_kernel(kernelParams, stream);
    } else {
      // always use max tile to avoid data-deps as possible.
      const int TILE_SIZE = 32;
      size_t n_offsets = (n + TILE_SIZE - 1) / TILE_SIZE;
      const size_type dev_ws_size =
          n_offsets * sizeof(int64_t) + n * sizeof(bool) + sizeof(size_type);

      auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
      auto d_offsets{dev_ws.get<int64_t*>(0)};
      auto d_masks = reinterpret_cast<bool*>(d_offsets + n_offsets);

      CUDA_CHECK(
          cudaMemsetAsync(d_offsets, 0, n_offsets * sizeof(int64_t), stream));
      CUDA_CHECK(cudaMemsetAsync(d_masks, 0, n * sizeof(bool), stream));

      size_type block_size = options_.block_size;
      size_type grid_size = SAFE_GET_GRID_SIZE(n, block_size);
      CUDA_CHECK(memset64Async(evicted_keys, EMPTY_KEY_CPU, n, stream));
      using Selector =
          SelectUpsertAndEvictKernelWithIO<key_type, value_type, score_type,
                                           evict_strategy>;

      Selector::execute_kernel(
          load_factor, options_.block_size, options_.max_bucket_size,
          table_->buckets_num, options_.dim, stream, n, d_table_,
          table_->buckets, keys, values, scores, evicted_keys, evicted_values,
          evicted_scores, global_epoch_);

      keys_not_empty<K>
          <<<grid_size, block_size, 0, stream>>>(evicted_keys, d_masks, n);
      gpu_boolean_mask<K, V, S, int64_t, TILE_SIZE>(
          grid_size, block_size, d_masks, n, d_evicted_counter, d_offsets,
          evicted_keys, evicted_values, evicted_scores, dim(), stream);
    }
    return;
  }

  size_type insert_and_evict(const size_type n,
                             const key_type* keys,        // (n)
                             const value_type* values,    // (n, DIM)
                             const score_type* scores,    // (n)
                             key_type* evicted_keys,      // (n)
                             value_type* evicted_values,  // (n, DIM)
                             score_type* evicted_scores,  // (n)
                             cudaStream_t stream = 0, bool unique_key = true,
                             bool ignore_evict_strategy = false) {
    if (n == 0) {
      return 0;
    }
    auto dev_ws{dev_mem_pool_->get_workspace<1>(sizeof(size_type), stream)};
    size_type* d_evicted_counter{dev_ws.get<size_type*>(0)};

    CUDA_CHECK(
        cudaMemsetAsync(d_evicted_counter, 0, sizeof(size_type), stream));
    insert_and_evict(n, keys, values, scores, evicted_keys, evicted_values,
                     evicted_scores, d_evicted_counter, stream, unique_key,
                     ignore_evict_strategy);

    size_type h_evicted_counter = 0;
    CUDA_CHECK(cudaMemcpyAsync(&h_evicted_counter, d_evicted_counter,
                               sizeof(size_type), cudaMemcpyDeviceToHost,
                               stream));
    CUDA_CHECK(cudaStreamSynchronize(stream));
    CudaCheckError();
    return h_evicted_counter;
  }

  void accum_or_assign(const size_type n,
                       const key_type* keys,                // (n)
                       const value_type* value_or_deltas,   // (n, DIM)
                       const bool* accum_or_assigns,        // (n)
                       const score_type* scores = nullptr,  // (n)
                       cudaStream_t stream = 0,
                       bool ignore_evict_strategy = false) {
    if (n == 0) {
      return;
    }

    while (!reach_max_capacity_ &&
           fast_load_factor(n, stream) > options_.max_load_factor) {
      reserve(capacity() * 2, stream);
    }

    if (!ignore_evict_strategy) {
      check_evict_strategy(scores);
    }

    insert_unique_lock lock(mutex_, stream);

    if (is_fast_mode()) {
      using Selector =
          SelectAccumOrAssignKernelWithIO<key_type, value_type, score_type,
                                          evict_strategy>;
      static thread_local int step_counter = 0;
      static thread_local float load_factor = 0.0;

      if (((step_counter++) % kernel_select_interval_) == 0) {
        load_factor = fast_load_factor(0, stream, false);
      }
      Selector::execute_kernel(
          load_factor, options_.block_size, options_.max_bucket_size,
          table_->buckets_num, dim(), stream, n, d_table_, keys,
          value_or_deltas, scores, accum_or_assigns, global_epoch_);

    } else {
      const size_type dev_ws_size{
          n * (sizeof(value_type*) + sizeof(int) + sizeof(bool))};
      auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
      auto dst{dev_ws.get<value_type**>(0)};
      auto src_offset{reinterpret_cast<int*>(dst + n)};
      auto founds{reinterpret_cast<bool*>(src_offset + n)};

      CUDA_CHECK(cudaMemsetAsync(dst, 0, dev_ws_size, stream));

      {
        const size_t block_size = options_.block_size;
        const size_t N = n * TILE_SIZE;
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        accum_or_assign_kernel<key_type, value_type, score_type, evict_strategy,
                               TILE_SIZE><<<grid_size, block_size, 0, stream>>>(
            d_table_, options_.max_bucket_size, table_->buckets_num, dim(),
            keys, dst, scores, accum_or_assigns, src_offset, founds,
            global_epoch_, N);
      }

      {
        thrust::device_ptr<uintptr_t> dst_ptr(
            reinterpret_cast<uintptr_t*>(dst));
        thrust::device_ptr<int> src_offset_ptr(src_offset);

        thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), dst_ptr,
                            dst_ptr + n, src_offset_ptr,
                            thrust::less<uintptr_t>());
      }

      {
        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        write_with_accum_kernel<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(value_or_deltas, dst,
                                                   accum_or_assigns, founds,
                                                   src_offset, dim(), N);
      }
    }
    CudaCheckError();
  }

  void find_or_insert(const size_type n, const key_type* keys,  // (n)
                      value_type* values,                       // (n * DIM)
                      score_type* scores = nullptr,             // (n)
                      cudaStream_t stream = 0, bool unique_key = true,
                      bool ignore_evict_strategy = false) {
    if (n == 0) {
      return;
    }

    while (!reach_max_capacity_ &&
           fast_load_factor(n, stream) > options_.max_load_factor) {
      reserve(capacity() * 2, stream);
    }

    if (!ignore_evict_strategy) {
      check_evict_strategy(scores);
    }

    insert_unique_lock lock(mutex_, stream);

    if (is_fast_mode()) {
      static thread_local int step_counter = 0;
      static thread_local float load_factor = 0.0;

      if (((step_counter++) % kernel_select_interval_) == 0) {
        load_factor = fast_load_factor(0, stream, false);
      }

      using Selector =
          KernelSelector_FindOrInsert<key_type, value_type, score_type,
                                      evict_strategy, ArchTag>;
      if (Selector::callable(unique_key,
                             static_cast<uint32_t>(options_.max_bucket_size),
                             static_cast<uint32_t>(options_.dim))) {
        typename Selector::Params kernelParams(
            load_factor, table_->buckets, table_->buckets_size,
            table_->buckets_num,
            static_cast<uint32_t>(options_.max_bucket_size),
            static_cast<uint32_t>(options_.dim), keys, values, scores, n,
            global_epoch_);
        Selector::select_kernel(kernelParams, stream);
      } else {
        using Selector =
            SelectFindOrInsertKernelWithIO<key_type, value_type, score_type,
                                           evict_strategy>;
        Selector::execute_kernel(
            load_factor, options_.block_size, options_.max_bucket_size,
            table_->buckets_num, options_.dim, stream, n, d_table_,
            table_->buckets, keys, values, scores, global_epoch_);
      }
    } else {
      const size_type dev_ws_size{n * (sizeof(value_type*) + sizeof(int) +
                                       sizeof(bool) + sizeof(key_type*))};
      auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
      auto d_table_value_addrs{dev_ws.get<value_type**>(0)};
      auto keys_ptr{reinterpret_cast<key_type**>(d_table_value_addrs + n)};
      auto param_key_index{reinterpret_cast<int*>(keys_ptr + n)};
      auto founds{reinterpret_cast<bool*>(param_key_index + n)};

      CUDA_CHECK(cudaMemsetAsync(d_table_value_addrs, 0, dev_ws_size, stream));

      constexpr uint32_t MinBucketCapacityFilter =
          sizeof(VecD_Load) / sizeof(D);

      bool filter_condition =
          unique_key && options_.max_bucket_size >= MinBucketCapacityFilter &&
          !options_.io_by_cpu;

      if (filter_condition) {
        constexpr uint32_t BLOCK_SIZE = 128;

        find_or_insert_kernel_lock_key_hybrid<key_type, value_type, score_type,
                                              BLOCK_SIZE, evict_strategy>
            <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
                table_->buckets, table_->buckets_size, table_->buckets_num,
                options_.max_bucket_size, options_.dim, keys,
                d_table_value_addrs, scores, keys_ptr, param_key_index, founds,
                n, global_epoch_);

      } else {
        const size_t block_size = options_.block_size;
        const size_t N = n * TILE_SIZE;
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        find_or_insert_kernel<key_type, value_type, score_type, evict_strategy,
                              TILE_SIZE><<<grid_size, block_size, 0, stream>>>(
            d_table_, table_->buckets, options_.max_bucket_size,
            table_->buckets_num, options_.dim, keys, d_table_value_addrs,
            scores, founds, param_key_index, global_epoch_, N);
      }

      {
        thrust::device_ptr<uintptr_t> table_value_ptr(
            reinterpret_cast<uintptr_t*>(d_table_value_addrs));
        thrust::device_ptr<int> param_key_index_ptr(param_key_index);

        thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream),
                            table_value_ptr, table_value_ptr + n,
                            param_key_index_ptr, thrust::less<uintptr_t>());
      }

      if (filter_condition) {
        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        read_or_write_kernel_unlock_key<key_type, value_type, score_type, V>
            <<<grid_size, block_size, 0, stream>>>(d_table_value_addrs, values,
                                                   founds, param_key_index,
                                                   keys_ptr, keys, dim(), N);

      } else if (options_.io_by_cpu) {
        const size_type host_ws_size{
            dev_ws_size + n * (sizeof(bool) + sizeof(value_type) * dim())};
        auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
        auto h_table_value_addrs{host_ws.get<value_type**>(0)};
        auto h_param_key_index{reinterpret_cast<int*>(h_table_value_addrs + n)};
        auto h_founds{reinterpret_cast<bool*>(h_param_key_index + n)};
        auto h_param_values{reinterpret_cast<value_type*>(h_founds + n)};

        CUDA_CHECK(cudaMemcpyAsync(h_table_value_addrs, d_table_value_addrs,
                                   dev_ws_size, cudaMemcpyDeviceToHost,
                                   stream));
        CUDA_CHECK(cudaMemcpyAsync(h_founds, founds, n * sizeof(bool),
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaMemcpyAsync(h_param_values, values,
                                   n * sizeof(value_type) * dim(),
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaStreamSynchronize(stream));

        read_or_write_by_cpu<value_type>(h_table_value_addrs, h_param_values,
                                         h_param_key_index, h_founds, dim(), n);
        CUDA_CHECK(cudaMemcpyAsync(values, h_param_values,
                                   n * sizeof(value_type) * dim(),
                                   cudaMemcpyHostToDevice, stream));
      } else {
        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        read_or_write_kernel<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(
                d_table_value_addrs, values, founds, param_key_index, dim(), N);
      }
    }

    CudaCheckError();
  }

  void find_or_insert(const size_type n, const key_type* keys,  // (n)
                      value_type** values,                      // (n)
                      bool* founds,                             // (n)
                      score_type* scores = nullptr,             // (n)
                      cudaStream_t stream = 0, bool unique_key = true,
                      bool ignore_evict_strategy = false) {
    if (n == 0) {
      return;
    }

    while (!reach_max_capacity_ &&
           fast_load_factor(n, stream) > options_.max_load_factor) {
      reserve(capacity() * 2, stream);
    }

    if (!ignore_evict_strategy) {
      check_evict_strategy(scores);
    }

    insert_unique_lock lock(mutex_, stream);

    constexpr uint32_t MinBucketCapacityFilter = sizeof(VecD_Load) / sizeof(D);

    if (unique_key && options_.max_bucket_size >= MinBucketCapacityFilter) {
      constexpr uint32_t BLOCK_SIZE = 128U;

      const size_type dev_ws_size{n * sizeof(key_type**)};
      auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
      auto keys_ptr{dev_ws.get<key_type**>(0)};
      CUDA_CHECK(cudaMemsetAsync(keys_ptr, 0, dev_ws_size, stream));

      find_or_insert_ptr_kernel_lock_key<key_type, value_type, score_type,
                                         BLOCK_SIZE, evict_strategy>
          <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
              table_->buckets, table_->buckets_size, table_->buckets_num,
              options_.max_bucket_size, options_.dim, keys, values, scores,
              keys_ptr, n, founds, global_epoch_);

      find_or_insert_ptr_kernel_unlock_key<key_type>
          <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
              keys, keys_ptr, n);
    } else {
      using Selector = SelectFindOrInsertPtrKernel<key_type, value_type,
                                                   score_type, evict_strategy>;
      static thread_local int step_counter = 0;
      static thread_local float load_factor = 0.0;

      if (((step_counter++) % kernel_select_interval_) == 0) {
        load_factor = fast_load_factor(0, stream, false);
      }
      Selector::execute_kernel(
          load_factor, options_.block_size, options_.max_bucket_size,
          table_->buckets_num, options_.dim, stream, n, d_table_,
          table_->buckets, keys, values, scores, founds, global_epoch_);
    }

    CudaCheckError();
  }
  void assign(const size_type n,
              const key_type* keys,                // (n)
              const value_type* values,            // (n, DIM)
              const score_type* scores = nullptr,  // (n)
              cudaStream_t stream = 0, bool unique_key = true) {
    if (n == 0) {
      return;
    }

    check_evict_strategy(scores);

    update_shared_lock lock(mutex_, stream);

    if (is_fast_mode()) {
      static thread_local int step_counter = 0;
      static thread_local float load_factor = 0.0;

      if (((step_counter++) % kernel_select_interval_) == 0) {
        load_factor = fast_load_factor(0, stream, false);
      }
      using Selector = KernelSelector_Update<key_type, value_type, score_type,
                                             evict_strategy, ArchTag>;
      if (Selector::callable(unique_key,
                             static_cast<uint32_t>(options_.max_bucket_size),
                             static_cast<uint32_t>(options_.dim))) {
        typename Selector::Params kernelParams(
            load_factor, table_->buckets, table_->buckets_num,
            static_cast<uint32_t>(options_.max_bucket_size),
            static_cast<uint32_t>(options_.dim), keys, values, scores, n,
            global_epoch_);
        Selector::select_kernel(kernelParams, stream);
      } else {
        using Selector = SelectUpdateKernelWithIO<key_type, value_type,
                                                  score_type, evict_strategy>;
        Selector::execute_kernel(
            load_factor, options_.block_size, options_.max_bucket_size,
            table_->buckets_num, options_.dim, stream, n, d_table_,
            table_->buckets, keys, values, scores, global_epoch_);
      }
    } else {
      const size_type dev_ws_size{
          n * (sizeof(value_type*) + sizeof(key_type) + sizeof(int))};
      auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
      auto d_dst{dev_ws.get<value_type**>(0)};
      auto keys_ptr{reinterpret_cast<key_type**>(d_dst + n)};
      auto d_src_offset{reinterpret_cast<int*>(keys_ptr + n)};

      CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream));

      constexpr uint32_t MinBucketCapacityFilter =
          sizeof(VecD_Load) / sizeof(D);

      bool filter_condition =
          options_.max_bucket_size >= MinBucketCapacityFilter &&
          !options_.io_by_cpu && unique_key;

      if (filter_condition) {
        constexpr uint32_t BLOCK_SIZE = 128U;

        tlp_update_kernel_hybrid<key_type, value_type, score_type,
                                 evict_strategy>
            <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
                table_->buckets, table_->buckets_num, options_.max_bucket_size,
                options_.dim, keys, d_dst, scores, keys_ptr, d_src_offset,
                global_epoch_, n);

      } else {
        const size_t block_size = options_.block_size;
        const size_t N = n * TILE_SIZE;
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        update_kernel<key_type, value_type, score_type, evict_strategy,
                      TILE_SIZE><<<grid_size, block_size, 0, stream>>>(
            d_table_, table_->buckets, options_.max_bucket_size,
            table_->buckets_num, options_.dim, keys, d_dst, scores,
            d_src_offset, global_epoch_, N);
      }

      {
        thrust::device_ptr<uintptr_t> d_dst_ptr(
            reinterpret_cast<uintptr_t*>(d_dst));
        thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);

        thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), d_dst_ptr,
                            d_dst_ptr + n, d_src_offset_ptr,
                            thrust::less<uintptr_t>());
      }

      if (filter_condition) {
        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        write_kernel_unlock_key<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(values, d_dst, d_src_offset,
                                                   dim(), keys, keys_ptr, N);

      } else if (options_.io_by_cpu) {
        const size_type host_ws_size{dev_ws_size +
                                     n * sizeof(value_type) * dim()};
        auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
        auto h_dst{host_ws.get<value_type**>(0)};
        auto h_src_offset{reinterpret_cast<int*>(h_dst + n)};
        auto h_values{reinterpret_cast<value_type*>(h_src_offset + n)};

        CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_size,
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaMemcpyAsync(h_values, values, host_ws_size - dev_ws_size,
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaStreamSynchronize(stream));

        write_by_cpu<value_type>(h_dst, h_values, h_src_offset, dim(), n);
      } else {
        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        write_kernel<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(values, d_dst, d_src_offset,
                                                   dim(), N);
      }
    }

    CudaCheckError();
  }

  void assign_scores(const size_type n,
                     const key_type* keys,                // (n)
                     const score_type* scores = nullptr,  // (n)
                     cudaStream_t stream = 0, bool unique_key = true) {
    if (n == 0) {
      return;
    }

    check_evict_strategy(scores);

    {
      update_shared_lock lock(mutex_, stream);
      static thread_local int step_counter = 0;
      static thread_local float load_factor = 0.0;

      if (((step_counter++) % kernel_select_interval_) == 0) {
        load_factor = fast_load_factor(0, stream, false);
      }
      using Selector = KernelSelector_UpdateScore<key_type, value_type,
                                                  score_type, evict_strategy>;
      if (Selector::callable(unique_key,
                             static_cast<uint32_t>(options_.max_bucket_size))) {
        typename Selector::Params kernelParams(
            load_factor, table_->buckets, table_->buckets_num,
            static_cast<uint32_t>(options_.max_bucket_size), keys, scores, n,
            global_epoch_);
        Selector::select_kernel(kernelParams, stream);
      } else {
        using Selector = SelectUpdateScoreKernel<key_type, value_type,
                                                 score_type, evict_strategy>;
        Selector::execute_kernel(load_factor, options_.block_size,
                                 options_.max_bucket_size, table_->buckets_num,
                                 stream, n, d_table_, table_->buckets, keys,
                                 scores, global_epoch_);
      }
    }

    CudaCheckError();
  }

  void assign(const size_type n,
              const key_type* keys,                // (n)
              const score_type* scores = nullptr,  // (n)
              cudaStream_t stream = 0, bool unique_key = true) {
    assign_scores(n, keys, scores, stream, unique_key);
  }

  void assign_values(const size_type n,
                     const key_type* keys,      // (n)
                     const value_type* values,  // (n, DIM)
                     cudaStream_t stream = 0, bool unique_key = true) {
    if (n == 0) {
      return;
    }

    update_shared_lock lock(mutex_, stream);

    if (is_fast_mode()) {
      static thread_local int step_counter = 0;
      static thread_local float load_factor = 0.0;

      if (((step_counter++) % kernel_select_interval_) == 0) {
        load_factor = fast_load_factor(0, stream, false);
      }
      using Selector = KernelSelector_UpdateValues<key_type, value_type,
                                                   score_type, ArchTag>;
      if (Selector::callable(unique_key,
                             static_cast<uint32_t>(options_.max_bucket_size),
                             static_cast<uint32_t>(options_.dim))) {
        typename Selector::Params kernelParams(
            load_factor, table_->buckets, table_->buckets_num,
            static_cast<uint32_t>(options_.max_bucket_size),
            static_cast<uint32_t>(options_.dim), keys, values, n);
        Selector::select_kernel(kernelParams, stream);
      } else {
        using Selector =
            SelectUpdateValuesKernelWithIO<key_type, value_type, score_type>;
        Selector::execute_kernel(load_factor, options_.block_size,
                                 options_.max_bucket_size, table_->buckets_num,
                                 options_.dim, stream, n, d_table_,
                                 table_->buckets, keys, values);
      }
    } else {
      const size_type dev_ws_size{
          n * (sizeof(value_type*) + sizeof(key_type) + sizeof(int))};
      auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
      auto d_dst{dev_ws.get<value_type**>(0)};
      auto keys_ptr{reinterpret_cast<key_type**>(d_dst + n)};
      auto d_src_offset{reinterpret_cast<int*>(keys_ptr + n)};

      CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream));

      constexpr uint32_t MinBucketCapacityFilter =
          sizeof(VecD_Load) / sizeof(D);

      bool filter_condition =
          options_.max_bucket_size >= MinBucketCapacityFilter &&
          !options_.io_by_cpu && unique_key;

      if (filter_condition) {
        constexpr uint32_t BLOCK_SIZE = 128U;

        tlp_update_values_kernel_hybrid<key_type, value_type, score_type>
            <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
                table_->buckets, table_->buckets_num, options_.max_bucket_size,
                options_.dim, keys, d_dst, keys_ptr, d_src_offset, n);

      } else {
        const size_t block_size = options_.block_size;
        const size_t N = n * TILE_SIZE;
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        update_values_kernel<key_type, value_type, score_type, TILE_SIZE>
            <<<grid_size, block_size, 0, stream>>>(
                d_table_, table_->buckets, options_.max_bucket_size,
                table_->buckets_num, options_.dim, keys, d_dst, d_src_offset,
                N);
      }

      {
        thrust::device_ptr<uintptr_t> d_dst_ptr(
            reinterpret_cast<uintptr_t*>(d_dst));
        thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);

        thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), d_dst_ptr,
                            d_dst_ptr + n, d_src_offset_ptr,
                            thrust::less<uintptr_t>());
      }

      if (filter_condition) {
        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        write_kernel_unlock_key<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(values, d_dst, d_src_offset,
                                                   dim(), keys, keys_ptr, N);

      } else if (options_.io_by_cpu) {
        const size_type host_ws_size{dev_ws_size +
                                     n * sizeof(value_type) * dim()};
        auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
        auto h_dst{host_ws.get<value_type**>(0)};
        auto h_src_offset{reinterpret_cast<int*>(h_dst + n)};
        auto h_values{reinterpret_cast<value_type*>(h_src_offset + n)};

        CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_size,
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaMemcpyAsync(h_values, values, host_ws_size - dev_ws_size,
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaStreamSynchronize(stream));

        write_by_cpu<value_type>(h_dst, h_values, h_src_offset, dim(), n);
      } else {
        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        write_kernel<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(values, d_dst, d_src_offset,
                                                   dim(), N);
      }
    }

    CudaCheckError();
  }

  void find(const size_type n, const key_type* keys,  // (n)
            value_type* values,                       // (n, DIM)
            bool* founds,                             // (n)
            score_type* scores = nullptr,             // (n)
            cudaStream_t stream = 0) const {
    if (n == 0) {
      return;
    }

    CUDA_CHECK(cudaMemsetAsync(founds, 0, n * sizeof(bool), stream));

    read_shared_lock lock(mutex_, stream);

    const uint32_t value_size = dim() * sizeof(V);

    if (is_fast_mode()) {
      using Selector = SelectPipelineLookupKernelWithIO<key_type, value_type,
                                                        score_type, ArchTag>;
      const uint32_t pipeline_max_size = Selector::max_value_size();
      // Pipeline lookup kernel only supports "bucket_size = 128".
      if (options_.max_bucket_size == 128 && value_size <= pipeline_max_size) {
        LookupKernelParams<key_type, value_type, score_type> lookupParams(
            table_->buckets, table_->buckets_num, static_cast<uint32_t>(dim()),
            keys, values, scores, founds, n);
        Selector::select_kernel(lookupParams, stream);
      } else {
        using Selector =
            SelectLookupKernelWithIO<key_type, value_type, score_type>;
        static thread_local int step_counter = 0;
        static thread_local float load_factor = 0.0;

        if (((step_counter++) % kernel_select_interval_) == 0) {
          load_factor = fast_load_factor(0, stream, false);
        }
        Selector::execute_kernel(load_factor, options_.block_size,
                                 options_.max_bucket_size, table_->buckets_num,
                                 options_.dim, stream, n, d_table_,
                                 table_->buckets, keys, values, scores, founds);
      }
    } else {
      const size_type dev_ws_size{n * (sizeof(value_type*) + sizeof(int))};
      auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
      auto src{dev_ws.get<value_type**>(0)};
      auto dst_offset{reinterpret_cast<int*>(src + n)};

      CUDA_CHECK(cudaMemsetAsync(src, 0, dev_ws_size, stream));

      constexpr uint32_t MinBucketCapacityFilter =
          sizeof(VecD_Load) / sizeof(D);

      bool filter_condition =
          options_.max_bucket_size >= MinBucketCapacityFilter;

      if (filter_condition) {
        constexpr uint32_t BLOCK_SIZE = 128U;

        tlp_lookup_kernel_hybrid<key_type, value_type, score_type>
            <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
                table_->buckets, table_->buckets_num, options_.max_bucket_size,
                options_.dim, keys, src, scores, dst_offset, founds, n);
      } else {
        const size_t block_size = options_.block_size;
        const size_t N = n * TILE_SIZE;
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        lookup_kernel<key_type, value_type, score_type, TILE_SIZE>
            <<<grid_size, block_size, 0, stream>>>(
                d_table_, table_->buckets, options_.max_bucket_size,
                table_->buckets_num, options_.dim, keys, src, scores, founds,
                dst_offset, N);
      }

      if (values != nullptr) {
        thrust::device_ptr<uintptr_t> src_ptr(
            reinterpret_cast<uintptr_t*>(src));
        thrust::device_ptr<int> dst_offset_ptr(dst_offset);

        thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), src_ptr,
                            src_ptr + n, dst_offset_ptr,
                            thrust::less<uintptr_t>());

        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        read_kernel<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(src, values, founds,
                                                   dst_offset, dim(), N);
      }
    }

    CudaCheckError();
  }

  void find(const size_type n, const key_type* keys,  // (n)
            value_type* values,                       // (n, DIM)
            key_type* missed_keys,                    // (n)
            int* missed_indices,                      // (n)
            int* missed_size,                         // scalar
            score_type* scores = nullptr,             // (n)
            cudaStream_t stream = 0) const {
    if (n == 0) {
      return;
    }

    CUDA_CHECK(cudaMemsetAsync(missed_size, 0, sizeof(*missed_size), stream));

    read_shared_lock lock(mutex_, stream);

    const uint32_t value_size = options_.dim * sizeof(V);

    if (is_fast_mode()) {
      using Selector = SelectPipelineLookupKernelWithIO<key_type, value_type,
                                                        score_type, ArchTag>;
      const uint32_t pipeline_max_size = Selector::max_value_size();
      // Pipeline lookup kernel only supports "bucket_size = 128".
      if (options_.max_bucket_size == 128 && value_size <= pipeline_max_size) {
        LookupKernelParamsV2<key_type, value_type, score_type> lookupParams(
            table_->buckets, table_->buckets_num, static_cast<uint32_t>(dim()),
            keys, values, scores, missed_keys, missed_indices, missed_size, n);
        Selector::select_kernel(lookupParams, stream);
      } else {
        using Selector =
            SelectLookupKernelWithIOV2<key_type, value_type, score_type>;
        static thread_local int step_counter = 0;
        static thread_local float load_factor = 0.0;

        if (((step_counter++) % kernel_select_interval_) == 0) {
          load_factor = fast_load_factor(0, stream, false);
        }
        Selector::execute_kernel(load_factor, options_.block_size,
                                 options_.max_bucket_size, table_->buckets_num,
                                 options_.dim, stream, n, d_table_,
                                 table_->buckets, keys, values, scores,
                                 missed_keys, missed_indices, missed_size);
      }
    } else {
      const size_type dev_ws_size{n * (sizeof(value_type*) + sizeof(int))};
      auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
      auto src{dev_ws.get<value_type**>(0)};
      auto dst_offset{reinterpret_cast<int*>(src + n)};

      CUDA_CHECK(cudaMemsetAsync(src, 0, dev_ws_size, stream));

      constexpr uint32_t MinBucketCapacityFilter =
          sizeof(VecD_Load) / sizeof(D);

      bool filter_condition =
          options_.max_bucket_size >= MinBucketCapacityFilter;

      if (filter_condition) {
        constexpr uint32_t BLOCK_SIZE = 128U;

        tlp_lookup_kernel_hybrid<key_type, value_type, score_type>
            <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
                table_->buckets, table_->buckets_num, options_.max_bucket_size,
                options_.dim, keys, src, scores, dst_offset, missed_keys,
                missed_indices, missed_size, n);
      } else {
        const size_t block_size = options_.block_size;
        const size_t N = n * TILE_SIZE;
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        lookup_kernel<key_type, value_type, score_type, TILE_SIZE>
            <<<grid_size, block_size, 0, stream>>>(
                d_table_, table_->buckets, options_.max_bucket_size,
                table_->buckets_num, options_.dim, keys, src, scores,
                missed_keys, missed_indices, missed_size, dst_offset, N);
      }

      if (values != nullptr) {
        thrust::device_ptr<uintptr_t> src_ptr(
            reinterpret_cast<uintptr_t*>(src));
        thrust::device_ptr<int> dst_offset_ptr(dst_offset);

        thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), src_ptr,
                            src_ptr + n, dst_offset_ptr,
                            thrust::less<uintptr_t>());

        const size_t block_size = options_.io_block_size;
        const size_t N = n * dim();
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        read_kernel<key_type, value_type, score_type>
            <<<grid_size, block_size, 0, stream>>>(src, values, dst_offset,
                                                   dim(), N);
      }
    }

    CudaCheckError();
  }

  void find(const size_type n, const key_type* keys,  // (n)
            value_type** values,                      // (n)
            bool* founds,                             // (n)
            score_type* scores = nullptr,             // (n)
            cudaStream_t stream = 0, bool unique_key = true) const {
    if (n == 0) {
      return;
    }

    CUDA_CHECK(cudaMemsetAsync(founds, 0, n * sizeof(bool), stream));

    read_shared_lock lock(mutex_, stream);

    constexpr uint32_t MinBucketCapacityFilter = sizeof(VecD_Load) / sizeof(D);
    if (unique_key && options_.max_bucket_size >= MinBucketCapacityFilter) {
      constexpr uint32_t BLOCK_SIZE = 128U;
      tlp_lookup_ptr_kernel_with_filter<key_type, value_type, score_type>
          <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
              table_->buckets, table_->buckets_num, options_.max_bucket_size,
              options_.dim, keys, values, scores, founds, n);
    } else {
      using Selector = SelectLookupPtrKernel<key_type, value_type, score_type>;
      static thread_local int step_counter = 0;
      static thread_local float load_factor = 0.0;

      if (((step_counter++) % kernel_select_interval_) == 0) {
        load_factor = fast_load_factor(0, stream, false);
      }

      Selector::execute_kernel(load_factor, options_.block_size,
                               options_.max_bucket_size, table_->buckets_num,
                               options_.dim, stream, n, d_table_,
                               table_->buckets, keys, values, scores, founds);
    }

    CudaCheckError();
  }

  void contains(const size_type n, const key_type* keys,  // (n)
                bool* founds,                             // (n)
                cudaStream_t stream = 0) const {
    if (n == 0) {
      return;
    }

    read_shared_lock lock(mutex_, stream);

    if (options_.max_bucket_size == 128) {
      // Pipeline lookup kernel only supports "bucket_size = 128".
      using Selector = SelectPipelineContainsKernel<key_type, value_type,
                                                    score_type, ArchTag>;
      ContainsKernelParams<key_type, value_type, score_type> containsParams(
          table_->buckets, table_->buckets_num, static_cast<uint32_t>(dim()),
          keys, founds, n);
      Selector::select_kernel(containsParams, stream);
    } else {
      using Selector = SelectContainsKernel<key_type, value_type, score_type>;
      static thread_local int step_counter = 0;
      static thread_local float load_factor = 0.0;

      if (((step_counter++) % kernel_select_interval_) == 0) {
        load_factor = fast_load_factor(0, stream, false);
      }
      Selector::execute_kernel(load_factor, options_.block_size,
                               options_.max_bucket_size, table_->buckets_num,
                               options_.dim, stream, n, d_table_,
                               table_->buckets, keys, founds);
    }
    CudaCheckError();
  }

  void erase(const size_type n, const key_type* keys, cudaStream_t stream = 0) {
    if (n == 0) {
      return;
    }

    update_read_lock lock(mutex_, stream);

    {
      const size_t block_size = options_.block_size;
      const size_t N = n * TILE_SIZE;
      const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

      remove_kernel<key_type, value_type, score_type, TILE_SIZE>
          <<<grid_size, block_size, 0, stream>>>(
              d_table_, keys, table_->buckets, table_->buckets_size,
              table_->bucket_max_size, table_->buckets_num, N);
    }

    CudaCheckError();
    return;
  }

  template <template <typename, typename> class PredFunctor>
  size_type erase_if(const key_type& pattern, const score_type& threshold,
                     cudaStream_t stream = 0) {
    update_read_lock lock(mutex_, stream);

    auto dev_ws{dev_mem_pool_->get_workspace<1>(sizeof(size_type), stream)};
    auto d_count{dev_ws.get<size_type*>(0)};

    CUDA_CHECK(cudaMemsetAsync(d_count, 0, sizeof(size_type), stream));

    {
      const size_t block_size = options_.block_size;
      const size_t N = table_->buckets_num;
      const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

      remove_kernel<key_type, value_type, score_type, PredFunctor>
          <<<grid_size, block_size, 0, stream>>>(
              d_table_, pattern, threshold, d_count, table_->buckets,
              table_->buckets_size, table_->bucket_max_size,
              table_->buckets_num, N);
    }

    size_type count = 0;
    CUDA_CHECK(cudaMemcpyAsync(&count, d_count, sizeof(size_type),
                               cudaMemcpyDeviceToHost, stream));
    CUDA_CHECK(cudaStreamSynchronize(stream));

    CudaCheckError();
    return count;
  }

  void clear(cudaStream_t stream = 0) {
    update_read_lock lock(mutex_, stream);

    const size_t block_size = options_.block_size;
    const size_t N = table_->buckets_num * table_->bucket_max_size;
    const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

    clear_kernel<key_type, value_type, score_type>
        <<<grid_size, block_size, 0, stream>>>(d_table_, table_->buckets, N);

    CudaCheckError();
  }

 public:
  void export_batch(size_type n, const size_type offset,
                    size_type* d_counter,          // (1)
                    key_type* keys,                // (n)
                    value_type* values,            // (n, DIM)
                    score_type* scores = nullptr,  // (n)
                    cudaStream_t stream = 0) const {
    read_shared_lock lock(mutex_, stream);

    CUDA_CHECK(cudaMemsetAsync(d_counter, 0, sizeof(size_type), stream));
    if (offset >= table_->capacity) {
      return;
    }
    n = std::min(table_->capacity - offset, n);

    size_type shared_size;
    size_type block_size;
    std::tie(shared_size, block_size) =
        dump_kernel_shared_memory_size<K, V, S>(shared_mem_size_);

    const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size);

    dump_kernel<key_type, value_type, score_type>
        <<<grid_size, block_size, shared_size, stream>>>(
            d_table_, table_->buckets, keys, values, scores, offset, n,
            d_counter);

    CudaCheckError();
  }

  size_type export_batch(const size_type n, const size_type offset,
                         key_type* keys,                // (n)
                         value_type* values,            // (n, DIM)
                         score_type* scores = nullptr,  // (n)
                         cudaStream_t stream = 0) const {
    auto dev_ws{dev_mem_pool_->get_workspace<1>(sizeof(size_type), stream)};
    auto d_counter{dev_ws.get<size_type*>(0)};

    CUDA_CHECK(cudaMemsetAsync(d_counter, 0, sizeof(size_type), stream));
    export_batch(n, offset, d_counter, keys, values, scores, stream);

    size_type counter = 0;
    CUDA_CHECK(cudaMemcpyAsync(&counter, d_counter, sizeof(size_type),
                               cudaMemcpyDeviceToHost, stream));
    CUDA_CHECK(cudaStreamSynchronize(stream));
    return counter;
  }

  template <template <typename, typename> class PredFunctor>
  void export_batch_if(const key_type& pattern, const score_type& threshold,
                       size_type n, const size_type offset,
                       size_type* d_counter,
                       key_type* keys,                // (n)
                       value_type* values,            // (n, DIM)
                       score_type* scores = nullptr,  // (n)
                       cudaStream_t stream = 0) const {
    read_shared_lock lock(mutex_, stream);
    CUDA_CHECK(cudaMemsetAsync(d_counter, 0, sizeof(size_type), stream));

    if (offset >= table_->capacity) {
      return;
    }
    n = std::min(table_->capacity - offset, n);

    const size_t score_size = scores ? sizeof(score_type) : 0;
    const size_t kvm_size =
        sizeof(key_type) + sizeof(value_type) * dim() + score_size;
    const size_t block_size = std::min(shared_mem_size_ / 2 / kvm_size, 1024UL);
    MERLIN_CHECK(
        block_size > 0,
        "[HierarchicalKV] block_size <= 0, the K-V-S size may be too large!");

    const size_t shared_size = kvm_size * block_size;
    const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size);

    dump_kernel<key_type, value_type, score_type, PredFunctor>
        <<<grid_size, block_size, shared_size, stream>>>(
            d_table_, table_->buckets, pattern, threshold, keys, values, scores,
            offset, n, d_counter);

    CudaCheckError();
  }

 public:
  bool empty(cudaStream_t stream = 0) const { return size(stream) == 0; }

  size_type size(cudaStream_t stream = 0) const {
    read_shared_lock lock(mutex_, stream);

    size_type h_size = 0;

    const size_type N = table_->buckets_num;
    const size_type step = static_cast<size_type>(
        std::numeric_limits<int>::max() / options_.max_bucket_size);

    thrust::device_ptr<int> size_ptr(table_->buckets_size);

    for (size_type start_i = 0; start_i < N; start_i += step) {
      size_type end_i = std::min(start_i + step, N);
      h_size += thrust::reduce(thrust_par(thrust_allocator_).on(stream),
                               size_ptr + start_i, size_ptr + end_i, 0,
                               thrust::plus<int>());
    }

    CudaCheckError();
    return h_size;
  }

  size_type capacity() const { return table_->capacity; }

  void reserve(const size_type new_capacity, cudaStream_t stream = 0) {
    if (reach_max_capacity_ || new_capacity > options_.max_capacity) {
      reach_max_capacity_ = (capacity() * 2 > options_.max_capacity);
      return;
    }

    {
      update_read_lock lock(mutex_, stream);

      // Once we have exclusive access, make sure that pending GPU calls have
      // been processed.
      CUDA_CHECK(cudaDeviceSynchronize());

      while (capacity() < new_capacity &&
             capacity() * 2 <= options_.max_capacity) {
        double_capacity<key_type, value_type, score_type>(&table_, allocator_);
        CUDA_CHECK(cudaDeviceSynchronize());
        sync_table_configuration();

        const size_t block_size = options_.block_size;
        const size_t N = TILE_SIZE * table_->buckets_num / 2;
        const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

        rehash_kernel_for_fast_mode<key_type, value_type, score_type, TILE_SIZE>
            <<<grid_size, block_size, 0, stream>>>(d_table_, table_->buckets,
                                                   N);
      }
      CUDA_CHECK(cudaDeviceSynchronize());
      reach_max_capacity_ = (capacity() * 2 > options_.max_capacity);
    }
    CudaCheckError();
  }

  float load_factor(cudaStream_t stream = 0) const {
    return static_cast<float>((size(stream) * 1.0) / (capacity() * 1.0));
  }

  void set_max_capacity(size_type new_max_capacity) {
    if (!is_power(2, new_max_capacity)) {
      throw std::invalid_argument(
          "None power-of-2 new_max_capacity is not supported.");
    }

    update_read_lock lock(mutex_);

    if (new_max_capacity < capacity()) {
      return;
    }
    if (reach_max_capacity_) {
      reach_max_capacity_ = false;
    }
    options_.max_capacity = new_max_capacity;
  }

  size_type dim() const noexcept { return options_.dim; }

  size_type max_bucket_size() const noexcept {
    return options_.max_bucket_size;
  }

  size_type bucket_count() const noexcept { return table_->buckets_num; }

  size_type save(BaseKVFile<K, V, S>* file,
                 const size_t max_workspace_size = 1L * 1024 * 1024,
                 cudaStream_t stream = 0) const {
    const size_type tuple_size{sizeof(key_type) + sizeof(score_type) +
                               sizeof(value_type) * dim()};
    MERLIN_CHECK(max_workspace_size >= tuple_size,
                 "[HierarchicalKV] max_workspace_size is smaller than a single "
                 "`key + scoredata + value` tuple! Please set a larger value!");

    size_type shared_size;
    size_type block_size;
    std::tie(shared_size, block_size) =
        dump_kernel_shared_memory_size<K, V, S>(shared_mem_size_);

    // Request exclusive access (to make sure capacity won't change anymore).
    update_read_lock lock(mutex_, stream);

    const size_type total_size{capacity()};
    const size_type n{std::min(max_workspace_size / tuple_size, total_size)};
    const size_type grid_size{SAFE_GET_GRID_SIZE(n, block_size)};

    // Grab temporary device and host memory.
    const size_type host_ws_size{n * tuple_size};
    auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
    auto h_keys{host_ws.get<key_type*>(0)};
    auto h_scores{reinterpret_cast<score_type*>(h_keys + n)};
    auto h_values{reinterpret_cast<value_type*>(h_scores + n)};

    const size_type dev_ws_size{sizeof(size_type) + host_ws_size};
    auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
    auto d_count{dev_ws.get<size_type*>(0)};
    auto d_keys{reinterpret_cast<key_type*>(d_count + 1)};
    auto d_scores{reinterpret_cast<score_type*>(d_keys + n)};
    auto d_values{reinterpret_cast<value_type*>(d_scores + n)};

    // Step through table, dumping contents in batches.
    size_type total_count{0};
    for (size_type i{0}; i < total_size; i += n) {
      // Dump the next batch to workspace, and then write it to the file.
      CUDA_CHECK(cudaMemsetAsync(d_count, 0, sizeof(size_type), stream));

      dump_kernel<key_type, value_type, score_type>
          <<<grid_size, block_size, shared_size, stream>>>(
              d_table_, table_->buckets, d_keys, d_values, d_scores, i,
              std::min(total_size - i, n), d_count);

      size_type count;
      CUDA_CHECK(cudaMemcpyAsync(&count, d_count, sizeof(size_type),
                                 cudaMemcpyDeviceToHost, stream));
      CUDA_CHECK(cudaStreamSynchronize(stream));

      if (count == n) {
        CUDA_CHECK(cudaMemcpyAsync(h_keys, d_keys, host_ws_size,
                                   cudaMemcpyDeviceToHost, stream));
      } else {
        CUDA_CHECK(cudaMemcpyAsync(h_keys, d_keys, sizeof(key_type) * count,
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaMemcpyAsync(h_scores, d_scores,
                                   sizeof(score_type) * count,
                                   cudaMemcpyDeviceToHost, stream));
        CUDA_CHECK(cudaMemcpyAsync(h_values, d_values,
                                   sizeof(value_type) * dim() * count,
                                   cudaMemcpyDeviceToHost, stream));
      }

      CUDA_CHECK(cudaStreamSynchronize(stream));
      file->write(count, dim(), h_keys, h_values, h_scores);
      total_count += count;
    }

    return total_count;
  }

  size_type load(BaseKVFile<K, V, S>* file,
                 const size_t max_workspace_size = 1L * 1024 * 1024,
                 cudaStream_t stream = 0) {
    const size_type tuple_size{sizeof(key_type) + sizeof(score_type) +
                               sizeof(value_type) * dim()};
    MERLIN_CHECK(max_workspace_size >= tuple_size,
                 "[HierarchicalKV] max_workspace_size is smaller than a single "
                 "`key + score + value` tuple! Please set a larger value!");

    const size_type n{max_workspace_size / tuple_size};
    const size_type ws_size{n * tuple_size};

    // Grab enough host memory to hold batch data.
    auto host_ws{host_mem_pool_->get_workspace<1>(ws_size, stream)};
    auto h_keys{host_ws.get<key_type*>(0)};
    auto h_scores{reinterpret_cast<score_type*>(h_keys + n)};
    auto h_values{reinterpret_cast<value_type*>(h_scores + n)};

    // Attempt a first read.
    size_type count{file->read(n, dim(), h_keys, h_values, h_scores)};
    if (count == 0) {
      return 0;
    }

    // Grab equal amount of device memory as temporary storage.
    auto dev_ws{dev_mem_pool_->get_workspace<1>(ws_size, stream)};
    auto d_keys{dev_ws.get<key_type*>(0)};
    auto d_scores{reinterpret_cast<score_type*>(d_keys + n)};
    auto d_values{reinterpret_cast<value_type*>(d_scores + n)};

    size_type total_count{0};
    do {
      if (count == n) {
        CUDA_CHECK(cudaMemcpyAsync(d_keys, h_keys, ws_size,
                                   cudaMemcpyHostToDevice, stream));
      } else {
        CUDA_CHECK(cudaMemcpyAsync(d_keys, h_keys, sizeof(key_type) * count,
                                   cudaMemcpyHostToDevice, stream));
        CUDA_CHECK(cudaMemcpyAsync(d_scores, h_scores,
                                   sizeof(score_type) * count,
                                   cudaMemcpyHostToDevice, stream));
        CUDA_CHECK(cudaMemcpyAsync(d_values, h_values,
                                   sizeof(value_type) * dim() * count,
                                   cudaMemcpyHostToDevice, stream));
      }

      set_global_epoch(static_cast<S>(IGNORED_GLOBAL_EPOCH));
      insert_or_assign(count, d_keys, d_values, d_scores, stream, true, true);
      total_count += count;

      // Read next batch.
      CUDA_CHECK(cudaStreamSynchronize(stream));
      count = file->read(n, dim(), h_keys, h_values, h_scores);
    } while (count > 0);

    return total_count;
  }

  void set_global_epoch(const uint64_t epoch) { global_epoch_ = epoch; }

 private:
  bool is_power(size_t base, size_t n) {
    if (base < 2) {
      throw std::invalid_argument("is_power with zero base.");
    }
    while (n > 1) {
      if (n % base != 0) {
        return false;
      }
      n /= base;
    }
    return true;
  }

 private:
  inline bool is_fast_mode() const noexcept { return table_->is_pure_hbm; }

  inline float fast_load_factor(const size_type delta = 0,
                                cudaStream_t stream = 0,
                                const bool need_lock = true) const {
    read_shared_lock lock(mutex_, std::defer_lock, stream);
    if (need_lock) {
      lock.lock();
    }

    size_t N = std::min(table_->buckets_num, 1024UL);

    thrust::device_ptr<int> size_ptr(table_->buckets_size);

    int size = thrust::reduce(thrust_par(thrust_allocator_).on(stream),
                              size_ptr, size_ptr + N, 0, thrust::plus<int>());

    CudaCheckError();
    return static_cast<float>((delta * 1.0) / (capacity() * 1.0) +
                              (size * 1.0) /
                                  (options_.max_bucket_size * N * 1.0));
  }

  inline void check_evict_strategy(const score_type* scores) {
    if (evict_strategy == EvictStrategy::kLru ||
        evict_strategy == EvictStrategy::kEpochLru) {
      MERLIN_CHECK(scores == nullptr,
                   "the scores should not be specified when running on "
                   "LRU or Epoch LRU mode.");
    }

    if (evict_strategy == EvictStrategy::kLfu ||
        evict_strategy == EvictStrategy::kEpochLfu) {
      MERLIN_CHECK(scores != nullptr,
                   "the scores should be specified when running on "
                   "LFU or Epoch LFU mode.");
    }

    if (evict_strategy == EvictStrategy::kCustomized) {
      MERLIN_CHECK(scores != nullptr,
                   "the scores should be specified when running on "
                   "customized mode.");
    }

    if ((evict_strategy == EvictStrategy::kEpochLru ||
         evict_strategy == EvictStrategy::kEpochLfu)) {
      MERLIN_CHECK(
          global_epoch_ != static_cast<S>(IGNORED_GLOBAL_EPOCH),
          "the global_epoch is invalid and should be assigned by calling "
          "`set_global_epoch` when running on "
          "Epoch LRU or Epoch LFU mode.");
    }
  }

  inline void sync_table_configuration() {
    CUDA_CHECK(
        cudaMemcpy(d_table_, table_, sizeof(TableCore), cudaMemcpyDefault));
  }

 private:
  HashTableOptions options_;
  TableCore* table_ = nullptr;
  TableCore* d_table_ = nullptr;
  size_t shared_mem_size_ = 0;
  std::atomic_bool reach_max_capacity_{false};
  bool initialized_ = false;
  mutable group_shared_mutex mutex_;
  const unsigned int kernel_select_interval_ = 7;
  std::unique_ptr<DeviceMemoryPool> dev_mem_pool_;
  std::unique_ptr<HostMemoryPool> host_mem_pool_;
  allocator_type* allocator_;
  ThrustAllocator<uint8_t> thrust_allocator_;
  bool default_allocator_ = true;
  std::atomic<uint64_t> global_epoch_{
      static_cast<uint64_t>(IGNORED_GLOBAL_EPOCH)};
};

}  // namespace merlin
}  // namespace nv