# Copyright 2023 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_models_01-getting-started/nvidia_logo.png

Getting Started with Merlin Models: Develop a Model for MovieLens#

This notebook is created using the latest stable merlin-tensorflow container.

Overview#

Merlin Models is a library for training recommender models. Merlin Models let Data Scientists and ML Engineers easily train standard RecSys models on their own dataset, getting GPU-accelerated models with best practices baked into the library. This will also let researchers to build custom models by incorporating standard components of deep learning recommender models, and then benchmark their new models on example offline datasets. Merlin Models is part of the Merlin open source framework.

Core features are:

  • Many different recommender system architectures (tabular, two-tower, sequential) or tasks (binary, multi-class classification, multi-task)

  • Flexible APIs targeted to both production and research

  • Deep integration with NVIDIA Merlin platform, including NVTabular for ETL and Merlin Systems model serving

Learning objectives#

Downloading and preparing the dataset#

import os
import merlin.models.tf as mm

from merlin.datasets.entertainment import get_movielens
2023-01-10 12:07:22.054533: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-01-10 12:07:24.185109: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-10 12:07:26.250320: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 16249 MB memory:  -> device: 0, name: Quadro GV100, pci bus id: 0000:15:00.0, compute capability: 7.0
2023-01-10 12:07:26.251401: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 30625 MB memory:  -> device: 1, name: Quadro GV100, pci bus id: 0000:2d:00.0, compute capability: 7.0

We provide the get_movielens() function as a convenience to download the dataset, perform simple preprocessing, and split the data into training and validation datasets.

input_path = os.environ.get("INPUT_DATA_DIR", os.path.expanduser("~/merlin-models-data/movielens/"))
train, valid = get_movielens(variant="ml-1m", path=input_path)
/home/gmoreira/projects/nvidia/nvidia_merlin/core/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/home/gmoreira/projects/nvidia/nvidia_merlin/core/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(

Training the DLRM Model with Merlin Models#

We define the DLRM model, whose prediction task is a binary classification. From the schema, the categorical features are identified (and embedded) and the target columns are also automatically inferred, because of the schema tags. We talk more about the schema in the next example notebook (02),

# Ignoring the rating regression target column, to keep only the rating_binary target column for prediction
schema = train.schema.without(['rating'])
model = mm.DLRMModel(
    schema,
    embedding_dim=64,
    bottom_block=mm.MLPBlock([128, 64]),
    top_block=mm.MLPBlock([128, 64, 32]),
    prediction_tasks=mm.OutputBlock(schema),
)

model.compile(optimizer="adam")

Next, we train the model.

model.fit(train, batch_size=1024)
782/782 [==============================] - 19s 20ms/step - loss: 0.5407 - precision: 0.7333 - recall: 0.8212 - binary_accuracy: 0.7252 - auc: 0.7903 - regularization_loss: 0.0000e+00 - loss_batch: 0.5407
<keras.callbacks.History at 0x7f2772a828e0>

We evaluate the model…

metrics = model.evaluate(valid, batch_size=1024, return_dict=True)
/home/gmoreira/projects/nvidia/nvidia_merlin/core/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/home/gmoreira/projects/nvidia/nvidia_merlin/core/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
196/196 [==============================] - 3s 13ms/step - loss: 0.5254 - precision: 0.7435 - recall: 0.8228 - binary_accuracy: 0.7356 - auc: 0.8056 - regularization_loss: 0.0000e+00 - loss_batch: 0.5255

… and check the evaluation metrics. As there are two columns tagged as target in the schema (rating_binary and rating), the model has two heads (multi-task learning), one for binary classification and the other for regression.
You can see from the list below that default metrics are provided – Precision, Recall, Accuracy and AUC for binary classification and RMSE for regression tasks. You can also provide your own metrics in model.compile().

metrics
{'loss': 0.5254488587379456,
 'precision': 0.7434586882591248,
 'recall': 0.8228089809417725,
 'binary_accuracy': 0.7355642318725586,
 'auc': 0.8055890202522278,
 'regularization_loss': 0.0,
 'loss_batch': 0.5332178473472595}

Conclusion#

Merlin Models enables users to define and train a deep learning recommeder model with only 3 commands.

model = mm.DLRMModel(
    schema,
    embedding_dim=64,
    bottom_block=mm.MLPBlock([128, 64]),
    top_block=mm.MLPBlock([128, 64, 32]),
    prediction_tasks=mm.OutputBlock(schema),
)
model.compile(optimizer="adam")
model.fit(train, batch_size=1024)

Next steps#

In the next example notebooks, we will show how the integration with NVTabular and how to explore different recommender models.