import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional, Literal, Union
def _make_blend_window(
ph: int,
pw: int,
window: Literal["uniform", "hann", "cosine"] = "hann",
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Create a 2D blending window of shape [1, 1, ph, pw] for overlap-add stitching.
- "uniform": all ones
- "hann"/"cosine": raised-cosine (Hann) in both dims, outer product -> 2D
"""
if window == "uniform":
w2d = torch.ones((ph, pw), device=device, dtype=dtype)
elif window in ("hann", "cosine"):
# Hann window in PyTorch is periodic by default; set periodic=False for overlap-add style.
wh = torch.hann_window(ph, periodic=False, device=device, dtype=dtype).clamp_min(1e-6)
ww = torch.hann_window(pw, periodic=False, device=device, dtype=dtype).clamp_min(1e-6)
w2d = wh[:, None] * ww[None, :]
else:
raise ValueError("window must be one of: 'uniform', 'hann', 'cosine'")
return w2d[None, None, :, :] # [1,1,ph,pw]
[docs]
def stitch_sliding_window_patches_core(
patch_outputs: torch.Tensor,
coords: List[Tuple[int, int]],
meta: Dict[str, int],
*,
output_size: Optional[Tuple[int, int]] = None,
window: Literal["uniform", "hann", "cosine"] = "hann",
eps: float = 1e-6,
) -> torch.Tensor:
"""
Stitch model outputs from sliding-window patches back into a full image using overlap-add.
Typical workflow:
patches, coords, meta = extract_sliding_window_patches_25d(...)
logits_patches = model(patches_reshaped) # then reshape back to [N,P,K,ph,pw]
full_logits = stitch_sliding_window_patches(logits_patches, coords, meta)
Args:
patch_outputs:
Either:
- [N, P, K, ph, pw] (recommended)
- [N*P, K, ph, pw] (will be reshaped using P from meta)
where:
N = number of 2.5D stacks / samples
P = number of patches per sample
K = channels (e.g., logits for num_classes)
(ph, pw) = patch spatial size
coords:
List of (top, left) coordinates of length P, matching the extraction order.
meta:
Dict returned by extract_sliding_window_patches_2p5d, containing:
H_in, W_in, H_pad, W_pad, ph, pw, P, pad_bottom, pad_right, etc.
output_size:
If provided, crop final result to (H, W). If None, uses (H_in, W_in).
window:
Blending window for overlap regions:
- "uniform": simple average
- "hann"/"cosine": smooth blending (recommended)
eps:
Small constant to avoid divide-by-zero.
Returns:
full: [N, K, H_out, W_out] stitched output (cropped to original size by default).
"""
if patch_outputs.dim() == 4:
# [N*P, K, ph, pw] -> [N, P, K, ph, pw]
NP, K, ph, pw = patch_outputs.shape
P = int(meta["P"])
if NP % P != 0:
raise ValueError(f"Cannot reshape: NP={NP} not divisible by P={P}.")
N = NP // P
patch_outputs = patch_outputs.view(N, P, K, ph, pw)
elif patch_outputs.dim() == 5:
# [N, P, K, ph, pw]
N, P, K, ph, pw = patch_outputs.shape
if P != int(meta["P"]):
raise ValueError(f"P mismatch: patch_outputs has P={P}, meta['P']={meta['P']}.")
else:
raise ValueError(f"Expected patch_outputs dim 4 or 5, got {patch_outputs.dim()}.")
if len(coords) != P:
raise ValueError(f"coords length {len(coords)} != P {P}.")
H_pad = int(meta["H_pad"])
W_pad = int(meta["W_pad"])
# Accumulators
device = patch_outputs.device
dtype = patch_outputs.dtype
acc = torch.zeros((N, K, H_pad, W_pad), device=device, dtype=dtype)
wacc = torch.zeros((N, 1, H_pad, W_pad), device=device, dtype=dtype)
wpatch = _make_blend_window(ph, pw, window=window, device=device, dtype=dtype) # [1,1,ph,pw]
# Overlap-add accumulation
# patch_outputs: [N, P, K, ph, pw]
for p, (top, left) in enumerate(coords):
patch = patch_outputs[:, p, :, :, :] # [N, K, ph, pw]
acc[:, :, top:top + ph, left:left + pw] += patch * wpatch
wacc[:, :, top:top + ph, left:left + pw] += wpatch
full = acc / (wacc + eps) # [N, K, H_pad, W_pad]
# Crop padding back to original size (or a user-specified output_size)
if output_size is None:
H_out, W_out = int(meta["H_in"]), int(meta["W_in"])
else:
H_out, W_out = output_size
full = full[:, :, :H_out, :W_out].contiguous()
return full
[docs]
def stitch_sliding_window_patches(
patch_outputs: torch.Tensor,
coords: List[Tuple[int, int]],
meta: Dict[str, int],
*,
output_size: Optional[Tuple[int, int]] = None,
window: Literal["uniform", "hann", "cosine"] = "hann",
eps: float = 1e-6,
) -> torch.Tensor:
"""
Stitch regression outputs from sliding-window patches back into a full image using overlap-add.
This is identical in spirit to the segmentation/logits stitcher, but intended for
regression targets (e.g., predicting the middle slice in a 2.5D setup).
Args:
patch_outputs:
Either:
- [N, P, K, ph, pw] (recommended; K typically = 1 for regression)
- [N*P, K, ph, pw] (will be reshaped using P from meta)
Note: K may be 1, but we keep it general.
coords:
List of (top, left) coordinates of length P, matching the extraction order.
meta:
Dict returned by extract_sliding_window_patches_2p5d, containing:
H_in, W_in, H_pad, W_pad, ph, pw, P, pad_bottom, pad_right, etc.
output_size:
If provided, crop final result to (H, W). If None, uses (H_in, W_in).
window:
Blending window for overlap regions:
- "uniform": simple average
- "hann"/"cosine": smooth blending (recommended)
eps:
Small constant to avoid divide-by-zero.
Returns:
full: [N, K, H_out, W_out] stitched regression output (cropped to original size by default).
For your middle-slice regression use-case, K=1 so output is [N, 1, H, W].
"""
# Reuse the exact same stitcher as segmentation: it's channel-agnostic.
return stitch_sliding_window_patches_core(
patch_outputs=patch_outputs,
coords=coords,
meta=meta,
output_size=output_size,
window=window,
eps=eps,
)
[docs]
class InferenceBatchSizeOptimizer:
"""
Class for determining the optimal batch size to be used for inferencing
-Differences in GPU memory (32GB V100 vs. 80GB A100), model size, and reconstructed image size can all influence
how many images can be processed during inference. While we could process 1 image per batch, this is slow and wasteful.
-This class helps determine the optimal size to be used
params
-model (obj) pytorch model to be used for inference
-input_shape (tuple) size of the images to be denoised
-device (obj) cuda device
-max_batch_size (int) maximum batch size to check
-precision (str) whether to use flaoting point 32 or amp
"""
def __init__(self, model: nn.Module, input_shape: tuple, device: torch.device = torch.device('cuda'),
max_batch_size: int = 512, precision: str = 'fp32', n_channels: int = 5):
self.model = model.eval().to(device)
self.input_shape = input_shape # (H, W) for 2.5d or (D, H, W) for 3d
self.device = device
self.max_batch_size = max_batch_size
self.precision = precision.lower()
self.n_channels = n_channels
if self.precision not in ['fp32', 'amp']:
raise ValueError("precision must be either 'fp32' or 'amp'")
self.cached_optimal_batch_size = None
[docs]
def get_available_memory(self):
torch.cuda.empty_cache()
return torch.cuda.mem_get_info(self.device.index)[0] / 1024**2 # MB
[docs]
def estimate_peak_memory(self, batch_size: int) -> float:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(self.device)
dummy_input = torch.randn((batch_size, self.n_channels, *self.input_shape), device=self.device)
try:
with torch.no_grad():
if self.precision == 'amp':
with torch.autocast(device_type='cuda'):
_ = self.model(dummy_input)
else:
_ = self.model(dummy_input)
except RuntimeError as e:
raise RuntimeError(f"OOM or other error at batch size {batch_size}: {e}")
peak_mem = torch.cuda.max_memory_allocated(self.device) / 1024**2 # MB
return peak_mem
[docs]
def find_optimal_batch_size(self) -> int:
if self.cached_optimal_batch_size is not None:
return self.cached_optimal_batch_size
low, high = 1, self.max_batch_size
best = 1
while low <= high:
mid = (low + high) // 2
try:
_ = self.estimate_peak_memory(mid)
best = mid
low = mid + 1
except RuntimeError:
high = mid - 1
self.cached_optimal_batch_size = best
return best
[docs]
def profile(self):
batch_size = self.find_optimal_batch_size()
peak_memory = self.estimate_peak_memory(batch_size)
available_memory = self.get_available_memory()
return {
'optimal_batch_size': batch_size,
'peak_memory_used_MB': peak_memory,
'available_memory_MB': available_memory
}