# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ================================
Iterating over Deep Learning Models using Merlin Models
In this example, we’ll define several popular deep learning-based model architectures, train, and evaluate them and show how Merlin Models simplifies and eases this common and iterative process.
In this example notebook, we use the Ali-CCP: Alibaba Click and Conversion Prediction dataset to build our recommender system models. This is a dataset gathered from real-world traffic logs of the recommender system in Taobao, the largest online retail platform in the world. To download the raw Ali-CCP training and test datasets visit tianchi.aliyun.com. We have curated the raw dataset via this get_aliccp() function and generated the parquet files that we use in this example. If you want to use synthetic ali-ccp dataset instead, you can simple run the following command:
from merlin.datasets.synthetic import generate_data
train, valid = generate_data("aliccp-raw", 1000000, set_sizes=(0.7, 0.3))
Learning objectives
Preparing the data with NVTabular
Training different deep learning-based recommender models with Merlin Models
Importing Libraries
Let’s start with importing the libraries that we’ll use in this notebook.
import os
import numpy as np
from nvtabular.loader.tf_utils import configure_tensorflow
configure_tensorflow()
import nvtabular as nvt
from nvtabular.ops import *
from merlin.models.utils.example_utils import workflow_fit_transform, save_results
from merlin.schema.tags import Tags
import merlin.models.tf as mm
from merlin.io.dataset import Dataset
import tensorflow as tf
2022-04-05 16:19:29.988780: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-04-05 16:19:31.114629: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 16254 MB memory: -> device: 0, name: Quadro GV100, pci bus id: 0000:15:00.0, compute capability: 7.0
Feature Engineering with NVTabular
When we work on a new recommender systems, we explore the dataset, first. In doing so, we define our input and output paths. We will use the parquet files in the test folder to validate our trained model.
DATA_FOLDER = os.environ.get("DATA_FOLDER", "/workspace/data/")
train_path = os.path.join(DATA_FOLDER, 'train', '*.parquet')
valid_path = os.path.join(DATA_FOLDER, 'test','*.parquet')
output_path =os.path.join(DATA_FOLDER, "processed")
Our dataset has only categorical features. Below, we create continuous features using target encoding (TE) technique. Target Encoding calculates the statistics from a target variable grouped by the unique values of one or more categorical features. For example, in a binary classification problem, TE calculates the conditional probability that the target is true for each category value- a simple mean. To learn more about TE, visit this medium blog.
Note that the Ali-CCP
dataset has click
and conversion
target columns but we only focus on building different ranking models with binary target column click
.
We use a utility function, workflow_fit_transform
perform to fit and transform steps on the raw dataset applying the operators defined in the NVTabular workflow pipeline below, and also save our workflow model. After fit and transform, the processed parquet files are saved to output_path
.
%%time
user_id = ["user_id"] >> Categorify(freq_threshold=5) >> TagAsUserID()
item_id = ["item_id"] >> Categorify(freq_threshold=5) >> TagAsItemID()
add_feat = ["user_item_categories", "user_item_shops", "user_item_brands", "user_item_intentions","item_category", "item_shop", "item_brand"] >> Categorify()
te_feat = (
["user_id", "item_id"] + add_feat >>
TargetEncoding(
['click'],
kfold=1,
p_smooth=20
) >>
Normalize()
)
targets = ["click"] >> AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, "target"])
outputs = user_id + item_id + targets + add_feat + te_feat
# Remove rows where item_id==0 and user_id==0
outputs = outputs>> Filter(f=lambda df: (df["item_id"] != 0) & (df["user_id"] != 0))
workflow_fit_transform(outputs, train_path, valid_path, output_path)
/usr/local/lib/python3.8/dist-packages/cudf/core/dataframe.py:1253: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
warnings.warn(
CPU times: user 23.2 s, sys: 30.2 s, total: 53.4 s
Wall time: 57.1 s
Training Recommender Models
NVTabular exported the schema file of our processed dataset. The schema.pbtxt
is a protobuf text file contains features metadata, including statistics about features such as cardinality, min and max values and also tags based on their characteristics and dtypes (e.g., categorical, continuous, list, item_id). The metadata information is loaded from schema and their tags are used to automatically set the parameters of Merlin Models. In other words, Merlin Models relies on the schema object to automatically build all necessary input and output layers.
train = Dataset(os.path.join(output_path, 'train', '*.parquet'), part_size="500MB")
valid = Dataset(os.path.join(output_path, 'valid', '*.parquet'), part_size="500MB")
# define schema object
schema = train.schema
/usr/local/lib/python3.8/dist-packages/cudf/core/dataframe.py:1253: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
warnings.warn(
target_column = schema.select_by_tag(Tags.TARGET).column_names[0]
target_column
'click'
We can print out all the features that are included in the schema.pbtxt
file.
schema.column_names
['user_id',
'item_id',
'click',
'user_item_categories',
'user_item_shops',
'user_item_brands',
'user_item_intentions',
'item_category',
'item_shop',
'item_brand',
'TE_user_item_categories_click',
'TE_user_item_shops_click',
'TE_user_item_brands_click',
'TE_user_item_intentions_click',
'TE_item_category_click',
'TE_item_shop_click',
'TE_item_brand_click',
'TE_user_id_click',
'TE_item_id_click']
Initialize Dataloaders
We’re ready to start training, for that, we create our dataset objects, and under the hood we use Merlin BatchedDataset
class for reading chunks of parquet files. BatchedDataset
asynchronously iterate through CSV or Parquet dataframes on GPU by leveraging an NVTabular Dataset
. To read more about Merlin optimized dataloaders visit here.
NCF Model
We will first build and train a Neural Collaborative Filtering (NCF) model. Neural Collaborative Filtering (NCF) Model architecture explores neural network architectures for collaborative filtering, in other words explores the use of deep neural networks for learning the interaction function from data.
NCF feed categorical features into embedding layer, concat the embedding outputs and add multiple hidden layers via its MLP layer tower as seen in the figure. GMF and MLP uses separate user and item embeddings, and then outputs of their interactions from GMF Layer and MLP Layer are concatenated and fed to the final NeuMF (Neural Matrix Factorisation) layer.
With schema
object we enable NCF model easily to recognize item_id and user_id columns (defined in the schema.pbtxt with corresponding tags). Input block of embedding layers will be generated using item_id and user_id as seen in the Figure.
model = mm.benchmark.NCFModel(
schema,
embedding_dim=64,
mlp_block=mm.MLPBlock([128, 64]),
prediction_tasks=mm.BinaryClassificationTask(target_column, metrics=[tf.keras.metrics.AUC()]),
)
%%time
batch_size = 16*1024
LR=0.01
opt = tf.keras.optimizers.Adagrad(learning_rate=LR)
model.compile(optimizer=opt, run_eagerly=False)
model.fit(train, validation_data=valid, batch_size=batch_size)
2022-04-05 16:20:29.232080: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
2207/2209 [============================>.] - ETA: 0s - auc: 0.5287 - loss: 0.1823 - regularization_loss: 0.0000e+00 - total_loss: 0.1823
2022-04-05 16:21:34.753153: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: cond/else/_1/cond/cond/branch_executed/_139
2209/2209 [==============================] - 90s 35ms/step - auc: 0.5287 - loss: 0.1822 - regularization_loss: 0.0000e+00 - total_loss: 0.1822 - val_auc: 0.5001 - val_loss: 0.1327 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.1327
CPU times: user 2min 24s, sys: 20.2 s, total: 2min 44s
Wall time: 1min 32s
<keras.callbacks.History at 0x7f09c037f160>
Let’s save our accuracy results
save_results('NCF', model)
MLP Model
Now we will change our model to Multi-Layer Percepton (MLP) model. MLP models feed categorical features into embedding layer, concat the embedding outputs and add multiple hidden layers.
Steps:
Change the model to MLP model
Rerun the pipeline from there from model.fit
# uses default embedding_dim = 64
model = mm.MLPBlock([64, 32]).to_model(
schema,
prediction_tasks=mm.BinaryClassificationTask(target_column, metrics=[tf.keras.metrics.AUC()])
)
%%time
opt = tf.keras.optimizers.Adagrad(learning_rate=LR)
model.compile(optimizer=opt, run_eagerly=False)
model.fit(train, validation_data=valid, batch_size=batch_size)
2209/2209 [==============================] - ETA: 0s - auc_1: 0.6690 - loss: 0.1677 - regularization_loss: 0.0000e+00 - total_loss: 0.1677
2022-04-05 16:23:02.687741: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: cond/then/_0/cond/cond/branch_executed/_140
2209/2209 [==============================] - 89s 38ms/step - auc_1: 0.6690 - loss: 0.1677 - regularization_loss: 0.0000e+00 - total_loss: 0.1677 - val_auc_1: 0.5746 - val_loss: 0.1365 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.1365
CPU times: user 2min 55s, sys: 28 s, total: 3min 23s
Wall time: 1min 29s
<keras.callbacks.History at 0x7f09b05ed0a0>
save_results('MLP', model)
DLRM Model
Deep Learning Recommendation Model (DLRM) architecture is a popular neural network model originally proposed by Facebook in 2019 as a personalization deep learning model.
DLRM accepts two types of features: categorical and numerical.
For each categorical feature, an embedding table is used to provide dense representation to each unique value.
For numerical features, they are fed to model as dense features, and then transformed by a simple neural network referred to as “bottom MLP”. This part of the network consists of a series of linear layers with ReLU activations.
The output of the bottom MLP and the embedding vectors are then fed into the dot product interaction operation (see Pairwise interaction step). The output of “dot interaction” is then concatenated with the features resulting from the bottom MLP (we apply a skip-connection there) and fed into the “top MLP” which is also a series of dense layers with activations ((a fully connected NN).
The model outputs a single number (here we use sigmoid function to generate probabilities) which can be interpreted as a likelihood of a certain user clicking on an ad, watching a movie, or viewing a news page.
Steps:
Change the model to
DLRMModel
Rerun the pipeline from there from model.fit
model = mm.DLRMModel(
schema,
embedding_dim=64,
bottom_block=mm.MLPBlock([128, 64]),
top_block=mm.MLPBlock([128, 64, 32]),
prediction_tasks=mm.BinaryClassificationTask(target_column, metrics=[tf.keras.metrics.AUC()])
)
%%time
opt = tf.keras.optimizers.Adagrad(learning_rate=LR)
model.compile(optimizer=opt, run_eagerly=False)
model.fit(train, validation_data=valid, batch_size=batch_size)
2207/2209 [============================>.] - ETA: 0s - auc_2: 0.6915 - loss: 0.1616 - regularization_loss: 0.0000e+00 - total_loss: 0.1616
2022-04-05 16:24:57.719674: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: cond/then/_0/cond/cond/branch_executed/_158
2209/2209 [==============================] - 116s 49ms/step - auc_2: 0.6915 - loss: 0.1616 - regularization_loss: 0.0000e+00 - total_loss: 0.1616 - val_auc_2: 0.5724 - val_loss: 0.1364 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.1364
CPU times: user 3min 43s, sys: 38.6 s, total: 4min 21s
Wall time: 1min 57s
<keras.callbacks.History at 0x7f09a3226550>
save_results("DLRM", model)
DCN Model
DCN-V2 is an architecture proposed as an improvement upon the original DCN model. The explicit feature interactions of the inputs are learned through cross layers, and then combined with a deep network to learn complementary implicit interactions. The overall model architecture is depicted in Figure below, with two ways to combine the cross network with the deep network: (1) stacked and (2) parallel. The output of the embbedding layer is the concatenation of all the embedded vectors and the normalized dense features: x0 = [xembed,1; … ; xembed,𝑛; 𝑥dense].
In this example, we build DCN-v2 stacked
architecture.
Steps:
Change the model to
DCNModel
Rerun the pipeline from there to model.fit
model = mm.DCNModel(
schema,
depth=2,
deep_block=mm.MLPBlock([64, 32]),
prediction_tasks=mm.BinaryClassificationTask(target_column, metrics=[tf.keras.metrics.AUC()])
)
%%time
opt = tf.keras.optimizers.Adagrad(learning_rate=0.005)
model.compile(optimizer=opt, run_eagerly=False)
model.fit(train, validation_data=valid, batch_size=batch_size)
2208/2209 [============================>.] - ETA: 0s - auc_3: 0.6411 - loss: 0.1858 - regularization_loss: 0.0000e+00 - total_loss: 0.1858
2022-04-05 16:26:44.596443: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: cond/else/_1/cond/cond/branch_executed/_166
2209/2209 [==============================] - 106s 45ms/step - auc_3: 0.6411 - loss: 0.1858 - regularization_loss: 0.0000e+00 - total_loss: 0.1858 - val_auc_3: 0.5726 - val_loss: 0.1385 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.1385
CPU times: user 3min 26s, sys: 35.7 s, total: 4min 1s
Wall time: 1min 46s
<keras.callbacks.History at 0x7f09a2158fd0>
save_results("DCN", model)
Let’s visualize our model validation accuracy values. Since we did not do any hyper-parameter optimization or extensive feature engineering here, we do not come up with a final conclusion that one model is superior to another.
import matplotlib.pyplot as plt
def create_bar_chart(text_file_name, models_name):
"""a func to plot barcharts via parsing the accuracy results in a text file"""
auc = []
with open(text_file_name, "r") as infile:
for line in infile:
if "auc" in line:
data = [line.rstrip().split(":")]
key, value = zip(*data)
auc.append(float(value[0]))
X_axis = np.arange(len(models_name))
plt.title("Models' accuracy metrics comparison", pad=20)
plt.bar(X_axis - 0.2, auc, 0.4, label="AUC")
plt.xticks(X_axis, models_name)
plt.xlabel("Models")
plt.ylabel("AUC")
plt.show()
models_name = ["NCF", "MLP", "DLRM", "DCN"]
create_bar_chart("results.txt", models_name)