添加注释

This commit is contained in:
YaoFANGUK
2024-01-09 11:05:07 +08:00
parent 6b353455a0
commit a3dd7b797d
11 changed files with 271 additions and 154 deletions

1
.gitignore vendored
View File

@@ -372,3 +372,4 @@ test*_no_sub*.mp4
/backend/models/video/ProPainter.pth
/backend/models/big-lama/big-lama.pt
/test/debug/
/backend/tools/train/release_model/

View File

@@ -4,8 +4,8 @@
"data_loader": {
"name": "davis",
"data_root": "datasets/",
"w": 432,
"h": 240,
"w": 640,
"h": 120,
"sample_length": 5
},
"losses": {

View File

@@ -3,9 +3,9 @@
"save_dir": "release_model/",
"data_loader": {
"name": "youtube-vos",
"data_root": "datasets/",
"w": 432,
"h": 240,
"data_root": "datasets_sttn/",
"w": 640,
"h": 120,
"sample_length": 5
},
"losses": {

View File

@@ -8,62 +8,78 @@ from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_r
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
# 自定义的数据集
class Dataset(torch.utils.data.Dataset):
def __init__(self, args: dict, split='train', debug=False):
# 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
self.args = args
self.split = split
self.sample_length = args['sample_length']
self.size = self.w, self.h = (args['w'], args['h'])
self.sample_length = args['sample_length'] # 样本长度参数
self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
# 打开存放数据相关信息的json文件
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
self.video_dict = json.load(f)
self.video_names = list(self.video_dict.keys())
if debug or split != 'train':
self.video_dict = json.load(f) # 加载json文件内容
self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
if debug or split != 'train': # 如果是调试模式或者不是训练集只取前100个视频
self.video_names = self.video_names[:100]
# 定义数据的转换操作,转换成堆叠的张量
self._to_tensors = transforms.Compose([
Stack(),
ToTorchFormatTensor(), ])
ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
])
def __len__(self):
# 返回数据集中视频的数量
return len(self.video_names)
def __getitem__(self, index):
# 获取一个样本项
try:
item = self.load_item(index)
item = self.load_item(index) # 尝试加载指定索引的数据项
except:
print('Loading error in video {}'.format(self.video_names[index]))
item = self.load_item(0)
print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
item = self.load_item(0) # 加载第一个项目作为兜底
return item
def load_item(self, index):
video_name = self.video_names[index]
# 加载数据项的具体实现
video_name = self.video_names[index] # 根据索引获取视频名称
# 为所有视频帧生成帧文件名列表
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
# 生成随机运动的随机形状的遮罩
all_masks = create_random_shape_with_random_motion(
len(all_frames), imageHeight=self.h, imageWidth=self.w)
# 获取参考帧的索引
ref_index = get_ref_index(len(all_frames), self.sample_length)
# read video frames
# 读取视频帧
frames = []
masks = []
for idx in ref_index:
# 读取图片转化为RGB调整大小并添加到列表中
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
img = img.resize(self.size)
frames.append(img)
masks.append(all_masks[idx])
if self.split == 'train':
# 如果是训练集,随机水平翻转图像
frames = GroupRandomHorizontalFlip()(frames)
# To tensors
frame_tensors = self._to_tensors(frames)*2.0 - 1.0
mask_tensors = self._to_tensors(masks)
return frame_tensors, mask_tensors
# 转换成张量形式
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
return frame_tensors, mask_tensors # 返回图像和遮罩的张量
def get_ref_index(length, sample_length):
# 获取参考帧索引的实现
if random.uniform(0, 1) > 0.5:
# 有一半的概率随机选择帧
ref_index = random.sample(range(length), sample_length)
ref_index.sort()
ref_index.sort() # 排序保证顺序
else:
# 另一半概率选择连续的帧
pivot = random.randint(0, length-sample_length)
ref_index = [pivot+i for i in range(sample_length)]
return ref_index

View File

@@ -0,0 +1 @@
{"bear": 82, "bike-packing": 69, "blackswan": 50, "bmx-bumps": 90, "bmx-trees": 80, "boat": 75, "boxing-fisheye": 87, "breakdance": 84, "breakdance-flare": 71, "bus": 80, "camel": 90, "car-roundabout": 75, "car-shadow": 40, "car-turn": 80, "cat-girl": 89, "classic-car": 63, "color-run": 84, "cows": 104, "crossing": 52, "dance-jump": 60, "dance-twirl": 90, "dancing": 62, "disc-jockey": 76, "dog": 60, "dog-agility": 25, "dog-gooses": 86, "dogs-jump": 66, "dogs-scale": 83, "drift-chicane": 52, "drift-straight": 50, "drift-turn": 64, "drone": 91, "elephant": 80, "flamingo": 80, "goat": 90, "gold-fish": 78, "hike": 80, "hockey": 75, "horsejump-high": 50, "horsejump-low": 60, "india": 81, "judo": 34, "kid-football": 68, "kite-surf": 50, "kite-walk": 80, "koala": 100, "lab-coat": 47, "lady-running": 65, "libby": 49, "lindy-hop": 73, "loading": 50, "longboard": 52, "lucia": 70, "mallard-fly": 70, "mallard-water": 80, "mbike-trick": 79, "miami-surf": 70, "motocross-bumps": 60, "motocross-jump": 40, "motorbike": 43, "night-race": 46, "paragliding": 70, "paragliding-launch": 80, "parkour": 100, "pigs": 79, "planes-water": 38, "rallye": 50, "rhino": 90, "rollerblade": 35, "schoolgirls": 80, "scooter-black": 43, "scooter-board": 91, "scooter-gray": 75, "sheep": 68, "shooting": 40, "skate-park": 80, "snowboard": 66, "soapbox": 99, "soccerball": 48, "stroller": 91, "stunt": 71, "surf": 55, "swing": 60, "tennis": 70, "tractor-sand": 76, "train": 80, "tuk-tuk": 59, "upside-down": 65, "varanus-cage": 67, "walking": 72}

View File

@@ -0,0 +1 @@
{"baseball": 90, "basketball-game": 77, "bears-ball": 78, "bmx-rider": 85, "butterfly": 80, "car-competition": 66, "cat": 52, "chairlift": 99, "circus": 73, "city-ride": 70, "crafting": 45, "curling": 76, "dog-competition": 85, "dolphins-show": 74, "dribbling": 49, "drone-flying": 70, "ducks": 75, "elephant-hyenas": 55, "giraffes": 88, "gym-ball": 69, "helicopter-landing": 77, "horse-race": 80, "horses-kids": 78, "hurdles-race": 55, "ice-hockey": 52, "jet-ski": 83, "juggling-selfie": 78, "kayak-race": 63, "kids-robot": 75, "landing": 35, "luggage": 83, "mantaray": 73, "marbles": 70, "mascot": 78, "mermaid": 78, "monster-trucks": 99, "motorbike-indoors": 79, "motorbike-race": 88, "music-band": 87, "obstacles": 81, "obstacles-race": 48, "peacock": 75, "plane-exhibition": 73, "puppet": 100, "robot-battle": 85, "robotic-arm": 82, "rodeo": 85, "sea-turtle": 90, "skydiving-jumping": 75, "snowboard-race": 75, "snowboard-sand": 55, "surfer": 80, "swimmer": 86, "table-tennis": 70, "tram": 84, "trucks-race": 78, "twist-dance": 83, "volleyball-beach": 73, "water-slide": 88, "weightlifting": 90}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -3,39 +3,54 @@ import torch.nn as nn
class AdversarialLoss(nn.Module):
r"""
Adversarial loss
https://arxiv.org/abs/1711.10337
"""
对抗性损失
根据论文 https://arxiv.org/abs/1711.10337 实现
"""
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
r"""
type = nsgan | lsgan | hinge
"""
可以选择的损失类型有 'nsgan' | 'lsgan' | 'hinge'
type: 指定使用哪种类型的 GAN 损失。
target_real_label: 真实图像的目标标签值。
target_fake_label: 生成图像的目标标签值。
"""
super(AdversarialLoss, self).__init__()
self.type = type
self.type = type # 损失类型
# 使用缓冲区注册标签,这样在模型保存和加载时会一同保存和加载
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
# 根据选择的类型初始化不同的损失函数
if type == 'nsgan':
self.criterion = nn.BCELoss()
self.criterion = nn.BCELoss() # 二进制交叉熵损失非饱和GAN
elif type == 'lsgan':
self.criterion = nn.MSELoss()
self.criterion = nn.MSELoss() # 均方误差损失最小平方GAN
elif type == 'hinge':
self.criterion = nn.ReLU()
self.criterion = nn.ReLU() # 适用于hinge损失的ReLU函数
def __call__(self, outputs, is_real, is_disc=None):
"""
调用函数计算损失。
outputs: 网络输出。
is_real: 如果是真实样本,则为 True如果是生成样本则为 False。
is_disc: 指示当前是否在优化判别器。
"""
if self.type == 'hinge':
# 对于 hinge 损失
if is_disc:
# 如果是判别器
if is_real:
outputs = -outputs
outputs = -outputs # 对真实样本反向标签
# max(0, 1 - (真/假)示例输出)
return self.criterion(1 + outputs).mean()
else:
# 如果是生成器, -min(0, -输出) = max(0, 输出)
return (-outputs).mean()
else:
# 对于 nsgan 和 lsgan 损失
labels = (self.real_label if is_real else self.fake_label).expand_as(
outputs)
# 计算模型输出和目标标签之间的损失
loss = self.criterion(outputs, labels)
return loss

View File

@@ -14,7 +14,7 @@ from backend.tools.train.utils_sttn import (
)
parser = argparse.ArgumentParser(description='STTN')
parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str)
parser.add_argument('-c', '--config', default='configs_sttn/youtube-vos.json', type=str)
parser.add_argument('-m', '--model', default='sttn', type=str)
parser.add_argument('-p', '--port', default='23455', type=str)
parser.add_argument('-e', '--exam', action='store_true')
@@ -22,56 +22,75 @@ args = parser.parse_args()
def main_worker(rank, config):
# 如果配置中没有提到局部排序local_rank就给它和全局排序global_rank赋值为传入的排序rank
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
if config['distributed']:
torch.cuda.set_device(int(config['local_rank']))
torch.distributed.init_process_group(backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch'
)
print('using GPU {}-{} for training'.format(
int(config['global_rank']), int(config['local_rank'])))
config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'],
os.path.basename(args.config).split('.')[0]))
# 如果配置指定为分布式训练
if config['distributed']:
# 设置使用的CUDA设备为当前的本地排名对应的GPU
torch.cuda.set_device(int(config['local_rank']))
# 初始化分布式进程组通过nccl后端
torch.distributed.init_process_group(
backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch'
)
# 打印当前GPU的使用情况输出全球排名和本地排名
print('using GPU {}-{} for training'.format(
int(config['global_rank']), int(config['local_rank']))
)
# 创建模型保存的目录路径,包括模型名和配置文件名
config['save_dir'] = os.path.join(
config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0])
)
# 如果CUDA可用则设置设备为相应的CUDA设备否则为CPU
if torch.cuda.is_available():
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
else:
config['device'] = 'cpu'
# 如果不是分布式训练或者是分布式训练的主节点rank 0
if (not config['distributed']) or config['global_rank'] == 0:
# 创建模型保存目录并允许如果该目录存在则忽略创建exist_ok=True
os.makedirs(config['save_dir'], exist_ok=True)
# 设置配置文件的保存路径
config_path = os.path.join(
config['save_dir'], config['config'].split('/')[-1])
config['save_dir'], config['config'].split('/')[-1]
)
# 如果配置文件不存在,则从给定的配置文件路径复制到新路径
if not os.path.isfile(config_path):
copyfile(config['config'], config_path)
# 打印创建目录的信息
print('[**] create folder {}'.format(config['save_dir']))
# 初始化训练器传入配置参数和debug标记
trainer = Trainer(config, debug=args.exam)
# 开始训练
trainer.train()
if __name__ == "__main__":
# loading configs
# 加载配置文件
config = json.load(open(args.config))
config['model'] = args.model
config['config'] = args.config
config['model'] = args.model # 设置模型名称
config['config'] = args.config # 设置配置文件路径
# setting distributed configurations
config['world_size'] = get_world_size()
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
config['distributed'] = True if config['world_size'] > 1 else False
# 设置分布式训练的相关配置
config['world_size'] = get_world_size() # 获取全局进程数即训练过程中参与计算的总GPU数量
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" # 设置初始化方法包括主节点IP和端口
config['distributed'] = True if config['world_size'] > 1 else False # 根据世界规模确定是否启用分布式训练
# setup distributed parallel training environments
# 设置分布式并行训练环境
if get_master_ip() == "127.0.0.1":
# manually launch distributed processes
# 如果主节点IP是本机地址那么手动启动多个分布式训练进程
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
else:
# multiple processes have been launched by openmpi
config['local_rank'] = get_local_rank()
config['global_rank'] = get_global_rank()
main_worker(-1, config)
# 如果是由其他工具如OpenMPI启动的多个进程不需手动创建进程。
config['local_rank'] = get_local_rank() # 获取本地(单个节点)排名
config['global_rank'] = get_global_rank() # 获取全局排名
main_worker(-1, config) # 启动主工作函数

View File

@@ -1,257 +1,319 @@
import os
import glob
from tqdm import tqdm
import importlib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from tensorboardX import SummaryWriter
from backend.inpaint.sttn.auto_sttn import Discriminator
from backend.inpaint.sttn.auto_sttn import InpaintGenerator
from backend.tools.train.dataset_sttn import Dataset
from backend.tools.train.loss_sttn import AdversarialLoss
class Trainer:
def __init__(self, config, debug=False):
self.config = config
self.epoch = 0
self.iteration = 0
# 训练器初始化
self.config = config # 保存配置信息
self.epoch = 0 # 当前训练所处的epoch
self.iteration = 0 # 当前训练迭代次数
if debug:
# 如果是调试模式,设置更频繁的保存和验证频率
self.config['trainer']['save_freq'] = 5
self.config['trainer']['valid_freq'] = 5
self.config['trainer']['iterations'] = 5
# setup data set and data loader
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug)
self.train_sampler = None
self.train_args = config['trainer']
# 设置数据集和数据加载器
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug) # 创建训练集对象
self.train_sampler = None # 初始化训练集采样器为None
self.train_args = config['trainer'] # 训练过程参数
if config['distributed']:
# 如果是分布式训练,则初始化分布式采样器
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=config['world_size'],
rank=config['global_rank'])
rank=config['global_rank']
)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.train_args['batch_size'] // config['world_size'],
shuffle=(self.train_sampler is None),
shuffle=(self.train_sampler is None), # 如果没有采样器则进行打乱
num_workers=self.train_args['num_workers'],
sampler=self.train_sampler)
sampler=self.train_sampler
)
# set loss functions
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
self.l1_loss = nn.L1Loss()
# 设置损失函数
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) # 对抗性损失
self.adversarial_loss = self.adversarial_loss.to(self.config['device']) # 将损失函数转移到相应设备
self.l1_loss = nn.L1Loss() # L1损失
# setup models including generator and discriminator
net = importlib.import_module('model.' + config['model'])
self.netG = net.InpaintGenerator()
self.netG = self.netG.to(self.config['device'])
self.netD = net.Discriminator(
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
self.netD = self.netD.to(self.config['device'])
# 初始化生成器和判别器模型
self.netG = InpaintGenerator() # 生成网络
self.netG = self.netG.to(self.config['device']) # 转移到设备
self.netD = Discriminator(
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge'
)
self.netD = self.netD.to(self.config['device']) # 判别网络
# 初始化优化器
self.optimG = torch.optim.Adam(
self.netG.parameters(),
lr=config['trainer']['lr'],
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
self.netG.parameters(), # 生成器参数
lr=config['trainer']['lr'], # 学习率
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
)
self.optimD = torch.optim.Adam(
self.netD.parameters(),
lr=config['trainer']['lr'],
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
self.load()
self.netD.parameters(), # 判别器参数
lr=config['trainer']['lr'], # 学习率
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
)
self.load() # 加载模型
if config['distributed']:
# 如果是分布式训练,则使用分布式数据并行包装器
self.netG = DDP(
self.netG,
device_ids=[self.config['local_rank']],
output_device=self.config['local_rank'],
broadcast_buffers=True,
find_unused_parameters=False)
find_unused_parameters=False
)
self.netD = DDP(
self.netD,
device_ids=[self.config['local_rank']],
output_device=self.config['local_rank'],
broadcast_buffers=True,
find_unused_parameters=False)
find_unused_parameters=False
)
# set summary writer
self.dis_writer = None
self.gen_writer = None
self.summary = {}
# 设置日志记录器
self.dis_writer = None # 判别器写入器
self.gen_writer = None # 生成器写入器
self.summary = {} # 存放摘要统计
if self.config['global_rank'] == 0 or (not config['distributed']):
# 如果不是分布式训练或者为分布式训练的主节点
self.dis_writer = SummaryWriter(
os.path.join(config['save_dir'], 'dis'))
os.path.join(config['save_dir'], 'dis')
)
self.gen_writer = SummaryWriter(
os.path.join(config['save_dir'], 'gen'))
os.path.join(config['save_dir'], 'gen')
)
# get current learning rate
# 获取当前学习率
def get_lr(self):
return self.optimG.param_groups[0]['lr']
# learning rate scheduler, step
# 调整学习率
def adjust_learning_rate(self):
decay = 0.1 ** (min(self.iteration,
self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
# 计算衰减的学习率
decay = 0.1 ** (min(self.iteration, self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
new_lr = self.config['trainer']['lr'] * decay
# 如果新的学习率和当前学习率不同,则更新优化器中的学习率
if new_lr != self.get_lr():
for param_group in self.optimG.param_groups:
param_group['lr'] = new_lr
for param_group in self.optimD.param_groups:
param_group['lr'] = new_lr
# add summary
# 添加摘要信息
def add_summary(self, writer, name, val):
# 添加并更新统计信息,每次迭代都累加
if name not in self.summary:
self.summary[name] = 0
self.summary[name] += val
# 每100次迭代记录一次
if writer is not None and self.iteration % 100 == 0:
writer.add_scalar(name, self.summary[name] / 100, self.iteration)
self.summary[name] = 0
# load netG and netD
# 加载模型netG and netD
def load(self):
model_path = self.config['save_dir']
model_path = self.config['save_dir'] # 模型的保存路径
# 检测是否存在最近的模型检查点
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
# 读取最后一个epoch的编号
latest_epoch = open(os.path.join(
model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
else:
# 如果不存在latest.ckpt尝试读取存储好的模型文件列表获取最近的一个
ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
os.path.join(model_path, '*.pth'))]
ckpts.sort()
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None
ckpts.sort() # 排序模型文件,以获取最近的一个
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None # 获取最近的epoch值
if latest_epoch is not None:
# 拼接得到生成器和判别器的模型文件路径
gen_path = os.path.join(
model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
dis_path = os.path.join(
model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
opt_path = os.path.join(
model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
# 如果是主节点,输出加载模型的信息
if self.config['global_rank'] == 0:
print('Loading model from {}...'.format(gen_path))
# 加载生成器模型
data = torch.load(gen_path, map_location=self.config['device'])
self.netG.load_state_dict(data['netG'])
# 加载判别器模型
data = torch.load(dis_path, map_location=self.config['device'])
self.netD.load_state_dict(data['netD'])
# 加载优化器状态
data = torch.load(opt_path, map_location=self.config['device'])
self.optimG.load_state_dict(data['optimG'])
self.optimD.load_state_dict(data['optimD'])
# 更新当前epoch和迭代次数
self.epoch = data['epoch']
self.iteration = data['iteration']
else:
# 如果没有找到模型文件,则输出警告信息
if self.config['global_rank'] == 0:
print(
'Warnning: There is no trained model found. An initialized model will be used.')
print('Warning: There is no trained model found. An initialized model will be used.')
# save parameters every eval_epoch
# 保存模型参数,每次评估周期 (eval_epoch) 调用一次
def save(self, it):
# 只在全局排名为0的进程上执行保存操作通常代表主节点
if self.config['global_rank'] == 0:
# 生成保存生成器模型状态字典的文件路径
gen_path = os.path.join(
self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
# 生成保存判别器模型状态字典的文件路径
dis_path = os.path.join(
self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
# 生成保存优化器状态字典的文件路径
opt_path = os.path.join(
self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
# 打印消息表示模型正在保存
print('\nsaving model to {} ...'.format(gen_path))
# 判断模型是否是经过DataParallel或DDP包装的若是则获取原始的模型
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
netG = self.netG.module
netD = self.netD.module
else:
netG = self.netG
netD = self.netD
# 保存生成器和判别器的模型参数
torch.save({'netG': netG.state_dict()}, gen_path)
torch.save({'netD': netD.state_dict()}, dis_path)
torch.save({'epoch': self.epoch,
'iteration': self.iteration,
'optimG': self.optimG.state_dict(),
'optimD': self.optimD.state_dict()}, opt_path)
# 保存当前的epoch、迭代次数和优化器的状态
torch.save({
'epoch': self.epoch,
'iteration': self.iteration,
'optimG': self.optimG.state_dict(),
'optimD': self.optimD.state_dict()
}, opt_path)
# 写入最新的迭代次数到"latest.ckpt"文件
os.system('echo {} > {}'.format(str(it).zfill(5),
os.path.join(self.config['save_dir'], 'latest.ckpt')))
# train entry
# 训练入口
def train(self):
# 初始化进度条范围
pbar = range(int(self.train_args['iterations']))
# 如果是全局rank 0的进程则设置显示进度条
if self.config['global_rank'] == 0:
pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
# 开始训练循环
while True:
self.epoch += 1
self.epoch += 1 # epoch计数增加
if self.config['distributed']:
# 如果是分布式训练,则对采样器进行设置,保证每个进程获取的数据不同
self.train_sampler.set_epoch(self.epoch)
# 调用训练一个epoch的函数
self._train_epoch(pbar)
# 如果迭代次数超过配置中的迭代上限,则退出循环
if self.iteration > self.train_args['iterations']:
break
# 训练结束输出
print('\nEnd training....')
# process input and calculate loss every training epoch
def _train_epoch(self, pbar):
device = self.config['device']
# 每个训练周期处理输入并计算损失
def _train_epoch(self, pbar):
device = self.config['device'] # 获取设备信息
# 遍历数据加载器中的数据
for frames, masks in self.train_loader:
# 调整学习率
self.adjust_learning_rate()
# 迭代次数+1
self.iteration += 1
# 将frames和masks转移到设备上
frames, masks = frames.to(device), masks.to(device)
b, t, c, h, w = frames.size()
masked_frame = (frames * (1 - masks).float())
pred_img = self.netG(masked_frame, masks)
b, t, c, h, w = frames.size() # 获取帧和蒙版的尺寸
masked_frame = (frames * (1 - masks).float()) # 应用蒙版到图像
pred_img = self.netG(masked_frame, masks) # 使用生成器生成填充图像
# 调整frames和masks的维度以符合网络的输入要求
frames = frames.view(b * t, c, h, w)
masks = masks.view(b * t, 1, h, w)
comp_img = frames * (1. - masks) + masks * pred_img
comp_img = frames * (1. - masks) + masks * pred_img # 生成最终的组合图像
gen_loss = 0
dis_loss = 0
gen_loss = 0 # 初始化生成器损失
dis_loss = 0 # 初始化判别器损失
# discriminator adversarial loss
real_vid_feat = self.netD(frames)
fake_vid_feat = self.netD(comp_img.detach())
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True)
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True)
dis_loss += (dis_real_loss + dis_fake_loss) / 2
self.add_summary(
self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
self.add_summary(
self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
# 判别器对抗性损失
real_vid_feat = self.netD(frames) # 判别器对真实图像判别
fake_vid_feat = self.netD(comp_img.detach()) # 判别器对生成图像判别注意detach是为了不计算梯度
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) # 真实图像的损失
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) # 生成图像的损失
dis_loss += (dis_real_loss + dis_fake_loss) / 2 # 求平均的判别器损失
# 添加判别器损失到摘要
self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
# 优化判别器
self.optimD.zero_grad()
dis_loss.backward()
self.optimD.step()
# generator adversarial loss
# 生成器对抗性损失
gen_vid_feat = self.netD(comp_img)
gan_loss = self.adversarial_loss(gen_vid_feat, True, False)
gan_loss = gan_loss * self.config['losses']['adversarial_weight']
gen_loss += gan_loss
self.add_summary(
self.gen_writer, 'loss/gan_loss', gan_loss.item())
gan_loss = self.adversarial_loss(gen_vid_feat, True, False) # 生成器的对抗损失
gan_loss = gan_loss * self.config['losses']['adversarial_weight'] # 权重放大
gen_loss += gan_loss # 累加到生成器损失
# 添加生成器对抗性损失到摘要
self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
# generator l1 loss
hole_loss = self.l1_loss(pred_img * masks, frames * masks)
# 生成器L1损失
hole_loss = self.l1_loss(pred_img * masks, frames * masks) # 只计算有蒙版区域的损失
# 考虑蒙版的平均值乘以配置中的hole_weight
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
gen_loss += hole_loss
self.add_summary(
self.gen_writer, 'loss/hole_loss', hole_loss.item())
gen_loss += hole_loss # 累加到生成器损失
# 添加hole_loss到摘要
self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
# 计算蒙版外区域的L1损失
valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
# 考虑非蒙版区的平均值乘以配置中的valid_weight
valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
gen_loss += valid_loss
self.add_summary(
self.gen_writer, 'loss/valid_loss', valid_loss.item())
gen_loss += valid_loss # 累加到生成器损失
# 添加valid_loss到摘要
self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
# 生成器优化
self.optimG.zero_grad()
gen_loss.backward()
self.optimG.step()
# console logs
# 控制台日志输出
if self.config['global_rank'] == 0:
pbar.update(1)
pbar.set_description((
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};"
pbar.update(1) # 进度条更新
pbar.set_description(( # 设置进度条描述
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" # 打印损失数值
f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
)
# saving models
# 模型保存
if self.iteration % self.train_args['save_freq'] == 0:
self.save(int(self.iteration // self.train_args['save_freq']))
# 迭代次数终止判断
if self.iteration > self.train_args['iterations']:
break