[1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Getting Started MovieLens: Training with TensorFlow
Overview
We observed that TensorFlow training pipelines can be slow as the dataloader is a bottleneck. The native dataloader in TensorFlow randomly sample each item from the dataset, which is very slow. The window dataloader in TensorFlow is not much faster. In our experiments, we are able to speed-up existing TensorFlow pipelines by 9x using a highly optimized dataloader.
Applying deep learning models to recommendation systems faces unique challenges in comparison to other domains, such as computer vision and natural language processing. The datasets and common model architectures have unique characteristics, which require custom solutions. Recommendation system datasets have terabytes in size with billion examples but each example is represented by only a few bytes. For example, the Criteo CTR dataset, the largest publicly available dataset, is 1.3TB with 4 billion examples. The model architectures have normally large embedding tables for the users and items, which do not fit on a single GPU. You can read more in our blogpost.
Learning objectives
This notebook explains, how to use the NVTabular dataloader to accelerate TensorFlow training. 1. Use NVTabular dataloader with TensorFlow Keras model 2. Leverage multi-hot encoded input features
MovieLens25M
The MovieLens25M is a popular dataset for recommender systems and is used in academic publications. The dataset contains 25M movie ratings for 62,000 movies given by 162,000 users. Many projects use only the user/item/rating information of MovieLens, but the original dataset provides metadata for the movies, as well. For example, which genres a movie has. Although we may not improve state-of-the-art results with our neural network architecture, the purpose of this notebook is to explain how to integrate multi-hot categorical features into a neural network.
NVTabular dataloader for TensorFlow
We’ve identified that the dataloader is one bottleneck in deep learning recommender systems when training pipelines with TensorFlow. The dataloader cannot prepare the next batch fast enough and therefore, the GPU is not fully utilized.
We developed a highly customized tabular dataloader for accelerating existing pipelines in TensorFlow. In our experiments, we see a speed-up by 9x of the same training workflow with NVTabular dataloader. NVTabular dataloader’s features are: - removing bottleneck of item-by-item dataloading - enabling larger than memory dataset by streaming from disk - reading data directly into GPU memory and remove CPU-GPU communication - preparing batch asynchronously in GPU to avoid CPU-GPU communication - supporting commonly used .parquet format - easy integration into existing TensorFlow pipelines by using similar API - works with tf.keras models
More information in our blogpost.
Getting Started
[2]:
# External dependencies
import os
import time
import gc
import glob
import os
import nvtabular as nvt
We define our base directory, containing the data.
[3]:
BASE_DIR = os.path.expanduser("~/nvt-examples/movielens/data")
Defining Hyperparameters
First, we define the data schema and differentiate between single-hot and multi-hot categorical features. Note, that we do not have any numerical input features.
[4]:
BATCH_SIZE = 1024*32 # Batch Size
CATEGORICAL_COLUMNS = ['movieId', 'userId'] # Single-hot
CATEGORICAL_MH_COLUMNS = ['genres'] # Multi-hot
NUMERIC_COLUMNS = []
# Output from ETL-with-NVTabular
TRAIN_PATHS = sorted(glob.glob(os.path.join(BASE_DIR, "train", "*.parquet")))
VALID_PATHS = sorted(glob.glob(os.path.join(BASE_DIR, "valid", "*.parquet")))
In the previous notebook, we used NVTabular for ETL and stored the workflow to disk. We can load the NVTabular workflow to extract important metadata for our training pipeline.
[5]:
proc = nvt.Workflow.load(os.path.join(BASE_DIR, "workflow"))
The embedding table shows the cardinality of each categorical variable along with its associated embedding size. Each entry is of the form (cardinality, embedding_size)
.
[6]:
EMBEDDING_TABLE_SHAPES = nvt.ops.get_embedding_sizes(proc)
EMBEDDING_TABLE_SHAPES
[6]:
{'genres': (21, 16), 'movieId': (56586, 512), 'userId': (162542, 512)}
Initializing NVTabular Dataloader for Tensorflow
We import TensorFlow and some NVTabular TF extensions, such as custom TensorFlow layers supporting multi-hot and the NVTabular TensorFlow data loader.
[7]:
import os, time
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"
import tensorflow as tf
from tensorflow.python.feature_column import feature_column_v2 as fc
# we can control how much memory to give tensorflow with this environment variable
# IMPORTANT: make sure you do this before you initialize TF's runtime, otherwise
# TF will have claimed all free GPU memory
os.environ['TF_MEMORY_ALLOCATION'] = "0.8" # fraction of free memory
from nvtabular.loader.tensorflow import KerasSequenceLoader, KerasSequenceValidater
from nvtabular.framework_utils.tensorflow import layers
from tensorflow.python.feature_column import feature_column_v2 as fc
First, we take a look on our data loader and how the data is represented as tensors. The NVTabular data loader are initialized as usually and we specify both single-hot and multi-hot categorical features as cat_names. The data loader will automatically recognize the single/multi-hot columns and represent them accordingly.
[8]:
train_dataset_tf = KerasSequenceLoader(
TRAIN_PATHS, # you could also use a glob pattern
batch_size=BATCH_SIZE,
label_names=['rating'],
cat_names=CATEGORICAL_COLUMNS+CATEGORICAL_MH_COLUMNS,
cont_names=NUMERIC_COLUMNS,
engine='parquet',
shuffle=True,
buffer_size=0.06, # how many batches to load at once
parts_per_chunk=1
)
valid_dataset_tf = KerasSequenceLoader(
VALID_PATHS, # you could also use a glob pattern
batch_size=BATCH_SIZE,
label_names=['rating'],
cat_names = CATEGORICAL_COLUMNS+CATEGORICAL_MH_COLUMNS,
cont_names=NUMERIC_COLUMNS,
engine='parquet',
shuffle=False,
buffer_size=0.06,
parts_per_chunk=1
)
Let’s generate a batch and take a look on the input features. We can see, that the single-hot categorical features (userId
and movieId
) have a shape of (32768, 1)
, which is the batchsize (as usually). For the multi-hot categorical feature genres
, we receive two Tensors genres__values
and genres__nnzs
. genres__values
are the actual data, containing the genre IDs. Note that the Tensor has more values than the batch_size. The reason is, that one datapoint in the batch
can contain more than one genre (multi-hot). genres__nnzs
are a supporting Tensor, describing how many genres are associated with each datapoint in the batch. For example, - if the first value in genres__nnzs
is 5
, then the first 5 values in genres__values
are associated with the first datapoint in the batch (movieId/userId). - if the second value in genres__nnzs
is 2
, then the 6th and the 7th values in genres__values
are associated with the second datapoint in the
batch (continuing after the previous value stopped). - if the third value in genres_nnzs
is 1
, then the 8th value in genres__values
are associated with the third datapoint in the batch. - and so on
[9]:
batch = next(iter(train_dataset_tf))
batch[0]
[9]:
{'genres__values': <tf.Tensor: shape=(88761, 1), dtype=int64, numpy=
array([[ 6],
[16],
[ 9],
...,
[ 2],
[ 9],
[19]])>,
'genres__nnzs': <tf.Tensor: shape=(32768, 1), dtype=int32, numpy=
array([[2],
[1],
[3],
...,
[4],
[3],
[3]], dtype=int32)>,
'movieId': <tf.Tensor: shape=(32768, 1), dtype=int64, numpy=
array([[ 560],
[ 3067],
[ 900],
...,
[ 1121],
[ 5508],
[22156]])>,
'userId': <tf.Tensor: shape=(32768, 1), dtype=int64, numpy=
array([[ 52292],
[153071],
[135484],
...,
[117318],
[ 52951],
[143658]])>}
We can see that the sum of genres__nnzs
is equal to the shape of genres__values
.
[10]:
tf.reduce_sum(batch[0]['genres__nnzs'])
[10]:
<tf.Tensor: shape=(), dtype=int32, numpy=88761>
As each datapoint can have a different number of genres, it is more efficient to represent the genres as two flat tensors: One with the actual values (genres__values
) and one with the length for each datapoint (genres__nnzs
).
[11]:
del batch
gc.collect()
[11]:
61
Defining Neural Network Architecture
We will define a common neural network architecture for tabular data. * Single-hot categorical features are fed into an Embedding Layer * Each value of a multi-hot categorical features is fed into an Embedding Layer and the multiple Embedding outputs are combined via averaging * The output of the Embedding Layers are concatenated * The concatenated layers are fed through multiple feed-forward layers (Dense Layers with ReLU activations) * The final output is a single number with sigmoid activation function
First, we will define some dictonary/lists for our network architecture.
[12]:
inputs = {} # tf.keras.Input placeholders for each feature to be used
emb_layers = []# output of all embedding layers, which will be concatenated
We create tf.keras.Input
tensors for all 4 input features.
[13]:
for col in CATEGORICAL_COLUMNS:
inputs[col] = tf.keras.Input(
name=col,
dtype=tf.int32,
shape=(1,)
)
# Note that we need two input tensors for multi-hot categorical features
for col in CATEGORICAL_MH_COLUMNS:
inputs[col+'__values'] = tf.keras.Input(
name=f"{col}__values",
dtype=tf.int64,
shape=(1,)
)
inputs[col+'__nnzs'] = tf.keras.Input(
name=f"{col}__nnzs",
dtype=tf.int64,
shape=(1,)
)
Next, we initialize Embedding Layers with tf.feature_column.embedding_column
.
[14]:
for col in CATEGORICAL_COLUMNS+CATEGORICAL_MH_COLUMNS:
emb_layers.append(
tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_identity(
col,
EMBEDDING_TABLE_SHAPES[col][0] # Input dimension (vocab size)
), EMBEDDING_TABLE_SHAPES[col][1] # Embedding output dimension
)
)
emb_layers
[14]:
[EmbeddingColumn(categorical_column=IdentityCategoricalColumn(key='movieId', number_buckets=56586, default_value=None), dimension=512, combiner='mean', initializer=<tensorflow.python.ops.init_ops.TruncatedNormal object at 0x7fbca110b370>, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, trainable=True, use_safe_embedding_lookup=True),
EmbeddingColumn(categorical_column=IdentityCategoricalColumn(key='userId', number_buckets=162542, default_value=None), dimension=512, combiner='mean', initializer=<tensorflow.python.ops.init_ops.TruncatedNormal object at 0x7fbca110b2e0>, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, trainable=True, use_safe_embedding_lookup=True),
EmbeddingColumn(categorical_column=IdentityCategoricalColumn(key='genres', number_buckets=21, default_value=None), dimension=16, combiner='mean', initializer=<tensorflow.python.ops.init_ops.TruncatedNormal object at 0x7fbca110ba60>, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, trainable=True, use_safe_embedding_lookup=True)]
NVTabular implemented a custom TensorFlow layer layers.DenseFeatures
, which takes as an input the different tf.Keras.Input
and pre-initialized tf.feature_column
and automatically concatenate them into a flat tensor. In the case of multi-hot categorical features, DenseFeatures
organizes the inputs __values
and __nnzs
to define a RaggedTensor
and combine them. DenseFeatures
can handle numeric inputs, as well, but MovieLens does not provide numerical input features.
[15]:
emb_layer = layers.DenseFeatures(emb_layers)
x_emb_output = emb_layer(inputs)
x_emb_output
[15]:
<KerasTensor: shape=(None, 1040) dtype=float32 (created by layer 'dense_features')>
We can see that the output shape of the concatenated layer is equal to the sum of the individual Embedding output dimensions (1040 = 16+512+512).
[16]:
EMBEDDING_TABLE_SHAPES
[16]:
{'genres': (21, 16), 'movieId': (56586, 512), 'userId': (162542, 512)}
We add multiple Dense Layers. Finally, we initialize the tf.keras.Model
and add the optimizer.
[17]:
x = tf.keras.layers.Dense(128, activation="relu")(x_emb_output)
x = tf.keras.layers.Dense(128, activation="relu")(x)
x = tf.keras.layers.Dense(128, activation="relu")(x)
x = tf.keras.layers.Dense(1, activation="sigmoid", name="output")(x)
model = tf.keras.Model(inputs=inputs, outputs=x)
model.compile('sgd', 'binary_crossentropy')
[18]:
# You need to install the dependencies
tf.keras.utils.plot_model(model)
[18]:
Training the deep learning model
We can train our model with model.fit
. We need to use a Callback to add the validation dataloader.
[19]:
validation_callback = KerasSequenceValidater(valid_dataset_tf)
history = model.fit(train_dataset_tf, callbacks=[validation_callback], epochs=1)
611/611 [==============================] - 13s 20ms/step - loss: 0.6745