# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
ETL with NVTabular
You can execute these tutorial notebooks using the latest stable merlin-pytorch container.
Launch the docker container
docker run -it --gpus device=0 -p 8000:8000 -p 8001:8001 -p 8002:8002 -p 8888:8888 -v <path_to_data>:/workspace/data/ nvcr.io/nvidia/merlin/merlin-pytorch:22.XX
This script will mount your local data folder that includes your data files to /workspace/data
directory in the merlin-pytorch docker container.
1. Introduction
In this notebook, we will create a preprocessing and feature engineering pipeline with Rapids cuDF and Merlin NVTabular libraries to prepare our dataset for session-based recommendation model training.
NVTabular is a feature engineering and preprocessing library for tabular data that is designed to easily manipulate terabyte scale datasets and train deep learning (DL) based recommender systems. It provides high-level abstraction to simplify code and accelerates computation on the GPU using the RAPIDS Dask-cuDF library, and is designed to be interoperable with both PyTorch and TensorFlow using dataloaders that have been developed as extensions of native framework code.
Our main goal is to create sequential features. In order to do that, we are going to perform the following:
Categorify categorical features with
Categorify()
opCreate temporal features with a
user-defined custom
op andLambda
opTransform continuous features using
Log
andNormalize
opsGroup all these features together at the session level sorting the interactions by time with
Groupby
Finally export the preprocessed datasets to parquet files by hive-partitioning.
1.1. Dataset
In our hands-on exercise notebooks we are going to use a subset of the publicly available eCommerce dataset. The eCommerce behavior data contains 7 months of data (from October 2019 to April 2020) from a large multi-category online store. Each row in the file represents an event. All events are related to products and users. Each event is like many-to-many relation between products and users.
Data collected by Open CDP project and the source of the dataset is REES46 Marketing Platform.
2. Import Libraries
import os
import numpy as np
import cupy as cp
import glob
import cudf
import nvtabular as nvt
from nvtabular.ops import Operator
from merlin.dag import ColumnSelector
from merlin.schema import Schema, Tags
3. Set up Input and Output Data Paths
# define data path about where to get our data
INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "/workspace/data/")
4. Read the Input Parquet file
Even though the original dataset contains 7 months data files, we are going to use the first seven days of the Oct-2019.csv
ecommerce dataset. We already performed certain preprocessing steps on the first month (Oct-2019) of the raw dataset in the 01-preprocess
notebook:
we created
event_time_ts
column fromevent_time
column which shows the time when event happened at (in UTC).we created
prod_first_event_time_ts
column which indicates the timestamp that an item was seen first time.we removed the rows where the
user_session
is Null. As a result, 2 rows were removed.we categorified the
user_session
column, so that it now has only integer values.we removed consequetively repeated (user, item) interactions. For example, an original session with
[1, 2, 4, 1, 2, 2, 3, 3, 3]
product interactions has become[1, 2, 4, 1, 2, 3]
after removing the repeated interactions on the same item within the same session.
Below, we start by reading in Oct-2019.parquet
with cuDF. In order to create and save Oct-2019.parquet
file, please run 01-preprocess.ipynb notebook first.
%%time
df = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, 'Oct-2019.parquet'))
df.head(5)
CPU times: user 585 ms, sys: 239 ms, total: 824 ms
Wall time: 840 ms
user_session | event_type | product_id | category_id | category_code | brand | price | user_id | event_time_ts | prod_first_event_time_ts | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 43 | view | 5300797 | 2053013563173241677 | <NA> | panasonic | 39.90 | 513903572 | 1570460611 | 1569948287 |
1 | 43 | view | 5300798 | 2053013563173241677 | <NA> | panasonic | 32.18 | 513903572 | 1570460616 | 1569934097 |
2 | 43 | view | 5300284 | 2053013563173241677 | <NA> | rowenta | 30.86 | 513903572 | 1570460621 | 1569927253 |
3 | 43 | view | 5300382 | 2053013563173241677 | <NA> | remington | 28.22 | 513903572 | 1570460636 | 1570026747 |
4 | 43 | view | 5300366 | 2053013563173241677 | <NA> | polaris | 26.46 | 513903572 | 1570460650 | 1570097085 |
df.shape
(6390928, 10)
Let’s check if there is any column with nulls.
df.isnull().any()
user_session False
event_type False
product_id False
category_id False
category_code True
brand True
price False
user_id False
event_time_ts False
prod_first_event_time_ts False
dtype: bool
We see that 'category_code'
and 'brand'
columns have null values, and in the following cell we are going to fill these nulls with via categorify op, and then all categorical columns will be encoded to continuous integers. Note that we add start_index=1
in the Categorify op
for the categorical columns, the reason for that we want the encoded null values to start from 1
instead of 0
because we reserve 0
for padding the sequence features.
5. Initialize NVTabular Workflow
5.1. Categorical Features Encoding
# categorify features
cat_feats = ['user_session', 'category_code', 'brand', 'user_id', 'product_id', 'category_id', 'event_type'] >> nvt.ops.Categorify(start_index=1)
5.2. Extract Temporal Features
# create time features
session_ts = ['event_time_ts']
session_time = (
session_ts >>
nvt.ops.LambdaOp(lambda col: cudf.to_datetime(col, unit='s')) >>
nvt.ops.Rename(name = 'event_time_dt')
)
sessiontime_weekday = (
session_time >>
nvt.ops.LambdaOp(lambda col: col.dt.weekday) >>
nvt.ops.Rename(name ='et_dayofweek')
)
Now let’s create cycling features from the sessiontime_weekday
column. We would like to use the temporal features (hour, day of week, month, etc.) that have inherently cyclical characteristic. We represent the day of week as a cycling feature (sine and cosine), so that it can be represented in a continuous space. That way, the difference between the representation of two different days is the same, in other words, with cyclical features we can convey closeness between data. You can read more about it here.
def get_cycled_feature_value_sin(col, max_value):
value_scaled = (col + 0.000001) / max_value
value_sin = np.sin(2*np.pi*value_scaled)
return value_sin
def get_cycled_feature_value_cos(col, max_value):
value_scaled = (col + 0.000001) / max_value
value_cos = np.cos(2*np.pi*value_scaled)
return value_cos
weekday_sin = sessiontime_weekday >> (lambda col: get_cycled_feature_value_sin(col+1, 7)) >> nvt.ops.Rename(name = 'et_dayofweek_sin')
weekday_cos= sessiontime_weekday >> (lambda col: get_cycled_feature_value_cos(col+1, 7)) >> nvt.ops.Rename(name = 'et_dayofweek_cos')
5.2.1 Add Product Recency feature
Let’s define a custom op to calculate product recency in days
# Compute Item recency: Define a custom Op
class ItemRecency(nvt.ops.Operator):
def transform(self, columns, gdf):
for column in columns.names:
col = gdf[column]
item_first_timestamp = gdf['prod_first_event_time_ts']
delta_days = (col - item_first_timestamp) / (60*60*24)
gdf[column + "_age_days"] = delta_days * (delta_days >=0)
return gdf
def compute_selector(
self,
input_schema: Schema,
selector: ColumnSelector,
parents_selector: ColumnSelector,
dependencies_selector: ColumnSelector,
) -> ColumnSelector:
self._validate_matching_cols(input_schema, parents_selector, "computing input selector")
return parents_selector
def column_mapping(self, col_selector):
column_mapping = {}
for col_name in col_selector.names:
column_mapping[col_name + "_age_days"] = [col_name]
return column_mapping
@property
def dependencies(self):
return ["prod_first_event_time_ts"]
@property
def output_dtype(self):
return np.float64
recency_features = ['event_time_ts'] >> ItemRecency()
recency_features_norm = recency_features >> nvt.ops.LogOp() >> nvt.ops.Normalize(out_dtype=np.float32) >> nvt.ops.Rename(name='product_recency_days_log_norm')
time_features = (
session_time +
sessiontime_weekday +
weekday_sin +
weekday_cos +
recency_features_norm
)
5.3. Normalize Continuous Features¶
# Smoothing price long-tailed distribution and applying standardization
price_log = ['price'] >> nvt.ops.LogOp() >> nvt.ops.Normalize(out_dtype=np.float32) >> nvt.ops.Rename(name='price_log_norm')
# Relative price to the average price for the category_id
def relative_price_to_avg_categ(col, gdf):
epsilon = 1e-5
col = ((gdf['price'] - col) / (col + epsilon)) * (col > 0).astype(int)
return col
avg_category_id_pr = ['category_id'] >> nvt.ops.JoinGroupby(cont_cols =['price'], stats=["mean"]) >> nvt.ops.Rename(name='avg_category_id_price')
relative_price_to_avg_category = avg_category_id_pr >> nvt.ops.LambdaOp(relative_price_to_avg_categ, dependency=['price']) >> nvt.ops.Rename(name="relative_price_to_avg_categ_id")
5.4. Grouping interactions into sessions
Aggregate by session id and creates the sequential features
groupby_feats = ['event_time_ts', 'user_session'] + cat_feats + time_features + price_log + relative_price_to_avg_category
# Define Groupby Workflow
groupby_features = groupby_feats >> nvt.ops.Groupby(
groupby_cols=["user_session"],
sort_cols=["event_time_ts"],
aggs={
'user_id': ['first'],
'product_id': ["list", "count"],
'category_code': ["list"],
'brand': ["list"],
'category_id': ["list"],
'event_time_ts': ["first"],
'event_time_dt': ["first"],
'et_dayofweek_sin': ["list"],
'et_dayofweek_cos': ["list"],
'price_log_norm': ["list"],
'relative_price_to_avg_categ_id': ["list"],
'product_recency_days_log_norm': ["list"]
},
name_sep="-") >> nvt.ops.AddMetadata(tags=[Tags.CATEGORICAL])
Select columns which are list
groupby_features_list = groupby_features['product_id-list',
'category_code-list',
'brand-list',
'category_id-list',
'et_dayofweek_sin-list',
'et_dayofweek_cos-list',
'price_log_norm-list',
'relative_price_to_avg_categ_id-list',
'product_recency_days_log_norm-list']
SESSIONS_MAX_LENGTH = 20
MINIMUM_SESSION_LENGTH = 2
We truncate the sequence features in length according to sessions_max_length param, which is set as 20 in our example.
groupby_features_trim = groupby_features_list >> nvt.ops.ListSlice(0, SESSIONS_MAX_LENGTH, pad=True) >> nvt.ops.Rename(postfix = '_seq')
Create a
day_index
column in order to partition sessions by day when saving the parquet files.
# calculate session day index based on 'timestamp-first' column
day_index = ((groupby_features['event_time_dt-first']) >>
nvt.ops.LambdaOp(lambda col: (col - col.min()).dt.days +1) >>
nvt.ops.Rename(f = lambda col: "day_index")
)
Select certain columns to be used in model training
selected_features = groupby_features['user_session', 'product_id-count'] + groupby_features_trim + day_index
Filter out the session that have less than 2 interactions.
filtered_sessions = selected_features >> nvt.ops.Filter(f=lambda df: df["product_id-count"] >= MINIMUM_SESSION_LENGTH)
# avoid numba warnings
from numba import config
config.CUDA_LOW_OCCUPANCY_WARNINGS = 0
Initialize the NVTabular dataset object and workflow graph.
NVTabular’s preprocessing and feature engineering workflows are directed graphs of operators. When we initialize a Workflow with our pipeline, workflow organizes the input and output columns.
dataset = nvt.Dataset(df)
workflow = nvt.Workflow(filtered_sessions)
workflow.fit(dataset)
sessions_gdf = workflow.transform(dataset).to_ddf()
/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
warnings.warn(
Above, we created an NVTabular Dataset object using our input dataset. Then, we calculate statistics for this workflow on the input dataset, i.e. on our training set, using the workflow.fit()
method so that our Workflow can use these stats to transform any given input.
Let’s print the head of our preprocessed dataset. You can notice that now each example (row) is a session and the sequential features with respect to user interactions were converted to lists with matching length.
sessions_gdf.head(3)
user_session | product_id-count | product_id-list_seq | category_code-list_seq | brand-list_seq | category_id-list_seq | et_dayofweek_sin-list_seq | et_dayofweek_cos-list_seq | price_log_norm-list_seq | relative_price_to_avg_categ_id-list_seq | product_recency_days_log_norm-list_seq | day_index | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2 | 779 | [19064, 52057, 13290, 11446, 15835, 879, 633, ... | [1, 1, 1, 1, 1, 12, 12, 12, 12, 12, 12, 12, 12... | [171, 120, 231, 392, 562, 20, 9, 20, 295, 143,... | [3, 3, 3, 3, 3, 17, 17, 17, 17, 17, 17, 17, 17... | [0.9749277, 0.9749277, 0.9749277, 0.9749277, 0... | [-0.22252177, -0.22252177, -0.22252177, -0.222... | [-0.6063042879104614, -0.5922226905822754, -0.... | [0.03519274808521567, 0.05391073897011643, 0.0... | [-2.266085624694824, -2.266085624694824, -2.26... | 1 |
1 | 3 | 316 | [252, 2801, 5399, 1074, 252, 355, 327, 319, 34... | [1, 17, 17, 15, 1, 1, 17, 31, 17, 17, 17, 17, ... | [1, 1, 1, 1, 1, 50, 1, 1, 36, 1, 1, 36, 50, 1,... | [234, 36, 36, 30, 234, 52, 36, 48, 36, 36, 36,... | [0.43388295, 0.43388295, 0.43388295, 0.4338829... | [-0.90096927, -0.90096927, -0.90096927, -0.900... | [0.7637995481491089, 0.4069388806819916, 0.258... | [0.0006990219343376929, -0.0487534739286619, -... | [-0.8581507205963135, -0.9379308223724365, -1.... | 2 |
2 | 4 | 277 | [765, 353, 1360, 1965, 2204, 3129, 726, 861, 9... | [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 1... | [448, 114, 1, 20, 20, 72, 114, 143, 20, 141, 7... | [17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 1... | [0.43388295, 0.43388295, 0.43388295, 0.4338829... | [-0.90096927, -0.90096927, -0.90096927, -0.900... | [-1.7807854413986206, -0.5645747184753418, -0.... | [-0.8327210168829422, -0.19123243552338817, 0.... | [-0.799250602722168, -0.7864173650741577, -0.8... | 2 |
6. Exporting data
We export dataset to parquet partitioned by the session day_index
column.
# define partition column
PARTITION_COL = 'day_index'
# define output_folder to store the partitioned parquet files
OUTPUT_FOLDER = os.environ.get("OUTPUT_FOLDER", INPUT_DATA_DIR + "sessions_by_day")
!mkdir -p $OUTPUT_FOLDER
In this section we are going to create a folder structure as shown below. As we explained above, this is just to structure parquet files so that it would be easier to do incremental training and evaluation.
/sessions_by_day/
|-- 1
| |-- train.parquet
| |-- valid.parquet
| |-- test.parquet
|-- 2
| |-- train.parquet
| |-- valid.parquet
| |-- test.parquet
gpu_preprocessing
function converts the process df to a Dataset object and write out hive-partitioned data to disk.
from transformers4rec.data.preprocessing import save_time_based_splits
save_time_based_splits(data=nvt.Dataset(sessions_gdf),
output_dir= OUTPUT_FOLDER,
partition_col=PARTITION_COL,
timestamp_col='user_session',
)
Creating time-based splits: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00, 1.86it/s]
# check out the OUTPUT_FOLDER
!ls $OUTPUT_FOLDER
1 2 3 4 5 6 7
Save NVTabular workflow to load at the inference step.
workflow.output_schema.column_names
['user_session',
'product_id-count',
'product_id-list_seq',
'category_code-list_seq',
'brand-list_seq',
'category_id-list_seq',
'et_dayofweek_sin-list_seq',
'et_dayofweek_cos-list_seq',
'price_log_norm-list_seq',
'relative_price_to_avg_categ_id-list_seq',
'product_recency_days_log_norm-list_seq',
'day_index']
workflow_path = os.path.join(INPUT_DATA_DIR, 'workflow_etl')
workflow.save(workflow_path)
7. Wrap Up
That’s it! We finished our first task. We reprocessed our dataset and created new features to train a session-based recommendation model. Please run the cell below to shut down the kernel before moving on to the next notebook.
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)
{'status': 'ok', 'restart': True}