Source code for denoise.model3d

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
3D U-Net for Noise2Inverse volumetric denoising.

Architecture source: SSD_3D (Laugros et al., bioRxiv 2025)
  "Self-supervised image restoration in coherent X-ray neuronal microscopy"
  https://doi.org/10.1101/2025.02.10.633538

Original U-Net implementation: ELEKTRONN3 (Martin Drawitsch, MPG)
  https://github.com/ELEKTRONN/elektronn3
  Based on https://github.com/jaxony/unet-pytorch (Jackson Huang, MIT License)

Modifications in this file:
  - Removed test utilities
  - Added unet3d() factory function for N2I-compatible instantiation
"""

__all__ = ['UNet', 'unet3d']

import copy

from typing import Sequence, Union, Tuple, Optional

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn import functional as F


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def get_conv(dim=3):
    if dim == 3:
        return nn.Conv3d
    elif dim == 2:
        return nn.Conv2d
    else:
        raise ValueError('dim has to be 2 or 3')


def get_convtranspose(dim=3):
    if dim == 3:
        return nn.ConvTranspose3d
    elif dim == 2:
        return nn.ConvTranspose2d
    else:
        raise ValueError('dim has to be 2 or 3')


def get_maxpool(dim=3):
    if dim == 3:
        return nn.MaxPool3d
    elif dim == 2:
        return nn.MaxPool2d
    else:
        raise ValueError('dim has to be 2 or 3')


def get_normalization(normtype: str, num_channels: int, dim: int = 3):
    if normtype is None or normtype == 'none':
        return nn.Identity()
    elif normtype.startswith('group'):
        if normtype == 'group':
            num_groups = 8
        elif len(normtype) > len('group') and normtype[len('group'):].isdigit():
            num_groups = int(normtype[len('group'):])
        else:
            raise ValueError(
                f'normtype "{normtype}" not understood. Use "group<G>".'
            )
        return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
    elif normtype == 'instance':
        return nn.InstanceNorm3d(num_channels) if dim == 3 else nn.InstanceNorm2d(num_channels)
    elif normtype == 'batch':
        return nn.BatchNorm3d(num_channels) if dim == 3 else nn.BatchNorm2d(num_channels)
    elif normtype == 'layer':
        return nn.GroupNorm(1, num_channels=num_channels)
    else:
        raise ValueError(f'Unknown normalization type "{normtype}".')


def planar_kernel(x):
    return (1, x, x) if isinstance(x, int) else x


def planar_pad(x):
    return (0, x, x) if isinstance(x, int) else x


def conv3(in_channels, out_channels, kernel_size=3, stride=1,
          padding=1, bias=True, planar=False, dim=3):
    if planar:
        stride = planar_kernel(stride)
        padding = planar_pad(padding)
        kernel_size = planar_kernel(kernel_size)
    return get_conv(dim)(
        in_channels, out_channels,
        kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
    )


def upconv2(in_channels, out_channels, mode='transpose', planar=False, dim=3):
    kernel_size = 2
    stride = 2
    if planar:
        kernel_size = planar_kernel(kernel_size)
        stride = planar_kernel(stride)
    if mode == 'transpose':
        return get_convtranspose(dim)(in_channels, out_channels,
                                     kernel_size=kernel_size, stride=stride)
    elif 'resizeconv' in mode:
        upsampling_mode = ('trilinear' if dim == 3 else 'bilinear') if 'linear' in mode else 'nearest'
        rc_kernel_size = 1 if mode.endswith('1') else 3
        return ResizeConv(in_channels, out_channels, planar=planar, dim=dim,
                          upsampling_mode=upsampling_mode, kernel_size=rc_kernel_size)


def conv1(in_channels, out_channels, dim=3):
    return get_conv(dim)(in_channels, out_channels, kernel_size=1)


def get_activation(activation):
    if isinstance(activation, str):
        if activation == 'relu':
            return nn.ReLU()
        elif activation == 'leaky':
            return nn.LeakyReLU(negative_slope=0.1)
        elif activation == 'prelu':
            return nn.PReLU(num_parameters=1)
        elif activation == 'rrelu':
            return nn.RReLU()
        elif activation == 'silu':
            return nn.SiLU()
        elif activation == 'lin':
            return nn.Identity()
    else:
        return copy.deepcopy(activation)


# ---------------------------------------------------------------------------
# Network blocks
# ---------------------------------------------------------------------------

class DownConv(nn.Module):
    def __init__(self, in_channels, out_channels, pooling=True, planar=False,
                 activation='relu', normalization=None, full_norm=True, dim=3, conv_mode='same'):
        super().__init__()
        self.pooling = pooling
        self.dim = dim
        padding = 1 if 'same' in conv_mode else 0

        self.conv1 = conv3(in_channels, out_channels, planar=planar, dim=dim, padding=padding)
        self.conv2 = conv3(out_channels, out_channels, planar=planar, dim=dim, padding=padding)

        if self.pooling:
            kernel_size = planar_kernel(2) if planar else 2
            self.pool = get_maxpool(dim)(kernel_size=kernel_size, ceil_mode=True)
            self.pool_ks = kernel_size
        else:
            self.pool = nn.Identity()
            self.pool_ks = -123

        self.act1 = get_activation(activation)
        self.act2 = get_activation(activation)
        self.norm0 = get_normalization(normalization, out_channels, dim=dim) if full_norm else nn.Identity()
        self.norm1 = get_normalization(normalization, out_channels, dim=dim)

    def forward(self, x):
        y = self.act1(self.norm0(self.conv1(x)))
        y = self.act2(self.norm1(self.conv2(y)))
        before_pool = y
        y = self.pool(y)
        return y, before_pool


@torch.jit.script
def autocrop(from_down: torch.Tensor, from_up: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    if from_down.shape[2:] == from_up.shape[2:]:
        return from_down, from_up

    ds = from_down.shape[2:]
    us = from_up.shape[2:]
    upcrop = [u - ((u - d) % 2) for d, u in zip(ds, us)]

    ndim = from_down.dim()
    if ndim == 4:
        from_up = from_up[:, :, :upcrop[0], :upcrop[1]]
    if ndim == 5:
        from_up = from_up[:, :, :upcrop[0], :upcrop[1], :upcrop[2]]

    ds = from_down.shape[2:]
    us = from_up.shape[2:]
    assert ds[0] >= us[0], f'{ds, us}'
    assert ds[1] >= us[1]
    if ndim == 4:
        from_down = from_down[
            :, :,
            (ds[0] - us[0]) // 2:(ds[0] + us[0]) // 2,
            (ds[1] - us[1]) // 2:(ds[1] + us[1]) // 2,
        ]
    elif ndim == 5:
        assert ds[2] >= us[2]
        from_down = from_down[
            :, :,
            (ds[0] - us[0]) // 2:(ds[0] + us[0]) // 2,
            (ds[1] - us[1]) // 2:(ds[1] + us[1]) // 2,
            (ds[2] - us[2]) // 2:(ds[2] + us[2]) // 2,
        ]
    return from_down, from_up


class DummyAttention(nn.Module):
    def forward(self, x, g):
        return x, None


class GridAttention(nn.Module):
    """Grid attention gate (Oktay et al., 2018)."""
    def __init__(self, in_channels, gating_channels, inter_channels=None,
                 dim=3, sub_sample_factor=2):
        super().__init__()
        assert dim in [2, 3]
        self.dim = dim
        self.sub_sample_factor = (sub_sample_factor,) * dim
        self.sub_sample_kernel_size = self.sub_sample_factor
        self.in_channels = in_channels
        self.gating_channels = gating_channels
        self.inter_channels = inter_channels or max(1, in_channels // 2)

        conv_nd = nn.Conv3d if dim == 3 else nn.Conv2d
        bn = nn.BatchNorm3d if dim == 3 else nn.BatchNorm2d
        self.upsample_mode = 'trilinear' if dim == 3 else 'bilinear'

        self.w = nn.Sequential(
            conv_nd(self.in_channels, self.in_channels, kernel_size=1),
            bn(self.in_channels),
        )
        self.theta = conv_nd(self.in_channels, self.inter_channels,
                             kernel_size=self.sub_sample_kernel_size,
                             stride=self.sub_sample_factor, bias=False)
        self.phi = conv_nd(self.gating_channels, self.inter_channels,
                           kernel_size=1, stride=1, padding=0, bias=True)
        self.psi = conv_nd(self.inter_channels, 1, kernel_size=1, stride=1, bias=True)
        self.init_weights()

    def forward(self, x, g):
        theta_x = self.theta(x)
        phi_g = F.interpolate(self.phi(g), size=theta_x.shape[2:],
                              mode=self.upsample_mode, align_corners=False)
        f = F.relu(theta_x + phi_g, inplace=True)
        sigm_psi_f = torch.sigmoid(self.psi(f))
        sigm_psi_f = F.interpolate(sigm_psi_f, size=x.shape[2:],
                                   mode=self.upsample_mode, align_corners=False)
        return self.w(sigm_psi_f.expand_as(x) * x), sigm_psi_f

    def init_weights(self):
        def weight_init(m):
            cls = m.__class__.__name__
            if 'Conv' in cls:
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif 'Linear' in cls:
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif 'BatchNorm' in cls:
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0.0)
        self.apply(weight_init)


class UpConv(nn.Module):
    att: Optional[torch.Tensor]

    def __init__(self, in_channels, out_channels, merge_mode='concat', up_mode='transpose',
                 planar=False, activation='relu', normalization=None, full_norm=True,
                 dim=3, conv_mode='same', attention=False):
        super().__init__()
        self.merge_mode = merge_mode
        self.up_mode = up_mode
        padding = 1 if 'same' in conv_mode else 0

        self.upconv = upconv2(in_channels, out_channels, mode=up_mode, planar=planar, dim=dim)

        in1 = 2 * out_channels if merge_mode == 'concat' else out_channels
        self.conv1 = conv3(in1, out_channels, planar=planar, dim=dim, padding=padding)
        self.conv2 = conv3(out_channels, out_channels, planar=planar, dim=dim, padding=padding)

        self.act0 = get_activation(activation)
        self.act1 = get_activation(activation)
        self.act2 = get_activation(activation)

        if full_norm:
            self.norm0 = get_normalization(normalization, out_channels, dim=dim)
            self.norm1 = get_normalization(normalization, out_channels, dim=dim)
        else:
            self.norm0 = nn.Identity()
            self.norm1 = nn.Identity()
        self.norm2 = get_normalization(normalization, out_channels, dim=dim)

        self.attention = GridAttention(in_channels // 2, in_channels, dim=dim) if attention else DummyAttention()
        self.att = None

    def forward(self, enc, dec):
        updec = self.upconv(dec)
        enc, updec = autocrop(enc, updec)
        genc, att = self.attention(enc, dec)
        if not torch.jit.is_scripting():
            self.att = att
        updec = self.act0(self.norm0(updec))
        mrg = torch.cat((updec, genc), 1) if self.merge_mode == 'concat' else updec + genc
        y = self.act1(self.norm1(self.conv1(mrg)))
        y = self.act2(self.norm2(self.conv2(y)))
        return y


class ResizeConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, planar=False,
                 dim=3, upsampling_mode='nearest'):
        super().__init__()
        self.scale_factor = planar_kernel(2) if (dim == 3 and planar) else 2
        self.upsample = nn.Upsample(scale_factor=self.scale_factor, mode=upsampling_mode)
        if kernel_size == 3:
            self.conv = conv3(in_channels, out_channels, padding=1, planar=planar, dim=dim)
        elif kernel_size == 1:
            self.conv = conv1(in_channels, out_channels, dim=dim)
        else:
            raise ValueError(f'kernel_size={kernel_size} not supported.')

    def forward(self, x):
        return self.conv(self.upsample(x))


# ---------------------------------------------------------------------------
# Main UNet
# ---------------------------------------------------------------------------

[docs] class UNet(nn.Module): """ 3D U-Net with skip connections for volumetric image restoration. Input: [B, in_channels, D, H, W] Output: [B, out_channels, D, H, W] Spatial dimensions must each be divisible by 2**n_blocks. """ def __init__( self, in_channels: int = 1, out_channels: int = 1, n_blocks: int = 3, start_filts: int = 32, up_mode: str = 'transpose', merge_mode: str = 'concat', planar_blocks: Sequence = (), batch_norm: str = 'unset', attention: bool = False, activation: Union[str, nn.Module] = 'relu', normalization: str = 'layer', full_norm: bool = True, dim: int = 3, conv_mode: str = 'same', ): super().__init__() if n_blocks < 1: raise ValueError('n_blocks must be >= 1.') if dim not in {2, 3}: raise ValueError('dim must be 2 or 3.') if batch_norm != 'unset': raise RuntimeError('Use normalization= instead of batch_norm=.') if up_mode not in ('transpose', 'resizeconv_nearest', 'resizeconv_linear', 'resizeconv_nearest1', 'resizeconv_linear1'): raise ValueError(f'Invalid up_mode: {up_mode}') if merge_mode not in ('concat', 'add'): raise ValueError(f'Invalid merge_mode: {merge_mode}') self.out_channels = out_channels self.in_channels = in_channels self.start_filts = start_filts self.n_blocks = n_blocks self.planar_blocks = planar_blocks self.normalization = normalization self.attention = attention self.conv_mode = conv_mode self.activation = activation self.dim = dim self.up_mode = up_mode self.merge_mode = merge_mode self.down_convs = nn.ModuleList() self.up_convs = nn.ModuleList() # Encoder outs = None for i in range(n_blocks): ins = in_channels if i == 0 else outs outs = start_filts * (2 ** i) planar = i in planar_blocks self.down_convs.append(DownConv( ins, outs, pooling=(i < n_blocks - 1), planar=planar, activation=activation, normalization=normalization, full_norm=full_norm, dim=dim, conv_mode=conv_mode, )) # Decoder for i in range(n_blocks - 1): ins = outs outs = ins // 2 planar = (n_blocks - 2 - i) in planar_blocks self.up_convs.append(UpConv( ins, outs, up_mode=up_mode, merge_mode=merge_mode, planar=planar, activation=activation, normalization=normalization, attention=attention, full_norm=full_norm, dim=dim, conv_mode=conv_mode, )) self.conv_final = conv1(outs, out_channels, dim=dim) self.apply(self.weight_init)
[docs] @staticmethod def weight_init(m): if isinstance(m, GridAttention): return if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)): nn.init.xavier_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0)
[docs] def forward(self, x): encoder_outs = [] for module in self.down_convs: x, before_pool = module(x) encoder_outs.append(before_pool) for i, module in enumerate(self.up_convs): x = module(encoder_outs[-(i + 2)], x) return self.conv_final(x)
@torch.jit.unused def forward_gradcp(self, x): """Forward pass with gradient checkpointing (saves ~20-50% memory).""" encoder_outs = [] for module in self.down_convs: x, before_pool = checkpoint(module, x) encoder_outs.append(before_pool) for i, module in enumerate(self.up_convs): x = checkpoint(module, encoder_outs[-(i + 2)], x) return self.conv_final(x)
# --------------------------------------------------------------------------- # Factory # ---------------------------------------------------------------------------
[docs] def unet3d( in_channels: int = 1, out_channels: int = 1, n_blocks: int = 3, start_filts: int = 32, normalization: str = 'layer', activation: str = 'relu', ) -> UNet: """ Factory function matching the configuration used in Laugros et al. 2025. Input/output shape: [B, 1, D, H, W] D, H, W must each be divisible by 2**n_blocks (e.g. 64 for n_blocks=3). Parameters ---------- in_channels : int Number of input channels (1 for single-channel volumes). out_channels : int Number of output channels (1 for denoising). n_blocks : int Encoder depth. Controls receptive field and GPU memory use. Default 3 is a good balance; use 4 for deeper context (needs more memory). start_filts : int Filter count of the first encoder block. Subsequent blocks double this. normalization : str Normalisation type: 'layer', 'group', 'instance', 'batch', or 'none'. activation : str Activation function: 'relu', 'leaky', 'silu', etc. """ return UNet( in_channels=in_channels, out_channels=out_channels, n_blocks=n_blocks, start_filts=start_filts, up_mode='transpose', merge_mode='concat', planar_blocks=(), attention=False, activation=activation, normalization=normalization, full_norm=True, dim=3, conv_mode='same', )