import os import glob from tqdm import tqdm import importlib import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from tensorboardX import SummaryWriter from backend.tools.train.dataset_sttn import Dataset from backend.tools.train.loss_sttn import AdversarialLoss class Trainer: def __init__(self, config, debug=False): self.config = config self.epoch = 0 self.iteration = 0 if debug: self.config['trainer']['save_freq'] = 5 self.config['trainer']['valid_freq'] = 5 self.config['trainer']['iterations'] = 5 # setup data set and data loader self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug) self.train_sampler = None self.train_args = config['trainer'] if config['distributed']: self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=config['world_size'], rank=config['global_rank']) self.train_loader = DataLoader( self.train_dataset, batch_size=self.train_args['batch_size'] // config['world_size'], shuffle=(self.train_sampler is None), num_workers=self.train_args['num_workers'], sampler=self.train_sampler) # set loss functions self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) self.adversarial_loss = self.adversarial_loss.to(self.config['device']) self.l1_loss = nn.L1Loss() # setup models including generator and discriminator net = importlib.import_module('model.' + config['model']) self.netG = net.InpaintGenerator() self.netG = self.netG.to(self.config['device']) self.netD = net.Discriminator( in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') self.netD = self.netD.to(self.config['device']) self.optimG = torch.optim.Adam( self.netG.parameters(), lr=config['trainer']['lr'], betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) self.optimD = torch.optim.Adam( self.netD.parameters(), lr=config['trainer']['lr'], betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) self.load() if config['distributed']: self.netG = DDP( self.netG, device_ids=[self.config['local_rank']], output_device=self.config['local_rank'], broadcast_buffers=True, find_unused_parameters=False) self.netD = DDP( self.netD, device_ids=[self.config['local_rank']], output_device=self.config['local_rank'], broadcast_buffers=True, find_unused_parameters=False) # set summary writer self.dis_writer = None self.gen_writer = None self.summary = {} if self.config['global_rank'] == 0 or (not config['distributed']): self.dis_writer = SummaryWriter( os.path.join(config['save_dir'], 'dis')) self.gen_writer = SummaryWriter( os.path.join(config['save_dir'], 'gen')) # get current learning rate def get_lr(self): return self.optimG.param_groups[0]['lr'] # learning rate scheduler, step def adjust_learning_rate(self): decay = 0.1 ** (min(self.iteration, self.config['trainer']['niter_steady']) // self.config['trainer']['niter']) new_lr = self.config['trainer']['lr'] * decay if new_lr != self.get_lr(): for param_group in self.optimG.param_groups: param_group['lr'] = new_lr for param_group in self.optimD.param_groups: param_group['lr'] = new_lr # add summary def add_summary(self, writer, name, val): if name not in self.summary: self.summary[name] = 0 self.summary[name] += val if writer is not None and self.iteration % 100 == 0: writer.add_scalar(name, self.summary[name] / 100, self.iteration) self.summary[name] = 0 # load netG and netD def load(self): model_path = self.config['save_dir'] if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): latest_epoch = open(os.path.join( model_path, 'latest.ckpt'), 'r').read().splitlines()[-1] else: ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob( os.path.join(model_path, '*.pth'))] ckpts.sort() latest_epoch = ckpts[-1] if len(ckpts) > 0 else None if latest_epoch is not None: gen_path = os.path.join( model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5))) dis_path = os.path.join( model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5))) opt_path = os.path.join( model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5))) if self.config['global_rank'] == 0: print('Loading model from {}...'.format(gen_path)) data = torch.load(gen_path, map_location=self.config['device']) self.netG.load_state_dict(data['netG']) data = torch.load(dis_path, map_location=self.config['device']) self.netD.load_state_dict(data['netD']) data = torch.load(opt_path, map_location=self.config['device']) self.optimG.load_state_dict(data['optimG']) self.optimD.load_state_dict(data['optimD']) self.epoch = data['epoch'] self.iteration = data['iteration'] else: if self.config['global_rank'] == 0: print( 'Warnning: There is no trained model found. An initialized model will be used.') # save parameters every eval_epoch def save(self, it): if self.config['global_rank'] == 0: gen_path = os.path.join( self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5))) dis_path = os.path.join( self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5))) opt_path = os.path.join( self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5))) print('\nsaving model to {} ...'.format(gen_path)) if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): netG = self.netG.module netD = self.netD.module else: netG = self.netG netD = self.netD torch.save({'netG': netG.state_dict()}, gen_path) torch.save({'netD': netD.state_dict()}, dis_path) torch.save({'epoch': self.epoch, 'iteration': self.iteration, 'optimG': self.optimG.state_dict(), 'optimD': self.optimD.state_dict()}, opt_path) os.system('echo {} > {}'.format(str(it).zfill(5), os.path.join(self.config['save_dir'], 'latest.ckpt'))) # train entry def train(self): pbar = range(int(self.train_args['iterations'])) if self.config['global_rank'] == 0: pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01) while True: self.epoch += 1 if self.config['distributed']: self.train_sampler.set_epoch(self.epoch) self._train_epoch(pbar) if self.iteration > self.train_args['iterations']: break print('\nEnd training....') # process input and calculate loss every training epoch def _train_epoch(self, pbar): device = self.config['device'] for frames, masks in self.train_loader: self.adjust_learning_rate() self.iteration += 1 frames, masks = frames.to(device), masks.to(device) b, t, c, h, w = frames.size() masked_frame = (frames * (1 - masks).float()) pred_img = self.netG(masked_frame, masks) frames = frames.view(b * t, c, h, w) masks = masks.view(b * t, 1, h, w) comp_img = frames * (1. - masks) + masks * pred_img gen_loss = 0 dis_loss = 0 # discriminator adversarial loss real_vid_feat = self.netD(frames) fake_vid_feat = self.netD(comp_img.detach()) dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) dis_loss += (dis_real_loss + dis_fake_loss) / 2 self.add_summary( self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item()) self.add_summary( self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item()) self.optimD.zero_grad() dis_loss.backward() self.optimD.step() # generator adversarial loss gen_vid_feat = self.netD(comp_img) gan_loss = self.adversarial_loss(gen_vid_feat, True, False) gan_loss = gan_loss * self.config['losses']['adversarial_weight'] gen_loss += gan_loss self.add_summary( self.gen_writer, 'loss/gan_loss', gan_loss.item()) # generator l1 loss hole_loss = self.l1_loss(pred_img * masks, frames * masks) hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight'] gen_loss += hole_loss self.add_summary( self.gen_writer, 'loss/hole_loss', hole_loss.item()) valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks)) valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight'] gen_loss += valid_loss self.add_summary( self.gen_writer, 'loss/valid_loss', valid_loss.item()) self.optimG.zero_grad() gen_loss.backward() self.optimG.step() # console logs if self.config['global_rank'] == 0: pbar.update(1) pbar.set_description(( f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}") ) # saving models if self.iteration % self.train_args['save_freq'] == 0: self.save(int(self.iteration // self.train_args['save_freq'])) if self.iteration > self.train_args['iterations']: break