# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_models_01-getting-started/nvidia_logo.png

Getting Started with Merlin Models: Develop a Model for MovieLens#

This notebook is created using the latest stable merlin-tensorflow container.

Overview#

Merlin Models is a library for training recommender models. Merlin Models let Data Scientists and ML Engineers easily train standard RecSys models on their own dataset, getting GPU-accelerated models with best practices baked into the library. This will also let researchers to build custom models by incorporating standard components of deep learning recommender models, and then benchmark their new models on example offline datasets. Merlin Models is part of the Merlin open source framework.

Core features are:

  • Many different recommender system architectures (tabular, two-tower, sequential) or tasks (binary, multi-class classification, multi-task)

  • Flexible APIs targeted to both production and research

  • Deep integration with NVIDIA Merlin platform, including NVTabular for ETL and Merlin Systems model serving

Learning objectives#

Downloading and preparing the dataset#

import os
import merlin.models.tf as mm

from merlin.datasets.entertainment import get_movielens
2022-05-03 09:40:44.983158: I tensorflow/core/platform/cpu_feature_guard.cc:152] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-03 09:40:45.942290: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 20268 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:4e:00.0, compute capability: 8.0

We provide the get_movielens() function as a convenience to download the dataset, perform simple preprocessing, and split the data into training and validation datasets.

input_path = os.environ.get("INPUT_DATA_DIR", os.path.expanduser("~/merlin-models-data/movielens/"))
train, valid = get_movielens(variant="ml-1m", path=input_path)
WARNING:tensorflow:From /models/merlin/models/utils/nvt_utils.py:14: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
WARNING:tensorflow:From /models/merlin/models/utils/nvt_utils.py:14: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
2022-05-03 09:40:46.457693: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /device:GPU:0 with 20268 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:4e:00.0, compute capability: 8.0
/usr/local/lib/python3.8/dist-packages/cudf/core/dataframe.py:1292: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
  warnings.warn(

Training the DLRM Model with Merlin Models#

We define the DLRM model, whose prediction task is a binary classification. From the schema, the categorical features are identified (and embedded) and the target column is also automatically inferred, because of the schema tags. We talk more about the schema in the next example notebook (02),

model = mm.DLRMModel(
    train.schema,
    embedding_dim=64,
    bottom_block=mm.MLPBlock([128, 64]),
    top_block=mm.MLPBlock([128, 64, 32]),
    prediction_tasks=mm.BinaryClassificationTask(train.schema),
)

model.compile(optimizer="adam")

Next, we train the model.

model.fit(train, batch_size=1024)
2022-05-03 09:40:47.887135: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
2022-05-03 09:40:49.064501: I tensorflow/stream_executor/cuda/cuda_blas.cc:1804] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
782/782 [==============================] - 12s 9ms/step - rating_binary/binary_classification_task/precision: 0.7294 - rating_binary/binary_classification_task/recall: 0.8198 - rating_binary/binary_classification_task/binary_accuracy: 0.7214 - rating_binary/binary_classification_task/auc: 0.7843 - loss: 0.5477 - regularization_loss: 0.0000e+00 - total_loss: 0.5477
<keras.callbacks.History at 0x7f1c2d278b80>

We evaluate the model…

metrics = model.evaluate(valid, batch_size=1024, return_dict=True)
2022-05-03 09:41:02.880315: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: cond/branch_executed/_19
196/196 [==============================] - 2s 5ms/step - rating_binary/binary_classification_task/precision: 0.7409 - rating_binary/binary_classification_task/recall: 0.8057 - rating_binary/binary_classification_task/binary_accuracy: 0.7264 - rating_binary/binary_classification_task/auc: 0.7936 - loss: 0.5384 - regularization_loss: 0.0000e+00 - total_loss: 0.5384

… and check the evaluation metrics. We use by default typical binary classification metrics – Precision, Recall, Accuracy and AUC. But you can also provide your own metrics list by setting BinaryClassificationTask(..., metrics=[]).

metrics
{'rating_binary/binary_classification_task/precision': 0.7408556938171387,
 'rating_binary/binary_classification_task/recall': 0.8057063221931458,
 'rating_binary/binary_classification_task/binary_accuracy': 0.7263810634613037,
 'rating_binary/binary_classification_task/auc': 0.7936474680900574,
 'loss': 0.5632060170173645,
 'regularization_loss': 0.0,
 'total_loss': 0.5632060170173645}

Conclusion#

Merlin Models enables users to define and train a deep learning recommeder model with only 3 commands.

model = mm.DLRMModel(
    train.schema,
    embedding_dim=64,
    bottom_block=mm.MLPBlock([128, 64]),
    top_block=mm.MLPBlock([128, 64, 32]),
    prediction_tasks=mm.BinaryClassificationTask(
        train.schema.select_by_tag(Tags.TARGET).column_names[0]
    ),
)
model.compile(optimizer="adam")
model.fit(train, batch_size=1024)

Next steps#

In the next example notebooks, we will show how the integration with NVTabular and how to explore different recommender models.