mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-06-11 11:33:12 +08:00
init
This commit is contained in:
44
backend/inpaint/sttn/core/loss.py
Normal file
44
backend/inpaint/sttn/core/loss.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import os
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
r"""
|
||||
Adversarial loss
|
||||
https://arxiv.org/abs/1711.10337
|
||||
"""
|
||||
|
||||
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
|
||||
r"""
|
||||
type = nsgan | lsgan | hinge
|
||||
"""
|
||||
super(AdversarialLoss, self).__init__()
|
||||
self.type = type
|
||||
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||||
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||||
|
||||
if type == 'nsgan':
|
||||
self.criterion = nn.BCELoss()
|
||||
elif type == 'lsgan':
|
||||
self.criterion = nn.MSELoss()
|
||||
elif type == 'hinge':
|
||||
self.criterion = nn.ReLU()
|
||||
|
||||
def __call__(self, outputs, is_real, is_disc=None):
|
||||
if self.type == 'hinge':
|
||||
if is_disc:
|
||||
if is_real:
|
||||
outputs = -outputs
|
||||
return self.criterion(1 + outputs).mean()
|
||||
else:
|
||||
return (-outputs).mean()
|
||||
else:
|
||||
labels = (self.real_label if is_real else self.fake_label).expand_as(
|
||||
outputs)
|
||||
loss = self.criterion(outputs, labels)
|
||||
return loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user