HPS Layers

SparseLookupLayer

class hierarchical_parameter_server.SparseLookupLayer(*args, **kwargs)[source]

Bases: keras.engine.base_layer.Layer

Abbreviated as hps.SparseLookupLayer(*args, **kwargs).

This is a wrapper class for HPS sparse lookup layer, which basically performs the same function as tf.nn.embedding_lookup_sparse. Note that ps_config_file and global_batch_size should be specified in the constructor if you want to use implicit HPS initialization.

Parameters
  • model_name (str) – The name of the model that has embedding tables.

  • table_id (int) – The index of the embedding table for the model specified by model_name.

  • emb_vec_size (int) – The embedding vector size for the embedding table specified by model_name and table_id.

  • emb_vec_dtype – The data type of embedding vectors which must be tf.float32.

  • ps_config_file (str) – The JSON configuration file for HPS initialization.

  • global_batch_size (int) – The global batch size for HPS that is deployed on multiple GPUs.

Examples

import hierarchical_parameter_server as hps

sparse_lookup_layer = hps.SparseLookupLayer(model_name = args.model_name,
                                           table_id = args.table_id,
                                           emb_vec_size = args.embed_vec_size,
                                           emb_vec_dtype = tf.float32,
                                           ps_config_file = args.ps_config_file,
                                           global_batch_size = args.global_batch_size)

@tf.function
def _infer_step(inputs):
    embedding_vector = sparse_lookup_layer(sp_ids=inputs,
                                          sp_weights = None,
                                          combiner="mean")
    ...

for i, (inputs, labels) in enumerate(dataset):
    _infer_step(inputs)
call(sp_ids, sp_weights, name=None, combiner=None, max_norm=None)[source]

Looks up embeddings for the given ids and weights from a list of tensors. This op assumes that there is at least one ID for each row in the dense tensor represented by sp_ids (i.e. there are no rows with empty features), and that all the indices of sp_ids are in canonical row-major order. The sp_ids and sp_weights (if not None) are SparseTensor with rank of 2. Embeddings are always aggregated along the last dimension. If an ID value cannot be found in the HPS, the default embeddings are retrieved, which can be specified in the HPS configuration JSON file.

Parameters
  • sp_ids – N x M SparseTensor of int64 IDs where N is typically batch size and M is arbitrary.

  • sp_weights – Either a SparseTensor of float or double weights, or None to indicate all weights should be taken to be 1. If specified, sp_weights must have exactly the same shape and indices as sp_ids.

  • combiner

    A string that specifies the reduction op:

    "sum"

    Computes the weighted sum of the embedding results for each row.

    "mean"

    Computes the weighted sum divided by the total weight.

    "sqrtn"

    Computes the weighted sum divided by the square root of the sum of the squares of the weights.

    The default value is "mean".

  • max_norm – if not None, each embedding is clipped if its l2-norm is larger than this value, before combining.

Returns

emb_vector – A dense tensor representing the combined embeddings for the sparse IDs. For each row in the dense tensor represented by sp_ids, the op looks up the embeddings for all IDs in that row, multiplies them by the corresponding weight, and combines these embeddings as specified. In other words, if

shape(sp_ids) = shape(sp_weights) = [d0, d1]

then

shape(output) = [d0, self.emb_vec_dtype]

For instance, if self.emb_vec_dtype is 16, and sp_ids / sp_weights are

[0, 0]: id 1, weight 2.0
[0, 1]: id 3, weight 0.5
[1, 0]: id 0, weight 1.0
[2, 3]: id 1, weight 3.0

with combiner = "mean", then the output is a 3x16 matrix where

output[0, :] = (vector_for_id_1 * 2.0 + vector_for_id_3 * 0.5) / (2.0 + 0.5)
output[1, :] = (vector_for_id_0 * 1.0) / 1.0
output[2, :] = (vector_for_id_1 * 3.0) / 3.0

Return type

tf.Tensor of int32

Raises
  • TypeError – If sp_ids is not a SparseTensor, or if sp_weights is: neither None nor SparseTensor.

  • ValueError – If combiner is not one of {"mean", "sqrtn", "sum"}.:

LookupLayer

class hierarchical_parameter_server.LookupLayer(*args, **kwargs)[source]

Bases: keras.engine.base_layer.Layer

Abbreviated as hps.LookupLayer(*args, **kwargs).

This is a wrapper class for HPS lookup layer, which basically performs the same function as tf.nn.embedding_lookup. Note that ps_config_file and global_batch_size should be specified in the constructor if you want to use implicit HPS initialization.

Parameters
  • model_name (str) – The name of the model that has embedding tables.

  • table_id (int) – The index of the embedding table for the model specified by model_name.

  • emb_vec_size (int) – The embedding vector size for the embedding table specified by model_name and table_id.

  • emb_vec_dtype – The data type of embedding vectors which must be tf.float32.

  • ps_config_file (str) – The JSON configuration file for HPS initialization.

  • global_batch_size (int) – The global batch size for HPS that is deployed on multiple GPUs.

Examples

import hierarchical_parameter_server as hps

lookup_layer = hps.LookupLayer(model_name = args.model_name,
                              table_id = args.table_id,
                              emb_vec_size = args.embed_vec_size,
                              emb_vec_dtype = tf.float32,
                              ps_config_file = args.ps_config_file,
                              global_batch_size = args.global_batch_size)

@tf.function
def _infer_step(inputs):
    embedding_vector = lookup_layer(inputs)
    ...

for i, (inputs, labels) in enumerate(dataset):
    _infer_step(inputs)
call(inputs)[source]

The forward logic of this wrapper class.

Parameters

inputs – Keys are stored in Tensor. The data type must be tf.int64.

Returns

emb_vector – the embedding vectors for the input keys. Its shape is inputs.get_shape() + emb_vec_size.

Return type

tf.Tensor of int32