Source code for merlin.models.tf.outputs.regression
from typing import Optional, Union
import tensorflow as tf
from tensorflow.keras.layers import Layer
from merlin.models.tf.outputs.base import MetricsFn, ModelOutput
from merlin.schema import ColumnSchema
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class RegressionOutput(ModelOutput):
"""Regression prediction block
Parameters
----------
target: str, optional
The name of the target.
pre: Optional[Block], optional
Optional block to transform predictions before computing the regression scores,
by default None
post: Optional[Block], optional
Optional block to transform the regression scores,
by default None
name: str, optional
The name of the task.
default_loss: Union[str, tf.keras.losses.Loss], optional
Default loss to use for regression
by 'mse'
get_default_metrics: Callable, optional
A function returning the list of default metrics to set
if the user does not specify any
Default metrics to use for regression
"""
[docs] def __init__(
self,
target: Optional[Union[str, ColumnSchema]] = None,
pre: Optional[Layer] = None,
post: Optional[Layer] = None,
name: Optional[str] = None,
default_loss="mse",
default_metrics_fn: MetricsFn = lambda: (
tf.keras.metrics.RootMeanSquaredError(name="root_mean_squared_error"),
),
**kwargs,
):
if isinstance(target, ColumnSchema):
target = target.name
to_call = kwargs.pop("to_call", None)
super().__init__(
to_call=to_call or tf.keras.layers.Dense(1, activation="linear"),
default_loss=default_loss,
default_metrics_fn=default_metrics_fn,
target=target,
pre=pre,
post=post,
name=name,
**kwargs,
)