Rossmann Store Sales Example

The following example will illustrate how to use NVTabular to preprocess and load tabular data for training neural networks in both PyTorch and TensorFlow. We’ll use a dataset built by FastAI for solving the Kaggle Rossmann Store Sales competition. Some pandas preprocessing is required to build the appropriate feature set, so make sure to run rossmann-store-sales-preproc.ipynb first before going through this notebook.

[ ]:
import nvtabular as nvt
import os
import glob

Preparing our dataset

Let’s start by defining some of the a priori information about our data, including its schema (what columns to use and what sorts of variables they represent), as well as the location of the files corresponding to some particular sampling from this schema. Note that throughout, I’ll use UPPERCASE variables to represent this sort of a priori information that you might usually encode using commandline arguments or config files.

[2]:
DATA_DIR = os.environ.get('DATA_DIR', '/data')

CATEGORICAL_COLUMNS = [
    'Store', 'DayOfWeek', 'Year', 'Month', 'Day', 'StateHoliday', 'CompetitionMonthsOpen',
    'Promo2Weeks', 'StoreType', 'Assortment', 'PromoInterval', 'CompetitionOpenSinceYear', 'Promo2SinceYear',
    'State', 'Week', 'Events', 'Promo_fw', 'Promo_bw', 'StateHoliday_fw', 'StateHoliday_bw',
    'SchoolHoliday_fw', 'SchoolHoliday_bw'
]

CONTINUOUS_COLUMNS = [
    'CompetitionDistance', 'Max_TemperatureC', 'Mean_TemperatureC', 'Min_TemperatureC',
   'Max_Humidity', 'Mean_Humidity', 'Min_Humidity', 'Max_Wind_SpeedKm_h',
   'Mean_Wind_SpeedKm_h', 'CloudCover', 'trend', 'trend_DE',
   'AfterStateHoliday', 'BeforeStateHoliday', 'Promo', 'SchoolHoliday'
]
LABEL_COLUMNS = ['Sales']

COLUMNS = CATEGORICAL_COLUMNS + CONTINUOUS_COLUMNS + LABEL_COLUMNS

What files are available to train on in our data directory?

[3]:
! ls $DATA_DIR
jp_ross  test.csv  train.csv  valid.csv

train.csv and valid.csv seem like good candidates, let’s use those.

[4]:
TRAIN_PATH = os.path.join(DATA_DIR, 'train.csv')
VALID_PATH = os.path.join(DATA_DIR, 'valid.csv')

Workflows and Preprocessing

A Workflow is used to represent the chains of feature engineering and preprocessing operations performed on a dataset, and is instantiated with a description of the dataset’s schema so that it can keep track of how columns transform with each operation.

[5]:
# note that here, we want to perform a normalization transformation on the label
# column. Since NVT doesn't support transforming label columns right now, we'll
# pretend it's a regular continuous column during our feature engineering phase
proc = nvt.Workflow(
    cat_names=CATEGORICAL_COLUMNS,
    cont_names=CONTINUOUS_COLUMNS+LABEL_COLUMNS,
    label_name=LABEL_COLUMNS
)

Ops

We add operations to a Workflow by leveraging the add_(cat|cont)_feature and add_(cat|cont)_preprocess methods for categorical and continuous variables, respectively. When we’re done adding ops, we call the finalize method to let the Workflow build a representation of its outputs.

[6]:
proc.add_cont_feature(nvt.ops.FillMissing())
proc.add_cont_preprocess(nvt.ops.LogOp(columns=['Sales']))
proc.add_cont_preprocess(nvt.ops.Normalize())
proc.add_cat_preprocess(nvt.ops.Categorify())
proc.finalize()

Datasets

In general, the Ops in our Workflow will require measurements of statistical properties of our data in order to be leveraged. For example, the Normalize op requires measurements of the dataset mean and standard deviation, and the Categorify op requires an accounting of all the categories a particular feature can manifest. However, we frequently need to measure these properties across datasets which are too large to fit into GPU memory (or CPU memory for that matter) at once.

NVTabular solves this by providing the dataset object, an iterator over manageable chunks of sets of parquet or csv files that can we can use to compute statistics in an online fashion (and, later, to train neural networks in batches loaded from disk). The size of those chunks will be determined by the gpu_memory_frac kwarg, which will load chunks whose memory footprint is equal to that fraction of available GPU memory.

Larger chunks will lead to shorter run times due to the parallel-processing power of GPUs, but will constrain your memory and possibly lead to disk caching by expensive operations, thereby lowering efficiency.

[7]:
GPU_MEMORY_FRAC = 0.2
train_ds_iterator = nvt.dataset(TRAIN_PATH, gpu_memory_frac=GPU_MEMORY_FRAC, columns=COLUMNS)
valid_ds_iterator = nvt.dataset(VALID_PATH, gpu_memory_frac=GPU_MEMORY_FRAC, columns=COLUMNS)

Now that we have our datasets, we’ll apply our Workflow to them and save the results out to parquet files for fast reading at train time. We’ll also measure and record statistics on our training set using the record_stats=True kwarg so that our Workflow can use them at apply time.

[8]:
PREPROCESS_DIR = os.path.join(DATA_DIR, 'jp_ross')
PREPROCESS_DIR_TRAIN = os.path.join(PREPROCESS_DIR, 'train')
PREPROCESS_DIR_VALID = os.path.join(PREPROCESS_DIR, 'valid')
! mkdir -p $PREPROCESS_DIR_TRAIN
! mkdir -p $PREPROCESS_DIR_VALID
[9]:
proc.apply(train_ds_iterator, apply_offline=True, record_stats=True, output_path=PREPROCESS_DIR_TRAIN, shuffle=False)
proc.apply(valid_ds_iterator, apply_offline=True, record_stats=False, output_path=PREPROCESS_DIR_VALID, shuffle=False)

Finalize columns

The FastAI workflow will leverage the Workflow.ds_to_tensors method, which will map a dataset to its corresponding PyTorch tensors. In order to make sure it runs correctly, we’ll call the create_final_cols method to let the Workflow know to build the output dataset schema, and then we’ll be sure to remove instances of the label column that got added to that schema when we performed processing on it.

[10]:
proc.create_final_cols()
# using log op and normalize on sales column causes it to get added to
# continuous columns_ctx, so we'll remove it here
while True:
    try:
        proc.columns_ctx['final']['cols']['continuous'].remove(LABEL_COLUMNS[0])
    except ValueError:
        break

Training a Network

Now that our data is preprocessed and saved out, we can leverage datasets to read through the preprocessed parquet files in an online fashion to train neural networks! Even better, using the dlpack library, we can pass data loaded by cuDF’s accelerated parquet reader to networks in TensorFlow and PyTorch in remarkably analagous ways. Let’s compare them one-to-one below!

We’ll start by setting some universal hyperparameters for our model and optimizer (without making any claims on the quality of these hyperparmeter choices)

[11]:
BATCH_SIZE = 65536
LEARNING_RATE = 1e-3
EMBEDDING_DROPOUT_RATE = 0.04
DROPOUT_RATES = [0.001, 0.01]
HIDDEN_DIMS = [1000, 500]
EPOCHS = 10

# our categorical encoder provides a handy utility for coming up with default embedding sizes
# based on the number of potential categories, so we'll just use those defaults
EMBEDDING_TABLE_SHAPES = {
    column: shape for column, shape in
        proc.df_ops['Categorify'].get_emb_sz(
            proc.stats['categories'], proc.columns_ctx['categorical']['base']
        )
}

TRAIN_PATHS = sorted(glob.glob(os.path.join(PREPROCESS_DIR_TRAIN, '*.parquet')))
VALID_PATHS = sorted(glob.glob(os.path.join(PREPROCESS_DIR_VALID, '*.parquet')))

Data Loaders

The first thing we need to do is set up the objects for getting data into our models

Tensorflow

KerasSequenceDataset wraps a lightweight iterator around a dataset object to handle chunking, shuffling, and application of any workflows (which can be applied online as a preprocessing step). For column names, can use either a list of string names or a list of TensorFlow feature_columns that will be used to feed the network

[12]:
import tensorflow as tf

# we can control how much memory to give tensorflow with this environment variable
# IMPORTANT: make sure you do this before you initialize TF's runtime, otherwise
# it's too late and TF will have claimed all free GPU memory
os.environ['TF_MEMORY_ALLOCATION'] = "8192" # explicit MB
os.environ['TF_MEMORY_ALLOCATION'] = "0.5" # fraction of free memory
from nvtabular.tf_dataloader import KerasSequenceDataset


# cheap wrapper to keep things some semblance of neat
def make_categorical_embedding_column(name, dictionary_size, embedding_dim):
    return tf.feature_column.embedding_column(
        tf.feature_column.categorical_column_with_identity(name, dictionary_size),
        embedding_dim
    )

# instantiate our columns
categorical_columns = [
    make_categorical_embedding_column(name, *EMBEDDING_TABLE_SHAPES[name]) for
        name in CATEGORICAL_COLUMNS
]
continuous_columns = [
    tf.feature_column.numeric_column(name, (1,)) for name in CONTINUOUS_COLUMNS
]

# feed them to our datasets
train_dataset_tf = KerasSequenceDataset(
    TRAIN_PATHS, # you could also use a glob pattern
    categorical_columns+continuous_columns,
    batch_size=BATCH_SIZE,
    label_name=LABEL_COLUMNS[0],
    shuffle=True,
    buffer_size=48 # how many batches to load at once
)
valid_dataset_tf = KerasSequenceDataset(
    VALID_PATHS, # you could also use a glob pattern
    categorical_columns+continuous_columns,
    batch_size=BATCH_SIZE*4,
    label_name=LABEL_COLUMNS[0],
    shuffle=False,
    buffer_size=12
)

PyTorch

workflow.ds_to_tensors maps a symbolic dataset object to cat_features, cont_features, labels PyTorch tenosrs by iterating through the dataset and concatenating the results. Note that this means that the whole of the dataset is in memory. For larger than memory datasets, see the example in criteo-example.ipynb leveraing PyTorch ChainDatasets.

[ ]:
import torch
from nvtabular.torch_dataloader import TensorItrDataset, DLDataLoader
from fastai.basic_data import DataBunch
from fastai.tabular import TabularModel
from fastai.basic_train import Learner
from fastai.layers import MSELossFlat


def make_batched_dataloader(paths, columns, batch_size):
    ds_iterator = nvt.dataset(paths, columns=columns)
    ds_tensors = proc.ds_to_tensors(ds_iterator, apply_ops=False)
    ds_batch_sets = TensorItrDataset(ds_tensors, batch_size=batch_size)
    return DLDataLoader(
        ds_batch_sets,
        batch_size=None,
        pin_memory=False,
        num_workers=0
    )


train_dataset_pt = make_batched_dataloader(TRAIN_PATHS, COLUMNS, BATCH_SIZE)
valid_dataset_pt = make_batched_dataloader(VALID_PATHS, COLUMNS, BATCH_SIZE*4)
databunch = DataBunch(
    train_dataset_pt,
    valid_dataset_pt,
    collate_fn=lambda x: x,
    device="cuda"
)

Defining a model

Next we’ll need to define the inputs that will feed our model and build an architecture on top of them. For now, we’ll just stick to a simple MLP model.

Tensorflow

Using Keras, we can define the layers of our model and their parameters explicitly. Here, for the sake of consistency, I’ve tried to recreate the model created by FastAI as faithfully as I can given their description here, without making any claims as to whether this is the right model to use.

[14]:
# DenseFeatures layer needs a dictionary of {feature_name: input}
categorical_inputs = {}
for column_name in CATEGORICAL_COLUMNS:
    categorical_inputs[column_name] = tf.keras.Input(name=column_name, shape=(1,), dtype=tf.int64)
categorical_embedding_layer = tf.keras.layers.DenseFeatures(categorical_columns)
categorical_x = categorical_embedding_layer(categorical_inputs)
categorical_x = tf.keras.layers.Dropout(EMBEDDING_DROPOUT_RATE)(categorical_x)

# Just concatenating continuous, so can use a list
continuous_inputs = []
for column_name in CONTINUOUS_COLUMNS:
    continuous_inputs.append(tf.keras.Input(name=column_name, shape=(1,), dtype=tf.float32))
continuous_embedding_layer = tf.keras.layers.Concatenate(axis=1)
continuous_x = continuous_embedding_layer(continuous_inputs)
continuous_x = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1)(continuous_x)

# concatenate and build MLP
x = tf.keras.layers.Concatenate(axis=1)([categorical_x, continuous_x])
for dim, dropout_rate in zip(HIDDEN_DIMS, DROPOUT_RATES):
    x = tf.keras.layers.Dense(dim, activation='relu')(x)
    x = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1)(x)
    x = tf.keras.layers.Dropout(dropout_rate)(x)
x = tf.keras.layers.Dense(1, activation='linear')(x)

# combine all our inputs into a single list
# (note that you can still use .fit, .predict, etc. on a dict
# that maps input tensor names to input values)
inputs = list(categorical_inputs.values()) + continuous_inputs
tf_model = tf.keras.Model(inputs=inputs, outputs=x)

PyTorch

Using FastAI’s TabularModel, we can build an MLP under the hood by defining its high-level characteristics.

[15]:
emb_szs = [i for i in EMBEDDING_TABLE_SHAPES.values()]
pt_model = TabularModel(
    emb_szs=emb_szs,
    n_cont=len(CONTINUOUS_COLUMNS),
    out_sz=1,
    layers=HIDDEN_DIMS,
    ps=DROPOUT_RATES,
    use_bn=True,
    emb_drop=EMBEDDING_DROPOUT_RATE
)

Define optimizer and train

This is probably the most conceptually consistent part between the frameworks: we’ll define an objective and a method for optimizing it, then fit our model to our dataset iterators using that optimization scheme.

For both frameworks, we’ll build a quick implementation of the metric Kaggle used in the original competition so that we can keep tabs on it during training.

TensorFlow

[16]:
def rmspe_tf(y_true, y_pred):
    # map back into "true" space by undoing transform
    y_true = y_true*proc.stats['stds']['Sales'] + proc.stats['means']['Sales']
    y_pred = y_pred*proc.stats['stds']['Sales'] + proc.stats['means']['Sales']

    # and then the log(1+x)
    y_true = tf.exp(y_true) - 1
    y_pred = tf.exp(y_pred) - 1

    # drop zeroes for stability (and consistency with Kaggle)
    where = tf.not_equal(y_true, 0.)
    y_true = y_true[where]
    y_pred = y_pred[where]

    percent_error = (y_true - y_pred) / y_true
    return tf.sqrt(tf.reduce_mean(percent_error**2))

optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)
tf_model.compile(optimizer, 'mse', metrics=[rmspe_tf])
history = tf_model.fit(
    train_dataset_tf,
    validation_data=valid_dataset_tf,
    epochs=EPOCHS
)
WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to
  ['...']
WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to
  ['...']
Train for 10 steps, validate for 1 steps
Epoch 1/10
10/10 [==============================] - 8s 832ms/step - loss: 3.9479 - rmspe_tf: 2.3827 - val_loss: 2.0348 - val_rmspe_tf: 0.8523
Epoch 2/10
10/10 [==============================] - 2s 160ms/step - loss: 1.3154 - rmspe_tf: 0.6430 - val_loss: 1.1368 - val_rmspe_tf: 0.5126
Epoch 3/10
10/10 [==============================] - 2s 157ms/step - loss: 0.8508 - rmspe_tf: 0.5018 - val_loss: 0.8112 - val_rmspe_tf: 0.4539
Epoch 4/10
10/10 [==============================] - 2s 158ms/step - loss: 0.6647 - rmspe_tf: 0.4183 - val_loss: 0.6579 - val_rmspe_tf: 0.4051
Epoch 5/10
10/10 [==============================] - 2s 158ms/step - loss: 0.5661 - rmspe_tf: 0.4015 - val_loss: 0.5817 - val_rmspe_tf: 0.3725
Epoch 6/10
10/10 [==============================] - 2s 153ms/step - loss: 0.4897 - rmspe_tf: 0.3520 - val_loss: 0.5235 - val_rmspe_tf: 0.3531
Epoch 7/10
10/10 [==============================] - 2s 156ms/step - loss: 0.4248 - rmspe_tf: 0.3577 - val_loss: 0.4651 - val_rmspe_tf: 0.3333
Epoch 8/10
10/10 [==============================] - 2s 153ms/step - loss: 0.3668 - rmspe_tf: 0.3019 - val_loss: 0.4158 - val_rmspe_tf: 0.3108
Epoch 9/10
10/10 [==============================] - 2s 153ms/step - loss: 0.3192 - rmspe_tf: 0.3110 - val_loss: 0.3742 - val_rmspe_tf: 0.2967
Epoch 10/10
10/10 [==============================] - 2s 156ms/step - loss: 0.2801 - rmspe_tf: 0.2737 - val_loss: 0.3424 - val_rmspe_tf: 0.2840

PyTorch

[17]:
def rmspe_pt(y_true, y_pred):
    # map back into "true" space by undoing transform
    y_true = y_true[:, 0]*proc.stats['stds']['Sales'] + proc.stats['means']['Sales']
    y_pred = y_pred*proc.stats['stds']['Sales'] + proc.stats['means']['Sales']

    # and then the log(1+x)
    y_true = torch.exp(y_true) - 1
    y_pred = torch.exp(y_pred) - 1

    # drop zeroes for stability (and consistency with Kaggle)
    where = y_true != 0.
    y_true = y_true[where]
    y_pred = y_pred[where]

    percent_error = (y_true - y_pred) / y_true
    return torch.sqrt((percent_error**2).mean())

learn = Learner(databunch, pt_model, metrics=[rmspe_pt], wd=0.)
learn.loss_func = MSELossFlat()
learn.fit(EPOCHS, LEARNING_RATE)
epoch train_loss valid_loss rmspe_pt time
0 3.554359 0.975809 0.533289 00:01
1 2.273754 0.890281 0.506510 00:01
2 1.679724 0.748983 0.425627 00:01
3 1.315282 0.568450 0.379923 00:01
4 1.042458 0.349413 0.277927 00:01
5 0.843311 0.299327 0.261619 00:01
6 0.695372 0.277244 0.243338 00:01
7 0.585021 0.265418 0.240927 00:01
8 0.501111 0.258835 0.237369 00:01
9 0.436057 0.251688 0.232957 00:01