merlin.models.tf.ContrastiveOutput#

class merlin.models.tf.ContrastiveOutput(*args, **kwargs)[source]#

Bases: merlin.models.tf.outputs.base.ModelOutput

Categorical output

Parameters
  • to_call (Union[Schema, ColumnSchema,) –

    EmbeddingTable, ‘CategoricalTarget’,

    ’EmbeddingTablePrediction’, ‘DotProduct’]

    The target feature to predict. To perform weight-tying [1] technique, you should provide the EmbeddingTable or EmbeddingTablePrediction related to the target feature.

  • negative_samplers (ItemSamplersType) – List of samplers for negative sampling, by default None

  • pre (Optional[Block], optional) – Optional block to transform predictions before computing the binary logits, by default None

  • post (Optional[Block], optional) – Optional block to transform the binary logits, by default None

  • logits_temperature (float, optional) – Parameter used to reduce model overconfidence, so that logits / T. by default 1

  • name (str, optional) – The name of the task, by default None

  • default_loss (Union[str, tf.keras.losses.Loss], optional) – Default loss to use for categorical-classification by default ‘categorical_crossentropy’

  • get_default_metrics (Callable, optional) – A function returning the list of default metrics to use for categorical-classification

  • store_negative_ids (bool, optional) – Whether to store negative ids for post-processing by default False

  • logq_sampling_correction (bool, optional) – The LogQ correction is a standard technique for sampled softmax and popularity-biased sampling. It subtracts from the logits the log expected count/prob of the positive and negative samples in order to not overpenalize the popular items for being sampled more often as negatives. It can be enabled if a single negative sampler is provided and if it provides the sampler provides the sampling probabilities (i.e. implements with_sampling_probs()). Another alternative for performing logQ correction is using ContrastiveOutput(…, post=PopularityLogitsCorrection(item_frequencies)), where you need to provide the items frequency probability distribution (prior). Default is False.

  • References

  • ----------

  • Inan ([1] Hakan) –

  • Khosravi (Khashayar) –

  • vectors (and Richard Socher. 2016. Tying word) –

  • classifiers (and word) –

  • arXiv (1611.01462 (2016).) –

  • Notes

  • ----------

  • DotProduct() (In case to_call is set as) –

  • inferred (schema of target couldn't be) –

  • therefore

  • arg (the user should feed a schema only with ITEM_ID feature as schema) –

:param : :param which is treated as a kwargs arg below.: :param Example usage::

outputs=mm.ContrastiveOutput(
    to_call=DotProduct(),
    negative_samplers="in-batch",
    schema=schema.select_by_tag(Tags.ITEM_ID),
    logits_temperature = 0.2,
)
Parameters
  • arg. (The schema arg is not needed when we pass the schema to to_call) –

  • usage:: (Example) –

    outputs=mm.ContrastiveOutput(

    to_call=schema.select_by_tag(Tags.ITEM_ID), negative_samplers=”in-batch”, logits_temperature = 0.2,

    )

__init__(to_call: typing.Union[merlin.schema.schema.Schema, merlin.schema.schema.ColumnSchema, merlin.models.tf.inputs.embedding.EmbeddingTable, merlin.models.tf.outputs.classification.CategoricalTarget, merlin.models.tf.outputs.classification.EmbeddingTablePrediction, merlin.models.tf.outputs.base.DotProduct], negative_samplers: typing.Union[merlin.models.tf.outputs.sampling.base.CandidateSampler, typing.Sequence[typing.Union[merlin.models.tf.outputs.sampling.base.CandidateSampler, str]], str], target_name: typing.Optional[str] = None, pre: typing.Optional[keras.engine.base_layer.Layer] = None, post: typing.Optional[keras.engine.base_layer.Layer] = None, logits_temperature: float = 1.0, name: typing.Optional[str] = None, default_loss: typing.Union[str, keras.losses.Loss] = 'categorical_crossentropy', default_metrics_fn: typing.Callable[[], typing.Sequence[keras.metrics.base_metric.Metric]] = <function default_categorical_prediction_metrics>, downscore_false_negatives=True, false_negative_score: float = -655.04, query_name: str = 'query', candidate_name: str = 'candidate', store_negative_ids: bool = False, logq_sampling_correction: typing.Optional[bool] = False, **kwargs)[source]#

Methods

__init__(to_call, negative_samplers[, ...])

add_loss(losses, **kwargs)

Add loss tensor(s), potentially dependent on layer inputs.

add_metric(value[, name])

Adds metric tensor to the layer.

add_update(updates)

Add update op(s), potentially dependent on layer inputs.

add_variable(*args, **kwargs)

Deprecated, do NOT use! Alias for add_weight.

add_weight([name, shape, dtype, ...])

Adds a new variable to the layer.

build(input_shape)

build_from_config(config)

call(inputs[, features, targets, training, ...])

call_contrastive(inputs, features, targets)

compute_mask(inputs[, mask])

Computes an output mask tensor.

compute_output_shape(input_shape)

compute_output_signature(input_signature)

Compute the output tensor signature of the layer based on the inputs.

count_params()

Count the total number of scalars composing the weights.

create_default_metrics()

embedding_lookup(ids)

finalize_state()

Finalizes the layers state after updating layer weights.

from_config(config)

get_build_config()

get_config()

get_input_at(node_index)

Retrieves the input tensor(s) of a layer at a given node.

get_input_mask_at(node_index)

Retrieves the input mask tensor(s) of a layer at a given node.

get_input_shape_at(node_index)

Retrieves the input shape(s) of a layer at a given node.

get_output_at(node_index)

Retrieves the output tensor(s) of a layer at a given node.

get_output_mask_at(node_index)

Retrieves the output mask tensor(s) of a layer at a given node.

get_output_shape_at(node_index)

Retrieves the output shape(s) of a layer at a given node.

get_task_name(target_name)

get_weights()

Returns the current weights of the layer, as NumPy arrays.

outputs(query_embedding, positive, negative)

Method to compute the dot product between the query embeddings and positive/negative candidates

sample_negatives(positive, features[, ...])

Method to sample negatives from self.negative_samplers

set_negative_samplers(negative_samplers)

set_weights(weights)

Sets the weights of the layer, from NumPy arrays.

to_dataset([gpu])

with_name_scope(method)

Decorator to automatically enter the module name scope.

Attributes

activity_regularizer

Optional regularizer function for the output of this layer.

compute_dtype

The dtype of the layer's computations.

dtype

The dtype of the layer weights.

dtype_policy

The dtype policy associated with this layer.

dynamic

Whether the layer is dynamic (eager-only); set in the constructor.

has_candidate_weights

inbound_nodes

Return Functional API nodes upstream of this layer.

input

Retrieves the input tensor(s) of a layer.

input_mask

Retrieves the input mask tensor(s) of a layer.

input_shape

Retrieves the input shape(s) of a layer.

input_spec

InputSpec instance(s) describing the input format for this layer.

keys

losses

List of losses added using the add_loss() API.

metrics

List of metrics added using the add_metric() API.

name

Name of the layer (string), set in the constructor.

name_scope

Returns a tf.name_scope instance for this class.

non_trainable_variables

non_trainable_weights

List of all non-trainable weights tracked by this layer.

outbound_nodes

Return Functional API nodes downstream of this layer.

output

Retrieves the output tensor(s) of a layer.

output_mask

Retrieves the output mask tensor(s) of a layer.

output_shape

Retrieves the output shape(s) of a layer.

stateful

submodules

Sequence of all sub-modules.

supports_masking

Whether this layer supports computing a mask using compute_mask.

task_name

trainable

trainable_variables

trainable_weights

List of all trainable weights tracked by this layer.

updates

variable_dtype

Alias of Layer.dtype, the dtype of the weights.

variables

Returns the list of all layer variables/weights.

weights

Returns the list of all layer variables/weights.

build(input_shape)[source]#
call(inputs, features=None, targets=None, training=False, testing=False)[source]#
call_contrastive(inputs, features, targets, training=False, testing=False)[source]#
outputs(query_embedding: tensorflow.python.framework.ops.Tensor, positive: merlin.models.tf.outputs.sampling.base.Candidate, negative: merlin.models.tf.outputs.sampling.base.Candidate) merlin.models.tf.core.prediction.Prediction[source]#

Method to compute the dot product between the query embeddings and positive/negative candidates

Parameters
  • query_embedding (tf.Tensor) – tensor of query embeddings.

  • positive (Candidate) – Store the ids and metadata (such as embeddings) of the positive candidates.

  • negative (Candidate) – Store the ids and metadata (such as embeddings) of the sampled negative candidates.

Returns

a Prediction object with the prediction scores, the targets and the negative candidates ids if specified.

Return type

Prediction

sample_negatives(positive: merlin.models.tf.outputs.sampling.base.Candidate, features: Dict[str, tensorflow.python.framework.ops.Tensor], training=False, testing=False) Tuple[merlin.models.tf.outputs.sampling.base.Candidate, merlin.models.tf.outputs.sampling.base.Candidate][source]#

Method to sample negatives from self.negative_samplers

Parameters
  • positive_items (Items) – Positive items (ids and metadata)

  • features (TabularData) – Dictionary of input raw tensors

  • training (bool, optional) – Flag for train mode, by default False

  • testing (bool, optional) – Flag for test mode, by default False

Returns

Tuple of candidates with sampled negative ids and the provided positive ids added with the sampling probability

Return type

Tuple[Candidate, Candidate]

embedding_lookup(ids: tensorflow.python.framework.ops.Tensor)[source]#
to_dataset(gpu=None) merlin.io.dataset.Dataset[source]#
property has_candidate_weights: bool#
property keys: List[str]#
set_negative_samplers(negative_samplers: Union[merlin.models.tf.outputs.sampling.base.CandidateSampler, Sequence[Union[merlin.models.tf.outputs.sampling.base.CandidateSampler, str]], str])[source]#
get_config()[source]#
classmethod from_config(config)[source]#