TensorFlow Dataloader

class nvtabular.loader.tensorflow.KerasSequenceLoader(paths_or_dataset, batch_size, label_names=None, feature_columns=None, cat_names=None, cont_names=None, engine=None, shuffle=True, seed_fn=None, buffer_size=0.1, device=None, parts_per_chunk=1, reader_kwargs=None, global_size=None, global_rank=None, drop_last=False, sparse_names=None, sparse_max=None, sparse_as_dense=False, schema=None)[source]

Bases: keras.utils.data_utils.Sequence, nvtabular.loader.backend.DataLoader

Infinite generator used to asynchronously iterate through CSV or Parquet dataframes on GPU by leveraging an NVTabular Dataset. Applies preprocessing via NVTabular Workflow objects and outputs tabular dictionaries of TensorFlow Tensors via dlpack. Useful for training tabular models built in Keras and trained via tf.keras.Model.fit.

The data loading scheme is implemented by loading, preprocessing, and batching data in an asynchronous thread. The amount of randomness in shuffling is controlled by the buffer_size and parts_per_chunk kwargs. At load time, sub-chunks of data with size controlled by buffer_size are loaded from random partitions in the dataset, and parts_per_chunk of them are concatenated into a single chunk, shuffled, and split into batches. This means that each chunk has buffer_size*parts_per_chunk rows, and due to the asynchronous nature of the dataloader that means there are, including the batch being processed by your network, 3*buffer_size*parts_per_chunk rows of data in GPU memory at any given time. This means that for a fixed memory budget, using more parts_per_chunk will come at the expense of smaller buffer_size, increasing the number of reads and reducing throughput. The goal should be to maximize the total amount of memory utilized at once without going OOM and with the fewest number of reads to meet your epoch-level randomness needs.

An important thing to note is that TensorFlow’s default behavior is to claim all GPU memory for itself at initialziation time, which leaves none for NVTabular to load or preprocess data. As such, we attempt to configure TensorFlow to restrict its memory allocation on a given GPU using the environment variables TF_MEMORY_ALLOCATION and TF_VISIBLE_DEVICE. If TF_MEMORY_ALLOCATION < 1, it will be assumed that this refers to a fraction of free GPU memory on the given device. Otherwise, it will refer to an explicit allocation amount in MB. TF_VISIBLE_DEVICE should be an integer GPU index.

Iterator output is of the form (dict(features), list(labels)), where each element of the features dict is a feature_name: feature_tensor and each elemtn of the labels list is a tensor, and all tensors are of shape (batch_size, 1). Note that this means vectorized continuous and multi-hot categorical features are not currently supported. The underlying NVTabular Dataset object is stored in the data attribute, and should be used for updating NVTabular Workflow statistics:

workflow = nvt.Workflow(...)
dataset = KerasSequenceLoader(...)
workflow.update_stats(dataset.data.to_iter(), record_stats=True)
Parameters
  • paths_or_dataset (str or list(str)) – Either a string representing a file pattern (see tf.glob for pattern rules), a list of filenames to be iterated through, or a Dataset object, in which case buffer_size, engine, and reader_kwargs will be ignored

  • batch_size (int) – Number of samples to yield at each iteration

  • label_names (list(str)) – Column name of the target variable in the dataframe specified by paths_or_dataset

  • feature_columns (list(tf.feature_column) or None) – A list of TensorFlow feature columns representing the inputs exposed to the model to be trained. Columns with parent columns will climb the parent tree, and the names of the columns in the unique set of terminal columns will be used as the column names. If left as None, must specify cat_names and cont_names

  • cat_names (list(str) or None) – List of categorical column names. Ignored if feature_columns is specified

  • cont_names (list(str) or None) – List of continuous column names. Ignored if feature_columns is specified

  • engine ({'csv', 'parquet', None}, default None) – String specifying the type of read engine to use. If left as None, will try to infer the engine type from the file extension.

  • shuffle (bool, default True) – Whether to shuffle chunks of batches before iterating through them.

  • buffer_size (float or int) – If 0 < buffer_size < 1, buffer_size will refer to the fraction of total GPU memory to occupy with a buffered chunk. If 1 < buffer_size < batch_size, the number of rows read for a buffered chunk will be equal to int(buffer_size*batch_size). Otherwise, if buffer_size > batch_size, buffer_size rows will be read in each chunk (except for the last chunk in a dataset, which will, in general, be smaller). Larger chunk sizes will lead to more efficiency and randomness, but require more memory.

  • device (None) – Which GPU device to load from. Ignored for now

  • parts_per_chunk (int) – Number of dataset partitions with size dictated by buffer_size to load and concatenate asynchronously. More partitions leads to better epoch-level randomness but can negatively impact throughput

  • reader_kwargs (dict) – extra kwargs to pass when instantiating the underlying nvtabular.Dataset

  • sparse_list (list(str) or None) – list with column names of columns that should be represented as sparse tensors

  • sparse_max (dict) – dictionary of key: column_name + value: integer representing max sequence length for column

  • sparse_as_dense (bool) – bool value to activate transforming sparse tensors to dense

map(fn)[source]

Applying a function to each batch.

This can for instance be used to add sample_weight to the model.

class nvtabular.loader.tensorflow.KerasSequenceValidater(dataloader)[source]

Bases: keras.callbacks.Callback

on_epoch_end(epoch, logs=None)[source]

TensorFlow Layers

class nvtabular.framework_utils.tensorflow.layers.embedding.DenseFeatures(*args, **kwargs)[source]

Bases: keras.engine.base_layer.Layer

Layer which maps a dictionary of input tensors to a dense, continuous vector digestible by a neural network. Meant to reproduce the API exposed by tf.keras.layers.DenseFeatures while reducing overhead for the case of one-hot categorical and scalar numeric features.

Uses TensorFlow feature_column to represent inputs to the layer, but does not perform any preprocessing associated with those columns. As such, it should only be passed numeric_column objects and their subclasses, embedding_column and indicator_column. Preprocessing functionality should be moved to NVTabular.

For multi-hot categorical or vector continuous data, represent the data for a feature with a dictionary entry “<feature_name>__values” corresponding to the flattened array of all values in the batch. For multi-hot categorical data, there should be a corresponding “<feature_name>__nnzs” entry that describes how many categories are present in each sample (and so has length batch_size).

Note that categorical columns should be wrapped in embedding or indicator columns first, consistent with the API used by tf.keras.layers.DenseFeatures.

Example usage:

column_a = tf.feature_column.numeric_column("a", (1,))
column_b = tf.feature_column.categorical_column_with_identity("b", 100)
column_b_embedding = tf.feature_column.embedding_column(column_b, 4)

inputs = {
    "a": tf.keras.Input(name="a", shape=(1,), dtype=tf.float32),
    "b": tf.keras.Input(name="b", shape=(1,), dtype=tf.int64)
}
x = DenseFeatures([column_a, column_b_embedding])(inputs)
Parameters
  • feature_columns (list of tf.feature_column) – feature columns describing the inputs to the layer

  • aggregation (str in ("concat", "stack")) – how to combine the embeddings from multiple features

build(input_shapes)[source]
call(inputs)[source]
compute_output_shape(input_shapes)[source]
get_config()[source]
class nvtabular.framework_utils.tensorflow.layers.embedding.LinearFeatures(*args, **kwargs)[source]

Bases: keras.engine.base_layer.Layer

Layer which implements a linear combination of one-hot categorical and scalar numeric features. Based on the “wide” branch of the Wide & Deep network architecture.

Uses TensorFlow feature_column``s to represent inputs to the layer, but does not perform any preprocessing associated with those columns. As such, it should only be passed ``numeric_column and categorical_column_with_identity. Preprocessing functionality should be moved to NVTabular.

Also note that, unlike ScalarDenseFeatures, categorical columns should NOT be wrapped in embedding or indicator columns first.

Example usage:

column_a = tf.feature_column.numeric_column("a", (1,))
column_b = tf.feature_column.categorical_column_with_identity("b", 100)

inputs = {
    "a": tf.keras.Input(name="a", shape=(1,), dtype=tf.float32),
    "b": tf.keras.Input(name="b", shape=(1,), dtype=tf.int64)
}
x = ScalarLinearFeatures([column_a, column_b])(inputs)
Parameters

feature_columns (list of tf.feature_column) – feature columns describing the inputs to the layer

build(input_shapes)[source]
call(inputs)[source]
compute_output_shape(input_shapes)[source]
get_config()[source]
class nvtabular.framework_utils.tensorflow.layers.interaction.DotProductInteraction(*args, **kwargs)[source]

Bases: keras.engine.base_layer.Layer

build(input_shape)[source]
call(value)[source]
compute_output_shape(input_shape)[source]
get_config()[source]