# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#`
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions anda
# limitations under the License.
# ==============================================================================
# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
Transformer-based architecture for next-item prediction task with pretrained embeddings#
This notebook is created using the latest stable merlin-tensorflow container.
Overview#
In this use case we will train a Transformer-based architecture for next-item prediction task with pretrained embeddings.
You can chose to download the full dataset manually or use synthetic data.
We will use the SIGIR eCOM 2021 Data Challenge Dataset to train a session-based model. The dataset contains 36M events of users browsing an online store.
We will reshape the data to organize it into ‘sessions’. Each session will be a full customer online journey in chronological order. The goal will be to predict the url
of the next action taken.
Learning objectives#
Training a Transformer-based architecture for next-item prediction task
Downloading and preparing the dataset#
import os
import cudf
import numpy as np
import pandas as pd
import nvtabular as nvt
from merlin.schema import ColumnSchema, Schema, Tags
OUTPUT_DATA_DIR = os.environ.get('OUTPUT_DATA_DIR', '/workspace/data')
NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', 5))
NUM_EXAMPLES = int(os.environ.get('NUM_EXAMPLES', 100_000))
MINIMUM_SESSION_LENGTH = int(os.environ.get('MINIMUM_SESSION_LENGTH', 5))
2023-06-20 22:58:36.667322: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/torch.py:43: UserWarning: PyTorch dtype mappings did not load successfully due to an error: No module named 'torch'
warn(f"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}")
2023-06-20 22:58:38.026020: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-20 22:58:38.026445: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-20 22:58:38.026622: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
You can download the full dataset by registering here. If you chose to download the data, please place it alongside this notebook in the sigir_dataset
directory and extract it.
By default, in this notebook we will be using synthetically generated data based on the SIGIR dataset, but you can run on the full dataset by changing the value of the boolean flag below.
RUN_ON_SYNTHETIC_DATA = True
Clean downloaded data#
If you are training on the full SIGIR dataset, the following code will pre-process it.
Here we deal with nan
values, drop rows with missing information and parse strings containing lists to lists.
The synthetically generated data is already clean – it doesn’t require this pre-processing.
if not RUN_ON_SYNTHETIC_DATA:
train = nvt.Dataset('/workspace/sigir_dataset/train/browsing_train.csv', part_size='500MB')
skus = nvt.Dataset('/workspace/sigir_dataset/train/sku_to_content.csv')
skus = pd.read_csv('/workspace/sigir_dataset/train/sku_to_content.csv')
skus['description_vector'] = skus['description_vector'].replace(np.nan, '')
skus['image_vector'] = skus['image_vector'].replace(np.nan, '')
skus['description_vector'] = skus['description_vector'].apply(lambda x: [] if len(x) == 0 else eval(x))
skus['image_vector'] = skus['image_vector'].apply(lambda x: [] if len(x) == 0 else eval(x))
skus = skus[skus.description_vector.apply(len) > 0]
skus = nvt.Dataset(skus)
Generate synthetic data#
If you are not running on the full dataset, the following lines of code will generate its synthetic counterpart.
if RUN_ON_SYNTHETIC_DATA:
from merlin.datasets.synthetic import generate_data
train = generate_data('sigir-browsing', NUM_EXAMPLES)
skus = generate_data('sigir-sku', NUM_EXAMPLES)
Constructing a workflow#
We need to process our data further before we can use it to train our model.
In particular, the skus
dataset contains the mapping between the product_sku_hash
(essentially an item id) to the description_vector
– an embedding obtained from the description.
We would like to enable our model to make use of this piece of information. In order to feed this data to our model, we need to map the product_sku_hash
to an id.
But we need to make sure that the way we process skus
and the train
dataset (event information) is consistent, that the same product_sku_hash
is mapped to the same id both when processing skus
and train
.
We do so by defining and fitting a Categorify
op once and using it to process both the skus
and the train
datasets.
Additionally, we apply some further processing to the train
dataset. We group rows by session_id_hash
so that each training example will contain events from a single customer visit to the online store arranged in chronological order.
If you would like to learn more about leveraging NVTabular
to process tabular data on the GPU using a set of industry standard operators, please consult the examples available here.
Let’s first process the train
dataset and retain the Categorify
operator (cat_op
) for processing of skus
.
cat_op = nvt.ops.Categorify()
out = ['product_sku_hash'] >> cat_op >> nvt.ops.TagAsItemID()
out += ['event_type', 'product_action', 'session_id_hash', 'hashed_url'] >> nvt.ops.Categorify()
out += ['server_timestamp_epoch_ms'] >> nvt.ops.NormalizeMinMax()
groupby_features = out >> nvt.ops.Groupby(
groupby_cols=['session_id_hash'],
aggs={
'product_sku_hash': ['list'],
'event_type': ['list'],
'product_action': ['list'],
'hashed_url': ['list', 'count'],
'server_timestamp_epoch_ms': ['list']
},
sort_cols="server_timestamp_epoch_ms"
)
filtered_sessions = groupby_features >> nvt.ops.Filter(f=lambda df: df["hashed_url_count"] >= MINIMUM_SESSION_LENGTH)
# We won't be needing the `session_id_hash` nor the `hashed_url_count` any longer
wf = nvt.Workflow(
filtered_sessions[
'product_sku_hash_list',
'event_type_list',
'product_action_list',
'hashed_url_list',
]
)
# Let's save the output of our workflow -- transformed `train` for later use (training of our model).
wf.fit_transform(train).to_parquet('train_transformed')
Here are a couple of example rows from train_transformed
.
nvt.Dataset('train_transformed', engine='parquet').head()
product_sku_hash_list | event_type_list | product_action_list | hashed_url_list | |
---|---|---|---|---|
0 | [578, 972, 378, 420, 328, 126, 233, 925, 410, ... | [3, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 3, 4, 4, 4, ... | [3, 3, 5, 6, 4, 3, 3, 4, 4, 4, 6, 5, 3, 4, 3, ... | [766, 955, 745, 210, 940, 688, 986, 524, 425, ... |
1 | [298, 304, 393, 697, 706, 313, 834, 83, 502, 1... | [4, 4, 4, 3, 4, 4, 4, 3, 3, 3, 4, 4, 3, 4, 3, ... | [3, 5, 6, 4, 4, 3, 3, 3, 6, 6, 3, 3, 6, 6, 3, ... | [13, 221, 915, 658, 456, 378, 802, 180, 580, 4... |
2 | [706, 221, 22, 702, 339, 645, 436, 358, 84, 35... | [4, 3, 4, 4, 4, 4, 4, 4, 3, 3, 3, 4, 3, 4, 3, ... | [3, 6, 4, 6, 3, 3, 5, 5, 4, 6, 4, 6, 3, 5, 6, ... | [271, 940, 562, 498, 172, 239, 270, 215, 489, ... |
3 | [278, 153, 189, 717, 580, 540, 219, 79, 200, 9... | [3, 3, 3, 3, 4, 4, 3, 4, 4, 3, 4, 4, 3, 3, 3, ... | [6, 6, 6, 6, 3, 4, 4, 4, 4, 4, 3, 6, 5, 4, 3, ... | [169, 419, 875, 725, 926, 770, 160, 554, 763, ... |
4 | [156, 922, 914, 592, 842, 916, 137, 928, 615, ... | [3, 4, 4, 4, 3, 4, 4, 4, 4, 3, 4, 3, 4, 3, 4, ... | [6, 4, 5, 6, 5, 4, 3, 3, 6, 5, 6, 5, 3, 6, 3, ... | [318, 506, 281, 191, 506, 480, 965, 399, 761, ... |
Now that we have processed the train set, we can use the mapping preserved in the cat_op
to process the skus
dataset containing the embeddings we are after.
Let’s now Categorify
the product_sku_hash
in skus
and grab just the description embedding information.
skus.head()
product_sku_hash | description_vector | category_hash | price_bucket | |
---|---|---|---|---|
0 | 13 | [0.07939800762120258, 0.3465797761609977, -0.3... | 16 | 0.186690 |
1 | 25 | [0.4275482879608162, -0.30569476366666, 0.1440... | 38 | 0.951997 |
2 | 18 | [-0.31035419787213536, 0.18070481533058008, 0.... | 22 | 0.973384 |
3 | 1 | [-0.31319783485940356, -0.11623980504981396, -... | 138 | 0.146260 |
4 | 11 | [0.25091279302969943, -0.33473442518442525, 0.... | 119 | 0.808252 |
out = ['product_sku_hash'] >> cat_op
wf_skus = nvt.Workflow(out + 'description_vector')
skus_ds = wf_skus.transform(skus)
skus_ds.head()
product_sku_hash | description_vector | |
---|---|---|
0 | 836 | [0.07939800762120258, 0.3465797761609977, -0.3... |
1 | 979 | [0.4275482879608162, -0.30569476366666, 0.1440... |
2 | 11 | [-0.31035419787213536, 0.18070481533058008, 0.... |
3 | 469 | [-0.31319783485940356, -0.11623980504981396, -... |
4 | 118 | [0.25091279302969943, -0.33473442518442525, 0.... |
Let us now export the embedding information to a numpy
array and write it to disk.
We will later pass this information to the Loader
so that it will load the correct emebedding for the product corresponding to a given step of a customer journey.
The embeddings are linked to the train set using the product_sku_hash
information.
skus_ds.to_npy('skus.npy')
How will the Loader
know which embedding to associate with a given row of the train set?
The product_sku_hash
ids have been exported along with the embeddings and are contained in the first column of the output numpy
array.
Here is the id of the first embedding stored in skus.npy
:
np.load('skus.npy')[0, 0]
836.0
and here is the embedding vector corresponding to product_sku_hash
of id referenced above:
np.load('skus.npy')[0, 1:]
array([ 0.07939801, 0.34657978, -0.38269496, 0.56307004, -0.10142923,
0.03702352, -0.11606304, 0.10070879, -0.21879928, 0.06107687,
-0.20743195, -0.01330719, 0.60182867, 0.0920322 , 0.2648726 ,
0.56061561, 0.48643498, 0.39045152, -0.40012162, 0.09153962,
-0.38351605, 0.57134731, 0.59986226, -0.40321368, -0.32984972,
0.37559494, 0.1554353 , -0.0413067 , 0.33814398, 0.30678041,
0.24001132, 0.42737922, 0.41554601, -0.40451691, 0.50428902,
-0.2004803 , -0.38297056, 0.06580838, 0.48285745, 0.51406472,
0.02268894, 0.36343324, 0.32497967, -0.29736346, -0.00538915,
0.12329302, -0.04998194, 0.27843002, 0.20212714, 0.39019503])
We are now ready to construct the Loader
that will feed the data to our model.
We begin by reading in the embeddings information.
embeddings = np.load('skus.npy')
We are now ready to define the Loader
.
We are passing in an EmbeddingOperator
that will ensure that correct sku
information (correct description_vector
) is associated with the correct step in the customer journey (with the lookup key being contained in the product_sku_hash_list
)
When specifying the dataset, we are creating a Merlin Dataset
based on the train_transformed
data we saved above.
Depending on the hardware that you will be running this on and the size of the dataset that you will be using, should you run out of GPU memory, you can specify one of the several parameters that can ease the memory load (npartitions
, part_size
, or part_mem_fraction
).
The BATCH_SIZE
of 16 should work on a broad set of hardware, but if you are training on a lot of data and your hardware permitting you might want to significantly increase it.
BATCH_SIZE = 16
from merlin.dataloader.tensorflow import Loader
from merlin.dataloader.ops.embeddings import EmbeddingOperator
import merlin.models.tf as mm
embedding_operator = EmbeddingOperator(
embeddings[:, 1:].astype(np.float32),
id_lookup_table=embeddings[:, 0].astype(int),
lookup_key="product_sku_hash_list",
embedding_name='product_embeddings'
)
loader = Loader(
dataset=nvt.Dataset('train_transformed', engine='parquet'),
batch_size=BATCH_SIZE,
transforms=[
embedding_operator
],
shuffle=True
)
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
[INFO]: sparse_operation_kit is imported
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so
[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so
[SOK INFO] Initialize finished, communication tool: horovod
2023-06-20 22:58:50.835162: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-20 22:58:50.836068: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-20 22:58:50.836268: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-20 22:58:50.836425: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-20 22:58:50.836673: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-20 22:58:50.836849: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-20 22:58:50.837009: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-20 22:58:50.837114: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2023-06-20 22:58:50.837130: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1621] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 24576 MB memory: -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:08:00.0, compute capability: 7.5
Using the EmbeddingOperator
object we referenced our product_embeddings
and insructed the model what to use as a key to look up the information.
Below is an example batch of data that our model will consume.
batch = mm.sample_batch(loader, batch_size=BATCH_SIZE, include_targets=False, prepare_features=True)
product_embeddings
are included in the batch.
batch.keys()
dict_keys(['product_sku_hash_list', 'event_type_list', 'product_action_list', 'hashed_url_list', 'product_embeddings'])
Creating and training the model#
We are now ready to construct our model.
import merlin.models.tf as mm
input_block = mm.InputBlockV2(
loader.output_schema,
embeddings=mm.Embeddings(
loader.output_schema.select_by_tag(Tags.CATEGORICAL),
sequence_combiner=None,
),
pretrained_embeddings=mm.PretrainedEmbeddings(
loader.output_schema.select_by_tag(Tags.EMBEDDING),
sequence_combiner=None,
normalizer="l2-norm",
output_dims={"product_embeddings": 64},
)
)
We have now constructed an input_block
that will take our batch and transform it in a fashion that will make it amenable for further processing by subsequent layers of our model.
To test that everything has worked, we can pass our example batch
through the input_block
.
input_batch = input_block(batch)
Let us now construct the remaining layers of our model.
target = 'hashed_url_list'
# We do not need the `train_transformed` dataset here, but we do need
# to access the schema.
# It contains important information that will help our model construct itself.
schema = nvt.Dataset('train_transformed', engine='parquet').schema
dmodel=64
mlp_block = mm.MLPBlock(
[128,dmodel],
activation='relu',
no_activation_last_layer=True,
)
transformer_block = mm.XLNetBlock(d_model=dmodel, n_head=4, n_layer=2)
model = mm.Model(
input_block,
mlp_block,
transformer_block,
mm.CategoricalOutput(
schema.select_by_name(target),
default_loss="categorical_crossentropy",
),
)
And let us train it.
model.compile(run_eagerly=False, optimizer='adam', loss="categorical_crossentropy")
model.fit(loader, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS, pre=mm.SequenceMaskRandom(schema=loader.output_schema, target=target, masking_prob=0.3, transformer=transformer_block))
/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.
warnings.warn(
2023-06-20 22:58:58.950175: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8700
Epoch 1/5
WARNING:tensorflow:Gradients do not exist for variables ['model/mask_emb:0', 'transformer/layer_._0/rel_attn/r_s_bias:0', 'transformer/layer_._0/rel_attn/seg_embed:0', 'transformer/layer_._1/rel_attn/r_s_bias:0', 'transformer/layer_._1/rel_attn/seg_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['model/mask_emb:0', 'transformer/layer_._0/rel_attn/r_s_bias:0', 'transformer/layer_._0/rel_attn/seg_embed:0', 'transformer/layer_._1/rel_attn/r_s_bias:0', 'transformer/layer_._1/rel_attn/seg_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
2023-06-20 22:59:11.285571: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: model/xl_net_block/sequential_block_7/replace_masked_embeddings/RaggedWhere/Assert/AssertGuard/branch_executed/_95
18/18 [==============================] - 42s 2s/step - loss: 6.9800 - recall_at_10: 0.0106 - mrr_at_10: 0.0033 - ndcg_at_10: 0.0050 - map_at_10: 0.0033 - precision_at_10: 0.0011 - regularization_loss: 0.0000e+00 - loss_batch: 6.9689
Epoch 2/5
18/18 [==============================] - 34s 2s/step - loss: 6.9591 - recall_at_10: 0.0106 - mrr_at_10: 0.0031 - ndcg_at_10: 0.0048 - map_at_10: 0.0031 - precision_at_10: 0.0011 - regularization_loss: 0.0000e+00 - loss_batch: 6.9363
Epoch 3/5
18/18 [==============================] - 39s 2s/step - loss: 6.9471 - recall_at_10: 0.0107 - mrr_at_10: 0.0028 - ndcg_at_10: 0.0046 - map_at_10: 0.0028 - precision_at_10: 0.0011 - regularization_loss: 0.0000e+00 - loss_batch: 6.9206
Epoch 4/5
18/18 [==============================] - 38s 2s/step - loss: 6.9398 - recall_at_10: 0.0103 - mrr_at_10: 0.0030 - ndcg_at_10: 0.0047 - map_at_10: 0.0030 - precision_at_10: 0.0010 - regularization_loss: 0.0000e+00 - loss_batch: 6.9015
Epoch 5/5
18/18 [==============================] - 38s 2s/step - loss: 6.9375 - recall_at_10: 0.0104 - mrr_at_10: 0.0030 - ndcg_at_10: 0.0047 - map_at_10: 0.0030 - precision_at_10: 0.0010 - regularization_loss: 0.0000e+00 - loss_batch: 6.9095
<keras.callbacks.History at 0x7f55081d17c0>
Serving predictions#
Now that we have prepared a workflow for processing our data (wf
), defined the embedding operator (embedding_operator
) and trained our model (model
), we have all the components we need to serve our model using the Triton Inference Server (TIS).
Let us define a set of inference operators (a pipeline for processing our data all the way to obtaining predictions) and export them as an ensemble that we will be able to serve using TIS.
from merlin.systems.dag.ops.tensorflow import PredictTensorflow
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.ops.workflow import TransformWorkflow
inference_operators = wf.input_schema.column_names >> TransformWorkflow(wf) >> embedding_operator >> PredictTensorflow(model)
WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(
(_feature_shapes): Dict(
(product_sku_hash_list): TensorShape([16, None, 1])
(event_type_list): TensorShape([16, None, 1])
(product_action_list): TensorShape([16, None, 1])
(hashed_url_list): TensorShape([16, None, 1])
(product_embeddings): TensorShape([16, None, 50])
)
(_feature_dtypes): Dict(
(product_sku_hash_list): tf.int64
(event_type_list): tf.int64
(product_action_list): tf.int64
(hashed_url_list): tf.int64
(product_embeddings): tf.float32
)
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
(_feature_shapes): Dict(
(product_sku_hash_list): TensorShape([16, None, 1])
(event_type_list): TensorShape([16, None, 1])
(product_action_list): TensorShape([16, None, 1])
(hashed_url_list): TensorShape([16, None, 1])
(product_embeddings): TensorShape([16, None, 50])
)
(_feature_dtypes): Dict(
(product_sku_hash_list): tf.int64
(event_type_list): tf.int64
(product_action_list): tf.int64
(hashed_url_list): tf.int64
(product_embeddings): tf.float32
)
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
(_feature_shapes): Dict(
(product_sku_hash_list): TensorShape([16, None, 1])
(event_type_list): TensorShape([16, None, 1])
(product_action_list): TensorShape([16, None, 1])
(hashed_url_list): TensorShape([16, None, 1])
(product_embeddings): TensorShape([16, None, 50])
)
(_feature_dtypes): Dict(
(product_sku_hash_list): tf.int64
(event_type_list): tf.int64
(product_action_list): tf.int64
(hashed_url_list): tf.int64
(product_embeddings): tf.float32
)
), because it is not built.
WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, sequence_mask_random_layer_call_fn, sequence_mask_random_layer_call_and_return_conditional_losses, prepare_list_features_1_layer_call_fn while saving (showing 5 of 110). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpi3g8g7q7/assets
INFO:tensorflow:Assets written to: /tmp/tmpi3g8g7q7/assets
ensemble = Ensemble(inference_operators, wf.input_schema)
ensemble.export(os.path.join(OUTPUT_DATA_DIR, 'ensemble'));
WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(
(_feature_shapes): Dict(
(product_sku_hash_list): TensorShape([16, None, 1])
(event_type_list): TensorShape([16, None, 1])
(product_action_list): TensorShape([16, None, 1])
(hashed_url_list): TensorShape([16, None, 1])
(product_embeddings): TensorShape([16, None, 50])
)
(_feature_dtypes): Dict(
(product_sku_hash_list): tf.int64
(event_type_list): tf.int64
(product_action_list): tf.int64
(hashed_url_list): tf.int64
(product_embeddings): tf.float32
)
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(
(_feature_shapes): Dict(
(product_sku_hash_list): TensorShape([16, None, 1])
(event_type_list): TensorShape([16, None, 1])
(product_action_list): TensorShape([16, None, 1])
(hashed_url_list): TensorShape([16, None, 1])
(product_embeddings): TensorShape([16, None, 50])
)
(_feature_dtypes): Dict(
(product_sku_hash_list): tf.int64
(event_type_list): tf.int64
(product_action_list): tf.int64
(hashed_url_list): tf.int64
(product_embeddings): tf.float32
)
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
(_feature_shapes): Dict(
(product_sku_hash_list): TensorShape([16, None, 1])
(event_type_list): TensorShape([16, None, 1])
(product_action_list): TensorShape([16, None, 1])
(hashed_url_list): TensorShape([16, None, 1])
(product_embeddings): TensorShape([16, None, 50])
)
(_feature_dtypes): Dict(
(product_sku_hash_list): tf.int64
(event_type_list): tf.int64
(product_action_list): tf.int64
(hashed_url_list): tf.int64
(product_embeddings): tf.float32
)
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
(_feature_shapes): Dict(
(product_sku_hash_list): TensorShape([16, None, 1])
(event_type_list): TensorShape([16, None, 1])
(product_action_list): TensorShape([16, None, 1])
(hashed_url_list): TensorShape([16, None, 1])
(product_embeddings): TensorShape([16, None, 50])
)
(_feature_dtypes): Dict(
(product_sku_hash_list): tf.int64
(event_type_list): tf.int64
(product_action_list): tf.int64
(hashed_url_list): tf.int64
(product_embeddings): tf.float32
)
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
(_feature_shapes): Dict(
(product_sku_hash_list): TensorShape([16, None, 1])
(event_type_list): TensorShape([16, None, 1])
(product_action_list): TensorShape([16, None, 1])
(hashed_url_list): TensorShape([16, None, 1])
(product_embeddings): TensorShape([16, None, 50])
)
(_feature_dtypes): Dict(
(product_sku_hash_list): tf.int64
(event_type_list): tf.int64
(product_action_list): tf.int64
(hashed_url_list): tf.int64
(product_embeddings): tf.float32
)
), because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(
(_feature_shapes): Dict(
(product_sku_hash_list): TensorShape([16, None, 1])
(event_type_list): TensorShape([16, None, 1])
(product_action_list): TensorShape([16, None, 1])
(hashed_url_list): TensorShape([16, None, 1])
(product_embeddings): TensorShape([16, None, 50])
)
(_feature_dtypes): Dict(
(product_sku_hash_list): tf.int64
(event_type_list): tf.int64
(product_action_list): tf.int64
(hashed_url_list): tf.int64
(product_embeddings): tf.float32
)
), because it is not built.
WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, sequence_mask_random_layer_call_fn, sequence_mask_random_layer_call_and_return_conditional_losses, prepare_list_features_1_layer_call_fn while saving (showing 5 of 110). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets
INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets
/usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:101: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
config[key] = tf.keras.utils.serialize_keras_object(maybe_value)
/usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:288: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
config[i] = tf.keras.utils.serialize_keras_object(layer)
/usr/local/lib/python3.8/dist-packages/keras/saving/legacy/saved_model/layer_serialization.py:134: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
return serialization.serialize_keras_object(obj)
/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.
warnings.warn(
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
After we export the ensemble, we are ready to start the Triton Inference Server.
The server is installed in Merlin Tensorflow and Merlin PyTorch containers. If you are not using one of our containers, then ensure it is installed in your environment. For more information, see the Triton Inference Server documentation.
You can start the server by running the following command:
tritonserver --model-repository={OUTPUT_DATA_DIR}/ensemble/
For the –model-repository argument, specify the same value as the export_path
that you specified previously in the ensemble.export
method.
After you run the tritonserver
command, wait until your terminal shows messages like the following example:
I0414 18:29:50.741833 4067 grpc_server.cc:4421] Started GRPCInferenceService at 0.0.0.0:8001
I0414 18:29:50.742197 4067 http_server.cc:3113] Started HTTPService at 0.0.0.0:8000
I0414 18:29:50.783470 4067 http_server.cc:178] Started Metrics Service at 0.0.0.0:8002
Let us now package our data for inference. We will send 5 rows of data, which corresponds to a single customer journey (session) through the online store. The data will be first processed by the NVTabular
workflow and subsequentally passed to our transformer model for predicting.
# obtaining five rows of data
df = train.head(5)
# making sure all the rows correspond to the same online session (have the same `session_id_hash`)
df['session_id_hash'] = df['session_id_hash'].iloc[0]
Let us now send the data to the Triton Inference Server for inference.
from merlin.systems.triton import convert_df_to_triton_input
import tritonclient.grpc as grpcclient
inputs = convert_df_to_triton_input(wf.input_schema, df)
with grpcclient.InferenceServerClient("localhost:8001") as client:
response = client.infer('executor_model', inputs)
Let’s parse the response.
predictions = response.as_numpy("hashed_url_list/categorical_output")
predictions
array([[-2.2332087 , -2.1218574 , -2.390479 , ..., -0.7735352 ,
0.1954267 , -0.34523243]], dtype=float32)
The response contains logits predicting the id of the url the customer is most likely to arrive at as next step of their journey through the online store.
Here is the predicted hashed url id:
predicted_hashed_url_id = predictions.argmax()
predicted_hashed_url_id
34
Summary#
We have trained a transformer model for the next item prediction task using language model masking.
For another session-based example that goes deeper into data preprocessing and that covers several advanced techniques (Weight Tying, Temperature Scaling) please see Session-Based Next Item Prediction for Fashion E-Commerce.