mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
57 lines
2.3 KiB
Python
57 lines
2.3 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class AdversarialLoss(nn.Module):
|
||
"""
|
||
对抗性损失
|
||
根据论文 https://arxiv.org/abs/1711.10337 实现
|
||
"""
|
||
|
||
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
|
||
"""
|
||
可以选择的损失类型有 'nsgan' | 'lsgan' | 'hinge'
|
||
type: 指定使用哪种类型的 GAN 损失。
|
||
target_real_label: 真实图像的目标标签值。
|
||
target_fake_label: 生成图像的目标标签值。
|
||
"""
|
||
super(AdversarialLoss, self).__init__()
|
||
self.type = type # 损失类型
|
||
# 使用缓冲区注册标签,这样在模型保存和加载时会一同保存和加载
|
||
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||
|
||
# 根据选择的类型初始化不同的损失函数
|
||
if type == 'nsgan':
|
||
self.criterion = nn.BCELoss() # 二进制交叉熵损失(非饱和GAN)
|
||
elif type == 'lsgan':
|
||
self.criterion = nn.MSELoss() # 均方误差损失(最小平方GAN)
|
||
elif type == 'hinge':
|
||
self.criterion = nn.ReLU() # 适用于hinge损失的ReLU函数
|
||
|
||
def __call__(self, outputs, is_real, is_disc=None):
|
||
"""
|
||
调用函数计算损失。
|
||
outputs: 网络输出。
|
||
is_real: 如果是真实样本,则为 True;如果是生成样本,则为 False。
|
||
is_disc: 指示当前是否在优化判别器。
|
||
"""
|
||
if self.type == 'hinge':
|
||
# 对于 hinge 损失
|
||
if is_disc:
|
||
# 如果是判别器
|
||
if is_real:
|
||
outputs = -outputs # 对真实样本反向标签
|
||
# max(0, 1 - (真/假)示例输出)
|
||
return self.criterion(1 + outputs).mean()
|
||
else:
|
||
# 如果是生成器, -min(0, -输出) = max(0, 输出)
|
||
return (-outputs).mean()
|
||
else:
|
||
# 对于 nsgan 和 lsgan 损失
|
||
labels = (self.real_label if is_real else self.fake_label).expand_as(
|
||
outputs)
|
||
# 计算模型输出和目标标签之间的损失
|
||
loss = self.criterion(outputs, labels)
|
||
return loss
|