#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Text, Union
import torch
from merlin.models.utils.doc_utils import docstring_parameter
from merlin.schema import Tags, TagsType
from merlin_standard_lib import Schema, categorical_cardinalities
from merlin_standard_lib.utils.embedding_utils import get_embedding_sizes_from_schema
from ..block.base import SequentialBlock
from ..tabular.base import (
TABULAR_MODULE_PARAMS_DOCSTRING,
FilterFeatures,
TabularAggregationType,
TabularTransformation,
TabularTransformationType,
)
from ..utils.torch_utils import calculate_batch_size_from_input_size, get_output_sizes_from_schema
from .base import InputBlock
EMBEDDING_FEATURES_PARAMS_DOCSTRING = """
feature_config: Dict[str, FeatureConfig]
This specifies what TableConfig to use for each feature. For shared embeddings, the same
TableConfig can be used for multiple features.
item_id: str, optional
The name of the feature that's used for the item_id.
"""
[docs]@docstring_parameter(
tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING,
embedding_features_parameters=EMBEDDING_FEATURES_PARAMS_DOCSTRING,
)
class EmbeddingFeatures(InputBlock):
"""Input block for embedding-lookups for categorical features.
For multi-hot features, the embeddings will be aggregated into a single tensor using the mean.
Parameters
----------
{embedding_features_parameters}
{tabular_module_parameters}
"""
def __init__(
self,
feature_config: Dict[str, "FeatureConfig"],
item_id: Optional[str] = None,
pre: Optional[TabularTransformationType] = None,
post: Optional[TabularTransformationType] = None,
aggregation: Optional[TabularAggregationType] = None,
schema: Optional[Schema] = None,
):
super().__init__(pre=pre, post=post, aggregation=aggregation, schema=schema)
self.item_id = item_id
self.feature_config = feature_config
self.filter_features = FilterFeatures(list(feature_config.keys()))
embedding_tables = {}
features_dim = {}
tables: Dict[str, TableConfig] = {}
for name, feature in self.feature_config.items():
table: TableConfig = feature.table
features_dim[name] = table.dim
if name not in tables:
tables[name] = table
for name, table in tables.items():
embedding_tables[name] = self.table_to_embedding_module(table)
self.embedding_tables = torch.nn.ModuleDict(embedding_tables)
@property
def item_embedding_table(self):
assert self.item_id is not None
return self.embedding_tables[self.item_id]
[docs] def table_to_embedding_module(self, table: "TableConfig") -> torch.nn.Module:
embedding_table = EmbeddingBagWrapper(table.vocabulary_size, table.dim, mode=table.combiner)
if table.initializer is not None:
table.initializer(embedding_table.weight)
return embedding_table
[docs] @classmethod
def from_schema( # type: ignore
cls,
schema: Schema,
embedding_dims: Optional[Dict[str, int]] = None,
embedding_dim_default: int = 64,
infer_embedding_sizes: bool = False,
infer_embedding_sizes_multiplier: float = 2.0,
embeddings_initializers: Optional[Dict[str, Callable[[Any], None]]] = None,
combiner: str = "mean",
tags: Optional[TagsType] = None,
item_id: Optional[str] = None,
automatic_build: bool = True,
max_sequence_length: Optional[int] = None,
aggregation=None,
pre=None,
post=None,
**kwargs,
) -> Optional["EmbeddingFeatures"]:
"""Instantitates ``EmbeddingFeatures`` from a ``DatasetSchema``.
Parameters
----------
schema : DatasetSchema
Dataset schema
embedding_dims : Optional[Dict[str, int]], optional
The dimension of the embedding table for each feature (key),
by default None by default None
default_embedding_dim : Optional[int], optional
Default dimension of the embedding table, when the feature is not found
in ``default_soft_embedding_dim``, by default 64
infer_embedding_sizes : bool, optional
Automatically defines the embedding dimension from the
feature cardinality in the schema,
by default False
infer_embedding_sizes_multiplier: Optional[int], by default 2.0
multiplier used by the heuristic to infer the embedding dimension from
its cardinality. Generally reasonable values range between 2.0 and 10.0
embeddings_initializers: Optional[Dict[str, Callable[[Any], None]]]
Dict where keys are feature names and values are callable to initialize embedding tables
combiner : Optional[str], optional
Feature aggregation option, by default "mean"
tags : Optional[Union[DefaultTags, list, str]], optional
Tags to filter columns, by default None
item_id : Optional[str], optional
Name of the item id column (feature), by default None
automatic_build : bool, optional
Automatically infers input size from features, by default True
max_sequence_length : Optional[int], optional
Maximum sequence length for list features,, by default None
Returns
-------
Optional[EmbeddingFeatures]
Returns the ``EmbeddingFeatures`` for the dataset schema
"""
# TODO: propagate item-id from ITEM_ID tag
if tags:
schema = schema.select_by_tag(tags)
_item_id = schema.select_by_tag(Tags.ITEM_ID)
if not item_id and len(_item_id) > 0:
if len(_item_id) > 1:
raise ValueError(
"Multiple columns with tag ITEM_ID found. "
"Please specify the item_id column name."
)
item_id = list(_item_id)[0].name
embedding_dims = embedding_dims or {}
if infer_embedding_sizes:
embedding_dims_infered = get_embedding_sizes_from_schema(
schema, infer_embedding_sizes_multiplier
)
embedding_dims = {
**embedding_dims,
**{k: v for k, v in embedding_dims_infered.items() if k not in embedding_dims},
}
embeddings_initializers = embeddings_initializers or {}
emb_config = {}
cardinalities = categorical_cardinalities(schema)
for key, cardinality in cardinalities.items():
embedding_size = embedding_dims.get(key, embedding_dim_default)
embedding_initializer = embeddings_initializers.get(key, None)
emb_config[key] = (cardinality, embedding_size, embedding_initializer)
feature_config: Dict[str, FeatureConfig] = {}
for name, (vocab_size, dim, emb_initilizer) in emb_config.items():
feature_config[name] = FeatureConfig(
TableConfig(
vocabulary_size=vocab_size,
dim=dim,
name=name,
combiner=combiner,
initializer=emb_initilizer,
)
)
if not feature_config:
return None
output = cls(feature_config, item_id=item_id, pre=pre, post=post, aggregation=aggregation)
if automatic_build and schema:
output.build(
get_output_sizes_from_schema(
schema,
kwargs.get("batch_size", -1),
max_sequence_length=max_sequence_length,
),
schema=schema,
)
return output
[docs] def item_ids(self, inputs) -> torch.Tensor:
return inputs[self.item_id]
[docs] def forward(self, inputs, **kwargs):
embedded_outputs = {}
filtered_inputs = self.filter_features(inputs)
for name, val in filtered_inputs.items():
if isinstance(val, tuple):
values, offsets = val
values = torch.squeeze(values, -1)
# for the case where only one value in values
if len(values.shape) == 0:
values = values.unsqueeze(0)
embedded_outputs[name] = self.embedding_tables[name](values, offsets[:, 0])
else:
# if len(val.shape) <= 1:
# val = val.unsqueeze(0)
embedded_outputs[name] = self.embedding_tables[name](val)
# Store raw item ids for masking and/or negative sampling
# This makes this module stateful.
if self.item_id:
self.item_seq = self.item_ids(inputs)
embedded_outputs = super().forward(embedded_outputs)
return embedded_outputs
[docs] def forward_output_size(self, input_sizes):
sizes = {}
batch_size = calculate_batch_size_from_input_size(input_sizes)
for name, feature in self.feature_config.items():
sizes[name] = torch.Size([batch_size, feature.table.dim])
return sizes
[docs]class EmbeddingBagWrapper(torch.nn.EmbeddingBag):
"""
Wrapper class for the PyTorch EmbeddingBag module.
This class extends the torch.nn.EmbeddingBag class and overrides
the forward method to handle 1D tensor inputs
by reshaping them to 2D as required by the EmbeddingBag.
"""
[docs] def forward(self, input, **kwargs):
# EmbeddingBag requires 2D tensors (or offsets)
if len(input.shape) == 1:
input = input.unsqueeze(-1)
return super().forward(input, **kwargs)
[docs]@docstring_parameter(
tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING,
embedding_features_parameters=EMBEDDING_FEATURES_PARAMS_DOCSTRING,
)
class SoftEmbeddingFeatures(EmbeddingFeatures):
"""
Encapsulate continuous features encoded using the Soft-one hot encoding
embedding technique (SoftEmbedding), from https://arxiv.org/pdf/1708.00065.pdf
In a nutshell, it keeps an embedding table for each continuous feature,
which is represented as a weighted average of embeddings.
Parameters
----------
feature_config: Dict[str, FeatureConfig]
This specifies what TableConfig to use for each feature. For shared embeddings, the same
TableConfig can be used for multiple features.
layer_norm: boolean
When layer_norm is true, TabularLayerNorm will be used in post.
{tabular_module_parameters}
"""
def __init__(
self,
feature_config: Dict[str, "FeatureConfig"],
layer_norm: bool = True,
pre: Optional[TabularTransformationType] = None,
post: Optional[TabularTransformationType] = None,
aggregation: Optional[TabularAggregationType] = None,
**kwarg,
):
if layer_norm:
from transformers4rec.torch import TabularLayerNorm
post = TabularLayerNorm.from_feature_config(feature_config)
super().__init__(feature_config, pre=pre, post=post, aggregation=aggregation)
[docs] @classmethod
def from_schema( # type: ignore
cls,
schema: Schema,
soft_embedding_cardinalities: Optional[Dict[str, int]] = None,
soft_embedding_cardinality_default: int = 10,
soft_embedding_dims: Optional[Dict[str, int]] = None,
soft_embedding_dim_default: int = 8,
embeddings_initializers: Optional[Dict[str, Callable[[Any], None]]] = None,
layer_norm: bool = True,
combiner: str = "mean",
tags: Optional[TagsType] = None,
automatic_build: bool = True,
max_sequence_length: Optional[int] = None,
**kwargs,
) -> Optional["SoftEmbeddingFeatures"]:
"""
Instantitates ``SoftEmbeddingFeatures`` from a ``DatasetSchema``.
Parameters
----------
schema : DatasetSchema
Dataset schema
soft_embedding_cardinalities : Optional[Dict[str, int]], optional
The cardinality of the embedding table for each feature (key),
by default None
soft_embedding_cardinality_default : Optional[int], optional
Default cardinality of the embedding table, when the feature
is not found in ``soft_embedding_cardinalities``, by default 10
soft_embedding_dims : Optional[Dict[str, int]], optional
The dimension of the embedding table for each feature (key), by default None
soft_embedding_dim_default : Optional[int], optional
Default dimension of the embedding table, when the feature
is not found in ``soft_embedding_dim_default``, by default 8
embeddings_initializers: Optional[Dict[str, Callable[[Any], None]]]
Dict where keys are feature names and values are callable to initialize embedding tables
combiner : Optional[str], optional
Feature aggregation option, by default "mean"
tags : Optional[Union[DefaultTags, list, str]], optional
Tags to filter columns, by default None
automatic_build : bool, optional
Automatically infers input size from features, by default True
max_sequence_length : Optional[int], optional
Maximum sequence length for list features, by default None
Returns
-------
Optional[SoftEmbeddingFeatures]
Returns a ``SoftEmbeddingFeatures`` instance from the dataset schema
"""
# TODO: propagate item-id from ITEM_ID tag
if tags:
schema = schema.select_by_tag(tags)
soft_embedding_cardinalities = soft_embedding_cardinalities or {}
soft_embedding_dims = soft_embedding_dims or {}
embeddings_initializers = embeddings_initializers or {}
sizes = {}
cardinalities = categorical_cardinalities(schema)
for col_name in schema.column_names:
# If this is NOT a categorical feature
if col_name not in cardinalities:
embedding_size = soft_embedding_dims.get(col_name, soft_embedding_dim_default)
cardinality = soft_embedding_cardinalities.get(
col_name, soft_embedding_cardinality_default
)
emb_initializer = embeddings_initializers.get(col_name, None)
sizes[col_name] = (cardinality, embedding_size, emb_initializer)
feature_config: Dict[str, FeatureConfig] = {}
for name, (vocab_size, dim, emb_initializer) in sizes.items():
feature_config[name] = FeatureConfig(
TableConfig(
vocabulary_size=vocab_size,
dim=dim,
name=name,
combiner=combiner,
initializer=emb_initializer,
)
)
if not feature_config:
return None
output = cls(feature_config, layer_norm=layer_norm, **kwargs)
if automatic_build and schema:
output.build(
get_output_sizes_from_schema(
schema,
kwargs.get("batch_size", -1),
max_sequence_length=max_sequence_length,
)
)
return output
[docs] def table_to_embedding_module(self, table: "TableConfig") -> "SoftEmbedding":
return SoftEmbedding(table.vocabulary_size, table.dim, table.initializer)
[docs]class TableConfig:
"""
Class to configure the embeddings lookup table for a categorical feature.
Attributes
----------
vocabulary_size : int
The size of the vocabulary,
i.e., the cardinality of the categorical feature.
dim : int
The dimensionality of the embedding vectors.
initializer : Optional[Callable[[torch.Tensor], None]]
The initializer function for the embedding weights.
If None, the weights are initialized using a normal
distribution with mean 0.0 and standard deviation 0.05.
combiner : Optional[str]
The combiner operation used to aggregate bag of embeddings.
Possible options are "mean", "sum", and "sqrtn".
By default "mean".
name : Optional[str]
The name of the lookup table.
By default None.
"""
def __init__(
self,
vocabulary_size: int,
dim: int,
initializer: Optional[Callable[[torch.Tensor], None]] = None,
combiner: Text = "mean",
name: Optional[Text] = None,
):
if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
raise ValueError("Invalid vocabulary_size {}.".format(vocabulary_size))
if not isinstance(dim, int) or dim < 1:
raise ValueError("Invalid dim {}.".format(dim))
if combiner not in ("mean", "sum", "sqrtn"):
raise ValueError("Invalid combiner {}".format(combiner))
if (initializer is not None) and (not callable(initializer)):
raise ValueError("initializer must be callable if specified.")
self.initializer: Callable[[torch.Tensor], None]
if initializer is None:
self.initializer = partial(torch.nn.init.normal_, mean=0.0, std=0.05) # type: ignore
else:
self.initializer = initializer
self.vocabulary_size = vocabulary_size
self.dim = dim
self.combiner = combiner
self.name = name
def __repr__(self):
return (
"TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, "
"combiner={combiner!r}, name={name!r})".format(
vocabulary_size=self.vocabulary_size,
dim=self.dim,
combiner=self.combiner,
name=self.name,
)
)
[docs]class FeatureConfig:
"""
Class to set the embeddings table of a categorical feature
with a maximum sequence length.
Attributes
----------
table : TableConfig
Configuration for the lookup table,
which is used for embedding lookup and aggregation.
max_sequence_length : int, optional
Maximum sequence length for sequence features.
By default 0.
name : str, optional
The feature name.
By default None
"""
def __init__(
self, table: TableConfig, max_sequence_length: int = 0, name: Optional[Text] = None
):
self.table = table
self.max_sequence_length = max_sequence_length
self.name = name
def __repr__(self):
return (
"FeatureConfig(table={table!r}, "
"max_sequence_length={max_sequence_length!r}, name={name!r})".format(
table=self.table, max_sequence_length=self.max_sequence_length, name=self.name
)
)
[docs]class SoftEmbedding(torch.nn.Module):
"""
Soft-one hot encoding embedding technique, from https://arxiv.org/pdf/1708.00065.pdf
In a nutshell, it represents a continuous feature as a weighted average of embeddings
"""
def __init__(self, num_embeddings, embeddings_dim, emb_initializer=None):
"""
Parameters
----------
num_embeddings: Number of embeddings to use (cardinality of the embedding table).
embeddings_dim: The dimension of the vector space for projecting the scalar value.
embeddings_init_std: The standard deviation factor for normal initialization of the
embedding matrix weights.
emb_initializer: Dict where keys are feature names and values are callable to initialize
embedding tables
"""
assert (
num_embeddings > 0
), "The number of embeddings for soft embeddings needs to be greater than 0"
assert (
embeddings_dim > 0
), "The embeddings dim for soft embeddings needs to be greater than 0"
super(SoftEmbedding, self).__init__()
self.embedding_table = torch.nn.Embedding(num_embeddings, embeddings_dim)
if emb_initializer:
emb_initializer(self.embedding_table.weight)
self.projection_layer = torch.nn.Linear(1, num_embeddings, bias=True)
self.softmax = torch.nn.Softmax(dim=-1)
[docs] def forward(self, input_numeric):
input_numeric = input_numeric.unsqueeze(-1)
weights = self.softmax(self.projection_layer(input_numeric))
soft_one_hot_embeddings = (weights.unsqueeze(-1) * self.embedding_table.weight).sum(-2)
return soft_one_hot_embeddings
[docs]class PretrainedEmbeddingsInitializer(torch.nn.Module):
"""
Initializer of embedding tables with pre-trained weights
Parameters
----------
weight_matrix : Union[torch.Tensor, List[List[float]]]
A 2D torch or numpy tensor or lists of lists with the pre-trained
weights for embeddings. The expect dims are
(embedding_cardinality, embedding_dim). The embedding_cardinality
can be inferred from the column schema, for example,
`schema.select_by_name("item_id").feature[0].int_domain.max + 1`.
The first position of the embedding table is reserved for padded
items (id=0).
trainable : bool
Whether the embedding table should be trainable or not
"""
def __init__(
self,
weight_matrix: Union[torch.Tensor, List[List[float]]],
trainable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
# The weight matrix is kept in CPU, but when forward() is called
# to initialize the embedding table weight will be copied to
# the embedding table device (e.g. cuda)
self.weight_matrix = torch.tensor(weight_matrix, device="cpu")
self.trainable = trainable
[docs] def forward(self, x):
with torch.no_grad():
x.copy_(self.weight_matrix)
x.requires_grad = self.trainable
[docs]@docstring_parameter(
tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING,
)
class PretrainedEmbeddingFeatures(InputBlock):
"""Input block for pre-trained embeddings features.
For 3-D features, if sequence_combiner is set, the features are aggregated
using the second dimension (sequence length)
Parameters
----------
features: List[str]
A list of the pre-trained embeddings feature names.
You typically will pass schema.select_by_tag(Tags.EMBEDDING).column_names,
as that is the tag added to pre-trained embedding features when using the
merlin.dataloader.ops.embeddings.EmbeddingOperator
pretrained_output_dims: Optional[Union[int, Dict[str, int]]]
If provided, it projects features to specified dim(s).
If an int, all features are projected to that dim.
If a dict, only features provided in the dict will be mapped to the specified dim,
sequence_combiner: Optional[Union[str, torch.nn.Module]], optional
A string ("mean", "sum", "max", "min") or torch.nn.Module specifying
how to combine the second dimension of the pre-trained embeddings if it is 3D.
Default is None (no sequence combiner used)
normalizer: Optional[Union[str, TabularTransformationType]]
A tabular layer (e.g.tr.TabularLayerNorm()) or string ("layer-norm") to be applied
to pre-trained embeddings after projected and sequence combined
Default is None (no normalization)
schema (Optional[Schema]): the schema of the input data.
{tabular_module_parameters}
"""
def __init__(
self,
features: List[str],
pretrained_output_dims: Optional[Union[int, Dict[str, int]]] = None,
sequence_combiner: Optional[Union[str, torch.nn.Module]] = None,
pre: Optional[TabularTransformationType] = None,
post: Optional[TabularTransformationType] = None,
aggregation: Optional[TabularAggregationType] = None,
normalizer: Optional[TabularTransformationType] = None,
schema: Optional[Schema] = None,
):
if isinstance(normalizer, str):
normalizer = TabularTransformation.parse(normalizer)
if not post:
post = normalizer
elif normalizer:
post = SequentialBlock(normalizer, post) # type: ignore
super().__init__(pre=pre, post=post, aggregation=aggregation, schema=schema)
self.features = features
self.filter_features = FilterFeatures(features)
self.pretrained_output_dims = pretrained_output_dims
self.sequence_combiner = self.parse_combiner(sequence_combiner)
[docs] def build(self, input_size, **kwargs):
if input_size is not None:
if self.pretrained_output_dims:
self.projection = torch.nn.ModuleDict()
if isinstance(self.pretrained_output_dims, int):
for key in self.features:
self.projection[key] = torch.nn.Linear(
input_size[key][-1], self.pretrained_output_dims
)
elif isinstance(self.pretrained_output_dims, dict):
for key in self.features:
self.projection[key] = torch.nn.Linear(
input_size[key][-1], self.pretrained_output_dims[key]
)
return super().build(input_size, **kwargs)
[docs] @classmethod
def from_schema(
cls,
schema: Schema,
tags: Optional[TagsType] = None,
pretrained_output_dims=None,
sequence_combiner=None,
normalizer: Optional[Union[str, TabularTransformationType]] = None,
pre: Optional[TabularTransformationType] = None,
post: Optional[TabularTransformationType] = None,
aggregation: Optional[TabularAggregationType] = None,
**kwargs,
): # type: ignore
if tags:
schema = schema.select_by_tag(tags)
features = schema.column_names
return cls(
features=features,
pretrained_output_dims=pretrained_output_dims,
sequence_combiner=sequence_combiner,
pre=pre,
post=post,
aggregation=aggregation,
normalizer=normalizer,
)
[docs] def forward(self, inputs):
output = self.filter_features(inputs)
if self.pretrained_output_dims:
output = {key: self.projection[key](val) for key, val in output.items()}
if self.sequence_combiner:
for key, val in output.items():
if val.dim() > 2:
output[key] = self.sequence_combiner(val, axis=1)
return output
[docs] def forward_output_size(self, input_sizes):
sizes = self.filter_features.forward_output_size(input_sizes)
if self.pretrained_output_dims:
if isinstance(self.pretrained_output_dims, dict):
sizes.update(
{
key: torch.Size(list(sizes[key][:-1]) + [self.pretrained_output_dims[key]])
for key in self.features
}
)
else:
sizes.update(
{
key: torch.Size(list(sizes[key][:-1]) + [self.pretrained_output_dims])
for key in self.features
}
)
return sizes
[docs] def parse_combiner(self, combiner):
if isinstance(combiner, str):
if combiner == "sum":
combiner = torch.sum
elif combiner == "max":
combiner = torch.max
elif combiner == "min":
combiner = torch.min
elif combiner == "mean":
combiner = torch.mean
return combiner