#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import warnings
from typing import List, Optional, Protocol, Tuple, Union, runtime_checkable
import tensorflow as tf
from tensorflow.keras.layers import Layer
import merlin.io
from merlin.models.tf.core.prediction import Prediction
from merlin.models.tf.inputs.embedding import EmbeddingTable
from merlin.models.tf.outputs.base import DotProduct, MetricsFn, ModelOutput
from merlin.models.tf.outputs.classification import (
CategoricalTarget,
EmbeddingTablePrediction,
default_categorical_prediction_metrics,
)
from merlin.models.tf.outputs.sampling.base import (
Candidate,
ItemSamplersType,
parse_negative_samplers,
)
from merlin.models.tf.typing import TabularData
from merlin.models.tf.utils import tf_utils
from merlin.models.utils import schema_utils
from merlin.models.utils.constants import MIN_FLOAT
from merlin.schema import ColumnSchema, Schema
LOG = logging.getLogger("merlin_models")
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class ContrastiveOutput(ModelOutput):
"""Categorical output
Parameters
----------
to_call: Union[Schema, ColumnSchema,
EmbeddingTable, 'CategoricalTarget',
'EmbeddingTablePrediction', 'DotProduct']
The target feature to predict. To perform weight-tying [1] technique, you should provide
the `EmbeddingTable` or `EmbeddingTablePrediction` related to the
target feature.
negative_samplers: ItemSamplersType
List of samplers for negative sampling,
by default None
pre: Optional[Block], optional
Optional block to transform predictions before computing the binary logits,
by default None
post: Optional[Block], optional
Optional block to transform the binary logits,
by default None
logits_temperature: float, optional
Parameter used to reduce model overconfidence, so that logits / T.
by default 1
name: str, optional
The name of the task, by default None
default_loss: Union[str, tf.keras.losses.Loss], optional
Default loss to use for categorical-classification
by default 'categorical_crossentropy'
get_default_metrics: Callable, optional
A function returning the list of default metrics
to use for categorical-classification
store_negative_ids: bool, optional
Whether to store negative ids for post-processing
by default False
logq_sampling_correction: bool, optional
The LogQ correction is a standard technique for
sampled softmax and popularity-biased sampling.
It subtracts from the logits the
log expected count/prob of the positive and
negative samples in order to not overpenalize the
popular items for being sampled more often as negatives.
It can be enabled if a single negative sampler is provided
and if it provides the sampler provides the
sampling probabilities (i.e. implements with_sampling_probs()).
Another alternative for performing logQ correction is using
ContrastiveOutput(..., post=PopularityLogitsCorrection(item_frequencies)),
where you need to provide the items frequency probability distribution (prior).
Default is False.
References:
----------
[1] Hakan Inan, Khashayar Khosravi, and Richard Socher. 2016. Tying word vectors
and word classifiers: A loss framework for language modeling. arXiv preprint
arXiv:1611.01462 (2016).
Notes:
----------
In case `to_call` is set as `DotProduct()`, schema of target couldn't be inferred
therefore, the user should feed a schema only with ITEM_ID feature as schema arg,
which is treated as a `kwargs` arg below.
Example usage::
outputs=mm.ContrastiveOutput(
to_call=DotProduct(),
negative_samplers="in-batch",
schema=schema.select_by_tag(Tags.ITEM_ID),
logits_temperature = 0.2,
)
The schema arg is not needed when we pass the schema to `to_call` arg.
Example usage::
outputs=mm.ContrastiveOutput(
to_call=schema.select_by_tag(Tags.ITEM_ID),
negative_samplers="in-batch",
logits_temperature = 0.2,
)
"""
[docs] def __init__(
self,
to_call: Union[
Schema,
ColumnSchema,
EmbeddingTable,
CategoricalTarget,
EmbeddingTablePrediction,
DotProduct,
],
negative_samplers: ItemSamplersType,
target_name: str = None,
pre: Optional[Layer] = None,
post: Optional[Layer] = None,
logits_temperature: float = 1.0,
name: Optional[str] = None,
default_loss: Union[str, tf.keras.losses.Loss] = "categorical_crossentropy",
default_metrics_fn: MetricsFn = default_categorical_prediction_metrics,
downscore_false_negatives=True,
false_negative_score: float = MIN_FLOAT,
query_name: str = "query",
candidate_name: str = "candidate",
store_negative_ids: bool = False,
logq_sampling_correction: Optional[bool] = False,
**kwargs,
):
self.col_schema = None
_to_call = None
if to_call is not None:
if isinstance(to_call, (Schema, ColumnSchema)):
if isinstance(to_call, Schema):
if len(to_call) == 1:
to_call = to_call.first
else:
raise ValueError("to_call must be a single column schema")
self.col_schema = to_call
_to_call = CategoricalTarget(to_call)
target_name = target_name or to_call.name
elif isinstance(to_call, EmbeddingTable):
_to_call = EmbeddingTablePrediction(to_call)
target_name = _to_call.table.col_schema.name
self.col_schema = _to_call.table.col_schema
else:
_to_call = to_call
if "schema" in kwargs:
self.col_schema = kwargs.pop("schema").first
if not self.col_schema:
raise ValueError(
"schema of target couldn't be inferred, please provide ", "`schema=...`"
)
self.negative_samplers = parse_negative_samplers(negative_samplers)
self.downscore_false_negatives = downscore_false_negatives
self.false_negative_score = false_negative_score
self.query_name = query_name
self.candidate_name = candidate_name
self.store_negative_ids = store_negative_ids
self.logq_sampling_correction = logq_sampling_correction
self.target_name = kwargs.pop("target", target_name)
super().__init__(
to_call=_to_call,
default_loss=default_loss,
default_metrics_fn=default_metrics_fn,
name=name,
target=self.target_name,
pre=pre,
post=post,
logits_temperature=logits_temperature,
**kwargs,
)
[docs] def build(self, input_shape):
if (
isinstance(input_shape, dict)
and all(key in input_shape for key in self.keys)
and not isinstance(self.to_call, DotProduct)
):
self.to_call = DotProduct(*self.keys)
super().build(input_shape)
[docs] def call(self, inputs, features=None, targets=None, training=False, testing=False):
call_kwargs = dict(features=features, targets=targets, training=training, testing=testing)
if training or testing:
if self.has_candidate_weights and targets is None:
return tf_utils.call_layer(self.to_call, inputs, **call_kwargs)
return self.call_contrastive(inputs, **call_kwargs)
return tf_utils.call_layer(self.to_call, inputs, **call_kwargs)
[docs] def call_contrastive(self, inputs, features, targets, training=False, testing=False):
if isinstance(inputs, dict) and self.query_name in inputs:
query_embedding = inputs[self.query_name]
elif isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
query_embedding = inputs
else:
raise ValueError("Couldn't infer query embedding")
is_ragged = isinstance(query_embedding, tf.RaggedTensor)
if is_ragged:
# Get flat values of the ragged tensor
original_query_embedding = query_embedding
query_embedding = query_embedding.flat_values
if self.has_candidate_weights:
positive_id = targets
if isinstance(targets, dict):
positive_id = targets[self.col_schema.name]
positive_embedding = self.embedding_lookup(positive_id)
else:
positive_id = features[self.col_schema.name]
positive_embedding = inputs[self.candidate_name]
if isinstance(positive_id, tf.RaggedTensor):
# Select positive candidates at masked positions
target_mask = positive_id._keras_mask.with_row_splits_dtype(
positive_id.row_splits.dtype
)
# Flatten target tensor to have the same shape as the query tensor
positive_id = tf.ragged.boolean_mask(positive_id, target_mask)
original_target = positive_id
positive_id = positive_id.flat_values
positive_embedding = tf.ragged.boolean_mask(positive_embedding, target_mask).flat_values
positive = Candidate(id=positive_id, metadata={**features}).with_embedding(
positive_embedding
)
negative, positive = self.sample_negatives(
positive, features, training=training, testing=testing
)
if self.has_candidate_weights and (
positive.id.shape != negative.id.shape or positive != negative
):
negative = negative.with_embedding(self.embedding_lookup(negative.id))
logits = self.outputs(query_embedding, positive, negative)
if is_ragged:
logits.copy_with_updates(
outputs=original_query_embedding.with_flat_values(logits.outputs),
targets=original_target.with_flat_values(logits.targets),
)
return logits
[docs] def outputs(
self, query_embedding: tf.Tensor, positive: Candidate, negative: Candidate
) -> Prediction:
"""Method to compute the dot product between the query embeddings and
positive/negative candidates
Parameters
----------
query_embedding : tf.Tensor
tensor of query embeddings.
positive : Candidate
Store the ids and metadata (such as embeddings) of the positive candidates.
negative : Candidate
Store the ids and metadata (such as embeddings) of the sampled negative candidates.
Returns
-------
Prediction
a Prediction object with the prediction scores, the targets and
the negative candidates ids if specified.
"""
if not positive.has_embedding:
raise ValueError("Positive candidate must have an embedding")
if not negative.has_embedding:
raise ValueError("Negative candidate must have an embedding")
# Apply dot-product
negative_scores = tf.linalg.matmul(query_embedding, negative.embedding, transpose_b=True)
positive_scores = tf.reduce_sum(
tf.multiply(query_embedding, positive.embedding), keepdims=True, axis=-1
)
if self.logq_sampling_correction:
if positive.sampling_prob is None or negative.sampling_prob is None:
warnings.warn(
"The logQ sampling correction is enabled, but sampling probs were not found "
"for both positive and negative candidates",
RuntimeWarning,
)
epsilon = 1e-16
positive_scores -= tf.math.log(positive.sampling_prob + epsilon)
negative_scores -= tf.math.log(tf.transpose(negative.sampling_prob + epsilon))
if self.downscore_false_negatives:
negative_scores, _ = tf_utils.rescore_false_negatives(
positive.id, negative.id, negative_scores, self.false_negative_score
)
outputs = tf.concat([positive_scores, negative_scores], axis=-1)
# To ensure that the output is always fp32, avoiding numerical
# instabilities with mixed_float16 policy
outputs = tf.cast(outputs, tf.float32)
targets = tf.concat(
[
tf.ones([tf.shape(outputs)[0], 1], dtype=outputs.dtype),
tf.zeros(
[tf.shape(outputs)[0], tf.shape(outputs)[1] - 1],
dtype=outputs.dtype,
),
],
axis=1,
)
if self.store_negative_ids:
return Prediction(outputs, targets, negative_candidate_ids=negative.id)
return Prediction(outputs, targets)
[docs] def sample_negatives(
self,
positive: Candidate,
features: TabularData,
training=False,
testing=False,
) -> Tuple[Candidate, Candidate]:
"""Method to sample negatives from `self.negative_samplers`
Parameters
----------
positive_items : Items
Positive items (ids and metadata)
features : TabularData
Dictionary of input raw tensors
training : bool, optional
Flag for train mode, by default False
testing : bool, optional
Flag for test mode, by default False
Returns
-------
Tuple[Candidate, Candidate]
Tuple of candidates with sampled negative ids and the provided positive ids
added with the sampling probability
"""
sampling_kwargs = {"training": training, "testing": testing, "features": features}
candidates: List[Candidate] = []
if self.logq_sampling_correction and len(self.negative_samplers) > 1:
raise ValueError(
"It is only possible to apply logQ sampling correction "
"(logq_sampling_correction=True) when only one negative sampler is provided."
)
for sampler in self.negative_samplers:
neg_samples: Candidate = tf_utils.call_layer(sampler, positive, **sampling_kwargs)
# Adds to the positive and negative candidates their sampling probs from the sampler
positive = sampler.with_sampling_probs(positive)
neg_samples = sampler.with_sampling_probs(neg_samples)
if neg_samples.id is not None:
candidates.append(neg_samples)
else:
LOG.warn(
f"The sampler {type(sampler).__name__} returned no samples for this batch."
)
if len(candidates) == 0:
raise Exception(
f"No negative items where sampled from samplers {self.negative_samplers}"
)
negatives = candidates[0]
if len(candidates) > 1:
for neg in candidates[1:]:
negatives += neg
return negatives, positive
[docs] def embedding_lookup(self, ids: tf.Tensor):
return self.to_call.embedding_lookup(tf.squeeze(ids, axis=-1))
[docs] def to_dataset(self, gpu=None) -> merlin.io.Dataset:
return merlin.io.Dataset(tf_utils.tensor_to_df(self.to_call.embeddings, gpu=gpu))
@property
def has_candidate_weights(self) -> bool:
if isinstance(self.to_call, DotProduct):
return False
return isinstance(self.to_call, LookUpProtocol)
@property
def keys(self) -> List[str]:
return [self.query_name, self.candidate_name]
[docs] def set_negative_samplers(self, negative_samplers: ItemSamplersType):
if negative_samplers is not None:
negative_samplers = parse_negative_samplers(negative_samplers)
self.negative_samplers = negative_samplers
[docs] def get_config(self):
config = super().get_config()
config = tf_utils.maybe_serialize_keras_objects(self, config, ["negative_samplers"])
config["target"] = self.target_name
config["downscore_false_negatives"] = self.downscore_false_negatives
config["false_negative_score"] = self.false_negative_score
config["query_name"] = self.query_name
config["candidate_name"] = self.candidate_name
config["store_negative_ids"] = self.store_negative_ids
config["schema"] = schema_utils.schema_to_tensorflow_metadata_json(
Schema([self.col_schema])
)
return config
[docs] @classmethod
def from_config(cls, config):
config["schema"] = schema_utils.tensorflow_metadata_json_to_schema(config["schema"])
config = tf_utils.maybe_deserialize_keras_objects(config, ["negative_samplers"])
return super().from_config(config)
@runtime_checkable
class LookUpProtocol(Protocol):
"""Protocol for embedding lookup layers"""
@property
def embeddings(self):
pass
def embedding_lookup(self, inputs, **kwargs):
pass
def __call__(self, *args, **kwargs):
pass