# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ================================
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_models_04-exporting-ranking-models/nvidia_logo.png

Exporting Ranking Models#

This notebook is created using the latest stable merlin-tensorflow container.

In this example notebook we demonstrate how to export (save) NVTabular workflow and a ranking model for model deployment with Merlin Systems library.

Learning Objectives:

  • Export NVTabular workflow for model deployment

  • Export TensorFlow DLRM model for model deployment

We will follow the steps below:

  • Prepare the data with NVTabular and export NVTabular workflow

  • Train a DLRM model with Merlin Models and export the trained model

Importing Libraries#

Let’s start with importing the libraries that we’ll use in this notebook.

import os

import nvtabular as nvt
from nvtabular.ops import *

from merlin.models.utils.example_utils import workflow_fit_transform
from merlin.schema.tags import Tags

import merlin.models.tf as mm
from merlin.io.dataset import Dataset
import tensorflow as tf
2022-10-19 17:20:17.650375: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-10-19 17:20:19.081535: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2022-10-19 17:20:19.081560: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2022-10-19 17:20:19.121312: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

Feature Engineering with NVTabular#

We use the synthetic train and test datasets generated by mimicking the real Ali-CCP: Alibaba Click and Conversion Prediction dataset to build our recommender system ranking models.

If you would like to use real Ali-CCP dataset instead, you can download the training and test datasets on tianchi.aliyun.com. You can then use get_aliccp() function to curate the raw csv files and save them as parquet files.

from merlin.datasets.synthetic import generate_data

DATA_FOLDER = os.environ.get("DATA_FOLDER", "workspace/data/")
NUM_ROWS = os.environ.get("NUM_ROWS", 1000000)
SYNTHETIC_DATA = eval(os.environ.get("SYNTHETIC_DATA", "True"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 512))

if SYNTHETIC_DATA:
    train, valid = generate_data("aliccp-raw", int(NUM_ROWS), set_sizes=(0.7, 0.3))
    # save the datasets as parquet files
    train.to_ddf().to_parquet(os.path.join(DATA_FOLDER, "train"))
    valid.to_ddf().to_parquet(os.path.join(DATA_FOLDER, "valid"))
/home/alaiacano/.pyenv/versions/3.8.10/envs/merlin38/lib/python3.8/site-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/home/alaiacano/.pyenv/versions/3.8.10/envs/merlin38/lib/python3.8/site-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
/home/alaiacano/.pyenv/versions/3.8.10/envs/merlin38/lib/python3.8/site-packages/merlin/io/dataset.py:251: UserWarning: Initializing an NVTabular Dataset in CPU mode.This is an experimental feature with extremely limited support!
  warnings.warn(

Let’s define our input and output paths.

train_path = os.path.join(DATA_FOLDER, "train", "*.parquet")
valid_path = os.path.join(DATA_FOLDER, "valid", "*.parquet")
output_path = os.path.join(DATA_FOLDER, "processed")

After we execute fit() and transform() functions on the raw dataset applying the operators defined in the NVTabular workflow pipeline below, the processed parquet files are saved to output_path.

%%time
category_temp_directory = os.path.join(DATA_FOLDER, "categories")
user_id = ["user_id"] >> Categorify(out_path=category_temp_directory) >> TagAsUserID()
item_id = ["item_id"] >> Categorify(out_path=category_temp_directory) >> TagAsItemID()
targets = ["click"] >> AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, "target"])

item_features = ["item_category", "item_shop", "item_brand"] >> Categorify(out_path=category_temp_directory) >> TagAsItemFeatures()

user_features = (
    [
        "user_shops",
        "user_profile",
        "user_group",
        "user_gender",
        "user_age",
        "user_consumption_2",
        "user_is_occupied",
        "user_geography",
        "user_intentions",
        "user_brands",
        "user_categories",
    ]
    >> Categorify(out_path=category_temp_directory)
    >> TagAsUserFeatures()
)

outputs = user_id + item_id + item_features + user_features + targets

workflow = nvt.Workflow(outputs)

train_dataset = nvt.Dataset(train_path)
valid_dataset = nvt.Dataset(valid_path)

workflow.fit(train_dataset)
workflow.transform(train_dataset).to_parquet(output_path=output_path + "/train/")
workflow.transform(valid_dataset).to_parquet(output_path=output_path + "/valid/")
/home/alaiacano/.pyenv/versions/3.8.10/envs/merlin38/lib/python3.8/site-packages/merlin/io/dataset.py:251: UserWarning: Initializing an NVTabular Dataset in CPU mode.This is an experimental feature with extremely limited support!
  warnings.warn(
/home/alaiacano/.pyenv/versions/3.8.10/envs/merlin38/lib/python3.8/site-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/home/alaiacano/.pyenv/versions/3.8.10/envs/merlin38/lib/python3.8/site-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
CPU times: user 7.6 s, sys: 1.49 s, total: 9.09 s
Wall time: 8.23 s

We save NVTabular workflow model in the current working directory.

workflow.save(os.path.join(DATA_FOLDER, "workflow"))

Let’s check out our saved workflow model folder.

!pip install seedir
Collecting seedir
  Using cached seedir-0.3.1-py3-none-any.whl (114 kB)
Collecting emoji
  Using cached emoji-2.1.0-py3-none-any.whl
Requirement already satisfied: natsort in /home/alaiacano/.pyenv/versions/3.8.10/envs/merlin38/lib/python3.8/site-packages (from seedir) (8.1.0)
Installing collected packages: emoji, seedir
Successfully installed emoji-2.1.0 seedir-0.3.1

[notice] A new release of pip available: 22.2.2 -> 22.3
[notice] To update, run: python -m pip install --upgrade pip
import seedir as sd

sd.seedir(
    DATA_FOLDER,
    style="lines",
    itemlimit=10,
    depthlimit=3,
    exclude_folders=".ipynb_checkpoints",
    sort=True,
)
data/
├─categories/
│ └─categories/
│   ├─unique.item_brand.parquet
│   ├─unique.item_category.parquet
│   ├─unique.item_id.parquet
│   ├─unique.item_shop.parquet
│   ├─unique.user_age.parquet
│   ├─unique.user_brands.parquet
│   ├─unique.user_categories.parquet
│   ├─unique.user_consumption_2.parquet
│   ├─unique.user_gender.parquet
│   └─unique.user_geography.parquet
├─processed/
│ ├─train/
│ │ ├─_file_list.txt
│ │ ├─_metadata
│ │ ├─_metadata.json
│ │ ├─part_0.parquet
│ │ └─schema.pbtxt
│ └─valid/
│   ├─_file_list.txt
│   ├─_metadata
│   ├─_metadata.json
│   ├─part_0.parquet
│   └─schema.pbtxt
├─train/
│ └─part.0.parquet
├─valid/
│ └─part.0.parquet
└─workflow/
  ├─categories/
  │ ├─unique.item_brand.parquet
  │ ├─unique.item_category.parquet
  │ ├─unique.item_id.parquet
  │ ├─unique.item_shop.parquet
  │ ├─unique.user_age.parquet
  │ ├─unique.user_brands.parquet
  │ ├─unique.user_categories.parquet
  │ ├─unique.user_consumption_2.parquet
  │ ├─unique.user_gender.parquet
  │ └─unique.user_geography.parquet
  ├─metadata.json
  └─workflow.pkl

Build and Train a DLRM model#

In this example, we build, train, and export a Deep Learning Recommendation Model (DLRM) architecture. To learn more about how to train different deep learning models, how easily transition from one model to another and the seamless integration between data preparation and model training visit 03-Exploring-different-models.ipynb notebook.

NVTabular workflow above exports a schema file, schema.pbtxt, of our processed dataset. To learn more about the schema object, schema file and tags, you can explore 02-Merlin-Models-and-NVTabular-integration.ipynb.

# define train and valid dataset objects
train = Dataset(os.path.join(output_path, "train", "*.parquet"))
valid = Dataset(os.path.join(output_path, "valid", "*.parquet"))

# define schema object
schema = train.schema
target_column = schema.select_by_tag(Tags.TARGET).column_names[0]
target_column
'click'
model = mm.DLRMModel(
    schema,
    embedding_dim=64,
    bottom_block=mm.MLPBlock([128, 64]),
    top_block=mm.MLPBlock([128, 64, 32]),
    prediction_tasks=mm.BinaryClassificationTask(target_column),
)
%%time

model.compile("adam", run_eagerly=False, metrics=[tf.keras.metrics.AUC()])
model.fit(train, validation_data=valid, batch_size=BATCH_SIZE)
1368/1368 [==============================] - 30s 18ms/step - loss: 0.6932 - auc: 0.4999 - regularization_loss: 0.0000e+00 - val_loss: 0.6932 - val_auc: 0.4998 - val_regularization_loss: 0.0000e+00
CPU times: user 1min 21s, sys: 12.1 s, total: 1min 33s
Wall time: 30.9 s
<keras.callbacks.History at 0x7f5127386700>

Save model#

The last step of machine learning (ML)/deep learning (DL) pipeline is to deploy the ETL workflow and saved model into production. In the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the DL model for a prediction. Therefore, we deploy the NVTabular workflow with the Tensorflow model as an ensemble model to Triton Inference using Merlin Systems library very easily. The ensemble model guarantees that the same transformation is applied to the raw inputs.

Let’s save our DLRM model.

model.save(os.path.join(DATA_FOLDER, "dlrm"))
INFO:tensorflow:Unsupported signature for serialization: ((PredictionOutput(predictions=TensorSpec(shape=(None, 1), dtype=tf.float32, name='outputs/predictions'), targets=TensorSpec(shape=(None, 1), dtype=tf.float32, name='outputs/targets'), positive_item_ids=None, label_relevant_counts=None, valid_negatives_mask=None, negative_item_ids=None, sample_weight=None), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f4ebc177f40>), {}).
INFO:tensorflow:Unsupported signature for serialization: ((PredictionOutput(predictions=TensorSpec(shape=(None, 1), dtype=tf.float32, name='outputs/predictions'), targets=TensorSpec(shape=(None, 1), dtype=tf.float32, name='outputs/targets'), positive_item_ids=None, label_relevant_counts=None, valid_negatives_mask=None, negative_item_ids=None, sample_weight=None), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f4ebc177f40>), {}).
WARNING:absl:Found untraced functions such as train_compute_metrics, model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, output_layer_layer_call_fn, output_layer_layer_call_and_return_conditional_losses while saving (showing 5 of 97). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: workspace/data/dlrm/assets
INFO:tensorflow:Assets written to: workspace/data/dlrm/assets

We have NVTabular wokflow and DLRM model exported, now it is time to move on to the next step: model deployment with Merlin Systems.

Deploying the model with Merlin Systems#

We trained and exported our ranking model and NVTabular workflow. In the next step, we will learn how to deploy our trained DLRM model into Triton Inference Server with Merlin Systems library. NVIDIA Triton Inference Server (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.

For the next step, visit Merlin Systems library and execute Serving-Ranking-Models-With-Merlin-Systems notebook to deploy our saved DLRM and NVTabular workflow models as an ensemble to TIS and obtain prediction results for a qiven request. In doing so, you need to mount the saved DLRM and NVTabular workflow to the inference container following the instructions in the README.md.