# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ================================
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_models_06-define-your-own-architecture-with-merlin-models/nvidia_logo.png

Taking the Next Step with Merlin Models: Define Your Own Architecture

This notebook is created using the latest stable merlin-tensorflow container.

In Iterating over Deep Learning Models using Merlin Models, we conducted a benchmark of standard and deep learning-based ranking models provided by the high-level Merlin Models API. The library also includes the standard components of deep learning that let recsys practitioners and researchers to define custom models, train and export them for inference.

In this example, we combine pre-existing blocks and demonstrate how to create the DLRM architecture.

Learning objectives

  • Understand the building blocks of Merlin Models

  • Define a model architecture from scratch

Introduction to Merlin-models core building blocks

The Block is the core abstraction in Merlin Models and is the class from which all blocks inherit. The class extends the tf.keras.layers.Layer base class and implements a number of properties that simplify the creation of custom blocks and models. These properties include the Schema object for determining the embedding dimensions, input shapes, and output shapes. Additionally, the Block has a ModelContext instance to store and retrieve public variables and share them with other blocks in the same model as additional meta-data.

Before deep-diving into the definition of the DLRM architecture, let’s start by listing the core components you need to know to define a model from scratch:

Features Blocks

They include input blocks to process various inputs based on their types and shapes. Merlin Models supports three main blocks:

  • EmbeddingFeatures: Input block for embedding-lookups for categorical features.

  • SequenceEmbeddingFeatures: Input block for embedding-lookups for sequential categorical features (3D tensors).

  • ContinuousFeatures: Input block for continuous features.

Transformations Blocks

They include various operators commonly used to transform tensors in various parts of the model, such as:

  • ToDense: It takes a dictionary of raw input tensors and transforms the sparse tensors into dense tensors.

  • L2Norm: It takes a single or a dictionary of hidden tensors and applies an L2-normalization along a given axis.

  • LogitsTemperatureScaler: It scales the output tensor of predicted logits to lower the model’s confidence.

Aggregations Blocks

They include common aggregation operations to combine multiple tensors, such as:

  • ConcatFeatures: Concatenate dictionary of tensors along a given dimension.

  • StackFeatures: Stack dictionary of tensors along a given dimension.

  • CosineSimilarity: Calculate the cosine similarity between two tensors.

Connects Methods

The base class Block implements different connects methods that control how to link a given block to other blocks:

  • connect: Connect the block to other blocks sequentially. The output is a tensor returned by the last block.

  • connect_branch: Link the block to other blocks in parallel. The output is a dictionary containing the output tensor of each block.

  • connect_with_shortcut: Connect the block to other blocks sequentially and apply a skip connection with the block’s output.

  • connect_with_residual: Connect the block to other blocks sequentially and apply a residual sum with the block’s output.

Prediction Tasks

Merlin Models introduces the PredictionTask layer that defines the necessary blocks and transformation operations to compute the final prediction scores. It also provides the default loss and metrics related to the given prediction task.
Merlin Models supports the core tasks: BinaryClassificationTask, MultiClassClassificationTask, andRegressionTask. In addition to the preceding tasks, Merlin Models provides tasks that are specific to recommender systems: NextItemPredictionTask, and ItemRetrievalTask.

Implement the DLRM model with MovieLens-1M data

Now that we have introduced the core blocks of Merlin Models, let’s take a look at how we can combine them to define the DLRM architecture:

import tensorflow as tf
import merlin.models.tf as mm

from merlin.datasets.entertainment import get_movielens
from merlin.schema.tags import Tags
2022-09-14 20:23:06.071868: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-09-14 20:23:07.207541: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 16249 MB memory:  -> device: 0, name: Quadro GV100, pci bus id: 0000:2d:00.0, compute capability: 7.0

We use the get_movielens function to download, extract, and preprocess the MovieLens 1M dataset:

train, valid = get_movielens(variant="ml-1m")
/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(

We display the first five rows of the validation data and use them to check the outputs of each building block:

valid.head()
userId movieId title genres gender age occupation zipcode TE_age_rating TE_gender_rating TE_occupation_rating TE_zipcode_rating TE_movieId_rating TE_userId_rating rating_binary rating
0 1332 724 726 [2] 1 5 7 76 1.992826 -0.580296 1.166594 1.093124 0.526717 1.030316 1 5.0
1 4320 102 102 [2] 1 5 3 1424 1.985156 -0.599401 0.337638 0.394303 0.875647 1.060459 1 4.0
2 1351 581 583 [3, 7, 1] 1 5 16 1107 1.949597 -0.534489 0.687319 -0.252332 -0.630569 -0.574388 0 2.0
3 700 1070 1074 [7, 10, 2] 1 3 1 316 -1.045463 -0.576565 -0.751345 0.135090 -1.847183 0.218730 0 3.0
4 2366 130 130 [1] 1 4 4 1848 0.774309 -0.599401 -0.098956 0.586722 0.945900 0.524332 1 5.0

We convert the first five rows of the valid dataset to a batch of input tensors:

batch = mm.sample_batch(valid, batch_size=5, shuffle=False, include_targets=False)
batch["userId"]
<tf.Tensor: shape=(5, 1), dtype=int32, numpy=
array([[1332],
       [4320],
       [1351],
       [ 700],
       [2366]], dtype=int32)>

Define the inputs block

For the sake of simplicity, let’s create a schema with a subset of the following continuous and categorical features:

sub_schema = train.schema.select_by_name(
    [
        "userId",
        "movieId",
        "title",
        "gender",
        "TE_zipcode_rating",
        "TE_movieId_rating",
        "rating_binary",
    ]
)

We define the continuous layer based on the schema:

continuous_block = mm.ContinuousFeatures.from_schema(sub_schema, tags=Tags.CONTINUOUS)

We display the output tensor of the continuous block by using the data from the first batch. We can see the raw tensors of the continuous features:

continuous_block(batch)
{'TE_zipcode_rating': <tf.Tensor: shape=(5, 1), dtype=float32, numpy=
 array([[ 1.0931238 ],
        [ 0.39430296],
        [-0.25233188],
        [ 0.13508965],
        [ 0.5867218 ]], dtype=float32)>,
 'TE_movieId_rating': <tf.Tensor: shape=(5, 1), dtype=float32, numpy=
 array([[ 0.5267168],
        [ 0.8756474],
        [-0.6305694],
        [-1.8471828],
        [ 0.9459003]], dtype=float32)>}

We connect the continuous block to a MLPBlock instance to project them into the same dimensionality as the embedding width of categorical features:

deep_continuous_block = continuous_block.connect(mm.MLPBlock([64]))
deep_continuous_block(batch).shape
TensorShape([5, 64])

We define the categorical embedding block based on the schema:

embedding_block = mm.EmbeddingFeatures.from_schema(sub_schema)

We display the output tensor of the categorical embedding block using the data from the first batch. We can see the embeddings tensors of categorical features with a default dimension of 64:

embeddings = embedding_block(batch)
embeddings.keys(), embeddings["userId"].shape
(dict_keys(['userId', 'movieId', 'title', 'gender']), TensorShape([5, 64]))

Let’s store the continuous and categorical representations in a single dictionary using a ParallelBlock instance:

dlrm_input_block = mm.ParallelBlock(
    {"embeddings": embedding_block, "deep_continuous": deep_continuous_block}
)
print("Output shapes of DLRM input block:")
for key, val in dlrm_input_block(batch).items():
    print("\t%s : %s" % (key, val.shape))
Output shapes of DLRM input block:
	userId : (5, 64)
	movieId : (5, 64)
	title : (5, 64)
	gender : (5, 64)
	deep_continuous : (5, 64)

By looking at the output, we can see that the ParallelBlock class applies embedding and continuous blocks, in parallel, to the same input batch. Additionally, it merges the resulting tensors into one dictionary.

Define the interaction block

Now that we have a vector representation of each input feature, we will create the DLRM interaction block. It consists of three operations:

  • Apply a dot product between all continuous and categorical features to learn pairwise interactions.

  • Concat the resulting pairwise interaction with the deep representation of conitnuous features (skip-connection).

  • Apply an MLPBlock with a series of dense layers to the concatenated tensor.

First, we use the connect_with_shortcut method to create first two operations of the DLRM interaction block:

from merlin.models.tf.blocks.dlrm import DotProductInteractionBlock

dlrm_interaction = dlrm_input_block.connect_with_shortcut(
    DotProductInteractionBlock(), shortcut_filter=mm.Filter("deep_continuous"), aggregation="concat"
)

The Filter operation allows us to select the deep_continuous tensor from the dlrm_input_block outputs.

The following diagram provides a visualization of the operations that we constructed in the dlrm_interaction object.

../_images/residual_interaction.png
dlrm_interaction(batch)
<tf.Tensor: shape=(5, 74), dtype=float32, numpy=
array([[ 2.30220780e-02,  0.00000000e+00,  0.00000000e+00,
         7.56775588e-02,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  2.28175014e-01,  0.00000000e+00,
         3.19941580e-01,  0.00000000e+00,  2.78777480e-02,
         0.00000000e+00,  4.73161668e-01,  7.64442533e-02,
         2.35672563e-01,  2.91652471e-01,  1.66354105e-01,
         3.41646910e-01,  0.00000000e+00,  2.71705240e-01,
         1.37824044e-01,  0.00000000e+00,  2.48484939e-01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.39127672e-01,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         3.48613858e-02,  1.25599518e-01,  0.00000000e+00,
         4.45570499e-02,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.99817121e-02,  1.44250020e-02,
         0.00000000e+00,  1.90662712e-01,  0.00000000e+00,
         0.00000000e+00,  6.26910478e-02,  0.00000000e+00,
         0.00000000e+00,  2.94190586e-01,  1.50031894e-01,
         0.00000000e+00,  1.84342951e-01,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.04685947e-01,  6.66037053e-02,
         4.16741706e-03,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  4.50987183e-02, -2.46724137e-03,
        -6.76170439e-02, -9.77967530e-02, -9.46517102e-03,
         9.14053060e-03, -2.63647735e-02, -1.81730222e-02,
        -3.92094720e-03,  3.92234437e-02],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  3.07496935e-02,  9.01977122e-02,
         0.00000000e+00,  6.98774904e-02,  0.00000000e+00,
         1.15502082e-01,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  3.65737498e-01,  9.89171788e-02,
         2.00192884e-01,  2.92842209e-01,  0.00000000e+00,
         1.84250385e-01,  0.00000000e+00,  4.24410552e-02,
         1.90225601e-01,  1.20380744e-01,  1.70127630e-01,
         1.58379674e-01,  0.00000000e+00,  0.00000000e+00,
         9.06804949e-03,  7.82235116e-02,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.78655088e-01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.76252164e-02,  3.10025290e-02,  0.00000000e+00,
         0.00000000e+00,  2.63492391e-02,  1.61362186e-01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  4.64269891e-03,  0.00000000e+00,
         9.54076052e-02,  2.13437006e-01,  1.52400732e-01,
         0.00000000e+00,  1.64790913e-01,  6.50010034e-02,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  2.11483538e-01,  0.00000000e+00,
         7.20069092e-03,  1.05328977e-01,  0.00000000e+00,
         0.00000000e+00, -5.07842051e-03, -3.00218947e-02,
         4.93249968e-02,  6.40595555e-02,  2.05376633e-02,
        -3.37055735e-02,  2.52429601e-02,  3.80666833e-03,
        -2.01439932e-02, -5.69556700e-03],
       [ 2.18789969e-02,  3.36510167e-02,  1.95350796e-01,
         6.10727519e-02,  0.00000000e+00,  0.00000000e+00,
         1.02285758e-01,  0.00000000e+00,  4.84003574e-02,
         0.00000000e+00,  1.55193418e-01,  1.38503924e-01,
         7.77052790e-02,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  6.71746209e-03,
         0.00000000e+00,  1.04955822e-01,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.85306251e-01,  6.36328012e-02,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         2.10280612e-01,  1.58533588e-01,  0.00000000e+00,
         5.32950163e-02,  9.59217921e-03,  2.31691465e-01,
         0.00000000e+00,  0.00000000e+00,  1.55598149e-01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         5.84047437e-02,  2.57161334e-02,  1.14955440e-01,
         7.11445585e-02,  0.00000000e+00,  6.75121546e-02,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.76699877e-01,  0.00000000e+00,  0.00000000e+00,
         3.69086601e-02,  0.00000000e+00,  1.93824619e-03,
         4.37308550e-02,  0.00000000e+00,  1.04565844e-01,
         0.00000000e+00,  0.00000000e+00,  1.72027379e-01,
         1.87704325e-01, -6.01736642e-03, -9.64383222e-03,
         1.43044777e-02,  2.08113343e-04, -1.79108307e-02,
         4.82701696e-04, -1.83034837e-02, -3.67737515e-03,
        -3.17467004e-02, -3.01161502e-03],
       [ 1.05011471e-01,  0.00000000e+00,  5.52209735e-01,
         3.04437160e-01,  0.00000000e+00,  0.00000000e+00,
         2.06330895e-01,  6.28595725e-02,  7.87643418e-02,
         3.92730497e-02,  2.44794011e-01,  5.47987223e-01,
         8.38067681e-02,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.90068081e-01,
         0.00000000e+00,  3.45876396e-01,  1.88550934e-01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  3.29886556e-01,  2.34476715e-01,
         1.31867975e-01,  0.00000000e+00,  0.00000000e+00,
         3.70391816e-01,  3.36082906e-01,  0.00000000e+00,
         2.34773606e-01,  1.60487220e-01,  5.28671205e-01,
         1.17527368e-03,  0.00000000e+00,  1.98144943e-01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.11697763e-01,  2.85532802e-01,  2.98117757e-01,
         1.21483751e-01,  5.78673929e-02,  2.20516354e-01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         3.86492640e-01,  0.00000000e+00,  0.00000000e+00,
         1.85204111e-02,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  4.58851695e-01,
         0.00000000e+00,  0.00000000e+00,  3.30851346e-01,
         5.57816863e-01, -7.11362436e-03,  4.32709344e-02,
        -3.88916172e-02, -1.30204365e-01,  1.82463657e-02,
        -1.38023719e-02, -2.07279958e-02, -1.39129860e-02,
        -1.52076939e-02,  1.52949709e-03],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  3.72213125e-03,  8.35260451e-02,
         0.00000000e+00,  1.10449523e-01,  0.00000000e+00,
         1.71817169e-01,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  4.42635894e-01,  1.10036075e-01,
         2.37903923e-01,  3.38033378e-01,  3.05014253e-02,
         2.42390379e-01,  0.00000000e+00,  9.20888484e-02,
         2.09882915e-01,  9.92736742e-02,  2.11230367e-01,
         1.30059257e-01,  0.00000000e+00,  0.00000000e+00,
         3.49052809e-02,  6.48877323e-02,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.62512243e-01,
         0.00000000e+00,  1.71390176e-02,  0.00000000e+00,
         2.54175626e-02,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  2.92394385e-02,  1.58785328e-01,
         0.00000000e+00,  1.14827603e-02,  0.00000000e+00,
         0.00000000e+00,  1.62667409e-02,  0.00000000e+00,
         8.40894058e-02,  2.61706412e-01,  1.75590441e-01,
         0.00000000e+00,  1.94019809e-01,  4.01518792e-02,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  2.24219605e-01,  0.00000000e+00,
         7.74760824e-03,  9.18345302e-02,  0.00000000e+00,
         0.00000000e+00,  4.90122102e-03, -2.59129852e-02,
        -2.30328739e-02,  1.15348995e-02,  1.12771615e-02,
        -3.55871092e-03, -3.02606113e-02,  2.49744840e-02,
         1.98186152e-02,  1.06542893e-02]], dtype=float32)>

Then, we project the learned interaction using a series of dense layers:

deep_dlrm_interaction = dlrm_interaction.connect(mm.MLPBlock([64, 128, 512]))
deep_dlrm_interaction(batch)
<tf.Tensor: shape=(5, 512), dtype=float32, numpy=
array([[0.        , 0.        , 0.06550899, ..., 0.00943315, 0.        ,
        0.02917717],
       [0.        , 0.00613478, 0.02798904, ..., 0.00486526, 0.        ,
        0.0473245 ],
       [0.        , 0.        , 0.04229244, ..., 0.        , 0.        ,
        0.02768494],
       [0.        , 0.        , 0.09441784, ..., 0.        , 0.        ,
        0.04710757],
       [0.        , 0.        , 0.03271887, ..., 0.00290654, 0.        ,
        0.04926534]], dtype=float32)>

Define the Prediction block

At this stage, we have created the DLRM block that accepts a dictionary of categorical and continuous tensors as input. The output of this block is the interaction representation vector of shape 512. The next step is to use this hidden representation to conduct a given prediction task. In our case, we use the label rating_binary and the objective is: to predict if a user A will give a high rating to a movie B or not.

We use the BinaryClassificationTask class and evaluate the performances using the AUC metric. We also use the LogitsTemperatureScaler block as a pre-transformation operation that scales the logits returned by the task before computing the loss and metrics:

from merlin.models.tf.transforms.bias import LogitsTemperatureScaler

binary_task = mm.BinaryClassificationTask(
    sub_schema,
    pre=LogitsTemperatureScaler(temperature=2),
)

Define, train, and evaluate the final DLRM Model

We connect the deep DLRM interaction to the binary task and the method automatically generates the Model class for us. We note that the Model class inherits from tf.keras.Model class:

model = mm.Model(deep_dlrm_interaction, binary_task)
type(model)
merlin.models.tf.models.base.Model

We train the model using the built-in tf.keras fit method:

model.compile(optimizer="adam", metrics=[tf.keras.metrics.AUC()])
model.fit(train, batch_size=1024, epochs=1)
782/782 [==============================] - 16s 14ms/step - loss: 0.6473 - auc: 0.7240 - regularization_loss: 0.0000e+00
<keras.callbacks.History at 0x7f3f5856f250>

Let’s check out the model evaluation scores:

metrics = model.evaluate(valid, batch_size=1024, return_dict=True)
metrics
196/196 [==============================] - 3s 9ms/step - loss: 0.6393 - auc: 0.7334 - regularization_loss: 0.0000e+00
{'loss': 0.6393412947654724,
 'auc': 0.7334249019622803,
 'regularization_loss': 0.0}

Note that the evaluate() progress bar shows the loss score for every batch, whereas the final loss stored in the dictionary represents the total loss across all batches.

Save the model so we can use it for serving predictions in production or for resuming training with new observations:

model.save("custom_dlrm")
WARNING:absl:Function `_wrapped_model` contains input name(s) TE_age_rating, TE_gender_rating, TE_movieId_rating, TE_occupation_rating, TE_userId_rating, TE_zipcode_rating, movieId, userId with unsupported characters which will be renamed to te_age_rating, te_gender_rating, te_movieid_rating, te_occupation_rating, te_userid_rating, te_zipcode_rating, movieid, userid in the SavedModel.
WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, logits_temperature_scaler_layer_call_fn, logits_temperature_scaler_layer_call_and_return_conditional_losses, output_layer_layer_call_fn while saving (showing 5 of 66). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: custom_dlrm/assets
INFO:tensorflow:Assets written to: custom_dlrm/assets

Conclusion

Merlin Models provides common and state-of-the-art RecSys architectures in a high-level API as well as all the required low-level building blocks for you to create your own architecture (input blocks, MLP layers, prediction tasks, loss functions, etc.). In this example, we explored a subset of these pre-existing blocks to create the DLRM model, but you can view our documentation to discover more. You can also contribute to the library by submitting new RecSys architectures and custom building Blocks.

Next steps

To learn more about how to deploy the trained DLRM model, please visit Merlin Systems library and execute the Serving-Ranking-Models-With-Merlin-Systems.ipynb notebook that deploys an ensemble of a NVTabular Workflow and a trained model from Merlin Models to Triton Inference Server.