merlin.models.tf.NextItemPredictionTask
-
merlin.models.tf.
NextItemPredictionTask
(schema: merlin.schema.schema.Schema, weight_tying: bool = True, extra_pre_call: Optional[merlin.models.tf.core.base.Block] = None, target_name: Optional[str] = None, task_name: Optional[str] = None, task_block: Optional[keras.engine.base_layer.Layer] = None, logits_temperature: float = 1.0, l2_normalization: bool = False, sampled_softmax: bool = False, num_sampled: int = 100, min_sampled_id: int = 0, post_logits: Optional[merlin.models.tf.core.base.Block] = None) → merlin.models.tf.prediction_tasks.classification.MultiClassClassificationTask[source] Function to create the NextItemPrediction task with the right parameters. :param schema: The schema object including features to use and their properties. :type schema: Schema :param weight_tying: The item_id embedding weights are shared with the prediction network layer.
Defaults to True
- Parameters
extra_pre_call (Optional[PredictionBlock]) – Optional extra pre-call block. Defaults to None.
target_name (Optional[str]) – If specified, name of the target tensor to retrieve from dataloader. Defaults to None.
task_name (Optional[str]) – name of the task. Defaults to None.
task_block (Block) – The Block that applies additional layers op to inputs. Defaults to None.
logits_temperature (float) – Parameter used to reduce the model overconfidence, so that logits / T. Defaults to 1.
l2_normalization (bool) – Apply L2 normalization before computing dot interactions. Defaults to False.
sampled_softmax (bool) – Compute the logits scores over all items of the catalog or generate a subset of candidates Defaults to False
num_sampled (int) – When sampled_softmax is enabled, specify the number of negative candidates to generate for each batch Defaults to 100
min_sampled_id (int) – The minimum id value to be sampled. Useful to ignore the first categorical encoded ids, which are usually reserved for <nulls>, out-of-vocabulary or padding. Defaults to 0.
post_logits (Optional[PredictionBlock]) – Optional extra pre-call block for post-processing the logits, by default None. You can for example use post_logits = mm.PopularitySamplingBlock(item_fequency) for populariy sampling correction.
- Returns
The next item prediction task
- Return type