transformers4rec.tf.model package

Submodules

transformers4rec.tf.model.head module

transformers4rec.tf.model.model module

transformers4rec.tf.model.prediction_task module

transformers4rec.tf.model.prediction_task.name_fn(name, inp)[source]
class transformers4rec.tf.model.prediction_task.BinaryClassificationTask(*args, **kwargs)[source]

Bases: transformers4rec.tf.model.base.PredictionTask

DEFAULT_LOSS = <keras.losses.BinaryCrossentropy object>
DEFAULT_METRICS = (<class 'keras.metrics.Precision'>, <class 'keras.metrics.Recall'>, <class 'keras.metrics.BinaryAccuracy'>, <class 'keras.metrics.AUC'>)
class transformers4rec.tf.model.prediction_task.RegressionTask(*args, **kwargs)[source]

Bases: transformers4rec.tf.model.base.PredictionTask

DEFAULT_LOSS = <keras.losses.MeanSquaredError object>
DEFAULT_METRICS = (<class 'keras.metrics.RootMeanSquaredError'>,)
class transformers4rec.tf.model.prediction_task.NextItemPredictionTask(*args, **kwargs)[source]

Bases: transformers4rec.tf.model.base.PredictionTask

Next-item prediction task.

Parameters
  • loss – Loss function. SparseCategoricalCrossentropy()

  • metrics – List of RankingMetrics to be evaluated.

  • prediction_metrics – List of Keras metrics used to summarize the predictions.

  • label_metrics – List of Keras metrics used to summarize the labels.

  • loss_metrics – List of Keras metrics used to summarize the loss.

  • name – Optional task name.

  • target_dim (int) – Dimension of the target.

  • weight_tying (bool) – The item id embedding table weights are shared with the prediction network layer.

  • item_embedding_table (tf.Variable) – Variable of embedding table for the item.

  • softmax_temperature (float) – Softmax temperature, used to reduce model overconfidence, so that softmax(logits / T). Value 1.0 reduces to regular softmax.

DEFAULT_LOSS = <keras.losses.SparseCategoricalCrossentropy object>
DEFAULT_METRICS = (NDCGAt(   (top_ks): List(     (0): 10     (1): 20   ) ), AvgPrecisionAt(   (top_ks): List(     (0): 10     (1): 20   ) ), RecallAt(   (top_ks): List(     (0): 10     (1): 20   ) ))
build(input_shape, body, inputs=None)[source]
call(inputs, **kwargs)[source]
remove_pad_3d(inp_tensor, non_pad_mask)[source]
compute_loss(inputs, targets=None, compute_metrics: bool = True, call_task: bool = True, sample_weight: Optional[tensorflow.python.framework.ops.Tensor] = None, **kwargs)tensorflow.python.framework.ops.Tensor[source]
calculate_metrics(predictions, targets=None, sample_weight=None, forward=True, loss=None)[source]
metric_results(mode: Optional[str] = None)Dict[str, tensorflow.python.framework.ops.Tensor][source]

Module contents