merlin.dataloader.jax.Loader

class merlin.dataloader.jax.Loader(dataset, batch_size, shuffle=False, seed_fn=None, parts_per_chunk=1, global_size=None, global_rank=None, drop_last=False, transforms=None, device=None)[source]

Bases: merlin.dataloader.loader_base.LoaderBase

Jax dataloader

Parameters
  • dataset (merlin.io.Dataset) – The dataset to load

  • batch_size (int) – Number of rows to yield at each iteration

  • shuffle (bool, default True) – Whether to shuffle chunks of batches before iterating through them.

  • seed_fn (callable) – Function used to initialize random state

  • parts_per_chunk (int) – Number of dataset partitions with size dictated by buffer_size to load and concatenate asynchronously. More partitions leads to better epoch-level randomness but can negatively impact throughput

  • global_size (int, optional) – When doing distributed training, this indicates the number of total processes that are training the model.

  • global_rank – When doing distributed training, this indicates the local rank for the current process.

  • drop_last (bool, default False) – Whether or not to drop the last batch in an epoch. This is useful when you need to guarantee that each batch contains exactly batch_size rows - since the last batch will usually contain fewer rows.

__init__(dataset, batch_size, shuffle=False, seed_fn=None, parts_per_chunk=1, global_size=None, global_rank=None, drop_last=False, transforms=None, device=None)

Methods

__init__(dataset, batch_size[, shuffle, …])

array_lib()

epochs([epochs])

Create a dataloader that will efficiently run for more than one epoch.

make_tensors(gdf[, use_row_lengths])

Yields batches of tensors from a dataframe

peek()

Get the next batch without advancing the iterator.

stop()

Halts and resets the initialization parameters of the dataloader.

Attributes

input_schema

Get input schema of data to be loaded.

output_schema

Get output schema of data being loaded.

schema

Get input schema of data to be loaded

transforms