Source code for denoise.volume

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Denoise the entire CT volume using a trained Noise2Inverse model.
"""

import os
import time
import shutil
import yaml
import numpy as np
import torch
import warnings

warnings.filterwarnings("ignore")

from pathlib import Path
from tqdm import tqdm
from torch.utils.data import DataLoader

from denoise.model import unet_ns_gn
from denoise.model3d import unet3d
from denoise.data import TomoDatasetInfer
from denoise.data3d import TomoDataset3DInfer, _hann3d
from denoise.data_utils import InferenceBatchSizeOptimizer
from denoise import tiffs as tiffs_mod
from denoise import log


[docs] def run(args): # Read the YAML file with open(args.config, 'r') as file: params = yaml.safe_load(file) # Determine mode: CLI flag > YAML config > default 2.5d mode = getattr(args, 'mode', None) or params['train'].get('mode', '2.5d') # setup output directory full_recon_name = params['dataset']['full_recon_name'] base_name = full_recon_name[:-4] if full_recon_name.endswith('_rec') else full_recon_name output_dir = params['dataset']['directory_to_reconstructions'] + '/' + base_name + '_denoised_volume_' + mode if os.path.isdir(output_dir): shutil.rmtree(output_dir) os.mkdir(output_dir) log.info("Inference mode: %s" % mode) dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load checkpoint _ckpt_map = {'val': 'best_val_model.pth', 'lcl': 'best_lcl_model.pth', 'edge': 'best_edge_model.pth'} ckpt_name = _ckpt_map[getattr(args, 'checkpoint', 'lcl')] model_dir = getattr(args, 'model_dir', None) or \ params['dataset']['directory_to_reconstructions'] + '/TrainOutput' path_to_mdl = os.path.join(model_dir, ckpt_name) log.info("Using checkpoint: %s" % ckpt_name) ckpt = torch.load(path_to_mdl, map_location='cpu', weights_only=False) if mode == '3d': n_blocks = int(params['train'].get('n_blocks_3d', 3)) start_filts = int(params['train'].get('start_filts_3d', 32)) model = unet3d(in_channels=1, out_channels=1, n_blocks=n_blocks, start_filts=start_filts) else: n_slices = params['train']['n_slices'] model = unet_ns_gn(ich=n_slices, start_filter_size=16, channels_per_group=8) model.load_state_dict(ckpt['model_state_dict']) model.to(dev).eval() log.info("Loading data into CPU memory, it will take a while ...") if mode == '3d': ds_test = TomoDataset3DInfer(params=params, start_slice=args.start_slice, end_slice=args.end_slice) psz_3d = ds_test.psz patch_shape = (psz_3d, psz_3d, psz_3d) optimal_batch_size = InferenceBatchSizeOptimizer(model=model, input_shape=patch_shape, device=dev, max_batch_size=64, precision='fp32', n_channels=1) stats = optimal_batch_size.profile() mbsz = stats['optimal_batch_size'] dl_test = DataLoader(dataset=ds_test, batch_size=mbsz, shuffle=False, num_workers=2, drop_last=False, prefetch_factor=2, pin_memory=True) # Online stitching — accumulate directly into acc/wacc to avoid # allocating a (N_patches, psz, psz, psz) array that can exceed RAM. window_type = params['infer'].get('window', 'hann') w3d = _hann3d(psz_3d, psz_3d, psz_3d) if window_type in ('hann', 'cosine') \ else np.ones((psz_3d, psz_3d, psz_3d), dtype=np.float32) acc = np.zeros((ds_test.D_pad, ds_test.H_pad, ds_test.W_pad), dtype=np.float32) wacc = np.zeros_like(acc) insert_cnt = 0 log.info("Processing %d 3D patches ..." % len(ds_test)) with torch.no_grad(): for X in tqdm(dl_test): out = model(X.to(dev)).cpu().squeeze(1).numpy() # [B, psz, psz, psz] B = out.shape[0] for b in range(B): d, h, w = ds_test.index[insert_cnt + b] acc [d:d+psz_3d, h:h+psz_3d, w:w+psz_3d] += out[b] * w3d wacc[d:d+psz_3d, h:h+psz_3d, w:w+psz_3d] += w3d insert_cnt += B log.info("Stitching 3D denoised volume ...") preds = (acc / (wacc + 1e-6))[:ds_test.D, :ds_test.H, :ds_test.W].copy() else: ds_test = TomoDatasetInfer(params=params, start_slice=args.start_slice, end_slice=args.end_slice) log.info("Loaded %d slices of size %dx%d" % (ds_test.vol.shape[0], ds_test.vol.shape[1], ds_test.vol.shape[2])) patch_shape = (params['train']['psz'], params['train']['psz']) optimal_batch_size = InferenceBatchSizeOptimizer(model=model, input_shape=patch_shape, device=dev, max_batch_size=512, precision='fp32') stats = optimal_batch_size.profile() mbsz = stats['optimal_batch_size'] dl_test = DataLoader(dataset=ds_test, batch_size=mbsz, shuffle=False, num_workers=4, drop_last=False, prefetch_factor=6, pin_memory=True) preds = np.zeros((dl_test.dataset.total_patches, params['train']['psz'], params['train']['psz'])) log.info("Patch volume size: %dx%dx%d" % (preds.shape[0], preds.shape[1], preds.shape[2])) insert_cnt = 0 log.info("Processing data ...") with torch.no_grad(): for X, _ in tqdm(dl_test): output = model(X.to(dev)).cpu().squeeze(dim=1).numpy() preds[insert_cnt:(insert_cnt + X.shape[0])] = output insert_cnt += X.shape[0] log.info("Stitching denoised data ...") preds = ds_test.stitch_predictions(preds, window=params['infer']['window'], keep_k_dim=False) # Rescale back to original intensity range preds = preds * params['dataset']['std4norm'] + params['dataset']['mean4norm'] # Save volume log.info("Saving data to %s ..." % output_dir) if len(args.start_slice) == 0: tiffs_mod.save_stack(output_dir, preds) else: tiffs_mod.save_stack(output_dir, preds, offset=int(args.start_slice)) log.info("Done.")