添加注释

This commit is contained in:
YaoFANGUK
2024-01-09 11:05:07 +08:00
parent 6b353455a0
commit a3dd7b797d
11 changed files with 271 additions and 154 deletions

View File

@@ -3,39 +3,54 @@ import torch.nn as nn
class AdversarialLoss(nn.Module):
r"""
Adversarial loss
https://arxiv.org/abs/1711.10337
"""
对抗性损失
根据论文 https://arxiv.org/abs/1711.10337 实现
"""
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
r"""
type = nsgan | lsgan | hinge
"""
可以选择的损失类型有 'nsgan' | 'lsgan' | 'hinge'
type: 指定使用哪种类型的 GAN 损失。
target_real_label: 真实图像的目标标签值。
target_fake_label: 生成图像的目标标签值。
"""
super(AdversarialLoss, self).__init__()
self.type = type
self.type = type # 损失类型
# 使用缓冲区注册标签,这样在模型保存和加载时会一同保存和加载
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
# 根据选择的类型初始化不同的损失函数
if type == 'nsgan':
self.criterion = nn.BCELoss()
self.criterion = nn.BCELoss() # 二进制交叉熵损失非饱和GAN
elif type == 'lsgan':
self.criterion = nn.MSELoss()
self.criterion = nn.MSELoss() # 均方误差损失最小平方GAN
elif type == 'hinge':
self.criterion = nn.ReLU()
self.criterion = nn.ReLU() # 适用于hinge损失的ReLU函数
def __call__(self, outputs, is_real, is_disc=None):
"""
调用函数计算损失。
outputs: 网络输出。
is_real: 如果是真实样本,则为 True如果是生成样本则为 False。
is_disc: 指示当前是否在优化判别器。
"""
if self.type == 'hinge':
# 对于 hinge 损失
if is_disc:
# 如果是判别器
if is_real:
outputs = -outputs
outputs = -outputs # 对真实样本反向标签
# max(0, 1 - (真/假)示例输出)
return self.criterion(1 + outputs).mean()
else:
# 如果是生成器, -min(0, -输出) = max(0, 输出)
return (-outputs).mean()
else:
# 对于 nsgan 和 lsgan 损失
labels = (self.real_label if is_real else self.fake_label).expand_as(
outputs)
# 计算模型输出和目标标签之间的损失
loss = self.criterion(outputs, labels)
return loss