#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import inspect
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Sequence, Type, Union
import tensorflow as tf
from tensorflow.keras import backend
from tensorflow.python import to_dlpack
from tensorflow.python.tpu.tpu_embedding_v2_utils import FeatureConfig, TableConfig
import merlin.io
from merlin.core.dispatch import DataFrameType
from merlin.io import Dataset
from merlin.models.tf.blocks.mlp import InitializerType, RegularizerType
from merlin.models.tf.core.base import Block, BlockType
from merlin.models.tf.core.combinators import ParallelBlock, SequentialBlock
from merlin.models.tf.core.tabular import (
TABULAR_MODULE_PARAMS_DOCSTRING,
Filter,
TabularAggregationType,
TabularBlock,
)
from merlin.models.tf.transforms.tensor import ListToDense, ListToSparse
# pylint has issues with TF array ops, so disable checks until fixed:
# https://github.com/PyCQA/pylint/issues/3613
# pylint: disable=no-value-for-parameter, unexpected-keyword-arg
from merlin.models.tf.typing import TabularData
from merlin.models.tf.utils.tf_utils import (
call_layer,
df_to_tensor,
list_col_to_ragged,
tensor_to_df,
)
from merlin.models.utils import schema_utils
from merlin.models.utils.doc_utils import docstring_parameter
from merlin.models.utils.schema_utils import (
create_categorical_column,
infer_embedding_dim,
schema_to_tensorflow_metadata_json,
tensorflow_metadata_json_to_schema,
)
from merlin.schema import ColumnSchema, Schema, Tags, TagsType
EMBEDDING_FEATURES_PARAMS_DOCSTRING = """
feature_config: Dict[str, FeatureConfig]
This specifies what TableConfig to use for each feature. For shared embeddings, the same
TableConfig can be used for multiple features.
item_id: str, optional
The name of the feature that's used for the item_id.
"""
class EmbeddingTableBase(Block):
def __init__(self, dim: int, *col_schemas: ColumnSchema, trainable=True, **kwargs):
super(EmbeddingTableBase, self).__init__(trainable=trainable, **kwargs)
self.dim = dim
self.features = {}
if len(col_schemas) == 0:
raise ValueError("At least one col_schema must be provided to the embedding table.")
self.col_schema = col_schemas[0]
for col_schema in col_schemas:
self.add_feature(col_schema)
@property
def _schema(self):
return Schema([col_schema for col_schema in self.features.values()])
@classmethod
def from_pretrained(
cls,
data: Union[Dataset, DataFrameType],
col_schema: Optional[ColumnSchema] = None,
trainable=True,
**kwargs,
):
raise NotImplementedError()
@property
def input_dim(self):
return self.col_schema.int_domain.max + 1
@property
def table_name(self):
return self.col_schema.int_domain.name or self.col_schema.name
def add_feature(self, col_schema: ColumnSchema) -> None:
"""Add a feature to the table.
Adding more than one feature enables the table to lookup and return embeddings
for more than one feature when called with tabular data (Dict[str, TensorLike]).
Additional column schemas must have an int domain that matches the existing ones.
Parameters
----------
col_schema : ColumnSchema
"""
if not col_schema.int_domain:
raise ValueError("`col_schema` needs to have an int-domain")
if (
col_schema.int_domain.name
and self.col_schema.int_domain.name
and col_schema.int_domain.name != self.col_schema.int_domain.name
):
raise ValueError(
"`col_schema` int-domain name does not match table domain name. "
f"{col_schema.int_domain.name} != {self.col_schema.int_domain.name} "
)
if col_schema.int_domain.max != self.col_schema.int_domain.max:
raise ValueError(
"`col_schema.int_domain.max` does not match existing input dim."
f"{col_schema.int_domain.max} != {self.col_schema.int_domain.max} "
)
self.features[col_schema.name] = col_schema
def get_config(self):
config = super().get_config()
config["dim"] = self.dim
schema = schema_to_tensorflow_metadata_json(self.schema)
config["schema"] = schema
return config
@classmethod
def from_config(cls, config):
dim = config.pop("dim")
schema = tensorflow_metadata_json_to_schema(config.pop("schema"))
return cls(dim, *schema, **config)
CombinerType = Union[str, tf.keras.layers.Layer]
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class EmbeddingTable(EmbeddingTableBase):
"""Embedding table that is backed by a standard Keras Embedding Layer.
Parameters
----------
dim: Dimension of the dense embedding.
col_schema: ColumnSchema
Schema of the column. This is used to infer the cardinality.
embeddings_initializer: Initializer for the `embeddings`
matrix (see `keras.initializers`).
embeddings_regularizer: Regularizer function applied to
the `embeddings` matrix (see `keras.regularizers`).
embeddings_constraint: Constraint function applied to
the `embeddings` matrix (see `keras.constraints`).
mask_zero: Boolean, whether or not the input value 0 is a special "padding"
value that should be masked out.
This is useful when using recurrent layers
which may take variable length input.
If this is `True`, then all subsequent layers
in the model need to support masking or an exception will be raised.
If mask_zero is set to True, as a consequence, index 0 cannot be
used in the vocabulary (input_dim should equal size of
vocabulary + 1).
input_length: Length of input sequences, when it is constant.
This argument is required if you are going to connect
`Flatten` then `Dense` layers upstream
(without it, the shape of the dense outputs cannot be computed).
combiner: A string specifying how to combine embedding results for each
entry ("mean", "sqrtn" and "sum" are supported) or a layer.
Default is None (no combiner used)
trainable: Boolean, whether the layer's variables should be trainable.
name: String name of the layer.
dtype: The dtype of the layer's computations and weights. Can also be a
`tf.keras.mixed_precision.Policy`, which allows the computation and weight
dtype to differ. Default of `None` means to use
`tf.keras.mixed_precision.global_policy()`, which is a float32 policy
unless set to different value.
dynamic: Set this to `True` if your layer should only be run eagerly, and
should not be used to generate a static computation graph.
This would be the case for a Tree-RNN or a recursive network,
for example, or generally for any layer that manipulates tensors
using Python control flow. If `False`, we assume that the layer can
safely be used to generate a static computation graph.
l2_batch_regularization_factor: float, optional
Factor for L2 regularization of the embeddings vectors (from the current batch only)
by default 0.0
**kwargs: Forwarded Keras Layer parameters
"""
def __init__(
self,
dim: int,
*col_schemas: ColumnSchema,
embeddings_initializer="uniform",
embeddings_regularizer=None,
activity_regularizer=None,
embeddings_constraint=None,
mask_zero=False,
input_length=None,
sequence_combiner: Optional[CombinerType] = None,
trainable=True,
name=None,
dtype=None,
dynamic=False,
table=None,
l2_batch_regularization_factor=0.0,
**kwargs,
):
"""Create an EmbeddingTable."""
super(EmbeddingTable, self).__init__(
dim,
*col_schemas,
trainable=trainable,
name=name,
dtype=dtype,
dynamic=dynamic,
**kwargs,
)
if table is not None:
self.table = table
else:
table_kwargs = dict(
embeddings_initializer=embeddings_initializer,
embeddings_regularizer=embeddings_regularizer,
activity_regularizer=activity_regularizer,
embeddings_constraint=embeddings_constraint,
mask_zero=mask_zero,
input_length=input_length,
trainable=trainable,
)
self.table = tf.keras.layers.Embedding(
input_dim=self.input_dim,
output_dim=self.dim,
name=self.table_name,
**table_kwargs,
)
self.sequence_combiner = sequence_combiner
self.supports_masking = True
self.l2_batch_regularization_factor = l2_batch_regularization_factor
def select_by_tag(self, tags: Union[Tags, Sequence[Tags]]) -> Optional["EmbeddingTable"]:
"""Select features in EmbeddingTable by tags.
Since an EmbeddingTable can be a shared-embedding table, this method filters
the schema for features that match the tags.
If none of the features match the tags, it will return None.
Parameters
----------
tags: Union[Tags, Sequence[Tags]]
A list of tags.
Returns
-------
An EmbeddingTable if the tags match. If no features match, it returns None.
"""
if not isinstance(tags, collections.Sequence):
tags = [tags]
selected_schema = self.schema.select_by_tag(tags)
if not selected_schema:
return
config = self.get_config()
config["schema"] = schema_utils.schema_to_tensorflow_metadata_json(selected_schema)
embedding_table = EmbeddingTable.from_config(config, table=self.table)
return embedding_table
@classmethod
def from_pretrained(
cls,
data: Union[Dataset, DataFrameType],
trainable=True,
name=None,
col_schema=None,
**kwargs,
) -> "EmbeddingTable":
"""Create From pre-trained embeddings from a Dataset or DataFrame.
Parameters
----------
data : Union[Dataset, DataFrameType]
A dataset containing the pre-trained embedding weights
trainable : bool
Whether the layer should be trained or not.
name : str
The name of the layer.
"""
if hasattr(data, "to_ddf"):
data = data.to_ddf().compute()
embeddings = df_to_tensor(data, tf.float32)
num_items, dim = tuple(embeddings.shape)
if not col_schema:
if not name:
raise ValueError("`name` is required when not using a ColumnSchema")
col_schema = create_categorical_column(name, num_items - 1)
return cls(
dim,
col_schema,
name=name,
embeddings_initializer=tf.keras.initializers.constant(embeddings),
trainable=trainable,
**kwargs,
)
@classmethod
def from_dataset(
cls,
data: Union[Dataset, DataFrameType],
trainable=True,
name=None,
col_schema=None,
**kwargs,
) -> "EmbeddingTable":
"""Create From pre-trained embeddings from a Dataset or DataFrame.
Parameters
----------
data : Union[Dataset, DataFrameType]
A dataset containing the pre-trained embedding weights
trainable : bool
Whether the layer should be trained or not.
name : str
The name of the layer.
"""
return cls.from_pretrained(
data, trainable=trainable, name=name, col_schema=col_schema, **kwargs
)
def to_dataset(self, gpu=None) -> merlin.io.Dataset:
return merlin.io.Dataset(self.to_df(gpu=gpu))
def to_df(self, gpu=None):
return tensor_to_df(self.table.embeddings, gpu=gpu)
def _maybe_build(self, inputs):
"""Creates state between layer instantiation and layer call.
Invoked automatically before the first execution of `call()`.
"""
self.table._maybe_build(inputs)
return super(EmbeddingTable, self)._maybe_build(inputs)
def build(self, input_shapes):
if not self.table.built:
self.table.build(input_shapes)
return super(EmbeddingTable, self).build(input_shapes)
def call(
self, inputs: Union[tf.Tensor, TabularData], **kwargs
) -> Union[tf.Tensor, TabularData]:
"""
Parameters
----------
inputs : Union[tf.Tensor, tf.RaggedTensor, tf.SparseTensor]
Tensors or dictionary of tensors representing the input batch.
Returns
-------
A tensor or dict of tensors corresponding to the embeddings for inputs
"""
if isinstance(inputs, dict):
out = {}
for feature_name in self.schema.column_names:
if feature_name in inputs:
out[feature_name] = self._call_table(inputs[feature_name], **kwargs)
else:
out = self._call_table(inputs, **kwargs)
return out
def _call_table(self, inputs, **kwargs):
if isinstance(inputs, tuple) and len(inputs) == 2:
inputs = list_col_to_ragged(inputs)
# Eliminating the last dim==1 of dense tensors before embedding lookup
if isinstance(inputs, tf.Tensor) or (
isinstance(inputs, tf.RaggedTensor) and inputs.shape[-1] == 1
):
inputs = tf.squeeze(inputs, axis=-1)
if isinstance(inputs, (tf.RaggedTensor, tf.SparseTensor)):
if self.sequence_combiner and isinstance(self.sequence_combiner, str):
if isinstance(inputs, tf.RaggedTensor):
inputs = inputs.to_sparse()
if len(inputs.dense_shape) == 3 and inputs.dense_shape[-1] == 1:
inputs = tf.sparse.reshape(inputs, inputs.dense_shape[:-1])
out = tf.nn.safe_embedding_lookup_sparse(
self.table.embeddings, inputs, None, combiner=self.sequence_combiner
)
else:
if isinstance(inputs, tf.SparseTensor):
raise ValueError(
"Sparse tensors are not supported without sequence_combiner ",
"please convert the tensor to a ragged or dense.",
)
out = call_layer(self.table, inputs, **kwargs)
if isinstance(self.sequence_combiner, tf.keras.layers.Layer):
out = call_layer(self.sequence_combiner, out, **kwargs)
else:
out = call_layer(self.table, inputs, **kwargs)
if self.l2_batch_regularization_factor > 0:
self.add_loss(self.l2_batch_regularization_factor * tf.reduce_sum(tf.square(out)))
if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
# Instead of casting the variable as in most layers, cast the output, as
# this is mathematically equivalent but is faster.
out = tf.cast(out, self._dtype_policy.compute_dtype)
return out
def compute_output_shape(
self, input_shape: Union[tf.TensorShape, Dict[str, tf.TensorShape]]
) -> Union[tf.TensorShape, Dict[str, tf.TensorShape]]:
if isinstance(input_shape, dict):
output_shapes = {}
for feature_name in self.schema.column_names:
if feature_name in input_shape:
output_shapes[feature_name] = self._compute_output_shape_table(
input_shape[feature_name]
)
else:
output_shapes = self._compute_output_shape_table(input_shape)
return output_shapes
def _compute_output_shape_table(
self, input_shape: Union[tf.TensorShape, tuple]
) -> tf.TensorShape:
if isinstance(input_shape, tuple) and isinstance(input_shape[1], tf.TensorShape):
input_shape = tf.TensorShape([input_shape[1][0], None])
first_dims = input_shape
if (self.sequence_combiner is not None) or (input_shape.rank > 1 and input_shape[-1] == 1):
if len(input_shape) == 3:
first_dims = [input_shape[0]]
else:
first_dims = input_shape[:-1]
output_shapes = tf.TensorShape(first_dims + [self.dim])
return output_shapes
def compute_call_output_shape(self, input_shapes):
return self.compute_output_shape(input_shapes)
@classmethod
def from_config(cls, config, table=None):
if table:
config["table"] = table
else:
config["table"] = tf.keras.layers.deserialize(config["table"])
if "combiner-layer" in config:
config["sequence_combiner"] = tf.keras.layers.deserialize(config.pop("combiner-layer"))
return super().from_config(config)
def get_config(self):
config = super().get_config()
config["table"] = tf.keras.layers.serialize(self.table)
if isinstance(self.sequence_combiner, tf.keras.layers.Layer):
config["combiner-layer"] = tf.keras.layers.serialize(self.sequence_combiner)
else:
config["sequence_combiner"] = self.sequence_combiner
return config
def Embeddings(
schema: Schema,
dim: Optional[Union[Dict[str, int], int]] = None,
infer_dim_fn: Callable[[ColumnSchema], int] = infer_embedding_dim,
sequence_combiner: Optional[Union[CombinerType, Dict[str, CombinerType]]] = "mean",
embeddings_initializer: Optional[Union[InitializerType, Dict[str, InitializerType]]] = None,
embeddings_regularizer: Optional[Union[RegularizerType, Dict[str, RegularizerType]]] = None,
activity_regularizer: Optional[Union[RegularizerType, Dict[str, RegularizerType]]] = None,
trainable: Optional[Union[bool, Dict[str, bool]]] = None,
table_cls: Type[tf.keras.layers.Layer] = EmbeddingTable,
pre: Optional[BlockType] = None,
post: Optional[BlockType] = None,
aggregation: Optional[TabularAggregationType] = None,
block_name: str = "embeddings",
l2_batch_regularization_factor: Optional[Union[float, Dict[str, float]]] = 0.0,
**kwargs,
) -> ParallelBlock:
"""Creates a ParallelBlock with an EmbeddingTable for each categorical feature
in the schema.
Parameters
----------
schema: Schema
Schema of the input data. This Schema object will be automatically generated using
[NVTabular](https://nvidia-merlin.github.io/NVTabular/main/Introduction.html).
Next to this, it's also possible to construct it manually.
dim: Optional[Union[Dict[str, int], int]], optional
A dim to use for all features, or a
Dict like {"feature_name": embedding size, ...}, by default None
infer_dim_fn: Callable[[ColumnSchema], int], defaults to infer_embedding_dim
The function to use to infer the embedding dimension, by default infer_embedding_dim
sequence_combiner: Optional[Union[str, tf.keras.layers.Layer]], optional
A string specifying how to combine embedding results for each
entry ("mean", "sqrtn" and "sum" are supported) or a layer.
Default is None (no combiner used)
embeddings_initializer: Union[InitializerType, Dict[str, InitializerType]], optional
An initializer function or a dict where keys are feature names and values are
callable to initialize embedding tables. Pre-trained embeddings can be fed via
embeddings_initializer arg.
embeddings_regularizer: Union[RegularizerType, Dict[str, RegularizerType]], optional
A regularizer function or a dict where keys are feature names and values are
callable to apply regularization to embedding tables.
activity_regularizer: Union[RegularizerType, Dict[str, RegularizerType]], optional
A regularizer function or a dict where keys are feature names and values are
callable to apply regularization to the activations of the embedding tables.
trainable: Optional[Dict[str, bool]] = None
Name of the column(s) whose embeddings should be frozen (or trainable) during training
trainable will be set to False/True for these column(s), accordingly
table_cls: Type[tf.keras.layers.Layer], by default EmbeddingTable
The class to use for each embedding table.
pre: Optional[BlockType], optional
Transformation block to apply before the embeddings lookup, by default None
post: Optional[BlockType], optional
Transformation block to apply after the embeddings lookup, by default None
aggregation: Optional[TabularAggregationType], optional
Transformation block to apply for aggregating the inputs, by default None
block_name: str, optional
Name of the block, by default "embeddings"
l2_batch_regularization_factor: Optional[float, Dict[str, float]] = 0.0
Factor for L2 regularization of the embeddings vectors (from the current batch only)
If a dictionary is provided, the keys are feature names and the values are
regularization factors
Returns
-------
ParallelBlock
Returns a parallel block with an embedding table for each categorical features
"""
if trainable:
kwargs["trainable"] = trainable
if embeddings_initializer:
kwargs["embeddings_initializer"] = embeddings_initializer
if embeddings_regularizer:
kwargs["embeddings_regularizer"] = embeddings_regularizer
if activity_regularizer:
kwargs["activity_regularizer"] = activity_regularizer
if sequence_combiner:
kwargs["sequence_combiner"] = sequence_combiner
if l2_batch_regularization_factor:
kwargs["l2_batch_regularization_factor"] = l2_batch_regularization_factor
tables = {}
for col in schema:
table_kwargs = _forward_kwargs_to_table(col, table_cls, kwargs)
table_name = col.int_domain.name or col.name
if table_name in tables:
tables[table_name].add_feature(col)
else:
tables[table_name] = table_cls(
_get_dim(col, dim, infer_dim_fn),
col,
name=table_name,
**table_kwargs,
)
return ParallelBlock(
tables, pre=pre, post=post, aggregation=aggregation, name=block_name, schema=schema
)
def _forward_kwargs_to_table(col, table_cls, kwargs):
arg_spec = inspect.getfullargspec(table_cls.__init__)
supported_kwargs = arg_spec.kwonlyargs
if arg_spec.defaults:
supported_kwargs += arg_spec.args[-len(arg_spec.defaults) :]
table_kwargs = {}
for key, val in kwargs.items():
if key in supported_kwargs:
if isinstance(val, dict):
if col.name in val:
table_kwargs[key] = val[col.name]
else:
table_kwargs[key] = val
return table_kwargs
def _get_dim(col, embedding_dims, infer_dim_fn):
dim = None
if isinstance(embedding_dims, dict):
dim = embedding_dims.get(col.name)
elif isinstance(embedding_dims, int):
dim = embedding_dims
if not dim:
dim = infer_dim_fn(col)
return dim
class AverageEmbeddingsByWeightFeature(tf.keras.layers.Layer):
def __init__(self, weight_feature_name: str, axis=1, **kwargs):
"""Computes the weighted average of a Tensor based
on one of the input features.
Typically used as a combiner for EmbeddingTable
for aggregating sequential embedding features
Parameters
----------
weight_feature_name : str
Name of the feature to be used as weight for average
axis : int, optional
Axis for reduction, by default 1 (assuming the 2nd dim is
the sequence length)
"""
super(AverageEmbeddingsByWeightFeature, self).__init__(**kwargs)
self.axis = axis
self.weight_feature_name = weight_feature_name
def call(self, inputs, features):
weight_feature = features[self.weight_feature_name]
if isinstance(inputs, tf.RaggedTensor) and not isinstance(weight_feature, tf.RaggedTensor):
raise ValueError(
f"If inputs is a tf.RaggedTensor, the weight feature ({self.weight_feature_name}) "
f"should also be a tf.RaggedTensor (and not a {type(weight_feature)}), "
"so that the list length can vary per example for both input embedding "
"and weight features."
)
weights = tf.expand_dims(tf.cast(weight_feature, tf.float32), -1)
output = tf.divide(
tf.reduce_sum(tf.multiply(inputs, weights), axis=self.axis),
tf.reduce_sum(weights, axis=self.axis),
)
return output
def compute_output_shape(self, input_shape):
return input_shape
@staticmethod
def from_schema_convention(schema: Schema, weight_features_name_suffix: str = "_weight"):
"""Infers the weight features corresponding to sequential embedding
features based on the feature name suffix. For example, if a
sequential categorical feature is called `item_id_seq`, if there is another
feature in the schema called `item_id_seq_weight`, then it will be used
for weighted average. If a weight feature cannot be found for a given
seq cat. feature then standard mean is used as combiner
Parameters
----------
schema : Schema
The feature schema
weight_features_name_suffix : str
Suffix to look for a corresponding weight feature
Returns
-------
Dict[str, WeightedAverageByFeature]
A dict where the key is the sequential categorical feature name and the value
is an instance of WeightedAverageByFeature with the corresponding weight feature name
"""
cat_cols = schema.select_by_tag(Tags.CATEGORICAL)
seq_combiners = {}
for cat_col in cat_cols:
combiner = None
if Tags.SEQUENCE in cat_col.tags:
weight_col_name = f"{cat_col.name}{weight_features_name_suffix}"
if weight_col_name in schema.column_names:
combiner = AverageEmbeddingsByWeightFeature(weight_col_name)
else:
combiner = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1))
seq_combiners[cat_col.name] = combiner
return seq_combiners
@dataclass
class EmbeddingOptions:
embedding_dims: Optional[Dict[str, int]] = None
embedding_dim_default: Optional[int] = 64
infer_embedding_sizes: bool = False
infer_embedding_sizes_multiplier: float = 2.0
infer_embeddings_ensure_dim_multiple_of_8: bool = False
embeddings_initializers: Optional[
Union[Dict[str, Callable[[Any], None]], Callable[[Any], None], str]
] = None
embeddings_l2_reg: float = 0.0
combiner: Optional[str] = "mean"
[docs]@docstring_parameter(
tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING,
embedding_features_parameters=EMBEDDING_FEATURES_PARAMS_DOCSTRING,
)
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class EmbeddingFeatures(TabularBlock):
"""Input block for embedding-lookups for categorical features.
For multi-hot features, the embeddings will be aggregated into a single tensor using the mean.
Parameters
----------
{embedding_features_parameters}
{tabular_module_parameters}
"""
[docs] def __init__(
self,
feature_config: Dict[str, "FeatureConfig"],
pre: Optional[BlockType] = None,
post: Optional[BlockType] = None,
aggregation: Optional[TabularAggregationType] = None,
schema: Optional[Schema] = None,
name=None,
add_default_pre=True,
l2_reg: Optional[float] = 0.0,
**kwargs,
):
if add_default_pre:
embedding_pre = [Filter(list(feature_config.keys())), ListToSparse()]
pre = [embedding_pre, pre] if pre else embedding_pre # type: ignore
self.feature_config = feature_config
self.l2_reg = l2_reg
self.embedding_tables = {}
tables: Dict[str, TableConfig] = {}
for _, feature in self.feature_config.items():
table: TableConfig = feature.table
if table.name not in tables:
tables[table.name] = table
for table_name, table in tables.items():
self.embedding_tables[table_name] = tf.keras.layers.Embedding(
table.vocabulary_size,
table.dim,
name=table_name,
embeddings_initializer=table.initializer,
)
super().__init__(
pre=pre,
post=post,
aggregation=aggregation,
name=name,
schema=schema,
is_input=True,
**kwargs,
)
[docs] @classmethod
def from_schema( # type: ignore
cls,
schema: Schema,
embedding_options: EmbeddingOptions = EmbeddingOptions(),
tags: Optional[TagsType] = None,
max_sequence_length: Optional[int] = None,
**kwargs,
) -> Optional["EmbeddingFeatures"]:
"""Instantiates embedding features from the schema
Parameters
----------
schema : Schema
The features chema
embedding_options : EmbeddingOptions, optional
An EmbeddingOptions instance, which allows for a number of
options for the embedding table, by default EmbeddingOptions()
tags : Optional[TagsType], optional
If provided, keeps only features from those tags, by default None
max_sequence_length : Optional[int], optional
Maximum sequence length of sparse features (if any), by default None
Returns
-------
EmbeddingFeatures
An instance of EmbeddingFeatures block, with the embedding
layers created under-the-hood
"""
schema_copy = copy(schema)
if tags:
schema_copy = schema_copy.select_by_tag(tags)
embedding_dims = embedding_options.embedding_dims or {}
if embedding_options.infer_embedding_sizes:
inferred_embedding_dims = schema_utils.get_embedding_sizes_from_schema(
schema,
embedding_options.infer_embedding_sizes_multiplier,
embedding_options.infer_embeddings_ensure_dim_multiple_of_8,
)
# Adding inferred embedding dims only for features where the embedding sizes
# were not pre-defined
inferred_embedding_dims = {
k: v for k, v in inferred_embedding_dims.items() if k not in embedding_dims
}
embedding_dims = {**embedding_dims, **inferred_embedding_dims}
initializer_default = tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05)
embeddings_initializer = embedding_options.embeddings_initializers or initializer_default
emb_config = {}
cardinalities = schema_utils.categorical_cardinalities(schema)
for key, cardinality in cardinalities.items():
embedding_size = embedding_dims.get(key, embedding_options.embedding_dim_default)
if isinstance(embeddings_initializer, dict):
emb_initializer = embeddings_initializer.get(key, initializer_default)
else:
emb_initializer = embeddings_initializer
emb_config[key] = (cardinality, embedding_size, emb_initializer)
feature_config: Dict[str, FeatureConfig] = {}
tables: Dict[str, TableConfig] = {}
domains = schema_utils.categorical_domains(schema)
for name, (vocab_size, dim, emb_initilizer) in emb_config.items():
table_name = domains[name]
table = tables.get(table_name, None)
if not table:
table = TableConfig(
vocabulary_size=vocab_size,
dim=dim,
name=table_name,
combiner=embedding_options.combiner,
initializer=emb_initilizer,
)
tables[table_name] = table
feature_config[name] = FeatureConfig(table)
if not feature_config:
return None
output = cls(
feature_config,
schema=schema_copy,
l2_reg=embedding_options.embeddings_l2_reg,
**kwargs,
)
return output
[docs] def build(self, input_shapes):
for name, embedding_table in self.embedding_tables.items():
embedding_table.build(())
if hasattr(self, "_context"):
self._context.add_embedding_table(name, self.embedding_tables[name])
if isinstance(input_shapes, dict):
super().build(input_shapes)
else:
tf.keras.layers.Layer.build(self, input_shapes)
[docs] def call(self, inputs: TabularData, **kwargs) -> TabularData:
embedded_outputs = {}
for name, val in inputs.items():
embedded_outputs[name] = self.lookup_feature(name, val)
if self.l2_reg > 0:
self.add_loss(self.l2_reg * tf.reduce_sum(tf.square(embedded_outputs[name])))
return embedded_outputs
[docs] def compute_call_output_shape(self, input_shapes):
batch_size = self.calculate_batch_size_from_input_shapes(input_shapes)
output_shapes = {}
for name, val in input_shapes.items():
output_shapes[name] = tf.TensorShape([batch_size, self.feature_config[name].table.dim])
return output_shapes
[docs] def lookup_feature(self, name, val, output_sequence=False):
dtype = backend.dtype(val)
if dtype != "int32" and dtype != "int64":
val = tf.cast(val, "int32")
table: TableConfig = self.feature_config[name].table
table_var = self.embedding_tables[table.name].embeddings
if isinstance(val, tf.SparseTensor):
if len(val.dense_shape) == 3 and val.dense_shape[-1] == 1:
val = tf.sparse.reshape(val, val.dense_shape[:-1])
out = tf.nn.safe_embedding_lookup_sparse(table_var, val, None, combiner=table.combiner)
else:
if output_sequence:
out = tf.gather(table_var, tf.cast(val, tf.int32))
else:
if len(val.shape) > 1:
# TODO: Check if it is correct to retrieve only the 1st element
# of second dim for non-sequential multi-hot categ features
out = tf.gather(table_var, tf.cast(val, tf.int32)[:, 0])
else:
out = tf.gather(table_var, tf.cast(val, tf.int32))
if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
# Instead of casting the variable as in most layers, cast the output, as
# this is mathematically equivalent but is faster.
out = tf.cast(out, self._dtype_policy.compute_dtype)
return out
[docs] def table_config(self, feature_name: str):
return self.feature_config[feature_name].table
[docs] def get_embedding_table(self, table_name: Union[str, Tags], l2_normalization: bool = False):
if isinstance(table_name, Tags):
feature_names = self.schema.select_by_tag(table_name).column_names
if len(feature_names) == 1:
table_name = feature_names[0]
elif len(feature_names) > 1:
raise ValueError(
f"There is more than one feature associated to the tag {table_name}"
)
else:
raise ValueError(f"Could not find a feature associated to the tag {table_name}")
embeddings = self.embedding_tables[table_name].embeddings
if l2_normalization:
embeddings = tf.linalg.l2_normalize(embeddings, axis=-1)
return embeddings
[docs] def embedding_table_df(
self, table_name: Union[str, Tags], l2_normalization: bool = False, gpu: bool = True
):
"""Retrieves a dataframe with the embedding table
Parameters
----------
table_name : Union[str, Tags]
Tag or name of the embedding table
l2_normalization : bool, optional
Whether the L2-normalization should be applied to
embeddings (common approach for Matrix Factorization
and Retrieval models in general), by default False
gpu : bool, optional
Whether or not should use GPU, by default True
Returns
-------
Union[pd.DataFrame, cudf.DataFrame]
Returns a dataframe (cudf or pandas), depending on the gpu
"""
embeddings = self.get_embedding_table(table_name, l2_normalization)
if gpu:
import cudf
import cupy
# Note: It is not possible to convert Tensorflow tensors to the cudf dataframe
# directly using dlPack (as the example commented below) because cudf.from_dlpack()
# expects the 2D tensor to be in Fortran order (column-major), which is not
# supported by TF (https://github.com/rapidsai/cudf/issues/10754).
# df = cudf.from_dlpack(to_dlpack(tf.convert_to_tensor(embeddings)))
embeddings_cupy = cupy.fromDlpack(to_dlpack(tf.convert_to_tensor(embeddings)))
df = cudf.DataFrame(embeddings_cupy)
df.columns = [str(col) for col in list(df.columns)]
df.set_index(cudf.RangeIndex(0, embeddings.shape[0]))
else:
import pandas as pd
df = pd.DataFrame(embeddings.numpy())
df.columns = [str(col) for col in list(df.columns)]
df.set_index(pd.RangeIndex(0, embeddings.shape[0]))
return df
[docs] def embedding_table_dataset(
self, table_name: Union[str, Tags], l2_normalization: bool = False, gpu=True
) -> merlin.io.Dataset:
"""Creates a Dataset for the embedding table
Parameters
----------
table_name : Union[str, Tags]
Tag or name of the embedding table
l2_normalization : bool, optional
Whether the L2-normalization should be applied to
embeddings (common approach for Matrix Factorization
and Retrieval models in general), by default False
gpu : bool, optional
Whether or not should use GPU, by default True
Returns
-------
merlin.io.Dataset
Returns a Dataset with the embeddings
"""
return merlin.io.Dataset(self.embedding_table_df(table_name, l2_normalization, gpu))
[docs] def export_embedding_table(
self,
table_name: Union[str, Tags],
export_path: str,
l2_normalization: bool = False,
gpu=True,
):
"""Exports the embedding table to parquet file
Parameters
----------
table_name : Union[str, Tags]
Tag or name of the embedding table
export_path : str
Path for the generated parquet file
l2_normalization : bool, optional
Whether the L2-normalization should be applied to
embeddings (common approach for Matrix Factorization
and Retrieval models in general), by default False
gpu : bool, optional
Whether or not should use GPU, by default True
"""
df = self.embedding_table_df(table_name, l2_normalization, gpu=gpu)
df.to_parquet(export_path)
[docs] def get_config(self):
config = super().get_config()
feature_configs = {}
for key, val in self.feature_config.items():
feature_config_dict = dict(name=val.name, max_sequence_length=val.max_sequence_length)
feature_config_dict["table"] = serialize_table_config(val.table)
feature_configs[key] = feature_config_dict
config["feature_config"] = feature_configs
config["l2_reg"] = self.l2_reg
return config
[docs] @classmethod
def from_config(cls, config):
# Deserialize feature_config
feature_configs = {}
for key, val in config["feature_config"].items():
table = deserialize_table_config(val["table"])
feature_config_params = {**val, "table": table}
feature_configs[key] = FeatureConfig(**feature_config_params)
config["feature_config"] = feature_configs
# Set `add_default_pre to False` since pre will be provided from the config
config["add_default_pre"] = False
return super().from_config(config)
[docs]@docstring_parameter(
tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING,
embedding_features_parameters=EMBEDDING_FEATURES_PARAMS_DOCSTRING,
)
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class SequenceEmbeddingFeatures(EmbeddingFeatures):
"""Input block for embedding-lookups for categorical features. This module produces 3-D tensors,
this is useful for sequential models like transformers.
Parameters
----------
{embedding_features_parameters}
padding_idx: int
The symbol to use for padding.
{tabular_module_parameters}
"""
[docs] def __init__(
self,
feature_config: Dict[str, FeatureConfig],
max_seq_length: Optional[int] = None,
mask_zero: bool = True,
padding_idx: int = 0,
pre: Optional[BlockType] = None,
post: Optional[BlockType] = None,
aggregation: Optional[TabularAggregationType] = None,
schema: Optional[Schema] = None,
name: Optional[str] = None,
add_default_pre=True,
**kwargs,
):
if add_default_pre:
embedding_pre = [Filter(list(feature_config.keys())), ListToDense(max_seq_length)]
pre = [embedding_pre, pre] if pre else embedding_pre # type: ignore
super().__init__(
feature_config=feature_config,
pre=pre,
post=post,
aggregation=aggregation,
name=name,
schema=schema,
add_default_pre=False,
**kwargs,
)
self.padding_idx = padding_idx
self.mask_zero = mask_zero
[docs] def lookup_feature(self, name, val, **kwargs):
return super(SequenceEmbeddingFeatures, self).lookup_feature(
name, val, output_sequence=True
)
[docs] def compute_call_output_shape(self, input_shapes):
batch_size = self.calculate_batch_size_from_input_shapes(input_shapes)
sequence_length = input_shapes[list(self.feature_config.keys())[0]][1]
output_shapes = {}
for name, val in input_shapes.items():
output_shapes[name] = tf.TensorShape(
[batch_size, sequence_length, self.feature_config[name].table.dim]
)
return output_shapes
[docs] def compute_mask(self, inputs, mask=None):
if not self.mask_zero:
return None
outputs = {}
for key, val in inputs.items():
outputs[key] = tf.not_equal(val, self.padding_idx)
return outputs
[docs] def get_config(self):
config = super().get_config()
config["mask_zero"] = self.mask_zero
config["padding_idx"] = self.padding_idx
return config
[docs]def ContinuousEmbedding(
inputs: Block,
embedding_block: Block,
aggregation=None,
continuous_aggregation="concat",
name: str = "continuous",
**kwargs,
) -> SequentialBlock:
continuous_embedding = Filter(Tags.CONTINUOUS, aggregation=continuous_aggregation).connect(
embedding_block
)
outputs = inputs.connect_branch(
continuous_embedding.as_tabular(name), add_rest=True, aggregation=aggregation, **kwargs
)
return outputs
def serialize_table_config(table_config: TableConfig) -> Dict[str, Any]:
table = deepcopy(table_config.__dict__)
if "initializer" in table:
table["initializer"] = tf.keras.initializers.serialize(table["initializer"])
if "optimizer" in table:
table["optimizer"] = tf.keras.optimizers.serialize(table["optimizer"])
return table
def deserialize_table_config(table_params: Dict[str, Any]) -> TableConfig:
if "initializer" in table_params and table_params["initializer"]:
table_params["initializer"] = tf.keras.initializers.deserialize(table_params["initializer"])
if "optimizer" in table_params and table_params["optimizer"]:
table_params["optimizer"] = tf.keras.optimizers.deserialize(table_params["optimizer"])
table = TableConfig(**table_params)
return table
def serialize_feature_config(feature_config: FeatureConfig) -> Dict[str, Any]:
outputs = {}
for key, val in feature_config.items():
feature_config_dict = dict(name=val.name, max_sequence_length=val.max_sequence_length)
feature_config_dict["table"] = serialize_table_config(feature_config_dict["table"])
outputs[key] = feature_config_dict
return outputs