#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Noise2Inverse model training using distributed data parallel (DDP).
"""
import os
import time
import shutil
import yaml
import datetime
import numpy as np
import torch
from copy import deepcopy
from matplotlib import pyplot as plt
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from denoise.model import unet_ns_gn
from denoise.model3d import unet3d
from denoise.loss import LCL
from denoise.data import TomoDatasetTrain
from denoise.data3d import TomoDataset3DTrain
from denoise.utils import save2img
from denoise.eval import laplacian_score_batch
from denoise import log
[docs]
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs]
def run(args):
# read the YAML file
with open(args.config, 'r') as file:
params = yaml.safe_load(file)
START_TIME = time.time()
# setup distributed training using PyTorch's DDP framework
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Use a 3-hour timeout to accommodate large dataset loading and normalization
_timeout = datetime.timedelta(hours=3)
if torch.distributed.is_nccl_available():
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size, timeout=_timeout)
else:
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size, timeout=_timeout)
# create directory containing training results in the directory of the reconstructions
path_to_reconstructions = params['dataset']['directory_to_reconstructions']
odir = getattr(args, 'output_dir', None) or (path_to_reconstructions + '/TrainOutput')
if rank == 0:
if getattr(args, 'resume', False):
if not os.path.isdir(odir):
raise RuntimeError("--resume specified but TrainOutput not found: %s" % odir)
if not os.path.isdir(f'{odir}/results'):
os.mkdir(f'{odir}/results')
else:
if os.path.isdir(odir):
shutil.rmtree(odir)
os.mkdir(odir)
os.mkdir(f'{odir}/results')
torch.distributed.barrier()
log.info("local rank %d (global rank %d) of a world size %d started" % (local_rank, rank, world_size))
torch.cuda.set_device(local_rank)
# Determine mode: CLI flag > YAML > default 2.5d
mode = getattr(args, 'mode', None) or params['train'].get('mode', '2.5d')
log.info("Training mode: %s" % mode)
# Save mode into config so inference commands can read it (rank 0 only)
if rank == 0 and params['train'].get('mode') != mode:
import yaml as _yaml
with open(args.config, 'r') as _f:
_cfg = _yaml.safe_load(_f)
_cfg['train']['mode'] = mode
with open(args.config, 'w') as _f:
_yaml.safe_dump(_cfg, _f, default_flow_style=False, sort_keys=False)
torch.distributed.barrier()
log.info("Loading data into CPU memory, it will take a while ...")
if mode == '3d':
ds_train = TomoDataset3DTrain(params=params, config_file=args.config)
else:
ds_train = TomoDatasetTrain(params=params, config_file=args.config)
# Only rank 0 writes normalization stats
if rank == 0:
from denoise.data import save_normalization_value
log.info("Saving training mean and std to config for inference")
save_normalization_value(config_file=args.config, mean=ds_train.split0_mean, std=ds_train.split0_std)
torch.distributed.barrier()
train_sampler = DistributedSampler(dataset=ds_train, shuffle=True, drop_last=True)
# 3D datasets can be hundreds of GB; forked workers trigger copy-on-write on
# every random patch access, ballooning memory. num_workers=0 avoids the
# fork entirely — the main thread reads patches directly from shared RAM.
_nw = 0 if mode == '3d' else 4
_pf = None if _nw == 0 else 2
dl_train = DataLoader(dataset=ds_train, batch_size=params['train']['mbsz'], sampler=train_sampler,
num_workers=_nw, drop_last=False, prefetch_factor=_pf, pin_memory=True)
log.info("Loaded %d samples into CPU memory for training." % len(ds_train))
# Initialize model
n_slices = params['train']['n_slices']
if mode == '3d':
psz_3d = int(params['train'].get('psz_3d', params['train'].get('psz', 64)))
n_blocks = int(params['train'].get('n_blocks_3d', 3))
start_filts = int(params['train'].get('start_filts_3d', 32))
model = unet3d(in_channels=1, out_channels=1, n_blocks=n_blocks,
start_filts=start_filts).cuda()
else:
model = unet_ns_gn(ich=n_slices, start_filter_size=16, channels_per_group=8).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=params['train']['lr'])
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
log.info("Number of model parameters: %s" % f"{count_parameters(model):,}")
# loss functions and warmup criteria
criterion = torch.nn.L1Loss()
criterion_lcl = LCL()
beta = .01
z_stride = params['train'].get('z_stride', 1)
warmup = params['train']['warmup'] // max(1, z_stride)
# training state (overwritten on --resume)
model_updates = 0
start_epoch = 1
continue_warmup = True
train_loss, val_loss = [], []
edge_values = []
train_lcl_loss, val_lcl_loss = [], []
best_val_loss, best_edge, best_lcl_loss = np.inf, 0, np.inf
best_val_epoch, best_edge_epoch, best_lcl_epoch = 0, 0, 0
epochs_since_improvement = 0
patience = params['train'].get('patience', 0)
center_idx = n_slices // 2 # used only in 2.5d mode
if getattr(args, 'resume', False):
resume_path = f"{odir}/resume.pth"
if not os.path.exists(resume_path):
raise RuntimeError("--resume specified but no resume checkpoint found at: %s" % resume_path)
ckpt = torch.load(resume_path, map_location='cpu', weights_only=False)
model.module.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
start_epoch = ckpt['epoch'] + 1
model_updates = ckpt['model_updates']
best_val_loss = ckpt['best_val_loss']
best_lcl_loss = ckpt['best_lcl_loss']
best_edge = ckpt['best_edge']
best_val_epoch = ckpt['best_val_epoch']
best_lcl_epoch = ckpt['best_lcl_epoch']
best_edge_epoch = ckpt['best_edge_epoch']
train_loss = ckpt['train_loss']
val_loss = ckpt['val_loss']
train_lcl_loss = ckpt['train_lcl_loss']
val_lcl_loss = ckpt['val_lcl_loss']
edge_values = ckpt['edge_values']
continue_warmup = ckpt['continue_warmup']
epochs_since_improvement = ckpt.get('epochs_since_improvement', 0)
log.info("Resuming training from epoch %d (model_updates=%d)" % (start_epoch, model_updates))
elif getattr(args, 'finetune', None):
finetune_path = args.finetune
if os.path.isdir(finetune_path):
finetune_path = os.path.join(finetune_path, 'best_val_model.pth')
if not os.path.exists(finetune_path):
raise RuntimeError("--finetune: checkpoint not found: %s" % finetune_path)
ckpt = torch.load(finetune_path, map_location='cpu', weights_only=False)
model.module.load_state_dict(ckpt['model_state_dict'])
log.info("Fine-tuning from: %s (training state reset from scratch)" % finetune_path)
else:
log.info('Initializing model from scratch')
# start training
for epoch in range(start_epoch, params['train']['maxep'] + 1):
step_losses, step_val_losses, step_lcl_loss, step_lcl_val_loss, step_edge_values = [], [], [], [], []
tick_ep = time.time()
model.train()
dl_train.sampler.set_epoch(epoch)
# ---- training loop ----
for X_mb, Y_mb in dl_train:
X_mb_dev = X_mb.cuda()
Y_mb_dev = Y_mb.cuda()
optimizer.zero_grad(set_to_none=True)
pred_view1 = model(X_mb_dev)
pred_view2 = model(Y_mb_dev)
if mode == '3d':
# 3D: inputs/outputs are [B, 1, D, H, W]; target is the full patch
loss_view1 = criterion(pred_view1, Y_mb_dev)
loss_view2 = criterion(pred_view2, X_mb_dev)
else:
# 2.5D: squeeze channel dim, compare against center slice
loss_view1 = criterion(pred_view1.squeeze(dim=1), Y_mb_dev[:, center_idx])
loss_view2 = criterion(pred_view2.squeeze(dim=1), X_mb_dev[:, center_idx])
if model_updates <= warmup:
loss_lcl1 = torch.tensor(0.)
loss_lcl2 = torch.tensor(0.)
loss = 0.5 * (loss_view1 + loss_view2)
else:
if mode == '3d':
# Apply LCL on center slice of 3D output
D = pred_view1.shape[2]
loss_lcl1 = criterion_lcl(pred_view1[:, :, D // 2]) * beta
loss_lcl2 = criterion_lcl(pred_view2[:, :, D // 2]) * beta
else:
loss_lcl1 = criterion_lcl(pred_view1) * beta
loss_lcl2 = criterion_lcl(pred_view2) * beta
loss = 0.5 * (loss_view1 + loss_view2) + 0.5 * (loss_lcl1 + loss_lcl2)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
step_losses.append((loss_view1 + loss_view2).detach().cpu().numpy())
step_lcl_loss.append((loss_lcl1 + loss_lcl2).detach().cpu().numpy())
model_updates += 1
model.eval()
with torch.no_grad():
# ---- validation loop ----
for X_mb, Y_mb in dl_train:
X_mb_dev = X_mb.cuda()
Y_mb_dev = Y_mb.cuda()
pred_view1 = model(X_mb_dev)
pred_view2 = model(Y_mb_dev)
if mode == '3d':
loss_view1 = criterion(pred_view1, Y_mb_dev)
loss_view2 = criterion(pred_view2, X_mb_dev)
D = pred_view1.shape[2]
loss_lcl1 = criterion_lcl(pred_view1[:, :, D // 2]) * beta
loss_lcl2 = criterion_lcl(pred_view2[:, :, D // 2]) * beta
# Edge score on center slice
lap_score = (laplacian_score_batch(pred_view1[:, :, D // 2].cpu()) +
laplacian_score_batch(pred_view2[:, :, D // 2].cpu()))
else:
loss_view1 = criterion(pred_view1.squeeze(dim=1), Y_mb_dev[:, int(n_slices // 2)])
loss_view2 = criterion(pred_view2.squeeze(dim=1), X_mb_dev[:, int(n_slices // 2)])
loss_lcl1 = criterion_lcl(pred_view1) * beta
loss_lcl2 = criterion_lcl(pred_view2) * beta
lap_score = laplacian_score_batch(pred_view1.cpu()) + laplacian_score_batch(pred_view2.cpu())
loss = loss_view1 + loss_view2
step_edge_values.append(lap_score)
step_val_losses.append(loss.detach().cpu().numpy())
loss_lcl = loss_lcl1 + loss_lcl2
step_lcl_val_loss.append(loss_lcl.cpu().numpy())
ep_time = time.time() - tick_ep
# Early stopping: rank world_size-1 decides, broadcasts to all ranks
stop_tensor = torch.zeros(1, dtype=torch.int32, device='cuda')
if rank == world_size - 1:
if patience > 0 and not continue_warmup:
if np.mean(step_val_losses) < best_val_loss:
epochs_since_improvement = 0
else:
epochs_since_improvement += 1
if epochs_since_improvement >= patience:
stop_tensor.fill_(1)
torch.distributed.broadcast(stop_tensor, src=world_size - 1)
early_stop = bool(stop_tensor.item())
if rank != world_size - 1:
if early_stop:
break
continue
log.info('Epoch %d' % epoch)
log.info('[Train] L1 loss: %.6f, %.6f => %.6f, rate: %.2fs/ep' % (
np.mean(step_losses), step_losses[0], step_losses[-1], ep_time))
log.info('[Train] LCL loss: %.6f, %.6f => %.6f, rate: %.2fs/ep' % (
np.mean(step_lcl_loss), step_lcl_loss[0], step_lcl_loss[-1], ep_time))
log.info('[Val] L1 loss: %.6f, %.6f => %.6f, rate: %.2fs/ep' % (
np.mean(step_val_losses), step_val_losses[0], step_val_losses[-1], ep_time))
log.info('[Val] LCL loss: %.6f, %.6f => %.6f, rate: %.2fs/ep' % (
np.mean(step_lcl_val_loss), step_lcl_val_loss[0], step_lcl_val_loss[-1], ep_time))
log.info('[Val] EDGE Value: %.4f, %.4f => %.4f, rate: %.2fs/ep' % (
np.mean(step_edge_values), step_edge_values[0], step_edge_values[-1], ep_time))
train_loss.append(np.mean(step_losses))
val_loss.append(np.mean(step_val_losses))
train_lcl_loss.append(np.mean(step_lcl_loss))
val_lcl_loss.append(np.mean(step_lcl_val_loss))
edge_values.append(np.mean(step_edge_values))
# Save the best model with the lowest lcl loss
if np.mean(step_lcl_val_loss) < best_lcl_loss:
best_lcl_loss = np.mean(step_lcl_val_loss)
best_lcl_epoch = epoch
mdl_fname = f"{odir}/best_lcl_model.pth"
torch.save({
'model_state_dict': deepcopy(model.module.state_dict()),
'optimizer_state_dict': deepcopy(optimizer.state_dict())
}, mdl_fname)
# Save the best model with the lowest val loss
if np.mean(step_val_losses) < best_val_loss:
best_val_loss = np.mean(step_val_losses)
best_val_epoch = epoch
mdl_fname = f"{odir}/best_val_model.pth"
torch.save({
'model_state_dict': deepcopy(model.module.state_dict()),
'optimizer_state_dict': deepcopy(optimizer.state_dict())
}, mdl_fname)
# Save the best model with the highest edge value
if np.mean(step_edge_values) > best_edge:
best_edge = np.mean(step_edge_values)
best_edge_epoch = epoch
mdl_fname = f"{odir}/best_edge_model.pth"
torch.save({
'model_state_dict': deepcopy(model.module.state_dict()),
'optimizer_state_dict': deepcopy(optimizer.state_dict())
}, mdl_fname)
# Warm up period
if model_updates > warmup and continue_warmup:
best_edge, best_lcl_loss = 0, np.inf
best_edge_epoch, best_lcl_epoch = 0, 0
epochs_since_improvement = 0
continue_warmup = False
CRNT_TIME = time.time()
log.info("[Info] Training Time: %.2f seconds" % (CRNT_TIME - START_TIME))
# option to view the denoising process during training
if epoch % 5 == 0:
ridx = np.random.randint(pred_view1.shape[0])
if mode == '3d':
D = pred_view1.shape[2]
save2img(pred_view1[ridx, 0, D // 2].detach().cpu().numpy(), '%s/results/_%d_pred_view1.png' % (odir, epoch))
save2img(pred_view2[ridx, 0, D // 2].detach().cpu().numpy(), '%s/results/_%d_pred_view2.png' % (odir, epoch))
save2img(X_mb_dev[ridx, 0, D // 2].detach().cpu().numpy(), '%s/results/_%d_view1.png' % (odir, epoch))
save2img(Y_mb_dev[ridx, 0, D // 2].detach().cpu().numpy(), '%s/results/_%d_view2.png' % (odir, epoch))
else:
save2img(pred_view1[ridx, -1].detach().cpu().numpy(), '%s/results/_%d_pred_view1.png' % (odir, epoch))
save2img(pred_view2[ridx, -1].detach().cpu().numpy(), '%s/results/_%d_pred_view2.png' % (odir, epoch))
save2img(X_mb_dev[ridx, int(n_slices // 2)].detach().cpu().numpy(), '%s/results/_%d_view1.png' % (odir, epoch))
save2img(Y_mb_dev[ridx, int(n_slices // 2)].detach().cpu().numpy(), '%s/results/_%d_view2.png' % (odir, epoch))
# Keep track of when/where the best model is
log.info('Lowest model validation loss %.6f at epoch %d' % (best_val_loss, best_val_epoch))
log.info('Lowest model LCL loss %.6f at epoch %d' % (best_lcl_loss, best_lcl_epoch))
log.info('Highest model EDGE score %.6f at epoch %d' % (best_edge, best_edge_epoch))
log.info('Number of model updates: %s' % f"{model_updates:,}")
log.info('Is model warming up?: %s' % continue_warmup)
# View the training/validation loss during training
if epoch % 5 == 0:
plt.figure(figsize=(12, 8))
plt.title("Training Progress")
plt.plot(train_loss[:], label="Training Loss")
plt.plot(val_loss[:], label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f'{odir}/results/__model_training.png')
plt.close()
plt.figure(figsize=(12, 8))
plt.title("Training Progress")
plt.plot(train_lcl_loss[:], label="Training Loss")
plt.plot(val_lcl_loss[:], label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f'{odir}/results/__model_lcl_training.png')
plt.close()
plt.figure(figsize=(12, 8))
plt.title("Training Progress")
plt.plot(edge_values[:])
plt.xlabel("Epoch")
plt.ylabel("EDGE Gradient")
plt.savefig(f'{odir}/results/__edge_training.png')
plt.close()
# Save resume checkpoint (overwritten each epoch; enables --resume after interruption)
torch.save({
'model_state_dict': deepcopy(model.module.state_dict()),
'optimizer_state_dict': deepcopy(optimizer.state_dict()),
'epoch': epoch,
'model_updates': model_updates,
'best_val_loss': best_val_loss,
'best_lcl_loss': best_lcl_loss,
'best_edge': best_edge,
'best_val_epoch': best_val_epoch,
'best_lcl_epoch': best_lcl_epoch,
'best_edge_epoch': best_edge_epoch,
'train_loss': train_loss,
'val_loss': val_loss,
'train_lcl_loss': train_lcl_loss,
'val_lcl_loss': val_lcl_loss,
'edge_values': edge_values,
'continue_warmup': continue_warmup,
'epochs_since_improvement': epochs_since_improvement,
}, f"{odir}/resume.pth")
if early_stop:
log.info("Early stopping: no val loss improvement for %d epochs (best was epoch %d)" % (patience, best_val_epoch))
break