#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union
import tensorflow as tf
from tensorflow.python import to_dlpack
from tensorflow.python.keras import backend
from tensorflow.python.tpu.tpu_embedding_v2_utils import FeatureConfig, TableConfig
import merlin.io
from merlin.models.tf.blocks.core.base import Block, BlockType
from merlin.models.tf.blocks.core.combinators import SequentialBlock
from merlin.models.tf.blocks.core.tabular import (
TABULAR_MODULE_PARAMS_DOCSTRING,
Filter,
TabularAggregationType,
TabularBlock,
)
from merlin.models.tf.blocks.core.transformations import AsDenseFeatures, AsSparseFeatures
# pylint has issues with TF array ops, so disable checks until fixed:
# https://github.com/PyCQA/pylint/issues/3613
# pylint: disable=no-value-for-parameter, unexpected-keyword-arg
from merlin.models.tf.typing import TabularData
from merlin.models.utils import schema_utils
from merlin.models.utils.doc_utils import docstring_parameter
from merlin.schema import Schema, Tags, TagsType
EMBEDDING_FEATURES_PARAMS_DOCSTRING = """
feature_config: Dict[str, FeatureConfig]
This specifies what TableConfig to use for each feature. For shared embeddings, the same
TableConfig can be used for multiple features.
item_id: str, optional
The name of the feature that's used for the item_id.
"""
@dataclass
class EmbeddingOptions:
embedding_dims: Optional[Dict[str, int]] = None
embedding_dim_default: Optional[int] = 64
infer_embedding_sizes: bool = False
infer_embedding_sizes_multiplier: float = 2.0
infer_embeddings_ensure_dim_multiple_of_8: bool = False
embeddings_initializers: Optional[
Union[Dict[str, Callable[[Any], None]], Callable[[Any], None]]
] = None
embeddings_l2_reg: float = 0.0
combiner: Optional[str] = "mean"
[docs]@docstring_parameter(
tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING,
embedding_features_parameters=EMBEDDING_FEATURES_PARAMS_DOCSTRING,
)
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class EmbeddingFeatures(TabularBlock):
"""Input block for embedding-lookups for categorical features.
For multi-hot features, the embeddings will be aggregated into a single tensor using the mean.
Parameters
----------
{embedding_features_parameters}
{tabular_module_parameters}
"""
[docs] def __init__(
self,
feature_config: Dict[str, "FeatureConfig"],
pre: Optional[BlockType] = None,
post: Optional[BlockType] = None,
aggregation: Optional[TabularAggregationType] = None,
schema: Optional[Schema] = None,
name=None,
add_default_pre=True,
l2_reg: Optional[float] = 0.0,
**kwargs,
):
if add_default_pre:
embedding_pre = [Filter(list(feature_config.keys())), AsSparseFeatures()]
pre = [embedding_pre, pre] if pre else embedding_pre # type: ignore
self.feature_config = feature_config
self.l2_reg = l2_reg
super().__init__(
pre=pre,
post=post,
aggregation=aggregation,
name=name,
schema=schema,
is_input=True,
**kwargs,
)
[docs] @classmethod
def from_schema( # type: ignore
cls,
schema: Schema,
embedding_options: EmbeddingOptions = EmbeddingOptions(),
tags: Optional[TagsType] = None,
max_sequence_length: Optional[int] = None,
**kwargs,
) -> Optional["EmbeddingFeatures"]:
"""Instantiates embedding features from the schema
Parameters
----------
schema : Schema
The features chema
embedding_options : EmbeddingOptions, optional
An EmbeddingOptions instance, which allows for a number of
options for the embedding table, by default EmbeddingOptions()
tags : Optional[TagsType], optional
If provided, keeps only features from those tags, by default None
max_sequence_length : Optional[int], optional
Maximum sequence length of sparse features (if any), by default None
Returns
-------
EmbeddingFeatures
An instance of EmbeddingFeatures block, with the embedding
layers created under-the-hood
"""
schema_copy = copy(schema)
if tags:
schema_copy = schema_copy.select_by_tag(tags)
embedding_dims = embedding_options.embedding_dims or {}
if embedding_options.infer_embedding_sizes:
inferred_embedding_dims = schema_utils.get_embedding_sizes_from_schema(
schema,
embedding_options.infer_embedding_sizes_multiplier,
embedding_options.infer_embeddings_ensure_dim_multiple_of_8,
)
# Adding inferred embedding dims only for features where the embedding sizes
# were not pre-defined
inferred_embedding_dims = {
k: v for k, v in inferred_embedding_dims.items() if k not in embedding_dims
}
embedding_dims = {**embedding_dims, **inferred_embedding_dims}
initializer_default = tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05)
embeddings_initializer = embedding_options.embeddings_initializers or initializer_default
emb_config = {}
cardinalities = schema_utils.categorical_cardinalities(schema)
for key, cardinality in cardinalities.items():
embedding_size = embedding_dims.get(key, embedding_options.embedding_dim_default)
if isinstance(embeddings_initializer, dict):
emb_initializer = embeddings_initializer.get(key, initializer_default)
else:
emb_initializer = embeddings_initializer
emb_config[key] = (cardinality, embedding_size, emb_initializer)
feature_config: Dict[str, FeatureConfig] = {}
tables: Dict[str, TableConfig] = {}
domains = schema_utils.categorical_domains(schema)
for name, (vocab_size, dim, emb_initilizer) in emb_config.items():
table_name = domains[name]
table = tables.get(table_name, None)
if not table:
table = TableConfig(
vocabulary_size=vocab_size,
dim=dim,
name=table_name,
combiner=embedding_options.combiner,
initializer=emb_initilizer,
)
tables[table_name] = table
feature_config[name] = FeatureConfig(table)
if not feature_config:
return None
output = cls(
feature_config,
schema=schema_copy,
l2_reg=embedding_options.embeddings_l2_reg,
**kwargs,
)
return output
[docs] def build(self, input_shapes):
self.embedding_tables = {}
tables: Dict[str, TableConfig] = {}
for name, feature in self.feature_config.items():
table: TableConfig = feature.table
if table.name not in tables:
tables[table.name] = table
for name, table in tables.items():
add_fn = (
self.context.add_embedding_weight if hasattr(self, "_context") else self.add_weight
)
self.embedding_tables[name] = add_fn(
name=name,
trainable=True,
initializer=table.initializer,
shape=(table.vocabulary_size, table.dim),
)
if isinstance(input_shapes, dict):
super().build(input_shapes)
else:
tf.keras.layers.Layer.build(self, input_shapes)
[docs] def call(self, inputs: TabularData, **kwargs) -> TabularData:
embedded_outputs = {}
for name, val in inputs.items():
embedded_outputs[name] = self.lookup_feature(name, val)
if self.l2_reg > 0:
self.add_loss(self.l2_reg * tf.reduce_sum(tf.square(embedded_outputs[name])))
return embedded_outputs
[docs] def compute_call_output_shape(self, input_shapes):
batch_size = self.calculate_batch_size_from_input_shapes(input_shapes)
output_shapes = {}
for name, val in input_shapes.items():
output_shapes[name] = tf.TensorShape([batch_size, self.feature_config[name].table.dim])
return output_shapes
[docs] def lookup_feature(self, name, val, output_sequence=False):
dtype = backend.dtype(val)
if dtype != "int32" and dtype != "int64":
val = tf.cast(val, "int32")
table: TableConfig = self.feature_config[name].table
table_var = self.embedding_tables[table.name]
if isinstance(val, tf.SparseTensor):
out = tf.nn.safe_embedding_lookup_sparse(table_var, val, None, combiner=table.combiner)
else:
if output_sequence:
out = tf.gather(table_var, tf.cast(val, tf.int32))
else:
if len(val.shape) > 1:
# TODO: Check if it is correct to retrieve only the 1st element
# of second dim for non-sequential multi-hot categ features
out = tf.gather(table_var, tf.cast(val, tf.int32)[:, 0])
else:
out = tf.gather(table_var, tf.cast(val, tf.int32))
if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
# Instead of casting the variable as in most layers, cast the output, as
# this is mathematically equivalent but is faster.
out = tf.cast(out, self._dtype_policy.compute_dtype)
return out
[docs] def table_config(self, feature_name: str):
return self.feature_config[feature_name].table
[docs] def get_embedding_table(self, table_name: Union[str, Tags], l2_normalization: bool = False):
if isinstance(table_name, Tags):
feature_names = self.schema.select_by_tag(table_name).column_names
if len(feature_names) == 1:
table_name = feature_names[0]
elif len(feature_names) > 1:
raise ValueError(
f"There is more than one feature associated to the tag {table_name}"
)
else:
raise ValueError(f"Could not find a feature associated to the tag {table_name}")
embeddings = self.embedding_tables[table_name]
if l2_normalization:
embeddings = tf.linalg.l2_normalize(embeddings, axis=-1)
return embeddings
[docs] def embedding_table_df(
self, table_name: Union[str, Tags], l2_normalization: bool = False, gpu: bool = True
):
"""Retrieves a dataframe with the embedding table
Parameters
----------
table_name : Union[str, Tags]
Tag or name of the embedding table
l2_normalization : bool, optional
Whether the L2-normalization should be applied to
embeddings (common approach for Matrix Factorization
and Retrieval models in general), by default False
gpu : bool, optional
Whether or not should use GPU, by default True
Returns
-------
Union[pd.DataFrame, cudf.DataFrame]
Returns a dataframe (cudf or pandas), depending on the gpu
"""
embeddings = self.get_embedding_table(table_name, l2_normalization)
if gpu:
import cudf
import cupy
# Note: It is not possible to convert Tensorflow tensors to the cudf dataframe
# directly using dlPack (as the example commented below) because cudf.from_dlpack()
# expects the 2D tensor to be in Fortran order (column-major), which is not
# supported by TF (https://github.com/rapidsai/cudf/issues/10754).
# df = cudf.from_dlpack(to_dlpack(tf.convert_to_tensor(embeddings)))
embeddings_cupy = cupy.fromDlpack(to_dlpack(tf.convert_to_tensor(embeddings)))
df = cudf.DataFrame(embeddings_cupy)
df.columns = [str(col) for col in list(df.columns)]
df.set_index(cudf.RangeIndex(0, embeddings.shape[0]))
else:
import pandas as pd
df = pd.DataFrame(embeddings.numpy())
df.columns = [str(col) for col in list(df.columns)]
df.set_index(pd.RangeIndex(0, embeddings.shape[0]))
return df
[docs] def embedding_table_dataset(
self, table_name: Union[str, Tags], l2_normalization: bool = False, gpu=True
) -> merlin.io.Dataset:
"""Creates a Dataset for the embedding table
Parameters
----------
table_name : Union[str, Tags]
Tag or name of the embedding table
l2_normalization : bool, optional
Whether the L2-normalization should be applied to
embeddings (common approach for Matrix Factorization
and Retrieval models in general), by default False
gpu : bool, optional
Whether or not should use GPU, by default True
Returns
-------
merlin.io.Dataset
Returns a Dataset with the embeddings
"""
return merlin.io.Dataset(self.embedding_table_df(table_name, l2_normalization, gpu))
[docs] def export_embedding_table(
self,
table_name: Union[str, Tags],
export_path: str,
l2_normalization: bool = False,
gpu=True,
):
"""Exports the embedding table to parquet file
Parameters
----------
table_name : Union[str, Tags]
Tag or name of the embedding table
export_path : str
Path for the generated parquet file
l2_normalization : bool, optional
Whether the L2-normalization should be applied to
embeddings (common approach for Matrix Factorization
and Retrieval models in general), by default False
gpu : bool, optional
Whether or not should use GPU, by default True
"""
df = self.embedding_table_df(table_name, l2_normalization, gpu=gpu)
df.to_parquet(export_path)
[docs] def get_config(self):
config = super().get_config()
feature_configs = {}
for key, val in self.feature_config.items():
feature_config_dict = dict(name=val.name, max_sequence_length=val.max_sequence_length)
feature_config_dict["table"] = serialize_table_config(val.table)
feature_configs[key] = feature_config_dict
config["feature_config"] = feature_configs
return config
[docs] @classmethod
def from_config(cls, config):
# Deserialize feature_config
feature_configs, table_configs = {}, {}
for key, val in config["feature_config"].items():
feature_params = deepcopy(val)
table_params = feature_params["table"]
if "name" in table_configs:
feature_params["table"] = table_configs["name"]
else:
table = deserialize_table_config(table_params)
if table.name:
table_configs[table.name] = table
feature_params["table"] = table
feature_configs[key] = FeatureConfig(**feature_params)
config["feature_config"] = feature_configs
# Set `add_default_pre to False` since pre will be provided from the config
config["add_default_pre"] = False
return super().from_config(config)
[docs]@docstring_parameter(
tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING,
embedding_features_parameters=EMBEDDING_FEATURES_PARAMS_DOCSTRING,
)
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class SequenceEmbeddingFeatures(EmbeddingFeatures):
"""Input block for embedding-lookups for categorical features. This module produces 3-D tensors,
this is useful for sequential models like transformers.
Parameters
----------
{embedding_features_parameters}
padding_idx: int
The symbol to use for padding.
{tabular_module_parameters}
"""
[docs] def __init__(
self,
feature_config: Dict[str, FeatureConfig],
max_seq_length: Optional[int] = None,
mask_zero: bool = True,
padding_idx: int = 0,
pre: Optional[BlockType] = None,
post: Optional[BlockType] = None,
aggregation: Optional[TabularAggregationType] = None,
schema: Optional[Schema] = None,
name: Optional[str] = None,
add_default_pre=True,
**kwargs,
):
if add_default_pre:
embedding_pre = [Filter(list(feature_config.keys())), AsDenseFeatures(max_seq_length)]
pre = [embedding_pre, pre] if pre else embedding_pre # type: ignore
super().__init__(
feature_config=feature_config,
pre=pre,
post=post,
aggregation=aggregation,
name=name,
schema=schema,
add_default_pre=False,
**kwargs,
)
self.padding_idx = padding_idx
self.mask_zero = mask_zero
[docs] def lookup_feature(self, name, val, **kwargs):
return super(SequenceEmbeddingFeatures, self).lookup_feature(
name, val, output_sequence=True
)
[docs] def compute_call_output_shape(self, input_shapes):
batch_size = self.calculate_batch_size_from_input_shapes(input_shapes)
sequence_length = input_shapes[list(self.feature_config.keys())[0]][1]
output_shapes = {}
for name, val in input_shapes.items():
output_shapes[name] = tf.TensorShape(
[batch_size, sequence_length, self.feature_config[name].table.dim]
)
return output_shapes
[docs] def compute_mask(self, inputs, mask=None):
if not self.mask_zero:
return None
outputs = {}
for key, val in inputs.items():
outputs[key] = tf.not_equal(val, self.padding_idx)
return outputs
[docs] def get_config(self):
config = super().get_config()
config["mask_zero"] = self.mask_zero
config["padding_idx"] = self.padding_idx
return config
[docs]def ContinuousEmbedding(
inputs: Block,
embedding_block: Block,
aggregation=None,
continuous_aggregation="concat",
name: str = "continuous",
**kwargs,
) -> SequentialBlock:
continuous_embedding = Filter(Tags.CONTINUOUS, aggregation=continuous_aggregation).connect(
embedding_block
)
outputs = inputs.connect_branch(
continuous_embedding.as_tabular(name), add_rest=True, aggregation=aggregation, **kwargs
)
return outputs
def serialize_table_config(table_config: TableConfig) -> Dict[str, Any]:
table = deepcopy(table_config.__dict__)
if "initializer" in table:
table["initializer"] = tf.keras.initializers.serialize(table["initializer"])
if "optimizer" in table:
table["optimizer"] = tf.keras.optimizers.serialize(table["optimizer"])
return table
def deserialize_table_config(table_params: Dict[str, Any]) -> TableConfig:
if "initializer" in table_params and table_params["initializer"]:
table_params["initializer"] = tf.keras.initializers.deserialize(table_params["initializer"])
if "optimizer" in table_params and table_params["optimizer"]:
table_params["optimizer"] = tf.keras.optimizers.deserialize(table_params["optimizer"])
table = TableConfig(**table_params)
return table
def serialize_feature_config(feature_config: FeatureConfig) -> Dict[str, Any]:
outputs = {}
for key, val in feature_config.items():
feature_config_dict = dict(name=val.name, max_sequence_length=val.max_sequence_length)
feature_config_dict["table"] = serialize_table_config(feature_config_dict["table"])
outputs[key] = feature_config_dict
return outputs