#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from merlin_standard_lib import Schema
from ...config.schema import requires_schema
from ..typing import TabularData
from ..utils.torch_utils import calculate_batch_size_from_input_size
from .base import TabularAggregation, tabular_aggregation_registry
[docs]@tabular_aggregation_registry.register("concat")
class ConcatFeatures(TabularAggregation):
"""Aggregation by stacking all values in TabularData, all non-sequential values will be
converted to a sequence.
The output of this concatenation will have 3 dimensions.
"""
[docs] def forward(
self,
inputs: TabularData,
) -> torch.Tensor:
self._expand_non_sequential_features(inputs)
self._check_concat_shapes(inputs)
tensors = []
for name in sorted(inputs.keys()):
val = inputs[name]
tensors.append(val)
return torch.cat(tensors, dim=-1)
[docs] def forward_output_size(self, input_size):
agg_dim = sum([i[-1] for i in input_size.values()])
output_size = self._get_agg_output_size(input_size, agg_dim)
return output_size
[docs]@tabular_aggregation_registry.register("stack")
class StackFeatures(TabularAggregation):
"""Aggregation by stacking all values in input dictionary in the given dimension.
Parameters
----------
axis: int, default=-1
Axis to use for the stacking operation.
"""
def __init__(self, axis: int = -1):
super().__init__()
self.axis = axis
[docs] def forward(self, inputs: TabularData) -> torch.Tensor:
self._expand_non_sequential_features(inputs)
self._check_concat_shapes(inputs)
tensors = []
for name in sorted(inputs.keys()):
tensors.append(inputs[name])
return torch.stack(tensors, dim=self.axis)
[docs] def forward_output_size(self, input_size):
batch_size = calculate_batch_size_from_input_size(input_size)
seq_features_shapes, sequence_length = self._get_seq_features_shapes(input_size)
if len(seq_features_shapes) > 0:
output_size = [
batch_size,
sequence_length,
]
else:
output_size = [batch_size]
num_features = len(input_size)
if self.axis == -1:
output_size.append(num_features)
else:
output_size.insert(self.axis, num_features)
return tuple(output_size)
[docs]class ElementwiseFeatureAggregation(TabularAggregation):
"""Base class for aggregation methods that aggregates
features element-wise.
It implements two check methods to ensure inputs have the correct shape.
"""
def _check_input_shapes_equal(self, inputs: TabularData):
"""Checks if the shapes of all inputs are equal.
Parameters
----------
inputs : TabularData
Dictionary of tensors.
"""
all_input_shapes_equal = len(set([x.shape for x in inputs.values()])) == 1
if not all_input_shapes_equal:
raise ValueError(
"The shapes of all input features are not equal, which is required for"
" element-wise aggregation: {}".format({k: v.shape for k, v in inputs.items()})
)
def _check_inputs_last_dim_equal(self, inputs_sizes):
"""
Checks if the last dimensions of all inputs are equal.
Parameters
----------
inputs_sizes : dict[str, Union[List[int], torch.Size]]
A dictionary containing the sizes of the inputs.
"""
all_input_last_dim_equal = len(set([x[-1] for x in inputs_sizes.values()])) == 1
if not all_input_last_dim_equal:
raise ValueError(
f"The last dim of all input features is not equal, which is"
f" required for element-wise aggregation: {inputs_sizes}"
)
[docs]@tabular_aggregation_registry.register("element-wise-sum")
class ElementwiseSum(ElementwiseFeatureAggregation):
"""Aggregation by first stacking all values in TabularData in the first dimension, and then
summing the result."""
def __init__(self):
super().__init__()
self.stack = StackFeatures(axis=0)
[docs] def forward(self, inputs: TabularData) -> torch.Tensor:
self._expand_non_sequential_features(inputs)
self._check_input_shapes_equal(inputs)
return self.stack(inputs).sum(dim=0)
[docs] def forward_output_size(self, input_size):
self._check_inputs_last_dim_equal(input_size)
agg_dim = list(input_size.values())[0][-1]
output_size = self._get_agg_output_size(input_size, agg_dim)
return output_size
[docs]@tabular_aggregation_registry.register("element-wise-sum-item-multi")
@requires_schema
class ElementwiseSumItemMulti(ElementwiseFeatureAggregation):
"""Aggregation by applying the `ElementwiseSum` aggregation to all features except the item-id,
and then multiplying this with the item-ids.
Parameters
----------
schema: DatasetSchema
"""
def __init__(self, schema: Schema = None):
super().__init__()
self.stack = StackFeatures(axis=0)
self.schema = schema
self.item_id_col_name = None
[docs] def forward(self, inputs: TabularData) -> torch.Tensor:
item_id_inputs = self.get_item_ids_from_inputs(inputs)
self._expand_non_sequential_features(inputs)
self._check_input_shapes_equal(inputs)
schema: Schema = self.schema # type: ignore
other_inputs = {k: v for k, v in inputs.items() if k != schema.item_id_column_name}
other_inputs_sum = self.stack(other_inputs).sum(dim=0)
result = item_id_inputs.multiply(other_inputs_sum)
return result
[docs] def forward_output_size(self, input_size):
self._check_inputs_last_dim_equal(input_size)
agg_dim = list(input_size.values())[0][-1]
output_size = self._get_agg_output_size(input_size, agg_dim)
return output_size