# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png

Multi-GPU with MovieLens: ETL and Training

Overview

NVIDIA Merlin is a open source framework to accelerate and scale end-to-end recommender system pipelines on GPU. In this notebook, we use NVTabular, Merlin’s ETL component, to scale feature engineering and pre-processing to multiple GPUs and then perform data-parallel distributed training of a neural network on multiple GPUs with PyTorch, Horovod, and NCCL.

The pre-requisites for this notebook are to be familiar with NVTabular and its API:

In this notebook, we will focus only on the new information related to multi-GPU training, so please check out the other notebooks first (if you haven’t already.)

Learning objectives

In this notebook, we learn how to scale ETL and deep learning taining to multiple GPUs

  • Learn to use larger than GPU/host memory datasets for ETL and training

  • Use multi-GPU or multi node for ETL with NVTabular

  • Use NVTabular dataloader to accelerate PyTorch pipelines

  • Scale PyTorch training with Horovod

Dataset

In this notebook, we use the MovieLens25M dataset. It is popular for recommender systems and is used in academic publications. The dataset contains 25M movie ratings for 62,000 movies given by 162,000 users. Many projects use only the user/item/rating information of MovieLens, but the original dataset provides metadata for the movies, as well.

Note: We are using the MovieLens 25M dataset in this example for simplicity, although the dataset is not large enough to require multi-GPU training. However, the functionality demonstrated in this notebook can be easily extended to scale recommender pipelines for larger datasets in the same way.

Tools

Download and Convert

First, we will download and convert the dataset to Parquet. This section is based on 01-Download-Convert.ipynb.

Download

# External dependencies
import os
import pathlib

import cudf  # cuDF is an implementation of Pandas-like Dataframe on GPU

from nvtabular.utils import download_file

INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "~/nvt-examples/multigpu-movielens/data/")
BASE_DIR = pathlib.Path(INPUT_DATA_DIR).expanduser()
zip_path = pathlib.Path(BASE_DIR, "ml-25m.zip")
download_file(
    "http://files.grouplens.org/datasets/movielens/ml-25m.zip", zip_path, redownload=False
)
downloading ml-25m.zip: 262MB [01:11, 3.66MB/s]                                                                                             
unzipping files: 100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:09<00:00,  1.16s/files]

Convert

movies = cudf.read_csv(pathlib.Path(BASE_DIR, "ml-25m", "movies.csv"))
movies["genres"] = movies["genres"].str.split("|")
movies = movies.drop("title", axis=1)
movies.to_parquet(pathlib.Path(BASE_DIR, "ml-25m", "movies_converted.parquet"))

Split into train and validation datasets

ratings = cudf.read_csv(pathlib.Path(BASE_DIR, "ml-25m", "ratings.csv"))
ratings = ratings.drop("timestamp", axis=1)

# shuffle the dataset
ratings = ratings.sample(len(ratings), replace=False)
# split the train_df as training and validation data sets.
num_valid = int(len(ratings) * 0.2)
train = ratings[:-num_valid]
valid = ratings[-num_valid:]

train.to_parquet(pathlib.Path(BASE_DIR, "train.parquet"))
valid.to_parquet(pathlib.Path(BASE_DIR, "valid.parquet"))

ETL with NVTabular

We finished downloading and converting the dataset. We will preprocess and engineer features with NVTabular on multiple GPUs. You can read more

Deploy a Distributed-Dask Cluster

This section is based on scaling-criteo/02-ETL-with-NVTabular.ipynb and multi-gpu-toy-example/multi-gpu_dask.ipynb

# Standard Libraries
import shutil

# External Dependencies
import cudf
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import rmm

# NVTabular
import nvtabular as nvt
from nvtabular.io import Shuffle
from nvtabular.utils import device_mem_size
# define some information about where to get our data
input_path = pathlib.Path(BASE_DIR, "converted", "movielens")
dask_workdir = pathlib.Path(BASE_DIR, "test_dask", "workdir")
output_path = pathlib.Path(BASE_DIR, "test_dask", "output")
stats_path = pathlib.Path(BASE_DIR, "test_dask", "stats")

# Make sure we have a clean worker space for Dask
if pathlib.Path.is_dir(dask_workdir):
    shutil.rmtree(dask_workdir)
dask_workdir.mkdir(parents=True)

# Make sure we have a clean stats space for Dask
if pathlib.Path.is_dir(stats_path):
    shutil.rmtree(stats_path)
stats_path.mkdir(parents=True)

# Make sure we have a clean output path
if pathlib.Path.is_dir(output_path):
    shutil.rmtree(output_path)
output_path.mkdir(parents=True)

# Get device memory capacity
capacity = device_mem_size(kind="total")
# Deploy a Single-Machine Multi-GPU Cluster
protocol = "tcp"  # "tcp" or "ucx"
visible_devices = "0,1"  # Delect devices to place workers
device_spill_frac = 0.5  # Spill GPU-Worker memory to host at this limit.
# Reduce if spilling fails to prevent
# device memory errors.
cluster = None  # (Optional) Specify existing scheduler port
if cluster is None:
    cluster = LocalCUDACluster(
        protocol=protocol,
        CUDA_VISIBLE_DEVICES=visible_devices,
        local_directory=dask_workdir,
        device_memory_limit=capacity * device_spill_frac,
    )

# Create the distributed client
client = Client(cluster)
client
/opt/conda/lib/python3.8/site-packages/distributed/node.py:160: UserWarning: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 33621 instead
  warnings.warn(
distributed.preloading - INFO - Import preload module: dask_cuda.initialize
distributed.preloading - INFO - Import preload module: dask_cuda.initialize

Client

Client-1186085c-5cc7-11ec-8176-54ab3a8c9e18

Connection method: Cluster object Cluster type: LocalCUDACluster
Dashboard: http://127.0.0.1:33621/status

Cluster Info

# Initialize RMM pool on ALL workers
def _rmm_pool():
    rmm.reinitialize(
        pool_allocator=True,
        initial_pool_size=None,  # Use default size
    )


client.run(_rmm_pool)
{'tcp://127.0.0.1:34485': None, 'tcp://127.0.0.1:36269': None}

Defining our Preprocessing Pipeline

This subsection is based on getting-started-movielens/02-ETL-with-NVTabular.ipynb.

movies = cudf.read_parquet(pathlib.Path(BASE_DIR, "ml-25m", "movies_converted.parquet"))
joined = ["userId", "movieId"] >> nvt.ops.JoinExternal(movies, on=["movieId"])
cat_features = joined >> nvt.ops.Categorify()
ratings = nvt.ColumnSelector(["rating"]) >> nvt.ops.LambdaOp(lambda col: (col > 3).astype("int8"))
output = cat_features + ratings
workflow = nvt.Workflow(output)
!rm -rf $BASE_DIR/train
!rm -rf $BASE_DIR/valid
train_iter = nvt.Dataset([str(pathlib.Path(BASE_DIR, "train.parquet"))], part_size="100MB")
valid_iter = nvt.Dataset([str(pathlib.Path(BASE_DIR, "valid.parquet"))], part_size="100MB")
workflow.fit(train_iter)
workflow.save(str(pathlib.Path(BASE_DIR, "workflow")))
shuffle = Shuffle.PER_WORKER  # Shuffle algorithm
out_files_per_proc = 4  # Number of output files per worker
workflow.transform(train_iter).to_parquet(
    output_path=pathlib.Path(BASE_DIR, "train"),
    shuffle=shuffle,
    out_files_per_proc=out_files_per_proc,
)
workflow.transform(valid_iter).to_parquet(
    output_path=pathlib.Path(BASE_DIR, "valid"),
    shuffle=shuffle,
    out_files_per_proc=out_files_per_proc,
)

client.shutdown()
cluster.close()
/opt/conda/lib/python3.8/site-packages/distributed/worker.py:3801: UserWarning: Large object of size 3.80 MiB detected in task graph: 
  ([<Node JoinExternal>], 'read-parquet-b1cb9ccb4ca4 ...  1, 2], None)})
Consider scattering large objects ahead of time
with client.scatter to reduce scheduler burden and 
keep data on workers

    future = client.submit(func, big_data)    # bad

    big_future = client.scatter(big_data)     # good
    future = client.submit(func, big_future)  # good
  warnings.warn(
/nvtabular/nvtabular/io/dataset.py:868: UserWarning: Only created 2 files did not have enough
partitions to create 8 files.
  warnings.warn(

Training with PyTorch on multiGPUs

In this section, we will train a PyTorch model with multi-GPU support. In the NVTabular v0.5 release, we added multi-GPU support for NVTabular dataloaders. We will modify the getting-started-movielens/03-Training-with-PyTorch.ipynb to use multiple GPUs. Please review that notebook, if you have questions about the general functionality of the NVTabular dataloaders or the neural network architecture.

NVTabular dataloader for PyTorch

We’ve identified that the dataloader is one bottleneck in deep learning recommender systems when training pipelines with PyTorch. The normal PyTorch dataloaders cannot prepare the next training batches fast enough and therefore, the GPU is not fully utilized.

We developed a highly customized tabular dataloader for accelerating existing pipelines in PyTorch. In our experiments, we see a speed-up by 9x of the same training workflow with NVTabular dataloader. NVTabular dataloader’s features are:

  • removing bottleneck of item-by-item dataloading

  • enabling larger than memory dataset by streaming from disk

  • reading data directly into GPU memory and remove CPU-GPU communication

  • preparing batch asynchronously in GPU to avoid CPU-GPU communication

  • supporting commonly used .parquet format

  • easy integration into existing PyTorch pipelines by using similar API

  • supporting multi-GPU training with Horovod

You can find more information on the dataloaders in our blogpost.

Using Horovod with PyTorch and NVTabular

The training script below is based on getting-started-movielens/03-Training-with-PyTorch.ipynb, with a few important changes:

  • We provide several additional parameters to the TorchAsyncItr class, including the total number of workers hvd.size(), the current worker’s id number hvd.rank(), and a function for generating random seeds seed_fn().

    train_dataset = TorchAsyncItr(
        ...
        global_size=hvd.size(),
        global_rank=hvd.rank(),
        seed_fn=seed_fn,
    )
  • The seed function uses Horovod to collectively generate a random seed that’s shared by all workers so that they can each shuffle the dataset in a consistent way and select partitions to work on without overlap. The seed function is called by the dataloader during the shuffling process at the beginning of each epoch:

    def seed_fn():
        max_rand = torch.iinfo(torch.int).max // hvd.size()

        # Generate a seed fragment
        seed_fragment = cupy.random.randint(0, max_rand)

        # Aggregate seed fragments from all Horovod workers
        seed_tensor = torch.tensor(seed_fragment)
        reduced_seed = hvd.allreduce(seed_tensor, name="shuffle_seed", op=hvd.mpi_ops.Sum)

        return reduced_seed % max_rand
  • We wrap the PyTorch optimizer with Horovod’s DistributedOptimizer class and scale the learning rate by the number of workers:

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01 * lr_scaler)
    optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
  • We broadcast the model and optimizer parameters to all workers with Horovod:

    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

The rest of the script is the same as the MovieLens example in getting-started-movielens/03-Training-with-PyTorch.ipynb. In order to run it with Horovod, we first need to write it to a file.

%%writefile './torch_trainer.py'

import argparse
import glob
import os
from time import time

import cupy
import torch

import nvtabular as nvt
from nvtabular.framework_utils.torch.models import Model
from nvtabular.framework_utils.torch.utils import process_epoch
from nvtabular.loader.torch import DLDataLoader, TorchAsyncItr

# Horovod must be the last import to avoid conflicts
import horovod.torch as hvd  # noqa: E402, isort:skip


parser = argparse.ArgumentParser(description="Train a multi-gpu model with Torch and Horovod")
parser.add_argument("--dir_in", default=None, help="Input directory")
parser.add_argument("--batch_size", default=None, help="Batch size")
parser.add_argument("--cats", default=None, help="Categorical columns")
parser.add_argument("--cats_mh", default=None, help="Categorical multihot columns")
parser.add_argument("--conts", default=None, help="Continuous columns")
parser.add_argument("--labels", default=None, help="Label columns")
parser.add_argument("--epochs", default=1, help="Training epochs")
args = parser.parse_args()

hvd.init()

gpu_to_use = hvd.local_rank()

if torch.cuda.is_available():
    torch.cuda.set_device(gpu_to_use)


BASE_DIR = os.path.expanduser(args.dir_in or "./data/")
BATCH_SIZE = int(args.batch_size or 16384)  # Batch Size
CATEGORICAL_COLUMNS = args.cats or ["movieId", "userId"]  # Single-hot
CATEGORICAL_MH_COLUMNS = args.cats_mh or ["genres"]  # Multi-hot
NUMERIC_COLUMNS = args.conts or []

# Output from ETL-with-NVTabular
TRAIN_PATHS = sorted(glob.glob(os.path.join(BASE_DIR, "train", "*.parquet")))

proc = nvt.Workflow.load(os.path.join(BASE_DIR, "workflow/"))

EMBEDDING_TABLE_SHAPES = nvt.ops.get_embedding_sizes(proc)


# TensorItrDataset returns a single batch of x_cat, x_cont, y.
def collate_fn(x):
    return x


# Seed with system randomness (or a static seed)
cupy.random.seed(None)


def seed_fn():
    """
    Generate consistent dataloader shuffle seeds across workers

    Reseeds each worker's dataloader each epoch to get fresh a shuffle
    that's consistent across workers.
    """

    max_rand = torch.iinfo(torch.int).max // hvd.size()

    # Generate a seed fragment
    seed_fragment = cupy.random.randint(0, max_rand)

    # Aggregate seed fragments from all Horovod workers
    seed_tensor = torch.tensor(seed_fragment)
    reduced_seed = hvd.allreduce(seed_tensor, name="shuffle_seed", op=hvd.mpi_ops.Sum)

    return reduced_seed % max_rand


train_dataset = TorchAsyncItr(
    nvt.Dataset(TRAIN_PATHS),
    batch_size=BATCH_SIZE,
    cats=CATEGORICAL_COLUMNS + CATEGORICAL_MH_COLUMNS,
    conts=NUMERIC_COLUMNS,
    labels=["rating"],
    device=gpu_to_use,
    global_size=hvd.size(),
    global_rank=hvd.rank(),
    shuffle=True,
    seed_fn=seed_fn,
)
train_loader = DLDataLoader(
    train_dataset, batch_size=None, collate_fn=collate_fn, pin_memory=False, num_workers=0
)


if isinstance(EMBEDDING_TABLE_SHAPES, tuple):
    EMBEDDING_TABLE_SHAPES_TUPLE = (
        {
            CATEGORICAL_COLUMNS[0]: EMBEDDING_TABLE_SHAPES[0][CATEGORICAL_COLUMNS[0]],
            CATEGORICAL_COLUMNS[1]: EMBEDDING_TABLE_SHAPES[0][CATEGORICAL_COLUMNS[1]],
        },
        {CATEGORICAL_MH_COLUMNS[0]: EMBEDDING_TABLE_SHAPES[1][CATEGORICAL_MH_COLUMNS[0]]},
    )
else:
    EMBEDDING_TABLE_SHAPES_TUPLE = (
        {
            CATEGORICAL_COLUMNS[0]: EMBEDDING_TABLE_SHAPES[CATEGORICAL_COLUMNS[0]],
            CATEGORICAL_COLUMNS[1]: EMBEDDING_TABLE_SHAPES[CATEGORICAL_COLUMNS[1]],
        },
        {CATEGORICAL_MH_COLUMNS[0]: EMBEDDING_TABLE_SHAPES[CATEGORICAL_MH_COLUMNS[0]]},
    )

model = Model(
    embedding_table_shapes=EMBEDDING_TABLE_SHAPES_TUPLE,
    num_continuous=0,
    emb_dropout=0.0,
    layer_hidden_dims=[128, 128, 128],
    layer_dropout_rates=[0.0, 0.0, 0.0],
).cuda()

lr_scaler = hvd.size()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01 * lr_scaler)

hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())

for epoch in range(args.epochs):
    start = time()
    print(f"Training epoch {epoch}")
    train_loss, y_pred, y = process_epoch(train_loader, model, train=True, optimizer=optimizer)
    hvd.join(gpu_to_use)
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    print(f"Epoch {epoch:02d}. Train loss: {train_loss:.4f}.")
    hvd.join(gpu_to_use)
    t_final = time() - start
    total_rows = train_dataset.num_rows_processed
    print(
        f"run_time: {t_final} - rows: {total_rows} - "
        f"epochs: {epoch} - dl_thru: {total_rows / t_final}"
    )


hvd.join(gpu_to_use)
if hvd.local_rank() == 0:
    print("Training complete")
Overwriting ./torch_trainer.py
!horovodrun -np 2 python torch_trainer.py --dir_in $BASE_DIR --batch_size 16384
[1,0]<stdout>:Training epoch 0
[1,1]<stdout>:Training epoch 0
distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
asyncio.exceptions.CancelledError
[1,1]<stdout>:Total batches: 488
[1,1]<stderr>:[2021-12-14 10:18:27.983077: E /tmp/pip-install-mwp6l21a/horovod_1f03263b83654efeb4c82a546f8aadfa/horovod/common/operations.cc:649] [1]: Horovod background loop uncaught exception: CUDA error: invalid configuration argument
[1,1]<stderr>:CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
[1,1]<stderr>:For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
[1,1]<stderr>:Exception raised from launch_vectorized_kernel at /opt/pytorch/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh:103 (most recent call first):
[1,1]<stderr>:frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6c (0x7f07b599fe1c in /opt/conda/lib/python3.8/site-packages/torch/lib/libc10.so)
[1,1]<stderr>:frame #1: void at::native::gpu_kernel_impl<at::native::FillFunctor<float> >(at::TensorIteratorBase&, at::native::FillFunctor<float> const&) + 0xcfc (0x7f07b75db72c in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
[1,1]<stderr>:frame #2: void at::native::gpu_kernel<at::native::FillFunctor<float> >(at::TensorIteratorBase&, at::native::FillFunctor<float> const&) + 0x33b (0x7f07b75dc12b in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
[1,1]<stderr>:frame #3: <unknown function> + 0x1b9dbd3 (0x7f07b75c8bd3 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
[1,1]<stderr>:frame #4: at::native::fill_kernel_cuda(at::TensorIterator&, c10::Scalar const&) + 0x34 (0x7f07b75c9aa4 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
[1,1]<stderr>:frame #5: <unknown function> + 0x143a915 (0x7f0801bd8915 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
[1,1]<stderr>:frame #6: <unknown function> + 0x1029a89 (0x7f07b6a54a89 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
[1,1]<stderr>:frame #7: at::_ops::fill__Scalar::call(at::Tensor&, c10::Scalar const&) + 0x131 (0x7f08020e4611 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
[1,1]<stderr>:frame #8: at::native::zero_(at::Tensor&) + 0x7f (0x7f0801bd804f in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
[1,1]<stderr>:frame #9: <unknown function> + 0x1037095 (0x7f07b6a62095 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
[1,1]<stderr>:frame #10: at::_ops::zero_::call(at::Tensor&) + 0x12a (0x7f08020dd4ca in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
[1,1]<stderr>:frame #11: at::native::zeros(c10::ArrayRef<long>, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>) + 0x130 (0x7f0801df79b0 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
[1,1]<stderr>:frame #12: <unknown function> + 0x20089c9 (0x7f08027a69c9 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
[1,1]<stderr>:frame #13: <unknown function> + 0x1e00d5e (0x7f080259ed5e in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
[1,1]<stderr>:frame #14: <unknown function> + 0x1ded287 (0x7f080258b287 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
[1,1]<stderr>:frame #15: at::_ops::zeros::call(c10::ArrayRef<long>, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>) + 0x1ab (0x7f080211070b in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
[1,1]<stderr>:frame #16: horovod::torch::TorchOpContext::AllocateZeros(long, horovod::common::DataType, std::shared_ptr<horovod::common::Tensor>*) + 0xdf (0x7f0783bae9cf in /opt/conda/lib/python3.8/site-packages/horovod/torch/mpi_lib_v2.cpython-38-x86_64-linux-gnu.so)
[1,1]<stderr>:frame #17: horovod::common::TensorQueue::GetTensorEntriesFromResponse(horovod::common::Response const&, std::vector<horovod::common::TensorTableEntry, std::allocator<horovod::common::TensorTableEntry> >&, bool) + 0x60d (0x7f0783b2da3d in /opt/conda/lib/python3.8/site-packages/horovod/torch/mpi_lib_v2.cpython-38-x86_64-linux-gnu.so)
[1,1]<stderr>:frame #18: horovod::common::ResponseCache::put(horovod::common::Response const&, horovod::common::TensorQueue&, bool) + 0x1c6 (0x7f0783b21956 in /opt/conda/lib/python3.8/site-packages/horovod/torch/mpi_lib_v2.cpython-38-x86_64-linux-gnu.so)
[1,1]<stderr>:frame #19: horovod::common::Controller::Co[1,1]<stderr>:mputeResponseList(bool, horovod::common::HorovodGlobalState&, horovod::common::ProcessSet&) + 0x1869 (0x7f0783ae2bd9 in /opt/conda/lib/python3.8/site-packages/horovod/torch/mpi_lib_v2.cpython-38-x86_64-linux-gnu.so)
[1,1]<stderr>:frame #20: <unknown function> + 0x9cb3b (0x7f0783b05b3b in /opt/conda/lib/python3.8/site-packages/horovod/torch/mpi_lib_v2.cpython-38-x86_64-linux-gnu.so)
[1,1]<stderr>:frame #21: <unknown function> + 0xcc9d4 (0x7f08d9efb9d4 in /opt/conda/bin/../lib/libstdc++.so.6)
[1,1]<stderr>:frame #22: <unknown function> + 0x9609 (0x7f09a71e1609 in /usr/lib/x86_64-linux-gnu/libpthread.so.0)
[1,1]<stderr>:frame #23: clone + 0x43 (0x7f09a6fa1293 in /usr/lib/x86_64-linux-gnu/libc.so.6)
[1,1]<stderr>:
[1,1]<stderr>:/opt/conda/lib/python3.8/site-packages/numba/cuda/compiler.py:865: NumbaPerformanceWarning: Grid size (13) < 2 * SM count (160) will likely result in GPU under utilization due to low occupancy.
[1,1]<stderr>:  warn(NumbaPerformanceWarning(msg))
[1,1]<stderr>:Traceback (most recent call last):
[1,1]<stderr>:  File "/opt/conda/lib/python3.8/site-packages/horovod/torch/mpi_ops.py", line 944, in synchronize
[1,1]<stderr>:    mpi_lib.horovod_torch_wait_and_clear(handle)
[1,1]<stderr>:RuntimeError: Horovod has been shut down. This was caused by an exception on one of the ranks or an attempt to allreduce, allgather or broadcast a tensor after one of the ranks finished execution. If the shutdown was caused by an exception, you should see the exception in the log before the first shutdown message.
[1,1]<stderr>:
[1,1]<stderr>:During handling of the above exception, another exception occurred:
[1,1]<stderr>:
[1,1]<stderr>:Traceback (most recent call last):
[1,1]<stderr>:  File "torch_trainer.py", line 138, in <module>
[1,1]<stderr>:    hvd.join(gpu_to_use)
[1,1]<stderr>:  File "/opt/conda/lib/python3.8/site-packages/horovod/torch/mpi_ops.py", line 972, in join
[1,1]<stderr>:    return synchronize(handle).item()
[1,1]<stderr>:  File "/opt/conda/lib/python3.8/site-packages/horovod/torch/mpi_ops.py", line 949, in synchronize
[1,1]<stderr>:    raise HorovodInternalError(e)
[1,1]<stderr>:horovod.common.exceptions.HorovodInternalError: Horovod has been shut down. This was caused by an exception on one of the ranks or an attempt to allreduce, allgather or broadcast a tensor after one of the ranks finished execution. If the shutdown was caused by an exception, you should see the exception in the log before the first shutdown message.
^C