# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.
http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png

Running on multiple GPUs or on CPU

This notebook is created using the latest stable merlin-tensorflow container.

Overview

In this notebook we will look at running NVTabular operations on multiple GPUs or just on the CPU.

NVTabular supports switching easily between multi-GPU, single GPU and CPU with only changing a parameter or two. A common use-case is to develop locally on the CPU and then deploy the NVTabular workflow in the cloud on a multi-GPU cluster.

The default behavior is to use a single GPU if available, otherwise to run on the CPU. However, moving to multiple GPUs can offer speedups by 100-1000x vs CPU only workflows (you can read more about this in our blog post). Still the key word here is having options – there will be some workloads you might want to run on multiple GPUs, a single GPU, or maybe even on your laptop with only a couple of CPU cores. NVTabular facilitates all these scenarios.

Learning objectives

  • Setting up a dask cluster and executing transformations on multiple GPUs

  • Running CPU only workflows

Downloading the dataset

import os
from merlin.datasets.entertainment import get_movielens

input_path = os.environ.get("INPUT_DATA_DIR", os.path.expanduser("~/merlin-framework/movielens/"))
train, valid = get_movielens(variant="ml-1m", path=input_path); #noqa
2022-09-13 07:22:51.069388: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-13 07:22:51.069848: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-13 07:22:51.069987: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
downloading ml-1m.zip: 5.93MB [00:02, 1.98MB/s]                                                                                                                                                                                                                                                                                                                                           
unzipping files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 58.48files/s]
/usr/local/lib/python3.8/dist-packages/pandas/util/_decorators.py:311: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
  return func(*args, **kwargs)
INFO:merlin.datasets.entertainment.movielens.dataset:starting ETL..
/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
  warnings.warn(

Running on multiple-GPUs

Multi-GPU and multi-node scaling

NVTabular is built on top off RAPIDS.AI cuDF, dask_cudf and dask.

Dask is a task-based library for parallel scheduling and execution. Although it is certainly possible to use the task-scheduling machinery directly to implement customized parallel workflows (we do it in NVTabular), most users only interact with Dask through a Dask Collection API. The most popular “collection” API’s include:

  • Dask DataFrame: Dask-based version of the Pandas DataFrame/Series API. Note that dask_cudf is just a wrapper around this collection module (dask.dataframe).

  • Dask Array: Dask-based version of the NumPy array API

  • Dask Bag: Similar to a Dask-based version of PyToolz or a Pythonic version of PySpark RDD

For example, Dask DataFrame provides a convenient API for decomposing large Pandas (or cuDF) DataFrame/Series objects into a collection of DataFrame partitions.

../_images/dask-dataframe.svg

We use dask_cudf to process large datasets as a collection of cuDF dataframes instead of Pandas. CuDF is a GPU DataFrame library for loading, joining, aggregating, filtering, and otherwise manipulating data.

Dask enables easily to schedule tasks for multiple workers: multi-GPU or multi-node. We just need to initialize a Dask cluster (LocalCUDACluster) and NVTabular will use the cluster to execute the workflow.

Starting a dask cluster

import numba
import warnings
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import nvtabular as nvt
from merlin.core.compat import pynvml_mem_size, device_mem_size

dask_workdir = "test_dask/workdir"

The following code will automatically generate the parameters for the local CUDA cluster. It will infer the number of GPUs, calculate memory limits that work across a vast array of scenarios, and so on.

# Dask dashboard
dashboard_port = "8787"

# Deploy a Single-Machine Multi-GPU Cluster
protocol = "tcp"  # "tcp" or "ucx"

if numba.cuda.is_available():
    NUM_GPUS = list(range(len(numba.cuda.gpus)))
else:
    NUM_GPUS = []
try:
    visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
except KeyError:
    visible_devices = ",".join([str(n) for n in NUM_GPUS])  # Delect devices to place workers
device_limit_frac = 0.7  # Spill GPU-Worker memory to host at this limit.
device_pool_frac = 0.8
part_mem_frac = 0.15

# Use total device size to calculate args.device_limit_frac
device_size = device_mem_size(kind="total")
device_limit = int(device_limit_frac * device_size)
device_pool_size = int(device_pool_frac * device_size)
part_size = int(part_mem_frac * device_size)

# Check if any device memory is already occupied
if NUM_GPUS:
    devices = visible_devices.split(",")
else:
    devices = []
for dev in devices:
    fmem = pynvml_mem_size(kind="free", index=int(dev))
    used = (device_size - fmem) / 1e9
    if used > 1.0:
        warnings.warn(f"BEWARE - {used} GB is already occupied on device {int(dev)}!")

cluster = None  # (Optional) Specify existing scheduler port
/tmp/ipykernel_11/2427665162.py:26: UserWarning: BEWARE - 1.25140992 GB is already occupied on device 0!
  warnings.warn(f"BEWARE - {used} GB is already occupied on device {int(dev)}!")

We can now initialize the CUDA cluster.

if cluster is None and NUM_GPUS:
    cluster = LocalCUDACluster(
        protocol=protocol,
        n_workers=len(visible_devices.split(",")),
        CUDA_VISIBLE_DEVICES=visible_devices,
        device_memory_limit=device_limit,
        local_directory=dask_workdir,
        dashboard_address=":" + dashboard_port,
        rmm_pool_size=(device_pool_size // 256) * 256
    )
2022-09-13 07:23:07,266 - distributed.diskutils - INFO - Found stale lock file and directory '/workspace/examples/test_dask/workdir/dask-worker-space/worker-kymfv__r', purging
2022-09-13 07:23:07,266 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize

We can now start the local cluster.

Before we do so, please take a look at the options available to us in the Client(...) constructor. Instead of initializing a cluster locally, another option available to us is connecting to a remote CUDA cluster. Such cluster might not only include multiple GPUs, but can also span multiple nodes. This enables scaling to running on arbitrarily large data.

if cluster:
    client = Client(cluster)
else:
    client = Client(processes=False)
client.cluster
2022-08-26 01:26:46,865 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-08-26 01:26:46,950 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-08-26 01:26:47,005 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-08-26 01:26:47,055 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize

And that’s it! All we have to do is define the cluster, and NVTabular will automatically run the workload on available hardware!

Let’s put this to a test.

Defining and running a Workflow on multiple GPUs

categories = ['userId', 'movieId', 'zipcode'] >> nvt.ops.Categorify(freq_threshold=10)
age = ['age'] >> nvt.ops.Bucketize([0, 10, 21, 45])

example_workflow = nvt.Workflow(categories + age)
example_workflow.fit_transform(train).to_parquet('train')
example_workflow.transform(valid).to_parquet('valid')

We can see below that data has been loaded onto all our GPUs. All of them have been utilized in running the calculations.

!nvidia-smi
Fri Aug 26 01:26:54 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   35C    P0    71W / 160W |  14349MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   33C    P0    49W / 160W |  13568MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  On   | 00000000:0A:00.0 Off |                    0 |
| N/A   33C    P0    48W / 160W |  13568MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  On   | 00000000:0B:00.0 Off |                    0 |
| N/A   34C    P0    49W / 160W |  13568MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

Running on CPU

How do we run the workflow only on the CPU? To do so, we create our Datasets and specify that they should be backed by the CPU. Neither GPU memory, nor GPU processing, will be utilized.

train = nvt.Dataset('train', engine='parquet', cpu=True)
valid = nvt.Dataset('valid', engine='parquet', cpu=True)
/usr/local/lib/python3.8/dist-packages/merlin/io/dataset.py:251: UserWarning: Initializing an NVTabular Dataset in CPU mode.This is an experimental feature with extremely limited support!
  warnings.warn(

We can now execute the workflow on the CPU.

example_workflow = nvt.Workflow(categories + age)
example_workflow.fit_transform(train)
example_workflow.transform(valid)
<merlin.io.dataset.Dataset at 0x7f2930b63e80>

In summary, if you would like to create a Dataset directly on the CPU, you can do so via passing True as the cpu parameter into the constructor as follows.

nvt.Dataset(..., cpu=True)

Summary

NVTabular works seamlessly across a variety of settings. NVTabular operators can be run on the CPU and scale to accommodate multi-GPU or multi-node clusters with minimum amount of configuration required.