mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-21 13:47:38 +08:00
添加sttn训练代码
This commit is contained in:
257
backend/tools/train/trainer_sttn.py
Normal file
257
backend/tools/train/trainer_sttn.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import os
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
import importlib
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tensorboardX import SummaryWriter
|
||||
from backend.tools.train.dataset_sttn import Dataset
|
||||
from backend.tools.train.loss_sttn import AdversarialLoss
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, debug=False):
|
||||
self.config = config
|
||||
self.epoch = 0
|
||||
self.iteration = 0
|
||||
if debug:
|
||||
self.config['trainer']['save_freq'] = 5
|
||||
self.config['trainer']['valid_freq'] = 5
|
||||
self.config['trainer']['iterations'] = 5
|
||||
|
||||
# setup data set and data loader
|
||||
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug)
|
||||
self.train_sampler = None
|
||||
self.train_args = config['trainer']
|
||||
if config['distributed']:
|
||||
self.train_sampler = DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=config['world_size'],
|
||||
rank=config['global_rank'])
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.train_args['batch_size'] // config['world_size'],
|
||||
shuffle=(self.train_sampler is None),
|
||||
num_workers=self.train_args['num_workers'],
|
||||
sampler=self.train_sampler)
|
||||
|
||||
# set loss functions
|
||||
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
|
||||
self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
# setup models including generator and discriminator
|
||||
net = importlib.import_module('model.' + config['model'])
|
||||
self.netG = net.InpaintGenerator()
|
||||
self.netG = self.netG.to(self.config['device'])
|
||||
self.netD = net.Discriminator(
|
||||
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
|
||||
self.netD = self.netD.to(self.config['device'])
|
||||
self.optimG = torch.optim.Adam(
|
||||
self.netG.parameters(),
|
||||
lr=config['trainer']['lr'],
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
|
||||
self.optimD = torch.optim.Adam(
|
||||
self.netD.parameters(),
|
||||
lr=config['trainer']['lr'],
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
|
||||
self.load()
|
||||
|
||||
if config['distributed']:
|
||||
self.netG = DDP(
|
||||
self.netG,
|
||||
device_ids=[self.config['local_rank']],
|
||||
output_device=self.config['local_rank'],
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=False)
|
||||
self.netD = DDP(
|
||||
self.netD,
|
||||
device_ids=[self.config['local_rank']],
|
||||
output_device=self.config['local_rank'],
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=False)
|
||||
|
||||
# set summary writer
|
||||
self.dis_writer = None
|
||||
self.gen_writer = None
|
||||
self.summary = {}
|
||||
if self.config['global_rank'] == 0 or (not config['distributed']):
|
||||
self.dis_writer = SummaryWriter(
|
||||
os.path.join(config['save_dir'], 'dis'))
|
||||
self.gen_writer = SummaryWriter(
|
||||
os.path.join(config['save_dir'], 'gen'))
|
||||
|
||||
# get current learning rate
|
||||
def get_lr(self):
|
||||
return self.optimG.param_groups[0]['lr']
|
||||
|
||||
# learning rate scheduler, step
|
||||
def adjust_learning_rate(self):
|
||||
decay = 0.1 ** (min(self.iteration,
|
||||
self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
|
||||
new_lr = self.config['trainer']['lr'] * decay
|
||||
if new_lr != self.get_lr():
|
||||
for param_group in self.optimG.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
for param_group in self.optimD.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
|
||||
# add summary
|
||||
def add_summary(self, writer, name, val):
|
||||
if name not in self.summary:
|
||||
self.summary[name] = 0
|
||||
self.summary[name] += val
|
||||
if writer is not None and self.iteration % 100 == 0:
|
||||
writer.add_scalar(name, self.summary[name] / 100, self.iteration)
|
||||
self.summary[name] = 0
|
||||
|
||||
# load netG and netD
|
||||
def load(self):
|
||||
model_path = self.config['save_dir']
|
||||
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
||||
latest_epoch = open(os.path.join(
|
||||
model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
|
||||
else:
|
||||
ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
|
||||
os.path.join(model_path, '*.pth'))]
|
||||
ckpts.sort()
|
||||
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None
|
||||
if latest_epoch is not None:
|
||||
gen_path = os.path.join(
|
||||
model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
dis_path = os.path.join(
|
||||
model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
opt_path = os.path.join(
|
||||
model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
if self.config['global_rank'] == 0:
|
||||
print('Loading model from {}...'.format(gen_path))
|
||||
data = torch.load(gen_path, map_location=self.config['device'])
|
||||
self.netG.load_state_dict(data['netG'])
|
||||
data = torch.load(dis_path, map_location=self.config['device'])
|
||||
self.netD.load_state_dict(data['netD'])
|
||||
data = torch.load(opt_path, map_location=self.config['device'])
|
||||
self.optimG.load_state_dict(data['optimG'])
|
||||
self.optimD.load_state_dict(data['optimD'])
|
||||
self.epoch = data['epoch']
|
||||
self.iteration = data['iteration']
|
||||
else:
|
||||
if self.config['global_rank'] == 0:
|
||||
print(
|
||||
'Warnning: There is no trained model found. An initialized model will be used.')
|
||||
|
||||
# save parameters every eval_epoch
|
||||
def save(self, it):
|
||||
if self.config['global_rank'] == 0:
|
||||
gen_path = os.path.join(
|
||||
self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
|
||||
dis_path = os.path.join(
|
||||
self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
|
||||
opt_path = os.path.join(
|
||||
self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
|
||||
print('\nsaving model to {} ...'.format(gen_path))
|
||||
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
||||
netG = self.netG.module
|
||||
netD = self.netD.module
|
||||
else:
|
||||
netG = self.netG
|
||||
netD = self.netD
|
||||
torch.save({'netG': netG.state_dict()}, gen_path)
|
||||
torch.save({'netD': netD.state_dict()}, dis_path)
|
||||
torch.save({'epoch': self.epoch,
|
||||
'iteration': self.iteration,
|
||||
'optimG': self.optimG.state_dict(),
|
||||
'optimD': self.optimD.state_dict()}, opt_path)
|
||||
os.system('echo {} > {}'.format(str(it).zfill(5),
|
||||
os.path.join(self.config['save_dir'], 'latest.ckpt')))
|
||||
|
||||
# train entry
|
||||
def train(self):
|
||||
pbar = range(int(self.train_args['iterations']))
|
||||
if self.config['global_rank'] == 0:
|
||||
pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
|
||||
|
||||
while True:
|
||||
self.epoch += 1
|
||||
if self.config['distributed']:
|
||||
self.train_sampler.set_epoch(self.epoch)
|
||||
|
||||
self._train_epoch(pbar)
|
||||
if self.iteration > self.train_args['iterations']:
|
||||
break
|
||||
print('\nEnd training....')
|
||||
|
||||
# process input and calculate loss every training epoch
|
||||
def _train_epoch(self, pbar):
|
||||
device = self.config['device']
|
||||
|
||||
for frames, masks in self.train_loader:
|
||||
self.adjust_learning_rate()
|
||||
self.iteration += 1
|
||||
|
||||
frames, masks = frames.to(device), masks.to(device)
|
||||
b, t, c, h, w = frames.size()
|
||||
masked_frame = (frames * (1 - masks).float())
|
||||
pred_img = self.netG(masked_frame, masks)
|
||||
frames = frames.view(b * t, c, h, w)
|
||||
masks = masks.view(b * t, 1, h, w)
|
||||
comp_img = frames * (1. - masks) + masks * pred_img
|
||||
|
||||
gen_loss = 0
|
||||
dis_loss = 0
|
||||
|
||||
# discriminator adversarial loss
|
||||
real_vid_feat = self.netD(frames)
|
||||
fake_vid_feat = self.netD(comp_img.detach())
|
||||
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True)
|
||||
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True)
|
||||
dis_loss += (dis_real_loss + dis_fake_loss) / 2
|
||||
self.add_summary(
|
||||
self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
|
||||
self.add_summary(
|
||||
self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
|
||||
self.optimD.zero_grad()
|
||||
dis_loss.backward()
|
||||
self.optimD.step()
|
||||
|
||||
# generator adversarial loss
|
||||
gen_vid_feat = self.netD(comp_img)
|
||||
gan_loss = self.adversarial_loss(gen_vid_feat, True, False)
|
||||
gan_loss = gan_loss * self.config['losses']['adversarial_weight']
|
||||
gen_loss += gan_loss
|
||||
self.add_summary(
|
||||
self.gen_writer, 'loss/gan_loss', gan_loss.item())
|
||||
|
||||
# generator l1 loss
|
||||
hole_loss = self.l1_loss(pred_img * masks, frames * masks)
|
||||
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
|
||||
gen_loss += hole_loss
|
||||
self.add_summary(
|
||||
self.gen_writer, 'loss/hole_loss', hole_loss.item())
|
||||
|
||||
valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
|
||||
valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
|
||||
gen_loss += valid_loss
|
||||
self.add_summary(
|
||||
self.gen_writer, 'loss/valid_loss', valid_loss.item())
|
||||
|
||||
self.optimG.zero_grad()
|
||||
gen_loss.backward()
|
||||
self.optimG.step()
|
||||
|
||||
# console logs
|
||||
if self.config['global_rank'] == 0:
|
||||
pbar.update(1)
|
||||
pbar.set_description((
|
||||
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};"
|
||||
f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
|
||||
)
|
||||
|
||||
# saving models
|
||||
if self.iteration % self.train_args['save_freq'] == 0:
|
||||
self.save(int(self.iteration // self.train_args['save_freq']))
|
||||
if self.iteration > self.train_args['iterations']:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user