#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass
from typing import Any, Dict
import tensorflow as tf
from tensorflow.keras import backend
from tensorflow.python.ops import array_ops
from merlin_standard_lib import Registry
from merlin_standard_lib.utils.doc_utils import docstring_parameter
masking_registry = Registry("tf.masking")
MASK_SEQUENCE_PARAMETERS_DOCSTRING = """
hidden_size: int
The hidden dimension of input tensors, needed to initialize trainable vector of masked
positions.
padding_idx: int, default = 0
Index of padding item used for getting batch of sequences with the same length
eval_on_last_item_seq_only: bool, default = True
Predict only last item during evaluation
"""
[docs]@dataclass
class MaskingInfo:
schema: tf.Tensor
targets: tf.Tensor
[docs]@docstring_parameter(mask_sequence_parameters=MASK_SEQUENCE_PARAMETERS_DOCSTRING)
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class MaskSequence(tf.keras.layers.Layer):
"""Base class to prepare masked items inputs/labels for language modeling tasks.
Transformer architectures can be trained in different ways. Depending of the training method,
there is a specific masking schema. The masking schema sets the items to be predicted (labels)
and mask (hide) their positions in the sequence so that they are not used by the Transformer
layers for prediction.
We currently provide 4 different masking schemes out of the box:
- Causal LM (clm)
- Masked LM (mlm)
- Permutation LM (plm)
- Replacement Token Detection (rtd)
This class can be extended to add different a masking scheme.
Parameters
----------
{mask_sequence_parameters}
"""
# TODO: Link to masking-class in the doc-string.
def __init__(self, padding_idx: int = 0, eval_on_last_item_seq_only: bool = True, **kwargs):
super(MaskSequence, self).__init__(**kwargs)
self.padding_idx = padding_idx
self.eval_on_last_item_seq_only = eval_on_last_item_seq_only
self.mask_schema = None
self.masked_targets = None
[docs] def get_config(self):
config = super(MaskSequence, self).get_config()
config.update(
{
"padding_idx": self.padding_idx,
"eval_on_last_item_seq_only": self.eval_on_last_item_seq_only,
}
)
return config
def _compute_masked_targets(self, item_ids: tf.Tensor, training=False) -> MaskingInfo:
"""
Method to prepare masked labels based on the sequence of item ids.
It returns The true labels of masked positions and the related boolean mask.
Parameters
----------
item_ids: tf.Tensor
The sequence of input item ids used for deriving labels of
next item prediction task.
training: bool
Flag to indicate whether we are in `Training` mode or not.
During training, the labels can be any items within the sequence
based on the selected masking task.
During evaluation, we are predicting all next items or last item only
in the sequence based on the param `eval_on_last_item_seq_only`.
"""
raise NotImplementedError
[docs] def compute_masked_targets(self, item_ids: tf.Tensor, training=False) -> MaskingInfo:
"""
Method to prepare masked labels based on the sequence of item ids.
It returns The true labels of masked positions and the related boolean mask.
And the attributes of the class `mask_schema` and `masked_targets`
are updated to be re-used in other modules.
Parameters
----------
item_ids: tf.Tensor
The sequence of input item ids used for deriving labels of
next item prediction task.
training: bool
Flag to indicate whether we are in `Training` mode or not.
During training, the labels can be any items within the sequence
based on the selected masking task.
During evaluation, we are predicting the last item in the sequence.
Returns
-------
Tuple[MaskingSchema, MaskedTargets]
"""
assert item_ids.shape.rank == 2, "`item_ids` must have 2 dimensions."
masking_info = self._compute_masked_targets(item_ids, training=training)
self.mask_schema, self.masked_targets = masking_info.schema, masking_info.targets
return masking_info
[docs] def predict_all(self, item_ids: tf.Tensor) -> MaskingInfo:
"""
Prepare labels for all next item predictions instead of
last-item predictions in a user's sequence.
Parameters
----------
item_ids: tf.Tensor
The sequence of input item ids used for deriving labels of
next item prediction task.
Returns
-------
Tuple[MaskingSchema, MaskedTargets]
"""
# TODO : Add option to predict N-last items
# shift sequence of item-ids
labels = item_ids[:, 1:]
# As after shifting the sequence length will be subtracted by one, adding a masked item in
# the sequence to return to the initial sequence.
# This is important for ReformerModel(), for example
labels = tf.concat(
[
labels,
tf.zeros((tf.shape(labels)[0], 1), dtype=labels.dtype),
],
axis=-1,
)
# apply mask on input where target is on padding index
mask_labels = labels != self.padding_idx
return MaskingInfo(mask_labels, labels)
[docs] def call(self, inputs: tf.Tensor, item_ids: tf.Tensor, training=False) -> tf.Tensor:
_ = self.compute_masked_targets(item_ids=item_ids, training=training)
return self.apply_mask_to_inputs(inputs, self.mask_schema)
@property
def transformer_arguments(self) -> Dict[str, Any]:
"""
Prepare additional arguments to pass to the Transformer forward methods.
"""
return {**self.transformer_required_arguments(), **self.transformer_optional_arguments()}
[docs] def build(self, input_shape):
self.hidden_size = input_shape[-1]
# Create a trainable embedding to replace masked interactions
initializer = tf.random_normal_initializer(mean=0.0, stddev=0.001)
self.masked_item_embedding = tf.Variable(
initializer(shape=[self.hidden_size], dtype=tf.float32)
)
return super().build(input_shape)
[docs]@masking_registry.register_with_multiple_names("clm", "causal")
@docstring_parameter(mask_sequence_parameters=MASK_SEQUENCE_PARAMETERS_DOCSTRING)
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class CausalLanguageModeling(MaskSequence):
"""
In Causal Language Modeling (clm) you predict the next item based on past positions of the
sequence. Future positions are masked.
Parameters
----------
{mask_sequence_parameters}
train_on_last_item_seq_only: predict only last item during training
"""
def __init__(
self,
padding_idx: int = 0,
eval_on_last_item_seq_only: bool = True,
train_on_last_item_seq_only: bool = False,
**kwargs
):
super(CausalLanguageModeling, self).__init__(
padding_idx=padding_idx, eval_on_last_item_seq_only=eval_on_last_item_seq_only, **kwargs
)
self.train_on_last_item_seq_only = train_on_last_item_seq_only
self.label_seq_trg_eval = tf.Variable(
tf.zeros(shape=[1, 1], dtype=tf.int32),
dtype=tf.int32,
trainable=False,
shape=tf.TensorShape([None, None]),
)
[docs] def get_config(self):
config = super(CausalLanguageModeling, self).get_config()
config.update(
{
"train_on_last_item_seq_only": self.train_on_last_item_seq_only,
}
)
return config
def _compute_masked_targets(self, item_ids: tf.Tensor, training=False) -> MaskingInfo:
item_ids = tf.cast(item_ids, dtype=tf.int32)
masking_info: MaskingInfo = self.predict_all(item_ids)
mask_labels, labels = masking_info.schema, masking_info.targets
if (self.eval_on_last_item_seq_only and not training) or (
self.train_on_last_item_seq_only and training
):
last_item_sessions = tf.reduce_sum(tf.cast(mask_labels, labels.dtype), axis=1) - 1
rows_ids = tf.range(tf.shape(labels)[0], dtype=labels.dtype)
self.label_seq_trg_eval.assign(tf.zeros(tf.shape(labels), dtype=tf.int32))
indices = tf.concat(
[tf.expand_dims(rows_ids, 1), tf.expand_dims(last_item_sessions, 1)], axis=1
)
self.label_seq_trg_eval.scatter_nd_update(
indices=indices, updates=tf.gather_nd(labels, indices)
)
# Updating labels and mask
mask_labels = self.label_seq_trg_eval != self.padding_idx
else:
self.label_seq_trg_eval.assign(labels)
return MaskingInfo(mask_labels, self.label_seq_trg_eval)
[docs]@masking_registry.register_with_multiple_names("mlm", "masked")
@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class MaskedLanguageModeling(MaskSequence):
"""
In Masked Language Modeling (mlm) you randomly select some positions of the sequence to be
predicted, which are masked.
During training, the Transformer layer is allowed to use positions on the right (future info).
During inference, all past items are visible for the Transformer layer, which tries to predict
the next item.
Parameters
----------
{mask_sequence_parameters}
mlm_probability: Optional[float], default = 0.15
Probability of an item to be selected (masked) as a label of the given sequence.
p.s. We enforce that at least one item is masked for each sequence, so that the network can
learn something with it.
"""
def __init__(
self,
padding_idx: int = 0,
eval_on_last_item_seq_only: bool = True,
mlm_probability: float = 0.15,
**kwargs
):
super(MaskedLanguageModeling, self).__init__(
padding_idx=padding_idx, eval_on_last_item_seq_only=eval_on_last_item_seq_only, **kwargs
)
self.mlm_probability = mlm_probability
self.labels = tf.Variable(
tf.zeros(shape=[1, 1], dtype=tf.int32),
dtype=tf.int32,
trainable=False,
shape=tf.TensorShape([None, None]),
)
[docs] def get_config(self):
config = super(MaskedLanguageModeling, self).get_config()
config.update(
{
"mlm_probability": self.mlm_probability,
}
)
return config
def _compute_masked_targets(self, item_ids: tf.Tensor, training: bool = False) -> MaskingInfo:
"""
Prepare sequence with mask schema for masked language modeling prediction
the function is based on HuggingFace's transformers/data/data_collator.py
Parameters
----------
item_ids: tf.Tensor
Sequence of input itemid (target) column
Returns
-------
labels: tf.Tensor
Sequence of masked item ids.
mask_labels: tf.Tensor
Masking schema for masked targets positions.
"""
# cast item_ids to int32
item_ids = tf.cast(item_ids, dtype=tf.int32)
self.labels.assign(tf.fill(tf.shape(item_ids), self.padding_idx))
non_padded_mask = tf.cast(item_ids != self.padding_idx, self.labels.dtype)
rows_ids = tf.range(tf.shape(item_ids)[0], dtype=tf.int64)
# During training, masks labels to be predicted according to a probability, ensuring that
# each session has at least one label to predict
if training:
# Selects a percentage of items to be masked (selected as labels)
probability_matrix = tf.cast(
backend.random_bernoulli(array_ops.shape(item_ids), p=self.mlm_probability),
self.labels.dtype,
)
mask_labels = probability_matrix * non_padded_mask
self.labels.assign(
tf.where(
tf.cast(mask_labels, tf.bool),
item_ids,
tf.cast(tf.fill(tf.shape(item_ids), self.padding_idx), dtype=item_ids.dtype),
)
)
# Set at least one item in the sequence to mask, so that the network
# can learn something with this session
one_random_index_by_session = tf.random.categorical(
tf.math.log(tf.cast(non_padded_mask, tf.float32)), num_samples=1
)
indices = tf.concat([tf.expand_dims(rows_ids, 1), one_random_index_by_session], axis=1)
self.labels.scatter_nd_update(indices=indices, updates=tf.gather_nd(item_ids, indices))
mask_labels = tf.cast(self.labels != self.padding_idx, self.labels.dtype)
# If a sequence has only masked labels, unmask one of the labels
sequences_with_only_labels = tf.reduce_sum(mask_labels, axis=1) == tf.reduce_sum(
non_padded_mask, axis=1
)
sampled_labels_to_unmask = tf.random.categorical(
tf.math.log(tf.cast(mask_labels, tf.float32)), num_samples=1
)
labels_to_unmask = tf.boolean_mask(sampled_labels_to_unmask, sequences_with_only_labels)
rows_to_unmask = tf.boolean_mask(rows_ids, sequences_with_only_labels)
indices = tf.concat([tf.expand_dims(rows_to_unmask, 1), labels_to_unmask], axis=1)
num_updates = tf.shape(indices)[0]
self.labels.scatter_nd_update(
indices, tf.cast(tf.fill((num_updates,), self.padding_idx), self.labels.dtype)
)
mask_labels = self.labels != self.padding_idx
else:
if self.eval_on_last_item_seq_only:
last_item_sessions = tf.reduce_sum(non_padded_mask, axis=1) - 1
indices = tf.concat(
[
tf.expand_dims(rows_ids, 1),
tf.cast(tf.expand_dims(last_item_sessions, 1), tf.int64),
],
axis=1,
)
self.labels.scatter_nd_update(
indices=indices, updates=tf.gather_nd(item_ids, indices)
)
mask_labels = self.labels != self.padding_idx
else:
masking_info = self.predict_all(item_ids)
mask_labels, labels = masking_info.schema, masking_info.targets
self.labels.assign(labels)
return MaskingInfo(mask_labels, self.labels)
# @masking_registry.register_with_multiple_names("plm", "permutation")
# class PermutationLanguageModeling(MaskSequence):
# pass
#
#
# @masking_registry.register_with_multiple_names("rtd", "replacement")
# class ReplacementLanguageModeling(MaskSequence):
# pass