denoise.data_utils

class denoise.data_utils.InferenceBatchSizeOptimizer(model: torch.nn.Module, input_shape: tuple, device: torch.device = torch.device, max_batch_size: int = 512, precision: str = 'fp32', n_channels: int = 5)[source]

Bases: object

Class for determining the optimal batch size to be used for inferencing

-Differences in GPU memory (32GB V100 vs. 80GB A100), model size, and reconstructed image size can all influence how many images can be processed during inference. While we could process 1 image per batch, this is slow and wasteful. -This class helps determine the optimal size to be used

params

-model (obj) pytorch model to be used for inference -input_shape (tuple) size of the images to be denoised -device (obj) cuda device -max_batch_size (int) maximum batch size to check -precision (str) whether to use flaoting point 32 or amp

estimate_peak_memory(batch_size: int) float[source]
find_optimal_batch_size() int[source]
get_available_memory()[source]
profile()[source]
denoise.data_utils.extract_sliding_window_patches_25d(x: torch.Tensor, patch_size: Tuple[int, int] = (512, 512), overlap: float = 0.5, pad_mode: str = 'reflect', pad_value: float = 0.0, return_coords: bool = True) Tuple[torch.Tensor, List[Tuple[int, int]] | None, Dict[str, int]][source]

Extract sliding-window patches from a 2.5D CT input tensor.

Input shape:
x: [N, C, H, W]
  • N: number of samples/windows (e.g., number of center slices you are inferring)

  • C: “channels” = number of adjacent slices in your 2.5D stack

  • H, W: spatial size of each slice

IMPORTANT (alignment guarantee):

Patches are extracted with the SAME (top,left) coordinates across ALL channels C. So neighbors remain perfectly aligned within each [C, h, w] patch.

Parameters:
  • x – Tensor [N, C, H, W]

  • patch_size – (ph, pw)

  • overlap – fraction in [0, 1). overlap=0.5 => stride = 0.5 * patch_size. You may also pass overlap=0.0 for non-overlapping.

  • pad_mode – padding mode for edges. Options include: “reflect”, “replicate”, “constant”.

  • pad_value – used only if pad_mode=”constant”.

  • return_coords – if True, return list of (top,left) coords for each patch.

Returns:

patches

Tensor [N, P, C, ph, pw]

where P is number of spatial patches per image.

coords: List[(top,left)] of length P (shared across N). None if return_coords=False.

meta: dict with useful info:

  • H_in, W_in: original spatial size

  • H_pad, W_pad: padded spatial size

  • ph, pw: patch size

  • stride_h, stride_w

  • n_rows, n_cols

  • P: number of patches per image

  • pad_top, pad_left (always 0 here; we pad bottom/right for simplicity)

  • pad_bottom, pad_right

denoise.data_utils.stitch_sliding_window_patches(patch_outputs: torch.Tensor, coords: List[Tuple[int, int]], meta: Dict[str, int], *, output_size: Tuple[int, int] | None = None, window: Literal['uniform', 'hann', 'cosine'] = 'hann', eps: float = 1e-06) torch.Tensor[source]

Stitch regression outputs from sliding-window patches back into a full image using overlap-add.

This is identical in spirit to the segmentation/logits stitcher, but intended for regression targets (e.g., predicting the middle slice in a 2.5D setup).

Parameters:
  • patch_outputs

    Either:
    • [N, P, K, ph, pw] (recommended; K typically = 1 for regression)

    • [N*P, K, ph, pw] (will be reshaped using P from meta)

    Note: K may be 1, but we keep it general.

  • coords – List of (top, left) coordinates of length P, matching the extraction order.

  • meta

    Dict returned by extract_sliding_window_patches_2p5d, containing:

    H_in, W_in, H_pad, W_pad, ph, pw, P, pad_bottom, pad_right, etc.

  • output_size – If provided, crop final result to (H, W). If None, uses (H_in, W_in).

  • window

    Blending window for overlap regions:
    • “uniform”: simple average

    • “hann”/”cosine”: smooth blending (recommended)

  • eps – Small constant to avoid divide-by-zero.

Returns:

full

[N, K, H_out, W_out] stitched regression output (cropped to original size by default).

For your middle-slice regression use-case, K=1 so output is [N, 1, H, W].

denoise.data_utils.stitch_sliding_window_patches_core(patch_outputs: torch.Tensor, coords: List[Tuple[int, int]], meta: Dict[str, int], *, output_size: Tuple[int, int] | None = None, window: Literal['uniform', 'hann', 'cosine'] = 'hann', eps: float = 1e-06) torch.Tensor[source]

Stitch model outputs from sliding-window patches back into a full image using overlap-add.

Typical workflow:

patches, coords, meta = extract_sliding_window_patches_25d(…) logits_patches = model(patches_reshaped) # then reshape back to [N,P,K,ph,pw] full_logits = stitch_sliding_window_patches(logits_patches, coords, meta)

Parameters:
  • patch_outputs

    Either:
    • [N, P, K, ph, pw] (recommended)

    • [N*P, K, ph, pw] (will be reshaped using P from meta)

    where:

    N = number of 2.5D stacks / samples P = number of patches per sample K = channels (e.g., logits for num_classes) (ph, pw) = patch spatial size

  • coords – List of (top, left) coordinates of length P, matching the extraction order.

  • meta

    Dict returned by extract_sliding_window_patches_2p5d, containing:

    H_in, W_in, H_pad, W_pad, ph, pw, P, pad_bottom, pad_right, etc.

  • output_size – If provided, crop final result to (H, W). If None, uses (H_in, W_in).

  • window

    Blending window for overlap regions:
    • “uniform”: simple average

    • “hann”/”cosine”: smooth blending (recommended)

  • eps – Small constant to avoid divide-by-zero.

Returns:

full – [N, K, H_out, W_out] stitched output (cropped to original size by default).