merlin.models.tf.NextItemPredictionTask

merlin.models.tf.NextItemPredictionTask(schema: merlin.schema.schema.Schema, weight_tying: bool = True, extra_pre_call: Optional[merlin.models.tf.core.base.Block] = None, target_name: Optional[str] = None, task_name: Optional[str] = None, task_block: Optional[keras.engine.base_layer.Layer] = None, logits_temperature: float = 1.0, l2_normalization: bool = False, sampled_softmax: bool = False, num_sampled: int = 100, min_sampled_id: int = 0, post_logits: Optional[merlin.models.tf.core.base.Block] = None)merlin.models.tf.prediction_tasks.classification.MultiClassClassificationTask[source]

Function to create the NextItemPrediction task with the right parameters. :param schema: The schema object including features to use and their properties. :type schema: Schema :param weight_tying: The item_id embedding weights are shared with the prediction network layer.

Defaults to True

Parameters
  • extra_pre_call (Optional[PredictionBlock]) – Optional extra pre-call block. Defaults to None.

  • target_name (Optional[str]) – If specified, name of the target tensor to retrieve from dataloader. Defaults to None.

  • task_name (Optional[str]) – name of the task. Defaults to None.

  • task_block (Block) – The Block that applies additional layers op to inputs. Defaults to None.

  • logits_temperature (float) – Parameter used to reduce the model overconfidence, so that logits / T. Defaults to 1.

  • l2_normalization (bool) – Apply L2 normalization before computing dot interactions. Defaults to False.

  • sampled_softmax (bool) – Compute the logits scores over all items of the catalog or generate a subset of candidates Defaults to False

  • num_sampled (int) – When sampled_softmax is enabled, specify the number of negative candidates to generate for each batch Defaults to 100

  • min_sampled_id (int) – The minimum id value to be sampled. Useful to ignore the first categorical encoded ids, which are usually reserved for <nulls>, out-of-vocabulary or padding. Defaults to 0.

  • post_logits (Optional[PredictionBlock]) – Optional extra pre-call block for post-processing the logits, by default None. You can for example use post_logits = mm.PopularitySamplingBlock(item_fequency) for populariy sampling correction.

Returns

The next item prediction task

Return type

PredictionTask