from typing import Dict, List, NamedTuple, Optional, Text, Union
import tensorflow as tf
from keras.utils.generic_utils import to_snake_case
from tensorflow.keras.layers import Layer
from merlin.models.tf.core.base import (
Block,
BlockType,
ContextMixin,
ModelContext,
PredictionOutput,
name_fn,
)
from merlin.models.tf.core.combinators import ParallelBlock
from merlin.models.tf.typing import TabularData, TensorOrTabularData
from merlin.models.tf.utils import tf_utils
from merlin.models.utils.schema_utils import tensorflow_metadata_json_to_schema
from merlin.schema import Schema, Tags
class TaskResults(NamedTuple):
predictions: Union[TabularData, tf.Tensor]
targets: Union[TabularData, tf.Tensor]
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class PredictionTask(Layer, ContextMixin):
"""Base-class for prediction tasks.
Parameters
----------
target_name : Optional[str], optional
Label name, by default None
task_name : Optional[str], optional
Task name, by default None
metrics : Optional[MetricOrMetrics], optional
List of Keras metrics to be evaluated, by default None
pre : Optional[Block], optional
Optional block to transform predictions before computing loss and metrics,
by default None
pre_eval_topk : Optional[Block], optional
Optional block to apply additional transform predictions before computing
top-k evaluation loss and metrics, by default None
task_block : Optional[Layer], optional
Optional block to apply to inputs before computing predictions,
by default None
prediction_metrics : Optional[List[tf.keras.metrics.Metric]], optional
List of Keras metrics used to summarize the predictions, by default None
label_metrics : Optional[List[tf.keras.metrics.Metric]], optional
List of Keras metrics used to summarize the labels, by default None
compute_train_metrics : Optional[bool], optional
Enable computing metrics during training, by default True
name : Optional[Text], optional
Task name, by default None
"""
[docs] def __init__(
self,
target_name: Optional[str] = None,
task_name: Optional[str] = None,
pre: Optional[Block] = None,
pre_eval_topk: Optional[Block] = None,
task_block: Optional[Layer] = None,
name: Optional[Text] = None,
**kwargs,
) -> None:
self.target_name = target_name
self._task_name = task_name
name = name or self.task_name
super().__init__(name=name, **kwargs)
self.task_block = task_block
self.pre = pre
self._pre_eval_topk = pre_eval_topk
@property
def pre_eval_topk(self):
return self._pre_eval_topk
@pre_eval_topk.setter
def pre_eval_topk(self, value: Block):
"""Set pre_eval_topk Block
Parameters
----------
value : Block
The block for top-k evaluation
"""
self._pre_eval_topk = value
[docs] def pre_call(self, inputs: TensorOrTabularData, **kwargs) -> tf.Tensor:
"""Apply PredictionTask to inputs to get predictions scores
Parameters
----------
inputs : TensorOrTabularData
inputs of the prediction task
Returns
-------
tf.Tensor
"""
x = inputs
if self.task_block:
x = self.task_block(x)
if self.pre:
x = self.pre(x, **kwargs)
return x
[docs] def pre_loss(self, outputs: PredictionOutput, **kwargs) -> "PredictionOutput":
"""Apply `call_outputs` method of `pre` block to transform predictions and targets
before computing loss and metrics.
Parameters
----------
outputs : PredictionOutput
The named tuple containing predictions and targets tensors
Returns
-------
PredictionOutput
The named tuple containing transformed predictions and targets tensors
"""
out = self.pre.call_outputs(outputs, **kwargs) if self.pre else outputs
return out
[docs] def build(self, input_shape, features_shape=None):
if self.pre:
self.pre.build(input_shape)
return super().build(input_shape)
def __call__(self, *args, **kwargs):
inputs = self.pre_call(*args, **kwargs)
# This will call the `call` method implemented by the super class.
outputs = super().__call__(inputs, **kwargs) # noqa
if kwargs.get("training", False) or kwargs.get("testing", False):
targets = kwargs.get("targets", {})
if isinstance(targets, dict) and self.target_name:
targets = targets.get(self.target_name, targets)
if isinstance(outputs, dict) and self.target_name and self.task_name in outputs:
outputs = outputs[self.task_name]
prediction_output = PredictionOutput(outputs, targets)
if self.pre:
prediction_output = self.pre_loss(prediction_output, **kwargs)
if (
isinstance(prediction_output.targets, tf.Tensor)
and len(prediction_output.targets.shape)
== len(prediction_output.predictions.shape) - 1
):
prediction_output = prediction_output.copy_with_updates(
predictions=tf.squeeze(prediction_output.predictions)
)
return prediction_output
return outputs
[docs] def build_task(self, input_shape, schema: Schema, body: Block, **kwargs):
return super().build(input_shape)
@property
def task_name(self):
if self._task_name:
return self._task_name
base_name = to_snake_case(self.__class__.__name__)
return name_fn(self.target_name, base_name) if self.target_name else base_name
[docs] def child_name(self, name):
return name_fn(self.task_name, name)
[docs] @classmethod
def from_config(cls, config):
config = tf_utils.maybe_deserialize_keras_objects(
config,
{
"pre": tf.keras.layers.deserialize,
"task_block": tf.keras.layers.deserialize,
},
)
return super().from_config(config)
[docs] def get_config(self):
config = super().get_config()
config = tf_utils.maybe_serialize_keras_objects(
self,
config,
["pre", "task_block"],
)
# config["summary_type"] = self.sequence_summary.summary_type
if self.target_name:
config["target_name"] = self.target_name
if self._task_name:
config["task_name"] = self._task_name
return config
[docs] def create_default_metrics(self):
metrics = []
for metric in self.DEFAULT_METRICS:
metrics.append(metric(name=self.child_name(to_snake_case(metric.__name__))))
return metrics
[docs] def compute_output_shape(self, input_shape):
if self.task_block:
input_shape = self.task_block.compute_output_shape(input_shape)
if self.pre:
input_shape = self.pre.compute_output_shape(input_shape)
return input_shape
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class ParallelPredictionBlock(ParallelBlock):
"""Multi-task prediction block.
Parameters
----------
prediction_tasks: *PredictionTask
List of tasks to be used for prediction.
task_blocks: Optional[Union[Layer, Dict[str, Layer]]]
Task blocks to be used for prediction.
task_weights : Optional[List[float]]
Weights for each task.
bias_block : Optional[Layer]
Bias block to be used for prediction.
"""
[docs] def __init__(
self,
*prediction_tasks: PredictionTask,
task_blocks: Optional[Union[Layer, Dict[str, Layer]]] = None,
bias_block: Optional[Layer] = None,
task_weights=None,
pre: Optional[BlockType] = None,
post: Optional[BlockType] = None,
**kwargs,
):
self.prediction_tasks = prediction_tasks
self.bias_block = bias_block
if bias_block:
self.bias_logit = tf.keras.layers.Dense(1)
self.prediction_task_dict = {}
if prediction_tasks:
for task in prediction_tasks:
self.prediction_task_dict[task.task_name] = task
super(ParallelPredictionBlock, self).__init__(
self.prediction_task_dict, pre=pre, post=post, use_layer_name=False, **kwargs
)
if task_blocks:
self._set_task_blocks(task_blocks)
[docs] @classmethod
def get_tasks_from_schema(
cls,
schema,
task_weight_dict: Optional[Dict[str, float]] = None,
task_pre_dict: Optional[Dict[str, Block]] = None,
):
"""Built Multi-task prediction Block from schema
Parameters
----------
schema : Schema
The `Schema` with the input features
task_weight_dict : Optional[Dict[str, float]], optional
Dict where keys are target feature names and values are weights for each task,
by default None
task_pre_dict: Optional[Dict[str, Block]], optional
Dict where keys are target feature names and values are Blocks to be used as
pre for those tasks
"""
task_weight_dict = task_weight_dict or {}
task_pre_dict = task_pre_dict or {}
tasks = []
from merlin.models.tf.prediction_tasks.classification import BinaryClassificationTask
from merlin.models.tf.prediction_tasks.regression import RegressionTask
for target in sorted(schema.select_by_tag(Tags.BINARY_CLASSIFICATION).column_names):
task = BinaryClassificationTask(target, pre=task_pre_dict.get(target, None))
task_weight = task_weight_dict.get(target, 1.0)
tasks.append((target, task, task_weight))
for target in sorted(schema.select_by_tag(Tags.REGRESSION).column_names):
task = RegressionTask(target, pre=task_pre_dict.get(target, None))
task_weight = task_weight_dict.get(target, 1.0)
tasks.append((target, task, task_weight))
# TODO: Add multi-class classification here. Figure out how to get number of classes
# Ensures tasks are sorted by name, so that they match the metrics
# which are sorted the same way by Keras
_, tasks, tasks_weights = zip(*sorted(tasks))
return tasks_weights, tasks
[docs] @classmethod
def from_schema( # type: ignore
cls,
schema: Schema,
task_blocks: Optional[Union[Layer, Dict[str, Layer]]] = None,
task_weight_dict: Optional[Dict[str, float]] = None,
task_pre_dict: Optional[Dict[str, Block]] = None,
bias_block: Optional[Layer] = None,
**kwargs,
) -> "ParallelPredictionBlock":
"""Built Multi-task prediction Block from schema
Parameters
----------
schema : Schema
The `Schema` with the input features
task_blocks : Optional[Union[Layer, Dict[str, Layer]]], optional
Task blocks to be used for prediction, by default None
task_weight_dict : Optional[Dict[str, float]], optional
Dict where keys are target feature names and values are weights for each task,
by default None
task_pre_dict: Optional[Dict[str, Block]], optional
Dict where keys are target feature names and values are Blocks to be used as pre
for those tasks
bias_block : Optional[Layer], optional
Bias block to be used for prediction, by default None
"""
task_weight_dict = task_weight_dict or {}
task_weights, tasks = cls.get_tasks_from_schema(schema, task_weight_dict, task_pre_dict)
return cls(
*tasks,
task_blocks=task_blocks,
task_weights=task_weights,
bias_block=bias_block,
**kwargs,
)
[docs] @classmethod
def task_names_from_schema(cls, schema: Schema) -> List[str]:
_, tasks = cls.get_tasks_from_schema(schema)
return [task.task_name for task in tasks]
def _set_task_blocks(self, task_blocks):
if isinstance(task_blocks, dict):
tasks_multi_names = self._prediction_tasks_multi_names()
for key, task_block in task_blocks.items():
if key in tasks_multi_names:
tasks = tasks_multi_names[key]
if len(tasks) == 1:
self.prediction_task_dict[tasks[0].task_name].task_block = task_block
else:
raise ValueError(
f"Ambiguous name: {key}, can't resolve it to a task "
"because there are multiple tasks that contain the key: "
f"{', '.join([task.task_name for task in tasks])}"
)
else:
raise ValueError(
f"Couldn't find {key} in prediction_tasks, "
f"only found: {', '.join(list(self.prediction_task_dict.keys()))}"
)
elif isinstance(task_blocks, Layer):
for key, val in self.prediction_task_dict.items():
task_block = task_blocks.from_config(task_blocks.get_config())
val.task_block = task_block
else:
raise ValueError("`task_blocks` must be a Layer or a Dict[str, Layer]")
def _prediction_tasks_multi_names(self) -> Dict[str, List[PredictionTask]]:
prediction_tasks_multi_names = {
name: [val] for name, val in self.prediction_task_dict.items()
}
for name, value in self.prediction_task_dict.items():
name_parts = name.split("/")
for name_part in name_parts:
if name_part in prediction_tasks_multi_names:
prediction_tasks_multi_names[name_part].append(value)
else:
prediction_tasks_multi_names[name_part] = [value]
return prediction_tasks_multi_names
[docs] def add_task(self, task: PredictionTask, task_weight=1):
key = task.target_name
if not key:
raise ValueError("PredictionTask must have a target_name")
self.parallel_dict[key] = task
if task_weight:
self._task_weight_dict[key] = task_weight
return self
[docs] def pop_labels(self, inputs: Dict[Text, tf.Tensor]):
outputs = {}
for name in self.parallel_dict.keys():
outputs[name] = inputs.pop(name)
return outputs
[docs] def call(
self,
inputs: Union[TabularData, tf.Tensor],
training: bool = False,
bias_outputs=None,
**kwargs,
):
if isinstance(inputs, dict) and not all(
name in inputs for name in list(self.parallel_dict.keys())
):
if self.bias_block and not bias_outputs:
bias_outputs = self.bias_block(inputs, training=training)
inputs = self.body(inputs)
outputs = super(ParallelPredictionBlock, self).call(inputs, training=training, **kwargs)
if bias_outputs is not None:
for key in outputs:
outputs[key] += bias_outputs
return outputs
[docs] def compute_call_output_shape(self, input_shape):
if isinstance(input_shape, dict) and not all(
name in input_shape for name in list(self.parallel_dict.keys())
):
input_shape = self.body.compute_output_shape(input_shape)
return super().compute_call_output_shape(input_shape)
@property
def task_blocks(self) -> Dict[str, Optional[Layer]]:
return {name: task.task_block for name, task in self.prediction_task_dict.items()}
@property
def task_names(self) -> List[str]:
return [name for name in self.prediction_task_dict]
[docs] def repr_ignore(self) -> List[str]:
return ["prediction_tasks", "parallel_layers"]
def _set_context(self, context: "ModelContext"):
for task in self.prediction_task_dict.values():
task._set_context(context)
super(ParallelPredictionBlock, self)._set_context(context)
[docs] @classmethod
def from_config(cls, config, **kwargs):
config = tf_utils.maybe_deserialize_keras_objects(config, ["body", "prediction_tasks"])
if "schema" in config:
config["schema"] = tensorflow_metadata_json_to_schema(config["schema"])
prediction_tasks = config.pop("prediction_tasks", [])
config.pop("parallel_layers", None)
return cls(*prediction_tasks, **config)
[docs] def get_config(self):
config = super().get_config()
config = tf_utils.maybe_serialize_keras_objects(self, config, ["body", "prediction_tasks"])
return config