merlin.models.tf.DLRMModel#

merlin.models.tf.DLRMModel(schema: merlin.schema.schema.Schema, *, embeddings: Optional[merlin.models.tf.core.base.Block] = None, embedding_dim: Optional[int] = None, embedding_options: Optional[merlin.models.tf.inputs.embedding.EmbeddingOptions] = None, bottom_block: Optional[merlin.models.tf.core.base.Block] = None, top_block: Optional[merlin.models.tf.core.base.Block] = None, prediction_tasks: Optional[Union[merlin.models.tf.prediction_tasks.base.PredictionTask, List[merlin.models.tf.prediction_tasks.base.PredictionTask], merlin.models.tf.prediction_tasks.base.ParallelPredictionBlock, ModelOutput, merlin.models.tf.core.combinators.ParallelBlock]] = None) merlin.models.tf.models.base.Model[source]#

DLRM-model architecture.

Example Usage::

dlrm = DLRMModel(schema, embedding_dim=64, bottom_block=MLPBlock([256, 64])) dlrm.compile(optimizer=”adam”) dlrm.fit(train_data, epochs=10)

References

[1] Naumov, Maxim, et al. “Deep learning recommendation model for

personalization and recommendation systems.” arXiv preprint arXiv:1906.00091 (2019).

Parameters
  • schema (Schema) – The Schema with the input features

  • embeddings (Optional[Block]) – Optional block for categorical embeddings. Overrides the default embeddings inferred from the schema.

  • embedding_dim (int) – Dimension of the embeddings

  • embedding_options (Optional[EmbeddingOptions]) – Configuration for categorical embeddings. Alternatively use the embeddings parameter.

  • bottom_block (Block) – The Block that combines the continuous features (typically a MLPBlock)

  • top_block (Optional[Block], optional) – The optional Block that combines the outputs of bottom layer and of the factorization machine layer, by default None

  • prediction_tasks (Optional[Union[PredictionTask,List[PredictionTask],) – ParallelPredictionBlock,ModelOutputType] The prediction tasks to be used, by default this will be inferred from the Schema. For custom prediction tasks we recommending using OutputBlock and blocks based on ModelOutput than the ones based in PredictionTask (that will be deprecated).

Return type

Model