# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions anda
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_models-transformers-net-item-prediction/nvidia_logo.png

Transformer-based architecture for next-item prediction task#

This notebook is created using the latest stable merlin-tensorflow container.

Overview#

In this use case we will train a Transformer-based architecture for next-item prediction task.

Note, the data for this notebook will be automatically downloaded to the folder specified in the cells below.

We will use the booking.com dataset to train a session-based model. The dataset contains 1,166,835 of anonymized hotel reservations in the train set and 378,667 in the test set. Each reservation is a part of a customer’s trip (identified by utrip_id) which includes consecutive reservations.

We will reshape the data to organize it into ‘sessions’. Each session will be a full customer itinerary in chronological order. The goal will be to predict the city_id of the final reservation of each trip.

Learning objectives#

  • Training a Transformer-based architecture for next-item prediction task

Downloading and preparing the dataset#

We will download the dataset using a functionality provided by merlin models. The dataset can be found on GitHub here.

Read more about libraries used in the import statements below

# Resetting the TF memory allocation to not be 50% by default. 
import os
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"

from merlin.core.dispatch import get_lib
from merlin.datasets.ecommerce import get_booking

import numpy as np
import timeit

from nvtabular import *
from nvtabular import ops

from merlin.schema.tags import Tags
import merlin.models.tf as mm

INPUT_DATA_DIR = os.environ.get('INPUT_DATA_DIR', '/workspace/data')
OUTPUT_DATA_DIR = os.environ.get('OUTPUT_DATA_DIR', '/workspace/data')
NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', '5'))
2023-05-31 06:06:25.697025: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/torch.py:43: UserWarning: PyTorch dtype mappings did not load successfully due to an error: No module named 'torch'
  warn(f"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}")
2023-05-31 06:06:26.988036: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-31 06:06:26.988386: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-31 06:06:26.988518: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[INFO]: sparse_operation_kit is imported
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so
[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so
[SOK INFO] Initialize finished, communication tool: horovod
2023-05-31 06:06:28.519868: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-31 06:06:28.520815: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-31 06:06:28.520999: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-31 06:06:28.521129: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-31 06:06:28.591345: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-31 06:06:28.591534: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-31 06:06:28.591665: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-31 06:06:28.591770: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2023-05-31 06:06:28.591778: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:222] Using CUDA malloc Async allocator for GPU: 0
2023-05-31 06:06:28.591860: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1621] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 24576 MB memory:  -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:08:00.0, compute capability: 7.5
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Let’s download the data.

get_booking(INPUT_DATA_DIR)
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.SESSION_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.SESSION: 'session'>, <Tags.ID: 'id'>].
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
(<merlin.io.dataset.Dataset at 0x7fe90a7fce80>,
 <merlin.io.dataset.Dataset at 0x7fe90a7f7820>)

Each reservation has a unique utrip_id. During each trip a customer vists several destinations.

# When displaying cudf dataframes use print() or display(), otherwise Jupyter creates hidden copies.
train = get_lib().read_csv(f'{INPUT_DATA_DIR}/train_set.csv', parse_dates=['checkin', 'checkout'])
print(train.head())
   user_id    checkin   checkout  city_id device_class  affiliate_id  \
0  1000027 2016-08-13 2016-08-14     8183      desktop          7168   
1  1000027 2016-08-14 2016-08-16    15626      desktop          7168   
2  1000027 2016-08-16 2016-08-18    60902      desktop          7168   
3  1000027 2016-08-18 2016-08-21    30628      desktop           253   
4  1000033 2016-04-09 2016-04-11    38677       mobile           359   

  booker_country hotel_country   utrip_id  
0        Elbonia        Gondal  1000027_1  
1        Elbonia        Gondal  1000027_1  
2        Elbonia        Gondal  1000027_1  
3        Elbonia        Gondal  1000027_1  
4         Gondal  Cobra Island  1000033_1  

We will train on sequences of city_id and booker_country and based on this information, our model will attempt to predict the next city_id (the next hop in the journey).

We will train a transformer model that can work with sequences of variable length within a batch. This functionality is provided to us out of the box and doesn’t require any changes to the architecture. Thanks to it we do not have to pad or trim our sequences to any particular length – our model can make effective use of all of the data!

With one exception. For a masked language model that we will be training, we need to discard sequences that are shorter than two hops. This makes sense as there is nothing our model could learn if it was only presented with an itinerary with a single destination on it!

Let us begin by splitting the data into a train and validation set based on trip ID.

Let’s see how many unique trips there are in the dataset. Also, let us shuffle the trips along the way so that our validation set consists of a random sample of our train set.

# Unique trip ids.
utrip_ids = train.sample(frac=1).utrip_id.unique()
print('Number of unique trips is :', len(utrip_ids))
Number of unique trips is : 217686

Now let’s assign data to our train and validation sets. Furthermore, we sort the data by utrip_id and checkin. This way we ensure our sequences of visited city_ids will be in proper order!

Also, let’s remove trips where only a single city was visited as they cannot be modeled as a sequence.

train = get_lib().from_pandas(
    train.to_pandas().join(train.to_pandas().groupby('utrip_id').size().rename('num_examples'), on='utrip_id')
)
train = train[train.num_examples > 1]

train.checkin = train.checkin.astype('int')
train.checkout = train.checkout.astype('int')

train_set_utrip_ids = utrip_ids[:int(0.8 * utrip_ids.shape[0])]
validation_set_utrip_ids = utrip_ids[int(0.8 * utrip_ids.shape[0]):]

train_set = train[train.utrip_id.isin(train_set_utrip_ids)].sort_values(['utrip_id', 'checkin'])
validation_set = train[train.utrip_id.isin(validation_set_utrip_ids)].sort_values(['utrip_id', 'checkin'])

Preprocessing with NVTabular#

We can now begin with data preprocessing.

We will combine trips into “sessions”, discard trips that are too short and calculate total trip length.

We will use NVTabular for this work. It offers optimized tabular data preprocessing operators that run on the GPU. If you would like to learn more about the NVTabular library, please take a look here.

Read more about the Merlin’s Dataset API
Read more about how parquet files are read in and processed by Merlin
Read more about Tags

Read more about NVTabular Workflows

Read more about the NVTabular Operators

train_set_dataset = Dataset(train_set)
validation_set_dataset = Dataset(validation_set)
weekday_checkin = (
    ["checkin"]
    >> ops.LambdaOp(lambda col: get_lib().to_datetime(col).dt.weekday)
    >> ops.Rename(name="weekday_checkin")
)

weekday_checkout = (
    ["checkout"]
    >> ops.LambdaOp(lambda col: get_lib().to_datetime(col).dt.weekday)
    >> ops.Rename(name="weekday_checkout")
)

categorical_features = (['city_id', 'booker_country', 'hotel_country'] +
                         weekday_checkin + weekday_checkout
                       ) >> ops.Categorify()

groupby_features = categorical_features + ['utrip_id', 'checkin'] >> ops.Groupby(
    groupby_cols=['utrip_id'],
    aggs={
        'city_id': ['list', 'count'],
        'booker_country': ['list'],
        'hotel_country': ['list'],
        'weekday_checkin': ['list'],
        'weekday_checkout': ['list']
    },
    sort_cols="checkin"
)

list_features = (
            groupby_features['city_id_list', 'booker_country_list', 'hotel_country_list', 
                 'weekday_checkin_list', 'weekday_checkout_list'
            ] >> ops.AddTags([Tags.SEQUENCE])
)

# Filter out sessions with less than 2 interactions 
MINIMUM_SESSION_LENGTH = 2
features = list_features + (groupby_features['city_id_count'] >>  ops.AddTags([Tags.CONTINUOUS]))
filtered_sessions = features >> ops.Filter(f=lambda df: df["city_id_count"] >= MINIMUM_SESSION_LENGTH) 
wf = Workflow(filtered_sessions)

wf.fit_transform(train_set_dataset).to_parquet(os.path.join(OUTPUT_DATA_DIR, 'train_processed.parquet'))
wf.transform(validation_set_dataset).to_parquet(os.path.join(OUTPUT_DATA_DIR, 'validation_processed.parquet'))

wf.save(os.path.join(OUTPUT_DATA_DIR, 'workflow'))

Our data consists of a sequence of visited city_ids, a sequence of booker_countries (represented as integer categories) and a city_id_count column (which contains the count of visited cities in a trip).

Dataset(os.path.join(OUTPUT_DATA_DIR, 'train_processed.parquet')).head()
city_id_list booker_country_list hotel_country_list weekday_checkin_list weekday_checkout_list city_id_count
0 [8238, 156, 2278, 2097] [3, 3, 3, 3] [3, 3, 3, 3] [5, 7, 4, 3] [7, 4, 2, 7] 4
1 [63, 1160, 87, 618, 63] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] [5, 1, 4, 3, 5] [6, 4, 2, 5, 4] 5
2 [7, 6, 24, 1050, 65, 52, 3] [2, 2, 2, 2, 2, 2, 2] [2, 2, 2, 16, 16, 3, 3] [5, 1, 2, 6, 5, 7, 4] [6, 3, 1, 5, 7, 4, 3] 7
3 [1032, 757, 140, 3] [2, 2, 2, 2] [19, 19, 19, 3] [1, 4, 2, 3] [4, 3, 2, 5] 4
4 [3603, 262, 662, 250, 359] [1, 1, 1, 1, 1] [30, 30, 30, 30, 30] [1, 3, 6, 5, 1] [2, 1, 5, 6, 3] 5

We are now ready to train our model.

Here is the schema of the data that our model will use.

seq_schema = Workflow.load(os.path.join(OUTPUT_DATA_DIR, 'workflow')).output_schema.select_by_tag(Tags.SEQUENCE)
seq_schema
name tags dtype is_list is_ragged properties.num_buckets properties.freq_threshold properties.max_size properties.start_index properties.cat_path properties.domain.min properties.domain.max properties.domain.name properties.embedding_sizes.cardinality properties.embedding_sizes.dimension properties.value_count.min properties.value_count.max
0 city_id_list (Tags.SEQUENCE, Tags.CATEGORICAL) DType(name='int64', element_type=<ElementType.... True True None 0 0 0 .//categories/unique.city_id.parquet 0 37202 city_id 37203 512 0 None
1 booker_country_list (Tags.SEQUENCE, Tags.CATEGORICAL) DType(name='int64', element_type=<ElementType.... True True None 0 0 0 .//categories/unique.booker_country.parquet 0 5 booker_country 6 16 0 None
2 hotel_country_list (Tags.SEQUENCE, Tags.CATEGORICAL) DType(name='int64', element_type=<ElementType.... True True None 0 0 0 .//categories/unique.hotel_country.parquet 0 194 hotel_country 195 31 0 None
3 weekday_checkin_list (Tags.SEQUENCE, Tags.CATEGORICAL) DType(name='int64', element_type=<ElementType.... True True None 0 0 0 .//categories/unique.weekday_checkin.parquet 0 7 weekday_checkin 8 16 0 None
4 weekday_checkout_list (Tags.SEQUENCE, Tags.CATEGORICAL) DType(name='int64', element_type=<ElementType.... True True None 0 0 0 .//categories/unique.weekday_checkout.parquet 0 7 weekday_checkout 8 16 0 None

Let’s also identify the target column.

target = Workflow.load(os.path.join(OUTPUT_DATA_DIR, 'workflow')).output_schema.select_by_tag(Tags.SEQUENCE).column_names[0]
target
'city_id_list'

Constructing the model#

Let’s construct our model.

We can specify various hyperparameters, such as the number of heads and number of layers to use.

For the transformer portion of our model, we will use the XLNet architecture.

Later, when we run the fit method on our model, we will specify the masking_probability of 0.3 and link it to the transformer block defined in out model. Through the combination of these parameters, our model will train on sequences where any given timestep will be masked with a probability of 0.3 and it will be our model’s training task to infer the target value for that step!

To summarize, Masked Language Modeling is implemented by:

  • SequenceMaskRandom() - Used as a pre for model.fit(), it randomly selects items from the sequence to be masked for prediction as targets, by using Keras masking. This block also adds the necessary configuration to the specified transformer block so as it is pre-configured with the necessary layers needed to prepare the inputs to the HuggingFace transformer layer and to post-process its outputs. For example, one pre-processing operation is to replace the input embeddings at masked positions for prediction by a dummy trainable embedding, to avoid leakage of the targets.

Read more about the apis used to construct models

dmodel=48
mlp_block = mm.MLPBlock(
                [128,dmodel],
                activation='relu',
                no_activation_last_layer=True,
            )
transformer_block = mm.XLNetBlock(d_model=dmodel, n_head=4, n_layer=2)
model = mm.Model(
    mm.InputBlockV2(
        seq_schema,
        embeddings=mm.Embeddings(
            Workflow.load(os.path.join(OUTPUT_DATA_DIR, 'workflow')).output_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None
        ),
    ),
    mlp_block,
    transformer_block,
    mm.CategoricalOutput(
        Workflow.load(os.path.join(OUTPUT_DATA_DIR, 'workflow')).output_schema.select_by_name(target),
        default_loss="categorical_crossentropy",
    ),
)

Model training#

model.compile(run_eagerly=False, optimizer='adam', loss="categorical_crossentropy")

model.fit(
    Dataset(os.path.join(OUTPUT_DATA_DIR, 'train_processed.parquet')),
    batch_size=64,
    epochs=NUM_EPOCHS,
    pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3, transformer=transformer_block)
)
/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values  each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
  warnings.warn(
Epoch 1/5
2023-05-31 06:06:44.034041: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8700
WARNING:tensorflow:Gradients do not exist for variables ['model/mask_emb:0', 'transformer/layer_._0/rel_attn/r_s_bias:0', 'transformer/layer_._0/rel_attn/seg_embed:0', 'transformer/layer_._1/rel_attn/r_s_bias:0', 'transformer/layer_._1/rel_attn/seg_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['model/mask_emb:0', 'transformer/layer_._0/rel_attn/r_s_bias:0', 'transformer/layer_._0/rel_attn/seg_embed:0', 'transformer/layer_._1/rel_attn/r_s_bias:0', 'transformer/layer_._1/rel_attn/seg_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
2023-05-31 06:06:54.541024: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: model/xl_net_block/sequential_block_5/replace_masked_embeddings/RaggedWhere/Assert/AssertGuard/branch_executed/_95
2720/2720 [==============================] - 81s 25ms/step - loss: 7.3315 - recall_at_10: 0.1973 - mrr_at_10: 0.0863 - ndcg_at_10: 0.1123 - map_at_10: 0.0863 - precision_at_10: 0.0197 - regularization_loss: 0.0000e+00 - loss_batch: 7.3306
Epoch 2/5
2720/2720 [==============================] - 70s 25ms/step - loss: 6.0979 - recall_at_10: 0.3633 - mrr_at_10: 0.1707 - ndcg_at_10: 0.2161 - map_at_10: 0.1707 - precision_at_10: 0.0363 - regularization_loss: 0.0000e+00 - loss_batch: 6.0950
Epoch 3/5
2720/2720 [==============================] - 71s 26ms/step - loss: 5.5827 - recall_at_10: 0.4306 - mrr_at_10: 0.2056 - ndcg_at_10: 0.2588 - map_at_10: 0.2056 - precision_at_10: 0.0431 - regularization_loss: 0.0000e+00 - loss_batch: 5.5806
Epoch 4/5
2720/2720 [==============================] - 72s 26ms/step - loss: 5.3211 - recall_at_10: 0.4627 - mrr_at_10: 0.2213 - ndcg_at_10: 0.2784 - map_at_10: 0.2213 - precision_at_10: 0.0463 - regularization_loss: 0.0000e+00 - loss_batch: 5.3194
Epoch 5/5
2720/2720 [==============================] - 71s 26ms/step - loss: 5.1920 - recall_at_10: 0.4787 - mrr_at_10: 0.2306 - ndcg_at_10: 0.2892 - map_at_10: 0.2306 - precision_at_10: 0.0479 - regularization_loss: 0.0000e+00 - loss_batch: 5.1903
<keras.callbacks.History at 0x7fe67105a7f0>

Model evaluation#

We have trained our model.

But in training the metrics come from a masked language modelling task. A portion of steps in the sequence was masked for each example. The metrics were calculated on this task.

In reality, we probably care how well our model does on the next item prediction task (as it mimics the scenario in which the model would be likely to be used).

Let’s measure the performance of the model on a task where it attempts to predict the last item in a sequence.

We will mask the last item using SequenceMaskLast and run inference.

metrics = model.evaluate(
    Dataset(os.path.join(OUTPUT_DATA_DIR, 'validation_processed.parquet')),
    batch_size=128,
    pre=mm.SequenceMaskLast(schema=seq_schema, target=target, transformer=transformer_block),
    return_dict=True
)
2023-05-31 06:12:51.968982: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: model/xl_net_block/sequential_block_5/replace_masked_embeddings/RaggedWhere/Assert/AssertGuard/branch_executed/_74
340/340 [==============================] - 11s 20ms/step - loss: 4.7151 - recall_at_10: 0.5533 - mrr_at_10: 0.3083 - ndcg_at_10: 0.3665 - map_at_10: 0.3083 - precision_at_10: 0.0553 - regularization_loss: 0.0000e+00 - loss_batch: 4.7149
metrics
{'loss': 4.715089797973633,
 'recall_at_10': 0.5533444881439209,
 'mrr_at_10': 0.30831339955329895,
 'ndcg_at_10': 0.36654922366142273,
 'map_at_10': 0.30831339955329895,
 'precision_at_10': 0.055334459990262985,
 'regularization_loss': 0.0,
 'loss_batch': 4.635858535766602}

Serving predictions using the Triton Inference Server#

Now, we will serve our trained models on NVIDIA Triton Inference Server (TIS). TIS is an open-source inference serving software that helps standardize model deployment and execution and delivers fast and scalable AI in production. To serve recommender models on TIS easily, NVIDIA Merlin team designed and developed the Merlin Systems library. Merlin Systems provides tools and operators to be able to serve end-to-end recommender systems pipelines on TIS easily

In order to perform inference on the Triton Inference Server, we need to output the inference operators to disk.

The inference operators form an Ensemble, which is a pipeline that takes in raw data, processes it using NVTabular, and finally outputs predictions from the model that we trained.

Let’s write the Ensemble to disk (we will later load it on Triton to perform inference).

from merlin.systems.dag.ops.tensorflow import PredictTensorflow
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.ops.workflow import TransformWorkflow

inf_ops = wf.input_schema.column_names >> TransformWorkflow(wf) >> PredictTensorflow(model)

ensemble = Ensemble(inf_ops, wf.input_schema)
ensemble.export(os.path.join(OUTPUT_DATA_DIR, 'ensemble'));
WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(
  (_feature_shapes): Dict(
    (city_id_list): TensorShape([64, None, 1])
    (booker_country_list): TensorShape([64, None, 1])
    (hotel_country_list): TensorShape([64, None, 1])
    (weekday_checkin_list): TensorShape([64, None, 1])
    (weekday_checkout_list): TensorShape([64, None, 1])
  )
  (_feature_dtypes): Dict(
    (city_id_list): tf.int64
    (booker_country_list): tf.int64
    (hotel_country_list): tf.int64
    (weekday_checkin_list): tf.int64
    (weekday_checkout_list): tf.int64
  )
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
  (_feature_shapes): Dict(
    (city_id_list): TensorShape([64, None, 1])
    (booker_country_list): TensorShape([64, None, 1])
    (hotel_country_list): TensorShape([64, None, 1])
    (weekday_checkin_list): TensorShape([64, None, 1])
    (weekday_checkout_list): TensorShape([64, None, 1])
  )
  (_feature_dtypes): Dict(
    (city_id_list): tf.int64
    (booker_country_list): tf.int64
    (hotel_country_list): tf.int64
    (weekday_checkin_list): tf.int64
    (weekday_checkout_list): tf.int64
  )
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
  (_feature_shapes): Dict(
    (city_id_list): TensorShape([64, None, 1])
    (booker_country_list): TensorShape([64, None, 1])
    (hotel_country_list): TensorShape([64, None, 1])
    (weekday_checkin_list): TensorShape([64, None, 1])
    (weekday_checkout_list): TensorShape([64, None, 1])
  )
  (_feature_dtypes): Dict(
    (city_id_list): tf.int64
    (booker_country_list): tf.int64
    (hotel_country_list): tf.int64
    (weekday_checkin_list): tf.int64
    (weekday_checkout_list): tf.int64
  )
), because it is not built.
WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, sequence_mask_random_layer_call_fn, sequence_mask_random_layer_call_and_return_conditional_losses, sequence_mask_last_layer_call_fn while saving (showing 5 of 108). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmp1sakw940/model.savedmodel/assets
INFO:tensorflow:Assets written to: /tmp/tmp1sakw940/model.savedmodel/assets
/usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:101: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  config[key] = tf.keras.utils.serialize_keras_object(maybe_value)
/usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:288: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  config[i] = tf.keras.utils.serialize_keras_object(layer)
/usr/local/lib/python3.8/dist-packages/keras/saving/legacy/saved_model/layer_serialization.py:134: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  return serialization.serialize_keras_object(obj)
/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values  each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
  warnings.warn(
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values  each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
  warnings.warn(
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'city_id_count', which is not being used by any downstream operator in the ensemble graph.
  warnings.warn(
WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(
  (_feature_shapes): Dict(
    (city_id_list): TensorShape([64, None, 1])
    (booker_country_list): TensorShape([64, None, 1])
    (hotel_country_list): TensorShape([64, None, 1])
    (weekday_checkin_list): TensorShape([64, None, 1])
    (weekday_checkout_list): TensorShape([64, None, 1])
  )
  (_feature_dtypes): Dict(
    (city_id_list): tf.int64
    (booker_country_list): tf.int64
    (hotel_country_list): tf.int64
    (weekday_checkin_list): tf.int64
    (weekday_checkout_list): tf.int64
  )
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(
  (_feature_shapes): Dict(
    (city_id_list): TensorShape([64, None, 1])
    (booker_country_list): TensorShape([64, None, 1])
    (hotel_country_list): TensorShape([64, None, 1])
    (weekday_checkin_list): TensorShape([64, None, 1])
    (weekday_checkout_list): TensorShape([64, None, 1])
  )
  (_feature_dtypes): Dict(
    (city_id_list): tf.int64
    (booker_country_list): tf.int64
    (hotel_country_list): tf.int64
    (weekday_checkin_list): tf.int64
    (weekday_checkout_list): tf.int64
  )
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
  (_feature_shapes): Dict(
    (city_id_list): TensorShape([64, None, 1])
    (booker_country_list): TensorShape([64, None, 1])
    (hotel_country_list): TensorShape([64, None, 1])
    (weekday_checkin_list): TensorShape([64, None, 1])
    (weekday_checkout_list): TensorShape([64, None, 1])
  )
  (_feature_dtypes): Dict(
    (city_id_list): tf.int64
    (booker_country_list): tf.int64
    (hotel_country_list): tf.int64
    (weekday_checkin_list): tf.int64
    (weekday_checkout_list): tf.int64
  )
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
  (_feature_shapes): Dict(
    (city_id_list): TensorShape([64, None, 1])
    (booker_country_list): TensorShape([64, None, 1])
    (hotel_country_list): TensorShape([64, None, 1])
    (weekday_checkin_list): TensorShape([64, None, 1])
    (weekday_checkout_list): TensorShape([64, None, 1])
  )
  (_feature_dtypes): Dict(
    (city_id_list): tf.int64
    (booker_country_list): tf.int64
    (hotel_country_list): tf.int64
    (weekday_checkin_list): tf.int64
    (weekday_checkout_list): tf.int64
  )
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
  (_feature_shapes): Dict(
    (city_id_list): TensorShape([64, None, 1])
    (booker_country_list): TensorShape([64, None, 1])
    (hotel_country_list): TensorShape([64, None, 1])
    (weekday_checkin_list): TensorShape([64, None, 1])
    (weekday_checkout_list): TensorShape([64, None, 1])
  )
  (_feature_dtypes): Dict(
    (city_id_list): tf.int64
    (booker_country_list): tf.int64
    (hotel_country_list): tf.int64
    (weekday_checkin_list): tf.int64
    (weekday_checkout_list): tf.int64
  )
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
  (_feature_shapes): Dict(
    (city_id_list): TensorShape([64, None, 1])
    (booker_country_list): TensorShape([64, None, 1])
    (hotel_country_list): TensorShape([64, None, 1])
    (weekday_checkin_list): TensorShape([64, None, 1])
    (weekday_checkout_list): TensorShape([64, None, 1])
  )
  (_feature_dtypes): Dict(
    (city_id_list): tf.int64
    (booker_country_list): tf.int64
    (hotel_country_list): tf.int64
    (weekday_checkin_list): tf.int64
    (weekday_checkout_list): tf.int64
  )
), because it is not built.
WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, sequence_mask_random_layer_call_fn, sequence_mask_random_layer_call_and_return_conditional_losses, sequence_mask_last_layer_call_fn while saving (showing 5 of 108). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets
INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets
/usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:101: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  config[key] = tf.keras.utils.serialize_keras_object(maybe_value)
/usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:288: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  config[i] = tf.keras.utils.serialize_keras_object(layer)
/usr/local/lib/python3.8/dist-packages/keras/saving/legacy/saved_model/layer_serialization.py:134: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  return serialization.serialize_keras_object(obj)
/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values  each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
  warnings.warn(
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

After we export the ensemble, we are ready to start the Triton Inference Server.

The server is installed in Merlin Tensorflow and Merlin PyTorch containers. If you are not using one of our containers, then ensure it is installed in your environment. For more information, see the Triton Inference Server documentation.

You can start the server by running the following command:

tritonserver --model-repository={OUTPUT_DATA_DIR}/ensemble/

For the –model-repository argument, specify the same value as the export_path that you specified previously in the ensemble.export method.

After you run the tritonserver command, wait until your terminal shows messages like the following example:

I0414 18:29:50.741833 4067 grpc_server.cc:4421] Started GRPCInferenceService at 0.0.0.0:8001
I0414 18:29:50.742197 4067 http_server.cc:3113] Started HTTPService at 0.0.0.0:8000
I0414 18:29:50.783470 4067 http_server.cc:178] Started Metrics Service at 0.0.0.0:8002

Let us now package our data for inference. We will send the first 4 rows of our validation data, which corresponds to a single trip. The data will be first processed by the NVTabular workflow and subsequentally passed to our transformer model for predicting.

Let us send the first 4 rows of our validation data to Triton. This will correspond to a single trip (all rows have the same utrip_id) with four stops.

from merlin.systems.triton import convert_df_to_triton_input

validation_data = validation_set_dataset.compute()
inputs = convert_df_to_triton_input(wf.input_schema, validation_data.iloc[:4])
import tritonclient.grpc as grpcclient

with grpcclient.InferenceServerClient("localhost:8001") as client:
    response = client.infer('executor_model', inputs)

The response consists of logits coming from our model.

response.as_numpy('city_id_list/categorical_output')
array([[-2.8206294 , -1.3849059 ,  1.9042726 , ...,  0.851537  ,
        -2.4237087 , -0.73849726]], dtype=float32)
predictions = response.as_numpy('city_id_list/categorical_output')
predictions.shape
(1, 37203)

The above values are logits output from the last layer of our model. They correspond in size to the cardinality of city_id, our target variable:

cardinality = wf.output_schema['city_id_list'].properties['embedding_sizes']['cardinality']
cardinality
37203

Summary#

We have trained a transformer model for the next item prediction task using language model masking.

For another session-based example that goes deeper into data preprocessing and that covers several advanced techniques (Weight Tying, Temperature Scaling) please see Session-Based Next Item Prediction for Fashion E-Commerce.