Source code for merlin.models.tf.outputs.base

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import types as python_types
import warnings
from typing import Callable, Optional, Sequence, Union

import tensorflow as tf
from keras.utils import generic_utils
from keras.utils.generic_utils import to_snake_case
from tensorflow.keras.layers import Layer

from merlin.models.tf.core.base import name_fn
from merlin.models.tf.core.combinators import ParallelBlock
from merlin.models.tf.core.prediction import Prediction
from merlin.models.tf.transforms.bias import LogitsTemperatureScaler
from merlin.models.tf.utils import tf_utils

MetricsFn = Callable[[], Sequence[tf.keras.metrics.Metric]]

ModelOutputType = Union["ModelOutput", ParallelBlock]


[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models") class ModelOutput(Layer): """Base-class for prediction blocks. Parameters ---------- to_call : Layer The layer to call in the forward-pass of the model default_loss: Union[str, tf.keras.losses.Loss] Default loss to set if the user does not specify one default_metrics_fn: Callable A function returning the list of default metrics to set if the user does not specify any name: Optional[Text], optional Task name, by default None target: Optional[str], optional Label name, by default None pre: Optional[Block], optional Optional block to transform predictions before applying the prediction layer, by default None post: Optional[Block], optional Optional block to transform predictions after applying the prediction layer, by default None logits_temperature: float, optional Parameter used to reduce model overconfidence, so that logits / T. by default 1. """
[docs] def __init__( self, to_call: Layer, default_loss: Union[str, tf.keras.losses.Loss], default_metrics_fn: MetricsFn, name: Optional[str] = None, target: Optional[str] = None, pre: Optional[Layer] = None, post: Optional[Layer] = None, logits_temperature: float = 1.0, **kwargs, ): logits_scaler = kwargs.pop("logits_scaler", None) self.target = target self.full_name = self.get_task_name(self.target) super().__init__(name=name or self.full_name, **kwargs) self.to_call = to_call self.default_loss = default_loss self.default_metrics_fn = default_metrics_fn self.pre = pre self.post = post if logits_scaler is not None: self.logits_scaler = logits_scaler self.logits_temperature = logits_scaler.temperature else: self.logits_temperature = logits_temperature if logits_temperature != 1.0: self.logits_scaler = LogitsTemperatureScaler(logits_temperature)
@property def task_name(self) -> str: return self.full_name
[docs] def build(self, input_shape=None): """Builds the PredictionBlock. Parameters ---------- input_shape : tf.TensorShape, optional The input shape, by default None """ if self.pre is not None: self.pre.build(input_shape) input_shape = self.pre.compute_output_shape(input_shape) self.to_call.build(input_shape) input_shape = self.to_call.compute_output_shape(input_shape) if self.post is not None: self.post.build(input_shape) self.built = True
[docs] def call(self, inputs, training=False, testing=False, **kwargs): return tf_utils.call_layer( self.to_call, inputs, training=training, testing=testing, **kwargs )
[docs] def compute_output_shape(self, input_shape): output_shape = input_shape if self.pre is not None: output_shape = self.pre.compute_output_shape(output_shape) output_shape = self.to_call.compute_output_shape(output_shape) if self.post is not None: output_shape = self.post.compute_output_shape(output_shape) return output_shape
def __call__(self, inputs, *args, **kwargs): training = kwargs.get("training", False) testing = kwargs.get("testing", False) # call pre if self.pre: inputs = tf_utils.call_layer(self.pre, inputs, **kwargs) # super call outputs = super(ModelOutput, self).__call__(inputs, *args, **kwargs) if self.post: outputs = tf_utils.call_layer(self.post, outputs, target_name=self.target, **kwargs) if getattr(self, "logits_scaler", None): if isinstance(outputs, tf.Tensor): targets = kwargs.pop("targets", None) if isinstance(targets, dict) and self.target in targets: targets = targets[self.target] if training or testing: outputs = Prediction(outputs, targets) outputs = tf_utils.call_layer(self.logits_scaler, outputs, **kwargs) return outputs
[docs] def create_default_metrics(self): metrics = self.get_default_metrics() for metric in metrics: metric._name = self.full_name + "/" + to_snake_case(metric.name) return metrics
def _serialize_function_to_config(self, inputs): """function to serialize a callable function, Note: This code is adapted from Keras source code of the [Lambda layer] (https://github.com/keras-team/keras/blob/master/keras/layers/core/lambda_layer.py#L300) """ if isinstance(inputs, python_types.LambdaType): output = generic_utils.func_dump(inputs) output_type = "lambda" module = inputs.__module__ elif callable(inputs): output = inputs.__name__ output_type = "function" module = inputs.__module__ else: raise ValueError("Invalid input for serialization, type: %s " % type(inputs)) return output, output_type, module @classmethod def _parse_function_from_config( cls, config, func_attr_name, module_attr_name, func_type_attr_name ): """ "function to de-serialize a callable function, Note: This code is adapted from Keras source code of the [Lambda layer] (https://github.com/keras-team/keras/blob/master/keras/layers/core/lambda_layer.py#L350) """ globs = globals().copy() module = config.pop(module_attr_name, None) if module in sys.modules: globs.update(sys.modules[module].__dict__) elif module is not None: # Note: we don't know the name of the function if it's a lambda. warnings.warn( "{} is not loaded, but a Lambda layer uses it. " "It may cause errors.".format(module), UserWarning, stacklevel=2, ) function_type = config.pop(func_type_attr_name) if function_type == "function": function = generic_utils.deserialize_keras_object( config[func_attr_name], printable_module_name="default metrics function" ) elif function_type == "lambda": # Unsafe deserialization from bytecode function = generic_utils.func_load(config[func_attr_name], globs=globs) else: supported_types = ["function", "lambda"] raise TypeError( f"Unsupported value for `function_type` argument. Received: " f"function_type={function_type}. Expected one of {supported_types}" ) return function
[docs] def get_config(self): config = super(ModelOutput, self).get_config() function_config = self._serialize_function_to_config(self.default_metrics_fn) config.update( { "default_metrics_fn": function_config[0], "function_type": function_config[1], "module": function_config[2], "target": self.target, } ) objects = [ "to_call", "pre", "post", "logits_scaler", ] if isinstance(self.default_loss, str): config["default_loss"] = self.default_loss else: objects.append("default_loss") config = tf_utils.maybe_serialize_keras_objects(self, config, objects) return config
[docs] @classmethod def get_task_name(cls, target_name: str) -> str: """Returns the name of the task Parameters ---------- target_name : str Name of the target Returns ------- str Returns the task name, which includes the target name """ base_name = to_snake_case(cls.__name__) return name_fn(target_name, base_name) if target_name else base_name
[docs] @classmethod def from_config(cls, config): config["default_metrics_fn"] = cls._parse_function_from_config( config, "default_metrics_fn", "module", "function_type" ) config = tf_utils.maybe_deserialize_keras_objects( config, { "default_loss": tf.keras.losses.deserialize, "to_call": tf.keras.layers.deserialize, "pre": tf.keras.layers.deserialize, "post": tf.keras.layers.deserialize, "logits_scaler": tf.keras.layers.deserialize, }, ) return super().from_config(config)
@tf.keras.utils.register_keras_serializable(package="merlin_models") class DotProduct(Layer): """Dot-product between queries & items. Parameters: ----------- query_name : str, optional Identify query tower for query/user embeddings, by default 'query' item_name : str, optional Identify item tower for item embeddings, by default 'item' """ def __init__(self, query_name: str = "query", item_name: str = "candidate", **kwargs): super().__init__(**kwargs) self.query_name = query_name self.item_name = item_name def call(self, inputs, **kwargs): return tf.reduce_sum( tf.multiply(inputs[self.query_name], inputs[self.item_name]), keepdims=True, axis=-1 ) def compute_output_shape(self, input_shape): batch_size = tf_utils.calculate_batch_size_from_input_shapes(input_shape) return batch_size, 1 def get_config(self): return { **super(DotProduct, self).get_config(), "query_name": self.query_name, "item_name": self.item_name, }