HPS Plugin for Torch
LookupLayer class
This is a wrapper class for HPS lookup layer, which basically performs the same function as torch.nn.Embedding
. It inherits torch.nn.Module
.
hps_torch.LookupLayer.__init__
Arguments
ps_config_file
: String. The JSON configuration file for HPS initialization.model_name
: String. The name of the model that has embedding tables.table_id
: Integer. The index of the embedding table for the model specified bymodel_name
.emb_vec_size
: Integer. The embedding vector size for the embedding table specified bymodel_name
andtable_id
.
hps_torch.LookupLayer.forward
Arguments
keys
: Tensor oftorch.int32
ortorch.int64
.
Returns
vectors
: Tensor oftorch.float32
.