denoise.data_utils
- class denoise.data_utils.InferenceBatchSizeOptimizer(model: torch.nn.Module, input_shape: tuple, device: torch.device = torch.device, max_batch_size: int = 512, precision: str = 'fp32', n_channels: int = 5)[source]
Bases:
object- Class for determining the optimal batch size to be used for inferencing
-Differences in GPU memory (32GB V100 vs. 80GB A100), model size, and reconstructed image size can all influence how many images can be processed during inference. While we could process 1 image per batch, this is slow and wasteful. -This class helps determine the optimal size to be used
- params
-model (obj) pytorch model to be used for inference -input_shape (tuple) size of the images to be denoised -device (obj) cuda device -max_batch_size (int) maximum batch size to check -precision (str) whether to use flaoting point 32 or amp
- denoise.data_utils.extract_sliding_window_patches_25d(x: torch.Tensor, patch_size: Tuple[int, int] = (512, 512), overlap: float = 0.5, pad_mode: str = 'reflect', pad_value: float = 0.0, return_coords: bool = True) Tuple[torch.Tensor, List[Tuple[int, int]] | None, Dict[str, int]][source]
Extract sliding-window patches from a 2.5D CT input tensor.
- Input shape:
- x: [N, C, H, W]
N: number of samples/windows (e.g., number of center slices you are inferring)
C: “channels” = number of adjacent slices in your 2.5D stack
H, W: spatial size of each slice
- IMPORTANT (alignment guarantee):
Patches are extracted with the SAME (top,left) coordinates across ALL channels C. So neighbors remain perfectly aligned within each [C, h, w] patch.
- Parameters:
x – Tensor [N, C, H, W]
patch_size – (ph, pw)
overlap – fraction in [0, 1). overlap=0.5 => stride = 0.5 * patch_size. You may also pass overlap=0.0 for non-overlapping.
pad_mode – padding mode for edges. Options include: “reflect”, “replicate”, “constant”.
pad_value – used only if pad_mode=”constant”.
return_coords – if True, return list of (top,left) coords for each patch.
- Returns:
patches –
- Tensor [N, P, C, ph, pw]
where P is number of spatial patches per image.
coords: List[(top,left)] of length P (shared across N). None if return_coords=False.
meta: dict with useful info:
H_in, W_in: original spatial size
H_pad, W_pad: padded spatial size
ph, pw: patch size
stride_h, stride_w
n_rows, n_cols
P: number of patches per image
pad_top, pad_left (always 0 here; we pad bottom/right for simplicity)
pad_bottom, pad_right
- denoise.data_utils.stitch_sliding_window_patches(patch_outputs: torch.Tensor, coords: List[Tuple[int, int]], meta: Dict[str, int], *, output_size: Tuple[int, int] | None = None, window: Literal['uniform', 'hann', 'cosine'] = 'hann', eps: float = 1e-06) torch.Tensor[source]
Stitch regression outputs from sliding-window patches back into a full image using overlap-add.
This is identical in spirit to the segmentation/logits stitcher, but intended for regression targets (e.g., predicting the middle slice in a 2.5D setup).
- Parameters:
patch_outputs –
- Either:
[N, P, K, ph, pw] (recommended; K typically = 1 for regression)
[N*P, K, ph, pw] (will be reshaped using P from meta)
Note: K may be 1, but we keep it general.
coords – List of (top, left) coordinates of length P, matching the extraction order.
meta –
- Dict returned by extract_sliding_window_patches_2p5d, containing:
H_in, W_in, H_pad, W_pad, ph, pw, P, pad_bottom, pad_right, etc.
output_size – If provided, crop final result to (H, W). If None, uses (H_in, W_in).
window –
- Blending window for overlap regions:
“uniform”: simple average
“hann”/”cosine”: smooth blending (recommended)
eps – Small constant to avoid divide-by-zero.
- Returns:
full –
- [N, K, H_out, W_out] stitched regression output (cropped to original size by default).
For your middle-slice regression use-case, K=1 so output is [N, 1, H, W].
- denoise.data_utils.stitch_sliding_window_patches_core(patch_outputs: torch.Tensor, coords: List[Tuple[int, int]], meta: Dict[str, int], *, output_size: Tuple[int, int] | None = None, window: Literal['uniform', 'hann', 'cosine'] = 'hann', eps: float = 1e-06) torch.Tensor[source]
Stitch model outputs from sliding-window patches back into a full image using overlap-add.
- Typical workflow:
patches, coords, meta = extract_sliding_window_patches_25d(…) logits_patches = model(patches_reshaped) # then reshape back to [N,P,K,ph,pw] full_logits = stitch_sliding_window_patches(logits_patches, coords, meta)
- Parameters:
patch_outputs –
- Either:
[N, P, K, ph, pw] (recommended)
[N*P, K, ph, pw] (will be reshaped using P from meta)
- where:
N = number of 2.5D stacks / samples P = number of patches per sample K = channels (e.g., logits for num_classes) (ph, pw) = patch spatial size
coords – List of (top, left) coordinates of length P, matching the extraction order.
meta –
- Dict returned by extract_sliding_window_patches_2p5d, containing:
H_in, W_in, H_pad, W_pad, ph, pw, P, pad_bottom, pad_right, etc.
output_size – If provided, crop final result to (H, W). If None, uses (H_in, W_in).
window –
- Blending window for overlap regions:
“uniform”: simple average
“hann”/”cosine”: smooth blending (recommended)
eps – Small constant to avoid divide-by-zero.
- Returns:
full – [N, K, H_out, W_out] stitched output (cropped to original size by default).