#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import warnings
from functools import partial
from merlin.core.compat.tensorflow import tensorflow as tf
from merlin.dataloader.loader_base import LoaderBase
from merlin.table import TensorColumn, TensorflowColumn, TensorTable
from merlin.table.conversions import _dispatch_dlpack_fns, convert_col
[docs]class Loader(LoaderBase, tf.keras.utils.Sequence):
[docs] def __init__(
self,
dataset,
batch_size,
shuffle=False,
seed_fn=None,
parts_per_chunk=1,
global_size=None,
global_rank=None,
drop_last=False,
transforms=None,
device=None,
):
super().__init__(
dataset,
batch_size,
shuffle,
seed_fn,
parts_per_chunk,
global_size,
global_rank,
drop_last,
transforms,
device,
)
self._validate_batch_size(batch_size)
self.create_table = partial(TensorTable, _unsafe=True)
self.create_column = partial(TensorColumn, _unsafe=True)
column = self.create_column(self.array_lib().array([]))
_to_dlpack_fn, _from_dlpack_fn = _dispatch_dlpack_fns(column, TensorflowColumn)
self.convert_col = partial(
convert_col, _to_dlpack_fn=_to_dlpack_fn, _from_dlpack_fn=_from_dlpack_fn, _unsafe=True
)
def _validate_batch_size(self, batch_size):
is_power_of_two = batch_size & (batch_size - 1) == 0
if self.device != "cpu" and (batch_size < 16 or not is_power_of_two):
warnings.warn(
"Due to a CUDA memory alignment issue in some Tensorflow "
"operations such as Embedding ops, we recommend that "
"'batch_size' be at least 16 and also a power of two. "
"Please change 'batch_size' to a number that is a power of "
"two that is greater than or equal to 16.",
UserWarning,
)
def __len__(self):
"""Number of batches in the Sequence.
Note: This also resets the loader state.
Required because of the calls to `__getitem__`
from keras prior to the start of the main loop
through the loader.
"""
LoaderBase.stop(self)
return LoaderBase.__len__(self)
def __getitem__(self, index):
"""Gets batch at position `index`.
Note: This returns the next batch in the iterator.
Not the batch at position `index`.
This is because the dataloader is implemented as an iterator and
don't currently support fetching a batch by index.
"""
return self.__next__()
def __next__(self):
"""Get the next batch from the dataloader"""
converted_batch = self.convert_batch(super().__next__())
for map_fn in self._map_fns:
converted_batch = map_fn(*converted_batch)
return converted_batch
[docs] def peek(self):
"""Grab the next batch from the dataloader
without removing it from the queue"""
converted_batch = self.convert_batch(self._peek_next_batch())
for map_fn in self._map_fns:
converted_batch = map_fn(*converted_batch)
return converted_batch
[docs] def on_epoch_end(self):
self.stop()
[docs] def convert_batch(self, batch):
"""Returns a batch after it has been converted to the appropriate tensor
column type and then formats it in a flat dictionary which makes list
columns into values and offsets as separate entries.
Parameters
----------
batch : tuple
Tuple of dictionary inputs and n-dimensional array of targets
Returns
-------
Tuple
A tuple of dictionary inputs, with lists split as values and offsets,
and targets as an array
"""
inputs, targets = batch
column_type = TensorflowColumn
tf_inputs = {}
if inputs is not None:
inputs_table = self.create_table(inputs)
for col_name, col in inputs_table.items():
tf_inputs[col_name] = self.convert_col(col, column_type)
tf_target = None
if targets is not None:
if isinstance(targets, dict):
targets_table = self.create_table(targets)
tf_targets = {}
for col_name, col in targets_table.items():
tf_targets[col_name] = self.convert_col(col, column_type)
tf_target = self.create_table(tf_targets).to_dict()
else:
targets_col = self.create_column(targets)
tf_target = self.convert_col(targets_col, column_type).values
return (self.create_table(tf_inputs).to_dict(), tf_target)
[docs] def map(self, fn):
"""
Applying a function to each batch.
This can for instance be used to add `sample_weight` to the model.
"""
self._map_fns.append(fn)
return self
[docs]class KerasSequenceValidater(tf.keras.callbacks.Callback):
# TODO: document
_supports_tf_logs = True
[docs] def __init__(self, dataloader):
super().__init__()
self.dataloader = dataloader
[docs] def on_epoch_end(self, epoch, logs=None):
"""Callback that runs at the end of an epoch.
Parameters
----------
epoch : int
Integer representing the current epoch.
logs : Dict, optional
dictionary of logs collected, to be added to, by default None
Returns
-------
logs
Dictionary with results from callback.
"""
logs = logs if logs is not None else {}
for X, y_true in self.dataloader:
y_pred = self.model(X)
# TODO: how do we want to handle the multi-output case?
for metric in self.model.metrics:
metric.update_state(y_true, y_pred)
set_logs = {}
for metric in self.model.metrics:
set_logs[f"val_{metric.name}"] = metric.result().numpy()
logs.update(set_logs)
print(set_logs)
return logs