Source code for merlin.models.tf.blocks.experts

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, List, Optional, Tuple, Union

import tensorflow as tf

from merlin.models.tf.core.aggregation import StackFeatures
from merlin.models.tf.core.base import Block
from merlin.models.tf.core.combinators import ParallelBlock, TabularBlock
from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask
from merlin.models.tf.typing import TabularData
from merlin.schema import Schema


class MMOEGate(Block):
    def __init__(self, num_experts: int, dim=32, name=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.dim = dim
        self.num_experts = num_experts

        self.gate = tf.keras.layers.Dense(dim, name=f"gate_{name}")
        self.softmax = tf.keras.layers.Dense(
            num_experts, use_bias=False, activation="softmax", name=f"gate_distribution_{name}"
        )

    def call(self, inputs: TabularData, **kwargs):
        shortcut = inputs.pop("shortcut")
        expert_outputs = list(inputs.values())[0]

        expanded_gate_output = tf.expand_dims(self.softmax(self.gate(shortcut)), axis=-1)
        out = tf.reduce_sum(expert_outputs * expanded_gate_output, axis=1, keepdims=False)

        return out

    def compute_output_shape(self, input_shape):
        return input_shape["shortcut"]

    def get_config(self):
        config = super().get_config()
        config.update(dim=self.dim, num_experts=self.num_experts)

        return config


[docs]def MMOEBlock( outputs: Union[List[str], List[PredictionTask], ParallelPredictionBlock], expert_block: Block, num_experts: int, gate_dim: int = 32, ): if isinstance(outputs, ParallelPredictionBlock): output_names = outputs.task_names elif all(isinstance(x, PredictionTask) for x in outputs): output_names = [o.task_name for o in outputs] # type: ignore else: output_names = outputs # type: ignore experts = expert_block.repeat_in_parallel( num_experts, prefix="expert_", aggregation=StackFeatures(axis=1) ) gates = MMOEGate(num_experts, dim=gate_dim).repeat_in_parallel(names=output_names) mmoe = expert_block.connect_with_shortcut(experts, block_outputs_name="experts") mmoe = mmoe.connect(gates, block_name="MMOE") return mmoe
class CGCGateTransformation(TabularBlock): def __init__( self, task_names: List[str], num_task_experts: int = 1, num_shared_experts: int = 1, add_shared_gate: bool = True, dim: int = 32, **kwargs, ): super().__init__(**kwargs) num_total_experts = num_task_experts + num_shared_experts self.task_names = [*task_names, "shared"] if add_shared_gate else task_names self.stack = StackFeatures(axis=1) self.gate_dict: Dict[str, MMOEGate] = { name: MMOEGate(num_total_experts, dim=dim) for name in task_names } if add_shared_gate: self.gate_dict["shared"] = MMOEGate( len(task_names) * num_task_experts + num_shared_experts, dim=dim ) def call(self, expert_outputs: TabularData, **kwargs) -> TabularData: # type: ignore outputs: TabularData = {} shortcut = expert_outputs.pop("shortcut") outputs["shortcut"] = shortcut for name in self.task_names: experts = dict( experts=self.stack(self.filter_expert_outputs(expert_outputs, name)), shortcut=shortcut, ) outputs[name] = self.gate_dict[name](experts) return outputs def filter_expert_outputs(self, expert_outputs: TabularData, task_name: str) -> TabularData: if task_name == "shared": return expert_outputs filtered_experts: TabularData = {} for name, val in expert_outputs.items(): if name.startswith((task_name, "shared")): filtered_experts[name] = val return filtered_experts def compute_output_shape(self, input_shape): tensor_output_shape = list(input_shape.values())[0] return {name: tensor_output_shape for name in self.task_names}
[docs]class CGCBlock(ParallelBlock):
[docs] def __init__( self, outputs: Union[List[str], List[PredictionTask], ParallelPredictionBlock], expert_block: Union[Block, tf.keras.layers.Layer], num_task_experts: int = 1, num_shared_experts: int = 1, add_shared_gate: bool = True, schema: Optional[Schema] = None, name: Optional[str] = None, **kwargs, ): if not isinstance(expert_block, Block): expert_block = Block.from_layer(expert_block) if isinstance(outputs, ParallelPredictionBlock): output_names = outputs.task_names elif all(isinstance(x, PredictionTask) for x in outputs): output_names = [o.task_name for o in outputs] # type: ignore else: output_names = outputs # type: ignore task_experts = dict( [ create_expert(expert_block, f"{task}/expert_{i}") for task in output_names for i in range(num_task_experts) ] ) shared_experts = dict( [create_expert(expert_block, f"shared/expert_{i}") for i in range(num_shared_experts)] ) post = CGCGateTransformation( output_names, num_task_experts, num_shared_experts, add_shared_gate=add_shared_gate ) super().__init__( task_experts, shared_experts, post=post, aggregation=None, schema=schema, name=name, strict=False, **kwargs, )
[docs] def call(self, inputs, **kwargs): if isinstance(inputs, dict): outputs = dict(shortcut=inputs["shortcut"]) for name, layer in self.parallel_dict.items(): input_name = "/".join(name.split("/")[:-1]) outputs.update(layer(inputs[input_name])) return outputs else: outputs = super().call(inputs, **kwargs) outputs["shortcut"] = inputs # type: ignore return outputs
[docs] def compute_call_output_shape(self, input_shape): if isinstance(input_shape, dict): output_shapes = {} for name, layer in self.parallel_dict.items(): input_name = "/".join(name.split("/")[:-1]) output_shapes.update(layer.compute_output_shape(input_shape[input_name])) return output_shapes return super().compute_call_output_shape(input_shape)
def create_expert(expert_block: Block, name: str) -> Tuple[str, TabularBlock]: return name, expert_block.as_tabular(name)