Source code for zea.data.dataloader

"""H5 dataloader for loading images from zea datasets.

Example:
    .. code-block:: python

        import zea

        loader = zea.Dataloader(
            file_paths="/path/to/dataset",
            key="data/image",
            batch_size=16,
            image_range=(-60, 0),
            normalization_range=(0, 1),
            image_size=(256, 256),
            num_threads=16,
        )

        for batch in loader:
            # batch is a numpy array of shape (batch_size, 256, 256, 1)
            ...
"""

import re
import threading
from itertools import product
from pathlib import Path
from typing import List

import grain
import keras
import numpy as np

from zea import log
from zea.data.datasets import Dataset, H5FileHandleCache, count_samples_per_directory
from zea.data.layers import Resizer
from zea.utils import canonicalize_axis, map_negative_indices

DEFAULT_NORMALIZATION_RANGE = (0, 1)


[docs] def generate_h5_indices( file_paths: List[str], file_shapes: list, n_frames: int, frame_index_stride: int, key: str = "data/image", initial_frame_axis: int = 0, additional_axes_iter: List[int] | None = None, sort_files: bool = True, overlapping_blocks: bool = False, limit_n_frames: int | None = None, ): """Generate indices for h5 files. Generates a list of indices to extract images from hdf5 files. Length of this list is the length of the extracted dataset. Args: file_paths (list): List of file paths. file_shapes (list): List of file shapes. n_frames (int): Number of frames to load from each hdf5 file. frame_index_stride (int): Interval between frames to load. key (str, optional): Key of hdf5 dataset to grab data from. Defaults to "data/image". initial_frame_axis (int, optional): Axis to iterate over. Defaults to 0. additional_axes_iter (list, optional): Additional axes to iterate over in the dataset. Defaults to None. sort_files (bool, optional): Sort files by number. Defaults to True. overlapping_blocks (bool, optional): Will take n_frames from sequence, then move by 1. Defaults to False. limit_n_frames (int, optional): Limit the number of frames to load from each file. This means n_frames per data file will be used. These will be the first frames in the file. Defaults to None. Returns: list: List of tuples with indices to extract images from hdf5 files. (file_name, key, indices) with indices being a tuple of slices. Example: .. code-block:: python [ ( "/folder/path_to_file.hdf5", "data/image", (slice(0, 1, 1), slice(None, 256, None), slice(None, 256, None)), ), ( "/folder/path_to_file.hdf5", "data/image", (slice(1, 2, 1), slice(None, 256, None), slice(None, 256, None)), ), ..., ] """ if limit_n_frames is None: limit_n_frames = np.inf else: assert limit_n_frames > 0, f"limit_n_frames must be > 0, got {limit_n_frames}" assert len(file_paths) == len(file_shapes), "file_paths and file_shapes must have same length" if additional_axes_iter: # cannot contain initial_frame_axis assert initial_frame_axis not in additional_axes_iter, ( "initial_frame_axis cannot be in additional_axes_iter. " "We are already iterating over that axis." ) else: additional_axes_iter = [] if sort_files: try: # this is like an np.argsort, returns the indices that would sort the array indices_sorting_file_paths = sorted( range(len(file_paths)), key=lambda i: int(re.findall(r"\d+", file_paths[i])[-2]), ) file_paths = [file_paths[i] for i in indices_sorting_file_paths] file_shapes = [file_shapes[i] for i in indices_sorting_file_paths] except Exception: log.warning("Could not sort file_paths by number.") # block size with stride included block_size = n_frames * frame_index_stride if not overlapping_blocks: block_step_size = block_size else: # now blocks overlap by n_frames - 1 block_step_size = 1 def axis_indices_files(): # For every file for shape in file_shapes: n_frames_in_file = shape[initial_frame_axis] # Optionally limit frames to load from each file n_frames_in_file = min(n_frames_in_file, limit_n_frames) indices = [ slice(i, i + block_size, frame_index_stride) for i in range(0, n_frames_in_file - block_size + 1, block_step_size) ] yield [indices] indices = [] skipped_files = 0 for file, shape, axis_indices in zip(file_paths, file_shapes, list(axis_indices_files())): # remove all the files that have empty list at initial_frame_axis # this can happen if the file is too small to fit a block if not axis_indices[0]: # initial_frame_axis is the first entry in axis_indices skipped_files += 1 continue if additional_axes_iter: axis_indices += [list(range(shape[axis])) for axis in additional_axes_iter] axis_indices = product(*axis_indices) for axis_index in axis_indices: full_indices = [slice(size) for size in shape] for i, axis in enumerate([initial_frame_axis] + list(additional_axes_iter)): full_indices[axis] = axis_index[i] indices.append((file, key, tuple(full_indices))) if skipped_files > 0: log.warning( f"Skipping {skipped_files} files with not enough frames " f"which is about {skipped_files / len(file_paths) * 100:.2f}% of the " f"dataset. This can be fine if you expect set `n_frames` and " "`frame_index_stride` to be high. Minimum frames in a file needs to be at " f"least n_frames * frame_index_stride = {n_frames * frame_index_stride}. " ) return indices
[docs] class H5DataSource: """Thread-safe random-access data source for HDF5 files. Implements ``grain.RandomAccessDataSource`` protocol (``__getitem__`` and ``__len__``) so it can be plugged directly into a ``grain.MapDataset`` pipeline. Each worker thread gets its own ``H5FileHandleCache`` via ``threading.local()`` so ``h5py`` file handles are never shared across threads. Args: file_paths: Path(s) to HDF5 directory(ies) or file(s). key: HDF5 dataset key, e.g. ``"data/image"``. n_frames: Number of consecutive frames per sample. frame_index_stride: Stride between frames. frame_axis: Axis along which frames are stacked in the output. insert_frame_axis: Whether to insert a new axis for frames. initial_frame_axis: Source axis that stores frames in the file. additional_axes_iter: Extra axes to iterate over. sort_files: Sort files numerically. overlapping_blocks: Allow overlapping frame blocks. limit_n_samples: Cap the number of samples. limit_n_frames: Cap frames loaded per file. return_filename: Return filename metadata with each sample. cache: Cache loaded samples to RAM. validate: Validate dataset against the zea format. """ def __init__( self, file_paths: List[str] | str, key: str = "data/image", n_frames: int = 1, frame_index_stride: int = 1, frame_axis: int = -1, insert_frame_axis: bool = True, initial_frame_axis: int = 0, additional_axes_iter: tuple | None = None, sort_files: bool = True, overlapping_blocks: bool = False, limit_n_samples: int | None = None, limit_n_frames: int | None = None, return_filename: bool = False, cache: bool = False, validate: bool = True, **kwargs, ): self.return_filename = return_filename self.cache = cache self._data_cache = {} self.key = key self.n_frames = int(n_frames) self.frame_index_stride = int(frame_index_stride) self.frame_axis = int(frame_axis) self.insert_frame_axis = insert_frame_axis assert self.frame_index_stride > 0, ( f"`frame_index_stride` must be > 0, got {self.frame_index_stride}" ) assert self.n_frames > 0, f"`n_frames` must be > 0, got {self.n_frames}" # Discover files and shapes (reuses Dataset machinery) _dataset = Dataset(file_paths, validate=validate, **kwargs) self.file_paths = _dataset.file_paths self.file_shapes = _dataset.load_file_shapes(key) _dataset.close() num_dims = len(self.file_shapes[0]) self.initial_frame_axis = canonicalize_axis(int(initial_frame_axis), num_dims) self.additional_axes_iter = map_negative_indices(list(additional_axes_iter or []), num_dims) # Compute per-sample index table self.indices = generate_h5_indices( file_paths=self.file_paths, file_shapes=self.file_shapes, n_frames=self.n_frames, frame_index_stride=self.frame_index_stride, key=self.key, initial_frame_axis=self.initial_frame_axis, additional_axes_iter=self.additional_axes_iter, sort_files=sort_files, overlapping_blocks=overlapping_blocks, limit_n_frames=limit_n_frames, ) if limit_n_samples is not None: log.info(f"H5DataSource: Limiting to {limit_n_samples} / {len(self.indices)} samples.") self.indices = self.indices[:limit_n_samples] # Thread-local file handle caches (one per thread) self._local = threading.local() self._all_caches: set[H5FileHandleCache] = set() self._all_caches_lock = threading.Lock() def __len__(self) -> int: return len(self.indices) def __getitem__(self, index: int): """Return a single sample as a numpy array. Thread-safe.""" if self.cache and index in self._data_cache: return self._data_cache[index] file_name, key, indices = self.indices[index] file_handle_cache = self._get_file_handle_cache() file = file_handle_cache.get_file(file_name) try: images = file.load_data(key, indices) except (OSError, IOError): # Invalidate cache entry and retry once file_handle_cache.pop(file_name) file = file_handle_cache.get_file(file_name) images = file.load_data(key, indices) if self.insert_frame_axis: initial = self.initial_frame_axis if self.additional_axes_iter: initial -= sum(ax < self.initial_frame_axis for ax in self.additional_axes_iter) images = np.moveaxis(images, initial, self.frame_axis) else: images = np.concatenate(images, axis=self.frame_axis) if self.return_filename: file_data = { "fullpath": file.filename, # same as file.path, but str type "filename": file.stem, "indices": indices, } result = (images, file_data) else: result = images if self.cache: self._data_cache[index] = result return result def __repr__(self) -> str: return ( f"H5DataSource(n_samples={len(self)}, n_files={len(self.file_paths)}, key='{self.key}')" ) def _get_file_handle_cache(self) -> H5FileHandleCache: """Return the file-handle cache for the current thread.""" if not hasattr(self._local, "cache"): self._local.cache = H5FileHandleCache() with self._all_caches_lock: self._all_caches.add(self._local.cache) return self._local.cache
[docs] def close(self): """Close all file handles across all threads.""" with self._all_caches_lock: for c in self._all_caches: c.close() self._all_caches.clear()
[docs] class Dataloader: """High-performance HDF5 dataloader built on `Grain <https://github.com/google/grain>`_. .. code-block:: text grain threads (N) → h5py (thread-local handles) → numpy → user The entire pipeline runs in numpy — no framework dependency until you feed tensors to your model. Does the following in order to load a dataset: - Find all .hdf5 files in the director(ies) - Load the data from each file using the specified key - Apply the following transformations in order (if specified): - shuffle - shard - add channel dim - clip_image_range - assert_image_range - resize - repeat - batch - normalize - augmentation Args: file_paths: Path(s) to directory(ies) and/or HDF5 file(s). key: HDF5 dataset key. Default is ``"data/image"``. batch_size: Batch size. Set to ``None`` to disable batching. Default is ``16``. n_frames: Number of consecutive frames per sample. Default is ``1``. When ``n_frames > 1``, frames are grouped into blocks. shuffle: Shuffle dataset each epoch. Default is ``True``. return_filename: Return filename metadata together with each sample. Default is ``False``. seed: Random seed used for shuffling. Default is ``None``. If ``None`` and ``shuffle=True``, a random seed is generated. limit_n_samples: Limit total number of samples (useful for debugging). Default is ``None`` (no limit). limit_n_frames: Limit frames loaded per file to the first N frames. Default is ``None`` (no limit). drop_remainder: Drop the final incomplete batch. Default is ``False``. image_size: Target ``(height, width)``. Default is ``None`` (no resizing). resize_type: Resize strategy. One of ``"resize"``, ``"center_crop"``, ``"random_crop"`` or ``"crop_or_pad"``. Default is ``None``, which resolves to ``"resize"`` when `image_size` is set. resize_axes: Axes to resize along, must have length 2 (height, width). Only needed when data has more than ``(h, w, c)`` dimensions. Axes are interpreted after frame-axis insertion/reordering. Default is ``None``. resize_kwargs: Extra keyword arguments passed to ``Resizer``. Default is ``None``. image_range: Source value range of images, e.g. ``(-60, 0)``. Used for clipping/asserting/normalization. Default is ``None``. normalization_range: Target value range, e.g. ``(0, 1)``. If set, ``image_range`` must also be set. Default is ``None``. clip_image_range: Clip values to ``image_range`` before normalization. Default is ``False``. assert_image_range: Assert values stay within ``image_range``. Default is ``True``. dataset_repetitions: Repeat dataset this many times. Repetition happens after sharding. Default is ``None`` (no repetition). cache: Cache loaded samples in RAM. Default is ``False``. Note that with ``overlapping_blocks=True``, the same frame can be part of multiple samples, so caching will consume more memory. additional_axes_iter: Additional axes to iterate over in addition to ``initial_frame_axis``. Default is ``None``. sort_files: Sort files numerically before indexing. Default is ``True``. overlapping_blocks: If ``True``, frame blocks overlap by ``n_frames - 1``. Has no effect when ``n_frames == 1``. Default is ``False``. augmentation: Callable applied to each batch after normalization. Default is ``None``. initial_frame_axis: Axis in file data that represents frames. Default is ``0``. insert_frame_axis: If ``True``, keep per-frame samples and move/insert the frame dimension at ``frame_axis``. If ``False``, loaded frames are concatenated along ``frame_axis``. Default is ``True``. frame_index_stride: Step between selected frames in a block. Default is ``1``. frame_axis: Axis along which frames are stacked/placed in output. Default is ``-1``. validate: Validate discovered files against the zea format. Default is ``True``. prefetch: Enable Grain prefetching for iteration. Default is ``True``. shard_index: Shard index to select when ``num_shards > 1``. Must satisfy ``0 <= shard_index < num_shards``. num_shards: Total number of shards for distributed loading. Sharding happens before downstream transforms. Default is ``1``. num_threads: Number of Grain read threads (``0`` means main thread only). Default is ``16``. prefetch_buffer_size: Size of the Grain buffer for reading elements per Python process (not per thread). Useful when reading from a distributed file system. Default is ``500``. Example: .. code-block:: python loader = Dataloader( file_paths="/data/camus", key="data/image_sc", batch_size=32, image_range=(-60, 0), normalization_range=(0, 1), image_size=(256, 256), ) for batch in loader: ... # batch.shape == (32, 256, 256, 1) """ def __init__( self, file_paths: List[str] | str, key: str = "data/image", batch_size: int | None = 16, n_frames: int = 1, shuffle: bool = True, return_filename: bool = False, seed: int | None = None, limit_n_samples: int | None = None, limit_n_frames: int | None = None, drop_remainder: bool = False, image_size: tuple | None = None, resize_type: str | None = None, resize_axes: tuple | None = None, resize_kwargs: dict | None = None, image_range: tuple | None = None, normalization_range: tuple | None = None, clip_image_range: bool = False, assert_image_range: bool = True, dataset_repetitions: int | None = None, cache: bool = False, additional_axes_iter: tuple | None = None, sort_files: bool = True, overlapping_blocks: bool = False, augmentation: callable = None, initial_frame_axis: int = 0, insert_frame_axis: bool = True, frame_index_stride: int = 1, frame_axis: int = -1, validate: bool = True, prefetch: bool = True, shard_index: int | None = None, num_shards: int = 1, num_threads: int = 16, prefetch_buffer_size: int = 500, **kwargs, ): # ── Validation ──────────────────────────────────────────────── if normalization_range is not None: assert image_range is not None, ( "If normalization_range is set, image_range must be set too." ) if num_shards > 1: assert shard_index is not None, "shard_index must be specified" assert 0 <= shard_index < num_shards resize_kwargs = resize_kwargs or {} # ── Store config ────────────────────────────────────────────── self.batch_size = batch_size self.return_filename = return_filename self.num_threads = num_threads self.prefetch_buffer_size = prefetch_buffer_size self.prefetch = prefetch self.shuffle = shuffle # Grain requires a concrete seed for shuffle — generate one if needed if seed is None and shuffle: seed = int(np.random.default_rng().integers(0, 2**31)) self.seed = seed self._rng = np.random.default_rng(seed) # ── Data source ─────────────────────────────────────────────── self.source = H5DataSource( file_paths=file_paths, key=key, n_frames=n_frames, frame_index_stride=frame_index_stride, frame_axis=frame_axis, insert_frame_axis=insert_frame_axis, initial_frame_axis=initial_frame_axis, additional_axes_iter=additional_axes_iter, sort_files=sort_files, overlapping_blocks=overlapping_blocks, limit_n_samples=limit_n_samples, limit_n_frames=limit_n_frames, return_filename=return_filename, cache=cache, validate=validate, **kwargs, ) # ── Store pipeline config for rebuilding per epoch ──────────── self._pipeline_cfg = dict( num_shards=num_shards, shard_index=shard_index, clip_image_range=clip_image_range, assert_image_range=assert_image_range, image_range=image_range, normalization_range=normalization_range, dataset_repetitions=dataset_repetitions, drop_remainder=drop_remainder, augmentation=augmentation, resizer=None, ) # Pre-build the resizer (stateless, reusable across epochs) if image_size or resize_type: resize_type = resize_type or "resize" if frame_axis != -1: assert resize_axes is not None, ( "Resizing only works with frame_axis = -1. Alternatively, " "you can specify resize_axes." ) self._pipeline_cfg["resizer"] = Resizer( image_size=image_size, resize_type=resize_type, resize_axes=resize_axes, seed=seed, **resize_kwargs, ) self._map_dataset = self._build_pipeline(seed) def _build_pipeline(self, seed: int): """Build the Grain MapDataset pipeline with the given shuffle seed.""" cfg = self._pipeline_cfg def _ds_map(ds, fn): if self.return_filename: return ds.map(lambda item: (fn(item[0]), item[1])) return ds.map(fn) ds = grain.MapDataset.source(self.source) if self.shuffle: ds = ds.shuffle(seed=seed) if cfg["num_shards"] > 1: ds = ds[cfg["shard_index"] :: cfg["num_shards"]] ds = _ds_map(ds, self._ensure_channel_dim) if cfg["clip_image_range"] and cfg["image_range"] is not None: lo, hi = cfg["image_range"] ds = _ds_map(ds, lambda x, _lo=lo, _hi=hi: keras.ops.clip(x, _lo, _hi)) if cfg["assert_image_range"] and cfg["image_range"] is not None: _ir = cfg["image_range"] ds = _ds_map(ds, lambda x, _r=_ir: Dataloader._assert_image_range(x, _r)) if cfg["resizer"] is not None: ds = _ds_map(ds, cfg["resizer"]) if cfg["dataset_repetitions"] is not None: ds = ds.repeat(num_epochs=cfg["dataset_repetitions"]) if self.batch_size is not None: ds = ds.batch(batch_size=self.batch_size, drop_remainder=cfg["drop_remainder"]) if cfg["normalization_range"] is not None: _ir, _nr = cfg["image_range"], cfg["normalization_range"] ds = _ds_map(ds, lambda x, _a=_ir, _b=_nr: Dataloader._normalize(x, _a, _b)) if cfg["augmentation"] is not None: ds = _ds_map(ds, cfg["augmentation"]) return ds @property def dataset(self): """The underlying ``grain.MapDataset``.""" return self._map_dataset
[docs] def to_iter_dataset(self): """Convert to a ``grain.IterDataset`` with prefetching. This is called automatically when you iterate, but you can call it explicitly if you want to hold onto the ``IterDataset`` object. """ return self._map_dataset.to_iter_dataset( grain.ReadOptions( num_threads=self.num_threads, prefetch_buffer_size=self.prefetch_buffer_size if self.prefetch else 0, ) )
def __iter__(self): # Rebuild the pipeline with a fresh seed so each epoch sees a different order if self.shuffle: self._map_dataset = self._build_pipeline(seed=int(self._rng.integers(0, 2**31))) return iter(self.to_iter_dataset()) def __len__(self): """Number of batches (or samples if unbatched).""" return len(self._map_dataset) def __repr__(self): return ( f"<Dataloader: {len(self.source)} samples, " f"batch_size={self.batch_size}, " f"key='{self.source.key}', " f"threads={self.num_threads}>" ) @staticmethod def _ensure_channel_dim(image): """Ensure at least 3-D (H, W, C) so batching produces uniform shapes.""" if len(keras.ops.shape(image)) < 3: return keras.ops.expand_dims(image, axis=-1) return image @staticmethod def _assert_image_range(image, image_range): """Assert that image values are within the specified range.""" minval = float(keras.ops.min(image)) maxval = float(keras.ops.max(image)) if minval < image_range[0]: raise ValueError( f"Image min {minval} is below image_range lower bound {image_range[0]}" ) if maxval > image_range[1]: raise ValueError( f"Image max {maxval} is above image_range upper bound {image_range[1]}" ) return image @staticmethod def _normalize(image, image_range, normalization_range): """Normalize image from image_range to normalization_range.""" left_min, left_max = image_range right_min, right_max = normalization_range scale = (right_max - right_min) / (left_max - left_min) offset = right_min - scale * left_min return keras.ops.add(keras.ops.multiply(image, scale), offset)
[docs] def summary(self): """Print dataset statistics and per-directory breakdown.""" src = self.source total_samples = len(src) file_names = [idx[0] for idx in src.indices] directories = sorted({str(Path(f).parent) for f in file_names}) samples_per_dir = count_samples_per_directory(file_names, directories) parts = [f"Dataloader with {total_samples} total samples:"] for dir_path, count in samples_per_dir.items(): pct = (count / total_samples) * 100 if total_samples else 0 parts.append(f" {dir_path}: {count} samples ({pct:.1f}%)") print("\n".join(parts))
[docs] def close(self): """Release file handles.""" self.source.close()