# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_merlin_getting-started-movielens-01-download-convert/nvidia_logo.png

Training and Serving Merlin on AWS SageMaker#

This notebook is created using the latest stable merlin-tensorflow container. Note that AWS libraries in this notebook require AWS credentials, and if you are running this notebook in a container, you might need to restart the container with the AWS credentials mounted, e.g., -v $HOME/.aws:$HOME/.aws.

With AWS Sagemaker, you can package your own models that can then be trained and deployed in the SageMaker environment. This notebook shows you how to use Merlin for training and inference in the SageMaker environment.

It assumes that readers are familiar wtth some basic concepts in NVIDIA Merlin, such as:

  • Using NVTabular to GPU-accelerate preprocessing and feature engineering,

  • Training a ranking model using Merlin Models, and

  • Inference with the Triton Inference Server and Merlin Models for Tensorflow.

To learn more about these concepts in NVIDIA Merlin, see for example Deploying a Multi-Stage Recommender System in this repository or example notebooks in Merlin Models.

To run this notebook, you need to have Amazon SageMaker Python SDK installed.

! python -m pip install sagemaker
Collecting sagemaker
  Downloading sagemaker-2.116.0.tar.gz (592 kB)
     |████████████████████████████████| 592 kB 4.4 MB/s eta 0:00:01
?25hRequirement already satisfied: attrs<23,>=20.3.0 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (22.1.0)
Requirement already satisfied: boto3<2.0,>=1.20.21 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (1.25.2)
Requirement already satisfied: google-pasta in /usr/local/lib/python3.8/dist-packages (from sagemaker) (0.2.0)
Collecting importlib-metadata<5.0,>=1.4.0
  Downloading importlib_metadata-4.13.0-py3-none-any.whl (23 kB)
Requirement already satisfied: numpy<2.0,>=1.9.0 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (21.3)
Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from sagemaker) (1.3.5)
Collecting pathos
  Downloading pathos-0.3.0-py3-none-any.whl (79 kB)
     |████████████████████████████████| 79 kB 10.5 MB/s eta 0:00:01
?25hCollecting protobuf3-to-dict<1.0,>=0.1.5
  Downloading protobuf3-to-dict-0.1.5.tar.gz (3.5 kB)
Requirement already satisfied: protobuf<4.0,>=3.1 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (3.19.6)
Collecting schema
  Downloading schema-0.7.5-py2.py3-none-any.whl (17 kB)
Collecting smdebug_rulesconfig==1.0.1
  Downloading smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl (20 kB)
Requirement already satisfied: s3transfer<0.7.0,>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from boto3<2.0,>=1.20.21->sagemaker) (0.6.0)
Requirement already satisfied: botocore<1.29.0,>=1.28.2 in /usr/local/lib/python3.8/dist-packages (from boto3<2.0,>=1.20.21->sagemaker) (1.28.2)
Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /usr/local/lib/python3.8/dist-packages (from boto3<2.0,>=1.20.21->sagemaker) (1.0.1)
Requirement already satisfied: six in /usr/lib/python3/dist-packages (from google-pasta->sagemaker) (1.14.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata<5.0,>=1.4.0->sagemaker) (3.10.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->sagemaker) (3.0.9)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->sagemaker) (2022.5)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->sagemaker) (2.8.2)
Collecting pox>=0.3.2
  Downloading pox-0.3.2-py3-none-any.whl (29 kB)
Collecting dill>=0.3.6
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
     |████████████████████████████████| 110 kB 17.3 MB/s eta 0:00:01
?25hCollecting multiprocess>=0.70.14
  Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 17.8 MB/s eta 0:00:01
?25hCollecting ppft>=1.7.6.6
  Downloading ppft-1.7.6.6-py3-none-any.whl (52 kB)
     |████████████████████████████████| 52 kB 2.9 MB/s  eta 0:00:01
?25hCollecting contextlib2>=0.5.5
  Downloading contextlib2-21.6.0-py2.py3-none-any.whl (13 kB)
Requirement already satisfied: urllib3<1.27,>=1.25.4 in /usr/local/lib/python3.8/dist-packages (from botocore<1.29.0,>=1.28.2->boto3<2.0,>=1.20.21->sagemaker) (1.26.12)
Building wheels for collected packages: sagemaker, protobuf3-to-dict
  Building wheel for sagemaker (setup.py) ... ?25ldone
?25h  Created wheel for sagemaker: filename=sagemaker-2.116.0-py2.py3-none-any.whl size=809052 sha256=f446dd6eed6d268b7f3f2709f8f11c1ba153e382fbea9b2caedd517c1fb71215
  Stored in directory: /root/.cache/pip/wheels/3e/cb/b1/5b13ff7b150aa151e4a11030a6c41b1e457c31a52ea1ef11b0
  Building wheel for protobuf3-to-dict (setup.py) ... ?25ldone
?25h  Created wheel for protobuf3-to-dict: filename=protobuf3_to_dict-0.1.5-py3-none-any.whl size=4029 sha256=8f99baaa875ba544d54f624f95dfbf4fd52ca96d52ce8af6d05c1ff2bb8435b2
  Stored in directory: /root/.cache/pip/wheels/fc/10/27/2d1e23d8b9a9013a83fbb418a0b17b1e6f81c8db8f53b53934
Successfully built sagemaker protobuf3-to-dict
Installing collected packages: importlib-metadata, pox, dill, multiprocess, ppft, pathos, protobuf3-to-dict, contextlib2, schema, smdebug-rulesconfig, sagemaker
  Attempting uninstall: importlib-metadata
    Found existing installation: importlib-metadata 5.0.0
    Uninstalling importlib-metadata-5.0.0:
      Successfully uninstalled importlib-metadata-5.0.0
Successfully installed contextlib2-21.6.0 dill-0.3.6 importlib-metadata-4.13.0 multiprocess-0.70.14 pathos-0.3.0 pox-0.3.2 ppft-1.7.6.6 protobuf3-to-dict-0.1.5 sagemaker-2.116.0 schema-0.7.5 smdebug-rulesconfig-1.0.1

Part 1: Generating Dataset and Docker image#

Generating Dataset#

In this notebook, we use the synthetic train and test datasets generated by mimicking the real Ali-CCP: Alibaba Click and Conversion Prediction dataset to build our recommender system ranking models. The Ali-CCP is a dataset gathered from real-world traffic logs of the recommender system in Taobao, the largest online retail platform in the world.

If you would like to use real Ali-CCP dataset instead, you can download the training and test datasets on tianchi.aliyun.com. You can then use get_aliccp() function to curate the raw csv files and save them as parquet files.

import os

from merlin.datasets.synthetic import generate_data

DATA_FOLDER = os.environ.get("DATA_FOLDER", "/workspace/data/")
NUM_ROWS = os.environ.get("NUM_ROWS", 1_000_000)
SYNTHETIC_DATA = eval(os.environ.get("SYNTHETIC_DATA", "True"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 512))

if SYNTHETIC_DATA:
    train, valid = generate_data("aliccp-raw", int(NUM_ROWS), set_sizes=(0.7, 0.3))
    # save the datasets as parquet files
    train.to_ddf().to_parquet(os.path.join(DATA_FOLDER, "train"))
    valid.to_ddf().to_parquet(os.path.join(DATA_FOLDER, "valid"))
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(

Training Script#

The training script train.py in this example starts with the synthetic dataset we have created in the previous cell and produces a ranking model by performing the following tasks:

  • Perform feature engineering and preprocessing with NVTabular. NVTabular implements common feature engineering and preprocessing operators in easy-to-use, high-level APIs.

  • Use Merlin Models to train Facebook’s DLRM model in Tensorflow.

  • Prepares ensemble models for serving on Triton Inference Server. The training script outputs to model_dir the final NVTabular workflow and the trained DLRM model as an ensemble model. You want to make sure that your script generates any artifacts within model_dir, since SageMaker packages any files in this directory into a compressed tar archive and made available at the S3 location. The ensemble model that is uploaded to S3 will be used later to handle predictions in Triton inference server later in this notebook.

%%writefile train.py
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import json
import logging
import os
import sys
import tempfile

# We can control how much memory to give tensorflow with this environment variable
# IMPORTANT: make sure you do this before you initialize TF's runtime, otherwise
# TF will have claimed all free GPU memory
os.environ["TF_MEMORY_ALLOCATION"] = "0.7"  # fraction of free memory

import merlin.io
import merlin.models.tf as mm
import nvtabular as nvt
import tensorflow as tf
from merlin.schema.tags import Tags
from merlin.systems.dag.ops.workflow import TransformWorkflow
from merlin.systems.dag.ops.tensorflow import PredictTensorflow
from merlin.systems.dag.ensemble import Ensemble
import numpy as np
from nvtabular.ops import *


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))


def parse_args():
    """
    Parse arguments passed from the SageMaker API to the container.
    """

    parser = argparse.ArgumentParser()

    # Hyperparameters sent by the client are passed as command-line arguments to the script
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=1024)

    # Data directories
    parser.add_argument(
        "--train_dir", type=str, default=os.environ.get("SM_CHANNEL_TRAIN")
    )
    parser.add_argument(
        "--valid_dir", type=str, default=os.environ.get("SM_CHANNEL_VALID")
    )

    # Model directory: we will use the default set by SageMaker, /opt/ml/model
    parser.add_argument("--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR"))

    return parser.parse_known_args()


def create_nvtabular_workflow(train_path, valid_path):
    user_id = ["user_id"] >> Categorify() >> TagAsUserID()
    item_id = ["item_id"] >> Categorify() >> TagAsItemID()
    targets = ["click"] >> AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, "target"])

    item_features = (
        ["item_category", "item_shop", "item_brand"]
        >> Categorify()
        >> TagAsItemFeatures()
    )

    user_features = (
        [
            "user_shops",
            "user_profile",
            "user_group",
            "user_gender",
            "user_age",
            "user_consumption_2",
            "user_is_occupied",
            "user_geography",
            "user_intentions",
            "user_brands",
            "user_categories",
        ]
        >> Categorify()
        >> TagAsUserFeatures()
    )

    outputs = user_id + item_id + item_features + user_features + targets

    workflow = nvt.Workflow(outputs)

    return workflow


def create_ensemble(workflow, model):
    serving_operators = (
        workflow.input_schema.column_names
        >> TransformWorkflow(workflow)
        >> PredictTensorflow(model)
    )
    ensemble = Ensemble(serving_operators, workflow.input_schema)
    return ensemble


def train():
    """
    Train the Merlin model.
    """
    train_path = os.path.join(args.train_dir, "*.parquet")
    valid_path = os.path.join(args.valid_dir, "*.parquet")

    workflow = create_nvtabular_workflow(
        train_path=train_path,
        valid_path=valid_path,
    )

    train_dataset = nvt.Dataset(train_path)
    valid_dataset = nvt.Dataset(valid_path)

    output_path = tempfile.mkdtemp()
    workflow_path = os.path.join(output_path, "workflow")

    workflow.fit(train_dataset)
    workflow.transform(train_dataset).to_parquet(
        output_path=os.path.join(output_path, "train")
    )
    workflow.transform(valid_dataset).to_parquet(
        output_path=os.path.join(output_path, "valid")
    )

    workflow.save(workflow_path)
    logger.info(f"Workflow saved to {workflow_path}.")

    train_data = merlin.io.Dataset(os.path.join(output_path, "train", "*.parquet"))
    valid_data = merlin.io.Dataset(os.path.join(output_path, "valid", "*.parquet"))

    schema = train_data.schema
    target_column = schema.select_by_tag(Tags.TARGET).column_names[0]

    model = mm.DLRMModel(
        schema,
        embedding_dim=64,
        bottom_block=mm.MLPBlock([128, 64]),
        top_block=mm.MLPBlock([128, 64, 32]),
        prediction_tasks=mm.BinaryClassificationTask(target_column),
    )

    model.compile("adam", run_eagerly=False, metrics=[tf.keras.metrics.AUC()])

    batch_size = args.batch_size
    epochs = args.epochs
    logger.info(f"batch_size = {batch_size}, epochs = {epochs}")

    model.fit(
        train_data,
        validation_data=valid_data,
        batch_size=args.batch_size,
        epochs=epochs,
        verbose=2,
    )

    model_path = os.path.join(output_path, "dlrm")
    model.save(model_path)
    logger.info(f"Model saved to {model_path}.")

    # We remove the label columns from its inputs.
    # This removes all columns with the TARGET tag from the workflow.
    # We do this because we need to set the workflow to only require the
    # features needed to predict, not train, when creating an inference
    # pipeline.
    label_columns = workflow.output_schema.select_by_tag(Tags.TARGET).column_names
    workflow.remove_inputs(label_columns)

    ensemble = create_ensemble(workflow, model)
    ensemble_path = args.model_dir
    ensemble.export(ensemble_path)
    logger.info(f"Ensemble graph saved to {ensemble_path}.")


if __name__ == "__main__":
    args, _ = parse_args()
    train()
Overwriting train.py

Create the Dockerfile#

The Dockerfile describes the image that will be used on SageMaker for training and inference. We start from the latest stable merlin-tensorflow docker image and install the sagemaker-training-toolkit library, which makes the image compatible with Sagemaker for training models.

%%writefile container/Dockerfile

FROM nvcr.io/nvidia/merlin/merlin-tensorflow:22.10

RUN pip3 install sagemaker-training
Overwriting container/Dockerfile

Building and registering the container#

The following shell code shows how to build the container image using docker build and push the container image to ECR using docker push. This code is available as the shell script build_and_push_image.sh. If you are running this notebook inside the merlin-tensorflow docker container, you probably need to execute the script outside the container (e.g., in your terminal where you can run the docker command).

You need to have the AWS CLI installed to run this code. To install the AWS CLI, see Installing or updating the latest version of the AWS CLI.

This code looks for an ECR repository in the account you’re using and the current default region (if you’re using a SageMaker notebook instance, this is the region where the notebook instance was created). If the repository doesn’t exist, the script will create it.

Note that running the following script requires permissions to create new repositories on Amazon ECR.

%%writefile ./build_and_push_image.sh

#!/bin/bash

set -euo pipefail

# The name of our algorithm
ALGORITHM_NAME=sagemaker-merlin-tensorflow
REGION=us-east-1

cd container

ACCOUNT=$(aws sts get-caller-identity --query Account --output text --region ${REGION})

# Get the region defined in the current configuration (default to us-west-2 if none defined)

REPOSITORY="${ACCOUNT}.dkr.ecr.${REGION}.amazonaws.com"
IMAGE_URI="${REPOSITORY}/${ALGORITHM_NAME}:latest"

# Get the login command from ECR and execute it directly
aws ecr get-login-password --region ${REGION} | docker login --username AWS --password-stdin ${REPOSITORY}

# If the repository doesn't exist in ECR, create it.

aws ecr describe-repositories --repository-names "${ALGORITHM_NAME}" --region ${REGION} > /dev/null 2>&1

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${ALGORITHM_NAME}" --region ${REGION} > /dev/null
fi

# Build the docker image locally with the image name and then push it to ECR
# with the full name.

docker build  -t ${ALGORITHM_NAME} .
docker tag ${ALGORITHM_NAME} ${IMAGE_URI}

docker push ${IMAGE_URI}
Overwriting ./build_and_push_image.sh
# If you are able to run `docker` from the notebook environment, you can uncomment and run the below script.
# ! ./build_and_push_image.sh

Part 2: Training your Merlin model on Sagemaker#

To deploy the training script onto Sagemaker, we use the Sagemaker Python SDK. Here, we create a Sagemaker session that we will use to perform our Sagemaker operations, specify the bucket to use, and the role for working with Sagemaker.

import sagemaker

sess = sagemaker.Session()

# S3 prefix
prefix = "DEMO-merlin-tensorflow-aliccp"

role = sagemaker.get_execution_role()

print(role)
Couldn't call 'get_role' to get Role ARN from role name AWSOS-AD-Engineer to get Role path.
arn:aws:iam::843263297212:role/AWSOS-AD-Engineer

We can use the Sagemaker Python SDK to upload the Ali-CCP synthetic data to our S3 bucket.

data_location = sess.upload_data(DATA_FOLDER, key_prefix=prefix)

print(data_location)
s3://sagemaker-us-east-1-843263297212/DEMO-merlin-tensorflow-aliccp

Training on Sagemaker using the Python SDK#

Sagemaker provides the Python SDK for training a model on Sagemaker.

Here, we start by using the ECR image URL of the image we pushed in the previous section.

import boto3

sts_client = boto3.client("sts")
account = sts_client.get_caller_identity()["Account"]

my_session = boto3.session.Session()
region = my_session.region_name

algorithm_name = "sagemaker-merlin-tensorflow"

ecr_image = "{}.dkr.ecr.{}.amazonaws.com/{}:latest".format(
    account, region, algorithm_name
)

print(ecr_image)
843263297212.dkr.ecr.us-east-1.amazonaws.com/sagemaker-merlin-tensorflow:latest

We can call Estimator.fit() to start training on Sagemaker. Here, we use a g4dn GPU instance that are equipped with NVIDIA T4 GPUs. Our training script train.py is passed to the Estimator through the entry_point parameter. Behind the scenes, the Sagemaker Python SDK will upload the training script specified in theentry_point field (train.py in our case) to the S3 bucket and set the SAGEMAKER_PROGRAM environment variable in the training instance to the S3 location so that the training instance can download the training script on S3 to the training instance. We also adjust our hyperparameters in the hyperparameters field. We have uploaded our training dataset to our S3 bucket in the previous code cell, and the S3 URLs to our training and validation sets are passed into the fit() method.

import os
from sagemaker.estimator import Estimator


training_instance_type = "ml.g4dn.xlarge"  # GPU instance, T4

estimator = Estimator(
    role=role,
    instance_count=1,
    instance_type=training_instance_type,
    image_uri=ecr_image,
    entry_point="train.py",
    hyperparameters={
        "batch_size": 1_024,
        "epoch": 10,
    },
)

estimator.fit(
    {
        "train": f"{data_location}/train/",
        "valid": f"{data_location}/valid/",
    }
)
2022-11-09 10:18:31 Starting - Starting the training job...
2022-11-09 10:18:54 Starting - Preparing the instances for trainingProfilerReport-1667989110: InProgress
......
2022-11-09 10:19:54 Downloading - Downloading input data...
2022-11-09 10:20:34 Training - Downloading the training image..................................==================================
== Triton Inference Server Base ==
==================================
NVIDIA Release 22.08 (build 42766143)
Copyright (c) 2018-2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
Various files include modifications (c) NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license
NOTE: CUDA Forward Compatibility mode ENABLED.
  Using CUDA 11.7 driver version 515.65.01 with kernel driver version 510.47.03.
  See https://docs.nvidia.com/deploy/cuda-compatibility/ for details.
2022-11-09 10:27:03,405 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)
2022-11-09 10:27:03,438 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)
2022-11-09 10:27:03,473 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)
2022-11-09 10:27:03,485 sagemaker-training-toolkit INFO     Invoking user script
Training Env:
{
    "additional_framework_parameters": {},
    "channel_input_dirs": {
        "train": "/opt/ml/input/data/train",
        "valid": "/opt/ml/input/data/valid"
    },
    "current_host": "algo-1",
    "current_instance_group": "homogeneousCluster",
    "current_instance_group_hosts": [
        "algo-1"
    ],
    "current_instance_type": "ml.g4dn.xlarge",
    "distribution_hosts": [],
    "distribution_instance_groups": [],
    "framework_module": null,
    "hosts": [
        "algo-1"
    ],
    "hyperparameters": {
        "batch_size": 1024,
        "epoch": 10
    },
    "input_config_dir": "/opt/ml/input/config",
    "input_data_config": {
        "train": {
            "TrainingInputMode": "File",
            "S3DistributionType": "FullyReplicated",
            "RecordWrapperType": "None"
        },
        "valid": {
            "TrainingInputMode": "File",
            "S3DistributionType": "FullyReplicated",
            "RecordWrapperType": "None"
        }
    },
    "input_dir": "/opt/ml/input",
    "instance_groups": [
        "homogeneousCluster"
    ],
    "instance_groups_dict": {
        "homogeneousCluster": {
            "instance_group_name": "homogeneousCluster",
            "instance_type": "ml.g4dn.xlarge",
            "hosts": [
                "algo-1"
            ]
        }
    },
    "is_hetero": false,
    "is_master": true,
    "is_modelparallel_enabled": null,
    "is_smddpmprun_installed": false,
    "job_name": "sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376",
    "log_level": 20,
    "master_hostname": "algo-1",
    "model_dir": "/opt/ml/model",
    "module_dir": "s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376/source/sourcedir.tar.gz",
    "module_name": "train",
    "network_interface_name": "eth0",
    "num_cpus": 4,
    "num_gpus": 1,
    "num_neurons": 0,
    "output_data_dir": "/opt/ml/output/data",
    "output_dir": "/opt/ml/output",
    "output_intermediate_dir": "/opt/ml/output/intermediate",
    "resource_config": {
        "current_host": "algo-1",
        "current_instance_type": "ml.g4dn.xlarge",
        "current_group_name": "homogeneousCluster",
        "hosts": [
            "algo-1"
        ],
        "instance_groups": [
            {
                "instance_group_name": "homogeneousCluster",
                "instance_type": "ml.g4dn.xlarge",
                "hosts": [
                    "algo-1"
                ]
            }
        ],
        "network_interface_name": "eth0"
    },
    "user_entry_point": "train.py"
}
Environment variables:
SM_HOSTS=["algo-1"]
SM_NETWORK_INTERFACE_NAME=eth0
SM_HPS={"batch_size":1024,"epoch":10}
SM_USER_ENTRY_POINT=train.py
SM_FRAMEWORK_PARAMS={}
SM_RESOURCE_CONFIG={"current_group_name":"homogeneousCluster","current_host":"algo-1","current_instance_type":"ml.g4dn.xlarge","hosts":["algo-1"],"instance_groups":[{"hosts":["algo-1"],"instance_group_name":"homogeneousCluster","instance_type":"ml.g4dn.xlarge"}],"network_interface_name":"eth0"}
SM_INPUT_DATA_CONFIG={"train":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"},"valid":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}}
SM_OUTPUT_DATA_DIR=/opt/ml/output/data
SM_CHANNELS=["train","valid"]
SM_CURRENT_HOST=algo-1
SM_CURRENT_INSTANCE_TYPE=ml.g4dn.xlarge
SM_CURRENT_INSTANCE_GROUP=homogeneousCluster
SM_CURRENT_INSTANCE_GROUP_HOSTS=["algo-1"]
SM_INSTANCE_GROUPS=["homogeneousCluster"]
SM_INSTANCE_GROUPS_DICT={"homogeneousCluster":{"hosts":["algo-1"],"instance_group_name":"homogeneousCluster","instance_type":"ml.g4dn.xlarge"}}
SM_DISTRIBUTION_INSTANCE_GROUPS=[]
SM_IS_HETERO=false
SM_MODULE_NAME=train
SM_LOG_LEVEL=20
SM_FRAMEWORK_MODULE=
SM_INPUT_DIR=/opt/ml/input
SM_INPUT_CONFIG_DIR=/opt/ml/input/config
SM_OUTPUT_DIR=/opt/ml/output
SM_NUM_CPUS=4
SM_NUM_GPUS=1
SM_NUM_NEURONS=0
SM_MODEL_DIR=/opt/ml/model
SM_MODULE_DIR=s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376/source/sourcedir.tar.gz
SM_TRAINING_ENV={"additional_framework_parameters":{},"channel_input_dirs":{"train":"/opt/ml/input/data/train","valid":"/opt/ml/input/data/valid"},"current_host":"algo-1","current_instance_group":"homogeneousCluster","current_instance_group_hosts":["algo-1"],"current_instance_type":"ml.g4dn.xlarge","distribution_hosts":[],"distribution_instance_groups":[],"framework_module":null,"hosts":["algo-1"],"hyperparameters":{"batch_size":1024,"epoch":10},"input_config_dir":"/opt/ml/input/config","input_data_config":{"train":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"},"valid":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}},"input_dir":"/opt/ml/input","instance_groups":["homogeneousCluster"],"instance_groups_dict":{"homogeneousCluster":{"hosts":["algo-1"],"instance_group_name":"homogeneousCluster","instance_type":"ml.g4dn.xlarge"}},"is_hetero":false,"is_master":true,"is_modelparallel_enabled":null,"is_smddpmprun_installed":false,"job_name":"sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376","log_level":20,"master_hostname":"algo-1","model_dir":"/opt/ml/model","module_dir":"s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376/source/sourcedir.tar.gz","module_name":"train","network_interface_name":"eth0","num_cpus":4,"num_gpus":1,"num_neurons":0,"output_data_dir":"/opt/ml/output/data","output_dir":"/opt/ml/output","output_intermediate_dir":"/opt/ml/output/intermediate","resource_config":{"current_group_name":"homogeneousCluster","current_host":"algo-1","current_instance_type":"ml.g4dn.xlarge","hosts":["algo-1"],"instance_groups":[{"hosts":["algo-1"],"instance_group_name":"homogeneousCluster","instance_type":"ml.g4dn.xlarge"}],"network_interface_name":"eth0"},"user_entry_point":"train.py"}
SM_USER_ARGS=["--batch_size","1024","--epoch","10"]
SM_OUTPUT_INTERMEDIATE_DIR=/opt/ml/output/intermediate
SM_CHANNEL_TRAIN=/opt/ml/input/data/train
SM_CHANNEL_VALID=/opt/ml/input/data/valid
SM_HP_BATCH_SIZE=1024
SM_HP_EPOCH=10
PYTHONPATH=/opt/ml/code:/usr/local/bin:/opt/tritonserver:/usr/local/lib/python3.8/dist-packages:/usr/lib/python38.zip:/usr/lib/python3.8:/usr/lib/python3.8/lib-dynload:/usr/local/lib/python3.8/dist-packages/faiss-1.7.2-py3.8.egg:/usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg:/usr/local/lib/python3.8/dist-packages/merlin_hps-1.0.0-py3.8-linux-x86_64.egg:/usr/lib/python3/dist-packages
Invoking script with the following command:
/usr/bin/python3 train.py --batch_size 1024 --epoch 10
2022-11-09 10:27:03,486 sagemaker-training-toolkit INFO     Exceptions not imported for SageMaker Debugger as it is not installed.

2022-11-09 10:27:16 Training - Training image download completed. Training in progress.2022-11-09 10:27:08.761711: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-11-09 10:27:12.818302: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-09 10:27:12.819693: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-09 10:27:12.819906: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-09 10:27:12.894084: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-09 10:27:12.895367: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-09 10:27:12.895631: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-09 10:27:12.895807: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-09 10:27:16.651703: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-09 10:27:16.651981: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-09 10:27:16.652183: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-09 10:27:16.653025: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10752 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:1e.0, compute capability: 7.5
Workflow saved to /tmp/tmp5fpdavsc/workflow.
batch_size = 1024, epochs = 10
Epoch 1/10
684/684 - 14s - loss: 0.6932 - auc: 0.4998 - regularization_loss: 0.0000e+00 - val_loss: 0.6931 - val_auc: 0.5000 - val_regularization_loss: 0.0000e+00 - 14s/epoch - 20ms/step
Epoch 2/10
684/684 - 8s - loss: 0.6931 - auc: 0.5026 - regularization_loss: 0.0000e+00 - val_loss: 0.6932 - val_auc: 0.4990 - val_regularization_loss: 0.0000e+00 - 8s/epoch - 11ms/step
Epoch 3/10
684/684 - 7s - loss: 0.6922 - auc: 0.5222 - regularization_loss: 0.0000e+00 - val_loss: 0.6941 - val_auc: 0.4989 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step
Epoch 4/10
684/684 - 7s - loss: 0.6858 - auc: 0.5509 - regularization_loss: 0.0000e+00 - val_loss: 0.6991 - val_auc: 0.4994 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step
Epoch 5/10
684/684 - 7s - loss: 0.6790 - auc: 0.5660 - regularization_loss: 0.0000e+00 - val_loss: 0.7052 - val_auc: 0.4993 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step
Epoch 6/10
684/684 - 8s - loss: 0.6751 - auc: 0.5722 - regularization_loss: 0.0000e+00 - val_loss: 0.7096 - val_auc: 0.4994 - val_regularization_loss: 0.0000e+00 - 8s/epoch - 11ms/step
Epoch 7/10
684/684 - 7s - loss: 0.6722 - auc: 0.5755 - regularization_loss: 0.0000e+00 - val_loss: 0.7184 - val_auc: 0.4991 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step
Epoch 8/10
684/684 - 7s - loss: 0.6700 - auc: 0.5777 - regularization_loss: 0.0000e+00 - val_loss: 0.7289 - val_auc: 0.4990 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step
Epoch 9/10
684/684 - 8s - loss: 0.6687 - auc: 0.5792 - regularization_loss: 0.0000e+00 - val_loss: 0.7404 - val_auc: 0.4994 - val_regularization_loss: 0.0000e+00 - 8s/epoch - 11ms/step
Epoch 10/10
684/684 - 8s - loss: 0.6678 - auc: 0.5801 - regularization_loss: 0.0000e+00 - val_loss: 0.7393 - val_auc: 0.4988 - val_regularization_loss: 0.0000e+00 - 8s/epoch - 11ms/step
/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!
  warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
WARNING:absl:Found untraced functions such as train_compute_metrics, model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, output_layer_layer_call_fn, output_layer_layer_call_and_return_conditional_losses while saving (showing 5 of 97). These functions will not be directly callable after loading.
INFO:__main__:Model saved to /tmp/tmp5fpdavsc/dlrm.
Model saved to /tmp/tmp5fpdavsc/dlrm.
WARNING:absl:Found untraced functions such as train_compute_metrics, model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, output_layer_layer_call_fn, output_layer_layer_call_and_return_conditional_losses while saving (showing 5 of 97). These functions will not be directly callable after loading.
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:absl:Found untraced functions such as train_compute_metrics, model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, output_layer_layer_call_fn, output_layer_layer_call_and_return_conditional_losses while saving (showing 5 of 97). These functions will not be directly callable after loading.
Ensemble graph saved to /opt/ml/model.
INFO:__main__:Ensemble graph saved to /opt/ml/model.
2022-11-09 10:29:21,498 sagemaker-training-toolkit INFO     Reporting training SUCCESS

2022-11-09 10:29:41 Uploading - Uploading generated training model
2022-11-09 10:29:41 Completed - Training job completed
Training seconds: 589
Billable seconds: 589
print(estimator.model_data)
s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376/output/model.tar.gz
from sagemaker.s3 import S3Downloader as s3down

s3down.download(estimator.model_data, "/tmp/ensemble/")
! cd /tmp/ensemble && tar xvzf model.tar.gz
1_predicttensorflow/
1_predicttensorflow/config.pbtxt
1_predicttensorflow/1/
1_predicttensorflow/1/model.savedmodel/
1_predicttensorflow/1/model.savedmodel/assets/
1_predicttensorflow/1/model.savedmodel/variables/
1_predicttensorflow/1/model.savedmodel/variables/variables.index
1_predicttensorflow/1/model.savedmodel/variables/variables.data-00000-of-00001
1_predicttensorflow/1/model.savedmodel/saved_model.pb
1_predicttensorflow/1/model.savedmodel/keras_metadata.pb
ensemble_model/
ensemble_model/config.pbtxt
ensemble_model/1/
0_transformworkflow/
0_transformworkflow/config.pbtxt
0_transformworkflow/1/
0_transformworkflow/1/model.py
0_transformworkflow/1/workflow/
0_transformworkflow/1/workflow/categories/
0_transformworkflow/1/workflow/categories/unique.user_profile.parquet
0_transformworkflow/1/workflow/categories/unique.user_age.parquet
0_transformworkflow/1/workflow/categories/unique.user_group.parquet
0_transformworkflow/1/workflow/categories/unique.user_intentions.parquet
0_transformworkflow/1/workflow/categories/unique.item_brand.parquet
0_transformworkflow/1/workflow/categories/unique.user_geography.parquet
0_transformworkflow/1/workflow/categories/unique.user_is_occupied.parquet
0_transformworkflow/1/workflow/categories/unique.user_id.parquet
0_transformworkflow/1/workflow/categories/unique.user_gender.parquet
0_transformworkflow/1/workflow/categories/unique.user_shops.parquet
0_transformworkflow/1/workflow/categories/unique.item_category.parquet
0_transformworkflow/1/workflow/categories/unique.user_brands.parquet
0_transformworkflow/1/workflow/categories/unique.user_consumption_2.parquet
0_transformworkflow/1/workflow/categories/unique.item_id.parquet
0_transformworkflow/1/workflow/categories/unique.item_shop.parquet
0_transformworkflow/1/workflow/categories/unique.user_categories.parquet
0_transformworkflow/1/workflow/workflow.pkl
0_transformworkflow/1/workflow/metadata.json

Part 3: Retrieving Recommendations from Triton Inference Server#

Although we use the Sagemaker Python SDK to train our model, here we will use boto3 to launch our inference endpoint as it offers more low-level control than the Python SDK.

The model artificat model.tar.gz uploaded to S3 from the Sagemaker training job contained three directories: 0_transformworkflow for the NVTabular workflow, 1_predicttensorflow for the Tensorflow model, and ensemble_model for the ensemble graph that we can use in Triton.

/tmp/ensemble/
├── 0_transformworkflow
│   ├── 1      ├── model.py
│      └── workflow
│          ├── categories
│             ├── unique.item_brand.parquet
│             ├── unique.item_category.parquet
│             ├── unique.item_id.parquet
│             ├── unique.item_shop.parquet
│             ├── unique.user_age.parquet
│             ├── unique.user_brands.parquet
│             ├── unique.user_categories.parquet
│             ├── unique.user_consumption_2.parquet
│             ├── unique.user_gender.parquet
│             ├── unique.user_geography.parquet
│             ├── unique.user_group.parquet
│             ├── unique.user_id.parquet
│             ├── unique.user_intentions.parquet
│             ├── unique.user_is_occupied.parquet
│             ├── unique.user_profile.parquet
│             └── unique.user_shops.parquet
│          ├── metadata.json
│          └── workflow.pkl
│   └── config.pbtxt
├── 1_predicttensorflow
│   ├── 1      └── model.savedmodel
│          ├── assets
│          ├── keras_metadata.pb
│          ├── saved_model.pb
│          └── variables
│              ├── variables.data-00000-of-00001
│              └── variables.index
│   └── config.pbtxt
├── ensemble_model
│   ├── 1   └── config.pbtxt
└── model.tar.gz

We specify that we only want to use ensemble_model in Triton by passing the environment variable SAGEMAKER_TRITON_DEFAULT_MODEL_NAME.

import time

import boto3

sm_client = boto3.client(service_name="sagemaker")

container = {
    "Image": ecr_image,
    "ModelDataUrl": estimator.model_data,
    "Environment": {
        "SAGEMAKER_TRITON_TENSORFLOW_VERSION": "2",
        "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "ensemble_model",
    },
}

model_name = "model-triton-merlin-ensemble-" + time.strftime(
    "%Y-%m-%d-%H-%M-%S", time.gmtime()
)

create_model_response = sm_client.create_model(
    ModelName=model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

model_arn = create_model_response["ModelArn"]

print(f"Model Arn: {model_arn}")
Model Arn: arn:aws:sagemaker:us-east-1:843263297212:model/model-triton-merlin-ensemble-2022-11-09-10-29-57

We again use the g4dn GPU instance that are equipped with NVIDIA T4 GPUs for launching the Triton inference server.

endpoint_instance_type = "ml.g4dn.xlarge"

endpoint_config_name = "endpoint-config-triton-merlin-ensemble-" + time.strftime(
    "%Y-%m-%d-%H-%M-%S", time.gmtime()
)

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": endpoint_instance_type,
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

endpoint_config_arn = create_endpoint_config_response["EndpointConfigArn"]

print(f"Endpoint Config Arn: {endpoint_config_arn}")
Endpoint Config Arn: arn:aws:sagemaker:us-east-1:843263297212:endpoint-config/endpoint-config-triton-merlin-ensemble-2022-11-09-10-29-58
endpoint_name = "endpoint-triton-merlin-ensemble-" + time.strftime(
    "%Y-%m-%d-%H-%M-%S", time.gmtime()
)

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

endpoint_arn = create_endpoint_response["EndpointArn"]

print(f"Endpoint Arn: {endpoint_arn}")
Endpoint Arn: arn:aws:sagemaker:us-east-1:843263297212:endpoint/endpoint-triton-merlin-ensemble-2022-11-09-10-29-58
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint Creation Status: {status}")

while status == "Creating":
    time.sleep(60)
    rv = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = rv["EndpointStatus"]
    print(f"Endpoint Creation Status: {status}")

endpoint_arn = rv["EndpointArn"]

print(f"Endpoint Arn: {endpoint_arn}")
print(f"Endpoint Status: {status}")
Endpoint Creation Status: Creating
Endpoint Creation Status: Creating
Endpoint Creation Status: Creating
Endpoint Creation Status: Creating
Endpoint Creation Status: Creating
Endpoint Creation Status: Creating
Endpoint Creation Status: Creating
Endpoint Creation Status: Creating
Endpoint Creation Status: InService
Endpoint Arn: arn:aws:sagemaker:us-east-1:843263297212:endpoint/endpoint-triton-merlin-ensemble-2022-11-09-10-29-58
Endpoint Status: InService

Send a Request to Triton Inference Server to Transform a Raw Dataset#

Once we have an endpoint running, we can test it by sending requests. Here, we use the raw validation set and transform it using the saved NVTabular workflow we have downloaded from S3 in the previous section.

from merlin.schema.tags import Tags
from merlin.core.dispatch import get_lib
from nvtabular.workflow import Workflow

df_lib = get_lib()

workflow = Workflow.load("/tmp/ensemble/0_transformworkflow/1/workflow/")

label_columns = workflow.output_schema.select_by_tag(Tags.TARGET).column_names
workflow.remove_inputs(label_columns)

# read in data for request
batch = df_lib.read_parquet(
    os.path.join(DATA_FOLDER, "valid", "part.0.parquet"),
    columns=workflow.input_schema.column_names,
)[:10]
print(batch)
                     user_id  item_id  item_category  item_shop  item_brand  \
__null_dask_index__                                                           
700000                    12        2              3        194          67   
700001                    12       30             80       5621        1936   
700002                    18        5             12        776         267   
700003                    35        6             14        970         334   
700004                    51       11             28       1939         668   
700005                    22       83            226      15893        5474   
700006                    13       38            102       7172        2470   
700007                    10        7             17       1163         401   
700008                     4        4              9        582         201   
700009                     4       24             64       4458        1536   

                     user_shops  user_profile  user_group  user_gender  \
__null_dask_index__                                                      
700000                      636             1           1            1   
700001                      636             1           1            1   
700002                      983             1           1            1   
700003                     1965             2           1            1   
700004                     2890             3           1            1   
700005                     1214             2           1            1   
700006                      694             1           1            1   
700007                      521             1           1            1   
700008                      174             1           1            1   
700009                      174             1           1            1   

                     user_age  user_consumption_2  user_is_occupied  \
__null_dask_index__                                                   
700000                      1                   1                 1   
700001                      1                   1                 1   
700002                      1                   1                 1   
700003                      1                   1                 1   
700004                      1                   1                 1   
700005                      1                   1                 1   
700006                      1                   1                 1   
700007                      1                   1                 1   
700008                      1                   1                 1   
700009                      1                   1                 1   

                     user_geography  user_intentions  user_brands  \
__null_dask_index__                                                 
700000                            1              184          316   
700001                            1              184          316   
700002                            1              285          489   
700003                            1              569          977   
700004                            1              837         1436   
700005                            1              352          604   
700006                            1              201          345   
700007                            1              151          259   
700008                            1               51           87   
700009                            1               51           87   

                     user_categories  
__null_dask_index__                   
700000                            34  
700001                            34  
700002                            52  
700003                           103  
700004                           151  
700005                            64  
700006                            37  
700007                            28  
700008                            10  
700009                            10  

In the following code cell, we use a utility function provided in Merlin Systems to convert our dataframe to the payload format that can be used as inference request format for Triton.

from merlin.systems.triton import convert_df_to_triton_input
import tritonclient.http as httpclient

inputs = convert_df_to_triton_input(workflow.input_schema, batch, httpclient.InferInput)

request_body, header_length = httpclient.InferenceServerClient.generate_request_body(
    inputs
)

print(request_body)
b'{"inputs":[{"name":"user_id","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"item_id","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"item_category","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"item_shop","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"item_brand","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_shops","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_profile","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_group","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_gender","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_age","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_consumption_2","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_is_occupied","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_geography","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_intentions","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_brands","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}},{"name":"user_categories","shape":[10,1],"datatype":"INT32","parameters":{"binary_data_size":40}}],"parameters":{"binary_data_output":true}}\x0c\x00\x00\x00\x0c\x00\x00\x00\x12\x00\x00\x00#\x00\x00\x003\x00\x00\x00\x16\x00\x00\x00\r\x00\x00\x00\n\x00\x00\x00\x04\x00\x00\x00\x04\x00\x00\x00\x02\x00\x00\x00\x1e\x00\x00\x00\x05\x00\x00\x00\x06\x00\x00\x00\x0b\x00\x00\x00S\x00\x00\x00&\x00\x00\x00\x07\x00\x00\x00\x04\x00\x00\x00\x18\x00\x00\x00\x03\x00\x00\x00P\x00\x00\x00\x0c\x00\x00\x00\x0e\x00\x00\x00\x1c\x00\x00\x00\xe2\x00\x00\x00f\x00\x00\x00\x11\x00\x00\x00\t\x00\x00\x00@\x00\x00\x00\xc2\x00\x00\x00\xf5\x15\x00\x00\x08\x03\x00\x00\xca\x03\x00\x00\x93\x07\x00\x00\x15>\x00\x00\x04\x1c\x00\x00\x8b\x04\x00\x00F\x02\x00\x00j\x11\x00\x00C\x00\x00\x00\x90\x07\x00\x00\x0b\x01\x00\x00N\x01\x00\x00\x9c\x02\x00\x00b\x15\x00\x00\xa6\t\x00\x00\x91\x01\x00\x00\xc9\x00\x00\x00\x00\x06\x00\x00|\x02\x00\x00|\x02\x00\x00\xd7\x03\x00\x00\xad\x07\x00\x00J\x0b\x00\x00\xbe\x04\x00\x00\xb6\x02\x00\x00\t\x02\x00\x00\xae\x00\x00\x00\xae\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\xb8\x00\x00\x00\xb8\x00\x00\x00\x1d\x01\x00\x009\x02\x00\x00E\x03\x00\x00`\x01\x00\x00\xc9\x00\x00\x00\x97\x00\x00\x003\x00\x00\x003\x00\x00\x00<\x01\x00\x00<\x01\x00\x00\xe9\x01\x00\x00\xd1\x03\x00\x00\x9c\x05\x00\x00\\\x02\x00\x00Y\x01\x00\x00\x03\x01\x00\x00W\x00\x00\x00W\x00\x00\x00"\x00\x00\x00"\x00\x00\x004\x00\x00\x00g\x00\x00\x00\x97\x00\x00\x00@\x00\x00\x00%\x00\x00\x00\x1c\x00\x00\x00\n\x00\x00\x00\n\x00\x00\x00'

Triton uses the KServe community standard inference protocols. Here, we use the binary+json format for optimal performance in the inference request.

In order for Triton to correctly parse the binary payload, we have to specify the length of the request metadata in the header json-header-size.

runtime_sm_client = boto3.client("sagemaker-runtime")

response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=f"application/vnd.sagemaker-triton.binary+json;json-header-size={header_length}",
    Body=request_body,
)

# Parse json header size length from the response
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
header_length_str = response["ContentType"][len(header_length_prefix):]

# Read response body
result = httpclient.InferenceServerClient.parse_response_body(
    response["Body"].read(), header_length=int(header_length_str)
)
output_data = result.as_numpy("click/binary_classification_task")
print("predicted sigmoid result:\n", output_data)
predicted sigmoid result:
 [[0.48595208]
 [0.4647554 ]
 [0.50048226]
 [0.53553176]
 [0.5209902 ]
 [0.54944164]
 [0.5032344 ]
 [0.475241  ]
 [0.5077254 ]
 [0.5009623 ]]

Terminate endpoint and clean up artifacts#

Don’t forget to clean up artifacts and terminate the endpoint, or the endpoint will continue to incur costs.

sm_client.delete_model(ModelName=model_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_endpoint(EndpointName=endpoint_name)
{'ResponseMetadata': {'RequestId': '6ad24616-5c7c-4525-a63c-62d1b06ee8ad',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '6ad24616-5c7c-4525-a63c-62d1b06ee8ad',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '0',
   'date': 'Wed, 09 Nov 2022 10:38:12 GMT'},
  'RetryAttempts': 0}}