#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from typing import Dict, Optional
import torch
from ..features.embedding import FeatureConfig, TableConfig
from ..typing import TabularData, TensorOrTabularData
from .base import TabularTransformation, tabular_transformation_registry
LOG = logging.getLogger("transformers4rec")
[docs]@tabular_transformation_registry.register_with_multiple_names("stochastic-swap-noise", "ssn")
class StochasticSwapNoise(TabularTransformation):
"""
Applies Stochastic replacement of sequence features.
It can be applied as a `pre` transform like `TransformerBlock(pre="stochastic-swap-noise")`
"""
def __init__(self, schema=None, pad_token=0, replacement_prob=0.1):
super().__init__()
self.schema = schema
self.pad_token = pad_token
self.replacement_prob = replacement_prob
[docs] def forward( # type: ignore
self, inputs: TensorOrTabularData, input_mask: Optional[torch.Tensor] = None, **kwargs
) -> TensorOrTabularData:
if self.schema:
input_mask = input_mask or self.get_padding_mask_from_item_id(inputs, self.pad_token)
if isinstance(inputs, dict):
return {key: self.augment(val, input_mask) for key, val in inputs.items()}
return self.augment(inputs, input_mask)
[docs] def forward_output_size(self, input_size):
return input_size
[docs] def augment(
self, input_tensor: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
with torch.no_grad():
if mask is not None:
if input_tensor.ndim == mask.ndim - 1:
mask = mask[:, 0]
sse_prob_replacement_matrix = torch.full(
input_tensor.shape,
self.replacement_prob,
device=input_tensor.device,
)
sse_replacement_mask = torch.bernoulli(sse_prob_replacement_matrix).bool()
if mask is not None:
sse_replacement_mask = sse_replacement_mask & mask
n_values_to_replace = sse_replacement_mask.sum()
if mask is not None:
masked = torch.masked_select(input_tensor, mask)
else:
masked = torch.clone(input_tensor)
input_permutation = torch.randperm(masked.shape[0])
sampled_values_to_replace = masked[input_permutation][
:n_values_to_replace # type: ignore
]
output_tensor = input_tensor.clone()
if input_tensor[sse_replacement_mask].size() != sampled_values_to_replace:
sampled_values_to_replace = torch.squeeze(sampled_values_to_replace)
output_tensor[sse_replacement_mask] = sampled_values_to_replace
return output_tensor
[docs]@tabular_transformation_registry.register_with_multiple_names("layer-norm")
class TabularLayerNorm(TabularTransformation):
"""
Applies Layer norm to each input feature individually, before the aggregation
"""
def __init__(self, features_dim: Optional[Dict[str, int]] = None):
super().__init__()
self.feature_layer_norm = torch.nn.ModuleDict()
self._set_features_layer_norm(features_dim)
def _set_features_layer_norm(self, features_dim):
feature_layer_norm = {}
if features_dim:
for fname, dim in features_dim.items():
if dim == 1:
LOG.warning(
f"Layer norm can only be applied on features with more than 1 dim, "
f"but feature {fname} has dim {dim}"
)
continue
feature_layer_norm[fname] = torch.nn.LayerNorm(normalized_shape=dim)
self.feature_layer_norm.update(feature_layer_norm)
[docs] @classmethod
def from_feature_config(cls, feature_config: Dict[str, FeatureConfig]):
features_dim = {}
for name, feature in feature_config.items():
table: TableConfig = feature.table
features_dim[name] = table.dim
return cls(features_dim)
[docs] def forward(self, inputs: TabularData, **kwargs) -> TabularData:
return {
key: (self.feature_layer_norm[key](val) if key in self.feature_layer_norm else val)
for key, val in inputs.items()
}
[docs] def forward_output_size(self, input_size):
return input_size
[docs] def build(self, input_size, **kwargs):
if input_size is not None:
features_dim = {k: v[-1] for k, v in input_size.items()}
self._set_features_layer_norm(features_dim)
return super().build(input_size, **kwargs)
[docs]@tabular_transformation_registry.register_with_multiple_names("dropout")
class TabularDropout(TabularTransformation):
"""
Applies dropout transformation.
"""
def __init__(self, dropout_rate=0.0):
super().__init__()
self.dropout = torch.nn.Dropout(dropout_rate)
[docs] def forward(self, inputs: TensorOrTabularData, **kwargs) -> TensorOrTabularData: # type: ignore
outputs = {key: self.dropout(val) for key, val in inputs.items()} # type: ignore
return outputs
[docs] def forward_output_size(self, input_size):
return input_size