mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-27 22:24:42 +08:00
77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
import os
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from skimage import io
|
|
from skimage.transform import resize
|
|
from torch.utils.data import Dataset
|
|
|
|
from saicinpainting.evaluation.evaluator import InpaintingEvaluator
|
|
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
|
|
|
|
|
|
class SimpleImageDataset(Dataset):
|
|
def __init__(self, root_dir, image_size=(400, 600)):
|
|
self.root_dir = root_dir
|
|
self.files = sorted(os.listdir(root_dir))
|
|
self.image_size = image_size
|
|
|
|
def __getitem__(self, index):
|
|
img_name = os.path.join(self.root_dir, self.files[index])
|
|
image = io.imread(img_name)
|
|
image = resize(image, self.image_size, anti_aliasing=True)
|
|
image = torch.FloatTensor(image).permute(2, 0, 1)
|
|
return image
|
|
|
|
def __len__(self):
|
|
return len(self.files)
|
|
|
|
|
|
def create_rectangle_mask(height, width):
|
|
mask = np.ones((height, width))
|
|
up_left_corner = width // 4, height // 4
|
|
down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1)
|
|
cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED)
|
|
return mask
|
|
|
|
|
|
class Model():
|
|
def __call__(self, img_batch, mask_batch):
|
|
mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None]
|
|
inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :]
|
|
return inpainted
|
|
|
|
|
|
class SimpleImageSquareMaskDataset(Dataset):
|
|
def __init__(self, dataset):
|
|
self.dataset = dataset
|
|
self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size))
|
|
self.model = Model()
|
|
|
|
def __getitem__(self, index):
|
|
img = self.dataset[index]
|
|
mask = self.mask.clone()
|
|
inpainted = self.model(img[None, ...], mask[None, ...])
|
|
return dict(image=img, mask=mask, inpainted=inpainted)
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
|
|
dataset = SimpleImageDataset('imgs')
|
|
mask_dataset = SimpleImageSquareMaskDataset(dataset)
|
|
model = Model()
|
|
metrics = {
|
|
'ssim': SSIMScore(),
|
|
'lpips': LPIPSScore(),
|
|
'fid': FIDScore()
|
|
}
|
|
|
|
evaluator = InpaintingEvaluator(
|
|
mask_dataset, scores=metrics, batch_size=3, area_grouping=True
|
|
)
|
|
|
|
results = evaluator.evaluate(model)
|
|
print(results)
|