merlin.models.tf.CachedCrossBatchSampler

class merlin.models.tf.CachedCrossBatchSampler(*args, **kwargs)[source]

Bases: merlin.models.tf.blocks.sampling.base.ItemSampler

Provides efficient cached cross-batch 1 / inter-batch 2 negative sampling for two-tower item retrieval model. The caches consists of a fixed capacity FIFO queue which keeps the item embeddings from the last N batches. All items in the queue are sampled as negatives for upcoming batches. It is more efficient than computing embeddings exclusively for negative items. This is a popularity-biased sampling as popular items are observed more often in training batches. Compared to InBatchSampler, the CachedCrossBatchSampler allows for larger number of negative items, not limited to the batch size. The gradients are not computed for the cached negative embeddings which is a scalable approach. A common combination of samplers for the ItemRetrievalScorer is [InBatchSampler(), CachedCrossBatchSampler(ignore_last_batch_on_sample=True)], which computes gradients for the in-batch negatives and not for the cached item embeddings. P.s. Ignoring the false negatives (negative items equal to the positive ones) is managed by ItemRetrievalScorer(…, sampling_downscore_false_negatives=True)

References

1

Wang, Jinpeng, Jieming Zhu, and Xiuqiang He. “Cross-Batch Negative Sampling for Training Two-Tower Recommenders.” Proceedings of the 44th International ACM SIGIR Conference on Research and Development in Information Retrieval. 2021.

2

Zhou, Chang, et al. “Contrastive learning for debiased candidate generation in large-scale recommender systems.” Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining. 2021.

Parameters
  • capacity (int) – The queue capacity to store samples

  • ignore_last_batch_on_sample (bool) – Whether should include the last batch in the sampling. By default False, as for sampling from the current batch we recommend InBatchSampler(), which allows computing gradients for in-batch negative items

__init__(capacity: int, ignore_last_batch_on_sample: bool = True, **kwargs)[source]

Methods

__init__(capacity[, ignore_last_batch_on_sample])

add(inputs[, training])

add_loss(losses, **kwargs)

Add loss tensor(s), potentially dependent on layer inputs.

add_metric(value[, name])

Adds metric tensor to the layer.

add_update(updates)

Add update op(s), potentially dependent on layer inputs.

add_variable(*args, **kwargs)

Deprecated, do NOT use! Alias for add_weight.

add_weight([name, shape, dtype, …])

Adds a new variable to the layer.

build(input_shapes)

call(inputs[, training])

Adds the current batch to the FIFO queue cache and samples all items embeddings from the last N cached batches.

compute_mask(inputs[, mask])

Computes an output mask tensor.

compute_output_shape(input_shape)

Computes the output shape of the layer.

compute_output_signature(input_signature)

Compute the output tensor signature of the layer based on the inputs.

count_params()

Count the total number of scalars composing the weights.

finalize_state()

Finalizes the layers state after updating layer weights.

from_config(config)

Creates a layer from its config.

get_config()

Returns the config of the layer.

get_input_at(node_index)

Retrieves the input tensor(s) of a layer at a given node.

get_input_mask_at(node_index)

Retrieves the input mask tensor(s) of a layer at a given node.

get_input_shape_at(node_index)

Retrieves the input shape(s) of a layer at a given node.

get_output_at(node_index)

Retrieves the output tensor(s) of a layer at a given node.

get_output_mask_at(node_index)

Retrieves the output mask tensor(s) of a layer at a given node.

get_output_shape_at(node_index)

Retrieves the output shape(s) of a layer at a given node.

get_weights()

Returns the current weights of the layer, as NumPy arrays.

sample()

set_max_num_samples(value)

set_weights(weights)

Sets the weights of the layer, from NumPy arrays.

with_name_scope(method)

Decorator to automatically enter the module name scope.

Attributes

activity_regularizer

Optional regularizer function for the output of this layer.

compute_dtype

The dtype of the layer’s computations.

dtype

The dtype of the layer weights.

dtype_policy

The dtype policy associated with this layer.

dynamic

Whether the layer is dynamic (eager-only); set in the constructor.

inbound_nodes

Return Functional API nodes upstream of this layer.

input

Retrieves the input tensor(s) of a layer.

input_mask

Retrieves the input mask tensor(s) of a layer.

input_shape

Retrieves the input shape(s) of a layer.

input_spec

InputSpec instance(s) describing the input format for this layer.

item_embeddings_queue

losses

List of losses added using the add_loss() API.

max_num_samples

metrics

List of metrics added using the add_metric() API.

name

Name of the layer (string), set in the constructor.

name_scope

Returns a tf.name_scope instance for this class.

non_trainable_variables

non_trainable_weights

List of all non-trainable weights tracked by this layer.

outbound_nodes

Return Functional API nodes downstream of this layer.

output

Retrieves the output tensor(s) of a layer.

output_mask

Retrieves the output mask tensor(s) of a layer.

output_shape

Retrieves the output shape(s) of a layer.

required_features

stateful

submodules

Sequence of all sub-modules.

supports_masking

Whether this layer supports computing a mask using compute_mask.

trainable

trainable_variables

trainable_weights

List of all trainable weights tracked by this layer.

updates

variable_dtype

Alias of Layer.dtype, the dtype of the weights.

variables

Returns the list of all layer variables/weights.

weights

Returns the list of all layer variables/weights.

property item_embeddings_queue
build(input_shapes: Dict[str, tensorflow.python.framework.ops.Tensor])None[source]
call(inputs: Dict[str, tensorflow.python.framework.ops.Tensor], training=True)merlin.models.tf.blocks.core.base.EmbeddingWithMetadata[source]

Adds the current batch to the FIFO queue cache and samples all items embeddings from the last N cached batches.

Parameters
  • inputs (TabularData) –

    Dict with two keys:

    ”items_embeddings”: Items embeddings tensor “items_metadata”: Dict like {“<feature name>”: “<feature tensor>”} which contains features that might be relevant for the sampler (e.g. item id, item popularity, item recency). The CachedCrossBatchSampler does not use metadata features specifically, but “item_id” is required when using in combination with ItemRetrievalScorer(…, sampling_downscore_false_negatives=True), so that false negatives are identified and downscored.

  • training (bool, optional) – Flag indicating if on training mode, by default True

Returns

Value object with the sampled item embeddings and item metadata

Return type

EmbeddingWithMetadata

add(inputs: Dict[str, tensorflow.python.framework.ops.Tensor], training: bool = True)None[source]
sample()merlin.models.tf.blocks.core.base.EmbeddingWithMetadata[source]