transformers4rec.torch package

Submodules

transformers4rec.torch.masking module

class transformers4rec.torch.masking.MaskingInfo(schema: torch.Tensor, targets: torch.Tensor)[source]

Bases: object

schema: torch.Tensor
targets: torch.Tensor
class transformers4rec.torch.masking.MaskSequence(hidden_size: int, padding_idx: int = 0, eval_on_last_item_seq_only: bool = True, **kwargs)[source]

Bases: transformers4rec.torch.utils.torch_utils.OutputSizeMixin, torch.nn.modules.module.Module

Base class to prepare masked items inputs/labels for language modeling tasks.

Transformer architectures can be trained in different ways. Depending of the training method, there is a specific masking schema. The masking schema sets the items to be predicted (labels) and mask (hide) their positions in the sequence so that they are not used by the Transformer layers for prediction.

We currently provide 4 different masking schemes out of the box:
  • Causal LM (clm)

  • Masked LM (mlm)

  • Permutation LM (plm)

  • Replacement Token Detection (rtd)

This class can be extended to add different a masking scheme.

Parameters
  • hidden_size – The hidden dimension of input tensors, needed to initialize trainable vector of masked positions.

  • pad_token (int, default = 0) – Index of the padding token used for getting batch of sequences with the same length

compute_masked_targets(item_ids: torch.Tensor, training=False)transformers4rec.torch.masking.MaskingInfo[source]

Method to prepare masked labels based on the sequence of item ids. It returns The true labels of masked positions and the related boolean mask. And the attributes of the class mask_schema and masked_targets are updated to be re-used in other modules.

Parameters
  • item_ids (torch.Tensor) – The sequence of input item ids used for deriving labels of next item prediction task.

  • training (bool) – Flag to indicate whether we are in Training mode or not. During training, the labels can be any items within the sequence based on the selected masking task. During evaluation, we are predicting the last item in the sequence.

Returns

Return type

Tuple[MaskingSchema, MaskedTargets]

apply_mask_to_inputs(inputs: torch.Tensor, schema: torch.Tensor)torch.Tensor[source]

Control the masked positions in the inputs by replacing the true interaction by a learnable masked embedding.

Parameters
  • inputs (torch.Tensor) – The 3-D tensor of interaction embeddings resulting from the ops: TabularFeatures + aggregation + projection(optional)

  • schema (MaskingSchema) – The boolean mask indicating masked positions.

predict_all(item_ids: torch.Tensor)transformers4rec.torch.masking.MaskingInfo[source]

Prepare labels for all next item predictions instead of last-item predictions in a user’s sequence.

Parameters

item_ids (torch.Tensor) – The sequence of input item ids used for deriving labels of next item prediction task.

Returns

Return type

Tuple[MaskingSchema, MaskedTargets]

forward(inputs: torch.Tensor, item_ids: torch.Tensor, training=False)torch.Tensor[source]
forward_output_size(input_size)[source]
transformer_required_arguments()Dict[str, Any][source]
transformer_optional_arguments()Dict[str, Any][source]
property transformer_arguments

Prepare additional arguments to pass to the Transformer forward methods.

class transformers4rec.torch.masking.CausalLanguageModeling(hidden_size: int, padding_idx: int = 0, eval_on_last_item_seq_only: bool = True, train_on_last_item_seq_only: bool = False, **kwargs)[source]

Bases: transformers4rec.torch.masking.MaskSequence

In Causal Language Modeling (clm) you predict the next item based on past positions of the sequence. Future positions are masked.

Parameters
  • hidden_size (int) – The hidden dimension of input tensors, needed to initialize trainable vector of masked positions.

  • padding_idx (int, default = 0) – Index of padding item used for getting batch of sequences with the same length

  • eval_on_last_item_seq_only (bool, default = True) – Predict only last item during evaluation

  • train_on_last_item_seq_only (predict only last item during training) –

apply_mask_to_inputs(inputs: torch.Tensor, mask_schema: torch.Tensor)torch.Tensor[source]
class transformers4rec.torch.masking.MaskedLanguageModeling(hidden_size: int, padding_idx: int = 0, eval_on_last_item_seq_only: bool = True, mlm_probability: float = 0.15, **kwargs)[source]

Bases: transformers4rec.torch.masking.MaskSequence

In Masked Language Modeling (mlm) you randomly select some positions of the sequence to be predicted, which are masked. During training, the Transformer layer is allowed to use positions on the right (future info). During inference, all past items are visible for the Transformer layer, which tries to predict the next item.

Parameters
  • hidden_size (int) – The hidden dimension of input tensors, needed to initialize trainable vector of masked positions.

  • padding_idx (int, default = 0) – Index of padding item used for getting batch of sequences with the same length

  • eval_on_last_item_seq_only (bool, default = True) – Predict only last item during evaluation

  • mlm_probability (Optional[float], default = 0.15) – Probability of an item to be selected (masked) as a label of the given sequence. p.s. We enforce that at least one item is masked for each sequence, so that the network can learn something with it.

class transformers4rec.torch.masking.PermutationLanguageModeling(hidden_size: int, padding_idx: int = 0, eval_on_last_item_seq_only: bool = True, plm_probability: float = 0.16666666666666666, max_span_length: int = 5, permute_all: bool = False, **kwargs)[source]

Bases: transformers4rec.torch.masking.MaskSequence

In Permutation Language Modeling (plm) you use a permutation factorization at the level of the self-attention layer to define the accessible bidirectional context.

Parameters
  • hidden_size (int) – The hidden dimension of input tensors, needed to initialize trainable vector of masked positions.

  • padding_idx (int, default = 0) – Index of padding item used for getting batch of sequences with the same length

  • eval_on_last_item_seq_only (bool, default = True) – Predict only last item during evaluation

  • max_span_length (int) – maximum length of a span of masked items

  • plm_probability (float) – The ratio of surrounding items to unmask to define the context of the span-based prediction segment of items

  • permute_all (bool) – Compute partial span-based prediction (=False) or not.

compute_masked_targets(item_ids: torch.Tensor, training=False)transformers4rec.torch.masking.MaskingInfo[source]
transformer_required_arguments()Dict[str, Any][source]
class transformers4rec.torch.masking.ReplacementLanguageModeling(hidden_size: int, padding_idx: int = 0, eval_on_last_item_seq_only: bool = True, sample_from_batch: bool = False, **kwargs)[source]

Bases: transformers4rec.torch.masking.MaskedLanguageModeling

Replacement Language Modeling (rtd) you use MLM to randomly select some items, but replace them by random tokens. Then, a discriminator model (that can share the weights with the generator or not), is asked to classify whether the item at each position belongs or not to the original sequence. The generator-discriminator architecture was jointly trained using Masked LM and RTD tasks.

Parameters
  • hidden_size (int) – The hidden dimension of input tensors, needed to initialize trainable vector of masked positions.

  • padding_idx (int, default = 0) – Index of padding item used for getting batch of sequences with the same length

  • eval_on_last_item_seq_only (bool, default = True) – Predict only last item during evaluation

  • sample_from_batch (bool) – Whether to sample replacement item ids from the same batch or not

get_fake_tokens(itemid_seq, target_flat, logits)[source]

Second task of RTD is binary classification to train the discriminator. The task consists of generating fake data by replacing [MASK] positions with random items, ELECTRA discriminator learns to detect fake replacements.

Parameters
  • itemid_seq (torch.Tensor of shape (bs, max_seq_len)) – input sequence of item ids

  • target_flat (torch.Tensor of shape (bs*max_seq_len)) – flattened masked label sequences

  • logits (torch.Tensor of shape (#pos_item, vocab_size or #pos_item),) – mlm probabilities of positive items computed by the generator model. The logits are over the whole corpus if sample_from_batch = False, over the positive items (masked) of the current batch otherwise

Returns

  • corrupted_inputs (torch.Tensor of shape (bs, max_seq_len)) – input sequence of item ids with fake replacement

  • discriminator_labels (torch.Tensor of shape (bs, max_seq_len)) – binary labels to distinguish between original and replaced items

  • batch_updates (torch.Tensor of shape (#pos_item)) – the indices of replacement item within the current batch if sample_from_batch is enabled

sample_from_softmax(logits: torch.Tensor)torch.Tensor[source]

Sampling method for replacement token modeling (ELECTRA)

Parameters

logits (torch.Tensor(pos_item, vocab_size)) – scores of probability of masked positions returned by the generator model

Returns

samples – ids of replacements items.

Return type

torch.Tensor(#pos_item)

transformers4rec.torch.ranking_metric module

class transformers4rec.torch.ranking_metric.RankingMetric(top_ks=None, labels_onehot=False)[source]

Bases: torchmetrics.metric.Metric

Metric wrapper for computing ranking metrics@K for session-based task.

Parameters
  • top_ks (list, default [2, 5])) – list of cutoffs

  • labels_onehot (bool) – Enable transform the labels to one-hot representation

update(preds: torch.Tensor, target: torch.Tensor, **kwargs)[source]
compute()[source]
class transformers4rec.torch.ranking_metric.PrecisionAt(top_ks=None, labels_onehot=False)[source]

Bases: transformers4rec.torch.ranking_metric.RankingMetric

class transformers4rec.torch.ranking_metric.RecallAt(top_ks=None, labels_onehot=False)[source]

Bases: transformers4rec.torch.ranking_metric.RankingMetric

class transformers4rec.torch.ranking_metric.AvgPrecisionAt(top_ks=None, labels_onehot=False)[source]

Bases: transformers4rec.torch.ranking_metric.RankingMetric

class transformers4rec.torch.ranking_metric.DCGAt(top_ks=None, labels_onehot=False)[source]

Bases: transformers4rec.torch.ranking_metric.RankingMetric

class transformers4rec.torch.ranking_metric.NDCGAt(top_ks=None, labels_onehot=False)[source]

Bases: transformers4rec.torch.ranking_metric.RankingMetric

class transformers4rec.torch.ranking_metric.MeanRecipricolRankAt(top_ks=None, labels_onehot=False)[source]

Bases: transformers4rec.torch.ranking_metric.RankingMetric

transformers4rec.torch.trainer module

transformers4rec.torch.typing module

Module contents