# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_merlin_getting-started-movielens-02-etl-with-nvtabular/nvidia_logo.png

Getting Started MovieLens: ETL with NVTabular#

This notebook is created using the latest stable merlin-hugectr, merlin-tensorflow, or merlin-pytorch container.

Overview#

NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems. It provides a high level abstraction to simplify code and accelerates computation on the GPU using the RAPIDS cuDF library.

Deep Learning models require the input feature in a specific format. Categorical features needs to be continuous integers (0, …, |C|) to use them with an embedding layer. We will use NVTabular to preprocess the categorical features.

One other challenge is multi-hot categorical features. A product can have multiple categories assigned, but the number of categories per product varies. For example, a movie can have one or multiple genres:

  • Father of the Bride Part II: [Comedy]

  • Toy Story: [Adventure, Animation, Children, Comedy, Fantasy]

  • Jumanji: [Adventure, Children, Fantasy]

One strategy is often to use only the first category or the most frequent ones. However, a better strategy is to use all provided categories per datapoint. RAPID cuDF added list support in its latest release v0.16 and NVTabular now supports multi-hot categorical features.

Learning objectives#

In this notebook, we learn how to Categorify single-hot and multi-hot categorical input features with NVTabular

  • Learn NVTabular for using GPU-accelerated ETL (Preprocess and Feature Engineering)

  • Get familiar with NVTabular’s high-level API

  • Join two dataframes with JoinExternal operator

  • Preprocess single-hot categorical input features with NVTabular

  • Preprocess multi-hot categorical input features with NVTabular

  • Use LambdaOp for custom row-wise dataframe manipulations with NVTabular

For more information about NVTabular, refer to the documentation or the GitHub repository.

ETL with NVTabular#

# External dependencies
import os
import shutil
import numpy as np
from nvtabular.ops import *
from merlin.schema.tags import Tags

import nvtabular as nvt

from os import path

# Get dataframe library - cudf or pandas
from merlin.core.dispatch import get_lib
df_lib = get_lib()
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

We define our base input directory, containing the data.

INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", os.path.expanduser("/workspace/nvt-examples/movielens/data/")
)
movies = df_lib.read_parquet(os.path.join(INPUT_DATA_DIR, "movies_converted.parquet"))
movies.head()
movieId genres
0 1 [Adventure, Animation, Children, Comedy, Fantasy]
1 2 [Adventure, Children, Fantasy]
2 3 [Comedy, Romance]
3 4 [Comedy, Drama, Romance]
4 5 [Comedy]

Defining our Preprocessing Pipeline#

The first step is to define the feature engineering and preprocessing pipeline.

NVTabular has already implemented multiple calculations, called ops. An op can be applied to a ColumnGroup from an overloaded >> operator, which in turn returns a new ColumnGroup. A ColumnGroup is a list of column names as text.

Example:

features = [ column_name, ...] >> op1 >> op2 >> ...

This may sounds more complicated as it is. Let’s define our first pipeline for the MovieLens dataset.

Currently, our dataset consists of two separate dataframes. First, we use the JoinExternal operator to left-join the metadata (genres) to our rating dataset.

As we process the data, we are also adding tags. They will be helpful to our model during training and inference. Once specified here, we will not have to provide this information down the road.

CATEGORICAL_COLUMNS = ["userId", "movieId"]
LABEL_COLUMNS = ["rating"]
userId = ["userId"] >> TagAsUserID()
movieId = ["movieId"] >> TagAsItemID()

joined = userId + movieId >> JoinExternal(movies, on=["movieId"])

Data pipelines are Directed Acyclic Graphs (DAGs). We can visualize them with graphviz.

joined.graph
../../_images/79cf38d47df43a05455be030898708ab047017ab2abc5685f9f6a5af9165cd7e.svg

Embedding Layers of neural networks require that categorical features are contiguous, incremental Integers: 0, 1, 2, … , |C|-1. We need to ensure that our categorical features fulfill the requirement.

Currently, our genres are a list of Strings. In addition, we should transform the single-hot categorical features userId and movieId, as well.
NVTabular provides the operator Categorify, which provides this functionality with a high-level API out of the box. In NVTabular release v0.3, list support was added for multi-hot categorical features. Both works in the same way with no need for changes.

Next, we will add Categorify for our categorical features (single hot: userId, movieId and multi-hot: genres).

cat_features = joined >> Categorify()

The ratings are on a scale between 1-5. We want to predict a binary target with 1 for ratings >3 and 0 for ratings <=3. We use the LambdaOp for it.

ratings = nvt.ColumnGroup(["rating"]) >> LambdaOp(lambda col: (col > 3).astype("int8")) >> AddTags(Tags.TARGET)
output = cat_features + ratings
(output).graph
../../_images/169cf01515772f6d7f755dd1f095ac11146169526db919c068cd660a2f12de7c.svg

We initialize our NVTabular workflow.

workflow = nvt.Workflow(output)

Running the pipeline#

In general, the Ops in our Workflow will require measurements of statistical properties of our data in order to be leveraged. For example, the Normalize op requires measurements of the dataset mean and standard deviation, and the Categorify op requires an accounting of all the categories a particular feature can manifest. However, we frequently need to measure these properties across datasets which are too large to fit into GPU memory (or CPU memory for that matter) at once.

NVTabular solves this by providing the Dataset class, which breaks a set of parquet or csv files into into a collection of cudf.DataFrame chunks that can fit in device memory. The main purpose of this class is to abstract away the raw format of the data, and to allow other NVTabular classes to reliably materialize a dask_cudf.DataFrame collection (and/or collection-based iterator) on demand. Under the hood, the data decomposition corresponds to the construction of a dask_cudf.DataFrame object. By representing our dataset as a lazily-evaluated Dask collection, we can handle the calculation of complex global statistics (and later, can also iterate over the partitions while feeding data into a neural network). part_size defines the size read into GPU-memory at once.

Now instantiate dataset iterators to loop through our dataset (which we couldn’t fit into GPU memory). HugeCTR expect the categorical input columns as int64 and continuous/label columns as float32 We need to enforce the required HugeCTR data types, so we set them in a dictionary and give as an argument when creating our dataset.

dict_dtypes = {}

for col in CATEGORICAL_COLUMNS:
    dict_dtypes[col] = np.int64

for col in LABEL_COLUMNS:
    dict_dtypes[col] = np.float32
train_dataset = nvt.Dataset([os.path.join(INPUT_DATA_DIR, "train.parquet")])
valid_dataset = nvt.Dataset([os.path.join(INPUT_DATA_DIR, "valid.parquet")])

Now that we have our datasets, we’ll apply our Workflow to them and save the results out to parquet files for fast reading at train time. Similar to the scikit learn API, we collect the statistics of our train dataset with .fit.

%%time
workflow.fit(train_dataset)
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
CPU times: user 692 ms, sys: 354 ms, total: 1.05 s
Wall time: 1.06 s
<nvtabular.workflow.workflow.Workflow at 0x7fb2d7fc42b0>

We clear our output directories.

# Make sure we have a clean output path
if path.exists(os.path.join(INPUT_DATA_DIR, "train")):
    shutil.rmtree(os.path.join(INPUT_DATA_DIR, "train"))
if path.exists(os.path.join(INPUT_DATA_DIR, "valid")):
    shutil.rmtree(os.path.join(INPUT_DATA_DIR, "valid"))

We transform our workflow with .transform. We are going to add 'userId', 'movieId', 'genres' columns to _metadata.json, because this json file will be needed for HugeCTR training to obtain the required information from all the rows in each parquet file.

%time
workflow.transform(train_dataset).to_parquet(
    output_path=os.path.join(INPUT_DATA_DIR, "train"),
    shuffle=nvt.io.Shuffle.PER_PARTITION,
    cats=["userId", "movieId", "genres"],
    labels=["rating"],
    dtypes=dict_dtypes,
    write_hugectr_keyset=True  # only needed if using this ETL Notebook for training with HugeCTR
                               # should be removed otherwise to speed up computation
)
CPU times: user 0 ns, sys: 1e+03 ns, total: 1e+03 ns
Wall time: 2.86 µs
%time
workflow.transform(valid_dataset).to_parquet(
    output_path=os.path.join(INPUT_DATA_DIR, "valid"),
    shuffle=False,
    cats=["userId", "movieId", "genres"],
    labels=["rating"],
    dtypes=dict_dtypes,
    write_hugectr_keyset=True  # only needed if using this ETL Notebook for training with HugeCTR
                               # should be removed otherwise to speed up computation
)
CPU times: user 1 µs, sys: 0 ns, total: 1 µs
Wall time: 2.62 µs

In the next notebooks, we will train a deep learning model. Our training pipeline requires information about the data schema to define the neural network architecture. We will save the NVTabular workflow to disk so that we can restore it in the next notebooks.

workflow.save(os.path.join(INPUT_DATA_DIR, "workflow"))
workflow.output_schema
name tags dtype is_list is_ragged properties.num_buckets properties.freq_threshold properties.max_size properties.start_index properties.cat_path properties.domain.min properties.domain.max properties.domain.name properties.embedding_sizes.cardinality properties.embedding_sizes.dimension
0 userId (Tags.CATEGORICAL) DType(name='int64', element_type=<ElementType.... False False NaN 0.0 0.0 0.0 .//categories/unique.userId.parquet 0.0 162541.0 userId 162542.0 512.0
1 movieId (Tags.CATEGORICAL) DType(name='int64', element_type=<ElementType.... False False NaN 0.0 0.0 0.0 .//categories/unique.movieId.parquet 0.0 56658.0 movieId 56659.0 512.0
2 genres (Tags.CATEGORICAL) DType(name='int64', element_type=<ElementType.... True True NaN 0.0 0.0 0.0 .//categories/unique.genres.parquet 0.0 20.0 genres 21.0 16.0
3 rating (Tags.TARGET) DType(name='int8', element_type=<ElementType.I... False False NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

Checking the pre-processing outputs#

We can take a look on the data.

import glob

TRAIN_PATHS = sorted(glob.glob(os.path.join(INPUT_DATA_DIR, "train", "*.parquet")))
VALID_PATHS = sorted(glob.glob(os.path.join(INPUT_DATA_DIR, "valid", "*.parquet")))
TRAIN_PATHS, VALID_PATHS
(['/workspace/nvt-examples/movielens/data/train/part_0.parquet'],
 ['/workspace/nvt-examples/movielens/data/valid/part_0.parquet'])

We can see that genres are a list of Integers

df = df_lib.read_parquet(TRAIN_PATHS[0])
df.head()
userId movieId genres rating
0 1691 332 [1] 1.0
1 1001 154 [2, 6] 1.0
2 967 245 [3, 2] 0.0
3 150851 622 [3, 5, 7, 15] 1.0
4 39553 1146 [3, 1] 1.0

Next Steps#

You can read the API documentation for the Operators and classes used in this notebook:

The next step for learning to use Merlin for creating a recommender system is to train a model. Refer to Training with TensorFlow, Training with HugeCTR, or Training with PyTorch.