mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-15 20:34:45 +08:00
81 lines
2.1 KiB
Python
81 lines
2.1 KiB
Python
import os
|
|
import sys
|
|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
from torch.hub import download_url_to_file, get_dir
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
# Source https://github.com/advimman/lama
|
|
def get_image(image):
|
|
if isinstance(image, Image.Image):
|
|
img = np.array(image)
|
|
elif isinstance(image, np.ndarray):
|
|
img = image.copy()
|
|
else:
|
|
raise Exception("Input image should be either PIL Image or numpy array!")
|
|
|
|
if img.ndim == 3:
|
|
img = np.transpose(img, (2, 0, 1)) # chw
|
|
elif img.ndim == 2:
|
|
img = img[np.newaxis, ...]
|
|
|
|
assert img.ndim == 3
|
|
|
|
img = img.astype(np.float32) / 255
|
|
return img
|
|
|
|
|
|
def ceil_modulo(x, mod):
|
|
if x % mod == 0:
|
|
return x
|
|
return (x // mod + 1) * mod
|
|
|
|
|
|
def scale_image(img, factor, interpolation=cv2.INTER_AREA):
|
|
if img.shape[0] == 1:
|
|
img = img[0]
|
|
else:
|
|
img = np.transpose(img, (1, 2, 0))
|
|
|
|
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
|
|
|
|
if img.ndim == 2:
|
|
img = img[None, ...]
|
|
else:
|
|
img = np.transpose(img, (2, 0, 1))
|
|
return img
|
|
|
|
|
|
def pad_img_to_modulo(img, mod):
|
|
channels, height, width = img.shape
|
|
out_height = ceil_modulo(height, mod)
|
|
out_width = ceil_modulo(width, mod)
|
|
return np.pad(
|
|
img,
|
|
((0, 0), (0, out_height - height), (0, out_width - width)),
|
|
mode="symmetric",
|
|
)
|
|
|
|
|
|
def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
|
|
out_image = get_image(image)
|
|
out_mask = get_image(mask)
|
|
|
|
if scale_factor is not None:
|
|
out_image = scale_image(out_image, scale_factor)
|
|
out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
|
|
|
|
if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
|
|
out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
|
|
out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
|
|
|
|
out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
|
|
out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
|
|
|
|
out_mask = (out_mask > 0) * 1
|
|
|
|
return out_image, out_mask
|