Source code for

# Copyright (c) 2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Optional, Union

import tensorflow as tf
from tensorflow.keras.layers import Dense, Layer

from import Block, MetricOrMetrics
from import PredictionTask
from import (
from merlin.models.utils.schema_utils import categorical_cardinalities
from merlin.schema import Schema, Tags

[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models") class BinaryClassificationTask(PredictionTask): """ Prediction task for binary classification. Parameters ---------- target: Union[str, Schema], optional The name of the target. If a Schema is provided, the target is inferred from the schema. task_name: str, optional The name of the task. task_block: Block, optional The block to use for the task. """ # Default loss to use DEFAULT_LOSS = "binary_crossentropy" # Default metrics to use DEFAULT_METRICS = ( partial(tf.keras.metrics.Precision, name="precision"), partial(tf.keras.metrics.Recall, name="recall"), partial(tf.keras.metrics.BinaryAccuracy, name="binary_accuracy"), partial(tf.keras.metrics.AUC, name="auc"), )
[docs] def __init__( self, target: Optional[Union[str, Schema]] = None, task_name: Optional[str] = None, task_block: Optional[Layer] = None, **kwargs, ): if isinstance(target, Schema): target_name = target.select_by_tag(Tags.BINARY_CLASSIFICATION) if not target_name.column_names: raise ValueError( "Binary classification task requires a column with a ", "`Tags.BINARY_CLASSIFICATION` tag.", ) elif len(target_name.column_names) > 1: raise ValueError( "Binary classification task requires a single column with a ", "`Tags.BINARY_CLASSIFICATION` tag. ", "Found {} columns. ".format(len(target_name.column_names)), "Please specify the column name with the `target` argument.", ) target_name = target_name.column_names[0] else: target_name = target if target else kwargs.pop("target_name", None) output_layer = kwargs.pop("output_layer", None) super().__init__( target_name=target_name, task_name=task_name, task_block=task_block, **kwargs, ) self.output_layer = output_layer or tf.keras.layers.Dense( 1, activation="linear", name="output_layer" ) # To ensure that the output is always fp32, avoiding numerical # instabilities with mixed_float16 (fp16) policy self.output_activation = tf.keras.layers.Activation( "sigmoid", dtype="float32", name="prediction" )
[docs] def call(self, inputs, training=False, **kwargs): """Projects the input with the output layer to a single logit Parameters ---------- inputs : tf.Tensor Input tensor training : bool, optional Flag that indicates whether it is training or not, by default False Returns ------- tf.Tensor Tensor with the classification probabilities """ return self.output_activation(self.output_layer(inputs))
[docs] def compute_output_shape(self, input_shape): """Computes the output shape based on the input shape Parameters ---------- input_shape : tf.TensorShape The input shape Returns ------- tf.TensorShape The output shape """ return self.output_layer.compute_output_shape(input_shape)
[docs] def get_config(self): """Return a Python dict containing the configuration of the model.""" config = super().get_config() config = maybe_serialize_keras_objects( self, config, {"output_layer": tf.keras.layers.serialize}, ) return config
[docs] @classmethod def from_config(cls, config): config = maybe_deserialize_keras_objects( config, ["output_layer"], tf.keras.layers.deserialize ) return super().from_config(config)
@tf.keras.utils.register_keras_serializable(package="merlin.models") class CategFeaturePrediction(Block): """Block that predicts a categorical feature. num_classes is inferred from the""" def __init__( self, schema: Schema, feature_name: Optional[str] = None, bias_initializer="zeros", kernel_initializer="random_normal", activation=None, **kwargs, ): super(CategFeaturePrediction, self).__init__(**kwargs) self.bias_initializer = bias_initializer self.kernel_initializer = kernel_initializer self.feature_name = feature_name or schema.select_by_tag(Tags.ITEM_ID).column_names[0] self.num_classes = categorical_cardinalities(schema)[self.feature_name] self.activation = activation # To ensure that the output is always fp32, avoiding numerical # instabilities with mixed_float16 policy self.output_activation = tf.keras.layers.Activation( activation, dtype="float32", name="predictions" ) def build(self, input_shape): self.output_layer = Dense( units=self.num_classes, kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer, name=f"{self.feature_name}-prediction", activation="linear", ) return super().build(input_shape) def call(self, inputs, training=False, **kwargs) -> tf.Tensor: return self.output_activation(self.output_layer(inputs)) def compute_output_shape(self, input_shape): return input_shape[:-1] + (self.num_classes,)
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models") class MultiClassClassificationTask(PredictionTask): """ Prediction task for multi-class classification. Parameters ---------- target_name : Optional[str], optional Label name, by default None task_name: str, optional The name of the task. task_block: Block, optional The block to use for the task. """ DEFAULT_LOSS = "categorical_crossentropy" DEFAULT_METRICS: MetricOrMetrics = (partial(tf.keras.metrics.Accuracy, name="accuracy"),)
[docs] def __init__( self, target_name: Optional[str] = None, task_name: Optional[str] = None, task_block: Optional[Layer] = None, pre: Optional[Block] = None, **kwargs, ): super().__init__( target_name=target_name, task_name=task_name, task_block=task_block, pre=pre, **kwargs, )
[docs] @classmethod def from_schema( cls, schema: Schema, feature_name: str = Tags.ITEM_ID, bias_initializer="zeros", kernel_initializer="random_normal", extra_pre: Optional[Block] = None, **kwargs, ) -> "MultiClassClassificationTask": """Create from Schema.""" pre = CategFeaturePrediction( schema, feature_name, bias_initializer=bias_initializer, kernel_initializer=kernel_initializer, ) if extra_pre: pre = pre.connect(extra_pre) return cls( pre=pre, **kwargs, )
[docs] def call(self, inputs, training=False, **kwargs): return inputs