# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_merlin_getting-started-movielens-03-training-with-pytorch/nvidia_logo.png

Getting Started MovieLens: Training with PyTorch#

This notebook is created using the latest stable merlin-pytorch container.

Overview#

We observed that PyTorch training pipelines can be slow as the dataloader is a bottleneck. The native dataloader in PyTorch randomly samples each item from the dataset, which is very slow. In our experiments, we are able to speed-up existing PyTorch pipelines using a highly optimized dataloader.

In this tutorial we will be using the highly optimized Merlin Dataloader. To learn more about it, please consult the examples in its repository here.

Learning objectives#

This notebook explains, how to use the NVTabular dataloader to accelerate PyTorch training.

  1. Use Merlin dataloader with PyTorch

  2. Leverage multi-hot encoded input features

MovieLens25M#

The MovieLens25M is a popular dataset for recommender systems and is used in academic publications. The dataset contains 25M movie ratings for 62,000 movies given by 162,000 users. Many projects use only the user/item/rating information of MovieLens, but the original dataset provides metadata for the movies, as well. For example, which genres a movie has. Although we may not improve state-of-the-art results with our neural network architecture, the purpose of this notebook is to explain how to integrate multi-hot categorical features into a neural network.

# External dependencies
import os
import gc
import glob

import nvtabular as nvt
from merlin.schema.tags import Tags
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

We define our base directory, containing the data.

INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", os.path.expanduser("/workspace/nvt-examples/movielens/data/")
)

And the paths to our train and validation datasets.

# Output from ETL-with-NVTabular
TRAIN_PATHS = sorted(glob.glob(os.path.join(INPUT_DATA_DIR, "train", "*.parquet")))
VALID_PATHS = sorted(glob.glob(os.path.join(INPUT_DATA_DIR, "valid", "*.parquet")))

Initializing the Merlin Dataloader for PyTorch#

import torch
from merlin.loader.torch import Loader

from nvtabular.framework_utils.torch.models import Model
from nvtabular.framework_utils.torch.utils import process_epoch

First, we take a look on our dataloader and how the data is represented as tensors. The Merlin dataloader can automatically recognize the single/multi-hot columns and represent them accordingly.

BATCH_SIZE = 1024 * 32  # Batch Size

train_dataset = nvt.Dataset(TRAIN_PATHS)
validation_dataset = nvt.Dataset(VALID_PATHS)

train_loader = Loader(
    train_dataset,
    batch_size=BATCH_SIZE,
)

valid_loader = Loader(
    validation_dataset,
    batch_size=BATCH_SIZE,
)

Let’s generate a batch and take a look on the input features.

The single-hot categorical features (userId and movieId) have a shape of (32768, 1), which is the batch size (as usually).

For the multi-hot categorical feature genres, we receive a tuple of two Tensors. The first tensor are the actual data, containing the genre IDs. Note that the Tensor has more values than the batch_size. The reason is that one datapoint in the batch can contain more than one genre (multi-hot).

The second tensor is a supporting Tensor. It gives the starting index of each tensor of genres for a given example.

For example,

  • if the first two values in the second tensor are 0, 2, then the first 2 values (0, 1) in the first tensor are associated with the first datapoint in the batch (movieId/userId).

  • if the next value in the second tensor is 6, then the 3rd, 4th and 5th value in the first tensor are associated with the second datapoint in the batch (continuing after the previous value stopped).

  • if the third value in the second tensor is 7, then the 6th value in the first tensor is associated with the third datapoint in the batch.

  • and so on

batch = next(iter(train_loader))
batch
({'genres': (tensor([1, 2, 6,  ..., 8, 1, 4], device='cuda:0'),
   tensor([[    0],
           [    1],
           [    3],
           ...,
           [88555],
           [88556],
           [88557]], device='cuda:0', dtype=torch.int32)),
  'userId': tensor([[1691],
          [1001],
          [ 967],
          ...,
          [ 848],
          [1847],
          [5456]], device='cuda:0'),
  'movieId': tensor([[ 332],
          [ 154],
          [ 245],
          ...,
          [3095],
          [1062],
          [3705]], device='cuda:0')},
 tensor([1., 1., 0.,  ..., 1., 1., 0.], device='cuda:0'))

As each datapoint can have a different number of genres, it is more efficient to represent the genres as two flat tensors: One with the actual values (the first tensor) and one with the starting point for each datapoint (the second tensor).

del batch
gc.collect()
71

Defining Neural Network Architecture#

We implemented a simple PyTorch architecture.

  • Single-hot categorical features are fed into an Embedding Layer

  • Each value of a multi-hot categorical features is fed into an Embedding Layer and the multiple Embedding outputs are combined via summing

  • The output of the Embedding Layers are concatenated

  • The concatenated layers are fed through multiple feed-forward layers (Dense Layers, BatchNorm with ReLU activations)

You can see more details by checking out the implementation.

# ??Model

In order to initialize the model, we need to provide the cardinality and sizes of our embeddings.

Let’s capture them from our schema file and store this information in dicts.

def extract_info(col_name, schema):
    '''extracts embedding cardinality and dimension from schema'''
    return (
        int(schema.select_by_name(col_name).first.properties['embedding_sizes']['cardinality']),
        int(schema.select_by_name(col_name).first.properties['embedding_sizes']['dimension'])
    )

single_hot_embedding_tables_shapes = {col_name: extract_info(col_name, train_loader.dataset.schema) for col_name in ['userId', 'movieId']}
mutli_hot_embedding_tables_shapes = {col_name: extract_info(col_name, train_loader.dataset.schema) for col_name in ['genres']}
single_hot_embedding_tables_shapes, mutli_hot_embedding_tables_shapes
({'userId': (162542, 512), 'movieId': (56659, 512)}, {'genres': (21, 16)})
model = Model(
    embedding_table_shapes=(single_hot_embedding_tables_shapes, mutli_hot_embedding_tables_shapes),
    num_continuous=0,
    emb_dropout=0.0,
    layer_hidden_dims=[128, 128, 128],
    layer_dropout_rates=[0.0, 0.0, 0.0],
).to("cuda")
model
Model(
  (initial_cat_layer): ConcatenatedEmbeddings(
    (embedding_layers): ModuleList(
      (0): Embedding(162542, 512)
      (1): Embedding(56659, 512)
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (mh_cat_layer): MultiHotEmbeddings(
    (embedding_layers): ModuleList(
      (0): EmbeddingBag(21, 16, mode=sum)
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (initial_cont_layer): BatchNorm1d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layers): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=1040, out_features=128, bias=True)
      (1): ReLU(inplace=True)
      (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.0, inplace=False)
    )
    (1): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU(inplace=True)
      (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.0, inplace=False)
    )
    (2): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU(inplace=True)
      (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.0, inplace=False)
    )
  )
  (output_layer): Linear(in_features=128, out_features=1, bias=True)
)

We initialize the optimizer.

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

We use the process_epoch function to train and validate our model. It iterates over the dataset and calculates as usually the loss and optimizer step.

%%time
from time import time
EPOCHS = 1
for epoch in range(EPOCHS):
    start = time()
    train_loss, y_pred, y = process_epoch(train_loader,
                                          model,
                                          train=True,
                                          optimizer=optimizer,
                                          loss_func=torch.nn.BCEWithLogitsLoss())
    valid_loss, y_pred, y = process_epoch(valid_loader,
                                          model,
                                          train=False)
    print(f"Epoch {epoch:02d}. Train loss: {train_loss:.4f}. Valid loss: {valid_loss:.4f}.")
Total batches: 610
Total batches: 152
Epoch 00. Train loss: 0.5204. Valid loss: 2.2798.
CPU times: user 17 s, sys: 323 ms, total: 17.4 s
Wall time: 17.3 s