#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Denoise a single CT slice using a trained Noise2Inverse model.
"""
import os
import yaml
import numpy as np
import torch
import tifffile
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
from denoise.model import unet_ns_gn
from denoise.utils import save2img
from denoise.data_utils import extract_sliding_window_patches_25d, stitch_sliding_window_patches
from denoise import tiffs as tiffs_mod
from denoise import log
[docs]
def run(args):
# Read the YAML file
with open(args.config, 'r') as file:
params = yaml.safe_load(file)
# 3D mode operates on full volumes — single-slice inference is not meaningful
mode = getattr(args, 'mode', None) or params['train'].get('mode', '2.5d')
if mode == '3d':
raise RuntimeError(
"The 'slice' command is not available in --mode 3d.\n"
"3D denoising processes full volumes. Use:\n\n"
" denoise volume --config %s --mode 3d" % args.config
)
# create directory for denoised slices
full_recon_name = params['dataset']['full_recon_name']
base_name = full_recon_name[:-4] if full_recon_name.endswith('_rec') else full_recon_name
out_path = params['dataset']['directory_to_reconstructions'] + '/' + base_name + '_denoised_slices'
if not os.path.isdir(out_path):
os.mkdir(out_path)
# setup cuda device
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load in model
n_slices = params['train']['n_slices']
_ckpt_map = {'val': 'best_val_model.pth', 'lcl': 'best_lcl_model.pth', 'edge': 'best_edge_model.pth'}
ckpt_name = _ckpt_map[getattr(args, 'checkpoint', 'lcl')]
model_dir = getattr(args, 'model_dir', None) or \
params['dataset']['directory_to_reconstructions'] + '/TrainOutput'
path_to_mdl = os.path.join(model_dir, ckpt_name)
log.info("Using checkpoint: %s" % ckpt_name)
checkpoint = torch.load(path_to_mdl, map_location=torch.device('cpu'), weights_only=False)
model = unet_ns_gn(ich=n_slices, start_filter_size=16, channels_per_group=8)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(dev)
log.info("Loading slice %d" % args.slice_number)
# path to data
full_recon_path = params['dataset']['directory_to_reconstructions'] + '/' + params['dataset']['full_recon_name']
# collect tiff files
tiffs_collection = tiffs_mod.glob(full_recon_path)
# supports 2.5D modeling
S = len(tiffs_collection)
left = n_slices // 2
right = n_slices - 1 - left
offsets = np.arange(-left, right + 1, dtype=int)
idxs = args.slice_number + offsets
idxs_mapped = np.clip(idxs, 0, S - 1)
# get image slice
list_of_images_to_process = [tiffs_collection[img_num] for img_num in idxs_mapped]
# load in data
images, _, _ = tiffs_mod.load_stack(list_of_images_to_process)
images = torch.from_numpy(images[np.newaxis]).to(dev)
# normalize image stack using training mean/std
mean4norm = params['dataset']['mean4norm']
std4norm = params['dataset']['std4norm']
images = (images - mean4norm) / std4norm
psz = params['train']['psz']
patches, coords, meta = extract_sliding_window_patches_25d(
images,
patch_size=(psz, psz),
overlap=params['infer']['overlap'],
pad_mode="reflect",
return_coords=True,
)
denoised_patches = torch.zeros((1, meta["P"], 1, psz, psz))
# denoise image
with torch.no_grad():
for i in tqdm(range(patches.shape[1])):
denoised = model(patches[:, i])
denoised_patches[:, i] = denoised
denoised = stitch_sliding_window_patches(
denoised_patches, coords, meta, window=params['infer']['window']
).cpu().squeeze().numpy()
# rescale back to original values
denoised = denoised * std4norm + mean4norm
# save denoised slice
tifffile.imwrite(f'{out_path}/{args.slice_number:05d}.tiff', denoised)
log.info("Saved denoised slice to %s/%05d.tiff" % (out_path, args.slice_number))