Source code for

# Copyright (c) 2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import tensorflow as tf

from import MLPBlock
from import StackFeatures
from import Block
from import MapValues, ParallelBlock, SequentialBlock
from import Filter
from import InputBlockV2
from import Embeddings
from import CategoryEncoding, ToSparse
from merlin.schema import Schema, Tags

_INTERACTION_TYPES = (None, "field_all", "field_each", "field_interaction")

[docs]@tf.keras.utils.register_keras_serializable(package="merlin_models") class DotProductInteraction(tf.keras.layers.Layer): """ Layer implementing the factorization machine style feature interaction layer suggested by the DLRM and DeepFM architectures, generalized to include a dot-product version of the parameterized interaction suggested by the FiBiNet architecture (which normally uses element-wise multiplication instead of dot product). Maps from tensors of shape `(batch_size, num_features, embedding_dim)` to tensors of shape `(batch_size, (num_features - 1)*num_features // 2)` if `self_interaction` is `False`, otherwise `(batch_size, num_features**2)`. Parameters ------------------------ interaction_type: {} The type of feature interaction to use. `None` defaults to the standard factorization machine style interaction, and the alternatives use the implementation defined in the FiBiNet architecture (with the element-wise multiplication replaced with a dot product). self_interaction: bool Whether to calculate the interaction of a feature with itself. """.format( _INTERACTION_TYPES )
[docs] def __init__(self, interaction_type=None, self_interaction=False, name=None, **kwargs): if interaction_type not in _INTERACTION_TYPES: raise ValueError("Unknown interaction type {}".format(interaction_type)) self.interaction_type = interaction_type self.self_interaction = self_interaction super(DotProductInteraction, self).__init__(name=name, **kwargs)
[docs] def build(self, input_shape): if self.interaction_type is None: self.built = True return kernel_shape = [input_shape[2], input_shape[2]] if self.interaction_type in _INTERACTION_TYPES[2:]: idx = _INTERACTION_TYPES.index(self.interaction_type) for _ in range(idx - 1): kernel_shape.insert(0, input_shape[1]) self.kernel = self.add_weight( name="bilinear_interaction_kernel", shape=kernel_shape, initializer="glorot_normal", trainable=True, ) self.built = True
[docs] def call(self, inputs): right = inputs # first transform v_i depending on the interaction type if self.interaction_type is None: left = inputs elif self.interaction_type == "field_all": left = tf.matmul(inputs, self.kernel) elif self.interaction_type == "field_each": left = tf.einsum("b...k,...jk->b...j", inputs, self.kernel) else: left = tf.einsum("b...k,f...jk->bf...j", inputs, self.kernel) # do the interaction between v_i and v_j # output shape will be (batch_size, num_features, num_features) if self.interaction_type != "field_interaction": interactions = tf.matmul(left, right, transpose_b=True) else: interactions = tf.einsum("b...jk,b...k->b...j", left, right) # mask out the appropriate area ones = tf.reduce_sum(tf.zeros_like(interactions), axis=0) + 1 mask = tf.linalg.band_part(ones, 0, -1) # set lower diagonal to zero if not self.self_interaction: mask = mask - tf.linalg.band_part(ones, 0, 0) # get rid of diagonal mask = tf.cast(mask, tf.bool) x = tf.boolean_mask(interactions, mask, axis=1) # masking destroys shape information, set explicitly x.set_shape(self.compute_output_shape(inputs.shape)) return x
[docs] def compute_output_shape(self, input_shape): if self.self_interaction: output_dim = input_shape[1] ** 2 else: output_dim = input_shape[1] * (input_shape[1] - 1) // 2 return input_shape[0], output_dim
[docs] def get_config(self): return { "interaction_type": self.interaction_type, "self_interaction": self.self_interaction, }
class XDeepFmOuterProduct(tf.keras.layers.Layer): """ Layer implementing the outer product transformation used in the Compressed Interaction Network (CIN) proposed in in Treats the feature dimension H_k of a B x H_k x D feature embedding tensor as a feature map of the D embedding elements, and computes element-wise multiplication interaction between these maps and those from an initial input tensor x_0 before taking the inner product with a parameter matrix. Parameters ------------ dim : int Feature dimension of the layer. Output will be of shape (batch_size, dim, embedding_dim) """ def __init__(self, dim, **kwargs): self.dim = dim super().__init__(**kwargs) def build(self, input_shapes): if not isinstance(input_shapes[0], (tuple, tf.TensorShape)): raise ValueError("Should be called on a list of inputs.") if len(input_shapes) != 2: raise ValueError("Should only have two inputs, found {}".format(len(input_shapes))) for shape in input_shapes: if len(shape) != 3: raise ValueError("Found shape {} without 3 dimensions".format(shape)) if input_shapes[0][-1] != input_shapes[1][-1]: raise ValueError( "Last dimension should match, found dimensions {} and {}".format( input_shapes[0][-1], input_shapes[1][-1] ) ) # H_k x H_{k-1} x m shape = (self.dim, input_shapes[0][1], input_shapes[1][1]) self.kernel = self.add_weight( name="kernel", initializer="glorot_uniform", trainable=True, shape=shape ) self.built = True def call(self, inputs): """ Parameters ------------ inputs : array-like(tf.Tensor) The two input tensors, the first of which should be the output of the previous layer, and the second of which should be the input to the CIN. """ x_k_minus_1, x_0 = inputs # need to do shape manipulations so that we # can do element-wise multiply x_k_minus_1 = tf.expand_dims(x_k_minus_1, axis=2) # B x H_{k-1} x 1 x D x_k_minus_1 = tf.tile(x_k_minus_1, [1, 1, x_0.shape[1], 1]) # B x H_{k-1} x m x D x_k_minus_1 = tf.transpose(x_k_minus_1, (1, 0, 2, 3)) # H_{k-1} x B x m x D z_k = x_k_minus_1 * x_0 # H_{k-1} x B x m x D z_k = tf.transpose(z_k, (1, 0, 2, 3)) # B x H_{k-1} x m x D # now we need to map to B x H_k x D x_k = tf.tensordot(self.kernel, z_k, axes=[[1, 2], [1, 2]]) x_k = tf.transpose(x_k, (1, 0, 2)) return x_k def compute_output_shape(self, input_shapes): return (input_shapes[0][0], self.dim, input_shapes[0][2])
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models") class FMPairwiseInteraction(Block): """Compute pairwise (2nd-order) feature interactions like defined in Factorized Machine [1]. References ---------- [1] Steffen, Rendle, "Factorization Machines" IEEE International Conference on Data Mining, 2010. """
[docs] def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: """ Parameters ---------- inputs : array-like(tf.Tensor) A 3-D tensor of shape (bs, n_features, embedding_dim) containing the stacked embeddings of input features. Returns ------- A 2-D tensor of shape (bs, K) containing pairwise interactions """ assert len(inputs.shape) == 3, "inputs should be a 3-D tensor" # sum_square part summed_square = tf.square(tf.reduce_sum(inputs, 1)) # square_sum part squared_sum = tf.reduce_sum(tf.square(inputs), 1) # second order return 0.5 * tf.subtract(summed_square, squared_sum)
[docs] def compute_output_shape(self, input_shapes): """Computes the output shape based on the input shapes Parameters ---------- input_shapes : tf.TensorShape The input shapes Returns ------- tf.TensorShape The output shape """ if len(input_shapes) != 3: raise ValueError("Found shape {} without 3 dimensions".format(input_shapes)) return (input_shapes[0], input_shapes[2])
[docs]def FMBlock( schema: Schema, fm_input_block: Optional[Block] = None, wide_input_block: Optional[Block] = None, wide_logit_block: Optional[Block] = None, factors_dim: Optional[int] = None, **kwargs, ) -> tf.Tensor: """Implements the Factorization Machine, as introduced in [1]. It consists in the sum of a wide component that weights each feature individually and a 2nd-level feature interaction using factors. References ---------- [1] Steffen, Rendle, "Factorization Machines" IEEE International Conference on Data Mining, 2010. Parameters ---------- schema : Schema The schema of input features fm_input_block : Optional[Block], by default None The input block for the 2nd-order feature interaction in Factorization Machine. Only categorical features will be used by this block, as it computes dot product between all paired combinations of embedding values. If not provided, an InputBlockV2 is instantiated based on schema. Note: All features (including continuous) are considered in the 1st-order (wide) part which uses another input block. wide_input_block: Optional[Block], by default None The input for the wide block. If not provided, creates a default block that encodes categorical features with one-hot / multi-hot representation and also includes the continuous features. wide_logit_block: Optional[Block], by default None The output layer of the wide input. The last dimension needs to be 1. You might want to provide your own output logit block if you want to add dropout or kernel regularization to the wide block. factors_dim : Optional[int], optional If fm_input_block is not provided, the factors_dim is used to define the embeddings dim to instantiate InputBlockV2, by default None Returns ------- tf.Tensor Returns a 2D tensor (batch size, 1) with the sum of the wide component and 2nd-order interaction component of FM """ cat_schema = schema.select_by_tag(Tags.CATEGORICAL) cont_schema = schema.select_by_tag(Tags.CONTINUOUS) wide_input_block = wide_input_block or ParallelBlock( { "categorical": CategoryEncoding(cat_schema, output_mode="multi_hot", sparse=True), "continuous": Filter(cont_schema).connect(ToSparse()), }, aggregation="concat", ) wide_logit_block = wide_logit_block or MLPBlock([1], activation="linear", use_bias=True) first_order = wide_input_block.connect(wide_logit_block) fm_input_block = fm_input_block or InputBlockV2( cat_schema, categorical=Embeddings(cat_schema, dim=factors_dim), aggregation=None, ) pairwise_interaction = SequentialBlock( Filter(cat_schema), fm_input_block, FMPairwiseInteraction().prepare(aggregation=StackFeatures(axis=-1)), MapValues(tf.keras.layers.Lambda(lambda x: tf.reduce_sum(x, axis=1, keepdims=True))), ) fm_block = ParallelBlock([first_order, pairwise_interaction], aggregation="element-wise-sum") return fm_block