Model Architectures

Transformers4Rec provides modularized building blocks that can be combined with plain PyTorch modules and Keras layers. This provides a great flexibility in the model definition, as you can use the blocks to build custom architectures, e.g., with multiple towers, multiple heads and losses (multi-task).

In Fig. 2, we provide a reference architecture for next-item prediction with Transformers, that can be used for both sequential and session-based recommendation. We can divide that reference architecture in four conceptual layers, described next.

Transformers4Rec meta-architecture
Fig. 2 - Transformers4Rec meta-architecture

Feature Aggregation (Input Block)

In order to be fed into a transformer block the sequences of input features (user_ids, user metadata, item_ids, item metadata) must be aggregated into a single vector representation per element in the sequence which we call the interaction embedding.

Current feature aggregation options are:

  • Concat - Concatenation of the features

  • Element-wise sum - Features are summed. For that, all features must have the same dimension, i.e. categorical embeddings must have the same dim and continuous features are projected to that dim.

  • Element-wise sum & item multiplication - Similar to Element-wise sum, as all features are summed. except for the item id embedding, which is multiplied by the other features sum. The aggregation formula is available in our paper.

Categorical features are represented by embeddings. Numerical features can be represented as a scalar, projected by a fully-connected (FC) layer to multiple dimensions, or represented as a weighted average of embeddings by the technique Soft One-Hot embeddings (more info in our paper online appendix). Categorical input features are optionally normalized (with layer normalization) before aggregation. Continuous features should be normalized during feature engineering of the dataset.

The core class of this module is the TabularSequenceFeatures, which is responsible for processing and aggregating all features and outputs a sequence of interaction embeddings to be fed into transformer blocks. It can be instantiated automatically from a dataset schema (from_schema()) generated by NVTabular, which directly creates all the necessary layers to represent the categorical and continuous features in the dataset. In addition, it has options to aggregate the sequential features, and to prepare masked labels depending on the chosen sequence masking approach (see next section)).

from transformers4rec.torch import TabularSequenceFeatures
tabular_inputs = TabularSequenceFeatures.from_schema(

Sequence Masking

Transformer architectures can be trained in different ways. Depending on the training method, there is a specific masking schema. The masking schema sets the items to be predicted (labels) and masks some positions of the sequence that cannot be used by the Transformer layers for prediction. Transformers4Rec currently supports the following training approaches, inspired by NLP:

  • Causal LM (masking="clm") - Predicts the next item based on past positions of the sequence. Future positions are masked.

  • Masked LM (masking="mlm") - Randomly select some positions of the sequence to be predicted, which are masked. The Transformer layer is allowed to use positions on the right (future information) during training. During inference, all past items are visible for the Transformer layer, which tries to predict the next item.

  • Permutation LM (masking="plm") - Uses a permutation factorization at the level of the self-attention layer to define the accessible bidirectional context

  • Replacement Token Detection (masking="rtd") - Uses MLM to randomly select some items, but replaces them by random tokens. Then, a discriminator model (that can share the weights with the generator or not), is asked to classify whether the item at each position belongs to the original sequence. The generator-discriminator architecture is jointly trained using Masked LM and RTD tasks.

Note that not all transformer architectures support all of these training approaches. Transformers4Rec will raise an exception when you attempt to use an invalid combination and will provide suggestions as to the appropriate masking techniques for that architecture.

Sequence Processing (Transformer/RNN Block)

The Transformer block processes the input sequences of interaction embeddings created by the input block using Transformer architectures like XLNet, GPT-2, etc, or RNN architectures like LSTM or GRU. The created block is a standard keras layer or torch block depending on the underlying framework and is compatible with and substitutable by other blocks of the same type which support the input of a sequence.

In the following example, a SequentialBlock module is used to build the model body: a TabularSequenceFeatures object (tabular_inputs defined in the previous code snippet), followed by an MLP projection layer to 64 dim (to match the Transformer d_model), followed by an XLNet transformer block with 2 layers (4 heads each).

from transformers4rec.config import transformer
from transformers4rec.torch import MLPBlock, SequentialBlock, TransformerBlock

# Configures the XLNet Transformer architecture
transformer_config =
    d_model=64, n_head=4, n_layer=2, total_seq_length=20

# Defines the model body including: inputs, masking, projection and transformer block.
model_body = SequentialBlock(
    torch4rec.TransformerBlock(transformer_config, masking=tabular_inputs.masking)

Prediction head (Output Block)

Following the input and transformer blocks the model outputs its predictions. The library supports the following prediction heads which can have multiple losses and can be combined for multi-task learning and multiple metrics.

  • Next Item Prediction - Predicts next items for a given sequence of interactions. During training it can be the next item or randomly selected items, depending on the masking scheme. For inference it is meant to always predict the next interacted item. Currently cross-entropy and pairwise losses are supported.

  • Classification - Predicts a categorical feature using the whole sequence. In the context of recommendation, which can be used to predict for example if the user is going to abandon a product added to cart or proceed to its purchase.

  • Regression - Predicts a continuous feature using the whole sequence, for example the elapsed time until the user returns to a service.

In the following example we instantiate a head with the pre-defined model_body for the NextItemPredictionTask. That head enables the weight_tying option, which is described in the next section. Decoupling model bodies and heads allow for a flexible model architecture definition, as it allows for multiple towers and/or heads. Finally, the Model class combines the heads and wraps the whole model.

from transformers4rec.torch import Head, Model
from transformers4rec.torch.model.head import NextItemPredictionTask

# Defines the head related to next item prediction task
head = Head(
    NextItemPredictionTask(weight_tying=True, hf_format=True),

# Get the end-to-end Model class
model = Model(head)

Tying embeddings

For NextItemPredictionTask we have added a best practice, Tying Embeddings, proposed originally by the NLP community to tie the weights of the input (item id) embedding matrix with the output projection layer. Not only do tied embeddings reduce the memory requirements significantly, but our own experimentation during recent competitions and empirical analysis detailed in our paper and online appendix show how effective this method is. It is enabled by default, but can be disabled by setting weight_tying=False).


The library supports a number of regularization techniques like Dropout, Weight Decay, Softmax Temperature Scaling, Stochastic Shared Embeddings, and Label Smoothing. In our extensive experimentation hypertuning all regularization techniques for different dataset we found that the Label Smoothing was particularly useful at improving both train and validation accuracy and better calibrating the predictions.

More details of the options available for each building block can be found in our API Documentation.