# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_transformers4rec_getting-started-session-based-01-etl-with-nvtabular/nvidia_logo.png

ETL with NVTabular

In this notebook we are going to generate synthetic data and then create sequential features with NVTabular. Such data will be used in the next notebook to train a session-based recommendation model.

NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems. It provides a high level abstraction to simplify code and accelerates computation on the GPU using the RAPIDS cuDF library.

Import required libraries

import os
import glob

import numpy as np
import pandas as pd

import cudf
import cupy as cp
import nvtabular as nvt
from nvtabular.ops import *
from merlin.schema.tags import Tags

Define Input/Output Path

INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "/workspace/data/")

Create a Synthetic Input Data

NUM_ROWS = 100000
long_tailed_item_distribution = np.clip(np.random.lognormal(3., 1., NUM_ROWS).astype(np.int32), 1, 50000)

# generate random item interaction features 
df = pd.DataFrame(np.random.randint(70000, 90000, NUM_ROWS), columns=['session_id'])
df['item_id'] = long_tailed_item_distribution

# generate category mapping for each item-id
df['category'] = pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32)
df['age_days'] = np.random.uniform(0, 1, NUM_ROWS).astype(np.float32)
df['weekday_sin']= np.random.uniform(0, 1, NUM_ROWS).astype(np.float32)

# generate day mapping for each session 
map_day = dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique()))))
df['day'] =  df.session_id.map(map_day)

Visualize couple of rows of the synthetic dataset:

df.head()
session_id item_id category age_days weekday_sin day
0 84344 5 2 0.197794 0.220711 1
1 79183 26 7 0.659679 0.554893 2
2 76110 7 2 0.545001 0.476261 5
3 86269 78 21 0.231765 0.040279 2
4 73974 90 24 0.321135 0.082030 5

Feature Engineering with NVTabular

Deep Learning models require dense input features. Categorical features are sparse, and need to be represented by dense embeddings in the model. To allow for that, categorical features first need to be encoded as contiguous integers (0, ..., |C|), where |C| is the feature cardinality (number of unique values), so that their embeddings can be efficiently stored in embedding layers. We will use NVTabular to preprocess the categorical features, so that all categorical columns are encoded as contiguous integers. Note that the Categorify op encodes OOVs or nulls to 0 automatically. In our synthetic dataset we do not have any nulls. On the other hand 0 is also used for padding the sequences in input block, therefore, you can set start_index=1 arg in the Categorify op if you want the encoded null or OOV values to start from 1 instead of 0 because we reserve 0 for padding the sequence features.

Here our goal is to create sequential features. In this cell, we are creating temporal features and grouping them together at the session level, sorting the interactions by time. Note that we also trim each feature sequence in a session to a certain length. Here, we use the NVTabular library so that we can easily preprocess and create features on GPU with a few lines.

SESSIONS_MAX_LENGTH =20

# Categorify categorical features
categ_feats = ['session_id', 'item_id', 'category'] >> nvt.ops.Categorify()

# Define Groupby Workflow
groupby_feats = categ_feats + ['day', 'age_days', 'weekday_sin']

# Group interaction features by session
groupby_features = groupby_feats >> nvt.ops.Groupby(
    groupby_cols=["session_id"], 
    aggs={
        "item_id": ["list", "count"],
        "category": ["list"],     
        "day": ["first"],
        "age_days": ["list"],
        'weekday_sin': ["list"],
        },
    name_sep="-")

# Select and truncate the sequential features
sequence_features_truncated = (
    groupby_features['category-list']
    >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) 
    >> nvt.ops.ValueCount()
)

sequence_features_truncated_item = (
    groupby_features['item_id-list']
    >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) 
    >> TagAsItemID()
    >> nvt.ops.ValueCount()
)  
sequence_features_truncated_cont = (
    groupby_features['age_days-list', 'weekday_sin-list'] 
    >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) 
    >> nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])
    >> nvt.ops.ValueCount()
)

# Filter out sessions with length 1 (not valid for next-item prediction training and evaluation)
MINIMUM_SESSION_LENGTH = 2
selected_features = (
    groupby_features['item_id-count', 'day-first', 'session_id'] + 
    sequence_features_truncated_item +
    sequence_features_truncated + 
    sequence_features_truncated_cont
)
    
filtered_sessions = selected_features >> nvt.ops.Filter(f=lambda df: df["item_id-count"] >= MINIMUM_SESSION_LENGTH)

seq_feats_list = filtered_sessions['item_id-list', 'category-list', 'age_days-list', 'weekday_sin-list'] >>  nvt.ops.ValueCount()


workflow = nvt.Workflow(filtered_sessions['session_id', 'day-first', 'item_id-count'] + seq_feats_list)

dataset = nvt.Dataset(df, cpu=False)
# Generate statistics for the features
workflow.fit(dataset)
# Apply the preprocessing and return an NVTabular dataset
sessions_ds = workflow.transform(dataset)
# Convert the NVTabular dataset to a Dask cuDF dataframe (`to_ddf()`) and then to cuDF dataframe (`.compute()`)
sessions_gdf = sessions_ds.to_ddf().compute()
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
sessions_gdf.head(3)
session_id day-first item_id-count item_id-list category-list age_days-list weekday_sin-list
0 1 1 18 [4, 11, 17, 29, 18, 1, 3, 48, 11, 9, 16, 5, 19... [1, 3, 6, 8, 5, 1, 1, 13, 3, 2, 6, 2, 4, 52, 1... [0.14561038, 0.9393455, 0.012047833, 0.658193,... [0.23078364, 0.99029666, 0.89728844, 0.9642181...
1 2 4 15 [97, 7, 44, 24, 31, 23, 41, 245, 11, 3, 28, 11... [28, 2, 12, 7, 9, 7, 10, 61, 3, 1, 8, 3, 18, 2... [0.54006344, 0.71162707, 0.2320292, 0.49496385... [0.72449577, 0.35770282, 0.13853826, 0.0450636...
2 3 9 15 [5, 27, 26, 111, 97, 50, 3, 4, 7, 31, 29, 23, ... [2, 8, 7, 27, 28, 13, 1, 1, 2, 9, 8, 7, 2, 11, 1] [0.14291424, 0.11157788, 0.7810709, 0.11342292... [0.35786915, 0.467376, 0.34360662, 0.50400823,...
sessions_gdf.dtypes
session_id          int64
day-first           int64
item_id-count       int32
item_id-list         list
category-list        list
age_days-list        list
weekday_sin-list     list
dtype: object

It is possible to save the preprocessing workflow. That is useful to apply the same preprocessing to other data (with the same schema) and also to deploy the session-based recommendation pipeline to Triton Inference Server.

workflow.output_schema
name tags dtype is_list is_ragged properties.num_buckets properties.freq_threshold properties.max_size properties.start_index properties.cat_path properties.domain.min properties.domain.max properties.domain.name properties.embedding_sizes.cardinality properties.embedding_sizes.dimension properties.value_count.min properties.value_count.max
0 session_id (Tags.CATEGORICAL) int64 False False NaN 0.0 0.0 0.0 .//categories/unique.session_id.parquet 0.0 19875.0 session_id 19876.0 408.0 NaN NaN
1 day-first () int64 False False NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 item_id-count (Tags.CATEGORICAL) int32 False False NaN 0.0 0.0 0.0 .//categories/unique.item_id.parquet 0.0 489.0 item_id 490.0 51.0 NaN NaN
3 item_id-list (Tags.ITEM, Tags.ITEM_ID, Tags.ID, Tags.LIST, ... int64 True True NaN 0.0 0.0 0.0 .//categories/unique.item_id.parquet 0.0 489.0 item_id 490.0 51.0 2.0 18.0
4 category-list (Tags.LIST, Tags.CATEGORICAL) int64 True True NaN 0.0 0.0 0.0 .//categories/unique.category.parquet 0.0 176.0 category 177.0 29.0 2.0 18.0
5 age_days-list (Tags.LIST, Tags.CONTINUOUS) float32 True True NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 2.0 18.0
6 weekday_sin-list (Tags.LIST, Tags.CONTINUOUS) float32 True True NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 2.0 18.0

The following will generate schema.pbtxt file in the provided folder.

workflow.fit_transform(dataset).to_parquet(os.path.join(INPUT_DATA_DIR, "processed_nvt"))
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
workflow.save('workflow_etl')

Export pre-processed data by day

In this example we are going to split the preprocessed parquet files by days, to allow for temporal training and evaluation. There will be a folder for each day and three parquet files within each day folder: train.parquet, validation.parquet and test.parquet.

OUTPUT_DIR = os.environ.get("OUTPUT_DIR",os.path.join(INPUT_DATA_DIR, "sessions_by_day"))
!mkdir -p $OUTPUT_DIR
from transformers4rec.data.preprocessing import save_time_based_splits
save_time_based_splits(data=nvt.Dataset(sessions_gdf),
                       output_dir= OUTPUT_DIR,
                       partition_col='day-first',
                       timestamp_col='session_id', 
                      )
Creating time-based splits: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 22.48it/s]

Checking the preprocessed outputs

TRAIN_PATHS = sorted(glob.glob(os.path.join(OUTPUT_DIR, "1", "train.parquet")))
gdf = cudf.read_parquet(TRAIN_PATHS[0])
gdf
session_id item_id-count item_id-list category-list age_days-list weekday_sin-list
0 1 18 [4, 11, 17, 29, 18, 1, 3, 48, 11, 9, 16, 5, 19... [1, 3, 6, 8, 5, 1, 1, 13, 3, 2, 6, 2, 4, 52, 1... [0.14561038, 0.9393455, 0.012047833, 0.658193,... [0.23078364, 0.99029666, 0.89728844, 0.9642181...
1 4 15 [36, 49, 9, 95, 12, 26, 35, 185, 43, 14, 19, 2... [11, 13, 2, 24, 3, 7, 9, 55, 12, 3, 4, 7, 1, 1... [0.4289175, 0.41714236, 0.6593241, 0.7470034, ... [0.5122762, 0.11083387, 0.26527187, 0.77329, 0...
2 30 13 [29, 12, 18, 20, 46, 77, 7, 7, 26, 21, 111, 2, 1] [8, 3, 5, 5, 12, 20, 2, 2, 7, 5, 27, 1, 1] [0.079172395, 0.26267487, 0.9678789, 0.601294,... [0.1440753, 0.5550622, 0.18317387, 0.06565472,...
4 63 12 [5, 44, 10, 5, 26, 6, 193, 11, 13, 9, 10, 60] [2, 12, 3, 2, 7, 2, 50, 3, 4, 2, 3, 15] [0.7659222, 0.9388312, 0.28288805, 0.75763357,... [0.99894804, 0.038836945, 0.85671306, 0.345418...
5 75 12 [48, 21, 5, 40, 4, 182, 36, 39, 54, 37, 8, 116] [13, 5, 2, 10, 1, 48, 11, 10, 14, 11, 4, 31] [0.4896862, 0.7550025, 0.92395943, 0.4152636, ... [0.08632153, 0.82823294, 0.50390047, 0.4975271...
... ... ... ... ... ... ...
2111 19151 2 [24, 19] [7, 4] [0.3092607, 0.25387767] [0.6523481, 0.059806556]
2112 19173 2 [60, 37] [15, 11] [0.82798934, 0.054636054] [0.84105706, 0.52476853]
2113 19188 2 [10, 21] [3, 5] [0.92787683, 0.5812024] [0.13824013, 0.74283314]
2114 19194 2 [4, 158] [1, 41] [0.22679287, 0.024510423] [0.9538698, 0.4295912]
2115 19204 2 [37, 16] [11, 6] [0.5972207, 0.11343666] [0.81323135, 0.46290976]

1687 rows × 6 columns

You have just created session-level features to train a session-based recommendation model using NVTabular. Now you can move to the the next notebook,02-session-based-XLNet-with-PyT.ipynb to train a session-based recommendation model using XLNet, one of the state-of-the-art NLP model. Please shut down this kernel to free the GPU memory before you start the next one.