Source code for

# Copyright (c) 2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, Optional, Sequence, Union

import tensorflow as tf
from keras.utils import tf_inspect
from tensorflow.python.ops import embedding_ops

from import ItemSampler
from import Block, BlockType, EmbeddingWithMetadata, PredictionOutput
from import ParallelBlock
from import TabularAggregationType
from import ModelBlock
from import L2Norm
from import TabularData
from import (
from merlin.models.utils.constants import MIN_FLOAT
from merlin.schema import Schema

LOG = logging.getLogger("merlin_models")

class TowerBlock(ModelBlock):
    """TowerBlock to wrap item or query tower"""


class RetrievalMixin:
    def query_block(self) -> TowerBlock:
        """Method to return the query tower from a RetrievalModel instance"""
        raise NotImplementedError()

    def item_block(self) -> TowerBlock:
        """Method to return the item tower from a RetrievalModel instance"""
        raise NotImplementedError()

[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models") class DualEncoderBlock(ParallelBlock):
[docs] def __init__( self, query_block: Block, item_block: Block, pre: Optional[BlockType] = None, post: Optional[BlockType] = None, aggregation: Optional[TabularAggregationType] = None, schema: Optional[Schema] = None, name: Optional[str] = None, strict: bool = False, l2_normalization: bool = False, **kwargs, ): """Prepare the Query and Item towers of a Retrieval block Parameters ---------- query_block : Block The `Block` instance that combines user features item_block : Block Optional `Block` instance that combines items features. pre : Optional[BlockType], optional Optional `Block` instance to apply before the `call` method of the Two-Tower block post : Optional[BlockType], optional Optional `Block` instance to apply on both outputs of Two-tower model aggregation : Optional[TabularAggregationType], optional The Aggregation operation to apply after processing the `call` method to output a single Tensor. schema : Optional[Schema], optional The `Schema` object with the input features. name : Optional[str], optional Name of the layer. strict : bool, optional If enabled, check that the input of the ParallelBlock instance is a dictionary. l2_normalization: bool Apply L2 normalization to the user and item representations before computing dot interactions. Defaults to False. """ if l2_normalization: query_block = query_block.connect(L2Norm()) item_block = item_block.connect(L2Norm()) self._query_block = TowerBlock(query_block) self._item_block = TowerBlock(item_block) branches = {"query": self._query_block, "item": self._item_block} super().__init__( branches, pre=pre, post=post, aggregation=aggregation, schema=schema, name=name, strict=strict, **kwargs, )
[docs] def query_block(self) -> TowerBlock: return self._query_block
[docs] def item_block(self) -> TowerBlock: return self._item_block
[docs] @classmethod def from_config(cls, config, custom_objects=None): inputs, config = cls.parse_config(config, custom_objects) output = ParallelBlock(inputs, **config) output.__class__ = cls return output
[docs]@Block.registry.register_with_multiple_names("item_retrieval_scorer") @tf.keras.utils.register_keras_serializable(package="merlin_models") class ItemRetrievalScorer(Block): """Block for ItemRetrieval, which expects query/user and item embeddings as input and uses dot product to score the positive item (inputs["item"]) and also sampled negative items (during training). Parameters ---------- samplers: List[ItemSampler], optional List of item samplers that provide negative samples when `training=True` sampling_downscore_false_negatives: bool, optional Identify false negatives (sampled item ids equal to the positive item and downscore them to the `sampling_downscore_false_negatives_value`), by default True sampling_downscore_false_negatives_value: int, optional Value to be used to downscore false negatives when `sampling_downscore_false_negatives=True`, by default `np.finfo(np.float32).min / 100.0` item_id_feature_name: str Name of the column containing the item ids Defaults to `item_id` query_name: str Identify query tower for query/user embeddings, by default 'query' item_name: str Identify item tower for item embeddings, by default'item' cache_query: bool Add query embeddings to the context block, by default False sampled_softmax_mode: bool Use sampled softmax for scoring, by default False store_negative_ids: bool Returns negative items ids as part of the output, by default False """
[docs] def __init__( self, samplers: Sequence[ItemSampler] = (), sampling_downscore_false_negatives=True, sampling_downscore_false_negatives_value: float = MIN_FLOAT, item_id_feature_name: str = "item_id", item_domain: str = "item_id", query_name: str = "query", item_name: str = "item", cache_query: bool = False, sampled_softmax_mode: bool = False, store_negative_ids: bool = False, **kwargs, ): """Initializes the `ItemRetrievalScorer` class.""" super().__init__(**kwargs) self.downscore_false_negatives = sampling_downscore_false_negatives self.false_negatives_score = sampling_downscore_false_negatives_value self.item_id_feature_name = item_id_feature_name self.item_domain = item_domain self.query_name = query_name self.item_name = item_name self.cache_query = cache_query self.store_negative_ids = store_negative_ids if not isinstance(samplers, (list, tuple)): samplers = (samplers,) # type: ignore self.samplers = samplers self.sampled_softmax_mode = sampled_softmax_mode self.set_required_features()
[docs] def build(self, input_shapes): """Builds the block. Parameters ---------- input_shapes: tuple or dict Shape of the input tensor. """ if isinstance(input_shapes, dict): query_shape = input_shapes[self.query_name] self.context.add_weight( name="query", shape=query_shape, dtype=tf.float32, trainable=False, initializer=tf.keras.initializers.Zeros(), ) super().build(input_shapes)
def _check_input_from_two_tower(self, inputs): """Checks if the inputs from the two towers (query and item) are correctly provided. Parameters ---------- inputs: dict Dictionary of inputs. """ if set(inputs.keys()) != set([self.query_name, self.item_name]): raise ValueError( f"Wrong input-names, expected: {[self.query_name, self.item_name]} " f"but got: {inputs.keys()}" )
[docs] def call( self, inputs: Union[tf.Tensor, TabularData], training: bool = True, testing: bool = False, **kwargs, ) -> Union[tf.Tensor, TabularData]: """Based on the user/query embedding (inputs[self.query_name]), uses dot product to score the positive item (inputs["item"]). For the sampled-softmax mode, logits are computed by multiplying the query vector and the item embeddings matrix (self.context.get_embedding(self.item_domain)) Parameters ---------- inputs : Union[tf.Tensor, TabularData] Dict with the query and item embeddings (e.g. `{"query": <emb>}, "item": <emb>}`), where embeddings are 2D tensors (batch size, embedding size) training : bool, optional Flag that indicates whether in training mode, by default True Returns ------- tf.Tensor 2D Tensor with the scores for the positive items, If `training=True`, return the original inputs """ if self.cache_query: # enabled only during top-k evaluation query = inputs[self.query_name] context_query_size = tf.shape(self.context["query"])[0] # pad with zeros to match shape of initial query variable padding_size = context_query_size - tf.shape(query)[0] if padding_size > 0: query = tf.pad(query, [[0, padding_size], [0, 0]]) query = query[:context_query_size] self.context["query"].assign(tf.cast(query, tf.float32)) if training or testing: return inputs if self.sampled_softmax_mode: return self._get_logits_for_sampled_softmax(inputs) self._check_input_from_two_tower(inputs) positive_scores = tf.reduce_sum( tf.multiply(inputs[self.query_name], inputs[self.item_name]), keepdims=True, axis=-1 ) return positive_scores
[docs] @tf.function def call_outputs( self, outputs: PredictionOutput, features: Dict[str, tf.Tensor] = None, training=True, testing=False, **kwargs, ) -> "PredictionOutput": """Based on the user/query embedding (inputs[self.query_name]), uses dot product to score the positive item and also sampled negative items (during training). Parameters ---------- inputs : TabularData Dict with the query and item embeddings (e.g. `{"query": <emb>}, "item": <emb>}`), where embeddings are 2D tensors (batch size, embedding size) training : bool, optional Flag that indicates whether in training mode, by default True Returns ------- [tf.Tensor,tf.Tensor] all_scores: 2D Tensor with the scores for the positive items and, if `training=True`, for the negative sampled items too. Return tensor is 2D (batch size, 1 + #negatives) """ targets, predictions = outputs.targets, outputs.predictions valid_negatives_mask = None if self.sampled_softmax_mode or isinstance(targets, tf.Tensor): positive_item_ids = targets else: positive_item_ids = tf.squeeze(features[self.item_id_feature_name]) neg_items_ids = None if training or testing: assert ( len(self.samplers) > 0 ), "At least one sampler is required by ItemRetrievalScorer for negative sampling" if self.sampled_softmax_mode: predictions = self._prepare_query_item_vectors_for_sampled_softmax( predictions, targets ) batch_items_embeddings = predictions[self.item_name] if self.sampled_softmax_mode: batch_items_metadata = {self.item_id_feature_name: positive_item_ids} else: batch_items_metadata = { feat_name: tf.squeeze(features[feat_name]) for feat_name in self._required_features } positive_scores = tf.reduce_sum( tf.multiply(predictions[self.query_name], predictions[self.item_name]), keepdims=True, axis=-1, ) neg_items_embeddings_list = [] neg_items_ids_list = [] # Adds items from the current batch into samplers and sample a number of negatives for sampler in self.samplers: input_data = EmbeddingWithMetadata(batch_items_embeddings, batch_items_metadata) sampling_kwargs = {"training": training} if "item_weights" in tf_inspect.getargspec( sampling_kwargs["item_weights"] = self.context.get_embedding(self.item_domain) neg_items = sampler(input_data.__dict__, **sampling_kwargs) if tf.shape(neg_items.embeddings)[0] > 0: # Accumulates sampled negative items from all samplers neg_items_embeddings_list.append(neg_items.embeddings) if self.downscore_false_negatives: neg_items_ids_list.append(neg_items.metadata[self.item_id_feature_name]) else: LOG.warn( f"The sampler {type(sampler).__name__} returned no samples for this batch." ) if len(neg_items_embeddings_list) == 0: raise Exception(f"No negative items where sampled from samplers {self.samplers}") elif len(neg_items_embeddings_list) == 1: neg_items_embeddings = neg_items_embeddings_list[0] else: neg_items_embeddings = tf.concat(neg_items_embeddings_list, axis=0) negative_scores = tf.linalg.matmul( predictions[self.query_name], neg_items_embeddings, transpose_b=True ) if self.downscore_false_negatives or self.store_negative_ids: if isinstance(targets, tf.Tensor): positive_item_ids = targets else: positive_item_ids = tf.squeeze(features[self.item_id_feature_name]) if len(neg_items_ids_list) == 1: neg_items_ids = neg_items_ids_list[0] else: neg_items_ids = tf.concat(neg_items_ids_list, axis=0) negative_scores, valid_negatives_mask = rescore_false_negatives( positive_item_ids, neg_items_ids, negative_scores, self.false_negatives_score ) predictions = tf.concat([positive_scores, negative_scores], axis=-1) # To ensure that the output is always fp32, avoiding numerical # instabilities with mixed_float16 policy predictions = tf.cast(predictions, tf.float32) assert isinstance(predictions, tf.Tensor), "Predictions must be a tensor" # prepare targets for computing the loss and metrics if self.sampled_softmax_mode and not training: # Converts target ids to one-hot representation num_classes = tf.shape(predictions)[-1] targets_one_hot = tf.one_hot(tf.reshape(targets, (-1,)), num_classes) return PredictionOutput( predictions, targets_one_hot, positive_item_ids=positive_item_ids, valid_negatives_mask=valid_negatives_mask, negative_item_ids=neg_items_ids, ) else: # Positives in the first column and negatives in the subsequent columns targets = tf.concat( [ tf.ones([tf.shape(predictions)[0], 1], dtype=predictions.dtype), tf.zeros( [tf.shape(predictions)[0], tf.shape(predictions)[1] - 1], dtype=predictions.dtype, ), ], axis=1, ) return PredictionOutput( predictions, targets, positive_item_ids=positive_item_ids, valid_negatives_mask=valid_negatives_mask, negative_item_ids=neg_items_ids, )
def _get_logits_for_sampled_softmax(self, inputs): if not isinstance(inputs, tf.Tensor): raise ValueError( f"Inputs to the Sampled Softmax block should be tensors, got {type(inputs)}" ) embedding_table = self.context.get_embedding(self.item_domain) all_scores = tf.matmul(inputs, tf.transpose(embedding_table)) return all_scores def _prepare_query_item_vectors_for_sampled_softmax( self, predictions: tf.Tensor, targets: tf.Tensor ): # extract positive items embeddings if not isinstance(predictions, tf.Tensor): raise ValueError( f"Inputs to the Sampled Softmax block should be tensors, got {type(predictions)}" ) embedding_table = self.context.get_embedding(self.item_domain) batch_items_embeddings = embedding_ops.embedding_lookup(embedding_table, targets) predictions = {self.query_name: predictions, self.item_name: batch_items_embeddings} return predictions
[docs] def set_required_features(self): """Sets the required features for the samplers.""" required_features = set() if self.downscore_false_negatives: required_features.add(self.item_id_feature_name) required_features.update( [feature for sampler in self.samplers for feature in sampler.required_features] ) self._required_features = list(required_features)
[docs] def get_config(self): """Returns the configuration of the model as a dictionary. Returns ------- dict The configuration of the model. """ config = super().get_config() config = maybe_serialize_keras_objects(self, config, ["samplers"]) config["sampling_downscore_false_negatives"] = self.downscore_false_negatives config["sampling_downscore_false_negatives_value"] = self.false_negatives_score config["item_id_feature_name"] = self.item_id_feature_name config["item_domain"] = self.item_domain config["query_name"] = self.query_name config["item_name"] = self.item_name config["cache_query"] = self.cache_query config["sampled_softmax_mode"] = self.sampled_softmax_mode config["store_negative_ids"] = self.store_negative_ids return config
[docs] @classmethod def from_config(cls, config): """Creates a new instance of the class from its config. Parameters ---------- config: dict A dictionary, typically the output of get_config. Returns ------- A new instance of the `ItemRetrievalScorer` class. """ config = maybe_deserialize_keras_objects(config, ["samplers"]) return super().from_config(config)