mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-03-12 23:27:33 +08:00
init
This commit is contained in:
76
backend/inpaint/lama/bin/evaluator_example.py
Normal file
76
backend/inpaint/lama/bin/evaluator_example.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from skimage import io
|
||||
from skimage.transform import resize
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from saicinpainting.evaluation.evaluator import InpaintingEvaluator
|
||||
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
|
||||
|
||||
|
||||
class SimpleImageDataset(Dataset):
|
||||
def __init__(self, root_dir, image_size=(400, 600)):
|
||||
self.root_dir = root_dir
|
||||
self.files = sorted(os.listdir(root_dir))
|
||||
self.image_size = image_size
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_name = os.path.join(self.root_dir, self.files[index])
|
||||
image = io.imread(img_name)
|
||||
image = resize(image, self.image_size, anti_aliasing=True)
|
||||
image = torch.FloatTensor(image).permute(2, 0, 1)
|
||||
return image
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files)
|
||||
|
||||
|
||||
def create_rectangle_mask(height, width):
|
||||
mask = np.ones((height, width))
|
||||
up_left_corner = width // 4, height // 4
|
||||
down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1)
|
||||
cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED)
|
||||
return mask
|
||||
|
||||
|
||||
class Model():
|
||||
def __call__(self, img_batch, mask_batch):
|
||||
mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None]
|
||||
inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :]
|
||||
return inpainted
|
||||
|
||||
|
||||
class SimpleImageSquareMaskDataset(Dataset):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size))
|
||||
self.model = Model()
|
||||
|
||||
def __getitem__(self, index):
|
||||
img = self.dataset[index]
|
||||
mask = self.mask.clone()
|
||||
inpainted = self.model(img[None, ...], mask[None, ...])
|
||||
return dict(image=img, mask=mask, inpainted=inpainted)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
dataset = SimpleImageDataset('imgs')
|
||||
mask_dataset = SimpleImageSquareMaskDataset(dataset)
|
||||
model = Model()
|
||||
metrics = {
|
||||
'ssim': SSIMScore(),
|
||||
'lpips': LPIPSScore(),
|
||||
'fid': FIDScore()
|
||||
}
|
||||
|
||||
evaluator = InpaintingEvaluator(
|
||||
mask_dataset, scores=metrics, batch_size=3, area_grouping=True
|
||||
)
|
||||
|
||||
results = evaluator.evaluate(model)
|
||||
print(results)
|
||||
Reference in New Issue
Block a user