#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
import tensorflow as tf
from keras.utils import control_flow_util
from tensorflow.keras import backend
from tensorflow.python.ops import array_ops
from merlin.models.tf.core.base import Block
from merlin.models.tf.core.tabular import TabularBlock, TensorOrTabularData
[docs]@Block.registry.register_with_multiple_names("stochastic-swap-noise", "ssn")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class StochasticSwapNoise(TabularBlock):
"""
Applies Stochastic replacement of sequence features
"""
[docs] def __init__(self, schema=None, pad_token=0, replacement_prob=0.1, **kwargs):
super().__init__(**kwargs)
self.schema = schema
self.pad_token = pad_token
self.replacement_prob = replacement_prob
[docs] def call(
self,
inputs: TensorOrTabularData,
input_mask: Optional[tf.Tensor] = None,
training=False,
**kwargs,
) -> TensorOrTabularData:
def augment(input_mask):
if self._schema:
input_mask = input_mask or self.get_padding_mask_from_item_id(
inputs, self.pad_token
)
if isinstance(inputs, dict):
return {key: self.augment(val, input_mask) for key, val in inputs.items()}
return self.augment(inputs, input_mask)
output = control_flow_util.smart_cond(training, lambda: augment(input_mask), lambda: inputs)
return output
[docs] def augment(self, input_tensor: tf.Tensor, mask: Optional[tf.Tensor], **kwargs) -> tf.Tensor:
if mask is not None:
if len(input_tensor.shape) == len(mask.shape) - 1:
mask = mask[:, 0]
casted = tf.cast(
backend.random_binomial(array_ops.shape(input_tensor), p=self.replacement_prob),
tf.int32,
)
replacement_mask_matrix = casted * tf.cast(mask, tf.int32)
n_values_to_replace = tf.reduce_sum(replacement_mask_matrix)
input_flattened_non_zero = tf.boolean_mask(
input_tensor, tf.cast(replacement_mask_matrix, tf.bool)
)
sampled_values_to_replace = tf.gather(
input_flattened_non_zero,
tf.random.shuffle(tf.range(tf.shape(input_flattened_non_zero)[0]))[
:n_values_to_replace
],
)
replacement_indices = tf.sparse.from_dense(replacement_mask_matrix).indices
output_tensor = tf.tensor_scatter_nd_update(
input_tensor, replacement_indices, sampled_values_to_replace
)
return output_tensor
[docs] def compute_output_shape(self, input_shape):
return input_shape
[docs] def get_config(self):
config = super().get_config()
config["pad_token"] = self.pad_token
config["replacement_prob"] = self.replacement_prob
return config