# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png

Advanced NVTabular Workflow

This notebook is created using the latest stable merlin-tensorflow container.

Overview

This notebook builds on what is covered in 01-Getting-started NB. In this notebook we will take a closer look at running more complex Operators.

We will also examine the functionality of the Schema – this can be very useful for learning more about our data and instructing which columns we would like to use and how.

Learning objectives

  • Leveraging more complex Merlin NVTabular Operators

  • Splitting our data in chunks to limit memory footprint

  • Understanding the role of the schema and modifying it

Downloading the dataset

import os
from merlin.datasets.entertainment import get_movielens

input_path = os.environ.get("INPUT_DATA_DIR", os.path.expanduser("~/merlin-framework/movielens/"))
get_movielens(variant="ml-1m", path=input_path); #noqa
2022-09-13 07:41:27.850043: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-13 07:41:27.850489: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-13 07:41:27.850632: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
  warnings.warn(

Reading in the data

from merlin.core.dispatch import get_lib

data = get_lib().read_parquet(f'{input_path}ml-1m/train.parquet').sample(frac=1)

train = data.iloc[:600_000]
valid = data.iloc[600_000:]

movies = get_lib().read_parquet(f'{input_path}ml-1m/movies_converted.parquet')

From the provided train and validation sets we extract userId, movieId and rating.

train.head()
userId movieId rating timestamp
315038 1880 1674 3 975379394
835377 5021 3468 4 962583655
105787 699 2205 2 986531971
150012 966 1653 1 975115827
137192 883 3911 4 975260600

We will also use the metadata contained in movies.

movies.head()
movieId title genres
0 1 Toy Story (1995) [Animation, Children's, Comedy]
1 2 Jumanji (1995) [Adventure, Children's, Fantasy]
2 3 Grumpier Old Men (1995) [Comedy, Romance]
3 4 Waiting to Exhale (1995) [Comedy, Drama]
4 5 Father of the Bride Part II (1995) [Comedy]

Let’s create Merlin Datasets that we will be able to run through our workflow (a workflow defines the operations we would like performed on our data).

import nvtabular as nvt
from merlin.schema.tags import Tags

train_ds = nvt.Dataset(train, npartitions=2)
valid_ds = nvt.Dataset(valid)

train_ds, valid_ds
(<merlin.io.dataset.Dataset at 0x7f1c2628e550>,
 <merlin.io.dataset.Dataset at 0x7f1c2628e0a0>)

The constructor nvt.Dataset(...) accepts an important parameter: npartitions. We can leverage it to specify into how many chunks we would like our data to be split. Our workflow will process data in chunks and by increasing the number of partitions we can limit the memory footprint.

We have to remember though to shuffle the data so that all relevant records reside in a single partition. Otherwise, depending on the operations we chose to perform, the results of our workflow might be incorrect.

train_ds.shuffle_by_keys('userId')
valid_ds.shuffle_by_keys('userId')
<merlin.io.dataset.Dataset at 0x7f1c26296790>

We have now shuffled the records so that all examples for the same userId reside in the same partition.

The number of partitions hasn’t changed.

train_ds.npartitions
2

Defining the workflow

Joining an external dataset

We want to make use of information contained in the movies DataFrame.

Let us join it onto our data, making sure we only select the genres column.

genres = ['movieId'] >> nvt.ops.JoinExternal(movies, on='movieId', columns_ext=['movieId', 'genres'])

The genres information is represented as a list of strings. In order for us to be able to use it for training an ML model we need to encode the strings as categories.

movies['genres'].head()
0     [Animation, Children's, Comedy]
1    [Adventure, Children's, Fantasy]
2                   [Comedy, Romance]
3                     [Comedy, Drama]
4                            [Comedy]
Name: genres, dtype: list

Encoding a list of strings as categories

We leverage the Categorify operator to encode our lists of strings.

Additionally, let’s treat the categories that appear infrequently as unknowns.

genres = genres >> nvt.ops.Categorify(freq_threshold=10)

Running a custom preprocessing step

Let’s look at an example of how we can run a custom preprocessing step on our data.

Here we will define our own function that will transform the rating column to binary_rating. We will output a value of False for any rating below 4 and a value of True for any rating of 4 or 5.

def rating_to_binary(col):
    return col > 3

The LambdaOp is an operator that takes in a function and performs a rowwise transformation.

Here, LambdaOp takes in the rating_to_binary function and will apply it to the values in the rating column.

binary_rating = ['rating'] >> nvt.ops.LambdaOp(rating_to_binary) >> nvt.ops.Rename(name='binary_rating')

Tagging columns

Let us also tag our columns. This will help streamline model training and serving down the road, should we opt to do so using other components of the Merlin Framework.

Running NVTabular operators on our data automatically tags the output columns. For instance, the Categorify operator we used above will tag the output columns as Tags.CATEGORICAL. Still, some information that can be useful down the road needs to be added by hand. This is true for instance for target information. Let’s add this information below.

userId = ['userId'] >> nvt.ops.Categorify() >> nvt.ops.AddTags(tags=[Tags.USER_ID, Tags.CATEGORICAL, Tags.USER])
movieId = ['movieId'] >> nvt.ops.Categorify() >> nvt.ops.AddTags(tags=[Tags.ITEM_ID, Tags.CATEGORICAL, Tags.ITEM])
binary_rating = binary_rating >> nvt.ops.AddTags(tags=[Tags.TARGET, Tags.BINARY_CLASSIFICATION])

Applying the workflow to the train and validation sets

We are now ready to create our workflow.

workflow = nvt.Workflow(userId + movieId + genres + binary_rating)

Let us now fit the workflow to our train set and transform the train and validation sets.

train_transformed = workflow.fit_transform(train_ds)
valid_transformed = workflow.transform(valid_ds)
valid_transformed.compute().head()
/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
  warnings.warn(
userId movieId genres binary_rating
0 112 1326 [1] True
1 251 1508 [1] True
2 147 254 [7, 1] True
3 2555 238 [1] True
4 566 1000 [3, 8, 4] False

Our data was processed as we expected.

Let us look at the schema.

train_transformed.schema
name tags dtype is_list is_ragged properties.num_buckets properties.freq_threshold properties.max_size properties.start_index properties.cat_path properties.domain.min properties.domain.max properties.domain.name properties.embedding_sizes.cardinality properties.embedding_sizes.dimension
0 userId (Tags.USER, Tags.CATEGORICAL, Tags.USER_ID) int64 False False NaN 0.0 0.0 0.0 .//categories/unique.userId.parquet 0.0 6041.0 userId 6041.0 210.0
1 movieId (Tags.ITEM_ID, Tags.ITEM, Tags.CATEGORICAL) int64 False False NaN 0.0 0.0 0.0 .//categories/unique.movieId.parquet 0.0 1852.0 movieId 1852.0 108.0
2 genres (Tags.CATEGORICAL) int64 True True NaN 100.0 0.0 0.0 .//categories/unique.genres.parquet 0.0 19.0 genres 19.0 16.0
3 binary_rating (Tags.BINARY_CLASSIFICATION, Tags.TARGET) bool False False NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

The schema contains a lot of useful information, such as the frequency threshold that was used to process a column or its cardinality.

Beyond providing us with useful information, schema can be used to inform other components of the Merlin Framework on what information from our datasets we would like it to act on and in what way.

This is what the Tags.TARGET, Tags.ITEM_ID, and Tags.USER_ID achieve. When we train our Deep Leranin, the model will be able to infer which columns it should use as a source of what type of information.

We can also use the schema to remove information that we wouldn’t want our model to use in training.

Let us look at an example of that below.

Training a DLRM Model

Let us train a DLRM Model.

It is usually advantageous to train on the entiriety of the available data. Let us do so below – we will utilize all the columns included in our schema file.

We tell the model about the structure of our data by passing in the Schema (train_transformed.schema in the example below).

import tensorflow
import merlin.models.tf as mm

model = mm.DLRMModel(
    train_transformed.schema,
    embedding_dim=64,
    bottom_block=mm.MLPBlock([128, 64]),
    top_block=mm.MLPBlock([128, 64, 32]),
    prediction_tasks=mm.BinaryClassificationTask('rating_binary')
)

opt = tensorflow.optimizers.Adam(learning_rate=5e-3)
model.compile(optimizer=opt)
model.fit(train_transformed, validation_data=valid_transformed, batch_size=1024, epochs=5)

model.optimizer.learning_rate = 1e-3
model.fit(train_transformed, validation_data=valid_transformed, batch_size=1024, epochs=3)
2022-08-24 23:45:16.972498: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-08-24 23:45:16.973379: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-08-24 23:45:16.973563: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-08-24 23:45:16.973699: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-08-24 23:45:16.973972: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-08-24 23:45:16.974118: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-08-24 23:45:16.974257: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-08-24 23:45:16.974377: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 24576 MB memory:  -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:08:00.0, compute capability: 7.5
Epoch 1/5
782/782 [==============================] - 14s 13ms/step - loss: 0.5644 - precision: 0.7184 - recall: 0.8133 - binary_accuracy: 0.7092 - auc: 0.7688 - regularization_loss: 0.0000e+00 - val_loss: 0.5397 - val_precision: 0.7416 - val_recall: 0.8116 - val_binary_accuracy: 0.7290 - val_auc: 0.7955 - val_regularization_loss: 0.0000e+00
Epoch 2/5
782/782 [==============================] - 9s 11ms/step - loss: 0.5418 - precision: 0.7363 - recall: 0.8131 - binary_accuracy: 0.7250 - auc: 0.7907 - regularization_loss: 0.0000e+00 - val_loss: 0.5288 - val_precision: 0.7366 - val_recall: 0.8376 - val_binary_accuracy: 0.7343 - val_auc: 0.8043 - val_regularization_loss: 0.0000e+00
Epoch 3/5
782/782 [==============================] - 10s 12ms/step - loss: 0.5331 - precision: 0.7403 - recall: 0.8172 - binary_accuracy: 0.7299 - auc: 0.7984 - regularization_loss: 0.0000e+00 - val_loss: 0.5184 - val_precision: 0.7562 - val_recall: 0.8081 - val_binary_accuracy: 0.7398 - val_auc: 0.8119 - val_regularization_loss: 0.0000e+00
Epoch 4/5
782/782 [==============================] - 9s 12ms/step - loss: 0.5248 - precision: 0.7453 - recall: 0.8207 - binary_accuracy: 0.7355 - auc: 0.8058 - regularization_loss: 0.0000e+00 - val_loss: 0.5080 - val_precision: 0.7599 - val_recall: 0.8182 - val_binary_accuracy: 0.7468 - val_auc: 0.8203 - val_regularization_loss: 0.0000e+00
Epoch 5/5
782/782 [==============================] - 9s 12ms/step - loss: 0.5162 - precision: 0.7519 - recall: 0.8204 - binary_accuracy: 0.7410 - auc: 0.8133 - regularization_loss: 0.0000e+00 - val_loss: 0.4997 - val_precision: 0.7601 - val_recall: 0.8349 - val_binary_accuracy: 0.7534 - val_auc: 0.8287 - val_regularization_loss: 0.0000e+00
Epoch 1/3
782/782 [==============================] - 10s 12ms/step - loss: 0.4921 - precision: 0.7671 - recall: 0.8265 - binary_accuracy: 0.7558 - auc: 0.8325 - regularization_loss: 0.0000e+00 - val_loss: 0.4798 - val_precision: 0.7747 - val_recall: 0.8338 - val_binary_accuracy: 0.7649 - val_auc: 0.8422 - val_regularization_loss: 0.0000e+00
Epoch 2/3
782/782 [==============================] - 10s 12ms/step - loss: 0.4815 - precision: 0.7739 - recall: 0.8304 - binary_accuracy: 0.7628 - auc: 0.8406 - regularization_loss: 0.0000e+00 - val_loss: 0.4685 - val_precision: 0.7804 - val_recall: 0.8411 - val_binary_accuracy: 0.7725 - val_auc: 0.8505 - val_regularization_loss: 0.0000e+00
Epoch 3/3
782/782 [==============================] - 10s 12ms/step - loss: 0.4717 - precision: 0.7797 - recall: 0.8349 - binary_accuracy: 0.7694 - auc: 0.8478 - regularization_loss: 0.0000e+00 - val_loss: 0.4578 - val_precision: 0.7901 - val_recall: 0.8385 - val_binary_accuracy: 0.7790 - val_auc: 0.8580 - val_regularization_loss: 0.0000e+00
<keras.callbacks.History at 0x7f1c021e8b80>

Let us now retry the training without the genres information. Maybe we have reason to suspect this data has some issues and we would like to verify how the model performs without this information being passed in.

We can achieve all this by modifying the passed in schema using the without method.

import tensorflow
import merlin.models.tf as mm

model = mm.DLRMModel(
    train_transformed.schema.without('genres'),   # <--- this is where we make the change
    embedding_dim=64,
    bottom_block=mm.MLPBlock([128, 64]),
    top_block=mm.MLPBlock([128, 64, 32]),
    prediction_tasks=mm.BinaryClassificationTask('rating_binary')
)

opt = tensorflow.optimizers.Adam(learning_rate=5e-3)
model.compile(optimizer=opt)
model.fit(train_transformed, validation_data=valid_transformed, batch_size=1024, epochs=5)

model.optimizer.learning_rate = 1e-3
metrics = model.fit(train_transformed, validation_data=valid_transformed, batch_size=1024, epochs=3)
Epoch 1/5
782/782 [==============================] - 10s 11ms/step - loss: 0.5666 - precision_1: 0.7193 - recall_1: 0.8067 - binary_accuracy: 0.7078 - auc_1: 0.7668 - regularization_loss: 0.0000e+00 - val_loss: 0.5415 - val_precision_1: 0.7386 - val_recall_1: 0.8128 - val_binary_accuracy: 0.7268 - val_auc_1: 0.7936 - val_regularization_loss: 0.0000e+00
Epoch 2/5
782/782 [==============================] - 8s 10ms/step - loss: 0.5451 - precision_1: 0.7343 - recall_1: 0.8115 - binary_accuracy: 0.7227 - auc_1: 0.7880 - regularization_loss: 0.0000e+00 - val_loss: 0.5314 - val_precision_1: 0.7508 - val_recall_1: 0.7996 - val_binary_accuracy: 0.7320 - val_auc_1: 0.8009 - val_regularization_loss: 0.0000e+00
Epoch 3/5
782/782 [==============================] - 8s 10ms/step - loss: 0.5374 - precision_1: 0.7398 - recall_1: 0.8116 - binary_accuracy: 0.7274 - auc_1: 0.7949 - regularization_loss: 0.0000e+00 - val_loss: 0.5239 - val_precision_1: 0.7529 - val_recall_1: 0.8055 - val_binary_accuracy: 0.7361 - val_auc_1: 0.8070 - val_regularization_loss: 0.0000e+00
Epoch 4/5
782/782 [==============================] - 8s 9ms/step - loss: 0.5307 - precision_1: 0.7423 - recall_1: 0.8171 - binary_accuracy: 0.7316 - auc_1: 0.8007 - regularization_loss: 0.0000e+00 - val_loss: 0.5171 - val_precision_1: 0.7567 - val_recall_1: 0.8113 - val_binary_accuracy: 0.7415 - val_auc_1: 0.8134 - val_regularization_loss: 0.0000e+00
Epoch 5/5
782/782 [==============================] - 8s 10ms/step - loss: 0.5233 - precision_1: 0.7476 - recall_1: 0.8185 - binary_accuracy: 0.7366 - auc_1: 0.8073 - regularization_loss: 0.0000e+00 - val_loss: 0.5092 - val_precision_1: 0.7577 - val_recall_1: 0.8232 - val_binary_accuracy: 0.7468 - val_auc_1: 0.8196 - val_regularization_loss: 0.0000e+00
Epoch 1/3
782/782 [==============================] - 8s 10ms/step - loss: 0.5038 - precision_1: 0.7594 - recall_1: 0.8249 - binary_accuracy: 0.7489 - auc_1: 0.8235 - regularization_loss: 0.0000e+00 - val_loss: 0.4938 - val_precision_1: 0.7656 - val_recall_1: 0.8327 - val_binary_accuracy: 0.7571 - val_auc_1: 0.8316 - val_regularization_loss: 0.0000e+00
Epoch 2/3
782/782 [==============================] - 8s 10ms/step - loss: 0.4953 - precision_1: 0.7637 - recall_1: 0.8316 - binary_accuracy: 0.7551 - auc_1: 0.8303 - regularization_loss: 0.0000e+00 - val_loss: 0.4856 - val_precision_1: 0.7670 - val_recall_1: 0.8440 - val_binary_accuracy: 0.7628 - val_auc_1: 0.8381 - val_regularization_loss: 0.0000e+00
Epoch 3/3
782/782 [==============================] - 8s 10ms/step - loss: 0.4881 - precision_1: 0.7674 - recall_1: 0.8360 - binary_accuracy: 0.7599 - auc_1: 0.8360 - regularization_loss: 0.0000e+00 - val_loss: 0.4781 - val_precision_1: 0.7727 - val_recall_1: 0.8451 - val_binary_accuracy: 0.7679 - val_auc_1: 0.8437 - val_regularization_loss: 0.0000e+00

As it turns out, genres contained signal that was genuinely useful to the model as without it the performance on the validation set decreased!

Using the Schema (combined with the ability of the Merlin Framework to infer the appropriate shape of the model based on our data) we can streamline the training and experimentation phase to a significant extent.

Next steps

We are using NVTabular in our Merlin repositories to preprocess and engineer features before training. We can recommend multiple examples, which show complex pipelines with NVTabular:

  • This notebook demonstrates how to aggregate data – we are going from multiple rows of session information to a single row describing a session.

  • The following notebook demonstrates how to derive features from timestamps and how to define a custom operator (ItemRecency).

  • In this notebook, among other functionality, you can familiarize yourself with using TargetEncoding, filling missing values, Normalization as well as adding various MetaData. Do note, this notebook is not maintained and thus there might be discrepancies between it and the most up to date version of NVTabular.