merlin.dataloader.jax.Loader
-
class
merlin.dataloader.jax.
Loader
(dataset, batch_size, shuffle=False, seed_fn=None, parts_per_chunk=1, global_size=None, global_rank=None, drop_last=False, transforms=None, device=None)[source] Bases:
merlin.dataloader.loader_base.LoaderBase
Jax dataloader
- Parameters
dataset (merlin.io.Dataset) – The dataset to load
batch_size (int) – Number of rows to yield at each iteration
shuffle (bool, default True) – Whether to shuffle chunks of batches before iterating through them.
seed_fn (callable) – Function used to initialize random state
parts_per_chunk (int) – Number of dataset partitions with size dictated by buffer_size to load and concatenate asynchronously. More partitions leads to better epoch-level randomness but can negatively impact throughput
global_size (int, optional) – When doing distributed training, this indicates the number of total processes that are training the model.
global_rank – When doing distributed training, this indicates the local rank for the current process.
drop_last (bool, default False) – Whether or not to drop the last batch in an epoch. This is useful when you need to guarantee that each batch contains exactly batch_size rows - since the last batch will usually contain fewer rows.
-
__init__
(dataset, batch_size, shuffle=False, seed_fn=None, parts_per_chunk=1, global_size=None, global_rank=None, drop_last=False, transforms=None, device=None)
Methods
__init__
(dataset, batch_size[, shuffle, …])array_lib
()epochs
([epochs])Create a dataloader that will efficiently run for more than one epoch.
make_tensors
(gdf[, use_row_lengths])Yields batches of tensors from a dataframe
peek
()Get the next batch without advancing the iterator.
stop
()Halts and resets the initialization parameters of the dataloader.
Attributes
input_schema
Get input schema of data to be loaded.
output_schema
Get output schema of data being loaded.
schema
Get input schema of data to be loaded
transforms