mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-22 18:04:43 +08:00
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class AdversarialLoss(nn.Module):
|
|
r"""
|
|
Adversarial loss
|
|
https://arxiv.org/abs/1711.10337
|
|
"""
|
|
|
|
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
|
|
r"""
|
|
type = nsgan | lsgan | hinge
|
|
"""
|
|
super(AdversarialLoss, self).__init__()
|
|
self.type = type
|
|
self.register_buffer('real_label', torch.tensor(target_real_label))
|
|
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
|
|
|
if type == 'nsgan':
|
|
self.criterion = nn.BCELoss()
|
|
elif type == 'lsgan':
|
|
self.criterion = nn.MSELoss()
|
|
elif type == 'hinge':
|
|
self.criterion = nn.ReLU()
|
|
|
|
def __call__(self, outputs, is_real, is_disc=None):
|
|
if self.type == 'hinge':
|
|
if is_disc:
|
|
if is_real:
|
|
outputs = -outputs
|
|
return self.criterion(1 + outputs).mean()
|
|
else:
|
|
return (-outputs).mean()
|
|
else:
|
|
labels = (self.real_label if is_real else self.fake_label).expand_as(
|
|
outputs)
|
|
loss = self.criterion(outputs, labels)
|
|
return loss
|
|
|
|
|