# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png

Multi-GPU Training with TensorFlow on MovieLens

Overview

NVIDIA Merlin is a open source framework to accelerate and scale end-to-end recommender system pipelines on GPU. In this notebook, we use NVTabular, Merlin’s ETL component, to scale feature engineering and pre-processing to multiple GPUs and then perform data-parallel distributed training of a neural network on multiple GPUs with TensorFlow, Horovod, and NCCL.

The pre-requisites for this notebook are to be familiar with NVTabular and its API:

In this notebook, we will focus only on the new information related to multi-GPU training, so please check out the other notebooks first (if you haven’t already.)

Learning objectives

In this notebook, we learn how to scale ETL and deep learning taining to multiple GPUs

  • Learn to use larger than GPU/host memory datasets for ETL and training

  • Use multi-GPU or multi node for ETL with NVTabular

  • Use NVTabular dataloader to accelerate TensorFlow pipelines

  • Scale TensorFlow training with Horovod

Dataset

In this notebook, we use the MovieLens25M dataset. It is popular for recommender systems and is used in academic publications. The dataset contains 25M movie ratings for 62,000 movies given by 162,000 users. Many projects use only the user/item/rating information of MovieLens, but the original dataset provides metadata for the movies, as well.

Note: We are using the MovieLens 25M dataset in this example for simplicity, although the dataset is not large enough to require multi-GPU training. However, the functionality demonstrated in this notebook can be easily extended to scale recommender pipelines for larger datasets in the same way.

Tools

Download and Convert

First, we will download and convert the dataset to Parquet. This section is based on 01-Download-Convert.ipynb.

Download

# External dependencies
import os
import pathlib

import cudf  # cuDF is an implementation of Pandas-like Dataframe on GPU

from merlin.core.utils import download_file

INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", "~/nvt-examples/multigpu-movielens/data/"
)
BASE_DIR = pathlib.Path(INPUT_DATA_DIR).expanduser()
zip_path = pathlib.Path(BASE_DIR, "ml-25m.zip")
download_file(
    "http://files.grouplens.org/datasets/movielens/ml-25m.zip", zip_path, redownload=False
)
downloading ml-25m.zip: 262MB [00:06, 41.9MB/s]                                                                                                                                            
unzipping files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.74files/s]

Convert

movies = cudf.read_csv(pathlib.Path(BASE_DIR, "ml-25m", "movies.csv"))
movies["genres"] = movies["genres"].str.split("|")
movies = movies.drop("title", axis=1)
movies.to_parquet(pathlib.Path(BASE_DIR, "ml-25m", "movies_converted.parquet"))

Split into train and validation datasets

ratings = cudf.read_csv(pathlib.Path(BASE_DIR, "ml-25m", "ratings.csv"))
ratings = ratings.drop("timestamp", axis=1)

# shuffle the dataset
ratings = ratings.sample(len(ratings), replace=False)
# split the train_df as training and validation data sets.
num_valid = int(len(ratings) * 0.2)
train = ratings[:-num_valid]
valid = ratings[-num_valid:]

train.to_parquet(pathlib.Path(BASE_DIR, "train.parquet"))
valid.to_parquet(pathlib.Path(BASE_DIR, "valid.parquet"))

ETL with NVTabular

We finished downloading and converting the dataset. We will preprocess and engineer features with NVTabular on multiple GPUs. You can read more

Deploy a Distributed-Dask Cluster

This section is based on scaling-criteo/02-ETL-with-NVTabular.ipynb and multi-gpu-toy-example/multi-gpu_dask.ipynb

# Standard Libraries
import shutil

# External Dependencies
import cupy as cp
import numpy as np
import cudf
import dask_cudf
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from dask.utils import parse_bytes
from dask.delayed import delayed
import rmm

# NVTabular
import nvtabular as nvt
import nvtabular.ops as ops
from merlin.io import Shuffle
from merlin.core.utils import device_mem_size
# define some information about where to get our data
input_path = pathlib.Path(BASE_DIR, "converted", "movielens")
dask_workdir = pathlib.Path(BASE_DIR, "test_dask", "workdir")
output_path = pathlib.Path(BASE_DIR, "test_dask", "output")
stats_path = pathlib.Path(BASE_DIR, "test_dask", "stats")

# Make sure we have a clean worker space for Dask
if pathlib.Path.is_dir(dask_workdir):
    shutil.rmtree(dask_workdir)
dask_workdir.mkdir(parents=True)

# Make sure we have a clean stats space for Dask
if pathlib.Path.is_dir(stats_path):
    shutil.rmtree(stats_path)
stats_path.mkdir(parents=True)

# Make sure we have a clean output path
if pathlib.Path.is_dir(output_path):
    shutil.rmtree(output_path)
output_path.mkdir(parents=True)

# Get device memory capacity
capacity = device_mem_size(kind="total")
# Deploy a Single-Machine Multi-GPU Cluster
protocol = "tcp"  # "tcp" or "ucx"
visible_devices = "0,1"  # Delect devices to place workers
device_spill_frac = 0.5  # Spill GPU-Worker memory to host at this limit.
# Reduce if spilling fails to prevent
# device memory errors.
cluster = None  # (Optional) Specify existing scheduler port
if cluster is None:
    cluster = LocalCUDACluster(
        protocol=protocol,
        CUDA_VISIBLE_DEVICES=visible_devices,
        local_directory=dask_workdir,
        device_memory_limit=capacity * device_spill_frac,
    )

# Create the distributed client
client = Client(cluster)
client

Client

Cluster

  • Workers: 2
  • Cores: 2
  • Memory: 125.84 GiB
# Initialize RMM pool on ALL workers
def _rmm_pool():
    rmm.reinitialize(
        pool_allocator=True,
        initial_pool_size=None,  # Use default size
    )


client.run(_rmm_pool)
{'tcp://127.0.0.1:40789': None, 'tcp://127.0.0.1:43439': None}

Defining our Preprocessing Pipeline

This subsection is based on getting-started-movielens/02-ETL-with-NVTabular.ipynb.

movies = cudf.read_parquet(pathlib.Path(BASE_DIR, "ml-25m", "movies_converted.parquet"))
joined = ["userId", "movieId"] >> nvt.ops.JoinExternal(movies, on=["movieId"])
cat_features = joined >> nvt.ops.Categorify()
ratings = nvt.ColumnSelector(["rating"]) >> nvt.ops.LambdaOp(lambda col: (col > 3).astype("int8"), dtype=np.int8)
output = cat_features + ratings
workflow = nvt.Workflow(output)
!rm -rf $BASE_DIR/train
!rm -rf $BASE_DIR/valid
train_iter = nvt.Dataset([str(pathlib.Path(BASE_DIR, "train.parquet"))], part_size="100MB")
valid_iter = nvt.Dataset([str(pathlib.Path(BASE_DIR, "valid.parquet"))], part_size="100MB")
workflow.fit(train_iter)
workflow.save(str(pathlib.Path(BASE_DIR, "workflow")))
shuffle = Shuffle.PER_WORKER  # Shuffle algorithm
out_files_per_proc = 4  # Number of output files per worker
workflow.transform(train_iter).to_parquet(
    output_path=pathlib.Path(BASE_DIR, "train"),
    shuffle=shuffle,
    out_files_per_proc=out_files_per_proc,
)
workflow.transform(valid_iter).to_parquet(
    output_path=pathlib.Path(BASE_DIR, "valid"),
    shuffle=shuffle,
    out_files_per_proc=out_files_per_proc,
)

client.shutdown()
cluster.close()
/usr/local/lib/python3.8/dist-packages/distributed/worker.py:3560: UserWarning: Large object of size 1.90 MiB detected in task graph: 
  ("('read-parquet-d36dd514a8adc53a9a91115c9be1d852' ... 1115c9be1d852')
Consider scattering large objects ahead of time
with client.scatter to reduce scheduler burden and 
keep data on workers

    future = client.submit(func, big_data)    # bad

    big_future = client.scatter(big_data)     # good
    future = client.submit(func, big_future)  # good
  warnings.warn(

Training with TensorFlow on multiGPUs

In this section, we will train a TensorFlow model with multi-GPU support. In the NVTabular v0.5 release, we added multi-GPU support for NVTabular dataloaders. We will modify the getting-started-movielens/03-Training-with-TF.ipynb to use multiple GPUs. Please review that notebook, if you have questions about the general functionality of the NVTabular dataloaders or the neural network architecture.

NVTabular dataloader for TensorFlow

We’ve identified that the dataloader is one bottleneck in deep learning recommender systems when training pipelines with TensorFlow. The normal TensorFlow dataloaders cannot prepare the next training batches fast enough and therefore, the GPU is not fully utilized.

We developed a highly customized tabular dataloader for accelerating existing pipelines in TensorFlow. In our experiments, we see a speed-up by 9x of the same training workflow with NVTabular dataloader. NVTabular dataloader’s features are:

  • removing bottleneck of item-by-item dataloading

  • enabling larger than memory dataset by streaming from disk

  • reading data directly into GPU memory and remove CPU-GPU communication

  • preparing batch asynchronously in GPU to avoid CPU-GPU communication

  • supporting commonly used .parquet format

  • easy integration into existing TensorFlow pipelines by using similar API - works with tf.keras models

  • supporting multi-GPU training with Horovod

You can find more information on the dataloaders in our blogpost.

Using Horovod with Tensorflow and NVTabular

The training script below is based on getting-started-movielens/03-Training-with-TF.ipynb, with a few important changes:

  • We provide several additional parameters to the KerasSequenceLoader class, including the total number of workers hvd.size(), the current worker’s id number hvd.rank(), and a function for generating random seeds seed_fn().

    train_dataset_tf = KerasSequenceLoader(
        ...
        global_size=hvd.size(),
        global_rank=hvd.rank(),
        seed_fn=seed_fn,
    )

  • The seed function uses Horovod to collectively generate a random seed that’s shared by all workers so that they can each shuffle the dataset in a consistent way and select partitions to work on without overlap. The seed function is called by the dataloader during the shuffling process at the beginning of each epoch:

    def seed_fn():
        min_int, max_int = tf.int32.limits
        max_rand = max_int // hvd.size()

        # Generate a seed fragment on each worker
        seed_fragment = cupy.random.randint(0, max_rand).get()

        # Aggregate seed fragments from all Horovod workers
        seed_tensor = tf.constant(seed_fragment)
        reduced_seed = hvd.allreduce(seed_tensor, name="shuffle_seed", op=hvd.mpi_ops.Sum) 

        return reduced_seed % max_rand
  • We wrap the TensorFlow optimizer with Horovod’s DistributedOptimizer class and scale the learning rate by the number of workers:

    opt = tf.keras.optimizers.SGD(0.01 * hvd.size())
    opt = hvd.DistributedOptimizer(opt)
  • We wrap the TensorFlow gradient tape with Horovod’s DistributedGradientTape class:

    with tf.GradientTape() as tape:
        ...
    tape = hvd.DistributedGradientTape(tape, sparse_as_dense=True)
  • After the first batch, we broadcast the model and optimizer parameters to all workers with Horovod:

    # Note: broadcast should be done after the first gradient step to
    # ensure optimizer initialization.
    if first_batch:
        hvd.broadcast_variables(model.variables, root_rank=0)
        hvd.broadcast_variables(opt.variables(), root_rank=0)
  • We only save checkpoints from the first worker to avoid multiple workers trying to write to the same files:

    if hvd.rank() == 0:
        checkpoint.save(checkpoint_dir)

The rest of the script is the same as the MovieLens example in getting-started-movielens/03-Training-with-TF.ipynb. In order to run it with Horovod, we first need to write it to a file.

%%writefile './tf_trainer.py'

# External dependencies
import argparse
import glob
import os

import cupy

# we can control how much memory to give tensorflow with this environment variable
# IMPORTANT: make sure you do this before you initialize TF's runtime, otherwise
# TF will have claimed all free GPU memory
os.environ["TF_MEMORY_ALLOCATION"] = "0.3"  # fraction of free memory

import nvtabular as nvt  # noqa: E402 isort:skip
from nvtabular.framework_utils.tensorflow import layers  # noqa: E402 isort:skip
from nvtabular.loader.tensorflow import KerasSequenceLoader  # noqa: E402 isort:skip

import tensorflow as tf  # noqa: E402 isort:skip
import horovod.tensorflow as hvd  # noqa: E402 isort:skip

parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("--dir_in", default=None, help="Input directory")
parser.add_argument("--batch_size", default=None, help="batch size")
parser.add_argument("--cats", default=None, help="categorical columns")
parser.add_argument("--cats_mh", default=None, help="categorical multihot columns")
parser.add_argument("--conts", default=None, help="continuous columns")
parser.add_argument("--labels", default=None, help="continuous columns")
args = parser.parse_args()


BASE_DIR = args.dir_in or "./data/"
BATCH_SIZE = int(args.batch_size or 16384)  # Batch Size
CATEGORICAL_COLUMNS = args.cats or ["movieId", "userId"]  # Single-hot
CATEGORICAL_MH_COLUMNS = args.cats_mh or ["genres"]  # Multi-hot
NUMERIC_COLUMNS = args.conts or []
TRAIN_PATHS = sorted(
    glob.glob(os.path.join(BASE_DIR, "train/*.parquet"))
)  # Output from ETL-with-NVTabular
hvd.init()

# Seed with system randomness (or a static seed)
cupy.random.seed(None)


def seed_fn():
    """
    Generate consistent dataloader shuffle seeds across workers

    Reseeds each worker's dataloader each epoch to get fresh a shuffle
    that's consistent across workers.
    """
    min_int, max_int = tf.int32.limits
    max_rand = max_int // hvd.size()

    # Generate a seed fragment on each worker
    seed_fragment = cupy.random.randint(0, max_rand).get()

    # Aggregate seed fragments from all Horovod workers
    seed_tensor = tf.constant(seed_fragment)
    reduced_seed = hvd.allreduce(seed_tensor, name="shuffle_seed", op=hvd.mpi_ops.Sum)

    return reduced_seed % max_rand


proc = nvt.Workflow.load(os.path.join(BASE_DIR, "workflow/"))
EMBEDDING_TABLE_SHAPES, MH_EMBEDDING_TABLE_SHAPES = nvt.ops.get_embedding_sizes(proc)
EMBEDDING_TABLE_SHAPES.update(MH_EMBEDDING_TABLE_SHAPES)

train_dataset_tf = KerasSequenceLoader(
    TRAIN_PATHS,  # you could also use a glob pattern
    batch_size=BATCH_SIZE,
    label_names=["rating"],
    cat_names=CATEGORICAL_COLUMNS + CATEGORICAL_MH_COLUMNS,
    cont_names=NUMERIC_COLUMNS,
    engine="parquet",
    shuffle=True,
    buffer_size=0.06,  # how many batches to load at once
    parts_per_chunk=1,
    global_size=hvd.size(),
    global_rank=hvd.rank(),
    seed_fn=seed_fn,
)
inputs = {}  # tf.keras.Input placeholders for each feature to be used
emb_layers = []  # output of all embedding layers, which will be concatenated
for col in CATEGORICAL_COLUMNS:
    inputs[col] = tf.keras.Input(name=col, dtype=tf.int32, shape=(1,))
# Note that we need two input tensors for multi-hot categorical features
for col in CATEGORICAL_MH_COLUMNS:
    inputs[col] = \
        (tf.keras.Input(name=f"{col}__values", dtype=tf.int64, shape=(1,)),
         tf.keras.Input(name=f"{col}__nnzs", dtype=tf.int64, shape=(1,)))
for col in CATEGORICAL_COLUMNS + CATEGORICAL_MH_COLUMNS:
    emb_layers.append(
        tf.feature_column.embedding_column(
            tf.feature_column.categorical_column_with_identity(
                col, EMBEDDING_TABLE_SHAPES[col][0]
            ),  # Input dimension (vocab size)
            EMBEDDING_TABLE_SHAPES[col][1],  # Embedding output dimension
        )
    )
emb_layer = layers.DenseFeatures(emb_layers)
x_emb_output = emb_layer(inputs)
x = tf.keras.layers.Dense(128, activation="relu")(x_emb_output)
x = tf.keras.layers.Dense(128, activation="relu")(x)
x = tf.keras.layers.Dense(128, activation="relu")(x)
x = tf.keras.layers.Dense(1, activation="sigmoid")(x)
model = tf.keras.Model(inputs=inputs, outputs=x)
loss = tf.losses.BinaryCrossentropy()
opt = tf.keras.optimizers.SGD(0.01 * hvd.size())
opt = hvd.DistributedOptimizer(opt)
checkpoint_dir = "./checkpoints"
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)


@tf.function(experimental_relax_shapes=True)
def training_step(examples, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = model(examples, training=True)
        loss_value = loss(labels, probs)
    # Horovod: add Horovod Distributed GradientTape.
    tape = hvd.DistributedGradientTape(tape, sparse_as_dense=True)
    grads = tape.gradient(loss_value, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))
    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    #
    # Note: broadcast should be done after the first gradient step to ensure optimizer
    # initialization.
    if first_batch:
        hvd.broadcast_variables(model.variables, root_rank=0)
        hvd.broadcast_variables(opt.variables(), root_rank=0)
    return loss_value


# Horovod: adjust number of steps based on number of GPUs.
for batch, (examples, labels) in enumerate(train_dataset_tf):
    loss_value = training_step(examples, labels, batch == 0)
    if batch % 100 == 0 and hvd.local_rank() == 0:
        print("Step #%d\tLoss: %.6f" % (batch, loss_value))
hvd.join()

# Horovod: save checkpoints only on worker 0 to prevent other workers from
# corrupting it.
if hvd.rank() == 0:
    checkpoint.save(checkpoint_dir)
Overwriting ./tf_trainer.py

We’ll also need a small wrapper script to check environment variables set by the Horovod runner to see which rank we’ll be assigned, in order to set CUDA_VISIBLE_DEVICES properly for each worker:

%%writefile './hvd_wrapper.sh'

#!/bin/bash

# Get local process ID from OpenMPI or alternatively from SLURM
if [ -z "${CUDA_VISIBLE_DEVICES:-}" ]; then
    if [ -n "${OMPI_COMM_WORLD_LOCAL_RANK:-}" ]; then
        LOCAL_RANK="${OMPI_COMM_WORLD_LOCAL_RANK}"
    elif [ -n "${SLURM_LOCALID:-}" ]; then
        LOCAL_RANK="${SLURM_LOCALID}"
    fi
    export CUDA_VISIBLE_DEVICES=${LOCAL_RANK}
fi

exec "$@"
Overwriting ./hvd_wrapper.sh

OpenMPI and Slurm are tools for running distributed computed jobs. In this example, we’re using OpenMPI, but depending on the environment you run distributed training jobs in, you may need to check slightly different environment variables to find the total number of workers (global size) and each process’s worker number (global rank.)

Why do we have to check environment variables instead of using hvd.rank() and hvd.local_rank()? NVTabular does some GPU configuration when imported and needs to be imported before Horovod to avoid conflicts. We need to set GPU visibility before NVTabular is imported (when Horovod isn’t yet available) so that multiple processes don’t each try to configure all the GPUs, so as a workaround, we “cheat” and peek at environment variables set by horovodrun to decide which GPU each process should use.

!horovodrun -np 2 sh hvd_wrapper.sh python tf_trainer.py --dir_in $BASE_DIR --batch_size 16384
2021-06-04 16:39:06.000313: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
[1,0]<stderr>:2021-06-04 16:39:08.979997: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
[1,1]<stderr>:2021-06-04 16:39:09.064191: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
[1,0]<stderr>:2021-06-04 16:39:10.138200: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
[1,0]<stderr>:2021-06-04 16:39:10.138376: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
[1,0]<stderr>:2021-06-04 16:39:10.139777: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Found device 0 with properties: 
[1,0]<stderr>:pciBusID: 0000:0b:00.0 name: GeForce GTX 1080 Ti computeCapability: 6.1
[1,0]<stderr>:coreClock: 1.582GHz coreCount: 28 deviceMemorySize: 10.91GiB deviceMemoryBandwidth: 451.17GiB/s
[1,0]<stderr>:2021-06-04 16:39:10.139823: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
[1,0]<stderr>:2021-06-04 16:39:10.139907: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
[1,0]<stderr>:2021-06-04 16:39:10.139949: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
[1,0]<stderr>:2021-06-04 16:39:10.139990: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
[1,0]<stderr>:2021-06-04 16:39:10.140029: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
[1,0]<stderr>:2021-06-04 16:39:10.140084: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.11
[1,0]<stderr>:2021-06-04 16:39:10.140123: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
[1,0]<stderr>:2021-06-04 16:39:10.140169: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
[1,0]<stderr>:2021-06-04 16:39:10.144021: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1888] Adding visible gpu devices: 0
[1,1]<stderr>:2021-06-04 16:39:10.367414: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
[1,1]<stderr>:2021-06-04 16:39:10.367496: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
[1,1]<stderr>:2021-06-04 16:39:10.368324: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Found device 0 with properties: 
[1,1]<stderr>:pciBusID: 0000:42:00.0 name: GeForce GTX 1080 Ti computeCapability: 6.1
[1,1]<stderr>:coreClock: 1.582GHz coreCount: 28 deviceMemorySize: 10.92GiB deviceMemoryBandwidth: 451.17GiB/s
[1,1]<stderr>:2021-06-04 16:39:10.368347: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
[1,1]<stderr>:2021-06-04 16:39:10.368396: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
[1,1]<stderr>:2021-06-04 16:39:10.368424: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
[1,1]<stderr>:2021-06-04 16:39:10.368451: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
[1,1]<stderr>:2021-06-04 16:39:10.368475: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
[1,1]<stderr>:2021-06-04 16:39:10.368512: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.11
[1,1]<stderr>:2021-06-04 16:39:10.368537: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
[1,1]<stderr>:2021-06-04 16:39:10.368573: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
[1,1]<stderr>:2021-06-04 16:39:10.369841: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1888] Adding visible gpu devices: 0
[1,1]<stderr>:2021-06-04 16:39:11.730033: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
[1,1]<stderr>:2021-06-04 16:39:11.730907: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Found device 0 with properties: 
[1,1]<stderr>:pciBusID: 0000:42:00.0 name: GeForce GTX 1080 Ti computeCapability: 6.1
[1,1]<stderr>:coreClock: 1.582GHz coreCount: 28 deviceMemorySize: 10.92GiB deviceMemoryBandwidth: 451.17GiB/s
[1,1]<stderr>:2021-06-04 16:39:11.730990: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
[1,1]<stderr>:2021-06-04 16:39:11.731005: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
[1,1]<stderr>:2021-06-04 16:39:11.731018: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
[1,1]<stderr>:2021-06-04 16:39:11.731029: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
[1,1]<stderr>:2021-06-04 16:39:11.731038: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
[1,1]<stderr>:2021-06-04 16:39:11.731049: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.11
[1,1]<stderr>:2021-06-04 16:39:11.731059: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
[1,1]<stderr>:2021-06-04 16:39:11.731078: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
[1,1]<stderr>:2021-06-04 16:39:11.732312: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1888] Adding visible gpu devices: 0
[1,1]<stderr>:2021-06-04 16:39:11.732350: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
[1,1]<stderr>:2021-06-04 16:39:11.732473: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1287] Device interconnect StreamExecutor with strength 1 edge matrix:
[1,1]<stderr>:2021-06-04 16:39:11.732487: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1293]      0 
[1,1]<stderr>:2021-06-04 16:39:11.732493: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1306] 0:   N 
[1,1]<stderr>:2021-06-04 16:39:11.734431: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1432] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 3352 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:42:00.0, compute capability: 6.1)
[1,0]<stderr>:2021-06-04 16:39:11.821346: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
[1,0]<stderr>:2021-06-04 16:39:11.822270: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Found device 0 with properties: 
[1,0]<stderr>:pciBusID: 0000:0b:00.0 name: GeForce GTX 1080 Ti computeCapability: 6.1
[1,0]<stderr>:coreClock: 1.582GHz coreCount: 28 deviceMemorySize: 10.91GiB deviceMemoryBandwidth: 451.17GiB/s
[1,0]<stderr>:2021-06-04 16:39:11.822360: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
[1,0]<stderr>:2021-06-04 16:39:11.822376: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
[1,0]<stderr>:2021-06-04 16:39:11.822389: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
[1,0]<stderr>:2021-06-04 16:39:11.822400: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
[1,0]<stderr>:2021-06-04 16:39:11.822411: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
[1,0]<stderr>:2021-06-04 16:39:11.822425: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.11
[1,0]<stderr>:2021-06-04 16:39:11.822434: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
[1,0]<stderr>:2021-06-04 16:39:11.822454: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
[1,0]<stderr>:2021-06-04 16:39:11.823684: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1888] Adding visible gpu devices: 0
[1,0]<stderr>:2021-06-04 16:39:11.823731: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
[1,0]<stderr>:2021-06-04 16:39:11.823868: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1287] Device interconnect StreamExecutor with strength 1 edge matrix:
[1,0]<stderr>:2021-06-04 16:39:11.823881: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1293]      0 
[1,0]<stderr>:2021-06-04 16:39:11.823888: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1306] 0:   N 
[1,0]<stderr>:2021-06-04 16:39:11.825784: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1432] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 3352 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:0b:00.0, compute capability: 6.1)
[1,0]<stderr>:2021-06-04 16:39:17.634485: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
[1,0]<stderr>:2021-06-04 16:39:17.668915: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2993950000 Hz
[1,1]<stderr>:2021-06-04 16:39:17.694128: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
[1,1]<stderr>:2021-06-04 16:39:17.703326: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2993950000 Hz
[1,0]<stderr>:2021-06-04 16:39:17.780825: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
[1,1]<stderr>:2021-06-04 16:39:17.810644: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
[1,0]<stderr>:2021-06-04 16:39:17.984966: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
[1,1]<stderr>:2021-06-04 16:39:18.012113: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
[1,0]<stdout>:Step #0	Loss: 0.695094
[1,0]<stdout>:Step #100	Loss: 0.669580
[1,0]<stdout>:Step #200	Loss: 0.661098
[1,0]<stdout>:Step #300	Loss: 0.660680
[1,0]<stdout>:Step #400	Loss: 0.658633
[1,0]<stdout>:Step #500	Loss: 0.660251
[1,0]<stdout>:Step #600	Loss: 0.657047