mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-16 05:01:06 +08:00
258 lines
11 KiB
Python
258 lines
11 KiB
Python
import os
|
|
import glob
|
|
from tqdm import tqdm
|
|
import importlib
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from tensorboardX import SummaryWriter
|
|
from backend.tools.train.dataset_sttn import Dataset
|
|
from backend.tools.train.loss_sttn import AdversarialLoss
|
|
|
|
|
|
class Trainer:
|
|
def __init__(self, config, debug=False):
|
|
self.config = config
|
|
self.epoch = 0
|
|
self.iteration = 0
|
|
if debug:
|
|
self.config['trainer']['save_freq'] = 5
|
|
self.config['trainer']['valid_freq'] = 5
|
|
self.config['trainer']['iterations'] = 5
|
|
|
|
# setup data set and data loader
|
|
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug)
|
|
self.train_sampler = None
|
|
self.train_args = config['trainer']
|
|
if config['distributed']:
|
|
self.train_sampler = DistributedSampler(
|
|
self.train_dataset,
|
|
num_replicas=config['world_size'],
|
|
rank=config['global_rank'])
|
|
self.train_loader = DataLoader(
|
|
self.train_dataset,
|
|
batch_size=self.train_args['batch_size'] // config['world_size'],
|
|
shuffle=(self.train_sampler is None),
|
|
num_workers=self.train_args['num_workers'],
|
|
sampler=self.train_sampler)
|
|
|
|
# set loss functions
|
|
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
|
|
self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
|
|
self.l1_loss = nn.L1Loss()
|
|
|
|
# setup models including generator and discriminator
|
|
net = importlib.import_module('model.' + config['model'])
|
|
self.netG = net.InpaintGenerator()
|
|
self.netG = self.netG.to(self.config['device'])
|
|
self.netD = net.Discriminator(
|
|
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
|
|
self.netD = self.netD.to(self.config['device'])
|
|
self.optimG = torch.optim.Adam(
|
|
self.netG.parameters(),
|
|
lr=config['trainer']['lr'],
|
|
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
|
|
self.optimD = torch.optim.Adam(
|
|
self.netD.parameters(),
|
|
lr=config['trainer']['lr'],
|
|
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
|
|
self.load()
|
|
|
|
if config['distributed']:
|
|
self.netG = DDP(
|
|
self.netG,
|
|
device_ids=[self.config['local_rank']],
|
|
output_device=self.config['local_rank'],
|
|
broadcast_buffers=True,
|
|
find_unused_parameters=False)
|
|
self.netD = DDP(
|
|
self.netD,
|
|
device_ids=[self.config['local_rank']],
|
|
output_device=self.config['local_rank'],
|
|
broadcast_buffers=True,
|
|
find_unused_parameters=False)
|
|
|
|
# set summary writer
|
|
self.dis_writer = None
|
|
self.gen_writer = None
|
|
self.summary = {}
|
|
if self.config['global_rank'] == 0 or (not config['distributed']):
|
|
self.dis_writer = SummaryWriter(
|
|
os.path.join(config['save_dir'], 'dis'))
|
|
self.gen_writer = SummaryWriter(
|
|
os.path.join(config['save_dir'], 'gen'))
|
|
|
|
# get current learning rate
|
|
def get_lr(self):
|
|
return self.optimG.param_groups[0]['lr']
|
|
|
|
# learning rate scheduler, step
|
|
def adjust_learning_rate(self):
|
|
decay = 0.1 ** (min(self.iteration,
|
|
self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
|
|
new_lr = self.config['trainer']['lr'] * decay
|
|
if new_lr != self.get_lr():
|
|
for param_group in self.optimG.param_groups:
|
|
param_group['lr'] = new_lr
|
|
for param_group in self.optimD.param_groups:
|
|
param_group['lr'] = new_lr
|
|
|
|
# add summary
|
|
def add_summary(self, writer, name, val):
|
|
if name not in self.summary:
|
|
self.summary[name] = 0
|
|
self.summary[name] += val
|
|
if writer is not None and self.iteration % 100 == 0:
|
|
writer.add_scalar(name, self.summary[name] / 100, self.iteration)
|
|
self.summary[name] = 0
|
|
|
|
# load netG and netD
|
|
def load(self):
|
|
model_path = self.config['save_dir']
|
|
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
|
latest_epoch = open(os.path.join(
|
|
model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
|
|
else:
|
|
ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
|
|
os.path.join(model_path, '*.pth'))]
|
|
ckpts.sort()
|
|
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None
|
|
if latest_epoch is not None:
|
|
gen_path = os.path.join(
|
|
model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
|
|
dis_path = os.path.join(
|
|
model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
|
|
opt_path = os.path.join(
|
|
model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
|
|
if self.config['global_rank'] == 0:
|
|
print('Loading model from {}...'.format(gen_path))
|
|
data = torch.load(gen_path, map_location=self.config['device'])
|
|
self.netG.load_state_dict(data['netG'])
|
|
data = torch.load(dis_path, map_location=self.config['device'])
|
|
self.netD.load_state_dict(data['netD'])
|
|
data = torch.load(opt_path, map_location=self.config['device'])
|
|
self.optimG.load_state_dict(data['optimG'])
|
|
self.optimD.load_state_dict(data['optimD'])
|
|
self.epoch = data['epoch']
|
|
self.iteration = data['iteration']
|
|
else:
|
|
if self.config['global_rank'] == 0:
|
|
print(
|
|
'Warnning: There is no trained model found. An initialized model will be used.')
|
|
|
|
# save parameters every eval_epoch
|
|
def save(self, it):
|
|
if self.config['global_rank'] == 0:
|
|
gen_path = os.path.join(
|
|
self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
|
|
dis_path = os.path.join(
|
|
self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
|
|
opt_path = os.path.join(
|
|
self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
|
|
print('\nsaving model to {} ...'.format(gen_path))
|
|
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
|
netG = self.netG.module
|
|
netD = self.netD.module
|
|
else:
|
|
netG = self.netG
|
|
netD = self.netD
|
|
torch.save({'netG': netG.state_dict()}, gen_path)
|
|
torch.save({'netD': netD.state_dict()}, dis_path)
|
|
torch.save({'epoch': self.epoch,
|
|
'iteration': self.iteration,
|
|
'optimG': self.optimG.state_dict(),
|
|
'optimD': self.optimD.state_dict()}, opt_path)
|
|
os.system('echo {} > {}'.format(str(it).zfill(5),
|
|
os.path.join(self.config['save_dir'], 'latest.ckpt')))
|
|
|
|
# train entry
|
|
def train(self):
|
|
pbar = range(int(self.train_args['iterations']))
|
|
if self.config['global_rank'] == 0:
|
|
pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
|
|
|
|
while True:
|
|
self.epoch += 1
|
|
if self.config['distributed']:
|
|
self.train_sampler.set_epoch(self.epoch)
|
|
|
|
self._train_epoch(pbar)
|
|
if self.iteration > self.train_args['iterations']:
|
|
break
|
|
print('\nEnd training....')
|
|
|
|
# process input and calculate loss every training epoch
|
|
def _train_epoch(self, pbar):
|
|
device = self.config['device']
|
|
|
|
for frames, masks in self.train_loader:
|
|
self.adjust_learning_rate()
|
|
self.iteration += 1
|
|
|
|
frames, masks = frames.to(device), masks.to(device)
|
|
b, t, c, h, w = frames.size()
|
|
masked_frame = (frames * (1 - masks).float())
|
|
pred_img = self.netG(masked_frame, masks)
|
|
frames = frames.view(b * t, c, h, w)
|
|
masks = masks.view(b * t, 1, h, w)
|
|
comp_img = frames * (1. - masks) + masks * pred_img
|
|
|
|
gen_loss = 0
|
|
dis_loss = 0
|
|
|
|
# discriminator adversarial loss
|
|
real_vid_feat = self.netD(frames)
|
|
fake_vid_feat = self.netD(comp_img.detach())
|
|
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True)
|
|
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True)
|
|
dis_loss += (dis_real_loss + dis_fake_loss) / 2
|
|
self.add_summary(
|
|
self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
|
|
self.add_summary(
|
|
self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
|
|
self.optimD.zero_grad()
|
|
dis_loss.backward()
|
|
self.optimD.step()
|
|
|
|
# generator adversarial loss
|
|
gen_vid_feat = self.netD(comp_img)
|
|
gan_loss = self.adversarial_loss(gen_vid_feat, True, False)
|
|
gan_loss = gan_loss * self.config['losses']['adversarial_weight']
|
|
gen_loss += gan_loss
|
|
self.add_summary(
|
|
self.gen_writer, 'loss/gan_loss', gan_loss.item())
|
|
|
|
# generator l1 loss
|
|
hole_loss = self.l1_loss(pred_img * masks, frames * masks)
|
|
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
|
|
gen_loss += hole_loss
|
|
self.add_summary(
|
|
self.gen_writer, 'loss/hole_loss', hole_loss.item())
|
|
|
|
valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
|
|
valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
|
|
gen_loss += valid_loss
|
|
self.add_summary(
|
|
self.gen_writer, 'loss/valid_loss', valid_loss.item())
|
|
|
|
self.optimG.zero_grad()
|
|
gen_loss.backward()
|
|
self.optimG.step()
|
|
|
|
# console logs
|
|
if self.config['global_rank'] == 0:
|
|
pbar.update(1)
|
|
pbar.set_description((
|
|
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};"
|
|
f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
|
|
)
|
|
|
|
# saving models
|
|
if self.iteration % self.train_args['save_freq'] == 0:
|
|
self.save(int(self.iteration // self.train_args['save_freq']))
|
|
if self.iteration > self.train_args['iterations']:
|
|
break
|
|
|