merlin.models.tf.ReplaceMaskedEmbeddings#
- class merlin.models.tf.ReplaceMaskedEmbeddings(*args, **kwargs)[source]#
Bases:
merlin.models.tf.core.base.Block- Takes a 3D input tensor (batch size x seq. length x embedding dim) and replaces
by a dummy trainable single embedding at the positions to be masked. This block looks for the Keras mask (._keras_mask) in the following order:
Checks if the input tensor has a mask
Checks if there is a single target and if it has a mask
3. If there are multiple targets (dict) returns the mask of the target that matches the first 2 dims of the input
This is useful to be used when PredictMasked() transformation is used in the Loader, which randomly selects some targets to be predicted and uses Keras Masking to cascade the _keras_mask. By replacing input embeddings at masked positions we avoid target leakage when training models with Masked Language Modeling (BERT-like)
Note: To support inference, the input sequence and its corresponding mask should be extended by one position at the end to account for the next-item (target) position. To do this, you should set SequenceMaskLastInference as a pre-layer of
ReplaceMaskedEmbeddings() using the sequential-block:
`mm.SequentialBlock([mm.SequenceMaskLastInference(), mm.ReplaceMaskedEmbeddings()])`Methods
__init__(**kwargs)add_loss(losses, **kwargs)Add loss tensor(s), potentially dependent on layer inputs.
add_metric(value[, name])Adds metric tensor to the layer.
add_update(updates)Add update op(s), potentially dependent on layer inputs.
add_variable(*args, **kwargs)Deprecated, do NOT use! Alias for add_weight.
add_weight([name, shape, dtype, ...])Adds a new variable to the layer.
as_tabular([name])build(input_shape)build_from_config(config)call(inputs[, targets])If the sequence of input embeddings or the corresponding sequential targets is masked (with tensor._keras_mask defined), replaces the input embeddings for masked elements :param inputs: A tensor with sequences of vectors. Needs to be 3D (batch_size, sequence_length, embeddings dim). If inputs._keras_mask is defined uses it to infer the mask :type inputs: Union[tf.Tensor, tf.RaggedTensor] :param targets: The target values, from which the mask can be extracted if targets inputs._keras_mask is defined. :type targets: Union[tf.Tensor, tf.RaggedTensor, TabularData], optional.
call_outputs(outputs[, training])check_schema([schema])compute_mask(inputs[, mask])Computes an output mask tensor.
compute_output_shape(input_shape)Computes the output shape of the layer.
compute_output_signature(input_signature)Compute the output tensor signature of the layer based on the inputs.
connect(*block[, block_name, context])Connect the block to other blocks sequentially.
connect_branch(*branches[, add_rest, post, ...])Connect the block to one or multiple branches.
connect_debug_block([append])Connect the block to a debug block.
connect_with_residual(block[, activation])Connect the block to other blocks sequentially with a residual connection.
connect_with_shortcut(block[, ...])Connect the block to other blocks sequentially with a shortcut connection.
copy()count_params()Count the total number of scalars composing the weights.
finalize_state()Finalizes the layers state after updating layer weights.
from_config(config)Creates a layer from its config.
from_layer(layer)get_build_config()get_config()get_input_at(node_index)Retrieves the input tensor(s) of a layer at a given node.
get_input_mask_at(node_index)Retrieves the input mask tensor(s) of a layer at a given node.
get_input_shape_at(node_index)Retrieves the input shape(s) of a layer at a given node.
get_item_ids_from_inputs(inputs)get_output_at(node_index)Retrieves the output tensor(s) of a layer at a given node.
get_output_mask_at(node_index)Retrieves the output mask tensor(s) of a layer at a given node.
get_output_shape_at(node_index)Retrieves the output shape(s) of a layer at a given node.
get_padding_mask_from_item_id(inputs[, ...])get_weights()Returns the current weights of the layer, as NumPy arrays.
parse(*block)parse_block(input)prepare([block, post, aggregation])Transform the inputs of this block.
register_features(feature_shapes)repeat([num])Repeat the block num times.
repeat_in_parallel([num, prefix, names, ...])Repeat the block num times in parallel.
select_by_name(name)select_by_tag(tags)set_schema([schema])set_weights(weights)Sets the weights of the layer, from NumPy arrays.
with_name_scope(method)Decorator to automatically enter the module name scope.
Attributes
REQUIRES_SCHEMAactivity_regularizerOptional regularizer function for the output of this layer.
compute_dtypeThe dtype of the layer's computations.
contextdtypeThe dtype of the layer weights.
dtype_policyThe dtype policy associated with this layer.
dynamicWhether the layer is dynamic (eager-only); set in the constructor.
has_schemainbound_nodesReturn Functional API nodes upstream of this layer.
inputRetrieves the input tensor(s) of a layer.
input_maskRetrieves the input mask tensor(s) of a layer.
input_shapeRetrieves the input shape(s) of a layer.
input_specInputSpec instance(s) describing the input format for this layer.
lossesList of losses added using the add_loss() API.
metricsList of metrics added using the add_metric() API.
nameName of the layer (string), set in the constructor.
name_scopeReturns a tf.name_scope instance for this class.
non_trainable_variablesnon_trainable_weightsList of all non-trainable weights tracked by this layer.
outbound_nodesReturn Functional API nodes downstream of this layer.
outputRetrieves the output tensor(s) of a layer.
output_maskRetrieves the output mask tensor(s) of a layer.
output_shapeRetrieves the output shape(s) of a layer.
registryschemastatefulsubmodulesSequence of all sub-modules.
supports_maskingWhether this layer supports computing a mask using compute_mask.
trainabletrainable_variablestrainable_weightsList of all trainable weights tracked by this layer.
updatesvariable_dtypeAlias of Layer.dtype, the dtype of the weights.
variablesReturns the list of all layer variables/weights.
weightsReturns the list of all layer variables/weights.
- call(inputs: Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor], targets: Optional[Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor, Dict[str, tensorflow.python.framework.ops.Tensor]]] = None) Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor][source]#
If the sequence of input embeddings or the corresponding sequential targets is masked (with tensor._keras_mask defined), replaces the input embeddings for masked elements :param inputs: A tensor with sequences of vectors.
Needs to be 3D (batch_size, sequence_length, embeddings dim). If inputs._keras_mask is defined uses it to infer the mask
- Parameters
targets (Union[tf.Tensor, tf.RaggedTensor, TabularData], optional) – The target values, from which the mask can be extracted if targets inputs._keras_mask is defined.
- Returns
If training, returns a tensor with the masked inputs replaced by the dummy embedding
- Return type
Union[tf.Tensor, tf.RaggedTensor]