mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-22 22:27:33 +08:00
添加注释
This commit is contained in:
@@ -3,39 +3,54 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
r"""
|
||||
Adversarial loss
|
||||
https://arxiv.org/abs/1711.10337
|
||||
"""
|
||||
对抗性损失
|
||||
根据论文 https://arxiv.org/abs/1711.10337 实现
|
||||
"""
|
||||
|
||||
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
|
||||
r"""
|
||||
type = nsgan | lsgan | hinge
|
||||
"""
|
||||
可以选择的损失类型有 'nsgan' | 'lsgan' | 'hinge'
|
||||
type: 指定使用哪种类型的 GAN 损失。
|
||||
target_real_label: 真实图像的目标标签值。
|
||||
target_fake_label: 生成图像的目标标签值。
|
||||
"""
|
||||
super(AdversarialLoss, self).__init__()
|
||||
self.type = type
|
||||
self.type = type # 损失类型
|
||||
# 使用缓冲区注册标签,这样在模型保存和加载时会一同保存和加载
|
||||
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||||
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||||
|
||||
# 根据选择的类型初始化不同的损失函数
|
||||
if type == 'nsgan':
|
||||
self.criterion = nn.BCELoss()
|
||||
self.criterion = nn.BCELoss() # 二进制交叉熵损失(非饱和GAN)
|
||||
elif type == 'lsgan':
|
||||
self.criterion = nn.MSELoss()
|
||||
self.criterion = nn.MSELoss() # 均方误差损失(最小平方GAN)
|
||||
elif type == 'hinge':
|
||||
self.criterion = nn.ReLU()
|
||||
self.criterion = nn.ReLU() # 适用于hinge损失的ReLU函数
|
||||
|
||||
def __call__(self, outputs, is_real, is_disc=None):
|
||||
"""
|
||||
调用函数计算损失。
|
||||
outputs: 网络输出。
|
||||
is_real: 如果是真实样本,则为 True;如果是生成样本,则为 False。
|
||||
is_disc: 指示当前是否在优化判别器。
|
||||
"""
|
||||
if self.type == 'hinge':
|
||||
# 对于 hinge 损失
|
||||
if is_disc:
|
||||
# 如果是判别器
|
||||
if is_real:
|
||||
outputs = -outputs
|
||||
outputs = -outputs # 对真实样本反向标签
|
||||
# max(0, 1 - (真/假)示例输出)
|
||||
return self.criterion(1 + outputs).mean()
|
||||
else:
|
||||
# 如果是生成器, -min(0, -输出) = max(0, 输出)
|
||||
return (-outputs).mean()
|
||||
else:
|
||||
# 对于 nsgan 和 lsgan 损失
|
||||
labels = (self.real_label if is_real else self.fake_label).expand_as(
|
||||
outputs)
|
||||
# 计算模型输出和目标标签之间的损失
|
||||
loss = self.criterion(outputs, labels)
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user