Source code for denoise.model

import torch
import torch.nn as nn
import torch.nn.functional as F

[docs] class unet_box_gn(torch.nn.Module): def __init__(self, in_ch, out_ch, groups): super().__init__() self.double_conv = torch.nn.Sequential( torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.GroupNorm(num_groups=groups, num_channels=out_ch), nn.LeakyReLU(.1, inplace=True), torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), nn.GroupNorm(num_groups=groups, num_channels=out_ch), nn.LeakyReLU(.1, inplace=True), )
[docs] def forward(self, x): return self.double_conv(x)
[docs] class unet_bottleneck_gn(torch.nn.Module): def __init__(self, in_ch, out_ch, groups): super().__init__() self.bn_conv = torch.nn.Sequential( torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.GroupNorm(num_groups=groups, num_channels=out_ch), nn.LeakyReLU(.1, inplace=True), )
[docs] def forward(self, x): return self.bn_conv(x)
[docs] class unet_up(torch.nn.Module): def __init__(self, ch,): super().__init__() self.down_scale = torch.nn.Sequential( torch.nn.Upsample(scale_factor=2, mode='nearest') )
[docs] def forward(self, x): return self.down_scale(x)
[docs] class unet_down(torch.nn.Module): def __init__(self, ch): super().__init__() self.maxpool = torch.nn.Sequential( torch.nn.MaxPool2d(2), )
[docs] def forward(self, x): return self.maxpool(x)
[docs] class unet_ns_gn(torch.nn.Module): def __init__(self, start_filter_size, ich=1, och=1, channels_per_group=8): super().__init__() self.in_box= torch.nn.Sequential( torch.nn.Conv2d(ich, start_filter_size, kernel_size=1, padding=0), nn.GroupNorm(num_groups=int((start_filter_size)/channels_per_group), num_channels=start_filter_size), nn.LeakyReLU(.1, inplace=True), ) self.box1 = unet_box_gn(start_filter_size, start_filter_size*4, groups=int((start_filter_size*4)/channels_per_group)) self.down1 = unet_down(start_filter_size*4) self.box2 = unet_box_gn(start_filter_size*4, start_filter_size*8, groups=int((start_filter_size*8)/channels_per_group)) self.down2 = unet_down(start_filter_size*8) self.box3 = unet_box_gn(start_filter_size*8, start_filter_size*16, groups=int((start_filter_size*16)/channels_per_group)) self.down3 = unet_down(start_filter_size*16) self.bottleneck = unet_bottleneck_gn(start_filter_size*16, start_filter_size*16, groups=int((start_filter_size*16)/channels_per_group)) self.up1 = unet_up(start_filter_size*16) self.box4 = unet_box_gn(start_filter_size*16, start_filter_size*8, groups=int((start_filter_size*8)/channels_per_group)) self.up2 = unet_up(start_filter_size*8) self.box5 = unet_box_gn(start_filter_size*8, start_filter_size*4, groups=int((start_filter_size*4)/channels_per_group)) self.up3 = unet_up(start_filter_size*4) self.box6 = unet_box_gn(start_filter_size*4, start_filter_size*4, groups=int((start_filter_size*4)/channels_per_group)) self.out_layer = torch.nn.Sequential( torch.nn.Conv2d(start_filter_size*4, start_filter_size*2, kernel_size=1, padding=0), nn.GroupNorm(num_groups=int((start_filter_size*2)/channels_per_group), num_channels=start_filter_size*2), nn.LeakyReLU(.1, inplace=True), torch.nn.Conv2d(start_filter_size*2, och, kernel_size=1, padding=0), )
[docs] def forward(self, x): output = self.in_box(x) output = self.box1(output) output = self.down1(output) output = self.box2(output) output = self.down2(output) output = self.box3(output) output = self.down3(output) output = self.bottleneck(output) output = self.up1(output) output = self.box4(output) output = self.up2(output) output = self.box5(output) output = self.up3(output) output = self.box6(output) output = self.out_layer(output) return output