Source code for zea.func.ultrasound

import keras
import numpy as np
import scipy.signal
from keras import ops

from zea import log
from zea.func import split_seed
from zea.func.tensor import (
    resample,
    split_into_windows,
)


[docs] def demodulate_not_jitable( rf_data, sampling_frequency=None, demodulation_frequency=None, bandwidth=None, filter_coeff=None, ): """Demodulates an RF signal to complex base-band (IQ). Demodulates the radiofrequency (RF) bandpass signals and returns the Inphase/Quadrature (I/Q) components. IQ is a complex whose real (imaginary) part contains the in-phase (quadrature) component. This function operates (i.e. demodulates) on the RF signal over the (fast-) time axis which is assumed to be the last axis. Args: rf_data (ndarray): real valued input array of size [..., n_ax, n_el]. second to last axis is fast-time axis. sampling_frequency (float): the sampling frequency of the RF signals (in Hz). Only not necessary when filter_coeff is provided. demodulation_frequency (float, optional): Modulation frequency (in Hz). bandwidth (float, optional): Bandwidth of RF signal in % of center frequency. Defaults to None. The bandwidth in % is defined by: B = Bandwidth_in_% = Bandwidth_in_Hz*(100/center_frequency). The cutoff frequency: Wn = Bandwidth_in_Hz/sampling_frequency, i.e: Wn = B*(center_frequency/100)/sampling_frequency. filter_coeff (list, optional): (b, a), numerator and denominator coefficients of FIR filter for quadratic band pass filter. All other parameters are ignored if filter_coeff are provided. Instead the given filter_coeff is directly used. If not provided, a filter is derived from the other params (sampling_frequency, center_frequency, bandwidth). see https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html Returns: iq_data (ndarray): complex valued base-band signal. """ rf_data = ops.convert_to_numpy(rf_data) assert np.isreal(rf_data).all(), f"RF must contain real RF signals, got {rf_data.dtype}" input_shape = rf_data.shape n_dim = len(input_shape) if n_dim > 2: *_, n_ax, n_el = input_shape else: n_ax, n_el = input_shape if filter_coeff is None: assert sampling_frequency is not None, "provide sampling_frequency when no filter is given." # Time vector t = np.arange(n_ax) / sampling_frequency t0 = 0 t = t + t0 # Estimate center frequency if demodulation_frequency is None: # Keep a maximum of 100 randomly selected scanlines idx = np.arange(n_el) if n_el > 100: idx = np.random.permutation(idx)[:100] # Power Spectrum P = np.sum( np.abs(np.fft.fft(np.take(rf_data, idx, axis=-1), axis=-2)) ** 2, axis=-1, ) P = P[: n_ax // 2] # Carrier frequency idx = np.sum(np.arange(n_ax // 2) * P) / np.sum(P) demodulation_frequency = idx * sampling_frequency / n_ax # Normalized cut-off frequency if bandwidth is None: Wn = min(2 * demodulation_frequency / sampling_frequency, 0.5) bandwidth = demodulation_frequency * Wn else: assert np.isscalar(bandwidth), "The signal bandwidth (in %) must be a scalar." assert (bandwidth > 0) & (bandwidth <= 200), ( "The signal bandwidth (in %) must be within the interval of ]0,200]." ) # bandwidth in Hz bandwidth = demodulation_frequency * bandwidth / 100 Wn = bandwidth / sampling_frequency assert (Wn > 0) & (Wn <= 1), ( "The normalized cutoff frequency is not within the interval of (0,1). " "Check the input parameters!" ) # Down-mixing of the RF signals carrier = np.exp(-1j * 2 * np.pi * demodulation_frequency * t) # add the singleton dimensions carrier = np.reshape(carrier, (*[1] * (n_dim - 2), n_ax, 1)) iq_data = rf_data * carrier # Low-pass filter N = 5 b, a = scipy.signal.butter(N, Wn, "low") # factor 2: to preserve the envelope amplitude iq_data = scipy.signal.filtfilt(b, a, iq_data, axis=-2) * 2 # Display a warning message if harmful aliasing is suspected # the RF signal is undersampled if sampling_frequency < (2 * demodulation_frequency + bandwidth): # lower and higher frequencies of the bandpass signal fL = demodulation_frequency - bandwidth / 2 fH = demodulation_frequency + bandwidth / 2 n = fH // (fH - fL) harmless_aliasing = any( (2 * fH / np.arange(1, n) <= sampling_frequency) & (sampling_frequency <= 2 * fL / np.arange(1, n)) ) if not harmless_aliasing: log.warning( "rf2iq:harmful_aliasing Harmful aliasing is present: the aliases" " are not mutually exclusive!" ) else: b, a = filter_coeff iq_data = scipy.signal.lfilter(b, a, rf_data, axis=-2) * 2 return iq_data
[docs] def upmix(iq_data, sampling_frequency, demodulation_frequency, upsampling_rate=6): """Upsamples and upmixes complex base-band signals (IQ) to RF. Args: iq_data (ndarray): complex valued input array of size [..., n_ax, n_el]. second to last axis is fast-time axis. sampling_frequency (float): the sampling frequency of the input IQ signal (in Hz). resulting sampling_frequency of RF data is upsampling_rate times higher. demodulation_frequency (float, optional): modulation frequency (in Hz). Returns: rf_data (ndarray): output real valued rf data. """ assert iq_data.dtype in [ "complex64", "complex128", ], "IQ must contain all complex signals." input_shape = iq_data.shape n_dim = len(input_shape) if n_dim > 2: *_, n_ax, _ = input_shape else: n_ax, _ = input_shape # Time vector n_ax_up = n_ax * upsampling_rate sampling_frequency_up = sampling_frequency * upsampling_rate t = ops.arange(n_ax_up, dtype="float32") / sampling_frequency_up t0 = 0 t = t + t0 iq_data_upsampled = resample( iq_data, n_samples=n_ax_up, axis=-2, order=1, ) # Up-mixing of the IQ signals t = ops.cast(t, dtype="complex64") demodulation_frequency = ops.cast(demodulation_frequency, dtype="complex64") carrier = ops.exp(1j * 2 * np.pi * demodulation_frequency * t) carrier = ops.reshape(carrier, (*[1] * (n_dim - 2), n_ax_up, 1)) rf_data = iq_data_upsampled * carrier rf_data = ops.real(rf_data) * ops.sqrt(2) return ops.cast(rf_data, "float32")
def _sinc(x): """Return the normalized sinc function. Equivalent to np.sinc(x).""" y = np.pi * ops.where(x == 0, 1.0e-20, x) return ops.sin(y) / y
[docs] def get_band_pass_filter(num_taps, sampling_frequency, f1, f2, validate=True): """Band pass filter Compatible with ``jax.jit`` when ``numtaps`` is static. Based on ``scipy.signal.firwin`` with hamming window. Args: num_taps (int): number of taps in filter. sampling_frequency (float): sample frequency in Hz. f1 (float): cutoff frequency in Hz of left band edge. f2 (float): cutoff frequency in Hz of right band edge. validate (bool, optional): whether to validate the cutoff frequencies. Defaults to True. Returns: ndarray: band pass filter """ sampling_frequency = ops.cast(sampling_frequency, "float32") f1 = ops.cast(f1, "float32") f2 = ops.cast(f2, "float32") nyq = 0.5 * sampling_frequency f1 = f1 / nyq f2 = f2 / nyq if validate: if f1 <= 0 or f2 >= 1: raise ValueError( f"Invalid cutoff frequency: frequencies must be greater than 0 and less than fs/2. " f"Got f1={f1 * nyq} Hz, f2={f2 * nyq} Hz." ) if f1 >= f2: raise ValueError( f"Invalid cutoff frequencies: the frequencies must be strictly increasing. " f"Got f1={f1 * nyq} Hz, f2={f2 * nyq} Hz." ) # Build up the coefficients. alpha = 0.5 * (num_taps - 1) m = ops.arange(0, num_taps, dtype="float32") - alpha h = f2 * _sinc(f2 * m) - f1 * _sinc(f1 * m) # Get and apply the window function. win = np.hamming(num_taps) win = ops.convert_to_tensor(win, dtype=h.dtype) h *= win # Use center frequency for scaling: 0 for lowpass, 1 (Nyquist) for highpass, or band center scale_frequency = ops.where(f1 == 0, 0.0, ops.where(f2 == 1, 1.0, 0.5 * (f1 + f2))) c = ops.cos(np.pi * m * scale_frequency) s = ops.sum(h * c) h /= s return h
[docs] def get_low_pass_iq_filter(num_taps, sampling_frequency, center_frequency, bandwidth): """Design complex low-pass filter. The filter is a low-pass FIR filter modulated to the center frequency. Args: num_taps (int): number of taps in filter. sampling_frequency (float): sample frequency. center_frequency (float): center frequency. bandwidth (float): bandwidth in Hz. Raises: ValueError: if cutoff frequency (bandwidth / 2) is not within (0, sampling_frequency / 2) Returns: ndarray: Complex-valued low-pass filter """ cutoff = bandwidth / 2 if not (0 < cutoff < sampling_frequency / 2): raise ValueError( f"Cutoff frequency must be within (0, sampling_frequency / 2), " f"got {cutoff} Hz, must be within (0, {sampling_frequency / 2}) Hz" ) # Design real-valued low-pass filter lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency) # Modulate to center frequency to make it complex time_points = np.arange(num_taps) / sampling_frequency lpf_complex = lpf * np.exp(1j * 2 * np.pi * center_frequency * time_points) return lpf_complex
[docs] def complex_to_channels(complex_data, axis=-1): """Unroll complex data to separate channels. Args: complex_data (complex ndarray): complex input data. axis (int, optional): on which axis to extend. Defaults to -1. Returns: ndarray: real array with real and imaginary components unrolled over two channels at axis. """ # assert ops.iscomplex(complex_data).any() q_data = ops.imag(complex_data) i_data = ops.real(complex_data) i_data = ops.expand_dims(i_data, axis=axis) q_data = ops.expand_dims(q_data, axis=axis) iq_data = ops.concatenate((i_data, q_data), axis=axis) return iq_data
[docs] def channels_to_complex(data): """Convert array with real and imaginary components at different channels to complex data array. Args: data (ndarray): input data, with at 0 index of axis real component and 1 index of axis the imaginary. Returns: ndarray: complex array with real and imaginary components. """ assert data.shape[-1] == 2, "Data must have two channels." data = ops.cast(data, "complex64") return data[..., 0] + 1j * data[..., 1]
[docs] def hilbert(x, N: int = None, axis=-1): """Implementation of the Hilbert transform function that computes the analytical signal. Operates in the Fourier domain by applying a filter that zeros out negative frequencies and doubles positive frequencies. .. note:: This is NOT the mathematical Hilbert transform as defined in the `Wikipedia article <https://en.wikipedia.org/wiki/Hilbert_transform>`_, but instead computes the analytical signal. The implementation reproduces the behavior of the :func:`scipy.signal.hilbert` function. Args: x (ndarray): Input data of any shape. N (int, optional): Number of points to use for the FFT. If specified and greater than the length of the data along the specified axis, the data will be zero-padded. If None, uses the length of x along the specified axis. Defaults to None. axis (int, optional): Axis along which to compute the Hilbert transform. Defaults to -1 (last axis). Returns: ndarray: Complex analytical signal with the same shape as the input (or padded to length N if specified). The real part is the original signal and the imaginary part is the Hilbert transform of the signal. Raises: ValueError: If N is specified and is less than the length of x along the specified axis. Example: >>> import numpy as np >>> from zea.func import hilbert >>> x = np.array([1.0, 2.0, 3.0, 4.0]) >>> analytical_signal = hilbert(x) >>> envelope = np.abs(analytical_signal) """ input_shape = x.shape n_dim = len(input_shape) n_ax = input_shape[axis] if axis < 0: axis = n_dim + axis if N is not None: if N < n_ax: raise ValueError(f"N must be greater or equal to n_ax, got N={N}, n_ax={n_ax}") pad = np.maximum(N - n_ax, 0) pad_list = [[0, 0] for _ in range(n_dim)] pad_list[axis] = [0, pad] x = ops.pad(x, pad_list, mode="constant", constant_values=0.0) else: N = n_ax # Create filter to zero out negative frequencies # h[0] = 1, h[1:N//2] = 2, h[N//2] = 1 (if even), rest = 0 indices = ops.arange(N, dtype="float32") h = ops.zeros(N, dtype="float32") h = ops.where(indices == 0, 1.0, h) h = ops.where((indices > 0) & (indices < N / 2.0), 2.0, h) h = ops.where((N % 2 == 0) & (indices == N / 2.0), 1.0, h) h = ops.cast(h, "complex64") idx = list(range(n_dim)) # make sure axis gets to the end for fft (operates on last axis) idx.remove(axis) idx.append(axis) x = ops.transpose(x, idx) if x.ndim > 1: h = ops.reshape(h, [1] * (x.ndim - 1) + [-1]) h = h + 1j * ops.zeros_like(h) Xf_r, Xf_i = ops.fft((x, ops.zeros_like(x))) Xf_r = ops.cast(Xf_r, "complex64") Xf_i = ops.cast(Xf_i, "complex64") Xf = Xf_r + 1j * Xf_i Xf = Xf * h # x = np.fft.ifft(Xf) # do manual ifft using fft Xf_r = ops.real(Xf) Xf_i = ops.imag(Xf) Xf_r_inv, Xf_i_inv = ops.fft((Xf_r, -Xf_i)) Xf_i_inv = ops.cast(Xf_i_inv, "complex64") Xf_r_inv = ops.cast(Xf_r_inv, "complex64") N = ops.cast(N, "complex64") x = Xf_r_inv / N x = x + 1j * (-Xf_i_inv / N) # switch back to original shape idx = list(range(n_dim)) idx.insert(axis, idx.pop(-1)) x = ops.transpose(x, idx) return x
[docs] def demodulate(data, demodulation_frequency, sampling_frequency, axis=-3): """Demodulates the input data to baseband. The function computes the analytical signal (the signal with negative frequencies removed) and then shifts the spectrum of the signal to baseband by multiplying with a complex exponential. Where the spectrum was centered around `center_frequency` before, it is now centered around 0 Hz. The baseband IQ data are complex-valued. The real and imaginary parts are stored in two real-valued channels. Args: data (ops.Tensor): The input data to demodulate of shape `(..., axis, ..., 1)`. demodulation_frequency (float): The center frequency of the signal. sampling_frequency (float): The sampling frequency of the signal. axis (int, optional): The axis along which to demodulate. Defaults to -3. Returns: ops.Tensor: The demodulated IQ data of shape `(..., axis, ..., 2)`. """ # Compute the analytical signal analytical_signal = hilbert(data, axis=axis) # Define frequency indices frequency_indices = ops.arange(analytical_signal.shape[axis]) # Expand the frequency indices to match the shape of the RF data indexing = [None] * data.ndim indexing[axis] = slice(None) indexing = tuple(indexing) frequency_indices_shaped_like_rf = frequency_indices[indexing] # Cast to complex64 demodulation_frequency = ops.cast(demodulation_frequency, dtype="complex64") sampling_frequency = ops.cast(sampling_frequency, dtype="complex64") frequency_indices_shaped_like_rf = ops.cast(frequency_indices_shaped_like_rf, dtype="complex64") # Shift to baseband phasor_exponent = ( -1j * 2 * np.pi * demodulation_frequency * frequency_indices_shaped_like_rf / sampling_frequency ) iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent) # Split the complex signal into two channels iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1)) return iq_data_two_channel
[docs] def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6): """Compute the time of the peak of each waveform in a stack of waveforms. Args: waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples). center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape (n_waveforms,) or a scalar if all waveforms have the same center frequency. waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz. Returns: ndarray: The time to peak for each waveform in seconds. """ t_peak = [] center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],)) for waveform, center_frequency in zip(waveforms, center_frequencies): t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency)) return ops.stack(t_peak)
[docs] def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6): """Compute the time of the peak of the waveform. Args: waveform (ndarray): The waveform of shape (n_samples). center_frequency (float): The center frequency of the waveform in Hz. waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz. Returns: float: The time to peak for the waveform in seconds. """ n_samples = waveform.shape[0] if n_samples == 0: raise ValueError("Waveform has zero samples.") waveforms_iq_complex_channels = demodulate( waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1 ) waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels) envelope = ops.abs(waveforms_iq_complex) peak_idx = ops.argmax(envelope, axis=-1) t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency return t_peak
[docs] def envelope_detect(data, axis=-3): """Envelope detection of RF signals. If the input data is real, it first applies the Hilbert transform along the specified axis and then computes the magnitude of the resulting complex signal. If the input data is complex, it computes the magnitude directly. Args: - data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch). - axis (int): Axis along which to apply the Hilbert transform. Defaults to -3. Returns: - envelope_data (Tensor): The envelope detected data of shape (..., grid_size_z, grid_size_x). """ if data.shape[-1] == 2: data = channels_to_complex(data) else: n_ax = ops.shape(data)[axis] # Calculate next power of 2: M = 2^ceil(log2(n_ax)) # see https://github.com/tue-bmd/zea/discussions/147 log2_n_ax = np.log2(n_ax) M = int(2 ** np.ceil(log2_n_ax)) data = hilbert(data, N=M, axis=axis) indices = ops.arange(n_ax) data = ops.take(data, indices, axis=axis) data = ops.squeeze(data, axis=-1) data = ops.abs(data) return data
[docs] def log_compress(data, eps=1e-16): """Apply logarithmic compression to data.""" eps = ops.convert_to_tensor(eps, dtype=data.dtype) data = ops.where(data == 0, eps, data) # Avoid log(0) return 20 * ops.log10(data)
[docs] def make_tgc_curve(n_ax, attenuation_coef, sampling_frequency, center_frequency, sound_speed=1540): """ Create a Time Gain Compensation (TGC) curve to compensate for depth-dependent attenuation. Args: n_ax (int): Number of samples in the axial direction attenuation_coef (float): Attenuation coefficient in dB/cm/MHz. For example, typical value for soft tissue is around 0.5 to 0.75 dB/cm/MHz. sampling_frequency (float): Sampling frequency in Hz center_frequency (float): Center frequency in Hz sound_speed (float): Speed of sound in m/s (default: 1540) Returns: np.ndarray: TGC gain curve of shape (n_ax,) in linear scale """ # Time vector for each sample t = np.arange(n_ax) / sampling_frequency # seconds # Distance traveled (round trip, so divide by 2) dist = (t * sound_speed) / 2 # meters # Convert distance to cm dist_cm = dist * 100 # Attenuation in dB (two-way: transmit + receive) attenuation_db = 2 * attenuation_coef * dist_cm * (center_frequency * 1e-6) # Convert dB to linear scale (TGC gain curve) tgc_gain_curve = 10 ** (attenuation_db / 20) return tgc_gain_curve.astype(np.float32)
[docs] def dehaze_nuclear_diffusion( hazy_video, diffusion_model, n_steps: int = 5000, initial_step: int = 4500, window_size: int = 7, window_stride: int | None = None, hard_project: bool = True, seed=None, verbose: bool = True, **guidance_kwargs, ): r"""Dehaze ultrasound videos using Nuclear Diffusion posterior sampling. This function performs video dehazing by combining diffusion posterior sampling with low-rank temporal modeling. It processes long video sequences by splitting them into overlapping windows, applying `Nuclear Diffusion <https://tue-bmd.github.io/nuclear-diffusion/>`_ to each window, and averaging predictions across windows for smooth results. .. seealso:: - :doc:`../../notebooks/models/nuclear_dehazing_example`: Detailed tutorial notebook - :class:`~zea.models.diffusion.NuclearDiffusion`: The guidance method used for dehazing - :func:`~zea.func.split_into_windows`: Window splitting utility The method performs posterior sampling to separate the video into: - **Tissue component** (:math:`\mathbf{X}`): Dynamic foreground signal with complex structure - **Haze component** (:math:`\mathbf{L}`): Low-rank background artifacts Nuclear Diffusion replaces the sparsity prior in RPCA with a learned diffusion prior while maintaining a nuclear norm penalty on the background component. Given video observations :math:`\mathbf{Y} \in \mathbb{R}^{n \times p}`, the method jointly samples: .. math:: \mathbf{X}, \mathbf{L} \sim p_\theta(\mathbf{X}, \mathbf{L} \mid \mathbf{Y}) where :math:`\mathbf{X}` is the dynamic foreground (tissue) and :math:`\mathbf{L}` is the low-rank background (haze). The posterior factorizes as: .. math:: p(\mathbf{Y}, \mathbf{L}, \mathbf{X}) = p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) \, p(\mathbf{L}) \, p_\theta(\mathbf{X}) - **Likelihood**: :math:`p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) = \mathcal{N}(\mathbf{Y}; \mathbf{L}+\mathbf{X}, \mu^{-1} \mathbf{I})` - **Low-rank prior**: :math:`p(\mathbf{L}) \propto \exp(-\gamma \|\mathbf{L}\|_*)` where :math:`\|\mathbf{L}\|_* = \sum_i \sigma_i(\mathbf{L})` is the nuclear norm - **Diffusion prior**: :math:`p_\theta(\mathbf{X})` learned from data, capturing complex signal structure The method operates by alternating between reverse diffusion and measurement-guided updates, minimizing both the data fidelity and the low-rank penalty. This allows it to effectively separate structured foreground dynamics from the low-rank haze, even when the foreground is not sparse. Args: hazy_video: Input hazy video as a tensor of shape ``(frames, height, width, channels)``. diffusion_model: Pre-trained diffusion model configured with Nuclear Diffusion guidance (``guidance="nuclear-dps"``) and haze operator (``operator="linear_interp"``). n_steps: Number of diffusion steps for posterior sampling. More steps generally produce better quality but take longer. Default is 5000. initial_step: Starting step for progressive blending in the diffusion process. Must be less than ``n_steps`` and non-negative. Passed to the NuclearDiffusion guidance function's ``compute_error`` method. Default is 4500. window_size: Number of frames to process together in each window. Larger windows capture more temporal context but require more memory. Default is 7. window_stride: Stride between consecutive windows. If ``None``, uses non-overlapping windows (stride = window_size). Smaller strides create more overlap and smoother results but increase computation time. hard_project: Whether to preserve bright speckle values from the input by projecting positive values from the hazy input. This helps preserve fine tissue texture. Default is ``True``. seed: Random seed for reproducibility. If ``None``, uses default random state. verbose: Whether to display progress information. Default is ``True``. **guidance_kwargs: Additional keyword arguments for Nuclear Diffusion guidance: - **omega** (float): Weight for measurement error term. Default is 1.0. - **gamma** (float): Weight for nuclear norm penalty. Default is 1.0. - **rank_weight_factor** (float, optional): Enhanced weighting for larger singular values. Returns: tuple: A tuple ``(tissue_frames, haze_frames)`` containing: - **tissue_frames**: Dehazed tissue component as a numpy array. - **haze_frames**: Estimated low-rank haze component as a numpy array. Raises: ValueError: If the model is not configured with Nuclear Diffusion guidance. .. note:: This function requires a diffusion model with Nuclear Diffusion guidance. Initialize your model with ``guidance="nuclear-dps"`` and ``operator="linear_interp"``. .. admonition:: Reference T. S. W. Stevens, O. Nolan, J.-L. Robert, and R. J. G. van Sloun, "Nuclear Diffusion Models for Low-Rank Background Suppression in Videos," *IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)*, 2026. https://arxiv.org/abs/2509.20886 """ # noqa: E501 assert initial_step < n_steps, "initial_step must be less than n_steps." assert initial_step >= 0, "initial_step must be non-negative." assert diffusion_model is not None, ( "You must pass a diffusion model to `dehaze_nuclear_diffusion`. To see which models are " "available on zeahub, visit https://huggingface.co/zeahub/models or" "see the available presets: https://github.com/tue-bmd/zea/blob/main/zea/models/presets.py" ) def _nuclear_diffusion_posterior_sample( diffusion_model, measurements, n_steps: int, seed=None, verbose: bool = True, initial_step: int = 100, **guidance_kwargs, ): """Internal method for Nuclear Diffusion posterior sampling. This method performs posterior sampling for a single batch/window of frames. It alternates between reverse diffusion on the tissue component and gradient updates that enforce measurement consistency and low-rank structure on the haze component. Args: diffusion_model: The diffusion model with Nuclear Diffusion guidance. measurements: Measurements of shape ``(batch, frames, H, W, C)``. n_steps: Number of diffusion steps. seed: Random seed. verbose: Whether to show progress. initial_step: Starting diffusion step. **guidance_kwargs: Guidance parameters (omega, gamma, etc.). Returns: tuple: ``(tissue_images, haze_images)`` as tensors. """ measurements = ops.convert_to_tensor(measurements) image_shape = ops.shape(measurements) # Ensure 5D input: (batch, frames, height, width, channels) if len(image_shape) != 5: raise ValueError(f"Expected 5D input (batch, frames, H, W, C), got shape {image_shape}") n_batches, n_frames, image_height, image_width, n_channels = image_shape frame_shape = (n_batches, n_frames, image_height, image_width, n_channels) # Prepare diffusion: validates params, computes step size, sets up progress tracking step_size, progbar = diffusion_model.prepare_diffusion(n_steps, initial_step, verbose) # Seed splitting handles None gracefully across all backends seed, seed1 = split_seed(seed, 2) initial_noise_tissue = keras.random.normal(shape=frame_shape, seed=seed1) initial_noise_haze = ops.zeros(frame_shape) # Base diffusion times (same pattern as reverse_diffusion / reverse_conditional_diffusion) base_diffusion_times = ops.ones((n_batches, n_frames, 1, 1, 1)) * diffusion_model.max_t # Initialize noisy samples at the starting diffusion time start_diffusion_times = base_diffusion_times - initial_step * step_size noise_rates, signal_rates = diffusion_model.diffusion_schedule(start_diffusion_times) next_noisy_tissue = signal_rates * measurements + noise_rates * initial_noise_tissue next_noisy_haze = initial_noise_haze initial_step_t = ops.convert_to_tensor(initial_step, dtype=initial_noise_tissue.dtype) # Reverse diffusion loop for step in range(initial_step, n_steps): noisy_tissue = next_noisy_tissue noisy_haze = next_noisy_haze # Compute diffusion schedule for current and next step diffusion_times = base_diffusion_times - step * step_size noise_rates, signal_rates = diffusion_model.diffusion_schedule(diffusion_times) next_diffusion_times = diffusion_times - step_size next_noise_rates, next_signal_rates = diffusion_model.diffusion_schedule( next_diffusion_times ) # Compute gradients from guidance function ( (gradients_tissue, gradients_haze), ( measurement_error, (pred_noises_tissue, pred_tissue, pred_haze, l2_error, nuclear_penalty), ), ) = diffusion_model.guidance_fn( noisy_tissue, noisy_haze, measurements=measurements, noise_rates=noise_rates, signal_rates=signal_rates, initial_step=initial_step_t, step=step, total_steps=n_steps, **guidance_kwargs, ) # DDIM step for tissue component (deterministic) next_noisy_tissue = diffusion_model.reverse_diffusion_step( shape=frame_shape, pred_images=pred_tissue, pred_noises=pred_noises_tissue, signal_rates=signal_rates, next_signal_rates=next_signal_rates, next_noise_rates=next_noise_rates, ) next_noisy_haze = pred_haze # Apply guidance updates next_noisy_tissue = next_noisy_tissue - gradients_tissue next_noisy_haze = next_noisy_haze - gradients_haze if progbar is not None: progbar.update( step + 1, [ ("total_error", measurement_error), ("l2_error", l2_error), ("nuclear_penalty", nuclear_penalty), ], ) return pred_tissue, pred_haze # Validate configuration if diffusion_model.guidance_fn is None: raise ValueError( "Model must have guidance function set. Initialize with guidance='nuclear-dps'." ) # Import here to avoid circular dependency from zea.models.diffusion import NuclearDiffusion if not isinstance(diffusion_model.guidance_fn, NuclearDiffusion): raise ValueError( f"dehaze_nuclear_diffusion() requires Nuclear Diffusion guidance, " f"but model has {type(diffusion_model.guidance_fn).__name__}. " "Initialize with guidance='nuclear-dps'." ) # Get sequence length seq_len = ops.shape(hazy_video)[0] if verbose: log.info(f"[Nuclear Diffusion] Processing {seq_len} frames.") # Split video into windows windows, window_indices = split_into_windows( hazy_video, window_size=window_size, stride=window_stride ) if verbose: log.info( f"[Nuclear Diffusion] Split into {len(windows)} windows with sizes:" f" {[len(w) for w in windows]}" ) # Accumulate predictions for each frame frame_tissue_preds = [[] for _ in range(int(seq_len))] frame_haze_preds = [[] for _ in range(int(seq_len))] progbar = keras.utils.Progbar(len(windows), verbose=verbose, unit_name="window") # Process each window for window_idx, (window, frame_indices) in enumerate(zip(windows, window_indices)): window_batch = ops.expand_dims(window, axis=0) # Add batch dimension seed, window_seed = split_seed(seed, 2) tissue_images, haze_images = _nuclear_diffusion_posterior_sample( diffusion_model, measurements=window_batch, n_steps=n_steps, initial_step=initial_step, seed=window_seed, verbose=False, # Disable per-window progress **guidance_kwargs, ) # Remove batch dimension tissue_frames_window = ops.squeeze(tissue_images, axis=0) haze_frames_window = ops.squeeze(haze_images, axis=0) # Accumulate predictions for overlapping frames for i, frame_idx in enumerate(frame_indices): frame_tissue_preds[frame_idx].append(tissue_frames_window[i]) frame_haze_preds[frame_idx].append(haze_frames_window[i]) progbar.add(1) # Average predictions across overlapping windows tissue_frames = [] haze_frames = [] for i in range(int(seq_len)): # Stack and average tissue predictions stacked_tissue = ops.stack(frame_tissue_preds[i], axis=0) tissue_frames.append(ops.mean(stacked_tissue, axis=0)) # Stack and average haze predictions stacked_haze = ops.stack(frame_haze_preds[i], axis=0) haze_frames.append(ops.mean(stacked_haze, axis=0)) # Stack frames into sequences tissue_frames = ops.stack(tissue_frames, axis=0) haze_frames = ops.stack(haze_frames, axis=0) # Apply hard projection if requested if hard_project: tissue_np = ops.convert_to_numpy(tissue_frames) hazy_np = ops.convert_to_numpy(hazy_video) # Preserve bright speckle values from hazy input proj = tissue_np.copy() proj[proj > 0] = hazy_np[proj > 0] tissue_frames = proj # Recompute haze from preserved tissue haze_frames = hazy_np - tissue_frames - 1 else: # Convert to numpy tissue_frames = ops.convert_to_numpy(tissue_frames) haze_frames = ops.convert_to_numpy(haze_frames) hazy_np = ops.convert_to_numpy(hazy_video) haze_frames = hazy_np - haze_frames - 1 return tissue_frames, haze_frames