# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
ETL with NVTabular
In this notebook we are going to generate synthetic data and then create sequential features with NVTabular. Such data will be used in the next notebook to train a session-based recommendation model.
NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems. It provides a high level abstraction to simplify code and accelerates computation on the GPU using the RAPIDS cuDF library.
Import required libraries
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import glob
import cudf
import numpy as np
import pandas as pd
import nvtabular as nvt
from nvtabular.ops import *
from merlin.schema.tags import Tags
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Define Input/Output Path
INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "/workspace/data/")
Create a Synthetic Input Data
NUM_ROWS = os.environ.get("NUM_ROWS", 100000)
long_tailed_item_distribution = np.clip(np.random.lognormal(3., 1., int(NUM_ROWS)).astype(np.int32), 1, 50000)
# generate random item interaction features
df = pd.DataFrame(np.random.randint(70000, 90000, int(NUM_ROWS)), columns=['session_id'])
df['item_id'] = long_tailed_item_distribution
# generate category mapping for each item-id
df['category'] = pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32)
df['age_days'] = np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)
df['weekday_sin']= np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)
# generate day mapping for each session
map_day = dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique()))))
df['day'] = df.session_id.map(map_day)
Visualize couple of rows of the synthetic dataset:
df.head()
session_id | item_id | category | age_days | weekday_sin | day | |
---|---|---|---|---|---|---|
0 | 88348 | 28 | 7 | 0.416052 | 0.116508 | 1 |
1 | 86615 | 6 | 2 | 0.998783 | 0.539034 | 6 |
2 | 85161 | 14 | 4 | 0.975656 | 0.246331 | 3 |
3 | 75889 | 61 | 16 | 0.329182 | 0.033715 | 9 |
4 | 75396 | 29 | 8 | 0.219127 | 0.993250 | 7 |
Feature Engineering with NVTabular
Deep Learning models require dense input features. Categorical features are sparse, and need to be represented by dense embeddings in the model. To allow for that, categorical features first need to be encoded as contiguous integers (0, ..., |C|)
, where |C|
is the feature cardinality (number of unique values), so that their embeddings can be efficiently stored in embedding layers. We will use NVTabular to preprocess the categorical features, so that all categorical columns are encoded as contiguous integers. Note that the Categorify
op encodes nulls
to 1
, OOVs to 2
automatically. We preserve 0
for padding. The encoding of other categories starts from 3
. In our synthetic dataset we do not have any nulls. On the other hand 0
is used for padding the sequences in input block.
Here our goal is to create sequential features. To do so, we are grouping the features together at the session level in the following cell. In this synthetically generated example dataset, we do not have a timestamp column, but if we had one (that’s the case for most real-world datasets), we would be sorting the interactions by the timestamp column as in this example notebook. Note that we also trim each feature sequence in a session to a certain length. Here, we use the NVTabular library so that we can easily preprocess and create features on GPU with a few lines.
SESSIONS_MAX_LENGTH =20
# Categorify categorical features
categ_feats = ['item_id', 'category'] >> nvt.ops.Categorify()
# Define Groupby Workflow
groupby_feats = categ_feats + ['session_id', 'day', 'age_days', 'weekday_sin']
# Group interaction features by session
groupby_features = groupby_feats >> nvt.ops.Groupby(
groupby_cols=["session_id"],
aggs={
"item_id": ["list", "count"],
"category": ["list"],
"day": ["first"],
"age_days": ["list"],
'weekday_sin': ["list"],
},
name_sep="-")
# Select and truncate the sequential features
sequence_features_truncated = (
groupby_features['category-list']
>> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)
)
sequence_features_truncated_item = (
groupby_features['item_id-list']
>> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)
>> TagAsItemID()
)
sequence_features_truncated_cont = (
groupby_features['age_days-list', 'weekday_sin-list']
>> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)
>> nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])
)
# Filter out sessions with length 1 (not valid for next-item prediction training and evaluation)
MINIMUM_SESSION_LENGTH = 2
selected_features = (
groupby_features['item_id-count', 'day-first', 'session_id'] +
sequence_features_truncated_item +
sequence_features_truncated +
sequence_features_truncated_cont
)
filtered_sessions = selected_features >> nvt.ops.Filter(f=lambda df: df["item_id-count"] >= MINIMUM_SESSION_LENGTH)
seq_feats_list = filtered_sessions['item_id-list', 'category-list', 'age_days-list', 'weekday_sin-list'] >> nvt.ops.ValueCount()
workflow = nvt.Workflow(filtered_sessions['session_id', 'day-first'] + seq_feats_list)
dataset = nvt.Dataset(df)
# Generate statistics for the features and export parquet files
# this step will generate the schema file
workflow.fit_transform(dataset).to_parquet(os.path.join(INPUT_DATA_DIR, "processed_nvt"))
It is possible to save the preprocessing workflow. That is useful to apply the same preprocessing to other data (with the same schema) and also to deploy the session-based recommendation pipeline to Triton Inference Server.
workflow.output_schema
name | tags | dtype | is_list | is_ragged | properties.num_buckets | properties.freq_threshold | properties.max_size | properties.cat_path | properties.domain.min | properties.domain.max | properties.domain.name | properties.embedding_sizes.cardinality | properties.embedding_sizes.dimension | properties.value_count.min | properties.value_count.max | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | session_id | () | DType(name='int64', element_type=<ElementType.... | False | False | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
1 | day-first | () | DType(name='int64', element_type=<ElementType.... | False | False | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2 | item_id-list | (Tags.CATEGORICAL, Tags.ID, Tags.LIST, Tags.ITEM) | DType(name='int64', element_type=<ElementType.... | True | True | NaN | 0.0 | 0.0 | .//categories/unique.item_id.parquet | 0.0 | 494.0 | item_id | 495.0 | 52.0 | 2.0 | 16.0 |
3 | category-list | (Tags.CATEGORICAL, Tags.LIST) | DType(name='int64', element_type=<ElementType.... | True | True | NaN | 0.0 | 0.0 | .//categories/unique.category.parquet | 0.0 | 171.0 | category | 172.0 | 29.0 | 2.0 | 16.0 |
4 | age_days-list | (Tags.CONTINUOUS, Tags.LIST) | DType(name='float32', element_type=<ElementTyp... | True | True | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 2.0 | 16.0 |
5 | weekday_sin-list | (Tags.CONTINUOUS, Tags.LIST) | DType(name='float32', element_type=<ElementTyp... | True | True | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 2.0 | 16.0 |
Save NVTabular workflow.
workflow.save(os.path.join(INPUT_DATA_DIR, "workflow_etl"))
Export pre-processed data by day
In this example we are going to split the preprocessed parquet files by days, to allow for temporal training and evaluation. There will be a folder for each day and three parquet files within each day folder: train.parquet
, validation.parquet
and test.parquet
.
OUTPUT_DIR = os.environ.get("OUTPUT_DIR",os.path.join(INPUT_DATA_DIR, "sessions_by_day"))
# Read in the processed parquet file
sessions_gdf = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
print(sessions_gdf.head(3))
session_id day-first item_id-list category-list \
0 70000 1 [306, 5, 40, 17] [104, 3, 12, 6]
1 70001 1 [43, 20, 69, 8, 57] [13, 6, 21, 3, 16]
2 70002 1 [137, 35, 37, 85, 65, 5] [37, 10, 11, 22, 18, 3]
age_days-list \
0 [0.044022594, 0.34956282, 0.7326993, 0.09403495]
1 [0.8072543, 0.28916782, 0.04966254, 0.08417622...
2 [0.04696693, 0.94499177, 0.2922437, 0.83047426...
weekday_sin-list
0 [0.7417527, 0.60325843, 0.07417604, 0.28911334]
1 [0.7995051, 0.86722755, 0.84298295, 0.15793765...
2 [0.72519076, 0.92308444, 0.40120387, 0.3821016...
from transformers4rec.utils.data_utils import save_time_based_splits
save_time_based_splits(data=nvt.Dataset(sessions_gdf),
output_dir= OUTPUT_DIR,
partition_col='day-first',
timestamp_col='session_id',
)
Creating time-based splits: 100%|██████████| 9/9 [00:02<00:00, 4.12it/s]
Check out the preprocessed outputs
TRAIN_PATHS = os.path.join(OUTPUT_DIR, "1", "train.parquet")
df = pd.read_parquet(TRAIN_PATHS)
df.head()
session_id | item_id-list | category-list | age_days-list | weekday_sin-list | |
---|---|---|---|---|---|
0 | 70000 | [306, 5, 40, 17] | [104, 3, 12, 6] | [0.044022594, 0.34956282, 0.7326993, 0.09403495] | [0.7417527, 0.60325843, 0.07417604, 0.28911334] |
1 | 70001 | [43, 20, 69, 8, 57] | [13, 6, 21, 3, 16] | [0.8072543, 0.28916782, 0.04966254, 0.08417622... | [0.7995051, 0.86722755, 0.84298295, 0.15793765... |
2 | 70002 | [137, 35, 37, 85, 65, 5] | [37, 10, 11, 22, 18, 3] | [0.04696693, 0.94499177, 0.2922437, 0.83047426... | [0.72519076, 0.92308444, 0.40120387, 0.3821016... |
4 | 70007 | [28, 9, 153, 74, 53, 15, 173] | [9, 4, 39, 20, 15, 5, 46] | [0.4730765, 0.69885534, 0.034774363, 0.7225920... | [0.33613566, 0.660022, 0.72897774, 0.66087157,... |
5 | 70021 | [59, 32, 11, 21, 23, 23, 9, 15] | [17, 10, 7, 7, 8, 8, 4, 5] | [0.07898139, 0.27463168, 0.1885847, 0.5203435,... | [0.39734098, 0.74895114, 0.43540764, 0.8372503... |
import gc
del df
gc.collect()
512
You have just created session-level features to train a session-based recommendation model using NVTabular. Now you can move to the the next notebook,02-session-based-XLNet-with-PyT.ipynb
to train a session-based recommendation model using XLNet, one of the state-of-the-art NLP model. Please shut down this kernel to free the GPU memory before you start the next one.