merlin.dataloader.jax.Loader
-
class
merlin.dataloader.jax.Loader(dataset, batch_size, shuffle=False, seed_fn=None, parts_per_chunk=1, global_size=None, global_rank=None, drop_last=False, transforms=None, device=None)[source] Bases:
merlin.dataloader.loader_base.LoaderBaseJax dataloader
- Parameters
dataset (merlin.io.Dataset) – The dataset to load
batch_size (int) – Number of rows to yield at each iteration
shuffle (bool, default True) – Whether to shuffle chunks of batches before iterating through them.
seed_fn (callable) – Function used to initialize random state
parts_per_chunk (int) – Number of dataset partitions with size dictated by buffer_size to load and concatenate asynchronously. More partitions leads to better epoch-level randomness but can negatively impact throughput
global_size (int, optional) – When doing distributed training, this indicates the number of total processes that are training the model.
global_rank – When doing distributed training, this indicates the local rank for the current process.
drop_last (bool, default False) – Whether or not to drop the last batch in an epoch. This is useful when you need to guarantee that each batch contains exactly batch_size rows - since the last batch will usually contain fewer rows.
-
__init__(dataset, batch_size, shuffle=False, seed_fn=None, parts_per_chunk=1, global_size=None, global_rank=None, drop_last=False, transforms=None, device=None)
Methods
__init__(dataset, batch_size[, shuffle, …])array_lib()epochs([epochs])Create a dataloader that will efficiently run for more than one epoch.
make_tensors(gdf[, use_row_lengths])Yields batches of tensors from a dataframe
peek()Get the next batch without advancing the iterator.
stop()Halts and resets the initialization parameters of the dataloader.
Attributes
input_schemaGet input schema of data to be loaded.
output_schemaGet output schema of data being loaded.
schemaGet input schema of data to be loaded
transforms