zea.data.dataloaderΒΆ
H5 dataloader for loading images from zea datasets.
Example
import zea
loader = zea.Dataloader(
file_paths="/path/to/dataset",
key="data/image",
batch_size=16,
image_range=(-60, 0),
normalization_range=(0, 1),
image_size=(256, 256),
num_threads=16,
)
for batch in loader:
# batch is a numpy array of shape (batch_size, 256, 256, 1)
...
Functions
|
Generate indices for h5 files. |
Classes
|
High-performance HDF5 dataloader built on Grain. |
|
Thread-safe random-access data source for HDF5 files. |
- class zea.data.dataloader.Dataloader(file_paths, key='data/image', batch_size=16, n_frames=1, shuffle=True, return_filename=False, seed=None, limit_n_samples=None, limit_n_frames=None, drop_remainder=False, image_size=None, resize_type=None, resize_axes=None, resize_kwargs=None, image_range=None, normalization_range=None, clip_image_range=False, assert_image_range=True, dataset_repetitions=None, cache=False, additional_axes_iter=None, sort_files=True, overlapping_blocks=False, augmentation=None, initial_frame_axis=0, insert_frame_axis=True, frame_index_stride=1, frame_axis=-1, validate=True, prefetch=True, shard_index=None, num_shards=1, num_threads=16, prefetch_buffer_size=500, **kwargs)[source]ΒΆ
Bases:
objectHigh-performance HDF5 dataloader built on Grain.
grain threads (N) β h5py (thread-local handles) β numpy β user
The entire pipeline runs in numpy β no framework dependency until you feed tensors to your model.
Does the following in order to load a dataset:
Find all .hdf5 files in the director(ies)
Load the data from each file using the specified key
Apply the following transformations in order (if specified):
shuffle
shard
add channel dim
clip_image_range
assert_image_range
resize
repeat
batch
normalize
augmentation
- Parameters:
file_paths (
Union[List[str],str]) β Path(s) to directory(ies) and/or HDF5 file(s).key (
str) β HDF5 dataset key. Default is"data/image".batch_size (
int|None) β Batch size. Set toNoneto disable batching. Default is16.n_frames (
int) β Number of consecutive frames per sample. Default is1. Whenn_frames > 1, frames are grouped into blocks.shuffle (
bool) β Shuffle dataset each epoch. Default isTrue.return_filename (
bool) β Return filename metadata together with each sample. Default isFalse.seed (
int|None) β Random seed used for shuffling. Default isNone. IfNoneandshuffle=True, a random seed is generated.limit_n_samples (
int|None) β Limit total number of samples (useful for debugging). Default isNone(no limit).limit_n_frames (
int|None) β Limit frames loaded per file to the first N frames. Default isNone(no limit).drop_remainder (
bool) β Drop the final incomplete batch. Default isFalse.image_size (
tuple|None) β Target(height, width). Default isNone(no resizing).resize_type (
str|None) β Resize strategy. One of"resize","center_crop","random_crop"or"crop_or_pad". Default isNone, which resolves to"resize"when image_size is set.resize_axes (
tuple|None) β Axes to resize along, must have length 2 (height, width). Only needed when data has more than(h, w, c)dimensions. Axes are interpreted after frame-axis insertion/reordering. Default isNone.resize_kwargs (
dict|None) β Extra keyword arguments passed toResizer. Default isNone.image_range (
tuple|None) β Source value range of images, e.g.(-60, 0). Used for clipping/asserting/normalization. Default isNone.normalization_range (
tuple|None) β Target value range, e.g.(0, 1). If set,image_rangemust also be set. Default isNone.clip_image_range (
bool) β Clip values toimage_rangebefore normalization. Default isFalse.assert_image_range (
bool) β Assert values stay withinimage_range. Default isTrue.dataset_repetitions (
int|None) β Repeat dataset this many times. Repetition happens after sharding. Default isNone(no repetition).cache (
bool) β Cache loaded samples in RAM. Default isFalse. Note that withoverlapping_blocks=True, the same frame can be part of multiple samples, so caching will consume more memory.additional_axes_iter (
tuple|None) β Additional axes to iterate over in addition toinitial_frame_axis. Default isNone.sort_files (
bool) β Sort files numerically before indexing. Default isTrue.overlapping_blocks (
bool) β IfTrue, frame blocks overlap byn_frames - 1. Has no effect whenn_frames == 1. Default isFalse.augmentation (
callable) β Callable applied to each batch after normalization. Default isNone.initial_frame_axis (
int) β Axis in file data that represents frames. Default is0.insert_frame_axis (
bool) β IfTrue, keep per-frame samples and move/insert the frame dimension atframe_axis. IfFalse, loaded frames are concatenated alongframe_axis. Default isTrue.frame_index_stride (
int) β Step between selected frames in a block. Default is1.frame_axis (
int) β Axis along which frames are stacked/placed in output. Default is-1.validate (
bool) β Validate discovered files against the zea format. Default isTrue.prefetch (
bool) β Enable Grain prefetching for iteration. Default isTrue.shard_index (
int|None) β Shard index to select whennum_shards > 1. Must satisfy0 <= shard_index < num_shards.num_shards (
int) β Total number of shards for distributed loading. Sharding happens before downstream transforms. Default is1.num_threads (
int) β Number of Grain read threads (0means main thread only). Default is16.prefetch_buffer_size (
int) β Size of the Grain buffer for reading elements per Python process (not per thread). Useful when reading from a distributed file system. Default is500.
Example
loader = Dataloader( file_paths="/data/camus", key="data/image_sc", batch_size=32, image_range=(-60, 0), normalization_range=(0, 1), image_size=(256, 256), ) for batch in loader: ... # batch.shape == (32, 256, 256, 1)
- property datasetΒΆ
The underlying
grain.MapDataset.
- class zea.data.dataloader.H5DataSource(file_paths, key='data/image', n_frames=1, frame_index_stride=1, frame_axis=-1, insert_frame_axis=True, initial_frame_axis=0, additional_axes_iter=None, sort_files=True, overlapping_blocks=False, limit_n_samples=None, limit_n_frames=None, return_filename=False, cache=False, validate=True, **kwargs)[source]ΒΆ
Bases:
objectThread-safe random-access data source for HDF5 files.
Implements
grain.RandomAccessDataSourceprotocol (__getitem__and__len__) so it can be plugged directly into agrain.MapDatasetpipeline.Each worker thread gets its own
H5FileHandleCacheviathreading.local()soh5pyfile handles are never shared across threads.- Parameters:
file_paths (
Union[List[str],str]) β Path(s) to HDF5 directory(ies) or file(s).key (
str) β HDF5 dataset key, e.g."data/image".n_frames (
int) β Number of consecutive frames per sample.frame_index_stride (
int) β Stride between frames.frame_axis (
int) β Axis along which frames are stacked in the output.insert_frame_axis (
bool) β Whether to insert a new axis for frames.initial_frame_axis (
int) β Source axis that stores frames in the file.additional_axes_iter (
tuple|None) β Extra axes to iterate over.sort_files (
bool) β Sort files numerically.overlapping_blocks (
bool) β Allow overlapping frame blocks.limit_n_samples (
int|None) β Cap the number of samples.limit_n_frames (
int|None) β Cap frames loaded per file.return_filename (
bool) β Return filename metadata with each sample.cache (
bool) β Cache loaded samples to RAM.validate (
bool) β Validate dataset against the zea format.
- zea.data.dataloader.generate_h5_indices(file_paths, file_shapes, n_frames, frame_index_stride, key='data/image', initial_frame_axis=0, additional_axes_iter=None, sort_files=True, overlapping_blocks=False, limit_n_frames=None)[source]ΒΆ
Generate indices for h5 files.
Generates a list of indices to extract images from hdf5 files. Length of this list is the length of the extracted dataset.
- Parameters:
file_paths (
List[str]) β List of file paths.file_shapes (
list) β List of file shapes.n_frames (
int) β Number of frames to load from each hdf5 file.frame_index_stride (
int) β Interval between frames to load.key (
str) β Key of hdf5 dataset to grab data from. Defaults to βdata/imageβ.initial_frame_axis (
int) β Axis to iterate over. Defaults to 0.additional_axes_iter (
Optional[List[int]]) β Additional axes to iterate over in the dataset. Defaults to None.sort_files (
bool) β Sort files by number. Defaults to True.overlapping_blocks (
bool) β Will take n_frames from sequence, then move by 1. Defaults to False.limit_n_frames (
int|None) β Limit the number of frames to load from each file. This means n_frames per data file will be used. These will be the first frames in the file. Defaults to None.
- Returns:
- List of tuples with indices to extract images from hdf5 files.
(file_name, key, indices) with indices being a tuple of slices.
- Return type:
list
Example
[ ( "/folder/path_to_file.hdf5", "data/image", (slice(0, 1, 1), slice(None, 256, None), slice(None, 256, None)), ), ( "/folder/path_to_file.hdf5", "data/image", (slice(1, 2, 1), slice(None, 256, None), slice(None, 256, None)), ), ..., ]