#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
from functools import reduce
from typing import Dict, List, Optional, Union
import tensorflow as tf
from merlin_standard_lib import Registry, RegistryMixin, Schema
from merlin_standard_lib.utils.doc_utils import docstring_parameter
from transformers4rec.config.schema import SchemaMixin
from ..block.base import Block, SequentialBlock
from ..typing import TabularData, TensorOrTabularData
from ..utils.tf_utils import (
calculate_batch_size_from_input_shapes,
maybe_deserialize_keras_objects,
maybe_serialize_keras_objects,
)
tabular_transformation_registry: Registry = Registry.class_registry("tf.tabular_transformations")
tabular_aggregation_registry: Registry = Registry.class_registry("tf.tabular_aggregations")
class TabularTransformation(
SchemaMixin, tf.keras.layers.Layer, RegistryMixin["TabularTransformation"], ABC
):
"""Transformation that takes in `TabularData` and outputs `TabularData`."""
def call(self, inputs: TabularData, **kwargs) -> TabularData:
raise NotImplementedError()
@classmethod
def registry(cls) -> Registry:
return tabular_transformation_registry
class TabularAggregation(
SchemaMixin, tf.keras.layers.Layer, RegistryMixin["TabularAggregation"], ABC
):
"""Aggregation of `TabularData` that outputs a single `Tensor`"""
def call(self, inputs: TabularData, **kwargs) -> tf.Tensor:
raise NotImplementedError()
@classmethod
def registry(cls) -> Registry:
return tabular_aggregation_registry
def _expand_non_sequential_features(self, inputs: TabularData) -> TabularData:
inputs_sizes = {k: v.shape for k, v in inputs.items()}
seq_features_shapes, sequence_length = self._get_seq_features_shapes(inputs_sizes)
if len(seq_features_shapes) > 0:
non_seq_features = set(inputs.keys()).difference(set(seq_features_shapes.keys()))
for fname in non_seq_features:
# Including the 2nd dim and repeating for the sequence length
inputs[fname] = tf.tile(tf.expand_dims(inputs[fname], 1), (1, sequence_length, 1))
return inputs
def _get_seq_features_shapes(self, inputs_sizes: Dict[str, tf.TensorShape]):
seq_features_shapes = dict()
for fname, fshape in inputs_sizes.items():
# Saves the shapes of sequential features
if len(fshape) >= 3:
seq_features_shapes[fname] = tuple(fshape[:2])
sequence_length = 0
if len(seq_features_shapes) > 0:
if len(set(seq_features_shapes.values())) > 1:
raise ValueError(
"All sequential features must share the same shape in the first two dims "
"(batch_size, seq_length): {}".format(seq_features_shapes)
)
sequence_length = list(seq_features_shapes.values())[0][1]
return seq_features_shapes, sequence_length
def _check_concat_shapes(self, inputs: TabularData):
input_sizes = {k: v.shape for k, v in inputs.items()}
if len(set([tuple(v[:-1]) for v in input_sizes.values()])) > 1:
raise Exception(
"All features dimensions except the last one must match: {}".format(input_sizes)
)
def _get_agg_output_size(self, input_size, agg_dim):
batch_size = calculate_batch_size_from_input_shapes(input_size)
seq_features_shapes, sequence_length = self._get_seq_features_shapes(input_size)
if len(seq_features_shapes) > 0:
return (
batch_size,
sequence_length,
agg_dim,
)
else:
return (batch_size, agg_dim)
TabularTransformationType = Union[str, TabularTransformation]
TabularTransformationsType = Union[TabularTransformationType, List[TabularTransformationType]]
TabularAggregationType = Union[str, TabularAggregation]
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class SequentialTabularTransformations(SequentialBlock):
"""A sequential container, modules will be added to it in the order they are passed in.
Parameters
----------
transformation: TabularTransformationType
transformations that are passed in here will be called in order.
"""
def __init__(self, transformation: TabularTransformationsType):
if len(transformation) == 1 and isinstance(transformation[0], list):
transformation = transformation[0]
if not isinstance(transformation, (list, tuple)):
transformation = [transformation]
super().__init__([TabularTransformation.parse(t) for t in transformation])
def append(self, transformation):
self.transformations.append(TabularTransformation.parse(transformation))
@classmethod
def from_config(cls, config, custom_objects=None):
layers = [
tf.keras.utils.deserialize_keras_object(conf, custom_objects=custom_objects)
for conf in config.values()
]
return SequentialTabularTransformations(layers)
TABULAR_MODULE_PARAMS_DOCSTRING = """
pre: Union[str, TabularTransformation, List[str], List[TabularTransformation]], optional
Transformations to apply on the inputs when the module is called (so **before** `call`).
post: Union[str, TabularTransformation, List[str], List[TabularTransformation]], optional
Transformations to apply on the inputs after the module is called (so **after** `call`).
aggregation: Union[str, TabularAggregation], optional
Aggregation to apply after processing the `call`-method to output a single Tensor.
Next to providing a class that extends TabularAggregation, it's also possible to provide
the name that the class is registered in the `tabular_aggregation_registry`. Out of the box
this contains: "concat", "stack", "element-wise-sum" &
"element-wise-sum-item-multi".
schema: Optional[DatasetSchema]
DatasetSchema containing the columns used in this block.
name: Optional[str]
Name of the layer.
"""
[docs]@docstring_parameter(tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING)
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class TabularBlock(Block):
"""Layer that's specialized for tabular-data by integrating many often used operations.
Note, when extending this class, typically you want to overwrite the `compute_call_output_shape`
method instead of the normal `compute_output_shape`. This because a Block can contain pre- and
post-processing and the output-shapes are handled automatically in `compute_output_shape`. The
output of `compute_call_output_shape` should be the shape that's outputted by the `call`-method.
Parameters
----------
{tabular_module_parameters}
"""
def __init__(
self,
pre: Optional[TabularTransformationsType] = None,
post: Optional[TabularTransformationsType] = None,
aggregation: Optional[TabularAggregationType] = None,
schema: Optional[Schema] = None,
name: Optional[str] = None,
**kwargs,
):
super().__init__(name=name, **kwargs)
self.input_size = None
self.set_pre(pre)
self.set_post(post)
self.set_aggregation(aggregation)
if schema:
self.set_schema(schema)
[docs] @classmethod
def from_schema(cls, schema: Schema, tags=None, **kwargs) -> Optional["TabularBlock"]:
"""Instantiate a TabularLayer instance from a DatasetSchema.
Parameters
----------
schema
tags
kwargs
Returns
-------
Optional[TabularModule]
"""
schema_copy = schema.copy()
if tags:
schema_copy = schema_copy.select_by_tag(tags)
if not schema_copy.column_names:
return None
return cls.from_features(schema_copy.column_names, schema=schema_copy, **kwargs)
[docs] @classmethod
@docstring_parameter(tabular_module_parameters=TABULAR_MODULE_PARAMS_DOCSTRING, extra_padding=4)
def from_features(
cls,
features: List[str],
pre: Optional[TabularTransformationsType] = None,
post: Optional[TabularTransformationsType] = None,
aggregation: Optional[TabularAggregationType] = None,
name=None,
**kwargs,
) -> "TabularBlock":
"""
Initializes a TabularLayer instance where the contents of features will be filtered out
Parameters
----------
features: List[str]
A list of feature-names that will be used as the first pre-processing op to filter out
all other features not in this list.
{tabular_module_parameters}
Returns
-------
TabularModule
"""
pre = [FilterFeatures(features), pre] if pre else FilterFeatures(features) # type: ignore
return cls(pre=pre, post=post, aggregation=aggregation, name=name, **kwargs)
[docs] def pre_call(
self, inputs: TabularData, transformations: Optional[TabularTransformationsType] = None
) -> TabularData:
"""Method that's typically called before the forward method for pre-processing.
Parameters
----------
inputs: TabularData
input-data, typically the output of the forward method.
transformations: TabularTransformationsType, optional
Returns
-------
TabularData
"""
return self._maybe_apply_transformations(
inputs, transformations=transformations or self.pre
)
[docs] def call(self, inputs: TabularData, **kwargs) -> TabularData:
return inputs
[docs] def post_call(
self,
inputs: TabularData,
transformations: Optional[TabularTransformationsType] = None,
merge_with: Union["TabularBlock", List["TabularBlock"]] = None,
aggregation: Optional[TabularAggregationType] = None,
) -> TensorOrTabularData:
"""Method that's typically called after the forward method for post-processing.
Parameters
----------
inputs: TabularData
input-data, typically the output of the forward method.
transformations: TabularTransformationType, optional
Transformations to apply on the input data.
merge_with: Union[TabularModule, List[TabularModule]], optional
Other TabularModule's to call and merge the outputs with.
aggregation: TabularAggregationType, optional
Aggregation to aggregate the output to a single Tensor.
Returns
-------
TensorOrTabularData (Tensor when aggregation is set, else TabularData)
"""
_aggregation: Optional[TabularAggregation] = None
if aggregation:
_aggregation = TabularAggregation.parse(aggregation)
_aggregation = _aggregation or getattr(self, "aggregation", None)
outputs = inputs
if merge_with:
if not isinstance(merge_with, list):
merge_with = [merge_with]
for layer_or_tensor in merge_with:
to_add = layer_or_tensor(inputs) if callable(layer_or_tensor) else layer_or_tensor
outputs.update(to_add)
outputs = self._maybe_apply_transformations(
outputs, transformations=transformations or self.post
)
if _aggregation:
schema = getattr(self, "schema", None)
_aggregation.set_schema(schema)
return _aggregation(outputs)
return outputs
def __call__( # type: ignore
self,
inputs: TabularData,
*args,
pre: Optional[TabularTransformationsType] = None,
post: Optional[TabularTransformationsType] = None,
merge_with: Union["TabularBlock", List["TabularBlock"]] = None,
aggregation: Optional[TabularAggregationType] = None,
**kwargs,
) -> TensorOrTabularData:
"""We overwrite the call method in order to be able to do pre- and post-processing.
Parameters
----------
inputs: TabularData
Input TabularData.
pre: TabularTransformationsType, optional
Transformations to apply before calling the forward method. If pre is None, this method
will check if `self.pre` is set.
post: TabularTransformationsType, optional
Transformations to apply after calling the forward method. If post is None, this method
will check if `self.post` is set.
merge_with: Union[TabularModule, List[TabularModule]]
Other TabularModule's to call and merge the outputs with.
aggregation: TabularAggregationType, optional
Aggregation to aggregate the output to a single Tensor.
Returns
-------
TensorOrTabularData (Tensor when aggregation is set, else TabularData)
"""
inputs = self.pre_call(inputs, transformations=pre)
# This will call the `forward` method implemented by the super class.
outputs = super().__call__(inputs, *args, **kwargs) # noqa
if isinstance(outputs, dict):
outputs = self.post_call(
outputs, transformations=post, merge_with=merge_with, aggregation=aggregation
)
return outputs
def _maybe_apply_transformations(
self,
inputs: TabularData,
transformations: Optional[TabularTransformationsType] = None,
) -> TabularData:
"""Apply transformations to the inputs if these are defined.
Parameters
----------
inputs
transformations
Returns
-------
"""
if transformations:
transformations = TabularTransformation.parse(transformations)
return transformations(inputs)
return inputs
[docs] def compute_call_output_shape(self, input_shapes):
return input_shapes
[docs] def compute_output_shape(self, input_shapes):
if self.pre:
input_shapes = self.pre.compute_output_shape(input_shapes)
output_shapes = self._check_post_output_size(self.compute_call_output_shape(input_shapes))
return output_shapes
[docs] def get_config(self):
config = super(TabularBlock, self).get_config()
config = maybe_serialize_keras_objects(self, config, ["pre", "post", "aggregation"])
if self.schema:
config["schema"] = self.schema.to_json()
return config
[docs] @classmethod
def from_config(cls, config):
config = maybe_deserialize_keras_objects(config, ["pre", "post", "aggregation"])
if "schema" in config:
config["schema"] = Schema().from_json(config["schema"])
return super().from_config(config)
def _check_post_output_size(self, input_shapes):
output_shapes = input_shapes
if isinstance(output_shapes, dict):
if self.post:
output_shapes = self.post.compute_output_shape(output_shapes)
if self.aggregation:
schema = getattr(self, "schema", None)
self.aggregation.set_schema(schema)
output_shapes = self.aggregation.compute_output_shape(output_shapes)
return output_shapes
[docs] def apply_to_all(self, inputs, columns_to_filter=None):
if columns_to_filter:
inputs = FilterFeatures(columns_to_filter)(inputs)
outputs = tf.nest.map_structure(self, inputs)
return outputs
[docs] def set_schema(self, schema=None):
self._maybe_set_schema(self.pre, schema)
self._maybe_set_schema(self.post, schema)
self._maybe_set_schema(self.aggregation, schema)
return super().set_schema(schema)
[docs] def set_pre(self, value: Optional[TabularTransformationsType]):
if value and isinstance(value, SequentialTabularTransformations):
self._pre: Optional[SequentialTabularTransformations] = value
elif value and isinstance(value, (tf.keras.layers.Layer, list)):
self._pre = SequentialTabularTransformations(value)
else:
self._pre = None
@property
def pre(self) -> Optional[SequentialTabularTransformations]:
"""
Returns
-------
SequentialTabularTransformations, optional
"""
return self._pre
@property
def post(self) -> Optional[SequentialTabularTransformations]:
"""
Returns
-------
SequentialTabularTransformations, optional
"""
return self._post
[docs] def set_post(self, value: Optional[TabularTransformationsType]):
if value and isinstance(value, SequentialTabularTransformations):
self._post: Optional[SequentialTabularTransformations] = value
elif value and isinstance(value, (tf.keras.layers.Layer, list)):
self._post = SequentialTabularTransformations(value)
else:
self._post = None
@property
def aggregation(self) -> Optional[TabularAggregation]:
"""
Returns
-------
TabularAggregation, optional
"""
return self._aggregation
[docs] def set_aggregation(self, value: Optional[Union[str, TabularAggregation]]):
"""
Parameters
----------
value
"""
if value:
self._aggregation: Optional[TabularAggregation] = TabularAggregation.parse(value)
else:
self._aggregation = None
[docs] def repr_ignore(self):
return []
[docs] def repr_add(self):
return []
def __rrshift__(self, other):
from ..block.base import right_shift_layer
return right_shift_layer(self, other)
[docs]@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class FilterFeatures(TabularTransformation):
"""Transformation that filters out certain features from `TabularData`."
Parameters
----------
to_include: List[str]
List of features to include in the result of calling the module
pop: bool
Boolean indicating whether to pop the features to exclude from the inputs dictionary.
"""
def __init__(
self, to_include, trainable=False, name=None, dtype=None, dynamic=False, pop=False, **kwargs
):
super().__init__(trainable, name, dtype, dynamic, **kwargs)
self.to_include = to_include
self.pop = pop
[docs] def call(self, inputs: TabularData, **kwargs) -> TabularData:
"""Filter out features from inputs.
Parameters
----------
inputs: TabularData
Input dictionary containing features to filter.
Returns Filtered TabularData that only contains the feature-names in `self.to_include`.
-------
"""
assert isinstance(inputs, dict), "Inputs needs to be a dict"
outputs = {k: v for k, v in inputs.items() if k in self.to_include}
if self.pop:
for key in outputs.keys():
inputs.pop(key)
return outputs
[docs] def compute_output_shape(self, input_shape):
return {k: v for k, v in input_shape.items() if k in self.to_include}
[docs] def get_config(self):
config = super().get_config()
config["to_include"] = self.to_include
return config
[docs]@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class MergeTabular(TabularBlock):
"""Merge multiple TabularModule's into a single output of TabularData.
Parameters
----------
blocks_to_merge: Union[TabularModule, Dict[str, TabularBlock]]
TabularBlocks to merge into, this can also be one or multiple dictionaries keyed by the
name the module should have.
{tabular_module_parameters}
"""
def __init__(
self,
*blocks_to_merge: Union[TabularBlock, Dict[str, TabularBlock]],
pre: Optional[TabularTransformationType] = None,
post: Optional[TabularTransformationType] = None,
aggregation: Optional[TabularAggregationType] = None,
schema: Optional[Schema] = None,
name: Optional[str] = None,
**kwargs,
):
super().__init__(
pre=pre, post=post, aggregation=aggregation, schema=schema, name=name, **kwargs
)
self.to_merge: Union[List[TabularBlock], Dict[str, TabularBlock]]
if all(isinstance(x, dict) for x in blocks_to_merge):
to_merge: Dict[str, TabularBlock] = reduce(
lambda a, b: dict(a, **b), blocks_to_merge
) # type: ignore
self.to_merge = to_merge
elif all(isinstance(x, tf.keras.layers.Layer) for x in blocks_to_merge):
self.to_merge = list(blocks_to_merge) # type: ignore
else:
raise ValueError(
"Please provide one or multiple layer's to merge or "
f"dictionaries of layer. got: {blocks_to_merge}"
)
# Merge schemas if necessary.
if not schema and all(getattr(m, "schema", False) for m in self.merge_values):
s = reduce(lambda a, b: a + b, [m.schema for m in self.merge_values]) # type: ignore
self.set_schema(s)
@property
def merge_values(self) -> List[tf.keras.layers.Layer]:
if isinstance(self.to_merge, dict):
return list(self.to_merge.values())
return self.to_merge
@property
def to_merge_dict(self) -> Dict[str, tf.keras.layers.Layer]:
if isinstance(self.to_merge, dict):
return self.to_merge
return {str(i): m for i, m in enumerate(self.to_merge)}
[docs] def call(self, inputs, **kwargs):
assert isinstance(inputs, dict), "Inputs needs to be a dict"
outputs = {}
for layer in self.merge_values:
outputs.update(layer(inputs))
return outputs
[docs] def compute_call_output_shape(self, input_shape):
output_shapes = {}
for layer in self.merge_values:
output_shapes.update(layer.compute_output_shape(input_shape))
return output_shapes
[docs] def get_config(self):
return maybe_serialize_keras_objects(
self, super(MergeTabular, self).get_config(), ["merge_layers"]
)
[docs]@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class AsTabular(tf.keras.layers.Layer):
"""Converts a Tensor to TabularData by converting it to a dictionary.
Parameters
----------
output_name: str
Name that should be used as the key in the output dictionary.
name: str
Name of the layer.
"""
def __init__(self, output_name: str, name=None, **kwargs):
super().__init__(name=name, **kwargs)
self.output_name = output_name
[docs] def call(self, inputs, **kwargs):
return {self.output_name: inputs}
[docs] def get_config(self):
config = super(AsTabular, self).get_config()
config["output_name"] = self.output_name
return config
def merge_tabular(self, other, aggregation=None, **kwargs):
return MergeTabular(self, other, aggregation=aggregation, **kwargs)
TabularBlock.__add__ = merge_tabular
TabularBlock.merge = merge_tabular