[1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

4419bb35492f4c888449bc93c1d4983c

NVTabular demo on Rossmann data - TensorFlow

Overview

NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems. It provides a high level abstraction to simplify code and accelerates computation on the GPU using the RAPIDS cuDF library.

Learning objectives

In the previous notebooks (01-Download-Convert.ipynb and 02-ETL-with-NVTabular.ipynb), we downloaded, preprocessed and created features for the dataset. Now, we are ready to train our deep learning model on the dataset. In this notebook, we use TensorFlow with the NVTabular data loader for TensorFlow to accelereate the training pipeline.

[2]:
import os
import math
import json
import glob

Loading NVTabular workflow

This time, we only need to define our data directories. We can load the data schema from the NVTabular workflow.

[3]:
DATA_DIR = os.environ.get("OUTPUT_DATA_DIR", os.path.expanduser("~/nvt-examples/data/"))
PREPROCESS_DIR = os.path.join(DATA_DIR, "ross_pre/")
PREPROCESS_DIR_TRAIN = os.path.join(PREPROCESS_DIR, "train")
PREPROCESS_DIR_VALID = os.path.join(PREPROCESS_DIR, "valid")

What files are available to train on in our directories?

[5]:
!ls $PREPROCESS_DIR
stats.json  train  valid
[6]:
!ls $PREPROCESS_DIR_TRAIN
0.1136d38916184bd39bf3d0cc6af8aecc.parquet  _metadata
_file_list.txt                              _metadata.json
[7]:
!ls $PREPROCESS_DIR_VALID
0.bcd2404e802640f29b1427feaacbd24a.parquet  _metadata
_file_list.txt                              _metadata.json

We load the data schema and statistic information from stats.json. We created the file in the previous notebook rossmann-store-sales-feature-engineering.

[8]:
stats = json.load(open(PREPROCESS_DIR + "/stats.json", "r"))
[8]:
CATEGORICAL_COLUMNS = stats["CATEGORICAL_COLUMNS"]
CONTINUOUS_COLUMNS = stats["CONTINUOUS_COLUMNS"]
LABEL_COLUMNS = stats["LABEL_COLUMNS"]
COLUMNS = CATEGORICAL_COLUMNS + CONTINUOUS_COLUMNS + LABEL_COLUMNS

The embedding table shows the cardinality of each categorical variable along with its associated embedding size. Each entry is of the form (cardinality, embedding_size).

[9]:
EMBEDDING_TABLE_SHAPES = stats["EMBEDDING_TABLE_SHAPES"]
EMBEDDING_TABLE_SHAPES
[9]:
{'Assortment': [4, 16],
 'CompetitionMonthsOpen': [26, 16],
 'CompetitionOpenSinceYear': [24, 16],
 'Day': [32, 16],
 'DayOfWeek': [8, 16],
 'Events': [22, 16],
 'Month': [13, 16],
 'Promo2SinceYear': [9, 16],
 'Promo2Weeks': [27, 16],
 'PromoInterval': [4, 16],
 'Promo_bw': [7, 16],
 'Promo_fw': [7, 16],
 'SchoolHoliday_bw': [9, 16],
 'SchoolHoliday_fw': [9, 16],
 'State': [13, 16],
 'StateHoliday': [3, 16],
 'StateHoliday_bw': [4, 16],
 'StateHoliday_fw': [4, 16],
 'Store': [1116, 81],
 'StoreType': [5, 16],
 'Week': [53, 16],
 'Year': [4, 16]}

Training a Network

Now that our data is preprocessed and saved out, we can leverage datasets to read through the preprocessed parquet files in an online fashion to train neural networks.

We’ll start by setting some universal hyperparameters for our model and optimizer. These settings will be the same across all of the frameworks that we explore in the different notebooks.

If you’re interested in contributing to NVTabular, feel free to take this challenge on and submit a pull request if successful. 12% RMSPE is achievable using the Novograd optimizer, but we know of no Novograd implementation for TensorFlow that supports sparse gradients, and so we are not including that solution below.

[10]:
EMBEDDING_DROPOUT_RATE = 0.04
DROPOUT_RATES = [0.001, 0.01]
HIDDEN_DIMS = [1000, 500]
BATCH_SIZE = 65536
LEARNING_RATE = 0.001
EPOCHS = 25

# TODO: Calculate on the fly rather than recalling from previous analysis.
MAX_SALES_IN_TRAINING_SET = 38722.0
MAX_LOG_SALES_PREDICTION = 1.2 * math.log(MAX_SALES_IN_TRAINING_SET + 1.0)

TRAIN_PATHS = sorted(glob.glob(os.path.join(PREPROCESS_DIR_TRAIN, "*.parquet")))
VALID_PATHS = sorted(glob.glob(os.path.join(PREPROCESS_DIR_VALID, "*.parquet")))

TensorFlow

TensorFlow: Preparing Datasets

KerasSequenceLoader wraps a lightweight iterator around a dataset object to handle chunking, shuffling, and application of any workflows (which can be applied online as a preprocessing step). For column names, can use either a list of string names or a list of TensorFlow feature_columns that will be used to feed the network

[11]:
import tensorflow as tf

# we can control how much memory to give tensorflow with this environment variable
# IMPORTANT: make sure you do this before you initialize TF's runtime, otherwise
# it's too late and TF will have claimed all free GPU memory
os.environ["TF_MEMORY_ALLOCATION"] = "8192"  # explicit MB
os.environ["TF_MEMORY_ALLOCATION"] = "0.5"  # fraction of free memory
from nvtabular.loader.tensorflow import KerasSequenceLoader, KerasSequenceValidater


# cheap wrapper to keep things some semblance of neat
def make_categorical_embedding_column(name, dictionary_size, embedding_dim):
    return tf.feature_column.embedding_column(
        tf.feature_column.categorical_column_with_identity(name, dictionary_size), embedding_dim
    )


# instantiate our columns
categorical_columns = [
    make_categorical_embedding_column(name, *EMBEDDING_TABLE_SHAPES[name])
    for name in CATEGORICAL_COLUMNS
]
continuous_columns = [tf.feature_column.numeric_column(name, (1,)) for name in CONTINUOUS_COLUMNS]

# feed them to our datasets
train_dataset = KerasSequenceLoader(
    TRAIN_PATHS,  # you could also use a glob pattern
    feature_columns=categorical_columns + continuous_columns,
    batch_size=BATCH_SIZE,
    label_names=LABEL_COLUMNS,
    shuffle=True,
    buffer_size=0.06,  # amount of data, as a fraction of GPU memory, to load at once
)

valid_dataset = KerasSequenceLoader(
    VALID_PATHS,  # you could also use a glob pattern
    feature_columns=categorical_columns + continuous_columns,
    batch_size=BATCH_SIZE * 4,
    label_names=LABEL_COLUMNS,
    shuffle=False,
    buffer_size=0.06,  # amount of data, as a fraction of GPU memory, to load at once
)

TensorFlow: Defining a Model

Using Keras, we can define the layers of our model and their parameters explicitly. Here, for the sake of consistency, we’ll mimic fast.ai’s TabularModel.

[12]:
# DenseFeatures layer needs a dictionary of {feature_name: input}
categorical_inputs = {}
for column_name in CATEGORICAL_COLUMNS:
    categorical_inputs[column_name] = tf.keras.Input(name=column_name, shape=(1,), dtype=tf.int64)
categorical_embedding_layer = tf.keras.layers.DenseFeatures(categorical_columns)
categorical_x = categorical_embedding_layer(categorical_inputs)
categorical_x = tf.keras.layers.Dropout(EMBEDDING_DROPOUT_RATE)(categorical_x)

# Just concatenating continuous, so can use a list
continuous_inputs = []
for column_name in CONTINUOUS_COLUMNS:
    continuous_inputs.append(tf.keras.Input(name=column_name, shape=(1,), dtype=tf.float32))
continuous_embedding_layer = tf.keras.layers.Concatenate(axis=1)
continuous_x = continuous_embedding_layer(continuous_inputs)
continuous_x = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1)(continuous_x)

# concatenate and build MLP
x = tf.keras.layers.Concatenate(axis=1)([categorical_x, continuous_x])
for dim, dropout_rate in zip(HIDDEN_DIMS, DROPOUT_RATES):
    x = tf.keras.layers.Dense(dim, activation="relu")(x)
    x = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1)(x)
    x = tf.keras.layers.Dropout(dropout_rate)(x)
x = tf.keras.layers.Dense(1, activation="linear")(x)

# TODO: Initialize model weights to fix saturation issues.
# For now, we'll just scale the output of our model directly before
# hitting the sigmoid.
x = 0.1 * x

x = MAX_LOG_SALES_PREDICTION * tf.keras.activations.sigmoid(x)

# combine all our inputs into a single list
# (note that you can still use .fit, .predict, etc. on a dict
# that maps input tensor names to input values)
inputs = list(categorical_inputs.values()) + continuous_inputs
tf_model = tf.keras.Model(inputs=inputs, outputs=x)

TensorFlow: Training

[13]:
def rmspe_tf(y_true, y_pred):
    # map back into "true" space by undoing transform
    y_true = tf.exp(y_true) - 1
    y_pred = tf.exp(y_pred) - 1

    percent_error = (y_true - y_pred) / y_true
    return tf.sqrt(tf.reduce_mean(percent_error ** 2))
[14]:
%%time
from time import time

optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
tf_model.compile(optimizer, "mse", metrics=[rmspe_tf])

validation_callback = KerasSequenceValidater(valid_dataset)
start = time()
history = tf_model.fit(
    train_dataset,
    callbacks=[validation_callback],
    epochs=EPOCHS,
)
t_final = time() - start
total_rows = train_dataset.num_rows_processed + valid_dataset.num_rows_processed
print(
    f"run_time: {t_final} - rows: {total_rows} - epochs: {EPOCHS} - dl_thru: { (EPOCHS * total_rows) / t_final}"
)
Epoch 1/25
13/13 [==============================] - 7s 166ms/step - loss: 6.1835 - rmspe_tf: 0.8900
Epoch 2/25
13/13 [==============================] - 2s 166ms/step - loss: 5.3472 - rmspe_tf: 0.8925
Epoch 3/25
13/13 [==============================] - 2s 166ms/step - loss: 4.6994 - rmspe_tf: 0.8807
Epoch 4/25
13/13 [==============================] - 2s 162ms/step - loss: 3.8926 - rmspe_tf: 0.8565
Epoch 5/25
13/13 [==============================] - 2s 165ms/step - loss: 2.9031 - rmspe_tf: 0.8123
Epoch 6/25
13/13 [==============================] - 2s 164ms/step - loss: 1.8590 - rmspe_tf: 0.7358
Epoch 7/25
13/13 [==============================] - 2s 161ms/step - loss: 0.9700 - rmspe_tf: 0.6148
Epoch 8/25
13/13 [==============================] - 2s 163ms/step - loss: 0.3942 - rmspe_tf: 0.4515
Epoch 9/25
13/13 [==============================] - 2s 164ms/step - loss: 0.1273 - rmspe_tf: 0.2961
Epoch 10/25
13/13 [==============================] - 2s 161ms/step - loss: 0.0468 - rmspe_tf: 0.2237
Epoch 11/25
13/13 [==============================] - 2s 163ms/step - loss: 0.0347 - rmspe_tf: 0.2213
Epoch 12/25
13/13 [==============================] - 2s 163ms/step - loss: 0.0342 - rmspe_tf: 0.2298
Epoch 13/25
13/13 [==============================] - 2s 164ms/step - loss: 0.0325 - rmspe_tf: 0.2151
Epoch 14/25
13/13 [==============================] - 2s 163ms/step - loss: 0.0293 - rmspe_tf: 0.2076
Epoch 15/25
13/13 [==============================] - 2s 164ms/step - loss: 0.0274 - rmspe_tf: 0.1980
Epoch 16/25
13/13 [==============================] - 2s 161ms/step - loss: 0.0261 - rmspe_tf: 0.1939
Epoch 17/25
13/13 [==============================] - 2s 163ms/step - loss: 0.0252 - rmspe_tf: 0.1875
Epoch 18/25
13/13 [==============================] - 2s 163ms/step - loss: 0.0245 - rmspe_tf: 0.1830
Epoch 19/25
13/13 [==============================] - 2s 165ms/step - loss: 0.0241 - rmspe_tf: 0.1885
Epoch 20/25
13/13 [==============================] - 2s 163ms/step - loss: 0.0236 - rmspe_tf: 0.1814
Epoch 21/25
13/13 [==============================] - 2s 166ms/step - loss: 0.0231 - rmspe_tf: 0.1797
Epoch 22/25
13/13 [==============================] - 2s 161ms/step - loss: 0.0227 - rmspe_tf: 0.1823
Epoch 23/25
13/13 [==============================] - 2s 164ms/step - loss: 0.0224 - rmspe_tf: 0.1754
Epoch 24/25
13/13 [==============================] - 2s 164ms/step - loss: 0.0221 - rmspe_tf: 0.1745
Epoch 25/25
13/13 [==============================] - 2s 163ms/step - loss: 0.0223 - rmspe_tf: 0.1740
run_time: 67.14452314376831 - rows: 76 - epochs: 25 - dl_thru: 28.29717020897987
CPU times: user 2min 52s, sys: 11.7 s, total: 3min 4s
Wall time: 1min 7s