Source code for transformers4rec.torch.tabular.aggregation

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import torch

from merlin_standard_lib import Schema

from ...config.schema import requires_schema
from ..typing import TabularData
from ..utils.torch_utils import calculate_batch_size_from_input_size
from .base import TabularAggregation, tabular_aggregation_registry


[docs]@tabular_aggregation_registry.register("concat") class ConcatFeatures(TabularAggregation): """Aggregation by stacking all values in TabularData, all non-sequential values will be converted to a sequence. The output of this concatenation will have 3 dimensions. """
[docs] def forward( self, inputs: TabularData, ) -> torch.Tensor: self._expand_non_sequential_features(inputs) self._check_concat_shapes(inputs) tensors = [] for name in sorted(inputs.keys()): val = inputs[name] tensors.append(val) return torch.cat(tensors, dim=-1)
[docs] def forward_output_size(self, input_size): agg_dim = sum([i[-1] for i in input_size.values()]) output_size = self._get_agg_output_size(input_size, agg_dim) return output_size
[docs]@tabular_aggregation_registry.register("stack") class StackFeatures(TabularAggregation): """Aggregation by stacking all values in input dictionary in the given dimension. Parameters ---------- axis: int, default=-1 Axis to use for the stacking operation. """ def __init__(self, axis: int = -1): super().__init__() self.axis = axis
[docs] def forward(self, inputs: TabularData) -> torch.Tensor: self._expand_non_sequential_features(inputs) self._check_concat_shapes(inputs) tensors = [] for name in sorted(inputs.keys()): tensors.append(inputs[name]) return torch.stack(tensors, dim=self.axis)
[docs] def forward_output_size(self, input_size): batch_size = calculate_batch_size_from_input_size(input_size) seq_features_shapes, sequence_length = self._get_seq_features_shapes(input_size) if len(seq_features_shapes) > 0: output_size = [ batch_size, sequence_length, ] else: output_size = [batch_size] num_features = len(input_size) if self.axis == -1: output_size.append(num_features) else: output_size.insert(self.axis, num_features) return tuple(output_size)
[docs]class ElementwiseFeatureAggregation(TabularAggregation): """Base class for aggregation methods that aggregates features element-wise. It implements two check methods to ensure inputs have the correct shape. """ def _check_input_shapes_equal(self, inputs: TabularData): """Checks if the shapes of all inputs are equal. Parameters ---------- inputs : TabularData Dictionary of tensors. """ all_input_shapes_equal = len(set([x.shape for x in inputs.values()])) == 1 if not all_input_shapes_equal: raise ValueError( "The shapes of all input features are not equal, which is required for" " element-wise aggregation: {}".format({k: v.shape for k, v in inputs.items()}) ) def _check_inputs_last_dim_equal(self, inputs_sizes): """ Checks if the last dimensions of all inputs are equal. Parameters ---------- inputs_sizes : dict[str, Union[List[int], torch.Size]] A dictionary containing the sizes of the inputs. """ all_input_last_dim_equal = len(set([x[-1] for x in inputs_sizes.values()])) == 1 if not all_input_last_dim_equal: raise ValueError( f"The last dim of all input features is not equal, which is" f" required for element-wise aggregation: {inputs_sizes}" )
[docs]@tabular_aggregation_registry.register("element-wise-sum") class ElementwiseSum(ElementwiseFeatureAggregation): """Aggregation by first stacking all values in TabularData in the first dimension, and then summing the result.""" def __init__(self): super().__init__() self.stack = StackFeatures(axis=0)
[docs] def forward(self, inputs: TabularData) -> torch.Tensor: self._expand_non_sequential_features(inputs) self._check_input_shapes_equal(inputs) return self.stack(inputs).sum(dim=0)
[docs] def forward_output_size(self, input_size): self._check_inputs_last_dim_equal(input_size) agg_dim = list(input_size.values())[0][-1] output_size = self._get_agg_output_size(input_size, agg_dim) return output_size
[docs]@tabular_aggregation_registry.register("element-wise-sum-item-multi") @requires_schema class ElementwiseSumItemMulti(ElementwiseFeatureAggregation): """Aggregation by applying the `ElementwiseSum` aggregation to all features except the item-id, and then multiplying this with the item-ids. Parameters ---------- schema: DatasetSchema """ def __init__(self, schema: Schema = None): super().__init__() self.stack = StackFeatures(axis=0) self.schema = schema self.item_id_col_name = None
[docs] def forward(self, inputs: TabularData) -> torch.Tensor: item_id_inputs = self.get_item_ids_from_inputs(inputs) self._expand_non_sequential_features(inputs) self._check_input_shapes_equal(inputs) schema: Schema = self.schema # type: ignore other_inputs = {k: v for k, v in inputs.items() if k != schema.item_id_column_name} other_inputs_sum = self.stack(other_inputs).sum(dim=0) result = item_id_inputs.multiply(other_inputs_sum) return result
[docs] def forward_output_size(self, input_size): self._check_inputs_last_dim_equal(input_size) agg_dim = list(input_size.values())[0][-1] output_size = self._get_agg_output_size(input_size, agg_dim) return output_size