#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from typing import List, Optional, Union
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.python.ops import embedding_ops
import merlin.io
from merlin.models.tf.inputs.embedding import EmbeddingTable
from merlin.models.tf.metrics.topk import AvgPrecisionAt, MRRAt, NDCGAt, PrecisionAt, RecallAt
from merlin.models.tf.outputs.base import MetricsFn, ModelOutput
from merlin.models.tf.utils.tf_utils import (
maybe_deserialize_keras_objects,
maybe_serialize_keras_objects,
tensor_to_df,
)
from merlin.schema import ColumnSchema, Schema
LOG = logging.getLogger("merlin_models")
def default_binary_metrics() -> List[tf.keras.metrics.Metric]:
"""Returns the default binary metrics
Returns
-------
List[tf.keras.metrics.Metric]
List with metrics for binary classification
"""
return (
tf.keras.metrics.Precision(name="precision"),
tf.keras.metrics.Recall(name="recall"),
tf.keras.metrics.BinaryAccuracy(name="binary_accuracy"),
tf.keras.metrics.AUC(name="auc"),
)
def default_categorical_prediction_metrics(k=10):
"""Returns the default top-k metrics for
categorical classification
Returns
-------
List[tf.keras.metrics.Metric]
List with top-k metrics for categorical classification
"""
return (
RecallAt(k),
MRRAt(k),
NDCGAt(k),
AvgPrecisionAt(k),
PrecisionAt(k),
)
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class BinaryOutput(ModelOutput):
"""
Binary-classification prediction block.
Parameters
----------
target: Union[str, Schema], optional
The name of the target. If a Schema is provided, the target is inferred from the schema.
pre: Optional[Block], optional
Optional block to transform predictions before computing the binary logits,
by default None
post: Optional[Block], optional
Optional block to transform the binary logits,
by default None
name: str, optional
The name of the task.
logits_temperature: float, optional
Parameter used to reduce model overconfidence, so that logits / T.
by default 1.
default_loss: Union[str, tf.keras.losses.Loss], optional
Default loss to use for binary-classification
by 'binary_crossentropy'
default_metrics_fn: Callable
A function returning the list of default metrics
to use for binary-classification
"""
[docs] def __init__(
self,
target: Optional[Union[str, ColumnSchema]] = None,
pre: Optional[Layer] = None,
post: Optional[Layer] = None,
logits_temperature: float = 1.0,
name: Optional[str] = None,
default_loss: Union[str, tf.keras.losses.Loss] = "binary_crossentropy",
default_metrics_fn: MetricsFn = default_binary_metrics,
**kwargs,
):
if isinstance(target, ColumnSchema):
target = target.name
to_call = kwargs.pop("to_call", None)
super().__init__(
to_call=to_call or tf.keras.layers.Dense(1, activation="sigmoid"),
default_loss=default_loss,
default_metrics_fn=default_metrics_fn,
target=target,
pre=pre,
post=post,
logits_temperature=logits_temperature,
name=name,
**kwargs,
)
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class CategoricalOutput(ModelOutput):
"""Categorical output
Parameters
----------
prediction: Union[Schema, ColumnSchema,
EmbeddingTable, 'CategoricalTarget',
'EmbeddingTablePrediction']
The target feature to predict. To perform weight-tying [1] technique, you should provide
the `EmbeddingTable` or `EmbeddingTablePrediction` related to the
target feature.
negative_samplers: ItemSamplersType, optional
List of samplers for negative sampling,
by default None
pre: Optional[Block], optional
Optional block to transform predictions before computing the binary logits,
by default None
post: Optional[Block], optional
Optional block to transform the binary logits,
by default None
logits_temperature: float, optional
Parameter used to reduce model overconfidence, so that logits / T.
by default 1
name: str, optional
The name of the task, by default None
default_loss: Union[str, tf.keras.losses.Loss], optional
Default loss to use for categorical-classification
by default 'categorical_crossentropy'
get_default_metrics: Callable, optional
A function returning the list of default metrics
to use for categorical-classification
References:
----------
[1] Hakan Inan, Khashayar Khosravi, and Richard Socher. 2016. Tying word vectors
and word classifiers: A loss framework for language modeling. arXiv preprint
arXiv:1611.01462 (2016).
"""
[docs] def __init__(
self,
to_call: Union[
Schema, ColumnSchema, EmbeddingTable, "CategoricalTarget", "EmbeddingTablePrediction"
],
target_name: str = None,
pre: Optional[Layer] = None,
post: Optional[Layer] = None,
logits_temperature: float = 1.0,
name: Optional[str] = None,
default_loss: Union[str, tf.keras.losses.Loss] = "categorical_crossentropy",
default_metrics_fn: MetricsFn = default_categorical_prediction_metrics,
**kwargs,
):
self.max_num_samples = kwargs.pop("max_num_samples", None)
_to_call = kwargs.pop("to_call", None)
if to_call is not None:
if isinstance(to_call, (Schema, ColumnSchema)):
_to_call = CategoricalTarget(to_call)
if isinstance(to_call, Schema):
to_call = to_call.first
target_name = target_name or to_call.name
elif isinstance(to_call, EmbeddingTable):
_to_call = EmbeddingTablePrediction(to_call)
target_name = _to_call.table.col_schema.name
else:
_to_call = to_call
self.target_name = kwargs.pop("target", target_name)
super().__init__(
to_call=_to_call,
default_loss=default_loss,
default_metrics_fn=default_metrics_fn,
name=name,
target=self.target_name,
pre=pre,
post=post,
logits_temperature=logits_temperature,
**kwargs,
)
[docs] def to_dataset(self, gpu=True) -> merlin.io.Dataset:
return merlin.io.Dataset(tensor_to_df(self.to_call.embeddings, gpu=gpu))
[docs] def get_config(self):
config = super().get_config()
config["max_num_samples"] = self.max_num_samples
config["target_name"] = self.target_name
return config
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class CategoricalTarget(tf.keras.layers.Dense):
"""Prediction of a categorical feature.
Parameters
----------
feature : Union[Schema, ColumnSchema]
The schema description of the categorical feature to predict.
activation : optional
Activation function to use, by default None
use_bias : bool, optional
Whether the layer uses a bias vector, by default True
kernel_initializer : str, optional
Initializer for the kernel weights matrix, by default "glorot_uniform"
bias_initializer : str, optional
Initializer for the bias vector., by default "zeros"
kernel_regularizer : optional
Regularizer function applied to the kernel weights matrix, by default None
bias_regularizer : optional
Regularizer function applied to the bias vector, by default None
activity_regularizer : optional
Regularizer function applied to the output of the layer (its "activation"),
by default None
kernel_constraint : optional
Constraint function applied to the kernel weights matrix, by default None
bias_constraint : optional
Constraint function applied to the bias vector, by default None
"""
def __init__(
self,
feature: Union[Schema, ColumnSchema] = None,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs,
):
if feature is not None:
if isinstance(feature, Schema):
assert len(feature) == 1, "Schema can have max 1 feature"
col_schema = feature.first
else:
col_schema = feature
self.num_classes = col_schema.int_domain.max + 1
units = self.num_classes
else:
units = kwargs.pop("units")
self.num_classes = units
super().__init__(
units,
activation,
use_bias,
kernel_initializer,
bias_initializer,
kernel_regularizer,
bias_regularizer,
activity_regularizer,
kernel_constraint,
bias_constraint,
**kwargs,
)
def embedding_lookup(self, inputs: tf.Tensor, **kwargs):
"""Method to retrieve hidden representation vectors from the kernel weight matrix
based on a given "input" positions
Parameters
----------
inputs : tf.Tensor
Tensor of indices to retrieve from the weight matrix.
Returns
-------
tf.Tensor
Tensor of hidden representation vectors.
"""
return embedding_ops.embedding_lookup(self.embeddings, inputs, **kwargs)
@property
def embeddings(self):
return tf.transpose(self.kernel)
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class EmbeddingTablePrediction(Layer):
"""Prediction of a categorical feature using weight-sharing [1] with an embedding table
Parameters
----------
table : EmbeddingTable
The embedding table to use as the weight matrix
bias_initializer : str, optional
Initializer for the bias vector, by default "zeros"
use_bias: bool, optional
Whether to add a bias term to weight-tying, by default False
References:
----------
[1] Hakan Inan, Khashayar Khosravi, and Richard Socher. 2016. Tying word vectors
and word classifiers: A loss framework for language modeling. arXiv preprint
arXiv:1611.01462 (2016).
"""
def __init__(self, table: EmbeddingTable, bias_initializer="zeros", use_bias=False, **kwargs):
self.table = table
self.num_classes = table.input_dim
self.bias_initializer = bias_initializer
self.use_bias = use_bias
super().__init__(**kwargs)
def build(self, input_shape):
if self.use_bias:
self.bias = self.add_weight(
name="output_layer_bias",
shape=(self.num_classes,),
initializer=self.bias_initializer,
)
self.table.build(input_shape)
return super().build(input_shape)
def call(self, inputs, training=False, **kwargs) -> tf.Tensor:
is_ragged = isinstance(inputs, tf.RaggedTensor)
if is_ragged:
original_inputs = inputs
inputs = inputs.flat_values
logits = tf.matmul(inputs, self.table.table.embeddings, transpose_b=True)
if self.use_bias:
logits = tf.nn.bias_add(logits, self.bias)
if is_ragged:
logits = original_inputs.with_flat_values(logits)
return logits
@property
def embeddings(self):
return self.table.table.embeddings
def embedding_lookup(self, inputs, **kwargs):
return self.table.table(inputs, **kwargs)
def compute_output_shape(self, input_shape):
return (input_shape[0], self.num_classes)
def get_config(self):
config = maybe_serialize_keras_objects(
self,
{
**super().get_config(),
},
["table"],
)
return config
@classmethod
def from_config(cls, config):
config = maybe_deserialize_keras_objects(config, ["table"])
return super().from_config(config)