transformers4rec.tf.model package
Submodules
transformers4rec.tf.model.head module
transformers4rec.tf.model.model module
transformers4rec.tf.model.prediction_task module
-
class
transformers4rec.tf.model.prediction_task.
BinaryClassificationTask
(*args, **kwargs)[source] Bases:
transformers4rec.tf.model.base.PredictionTask
-
DEFAULT_LOSS
= <keras.losses.BinaryCrossentropy object>
-
DEFAULT_METRICS
= (<class 'keras.metrics.Precision'>, <class 'keras.metrics.Recall'>, <class 'keras.metrics.BinaryAccuracy'>, <class 'keras.metrics.AUC'>)
-
-
class
transformers4rec.tf.model.prediction_task.
RegressionTask
(*args, **kwargs)[source] Bases:
transformers4rec.tf.model.base.PredictionTask
-
DEFAULT_LOSS
= <keras.losses.MeanSquaredError object>
-
DEFAULT_METRICS
= (<class 'keras.metrics.RootMeanSquaredError'>,)
-
-
class
transformers4rec.tf.model.prediction_task.
NextItemPredictionTask
(*args, **kwargs)[source] Bases:
transformers4rec.tf.model.base.PredictionTask
Next-item prediction task.
- Parameters
loss – Loss function. SparseCategoricalCrossentropy()
metrics – List of RankingMetrics to be evaluated.
prediction_metrics – List of Keras metrics used to summarize the predictions.
label_metrics – List of Keras metrics used to summarize the labels.
loss_metrics – List of Keras metrics used to summarize the loss.
name – Optional task name.
target_dim (int) – Dimension of the target.
weight_tying (bool) – The item id embedding table weights are shared with the prediction network layer.
item_embedding_table (tf.Variable) – Variable of embedding table for the item.
softmax_temperature (float) – Softmax temperature, used to reduce model overconfidence, so that softmax(logits / T). Value 1.0 reduces to regular softmax.
-
DEFAULT_LOSS
= <keras.losses.SparseCategoricalCrossentropy object>
-
DEFAULT_METRICS
= (NDCGAt(), AvgPrecisionAt(), RecallAt())