merlin.models.tf.MatrixFactorizationModel
-
merlin.models.tf.
MatrixFactorizationModel
(schema: merlin.schema.schema.Schema, dim: int, query_id_tag=<Tags.USER_ID: 'user_id'>, item_id_tag=<Tags.ITEM_ID: 'item_id'>, embeddings_initializers: Optional[Dict[str, Callable[[Any], None]]] = None, post: Optional[Union[merlin.models.tf.blocks.core.base.Block, str, Sequence[str]]] = None, prediction_tasks: Optional[Union[merlin.models.tf.prediction_tasks.base.PredictionTask, List[merlin.models.tf.prediction_tasks.base.PredictionTask], merlin.models.tf.prediction_tasks.base.ParallelPredictionBlock]] = None, logits_temperature: float = 1.0, loss: Optional[Union[str, keras.losses.Loss]] = 'bpr', metrics: Union[Sequence[Union[keras.metrics.base_metric.Metric, Type[keras.metrics.base_metric.Metric]]], keras.metrics.base_metric.Metric, Type[keras.metrics.base_metric.Metric]] = [RecallAt( (total): <tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0> (count): <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0> (_fn_kwargs): Dict() ), MRRAt( (total): <tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0> (count): <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0> (_fn_kwargs): Dict() ), NDCGAt( (total): <tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0> (count): <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0> (_fn_kwargs): Dict() ), AvgPrecisionAt( (total): <tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0> (count): <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0> (_fn_kwargs): Dict() ), PrecisionAt( (total): <tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0> (count): <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0> (_fn_kwargs): Dict() )], samplers: Sequence[merlin.models.tf.blocks.sampling.base.ItemSampler] = (), **kwargs) → Union[merlin.models.tf.models.base.Model, merlin.models.tf.models.base.RetrievalModel][source] Builds a matrix factorization model.
- Example Usage::
mf = MatrixFactorizationModel(schema, dim=128) mf.compile(optimizer=”adam”) mf.fit(train_data, epochs=10)
- Parameters
schema (Schema) – The Schema with the input features
dim (int) – The dimension of the embeddings.
query_id_tag (Tag) – The tag to select query features, by default Tags.USER
item_id_tag (Tag) – The tag to select item features, by default Tags.ITEM
embeddings_initializers (Dict[str, Callable[[Any], None]]) – A dictionary of initializers for embeddings.
post (Optional[Block], optional) – The optional Block to apply on both outputs of Two-tower model
prediction_tasks (optional) – The optional PredictionTask or list of PredictionTask to apply on the model.
logits_temperature (float) – Parameter used to reduce model overconfidence, so that logits / T. Defaults to 1.
loss (Optional[LossType]) – Loss function. Defaults to bpr.
samplers (List[ItemSampler]) – List of samplers for negative sampling, by default [InBatchSampler()]
- Returns
- Return type
Union[Model, RetrievalModel]