#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import types as python_types
import warnings
from typing import Callable, Optional, Sequence, Union
import tensorflow as tf
from keras.utils import generic_utils
from keras.utils.generic_utils import to_snake_case
from tensorflow.keras.layers import Layer
from merlin.models.tf.core.base import name_fn
from merlin.models.tf.core.combinators import ParallelBlock
from merlin.models.tf.core.prediction import Prediction
from merlin.models.tf.transforms.bias import LogitsTemperatureScaler
from merlin.models.tf.utils import tf_utils
MetricsFn = Callable[[], Sequence[tf.keras.metrics.Metric]]
ModelOutputType = Union["ModelOutput", ParallelBlock]
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class ModelOutput(Layer):
"""Base-class for prediction blocks.
Parameters
----------
to_call : Layer
The layer to call in the forward-pass of the model
default_loss: Union[str, tf.keras.losses.Loss]
Default loss to set if the user does not specify one
default_metrics_fn: Callable
A function returning the list of default metrics to set
if the user does not specify any
name: Optional[Text], optional
Task name, by default None
target: Optional[str], optional
Label name, by default None
pre: Optional[Block], optional
Optional block to transform predictions before applying the prediction layer,
by default None
post: Optional[Block], optional
Optional block to transform predictions after applying the prediction layer,
by default None
logits_temperature: float, optional
Parameter used to reduce model overconfidence, so that logits / T.
by default 1.
"""
[docs] def __init__(
self,
to_call: Layer,
default_loss: Union[str, tf.keras.losses.Loss],
default_metrics_fn: MetricsFn,
name: Optional[str] = None,
target: Optional[str] = None,
pre: Optional[Layer] = None,
post: Optional[Layer] = None,
logits_temperature: float = 1.0,
**kwargs,
):
logits_scaler = kwargs.pop("logits_scaler", None)
self.target = target
self.full_name = self.get_task_name(self.target)
super().__init__(name=name or self.full_name, **kwargs)
self.to_call = to_call
self.default_loss = default_loss
self.default_metrics_fn = default_metrics_fn
self.pre = pre
self.post = post
if logits_scaler is not None:
self.logits_scaler = logits_scaler
self.logits_temperature = logits_scaler.temperature
else:
self.logits_temperature = logits_temperature
if logits_temperature != 1.0:
self.logits_scaler = LogitsTemperatureScaler(logits_temperature)
@property
def task_name(self) -> str:
return self.full_name
[docs] def build(self, input_shape=None):
"""Builds the PredictionBlock.
Parameters
----------
input_shape : tf.TensorShape, optional
The input shape, by default None
"""
if self.pre is not None:
self.pre.build(input_shape)
input_shape = self.pre.compute_output_shape(input_shape)
self.to_call.build(input_shape)
input_shape = self.to_call.compute_output_shape(input_shape)
if self.post is not None:
self.post.build(input_shape)
self.built = True
[docs] def call(self, inputs, training=False, testing=False, **kwargs):
return tf_utils.call_layer(
self.to_call, inputs, training=training, testing=testing, **kwargs
)
[docs] def compute_output_shape(self, input_shape):
output_shape = input_shape
if self.pre is not None:
output_shape = self.pre.compute_output_shape(output_shape)
output_shape = self.to_call.compute_output_shape(output_shape)
if self.post is not None:
output_shape = self.post.compute_output_shape(output_shape)
return output_shape
def __call__(self, inputs, *args, **kwargs):
training = kwargs.get("training", False)
testing = kwargs.get("testing", False)
# call pre
if self.pre:
inputs = tf_utils.call_layer(self.pre, inputs, **kwargs)
# super call
outputs = super(ModelOutput, self).__call__(inputs, *args, **kwargs)
if self.post:
outputs = tf_utils.call_layer(self.post, outputs, target_name=self.target, **kwargs)
if getattr(self, "logits_scaler", None):
if isinstance(outputs, tf.Tensor):
targets = kwargs.pop("targets", None)
if isinstance(targets, dict) and self.target in targets:
targets = targets[self.target]
if training or testing:
outputs = Prediction(outputs, targets)
outputs = tf_utils.call_layer(self.logits_scaler, outputs, **kwargs)
return outputs
[docs] def create_default_metrics(self):
metrics = self.get_default_metrics()
for metric in metrics:
metric._name = self.full_name + "/" + to_snake_case(metric.name)
return metrics
def _serialize_function_to_config(self, inputs):
"""function to serialize a callable function,
Note: This code is adapted from Keras source code of
the [Lambda layer]
(https://github.com/keras-team/keras/blob/master/keras/layers/core/lambda_layer.py#L300)
"""
if isinstance(inputs, python_types.LambdaType):
output = generic_utils.func_dump(inputs)
output_type = "lambda"
module = inputs.__module__
elif callable(inputs):
output = inputs.__name__
output_type = "function"
module = inputs.__module__
else:
raise ValueError("Invalid input for serialization, type: %s " % type(inputs))
return output, output_type, module
@classmethod
def _parse_function_from_config(
cls, config, func_attr_name, module_attr_name, func_type_attr_name
):
""" "function to de-serialize a callable function,
Note: This code is adapted from Keras source code of
the [Lambda layer]
(https://github.com/keras-team/keras/blob/master/keras/layers/core/lambda_layer.py#L350)
"""
globs = globals().copy()
module = config.pop(module_attr_name, None)
if module in sys.modules:
globs.update(sys.modules[module].__dict__)
elif module is not None:
# Note: we don't know the name of the function if it's a lambda.
warnings.warn(
"{} is not loaded, but a Lambda layer uses it. "
"It may cause errors.".format(module),
UserWarning,
stacklevel=2,
)
function_type = config.pop(func_type_attr_name)
if function_type == "function":
function = generic_utils.deserialize_keras_object(
config[func_attr_name], printable_module_name="default metrics function"
)
elif function_type == "lambda":
# Unsafe deserialization from bytecode
function = generic_utils.func_load(config[func_attr_name], globs=globs)
else:
supported_types = ["function", "lambda"]
raise TypeError(
f"Unsupported value for `function_type` argument. Received: "
f"function_type={function_type}. Expected one of {supported_types}"
)
return function
[docs] def get_config(self):
config = super(ModelOutput, self).get_config()
function_config = self._serialize_function_to_config(self.default_metrics_fn)
config.update(
{
"default_metrics_fn": function_config[0],
"function_type": function_config[1],
"module": function_config[2],
"target": self.target,
}
)
objects = [
"to_call",
"pre",
"post",
"logits_scaler",
]
if isinstance(self.default_loss, str):
config["default_loss"] = self.default_loss
else:
objects.append("default_loss")
config = tf_utils.maybe_serialize_keras_objects(self, config, objects)
return config
[docs] @classmethod
def get_task_name(cls, target_name: str) -> str:
"""Returns the name of the task
Parameters
----------
target_name : str
Name of the target
Returns
-------
str
Returns the task name, which includes the target name
"""
base_name = to_snake_case(cls.__name__)
return name_fn(target_name, base_name) if target_name else base_name
[docs] @classmethod
def from_config(cls, config):
config["default_metrics_fn"] = cls._parse_function_from_config(
config, "default_metrics_fn", "module", "function_type"
)
config = tf_utils.maybe_deserialize_keras_objects(
config,
{
"default_loss": tf.keras.losses.deserialize,
"to_call": tf.keras.layers.deserialize,
"pre": tf.keras.layers.deserialize,
"post": tf.keras.layers.deserialize,
"logits_scaler": tf.keras.layers.deserialize,
},
)
return super().from_config(config)
@tf.keras.utils.register_keras_serializable(package="merlin_models")
class DotProduct(Layer):
"""Dot-product between queries & items.
Parameters:
-----------
query_name : str, optional
Identify query tower for query/user embeddings, by default 'query'
item_name : str, optional
Identify item tower for item embeddings, by default 'item'
"""
def __init__(self, query_name: str = "query", item_name: str = "candidate", **kwargs):
super().__init__(**kwargs)
self.query_name = query_name
self.item_name = item_name
def call(self, inputs, **kwargs):
return tf.reduce_sum(
tf.multiply(inputs[self.query_name], inputs[self.item_name]), keepdims=True, axis=-1
)
def compute_output_shape(self, input_shape):
batch_size = tf_utils.calculate_batch_size_from_input_shapes(input_shape)
return batch_size, 1
def get_config(self):
return {
**super(DotProduct, self).get_config(),
"query_name": self.query_name,
"item_name": self.item_name,
}