zea.models.diffusionΒΆ

Diffusion models for ultrasound image generation and posterior sampling.

To try this model, simply load one of the available presets:

>>> from zea.models.diffusion import DiffusionModel

>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic")

See also

A tutorial notebook where this model is used: Diffusion models for ultrasound image generation.

Classes

DDS(diffusion_model, operator[, disable_jit])

Decomposed Diffusion Sampling guidance.

DPS(diffusion_model, operator[, disable_jit])

Diffusion Posterior Sampling guidance.

DiffusionGuidance(diffusion_model, operator)

Base class for diffusion guidance methods.

DiffusionModel(*args, **kwargs)

Implementation of a diffusion generative model.

NuclearDiffusion(diffusion_model, operator)

Nuclear Diffusion posterior sampling guidance.

class zea.models.diffusion.DDS(diffusion_model, operator, disable_jit=False)[source]ΒΆ

Bases: DiffusionGuidance

Decomposed Diffusion Sampling guidance.

Reference paper: https://arxiv.org/pdf/2303.05754

Initialize the diffusion guidance.

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model to use for guidance.

  • disable_jit (bool) – Whether to disable JIT compilation.

Acg(x, **op_kwargs)[source]ΒΆ
__call__(noisy_images, measurements, noise_rates, signal_rates, n_inner=5, eps=1e-05, verbose=False, **op_kwargs)[source]ΒΆ

Call the DDS guidance function

Parameters:
  • noisy_images – Noisy images.

  • measurement – Target measurement.

  • noise_rates – Current noise rates.

  • signal_rates – Current signal rates.

  • n_inner – Number of conjugate gradient steps.

  • eps – Convergence threshold for conjugate gradient.

  • verbose – Whether to calculate error.

  • **kwargs – Additional arguments for the operator.

Returns:

Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))

call(noisy_images, measurements, noise_rates, signal_rates, n_inner, eps, verbose, **op_kwargs)[source]ΒΆ

Call the DDS guidance function

Parameters:
  • noisy_images – Noisy images.

  • measurement – Target measurement.

  • noise_rates – Current noise rates.

  • signal_rates – Current signal rates.

  • n_inner – Number of conjugate gradient steps.

  • eps – Convergence threshold for conjugate gradient.

  • verbose – Whether to calculate error.

Returns:

Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))

conjugate_gradient_inner_loop(i, loop_state, eps=1e-05)[source]ΒΆ

A single iteration of the conjugate gradient method. This involves minimizing the error of x along the current search vector p, and then choosing the next search vector.

Reference code from: https://github.com/svi-diffusion/

setup()[source]ΒΆ

Setup DDS guidance function.

class zea.models.diffusion.DPS(diffusion_model, operator, disable_jit=False)[source]ΒΆ

Bases: DiffusionGuidance

Diffusion Posterior Sampling guidance.

Initialize the diffusion guidance.

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model to use for guidance.

  • disable_jit (bool) – Whether to disable JIT compilation.

__call__(noisy_images, **kwargs)[source]ΒΆ

Call the gradient function.

Parameters:
  • noisy_images – Noisy images.

  • measurement – Target measurement.

  • operator – Forward operator.

  • noise_rates – Current noise rates.

  • signal_rates – Current signal rates.

  • omega – Weight for the measurement error.

  • **kwargs – Additional arguments for the operator.

Returns:

Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))

compute_error(noisy_images, measurements, noise_rates, signal_rates, omega, **kwargs)[source]ΒΆ

Compute measurement error for diffusion posterior sampling.

Parameters:
  • noisy_images – Noisy images.

  • measurements – Target measurement.

  • noise_rates – Current noise rates.

  • signal_rates – Current signal rates.

  • omega – Weight for the measurement error.

  • **kwargs – Additional arguments for the operator.

Returns:

Tuple of (measurement_error, (pred_noises, pred_images))

setup()[source]ΒΆ

Setup the autograd function for DPS.

class zea.models.diffusion.DiffusionGuidance(diffusion_model, operator, disable_jit=False)[source]ΒΆ

Bases: ABC, Object

Base class for diffusion guidance methods.

Initialize the diffusion guidance.

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model to use for guidance.

  • disable_jit (bool) – Whether to disable JIT compilation.

abstractmethod __call__(*args, **kwargs)[source]ΒΆ

Call the guidance function.

abstractmethod setup()[source]ΒΆ

Setup the guidance function. Should be implemented by subclasses.

class zea.models.diffusion.DiffusionModel(*args, **kwargs)[source]ΒΆ

Bases: DeepGenerativeModel

Implementation of a diffusion generative model. Heavily inspired from https://keras.io/examples/generative/ddim/

Initialize a diffusion model.

Parameters:
  • input_shape – Shape of the input data. Typically of the form (height, width, channels) for images.

  • input_range – Range of the input data.

  • min_signal_rate – Minimum signal rate for the diffusion schedule.

  • max_signal_rate – Maximum signal rate for the diffusion schedule.

  • network_name – Name of the network architecture to use. Options are β€œunet_time_conditional” or β€œdense_time_conditional”.

  • network_kwargs – Additional keyword arguments for the network.

  • name – Name of the model.

  • guidance – Guidance method to use. Can be a string, or dict with β€œname” and β€œparams” keys. Additionally, can be a DiffusionGuidance object.

  • operator – Operator to use. Can be a string, or dict with β€œname” and β€œparams” keys. Additionally, can be a Operator object.

  • ema_val – Exponential moving average value for the network weights.

  • min_t – Minimum diffusion time for sampling during training.

  • max_t – Maximum diffusion time for sampling during training.

  • **kwargs – Additional arguments.

call(inputs, training=False, network=None, **kwargs)[source]ΒΆ

Calls the score network.

If network is not provided, will use the exponential moving average network if training is False, otherwise the regular network.

denoise(noisy_images, noise_rates, signal_rates, training, network=None)[source]ΒΆ

Predict noise component and calculate the image component using it.

diffusion_schedule(diffusion_times)[source]ΒΆ

Cosine diffusion schedule https://arxiv.org/abs/2102.09672

Parameters:

diffusion_times – tensor with diffusion times in [0, 1]

Returns:

tensor with noise rates signal_rates: tensor with signal rates

according to: - x_t = signal_rate * x_0 + noise_rate * noise - x_t = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t) * noise

or with stochastic sampling: - x_t = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t - sigma_t^2) * noise + sigma_t * epsilon

where: - sigma_t = sqrt((1 - alpha_t) / (1 - alpha_{t+1})) * sqrt(1 - alpha_{t+1} / alpha_t)

Return type:

noise_rates

Note

t+1 = previous time step t = current time step

get_config()[source]ΒΆ

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

linear_diffusion_schedule(diffusion_times)[source]ΒΆ

Create a linear diffusion schedule

log_likelihood(data, **kwargs)[source]ΒΆ

Approximate log-likelihood of the data under the model.

Parameters:
  • data – Data to compute log-likelihood for.

  • **kwargs – Additional arguments.

Returns:

Approximate log-likelihood.

property metricsΒΆ

Metrics for training.

posterior_sample(measurements, n_samples=1, n_steps=20, initial_step=0, initial_samples=None, seed=None, **kwargs)[source]ΒΆ

Sample from the posterior distribution given measurements.

Parameters:
  • measurements – Input measurements. Typically of shape (batch_size, *input_shape).

  • n_samples – Number of posterior samples to generate. Will generate n_samples samples for each measurement in the measurements batch.

  • n_steps – Number of diffusion steps.

  • initial_step – Initial step to start from. Can warm start the diffusion process with a partially noised image, thereby skipping part of the diffusion process. Initial step closer to n_steps, will result in a shorter diffusion process (i.e. less noise added to the initial image). A value of 0 means that the diffusion process starts from pure noise.

  • initial_samples – Optional initial samples to start from. If provided, these samples will be used as the starting point for the diffusion process. Only used if initial_step is greater than 0. Must be of shape (batch_size, n_samples, *input_shape).

  • seed – Random seed generator.

  • **kwargs – Additional arguments.

Returns:

(batch_size, n_samples, *input_shape).

Return type:

Posterior samples p(x|y), of shape

prepare_diffusion(diffusion_steps, initial_step, verbose, disable_jit=False)[source]ΒΆ

Prepare the diffusion process.

This method sets up the parameters for the diffusion process, including validation of the initial step and calculation of the step size.

prepare_schedule(base_diffusion_times, initial_noise, initial_samples, initial_step, step_size)[source]ΒΆ

Prepare the diffusion schedule.

This method sets up the initial noisy images based on the provided initial noise and samples. It handles the case where the initial step is greater than 0, allowing for the use of partially noised images for initialization of the diffusion process.

Parameters:
  • base_diffusion_times – Base diffusion times.

  • initial_noise – Initial noise tensor.

  • initial_samples – Optional initial samples to start from.

  • initial_step – Initial step to start from.

  • step_size – Step size for the diffusion process.

Returns:

Noisy images after the initial step.

Return type:

next_noisy_images

reverse_conditional_diffusion(measurements, initial_noise, diffusion_steps, initial_samples=None, initial_step=0, stochastic_sampling=False, seed=None, verbose=False, track_progress_type='x_0', disable_jit=False, **kwargs)[source]ΒΆ

Reverse diffusion process conditioned on some measurement.

Effectively performs diffusion posterior sampling p(x_0 | y).

Parameters:
  • measurements – Conditioning data.

  • initial_noise – Initial noise tensor.

  • diffusion_steps (int) – Number of diffusion steps.

  • initial_samples – Optional initial samples to start from.

  • initial_step (int) – Initial step to start from.

  • stochastic_sampling (bool) – Whether to use stochastic sampling (DDPM).

  • seed – Random seed generator.

  • verbose (bool) – Whether to show a progress bar.

  • track_progress_type (Literal[None, 'x_0', 'x_t']) – Type of progress tracking (β€œx_0” or β€œx_t”).

  • **kwargs – Additional arguments. These are passed to the guidance function and the operator. Examples are omega, mask, etc.

Returns:

Generated images.

reverse_diffusion(initial_noise, diffusion_steps, initial_samples=None, initial_step=0, stochastic_sampling=False, seed=None, verbose=True, track_progress_type='x_0', disable_jit=False, training=False, network_type=None)[source]ΒΆ

Reverse diffusion process to generate images from noise.

Parameters:
  • initial_noise – Initial noise tensor.

  • diffusion_steps (int) – Number of diffusion steps.

  • initial_samples – Optional initial samples to start from.

  • initial_step (int) – Initial step to start from.

  • stochastic_sampling (bool) – Whether to use stochastic sampling (DDPM).

  • seed (SeedGenerator | None) – Random seed generator.

  • verbose (bool) – Whether to show a progress bar.

  • track_progress_type (Literal[None, 'x_0', 'x_t']) – Type of progress tracking (β€œx_0” or β€œx_t”).

  • disable_jit (bool) – Whether to disable JIT compilation.

  • training (bool) – Whether to use the training mode of the network.

  • network_type (Literal[None, 'main', 'ema']) – Which network to use (β€œmain” or β€œema”). If None, uses the network based on the training argument.

Returns:

Generated images.

reverse_diffusion_step(shape, pred_images, pred_noises, signal_rates, next_signal_rates, next_noise_rates, seed=None, stochastic_sampling=False)[source]ΒΆ

A single reverse diffusion step.

Parameters:
  • shape – Shape of the input tensor.

  • pred_images – Predicted images.

  • pred_noises – Predicted noises.

  • signal_rates – Current signal rates.

  • next_signal_rates – Next signal rates.

  • next_noise_rates – Next noise rates.

  • seed – Random seed generator.

  • stochastic_sampling – Whether to use stochastic sampling (DDPM).

Returns:

Noisy images after the reverse diffusion step.

Return type:

next_noisy_images

sample(n_samples=1, n_steps=20, seed=None, **kwargs)[source]ΒΆ

Sample from the model.

Parameters:
  • n_samples – Number of samples to generate.

  • n_steps – Number of diffusion steps.

  • seed – Random seed generator.

  • **kwargs – Additional arguments.

Returns:

Generated samples of shape (n_samples, *input_shape).

start_track_progress(diffusion_steps, initial_step=0)[source]ΒΆ

Initialize the progress tracking for the diffusion process. For diffusion animation we keep track of the diffusion progress. For large number of steps, we do not store all the images due to memory constraints.

store_progress(step, track_progress_type, next_noisy_images, pred_images)[source]ΒΆ

Store the progress of the diffusion process.

Parameters:
  • step – Current diffusion step.

  • track_progress_type – Type of progress tracking (β€œx_0” or β€œx_t”).

  • next_noisy_images – Noisy images after the current step.

  • pred_images – Predicted images.

Notes

  • x_0 is considered the predicted image (aka Tweedie estimate)

  • x_t is the noisy intermediate image

test_step(data)[source]ΒΆ

Custom test step so we can call model.fit() on the diffusion model.

train_step(data)[source]ΒΆ

Custom train step so we can call model.fit() on the diffusion model. .. note:: - Only implemented for the TensorFlow backend.

class zea.models.diffusion.NuclearDiffusion(diffusion_model, operator, disable_jit=False)[source]ΒΆ

Bases: DPS

Nuclear Diffusion posterior sampling guidance.

A hybrid framework that combines diffusion posterior sampling (DPS) with low-rank temporal modeling for video restoration. This method replaces the sparsity assumption in Robust Principal Component Analysis (RPCA) with a learned diffusion prior while maintaining a nuclear norm penalty on the background component to encourage low-rank temporal structure.

See also

Mathematical Formulation:

Given observations \(\mathbf{Y} \in \mathbb{R}^{n \times p}\) (video frames), Nuclear Diffusion jointly samples the signal \(\mathbf{X}\) and low-rank background \(\mathbf{L}\) from the posterior:

\[\mathbf{X}, \mathbf{L} \sim p_\theta(\mathbf{X}, \mathbf{L} \mid \mathbf{Y})\]

The posterior is factorized as:

\[p(\mathbf{Y}, \mathbf{L}, \mathbf{X}) = p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) \, p(\mathbf{L}) \, p_\theta(\mathbf{X})\]

where:

  • \(p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) = \mathcal{N}(\mathbf{Y}; \mathbf{L}+\mathbf{X}, \mu^{-1} \mathbf{I})\) is the likelihood (measurement model)

  • \(p(\mathbf{L}) \propto \exp(-\gamma \|\mathbf{L}\|_*)\) enforces low-rank structure via the nuclear norm \(\|\mathbf{L}\|_* = \sum_i \sigma_i(\mathbf{L})\)

  • \(p_\theta(\mathbf{X})\) is a learned diffusion prior capturing complex signal structure

The diffusion prior operates on individual frames \(\mathbf{x}^t \in \mathbb{R}^n\), while temporal dependencies are enforced through the nuclear norm on \(\mathbf{L}\).

This guidance method alternates between reverse diffusion and measurement-guided updates, computing gradients from both the measurement error and the nuclear norm penalty:

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model for the signal component.

  • operator (Operator) – Forward operator defining the measurement model.

  • disable_jit (bool) – Whether to disable JIT compilation.

Reference

T. Stevens, M. Wijkstra, M. Mischi, and R. J. G. van Sloun, β€œNuclear Diffusion Models for Low-Rank Background Suppression in Videos,” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2026. https://arxiv.org/abs/2509.20886

Initialize the diffusion guidance.

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model to use for guidance.

  • disable_jit (bool) – Whether to disable JIT compilation.

__call__(noisy_images1, noisy_images2, measurements, noise_rates, signal_rates, omega=1.0, gamma=1.0, **kwargs)[source]ΒΆ

Compute guidance gradients for posterior sampling.

This method concatenates the noisy foreground and background images, computes the combined loss via compute_error(), and returns separate gradients for each component.

Parameters:
  • noisy_images1 – Noisy foreground images \(\mathbf{x}_t\) from the diffusion model, shape (batch, frames, H, W, C).

  • noisy_images2 – Noisy background images \(\mathbf{L}_t\), shape (batch, frames, H, W, C).

  • measurements – Target measurements \(\mathbf{Y}\), shape (batch, frames, H, W, C).

  • noise_rates – Current noise rates from diffusion schedule.

  • signal_rates – Current signal rates from diffusion schedule.

  • omega (float) – Weight for the measurement error term. Default is 1.0.

  • gamma (float) – Weight for the nuclear norm penalty term. Default is 1.0.

  • **kwargs – Additional arguments passed to compute_error() (e.g., gamma, rank_weight_factor, step, total_steps).

Returns:

  • gradients (tuple): (grad_foreground, grad_background) - gradients for foreground and background.

  • loss_info (tuple): (loss, aux) where:

    • loss (float): Combined loss value.

    • aux (tuple): Auxiliary outputs from compute_error().

Return type:

A tuple containing

compute_error(combined_images, measurements, noise_rates, signal_rates, omega=1.0, gamma=1.0, rank_weight_factor=None, step=None, total_steps=None, initial_step=100, max_alpha=0.5, **kwargs)[source]ΒΆ

Compute measurement error for joint diffusion posterior sampling.

Parameters:
  • combined_images – Concatenated noisy images, containing both foreground and background components, shape (batch, frames, H, W, 2C). In the context of cardiac ultrasound dehazing, the first C channels correspond to the tissue signal (foreground), and the next C channels correspond to the haze (background) component.

  • measurements – Target measurements \(\mathbf{Y}\), shape (batch, frames, H, W, C).

  • noise_rates – Current noise rates from the diffusion schedule, shape (batch, frames, 1, 1, 1).

  • signal_rates – Current signal rates from the diffusion schedule, shape (batch, frames, 1, 1, 1).

  • omega (float) – Weight \(\omega\) for the measurement error term (L2 reconstruction loss).

  • gamma (float) – Weight \(\gamma\) for the nuclear norm penalty term.

  • rank_weight_factor (float | None) – Optional weight factor for weighted_nuclear_norm_penalty(). If None, uses standard nuclear_norm_penalty().

  • step (int | None) – Current diffusion step for progressive blending. Used to compute \(\alpha(t)\).

  • total_steps (int | None) – Total number of diffusion steps.

  • initial_step (int) – Step at which to start progressive blending.

  • max_alpha (float) – Maximum value for \(\alpha\) at the final step. The alpha parameter mixes foreground and background predictions, but only after the initial_step to allow the diffusion model to first focus on generating the foreground signal before blending in the background component.

  • **kwargs – Additional arguments (unused).

Returns:

  • measurement_error (float): Combined loss \(\mathcal{L}\).

  • aux (tuple): Auxiliary outputs: (pred_noises_foreground, pred_images_foreground, noisy_background_images, l2_error, nuclear_penalty)

Return type:

A tuple containing

Note

The progressive blending factor \(\alpha(t)\) linearly increases from 0 at initial_step and plateaus at max_alpha once normalized progress reaches max_alpha, allowing the background component to gradually influence the reconstruction and then saturate for the remainder of sampling.

static nuclear_norm_penalty(background_images)[source]ΒΆ

Compute nuclear norm penalty for low-rank enforcement.

The nuclear norm (sum of singular values) encourages low-rank structure in the background component across time. For a matrix \(\mathbf{L}\), it is defined as:

\[\|\mathbf{L}\|_* = \sum_{i=1}^{r} \sigma_i(\mathbf{L})\]

where \(\sigma_i\) are the singular values and \(r\) is the rank.

Parameters:

background_images – Background images of shape (batch, frames, height, width, channels). Each sequence is reshaped to a matrix of shape (frames, height x width x channels) before computing the nuclear norm.

Returns:

Nuclear norm penalty summed across the batch and normalized by number of frames.

Note

The input is reshaped from (batch, frames, H, W, C) to (batch, frames, HxWxC) before computing the singular values.

static weighted_nuclear_norm_penalty(background_images, weight_factor=2.0)[source]ΒΆ

Compute weighted nuclear norm penalty with enhanced rank control.

This implements a WNNM-style (Weighted Nuclear Norm Minimization) penalty that penalizes smaller singular values more heavily than larger ones, suppressing the spectrum tail to enforce low-rank structure. The weighted penalty is:

\[\|\mathbf{L}\|_{w,*} = \sum_{i=1}^{r} w_i \cdot \sigma_i(\mathbf{L})\]

where \(w_i = 1 + \alpha \cdot \frac{i}{r}\) increases linearly with the index \(i\), and \(\alpha\) is the weight_factor. Since ops.svd returns singular values in descending order (\(\sigma_1 \geq \sigma_2 \geq \cdots\)), higher indices correspond to smaller singular values, which receive larger weights.

Parameters:
  • background_images – Background images of shape (batch, frames, height, width, channels).

  • weight_factor (float) – Scaling factor \(\alpha\) controlling how much more to penalize smaller singular values (the spectrum tail). Default is 2.0.

Returns:

Weighted nuclear norm penalty summed across the batch and normalized by number of frames.

Note

This is a drop-in replacement for nuclear_norm_penalty() that provides better rank control by more aggressively penalizing the tail of the singular value spectrum (smaller singular values) rather than the leading ones.