Source code for merlin.models.tf.prediction_tasks.classification

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Optional, Union

import tensorflow as tf
from tensorflow.keras.layers import Dense, Layer

from merlin.models.tf.core.base import Block, MetricOrMetrics
from merlin.models.tf.prediction_tasks.base import PredictionTask
from merlin.models.tf.utils.tf_utils import (
    maybe_deserialize_keras_objects,
    maybe_serialize_keras_objects,
)
from merlin.models.utils.schema_utils import categorical_cardinalities
from merlin.schema import Schema, Tags


[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models") class BinaryClassificationTask(PredictionTask): """ Prediction task for binary classification. Parameters ---------- target: Union[str, Schema], optional The name of the target. If a Schema is provided, the target is inferred from the schema. task_name: str, optional The name of the task. task_block: Block, optional The block to use for the task. """ # Default loss to use DEFAULT_LOSS = "binary_crossentropy" # Default metrics to use DEFAULT_METRICS = ( tf.keras.metrics.Precision, tf.keras.metrics.Recall, tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.AUC, )
[docs] def __init__( self, target: Optional[Union[str, Schema]] = None, task_name: Optional[str] = None, task_block: Optional[Layer] = None, **kwargs, ): if isinstance(target, Schema): target_name = target.select_by_tag(Tags.BINARY_CLASSIFICATION) if not target_name.column_names: raise ValueError( "Binary classification task requires a column with a ", "`Tags.BINARY_CLASSIFICATION` tag.", ) elif len(target_name.column_names) > 1: raise ValueError( "Binary classification task requires a single column with a ", "`Tags.BINARY_CLASSIFICATION` tag. ", "Found {} columns. ".format(len(target_name.column_names)), "Please specify the column name with the `target` argument.", ) target_name = target_name.column_names[0] else: target_name = target if target else kwargs.pop("target_name", None) output_layer = kwargs.pop("output_layer", None) super().__init__( target_name=target_name, task_name=task_name, task_block=task_block, **kwargs, ) self.output_layer = output_layer or tf.keras.layers.Dense( 1, activation="linear", name="output_layer" ) # To ensure that the output is always fp32, avoiding numerical # instabilities with mixed_float16 (fp16) policy self.output_activation = tf.keras.layers.Activation( "sigmoid", dtype="float32", name="prediction" )
[docs] def call(self, inputs, training=False, **kwargs): return self.output_activation(self.output_layer(inputs))
[docs] def compute_output_shape(self, input_shape): return self.output_layer.compute_output_shape(input_shape)
[docs] def get_config(self): config = super().get_config() config = maybe_serialize_keras_objects( self, config, {"output_layer": tf.keras.layers.serialize}, ) return config
[docs] @classmethod def from_config(cls, config): config = maybe_deserialize_keras_objects( config, ["output_layer"], tf.keras.layers.deserialize ) return super().from_config(config)
@tf.keras.utils.register_keras_serializable(package="merlin.models") class CategFeaturePrediction(Block): """Block that predicts a categorical feature. num_classes is inferred from the""" def __init__( self, schema: Schema, feature_name: Optional[str] = None, bias_initializer="zeros", kernel_initializer="random_normal", activation=None, **kwargs, ): super(CategFeaturePrediction, self).__init__(**kwargs) self.bias_initializer = bias_initializer self.kernel_initializer = kernel_initializer self.feature_name = feature_name or schema.select_by_tag(Tags.ITEM_ID).column_names[0] self.num_classes = categorical_cardinalities(schema)[self.feature_name] self.activation = activation # To ensure that the output is always fp32, avoiding numerical # instabilities with mixed_float16 policy self.output_activation = tf.keras.layers.Activation( activation, dtype="float32", name="predictions" ) def build(self, input_shape): self.output_layer = Dense( units=self.num_classes, kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer, name=f"{self.feature_name}-prediction", activation="linear", ) return super().build(input_shape) def call(self, inputs, training=False, **kwargs) -> tf.Tensor: return self.output_activation(self.output_layer(inputs)) def compute_output_shape(self, input_shape): return input_shape[:-1] + (self.num_classes,)
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models") class MultiClassClassificationTask(PredictionTask): """ Prediction task for multi-class classification. Parameters ---------- target_name : Optional[str], optional Label name, by default None task_name: str, optional The name of the task. task_block: Block, optional The block to use for the task. """ DEFAULT_LOSS = "categorical_crossentropy" DEFAULT_METRICS: MetricOrMetrics = (tf.keras.metrics.Accuracy,)
[docs] def __init__( self, target_name: Optional[str] = None, task_name: Optional[str] = None, task_block: Optional[Layer] = None, pre: Optional[Block] = None, **kwargs, ): super().__init__( target_name=target_name, task_name=task_name, task_block=task_block, pre=pre, **kwargs, )
[docs] @classmethod def from_schema( cls, schema: Schema, feature_name: str = Tags.ITEM_ID, bias_initializer="zeros", kernel_initializer="random_normal", extra_pre: Optional[Block] = None, **kwargs, ) -> "MultiClassClassificationTask": """Create from Schema.""" pre = CategFeaturePrediction( schema, feature_name, bias_initializer=bias_initializer, kernel_initializer=kernel_initializer, ) if extra_pre: pre = pre.connect(extra_pre) return cls( pre=pre, **kwargs, )
[docs] def call(self, inputs, training=False, **kwargs): return inputs