Source code for merlin.models.tf.blocks.cross

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Tuple, Union

import tensorflow as tf

from merlin.models.tf.blocks.mlp import DenseMaybeLowRank, InitializerType, RegularizerType
from merlin.models.tf.core.combinators import Filter, SequentialBlock, TabularBlock
from merlin.models.tf.utils.tf_utils import (
    maybe_deserialize_keras_objects,
    maybe_serialize_keras_objects,
)
from merlin.schema import Schema, Tags


[docs]def CrossBlock( depth: int = 1, filter: Optional[Union[Schema, Tags, List[str], Filter]] = None, low_rank_dim: Optional[int] = None, use_bias: bool = True, kernel_initializer: InitializerType = "truncated_normal", bias_initializer: InitializerType = "zeros", kernel_regularizer: Optional[RegularizerType] = None, bias_regularizer: Optional[RegularizerType] = None, inputs: Optional[tf.keras.layers.Layer] = None, **kwargs, ) -> SequentialBlock: """This block provides a way to create high-order feature interactions by a number of stacked Cross Layers, from DCN V2: Improved Deep & Cross Network [1]. See Eq. (1) for full-rank and Eq. (2) for low-rank version. References ---------- .. [1]. Wang, Ruoxi, et al. "DCN V2: Improved deep & cross network and practical lessons for web-scale learning to rank systems." Proceedings of the Web Conference 2021. 2021. https://arxiv.org/pdf/2008.13535.pdf Parameters ---------- depth : int, optional Number of cross-layers to be stacked, by default 1 filter : Optional[Union[Schema, Tags, List[str], Filter]], optional Features filter to be applied on the input, by default None low_rank_dim : Optional[int], optional If this argument is provided, the weight (`W in R(dxd)`), where d is the input features dimension matrix, is factorized in a low-rank matrix W = U*V where U and D have (dxr) shape and `low_rank_dim = r`, by default None use_bias : bool, optional Enables or not the bias term, by default True kernel_initializer : InitializerType, optional Initializer to use on the kernel matrix, by default "truncated_normal" bias_initializer : InitializerType, optional Initializer to use on the bias vector, by default "zeros" kernel_regularizer : Optional[RegularizerType], optional Regularizer to use on the kernel matrix, by default None bias_regularizer : Optional[RegularizerType], optional Regularizer to use on the bias vector, by default None inputs : Optional[tf.keras.layers.Layer], optional If an `InputBlock` is provided, this block checks if features are being aggregated with concat, otherwise it does that, as cross blocks need features to be aggregated before, by default None Returns ------- SequentialBlock A `SequentialBlock` with a number of stacked Cross layers Raises ------ ValueError Number of cross layers (depth) should be positive """ layers = [inputs, TabularBlock(aggregation="concat")] if inputs else [] if depth <= 0: raise ValueError(f"Number of cross layers (depth) should be positive but is {depth}.") for i in range(depth): layers.append( Cross( low_rank_dim=low_rank_dim, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, output_x0=i < depth - 1, ) ) return SequentialBlock(layers, filter=filter, block_name="CrossBlock", **kwargs)
@tf.keras.utils.register_keras_serializable(package="merlin.models") class Cross(tf.keras.layers.Layer): """Implementation of the Cross Layers from DCN V2: Improved Deep & Cross Network [1]_ - See Eq. (1) for full-rank and Eq. (2) for low-rank version. This layer creates interactions of all input features. When used inside `CrossBlock`, stacked `Cross` layers can be used to high-order features interaction. The `call` method accepts `inputs` as a tuple of size 2 tensors. The first input `x0` is the base layer that contains the original features (usually the embedding layer); the second input `xi` is the output of the previous `Cross` layer in the stack, i.e., the i-th `Cross` layer. For the first `Cross` layer in the stack, x0 = xi. The output is x_{i+1} = x0 .* ((W * xi + bias * xi) + xi, where .* designates elementwise multiplication, W could be a full-rank matrix, or a low-rank matrix U*V to reduce the computational cost, and diag_scale increases the diagonal of W to improve training stability ( especially for the low-rank case). References ---------- .. [1]. Wang, Ruoxi, et al. "DCN V2: Improved deep & cross network and practical lessons for web-scale learning to rank systems." Proceedings of the Web Conference 2021. 2021. https://arxiv.org/pdf/2008.13535.pdf Parameters ---------- low_rank_dim : Optional[int], optional If this argument is provided, the weight (`W in R(dxd)`), where d is the input features dimension matrix, is factorized in a low-rank matrix W = U*V where U and D have (dxr) shape and `low_rank_dim = r`, by default None use_bias : bool, optional Enables or not the bias term, by default True kernel_initializer : InitializerType, optional Initializer to use on the kernel matrix, by default "truncated_normal" bias_initializer : InitializerType, optional Initializer to use on the bias vector, by default "zeros" kernel_regularizer : Optional[RegularizerType], optional Regularizer to use on the kernel matrix, by default None bias_regularizer : Optional[RegularizerType], optional Regularizer to use on the bias vector, by default None output_x0 : bool Whether to return a tuple containing the input of the first layer (`x0`), which usually represents the input features concatenated, by default False """ def __init__( self, low_rank_dim: Optional[int] = None, use_bias: bool = True, kernel_initializer: InitializerType = "truncated_normal", bias_initializer: InitializerType = "zeros", kernel_regularizer: Optional[RegularizerType] = None, bias_regularizer: Optional[RegularizerType] = None, output_x0: bool = False, **kwargs, ): dense = kwargs.pop("dense", None) if not dense: dense = DenseMaybeLowRank( low_rank_dim=low_rank_dim, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, ) super(Cross, self).__init__(**kwargs) self.dense = dense self.output_x0 = output_x0 def compute_output_shape(self, input_shape): return input_shape def call(self, inputs: Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]], **kwargs): if isinstance(inputs, tuple): x0, x = inputs else: x0 = x = inputs self.validate_inputs(x0, x) projected = self.dense(x) output = x0 * projected + x if self.output_x0: return x0, output return output def validate_inputs(self, x0, x): tf.assert_equal( tf.shape(x0), tf.shape(x), message="`x0` ({}) and `x` ({}) shapes mismatch!".format(x0.shape, x.shape), ) def get_config(self): config = dict() config.update(super(Cross, self).get_config()) return maybe_serialize_keras_objects(self, config, ["dense"]) @classmethod def from_config(cls, config): config = maybe_deserialize_keras_objects(config, ["dense"]) return cls(**config)