mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
添加注释
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -372,3 +372,4 @@ test*_no_sub*.mp4
|
||||
/backend/models/video/ProPainter.pth
|
||||
/backend/models/big-lama/big-lama.pt
|
||||
/test/debug/
|
||||
/backend/tools/train/release_model/
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
"data_loader": {
|
||||
"name": "davis",
|
||||
"data_root": "datasets/",
|
||||
"w": 432,
|
||||
"h": 240,
|
||||
"w": 640,
|
||||
"h": 120,
|
||||
"sample_length": 5
|
||||
},
|
||||
"losses": {
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
"save_dir": "release_model/",
|
||||
"data_loader": {
|
||||
"name": "youtube-vos",
|
||||
"data_root": "datasets/",
|
||||
"w": 432,
|
||||
"h": 240,
|
||||
"data_root": "datasets_sttn/",
|
||||
"w": 640,
|
||||
"h": 120,
|
||||
"sample_length": 5
|
||||
},
|
||||
"losses": {
|
||||
|
||||
@@ -8,62 +8,78 @@ from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_r
|
||||
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
|
||||
|
||||
|
||||
# 自定义的数据集
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, args: dict, split='train', debug=False):
|
||||
# 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
|
||||
self.args = args
|
||||
self.split = split
|
||||
self.sample_length = args['sample_length']
|
||||
self.size = self.w, self.h = (args['w'], args['h'])
|
||||
self.sample_length = args['sample_length'] # 样本长度参数
|
||||
self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
|
||||
|
||||
# 打开存放数据相关信息的json文件
|
||||
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
|
||||
self.video_dict = json.load(f)
|
||||
self.video_names = list(self.video_dict.keys())
|
||||
if debug or split != 'train':
|
||||
self.video_dict = json.load(f) # 加载json文件内容
|
||||
self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
|
||||
if debug or split != 'train': # 如果是调试模式或者不是训练集,只取前100个视频
|
||||
self.video_names = self.video_names[:100]
|
||||
|
||||
# 定义数据的转换操作,转换成堆叠的张量
|
||||
self._to_tensors = transforms.Compose([
|
||||
Stack(),
|
||||
ToTorchFormatTensor(), ])
|
||||
ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
# 返回数据集中视频的数量
|
||||
return len(self.video_names)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# 获取一个样本项
|
||||
try:
|
||||
item = self.load_item(index)
|
||||
item = self.load_item(index) # 尝试加载指定索引的数据项
|
||||
except:
|
||||
print('Loading error in video {}'.format(self.video_names[index]))
|
||||
item = self.load_item(0)
|
||||
print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
|
||||
item = self.load_item(0) # 加载第一个项目作为兜底
|
||||
return item
|
||||
|
||||
def load_item(self, index):
|
||||
video_name = self.video_names[index]
|
||||
# 加载数据项的具体实现
|
||||
video_name = self.video_names[index] # 根据索引获取视频名称
|
||||
# 为所有视频帧生成帧文件名列表
|
||||
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
|
||||
# 生成随机运动的随机形状的遮罩
|
||||
all_masks = create_random_shape_with_random_motion(
|
||||
len(all_frames), imageHeight=self.h, imageWidth=self.w)
|
||||
# 获取参考帧的索引
|
||||
ref_index = get_ref_index(len(all_frames), self.sample_length)
|
||||
# read video frames
|
||||
# 读取视频帧
|
||||
frames = []
|
||||
masks = []
|
||||
for idx in ref_index:
|
||||
# 读取图片,转化为RGB,调整大小并添加到列表中
|
||||
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
|
||||
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
|
||||
img = img.resize(self.size)
|
||||
frames.append(img)
|
||||
masks.append(all_masks[idx])
|
||||
if self.split == 'train':
|
||||
# 如果是训练集,随机水平翻转图像
|
||||
frames = GroupRandomHorizontalFlip()(frames)
|
||||
# To tensors
|
||||
frame_tensors = self._to_tensors(frames)*2.0 - 1.0
|
||||
mask_tensors = self._to_tensors(masks)
|
||||
return frame_tensors, mask_tensors
|
||||
# 转换成张量形式
|
||||
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
|
||||
mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
|
||||
return frame_tensors, mask_tensors # 返回图像和遮罩的张量
|
||||
|
||||
|
||||
def get_ref_index(length, sample_length):
|
||||
# 获取参考帧索引的实现
|
||||
if random.uniform(0, 1) > 0.5:
|
||||
# 有一半的概率随机选择帧
|
||||
ref_index = random.sample(range(length), sample_length)
|
||||
ref_index.sort()
|
||||
ref_index.sort() # 排序保证顺序
|
||||
else:
|
||||
# 另一半概率选择连续的帧
|
||||
pivot = random.randint(0, length-sample_length)
|
||||
ref_index = [pivot+i for i in range(sample_length)]
|
||||
return ref_index
|
||||
|
||||
1
backend/tools/train/datasets_sttn/davis/test.json
Normal file
1
backend/tools/train/datasets_sttn/davis/test.json
Normal file
@@ -0,0 +1 @@
|
||||
{"bear": 82, "bike-packing": 69, "blackswan": 50, "bmx-bumps": 90, "bmx-trees": 80, "boat": 75, "boxing-fisheye": 87, "breakdance": 84, "breakdance-flare": 71, "bus": 80, "camel": 90, "car-roundabout": 75, "car-shadow": 40, "car-turn": 80, "cat-girl": 89, "classic-car": 63, "color-run": 84, "cows": 104, "crossing": 52, "dance-jump": 60, "dance-twirl": 90, "dancing": 62, "disc-jockey": 76, "dog": 60, "dog-agility": 25, "dog-gooses": 86, "dogs-jump": 66, "dogs-scale": 83, "drift-chicane": 52, "drift-straight": 50, "drift-turn": 64, "drone": 91, "elephant": 80, "flamingo": 80, "goat": 90, "gold-fish": 78, "hike": 80, "hockey": 75, "horsejump-high": 50, "horsejump-low": 60, "india": 81, "judo": 34, "kid-football": 68, "kite-surf": 50, "kite-walk": 80, "koala": 100, "lab-coat": 47, "lady-running": 65, "libby": 49, "lindy-hop": 73, "loading": 50, "longboard": 52, "lucia": 70, "mallard-fly": 70, "mallard-water": 80, "mbike-trick": 79, "miami-surf": 70, "motocross-bumps": 60, "motocross-jump": 40, "motorbike": 43, "night-race": 46, "paragliding": 70, "paragliding-launch": 80, "parkour": 100, "pigs": 79, "planes-water": 38, "rallye": 50, "rhino": 90, "rollerblade": 35, "schoolgirls": 80, "scooter-black": 43, "scooter-board": 91, "scooter-gray": 75, "sheep": 68, "shooting": 40, "skate-park": 80, "snowboard": 66, "soapbox": 99, "soccerball": 48, "stroller": 91, "stunt": 71, "surf": 55, "swing": 60, "tennis": 70, "tractor-sand": 76, "train": 80, "tuk-tuk": 59, "upside-down": 65, "varanus-cage": 67, "walking": 72}
|
||||
1
backend/tools/train/datasets_sttn/davis/train.json
Normal file
1
backend/tools/train/datasets_sttn/davis/train.json
Normal file
@@ -0,0 +1 @@
|
||||
{"baseball": 90, "basketball-game": 77, "bears-ball": 78, "bmx-rider": 85, "butterfly": 80, "car-competition": 66, "cat": 52, "chairlift": 99, "circus": 73, "city-ride": 70, "crafting": 45, "curling": 76, "dog-competition": 85, "dolphins-show": 74, "dribbling": 49, "drone-flying": 70, "ducks": 75, "elephant-hyenas": 55, "giraffes": 88, "gym-ball": 69, "helicopter-landing": 77, "horse-race": 80, "horses-kids": 78, "hurdles-race": 55, "ice-hockey": 52, "jet-ski": 83, "juggling-selfie": 78, "kayak-race": 63, "kids-robot": 75, "landing": 35, "luggage": 83, "mantaray": 73, "marbles": 70, "mascot": 78, "mermaid": 78, "monster-trucks": 99, "motorbike-indoors": 79, "motorbike-race": 88, "music-band": 87, "obstacles": 81, "obstacles-race": 48, "peacock": 75, "plane-exhibition": 73, "puppet": 100, "robot-battle": 85, "robotic-arm": 82, "rodeo": 85, "sea-turtle": 90, "skydiving-jumping": 75, "snowboard-race": 75, "snowboard-sand": 55, "surfer": 80, "swimmer": 86, "table-tennis": 70, "tram": 84, "trucks-race": 78, "twist-dance": 83, "volleyball-beach": 73, "water-slide": 88, "weightlifting": 90}
|
||||
1
backend/tools/train/datasets_sttn/youtube-vos/test.json
Normal file
1
backend/tools/train/datasets_sttn/youtube-vos/test.json
Normal file
File diff suppressed because one or more lines are too long
1
backend/tools/train/datasets_sttn/youtube-vos/train.json
Normal file
1
backend/tools/train/datasets_sttn/youtube-vos/train.json
Normal file
File diff suppressed because one or more lines are too long
@@ -3,39 +3,54 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
r"""
|
||||
Adversarial loss
|
||||
https://arxiv.org/abs/1711.10337
|
||||
"""
|
||||
对抗性损失
|
||||
根据论文 https://arxiv.org/abs/1711.10337 实现
|
||||
"""
|
||||
|
||||
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
|
||||
r"""
|
||||
type = nsgan | lsgan | hinge
|
||||
"""
|
||||
可以选择的损失类型有 'nsgan' | 'lsgan' | 'hinge'
|
||||
type: 指定使用哪种类型的 GAN 损失。
|
||||
target_real_label: 真实图像的目标标签值。
|
||||
target_fake_label: 生成图像的目标标签值。
|
||||
"""
|
||||
super(AdversarialLoss, self).__init__()
|
||||
self.type = type
|
||||
self.type = type # 损失类型
|
||||
# 使用缓冲区注册标签,这样在模型保存和加载时会一同保存和加载
|
||||
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||||
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||||
|
||||
# 根据选择的类型初始化不同的损失函数
|
||||
if type == 'nsgan':
|
||||
self.criterion = nn.BCELoss()
|
||||
self.criterion = nn.BCELoss() # 二进制交叉熵损失(非饱和GAN)
|
||||
elif type == 'lsgan':
|
||||
self.criterion = nn.MSELoss()
|
||||
self.criterion = nn.MSELoss() # 均方误差损失(最小平方GAN)
|
||||
elif type == 'hinge':
|
||||
self.criterion = nn.ReLU()
|
||||
self.criterion = nn.ReLU() # 适用于hinge损失的ReLU函数
|
||||
|
||||
def __call__(self, outputs, is_real, is_disc=None):
|
||||
"""
|
||||
调用函数计算损失。
|
||||
outputs: 网络输出。
|
||||
is_real: 如果是真实样本,则为 True;如果是生成样本,则为 False。
|
||||
is_disc: 指示当前是否在优化判别器。
|
||||
"""
|
||||
if self.type == 'hinge':
|
||||
# 对于 hinge 损失
|
||||
if is_disc:
|
||||
# 如果是判别器
|
||||
if is_real:
|
||||
outputs = -outputs
|
||||
outputs = -outputs # 对真实样本反向标签
|
||||
# max(0, 1 - (真/假)示例输出)
|
||||
return self.criterion(1 + outputs).mean()
|
||||
else:
|
||||
# 如果是生成器, -min(0, -输出) = max(0, 输出)
|
||||
return (-outputs).mean()
|
||||
else:
|
||||
# 对于 nsgan 和 lsgan 损失
|
||||
labels = (self.real_label if is_real else self.fake_label).expand_as(
|
||||
outputs)
|
||||
# 计算模型输出和目标标签之间的损失
|
||||
loss = self.criterion(outputs, labels)
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from backend.tools.train.utils_sttn import (
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description='STTN')
|
||||
parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str)
|
||||
parser.add_argument('-c', '--config', default='configs_sttn/youtube-vos.json', type=str)
|
||||
parser.add_argument('-m', '--model', default='sttn', type=str)
|
||||
parser.add_argument('-p', '--port', default='23455', type=str)
|
||||
parser.add_argument('-e', '--exam', action='store_true')
|
||||
@@ -22,56 +22,75 @@ args = parser.parse_args()
|
||||
|
||||
|
||||
def main_worker(rank, config):
|
||||
# 如果配置中没有提到局部排序(local_rank),就给它和全局排序(global_rank)赋值为传入的排序(rank)
|
||||
if 'local_rank' not in config:
|
||||
config['local_rank'] = config['global_rank'] = rank
|
||||
if config['distributed']:
|
||||
torch.cuda.set_device(int(config['local_rank']))
|
||||
torch.distributed.init_process_group(backend='nccl',
|
||||
init_method=config['init_method'],
|
||||
world_size=config['world_size'],
|
||||
rank=config['global_rank'],
|
||||
group_name='mtorch'
|
||||
)
|
||||
print('using GPU {}-{} for training'.format(
|
||||
int(config['global_rank']), int(config['local_rank'])))
|
||||
|
||||
config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'],
|
||||
os.path.basename(args.config).split('.')[0]))
|
||||
# 如果配置指定为分布式训练
|
||||
if config['distributed']:
|
||||
# 设置使用的CUDA设备为当前的本地排名对应的GPU
|
||||
torch.cuda.set_device(int(config['local_rank']))
|
||||
# 初始化分布式进程组,通过nccl后端
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl',
|
||||
init_method=config['init_method'],
|
||||
world_size=config['world_size'],
|
||||
rank=config['global_rank'],
|
||||
group_name='mtorch'
|
||||
)
|
||||
# 打印当前GPU的使用情况,输出全球排名和本地排名
|
||||
print('using GPU {}-{} for training'.format(
|
||||
int(config['global_rank']), int(config['local_rank']))
|
||||
)
|
||||
|
||||
# 创建模型保存的目录路径,包括模型名和配置文件名
|
||||
config['save_dir'] = os.path.join(
|
||||
config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0])
|
||||
)
|
||||
|
||||
# 如果CUDA可用,则设置设备为相应的CUDA设备,否则为CPU
|
||||
if torch.cuda.is_available():
|
||||
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
|
||||
else:
|
||||
config['device'] = 'cpu'
|
||||
|
||||
# 如果不是分布式训练,或者是分布式训练的主节点(rank 0)
|
||||
if (not config['distributed']) or config['global_rank'] == 0:
|
||||
# 创建模型保存目录,并允许如果该目录存在则忽略创建(exist_ok=True)
|
||||
os.makedirs(config['save_dir'], exist_ok=True)
|
||||
# 设置配置文件的保存路径
|
||||
config_path = os.path.join(
|
||||
config['save_dir'], config['config'].split('/')[-1])
|
||||
config['save_dir'], config['config'].split('/')[-1]
|
||||
)
|
||||
# 如果配置文件不存在,则从给定的配置文件路径复制到新路径
|
||||
if not os.path.isfile(config_path):
|
||||
copyfile(config['config'], config_path)
|
||||
# 打印创建目录的信息
|
||||
print('[**] create folder {}'.format(config['save_dir']))
|
||||
|
||||
# 初始化训练器,传入配置参数和debug标记
|
||||
trainer = Trainer(config, debug=args.exam)
|
||||
# 开始训练
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# loading configs
|
||||
# 加载配置文件
|
||||
config = json.load(open(args.config))
|
||||
config['model'] = args.model
|
||||
config['config'] = args.config
|
||||
config['model'] = args.model # 设置模型名称
|
||||
config['config'] = args.config # 设置配置文件路径
|
||||
|
||||
# setting distributed configurations
|
||||
config['world_size'] = get_world_size()
|
||||
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
|
||||
config['distributed'] = True if config['world_size'] > 1 else False
|
||||
# 设置分布式训练的相关配置
|
||||
config['world_size'] = get_world_size() # 获取全局进程数,即训练过程中参与计算的总GPU数量
|
||||
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" # 设置初始化方法,包括主节点IP和端口
|
||||
config['distributed'] = True if config['world_size'] > 1 else False # 根据世界规模确定是否启用分布式训练
|
||||
|
||||
# setup distributed parallel training environments
|
||||
# 设置分布式并行训练环境
|
||||
if get_master_ip() == "127.0.0.1":
|
||||
# manually launch distributed processes
|
||||
# 如果主节点IP是本机地址,那么手动启动多个分布式训练进程
|
||||
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
|
||||
else:
|
||||
# multiple processes have been launched by openmpi
|
||||
config['local_rank'] = get_local_rank()
|
||||
config['global_rank'] = get_global_rank()
|
||||
main_worker(-1, config)
|
||||
# 如果是由其他工具如OpenMPI启动的多个进程,不需手动创建进程。
|
||||
config['local_rank'] = get_local_rank() # 获取本地(单个节点)排名
|
||||
config['global_rank'] = get_global_rank() # 获取全局排名
|
||||
main_worker(-1, config) # 启动主工作函数
|
||||
|
||||
@@ -1,257 +1,319 @@
|
||||
import os
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
import importlib
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from backend.inpaint.sttn.auto_sttn import Discriminator
|
||||
from backend.inpaint.sttn.auto_sttn import InpaintGenerator
|
||||
from backend.tools.train.dataset_sttn import Dataset
|
||||
from backend.tools.train.loss_sttn import AdversarialLoss
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, debug=False):
|
||||
self.config = config
|
||||
self.epoch = 0
|
||||
self.iteration = 0
|
||||
# 训练器初始化
|
||||
self.config = config # 保存配置信息
|
||||
self.epoch = 0 # 当前训练所处的epoch
|
||||
self.iteration = 0 # 当前训练迭代次数
|
||||
if debug:
|
||||
# 如果是调试模式,设置更频繁的保存和验证频率
|
||||
self.config['trainer']['save_freq'] = 5
|
||||
self.config['trainer']['valid_freq'] = 5
|
||||
self.config['trainer']['iterations'] = 5
|
||||
|
||||
# setup data set and data loader
|
||||
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug)
|
||||
self.train_sampler = None
|
||||
self.train_args = config['trainer']
|
||||
# 设置数据集和数据加载器
|
||||
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug) # 创建训练集对象
|
||||
self.train_sampler = None # 初始化训练集采样器为None
|
||||
self.train_args = config['trainer'] # 训练过程参数
|
||||
if config['distributed']:
|
||||
# 如果是分布式训练,则初始化分布式采样器
|
||||
self.train_sampler = DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=config['world_size'],
|
||||
rank=config['global_rank'])
|
||||
rank=config['global_rank']
|
||||
)
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.train_args['batch_size'] // config['world_size'],
|
||||
shuffle=(self.train_sampler is None),
|
||||
shuffle=(self.train_sampler is None), # 如果没有采样器则进行打乱
|
||||
num_workers=self.train_args['num_workers'],
|
||||
sampler=self.train_sampler)
|
||||
sampler=self.train_sampler
|
||||
)
|
||||
|
||||
# set loss functions
|
||||
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
|
||||
self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
|
||||
self.l1_loss = nn.L1Loss()
|
||||
# 设置损失函数
|
||||
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) # 对抗性损失
|
||||
self.adversarial_loss = self.adversarial_loss.to(self.config['device']) # 将损失函数转移到相应设备
|
||||
self.l1_loss = nn.L1Loss() # L1损失
|
||||
|
||||
# setup models including generator and discriminator
|
||||
net = importlib.import_module('model.' + config['model'])
|
||||
self.netG = net.InpaintGenerator()
|
||||
self.netG = self.netG.to(self.config['device'])
|
||||
self.netD = net.Discriminator(
|
||||
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
|
||||
self.netD = self.netD.to(self.config['device'])
|
||||
# 初始化生成器和判别器模型
|
||||
self.netG = InpaintGenerator() # 生成网络
|
||||
self.netG = self.netG.to(self.config['device']) # 转移到设备
|
||||
self.netD = Discriminator(
|
||||
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge'
|
||||
)
|
||||
self.netD = self.netD.to(self.config['device']) # 判别网络
|
||||
# 初始化优化器
|
||||
self.optimG = torch.optim.Adam(
|
||||
self.netG.parameters(),
|
||||
lr=config['trainer']['lr'],
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
|
||||
self.netG.parameters(), # 生成器参数
|
||||
lr=config['trainer']['lr'], # 学习率
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
|
||||
)
|
||||
self.optimD = torch.optim.Adam(
|
||||
self.netD.parameters(),
|
||||
lr=config['trainer']['lr'],
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
|
||||
self.load()
|
||||
self.netD.parameters(), # 判别器参数
|
||||
lr=config['trainer']['lr'], # 学习率
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
|
||||
)
|
||||
self.load() # 加载模型
|
||||
|
||||
if config['distributed']:
|
||||
# 如果是分布式训练,则使用分布式数据并行包装器
|
||||
self.netG = DDP(
|
||||
self.netG,
|
||||
device_ids=[self.config['local_rank']],
|
||||
output_device=self.config['local_rank'],
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=False)
|
||||
find_unused_parameters=False
|
||||
)
|
||||
self.netD = DDP(
|
||||
self.netD,
|
||||
device_ids=[self.config['local_rank']],
|
||||
output_device=self.config['local_rank'],
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=False)
|
||||
find_unused_parameters=False
|
||||
)
|
||||
|
||||
# set summary writer
|
||||
self.dis_writer = None
|
||||
self.gen_writer = None
|
||||
self.summary = {}
|
||||
# 设置日志记录器
|
||||
self.dis_writer = None # 判别器写入器
|
||||
self.gen_writer = None # 生成器写入器
|
||||
self.summary = {} # 存放摘要统计
|
||||
if self.config['global_rank'] == 0 or (not config['distributed']):
|
||||
# 如果不是分布式训练或者为分布式训练的主节点
|
||||
self.dis_writer = SummaryWriter(
|
||||
os.path.join(config['save_dir'], 'dis'))
|
||||
os.path.join(config['save_dir'], 'dis')
|
||||
)
|
||||
self.gen_writer = SummaryWriter(
|
||||
os.path.join(config['save_dir'], 'gen'))
|
||||
os.path.join(config['save_dir'], 'gen')
|
||||
)
|
||||
|
||||
# get current learning rate
|
||||
# 获取当前学习率
|
||||
def get_lr(self):
|
||||
return self.optimG.param_groups[0]['lr']
|
||||
|
||||
# learning rate scheduler, step
|
||||
# 调整学习率
|
||||
def adjust_learning_rate(self):
|
||||
decay = 0.1 ** (min(self.iteration,
|
||||
self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
|
||||
# 计算衰减的学习率
|
||||
decay = 0.1 ** (min(self.iteration, self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
|
||||
new_lr = self.config['trainer']['lr'] * decay
|
||||
# 如果新的学习率和当前学习率不同,则更新优化器中的学习率
|
||||
if new_lr != self.get_lr():
|
||||
for param_group in self.optimG.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
for param_group in self.optimD.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
|
||||
# add summary
|
||||
# 添加摘要信息
|
||||
def add_summary(self, writer, name, val):
|
||||
# 添加并更新统计信息,每次迭代都累加
|
||||
if name not in self.summary:
|
||||
self.summary[name] = 0
|
||||
self.summary[name] += val
|
||||
# 每100次迭代记录一次
|
||||
if writer is not None and self.iteration % 100 == 0:
|
||||
writer.add_scalar(name, self.summary[name] / 100, self.iteration)
|
||||
self.summary[name] = 0
|
||||
|
||||
# load netG and netD
|
||||
# 加载模型netG and netD
|
||||
def load(self):
|
||||
model_path = self.config['save_dir']
|
||||
model_path = self.config['save_dir'] # 模型的保存路径
|
||||
# 检测是否存在最近的模型检查点
|
||||
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
||||
# 读取最后一个epoch的编号
|
||||
latest_epoch = open(os.path.join(
|
||||
model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
|
||||
else:
|
||||
# 如果不存在latest.ckpt,尝试读取存储好的模型文件列表,获取最近的一个
|
||||
ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
|
||||
os.path.join(model_path, '*.pth'))]
|
||||
ckpts.sort()
|
||||
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None
|
||||
ckpts.sort() # 排序模型文件,以获取最近的一个
|
||||
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None # 获取最近的epoch值
|
||||
if latest_epoch is not None:
|
||||
# 拼接得到生成器和判别器的模型文件路径
|
||||
gen_path = os.path.join(
|
||||
model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
dis_path = os.path.join(
|
||||
model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
opt_path = os.path.join(
|
||||
model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
# 如果是主节点,输出加载模型的信息
|
||||
if self.config['global_rank'] == 0:
|
||||
print('Loading model from {}...'.format(gen_path))
|
||||
# 加载生成器模型
|
||||
data = torch.load(gen_path, map_location=self.config['device'])
|
||||
self.netG.load_state_dict(data['netG'])
|
||||
# 加载判别器模型
|
||||
data = torch.load(dis_path, map_location=self.config['device'])
|
||||
self.netD.load_state_dict(data['netD'])
|
||||
# 加载优化器状态
|
||||
data = torch.load(opt_path, map_location=self.config['device'])
|
||||
self.optimG.load_state_dict(data['optimG'])
|
||||
self.optimD.load_state_dict(data['optimD'])
|
||||
# 更新当前epoch和迭代次数
|
||||
self.epoch = data['epoch']
|
||||
self.iteration = data['iteration']
|
||||
else:
|
||||
# 如果没有找到模型文件,则输出警告信息
|
||||
if self.config['global_rank'] == 0:
|
||||
print(
|
||||
'Warnning: There is no trained model found. An initialized model will be used.')
|
||||
print('Warning: There is no trained model found. An initialized model will be used.')
|
||||
|
||||
# save parameters every eval_epoch
|
||||
# 保存模型参数,每次评估周期 (eval_epoch) 调用一次
|
||||
def save(self, it):
|
||||
# 只在全局排名为0的进程上执行保存操作,通常代表主节点
|
||||
if self.config['global_rank'] == 0:
|
||||
# 生成保存生成器模型状态字典的文件路径
|
||||
gen_path = os.path.join(
|
||||
self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
|
||||
# 生成保存判别器模型状态字典的文件路径
|
||||
dis_path = os.path.join(
|
||||
self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
|
||||
# 生成保存优化器状态字典的文件路径
|
||||
opt_path = os.path.join(
|
||||
self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
|
||||
|
||||
# 打印消息表示模型正在保存
|
||||
print('\nsaving model to {} ...'.format(gen_path))
|
||||
|
||||
# 判断模型是否是经过DataParallel或DDP包装的,若是则获取原始的模型
|
||||
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
||||
netG = self.netG.module
|
||||
netD = self.netD.module
|
||||
else:
|
||||
netG = self.netG
|
||||
netD = self.netD
|
||||
|
||||
# 保存生成器和判别器的模型参数
|
||||
torch.save({'netG': netG.state_dict()}, gen_path)
|
||||
torch.save({'netD': netD.state_dict()}, dis_path)
|
||||
torch.save({'epoch': self.epoch,
|
||||
'iteration': self.iteration,
|
||||
'optimG': self.optimG.state_dict(),
|
||||
'optimD': self.optimD.state_dict()}, opt_path)
|
||||
# 保存当前的epoch、迭代次数和优化器的状态
|
||||
torch.save({
|
||||
'epoch': self.epoch,
|
||||
'iteration': self.iteration,
|
||||
'optimG': self.optimG.state_dict(),
|
||||
'optimD': self.optimD.state_dict()
|
||||
}, opt_path)
|
||||
|
||||
# 写入最新的迭代次数到"latest.ckpt"文件
|
||||
os.system('echo {} > {}'.format(str(it).zfill(5),
|
||||
os.path.join(self.config['save_dir'], 'latest.ckpt')))
|
||||
|
||||
# train entry
|
||||
# 训练入口
|
||||
|
||||
def train(self):
|
||||
# 初始化进度条范围
|
||||
pbar = range(int(self.train_args['iterations']))
|
||||
# 如果是全局rank 0的进程,则设置显示进度条
|
||||
if self.config['global_rank'] == 0:
|
||||
pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
|
||||
|
||||
# 开始训练循环
|
||||
while True:
|
||||
self.epoch += 1
|
||||
self.epoch += 1 # epoch计数增加
|
||||
if self.config['distributed']:
|
||||
# 如果是分布式训练,则对采样器进行设置,保证每个进程获取的数据不同
|
||||
self.train_sampler.set_epoch(self.epoch)
|
||||
|
||||
# 调用训练一个epoch的函数
|
||||
self._train_epoch(pbar)
|
||||
# 如果迭代次数超过配置中的迭代上限,则退出循环
|
||||
if self.iteration > self.train_args['iterations']:
|
||||
break
|
||||
# 训练结束输出
|
||||
print('\nEnd training....')
|
||||
|
||||
# process input and calculate loss every training epoch
|
||||
def _train_epoch(self, pbar):
|
||||
device = self.config['device']
|
||||
# 每个训练周期处理输入并计算损失
|
||||
|
||||
def _train_epoch(self, pbar):
|
||||
device = self.config['device'] # 获取设备信息
|
||||
|
||||
# 遍历数据加载器中的数据
|
||||
for frames, masks in self.train_loader:
|
||||
# 调整学习率
|
||||
self.adjust_learning_rate()
|
||||
# 迭代次数+1
|
||||
self.iteration += 1
|
||||
|
||||
# 将frames和masks转移到设备上
|
||||
frames, masks = frames.to(device), masks.to(device)
|
||||
b, t, c, h, w = frames.size()
|
||||
masked_frame = (frames * (1 - masks).float())
|
||||
pred_img = self.netG(masked_frame, masks)
|
||||
b, t, c, h, w = frames.size() # 获取帧和蒙版的尺寸
|
||||
masked_frame = (frames * (1 - masks).float()) # 应用蒙版到图像
|
||||
pred_img = self.netG(masked_frame, masks) # 使用生成器生成填充图像
|
||||
# 调整frames和masks的维度以符合网络的输入要求
|
||||
frames = frames.view(b * t, c, h, w)
|
||||
masks = masks.view(b * t, 1, h, w)
|
||||
comp_img = frames * (1. - masks) + masks * pred_img
|
||||
comp_img = frames * (1. - masks) + masks * pred_img # 生成最终的组合图像
|
||||
|
||||
gen_loss = 0
|
||||
dis_loss = 0
|
||||
gen_loss = 0 # 初始化生成器损失
|
||||
dis_loss = 0 # 初始化判别器损失
|
||||
|
||||
# discriminator adversarial loss
|
||||
real_vid_feat = self.netD(frames)
|
||||
fake_vid_feat = self.netD(comp_img.detach())
|
||||
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True)
|
||||
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True)
|
||||
dis_loss += (dis_real_loss + dis_fake_loss) / 2
|
||||
self.add_summary(
|
||||
self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
|
||||
self.add_summary(
|
||||
self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
|
||||
# 判别器对抗性损失
|
||||
real_vid_feat = self.netD(frames) # 判别器对真实图像判别
|
||||
fake_vid_feat = self.netD(comp_img.detach()) # 判别器对生成图像判别,注意detach是为了不计算梯度
|
||||
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) # 真实图像的损失
|
||||
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) # 生成图像的损失
|
||||
dis_loss += (dis_real_loss + dis_fake_loss) / 2 # 求平均的判别器损失
|
||||
# 添加判别器损失到摘要
|
||||
self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
|
||||
self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
|
||||
# 优化判别器
|
||||
self.optimD.zero_grad()
|
||||
dis_loss.backward()
|
||||
self.optimD.step()
|
||||
|
||||
# generator adversarial loss
|
||||
# 生成器对抗性损失
|
||||
gen_vid_feat = self.netD(comp_img)
|
||||
gan_loss = self.adversarial_loss(gen_vid_feat, True, False)
|
||||
gan_loss = gan_loss * self.config['losses']['adversarial_weight']
|
||||
gen_loss += gan_loss
|
||||
self.add_summary(
|
||||
self.gen_writer, 'loss/gan_loss', gan_loss.item())
|
||||
gan_loss = self.adversarial_loss(gen_vid_feat, True, False) # 生成器的对抗损失
|
||||
gan_loss = gan_loss * self.config['losses']['adversarial_weight'] # 权重放大
|
||||
gen_loss += gan_loss # 累加到生成器损失
|
||||
# 添加生成器对抗性损失到摘要
|
||||
self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
|
||||
|
||||
# generator l1 loss
|
||||
hole_loss = self.l1_loss(pred_img * masks, frames * masks)
|
||||
# 生成器L1损失
|
||||
hole_loss = self.l1_loss(pred_img * masks, frames * masks) # 只计算有蒙版区域的损失
|
||||
# 考虑蒙版的平均值,乘以配置中的hole_weight
|
||||
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
|
||||
gen_loss += hole_loss
|
||||
self.add_summary(
|
||||
self.gen_writer, 'loss/hole_loss', hole_loss.item())
|
||||
gen_loss += hole_loss # 累加到生成器损失
|
||||
# 添加hole_loss到摘要
|
||||
self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
|
||||
|
||||
# 计算蒙版外区域的L1损失
|
||||
valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
|
||||
# 考虑非蒙版区的平均值,乘以配置中的valid_weight
|
||||
valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
|
||||
gen_loss += valid_loss
|
||||
self.add_summary(
|
||||
self.gen_writer, 'loss/valid_loss', valid_loss.item())
|
||||
gen_loss += valid_loss # 累加到生成器损失
|
||||
# 添加valid_loss到摘要
|
||||
self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
|
||||
|
||||
# 生成器优化
|
||||
self.optimG.zero_grad()
|
||||
gen_loss.backward()
|
||||
self.optimG.step()
|
||||
|
||||
# console logs
|
||||
# 控制台日志输出
|
||||
if self.config['global_rank'] == 0:
|
||||
pbar.update(1)
|
||||
pbar.set_description((
|
||||
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};"
|
||||
pbar.update(1) # 进度条更新
|
||||
pbar.set_description(( # 设置进度条描述
|
||||
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" # 打印损失数值
|
||||
f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
|
||||
)
|
||||
|
||||
# saving models
|
||||
# 模型保存
|
||||
if self.iteration % self.train_args['save_freq'] == 0:
|
||||
self.save(int(self.iteration // self.train_args['save_freq']))
|
||||
# 迭代次数终止判断
|
||||
if self.iteration > self.train_args['iterations']:
|
||||
break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user