Source code for transformers4rec.torch.features.embedding

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from functools import partial
from typing import Any, Callable, Dict, Optional, Text, Union

import torch

from merlin_standard_lib import Schema, Tag
from merlin_standard_lib.utils.doc_utils import docstring_parameter
from merlin_standard_lib.utils.embedding_utils import get_embedding_sizes_from_schema

from ..tabular.base import (
    TABULAR_MODULE_PARAMS_DOCSTRING,
    FilterFeatures,
    TabularAggregationType,
    TabularTransformationType,
)
from ..utils.torch_utils import calculate_batch_size_from_input_size, get_output_sizes_from_schema
from .base import InputBlock

EMBEDDING_FEATURES_PARAMS_DOCSTRING = """
    feature_config: Dict[str, FeatureConfig]
        This specifies what TableConfig to use for each feature. For shared embeddings, the same
        TableConfig can be used for multiple features.
    item_id: str, optional
        The name of the feature that's used for the item_id.
"""


[docs]@docstring_parameter( tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING, embedding_features_parameters=EMBEDDING_FEATURES_PARAMS_DOCSTRING, ) class EmbeddingFeatures(InputBlock): """Input block for embedding-lookups for categorical features. For multi-hot features, the embeddings will be aggregated into a single tensor using the mean. Parameters ---------- {embedding_features_parameters} {tabular_module_parameters} """ def __init__( self, feature_config: Dict[str, "FeatureConfig"], item_id: Optional[str] = None, pre: Optional[TabularTransformationType] = None, post: Optional[TabularTransformationType] = None, aggregation: Optional[TabularAggregationType] = None, schema: Optional[Schema] = None, ): super().__init__(pre=pre, post=post, aggregation=aggregation, schema=schema) self.item_id = item_id self.feature_config = feature_config self.filter_features = FilterFeatures(list(feature_config.keys())) embedding_tables = {} features_dim = {} tables: Dict[str, TableConfig] = {} for name, feature in self.feature_config.items(): table: TableConfig = feature.table features_dim[name] = table.dim if name not in tables: tables[name] = table for name, table in tables.items(): embedding_tables[name] = self.table_to_embedding_module(table) self.embedding_tables = torch.nn.ModuleDict(embedding_tables) @property def item_embedding_table(self): assert self.item_id is not None return self.embedding_tables[self.item_id]
[docs] def table_to_embedding_module(self, table: "TableConfig") -> torch.nn.Module: embedding_table = EmbeddingBagWrapper(table.vocabulary_size, table.dim, mode=table.combiner) if table.initializer is not None: table.initializer(embedding_table.weight) return embedding_table
[docs] @classmethod def from_schema( # type: ignore cls, schema: Schema, embedding_dims: Optional[Dict[str, int]] = None, embedding_dim_default: int = 64, infer_embedding_sizes: bool = False, infer_embedding_sizes_multiplier: float = 2.0, embeddings_initializers: Optional[Dict[str, Callable[[Any], None]]] = None, combiner: str = "mean", tags: Optional[Union[Tag, list, str]] = None, item_id: Optional[str] = None, automatic_build: bool = True, max_sequence_length: Optional[int] = None, aggregation=None, pre=None, post=None, **kwargs, ) -> Optional["EmbeddingFeatures"]: """Instantitates ``EmbeddingFeatures`` from a ``DatasetSchema``. Parameters ---------- schema : DatasetSchema Dataset schema embedding_dims : Optional[Dict[str, int]], optional The dimension of the embedding table for each feature (key), by default None by default None default_embedding_dim : Optional[int], optional Default dimension of the embedding table, when the feature is not found in ``default_soft_embedding_dim``, by default 64 infer_embedding_sizes : bool, optional Automatically defines the embedding dimension from the feature cardinality in the schema, by default False infer_embedding_sizes_multiplier: Optional[int], by default 2.0 multiplier used by the heuristic to infer the embedding dimension from its cardinality. Generally reasonable values range between 2.0 and 10.0 embeddings_initializers: Optional[Dict[str, Callable[[Any], None]]] Dict where keys are feature names and values are callable to initialize embedding tables combiner : Optional[str], optional Feature aggregation option, by default "mean" tags : Optional[Union[DefaultTags, list, str]], optional Tags to filter columns, by default None item_id : Optional[str], optional Name of the item id column (feature), by default None automatic_build : bool, optional Automatically infers input size from features, by default True max_sequence_length : Optional[int], optional Maximum sequence length for list features,, by default None Returns ------- Optional[EmbeddingFeatures] Returns the ``EmbeddingFeatures`` for the dataset schema """ # TODO: propagate item-id from ITEM_ID tag if tags: schema = schema.select_by_tag(tags) if not item_id and schema.select_by_tag(["item_id"]).column_names: item_id = schema.select_by_tag(["item_id"]).column_names[0] embedding_dims = embedding_dims or {} if infer_embedding_sizes: embedding_dims_infered = get_embedding_sizes_from_schema( schema, infer_embedding_sizes_multiplier ) embedding_dims = { **embedding_dims, **{k: v for k, v in embedding_dims_infered.items() if k not in embedding_dims}, } embeddings_initializers = embeddings_initializers or {} emb_config = {} cardinalities = schema.categorical_cardinalities() for key, cardinality in cardinalities.items(): embedding_size = embedding_dims.get(key, embedding_dim_default) embedding_initializer = embeddings_initializers.get(key, None) emb_config[key] = (cardinality, embedding_size, embedding_initializer) feature_config: Dict[str, FeatureConfig] = {} for name, (vocab_size, dim, emb_initilizer) in emb_config.items(): feature_config[name] = FeatureConfig( TableConfig( vocabulary_size=vocab_size, dim=dim, name=name, combiner=combiner, initializer=emb_initilizer, ) ) if not feature_config: return None output = cls(feature_config, item_id=item_id, pre=pre, post=post, aggregation=aggregation) if automatic_build and schema: output.build( get_output_sizes_from_schema( schema, kwargs.get("batch_size", -1), max_sequence_length=max_sequence_length, ), schema=schema, ) return output
[docs] def item_ids(self, inputs) -> torch.Tensor: return inputs[self.item_id]
[docs] def forward(self, inputs, **kwargs): embedded_outputs = {} filtered_inputs = self.filter_features(inputs) for name, val in filtered_inputs.items(): if isinstance(val, tuple): values, offsets = val values = torch.squeeze(values, -1) # for the case where only one value in values if len(values.shape) == 0: values = values.unsqueeze(0) embedded_outputs[name] = self.embedding_tables[name](values, offsets[:, 0]) else: # if len(val.shape) <= 1: # val = val.unsqueeze(0) embedded_outputs[name] = self.embedding_tables[name](val) # Store raw item ids for masking and/or negative sampling # This makes this module stateful. if self.item_id: self.item_seq = self.item_ids(inputs) embedded_outputs = super().forward(embedded_outputs) return embedded_outputs
[docs] def forward_output_size(self, input_sizes): sizes = {} batch_size = calculate_batch_size_from_input_size(input_sizes) for name, feature in self.feature_config.items(): sizes[name] = torch.Size([batch_size, feature.table.dim]) return sizes
[docs]class EmbeddingBagWrapper(torch.nn.EmbeddingBag):
[docs] def forward(self, input, **kwargs): # EmbeddingBag requires 2D tensors (or offsets) if len(input.shape) == 1: input = input.unsqueeze(-1) return super().forward(input, **kwargs)
[docs]@docstring_parameter( tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING, embedding_features_parameters=EMBEDDING_FEATURES_PARAMS_DOCSTRING, ) class SoftEmbeddingFeatures(EmbeddingFeatures): """ Encapsulate continuous features encoded using the Soft-one hot encoding embedding technique (SoftEmbedding), from https://arxiv.org/pdf/1708.00065.pdf In a nutshell, it keeps an embedding table for each continuous feature, which is represented as a weighted average of embeddings. Parameters ---------- feature_config: Dict[str, FeatureConfig] This specifies what TableConfig to use for each feature. For shared embeddings, the same TableConfig can be used for multiple features. layer_norm: boolean When layer_norm is true, TabularLayerNorm will be used in post. {tabular_module_parameters} """ def __init__( self, feature_config: Dict[str, "FeatureConfig"], layer_norm: bool = True, pre: Optional[TabularTransformationType] = None, post: Optional[TabularTransformationType] = None, aggregation: Optional[TabularAggregationType] = None, **kwarg, ): if layer_norm: from transformers4rec.torch import TabularLayerNorm post = TabularLayerNorm.from_feature_config(feature_config) super().__init__(feature_config, pre=pre, post=post, aggregation=aggregation)
[docs] @classmethod def from_schema( # type: ignore cls, schema: Schema, soft_embedding_cardinalities: Optional[Dict[str, int]] = None, soft_embedding_cardinality_default: int = 10, soft_embedding_dims: Optional[Dict[str, int]] = None, soft_embedding_dim_default: int = 8, embeddings_initializers: Optional[Dict[str, Callable[[Any], None]]] = None, layer_norm: bool = True, combiner: str = "mean", tags: Optional[Union[Tag, list, str]] = None, automatic_build: bool = True, max_sequence_length: Optional[int] = None, **kwargs, ) -> Optional["SoftEmbeddingFeatures"]: """ Instantitates ``SoftEmbeddingFeatures`` from a ``DatasetSchema``. Parameters ---------- schema : DatasetSchema Dataset schema soft_embedding_cardinalities : Optional[Dict[str, int]], optional The cardinality of the embedding table for each feature (key), by default None soft_embedding_cardinality_default : Optional[int], optional Default cardinality of the embedding table, when the feature is not found in ``soft_embedding_cardinalities``, by default 10 soft_embedding_dims : Optional[Dict[str, int]], optional The dimension of the embedding table for each feature (key), by default None soft_embedding_dim_default : Optional[int], optional Default dimension of the embedding table, when the feature is not found in ``soft_embedding_dim_default``, by default 8 embeddings_initializers: Optional[Dict[str, Callable[[Any], None]]] Dict where keys are feature names and values are callable to initialize embedding tables combiner : Optional[str], optional Feature aggregation option, by default "mean" tags : Optional[Union[DefaultTags, list, str]], optional Tags to filter columns, by default None automatic_build : bool, optional Automatically infers input size from features, by default True max_sequence_length : Optional[int], optional Maximum sequence length for list features, by default None Returns ------- Optional[SoftEmbeddingFeatures] Returns a ``SoftEmbeddingFeatures`` instance from the dataset schema """ # TODO: propagate item-id from ITEM_ID tag if tags: schema = schema.select_by_tag(tags) soft_embedding_cardinalities = soft_embedding_cardinalities or {} soft_embedding_dims = soft_embedding_dims or {} embeddings_initializers = embeddings_initializers or {} sizes = {} cardinalities = schema.categorical_cardinalities() for col_name in schema.column_names: # If this is NOT a categorical feature if col_name not in cardinalities: embedding_size = soft_embedding_dims.get(col_name, soft_embedding_dim_default) cardinality = soft_embedding_cardinalities.get( col_name, soft_embedding_cardinality_default ) emb_initializer = embeddings_initializers.get(col_name, None) sizes[col_name] = (cardinality, embedding_size, emb_initializer) feature_config: Dict[str, FeatureConfig] = {} for name, (vocab_size, dim, emb_initializer) in sizes.items(): feature_config[name] = FeatureConfig( TableConfig( vocabulary_size=vocab_size, dim=dim, name=name, combiner=combiner, initializer=emb_initializer, ) ) if not feature_config: return None output = cls(feature_config, layer_norm=layer_norm, **kwargs) if automatic_build and schema: output.build( get_output_sizes_from_schema( schema, kwargs.get("batch_size", -1), max_sequence_length=max_sequence_length, ) ) return output
[docs] def table_to_embedding_module(self, table: "TableConfig") -> "SoftEmbedding": return SoftEmbedding(table.vocabulary_size, table.dim, table.initializer)
[docs]class TableConfig: def __init__( self, vocabulary_size: int, dim: int, initializer: Optional[Callable[[torch.Tensor], None]] = None, combiner: Text = "mean", name: Optional[Text] = None, ): if not isinstance(vocabulary_size, int) or vocabulary_size < 1: raise ValueError("Invalid vocabulary_size {}.".format(vocabulary_size)) if not isinstance(dim, int) or dim < 1: raise ValueError("Invalid dim {}.".format(dim)) if combiner not in ("mean", "sum", "sqrtn"): raise ValueError("Invalid combiner {}".format(combiner)) if (initializer is not None) and (not callable(initializer)): raise ValueError("initializer must be callable if specified.") self.initializer: Callable[[torch.Tensor], None] if initializer is None: self.initializer = partial(torch.nn.init.normal_, mean=0.0, std=0.05) # type: ignore else: self.initializer = initializer self.vocabulary_size = vocabulary_size self.dim = dim self.combiner = combiner self.name = name def __repr__(self): return ( "TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, " "combiner={combiner!r}, name={name!r})".format( vocabulary_size=self.vocabulary_size, dim=self.dim, combiner=self.combiner, name=self.name, ) )
[docs]class FeatureConfig: def __init__( self, table: TableConfig, max_sequence_length: int = 0, name: Optional[Text] = None ): self.table = table self.max_sequence_length = max_sequence_length self.name = name def __repr__(self): return ( "FeatureConfig(table={table!r}, " "max_sequence_length={max_sequence_length!r}, name={name!r})".format( table=self.table, max_sequence_length=self.max_sequence_length, name=self.name ) )
[docs]class SoftEmbedding(torch.nn.Module): """ Soft-one hot encoding embedding technique, from https://arxiv.org/pdf/1708.00065.pdf In a nutshell, it represents a continuous feature as a weighted average of embeddings """ def __init__(self, num_embeddings, embeddings_dim, emb_initializer=None): """ Parameters ---------- num_embeddings: Number of embeddings to use (cardinality of the embedding table). embeddings_dim: The dimension of the vector space for projecting the scalar value. embeddings_init_std: The standard deviation factor for normal initialization of the embedding matrix weights. emb_initializer: Dict where keys are feature names and values are callable to initialize embedding tables """ assert ( num_embeddings > 0 ), "The number of embeddings for soft embeddings needs to be greater than 0" assert ( embeddings_dim > 0 ), "The embeddings dim for soft embeddings needs to be greater than 0" super(SoftEmbedding, self).__init__() self.embedding_table = torch.nn.Embedding(num_embeddings, embeddings_dim) if emb_initializer: emb_initializer(self.embedding_table.weight) self.projection_layer = torch.nn.Linear(1, num_embeddings, bias=True) self.softmax = torch.nn.Softmax(dim=-1)
[docs] def forward(self, input_numeric): input_numeric = input_numeric.unsqueeze(-1) weights = self.softmax(self.projection_layer(input_numeric)) soft_one_hot_embeddings = (weights.unsqueeze(-1) * self.embedding_table.weight).sum(-2) return soft_one_hot_embeddings