zea.data.dataloaderΒΆ

H5 dataloader for loading images from zea datasets.

Example

import zea

loader = zea.Dataloader(
    file_paths="/path/to/dataset",
    key="data/image",
    batch_size=16,
    image_range=(-60, 0),
    normalization_range=(0, 1),
    image_size=(256, 256),
    num_threads=16,
)

for batch in loader:
    # batch is a numpy array of shape (batch_size, 256, 256, 1)
    ...

Functions

generate_h5_indices(file_paths, file_shapes, ...)

Generate indices for h5 files.

Classes

Dataloader(file_paths[, key, batch_size, ...])

High-performance HDF5 dataloader built on Grain.

H5DataSource(file_paths[, key, n_frames, ...])

Thread-safe random-access data source for HDF5 files.

class zea.data.dataloader.Dataloader(file_paths, key='data/image', batch_size=16, n_frames=1, shuffle=True, return_filename=False, seed=None, limit_n_samples=None, limit_n_frames=None, drop_remainder=False, image_size=None, resize_type=None, resize_axes=None, resize_kwargs=None, image_range=None, normalization_range=None, clip_image_range=False, assert_image_range=True, dataset_repetitions=None, cache=False, additional_axes_iter=None, sort_files=True, overlapping_blocks=False, augmentation=None, initial_frame_axis=0, insert_frame_axis=True, frame_index_stride=1, frame_axis=-1, validate=True, prefetch=True, shard_index=None, num_shards=1, num_threads=16, prefetch_buffer_size=500, **kwargs)[source]ΒΆ

Bases: object

High-performance HDF5 dataloader built on Grain.

grain threads (N) β†’ h5py (thread-local handles) β†’ numpy β†’ user

The entire pipeline runs in numpy β€” no framework dependency until you feed tensors to your model.

Does the following in order to load a dataset:

  • Find all .hdf5 files in the director(ies)

  • Load the data from each file using the specified key

  • Apply the following transformations in order (if specified):

    • shuffle

    • shard

    • add channel dim

    • clip_image_range

    • assert_image_range

    • resize

    • repeat

    • batch

    • normalize

    • augmentation

Parameters:
  • file_paths (Union[List[str], str]) – Path(s) to directory(ies) and/or HDF5 file(s).

  • key (str) – HDF5 dataset key. Default is "data/image".

  • batch_size (int | None) – Batch size. Set to None to disable batching. Default is 16.

  • n_frames (int) – Number of consecutive frames per sample. Default is 1. When n_frames > 1, frames are grouped into blocks.

  • shuffle (bool) – Shuffle dataset each epoch. Default is True.

  • return_filename (bool) – Return filename metadata together with each sample. Default is False.

  • seed (int | None) – Random seed used for shuffling. Default is None. If None and shuffle=True, a random seed is generated.

  • limit_n_samples (int | None) – Limit total number of samples (useful for debugging). Default is None (no limit).

  • limit_n_frames (int | None) – Limit frames loaded per file to the first N frames. Default is None (no limit).

  • drop_remainder (bool) – Drop the final incomplete batch. Default is False.

  • image_size (tuple | None) – Target (height, width). Default is None (no resizing).

  • resize_type (str | None) – Resize strategy. One of "resize", "center_crop", "random_crop" or "crop_or_pad". Default is None, which resolves to "resize" when image_size is set.

  • resize_axes (tuple | None) – Axes to resize along, must have length 2 (height, width). Only needed when data has more than (h, w, c) dimensions. Axes are interpreted after frame-axis insertion/reordering. Default is None.

  • resize_kwargs (dict | None) – Extra keyword arguments passed to Resizer. Default is None.

  • image_range (tuple | None) – Source value range of images, e.g. (-60, 0). Used for clipping/asserting/normalization. Default is None.

  • normalization_range (tuple | None) – Target value range, e.g. (0, 1). If set, image_range must also be set. Default is None.

  • clip_image_range (bool) – Clip values to image_range before normalization. Default is False.

  • assert_image_range (bool) – Assert values stay within image_range. Default is True.

  • dataset_repetitions (int | None) – Repeat dataset this many times. Repetition happens after sharding. Default is None (no repetition).

  • cache (bool) – Cache loaded samples in RAM. Default is False. Note that with overlapping_blocks=True, the same frame can be part of multiple samples, so caching will consume more memory.

  • additional_axes_iter (tuple | None) – Additional axes to iterate over in addition to initial_frame_axis. Default is None.

  • sort_files (bool) – Sort files numerically before indexing. Default is True.

  • overlapping_blocks (bool) – If True, frame blocks overlap by n_frames - 1. Has no effect when n_frames == 1. Default is False.

  • augmentation (callable) – Callable applied to each batch after normalization. Default is None.

  • initial_frame_axis (int) – Axis in file data that represents frames. Default is 0.

  • insert_frame_axis (bool) – If True, keep per-frame samples and move/insert the frame dimension at frame_axis. If False, loaded frames are concatenated along frame_axis. Default is True.

  • frame_index_stride (int) – Step between selected frames in a block. Default is 1.

  • frame_axis (int) – Axis along which frames are stacked/placed in output. Default is -1.

  • validate (bool) – Validate discovered files against the zea format. Default is True.

  • prefetch (bool) – Enable Grain prefetching for iteration. Default is True.

  • shard_index (int | None) – Shard index to select when num_shards > 1. Must satisfy 0 <= shard_index < num_shards.

  • num_shards (int) – Total number of shards for distributed loading. Sharding happens before downstream transforms. Default is 1.

  • num_threads (int) – Number of Grain read threads (0 means main thread only). Default is 16.

  • prefetch_buffer_size (int) – Size of the Grain buffer for reading elements per Python process (not per thread). Useful when reading from a distributed file system. Default is 500.

Example

loader = Dataloader(
    file_paths="/data/camus",
    key="data/image_sc",
    batch_size=32,
    image_range=(-60, 0),
    normalization_range=(0, 1),
    image_size=(256, 256),
)
for batch in loader:
    ...  # batch.shape == (32, 256, 256, 1)
close()[source]ΒΆ

Release file handles.

property datasetΒΆ

The underlying grain.MapDataset.

summary()[source]ΒΆ

Print dataset statistics and per-directory breakdown.

to_iter_dataset()[source]ΒΆ

Convert to a grain.IterDataset with prefetching.

This is called automatically when you iterate, but you can call it explicitly if you want to hold onto the IterDataset object.

class zea.data.dataloader.H5DataSource(file_paths, key='data/image', n_frames=1, frame_index_stride=1, frame_axis=-1, insert_frame_axis=True, initial_frame_axis=0, additional_axes_iter=None, sort_files=True, overlapping_blocks=False, limit_n_samples=None, limit_n_frames=None, return_filename=False, cache=False, validate=True, **kwargs)[source]ΒΆ

Bases: object

Thread-safe random-access data source for HDF5 files.

Implements grain.RandomAccessDataSource protocol (__getitem__ and __len__) so it can be plugged directly into a grain.MapDataset pipeline.

Each worker thread gets its own H5FileHandleCache via threading.local() so h5py file handles are never shared across threads.

Parameters:
  • file_paths (Union[List[str], str]) – Path(s) to HDF5 directory(ies) or file(s).

  • key (str) – HDF5 dataset key, e.g. "data/image".

  • n_frames (int) – Number of consecutive frames per sample.

  • frame_index_stride (int) – Stride between frames.

  • frame_axis (int) – Axis along which frames are stacked in the output.

  • insert_frame_axis (bool) – Whether to insert a new axis for frames.

  • initial_frame_axis (int) – Source axis that stores frames in the file.

  • additional_axes_iter (tuple | None) – Extra axes to iterate over.

  • sort_files (bool) – Sort files numerically.

  • overlapping_blocks (bool) – Allow overlapping frame blocks.

  • limit_n_samples (int | None) – Cap the number of samples.

  • limit_n_frames (int | None) – Cap frames loaded per file.

  • return_filename (bool) – Return filename metadata with each sample.

  • cache (bool) – Cache loaded samples to RAM.

  • validate (bool) – Validate dataset against the zea format.

close()[source]ΒΆ

Close all file handles across all threads.

zea.data.dataloader.generate_h5_indices(file_paths, file_shapes, n_frames, frame_index_stride, key='data/image', initial_frame_axis=0, additional_axes_iter=None, sort_files=True, overlapping_blocks=False, limit_n_frames=None)[source]ΒΆ

Generate indices for h5 files.

Generates a list of indices to extract images from hdf5 files. Length of this list is the length of the extracted dataset.

Parameters:
  • file_paths (List[str]) – List of file paths.

  • file_shapes (list) – List of file shapes.

  • n_frames (int) – Number of frames to load from each hdf5 file.

  • frame_index_stride (int) – Interval between frames to load.

  • key (str) – Key of hdf5 dataset to grab data from. Defaults to β€œdata/image”.

  • initial_frame_axis (int) – Axis to iterate over. Defaults to 0.

  • additional_axes_iter (Optional[List[int]]) – Additional axes to iterate over in the dataset. Defaults to None.

  • sort_files (bool) – Sort files by number. Defaults to True.

  • overlapping_blocks (bool) – Will take n_frames from sequence, then move by 1. Defaults to False.

  • limit_n_frames (int | None) – Limit the number of frames to load from each file. This means n_frames per data file will be used. These will be the first frames in the file. Defaults to None.

Returns:

List of tuples with indices to extract images from hdf5 files.

(file_name, key, indices) with indices being a tuple of slices.

Return type:

list

Example

[
    (
        "/folder/path_to_file.hdf5",
        "data/image",
        (slice(0, 1, 1), slice(None, 256, None), slice(None, 256, None)),
    ),
    (
        "/folder/path_to_file.hdf5",
        "data/image",
        (slice(1, 2, 1), slice(None, 256, None), slice(None, 256, None)),
    ),
    ...,
]