mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-03-07 03:07:33 +08:00
197 lines
6.1 KiB
Python
197 lines
6.1 KiB
Python
import os
|
|
import sys
|
|
import numpy as np
|
|
import torch
|
|
import yaml
|
|
import glob
|
|
import argparse
|
|
from omegaconf import OmegaConf
|
|
from pathlib import Path
|
|
|
|
os.environ['OMP_NUM_THREADS'] = '1'
|
|
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
|
os.environ['MKL_NUM_THREADS'] = '1'
|
|
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
|
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
|
|
|
from backend.inpaint.lama.saicinpainting.evaluation.utils import move_to_device
|
|
from backend.inpaint.lama.saicinpainting.training.trainers import load_checkpoint
|
|
from backend.inpaint.lama.saicinpainting.evaluation.data import pad_tensor_to_modulo
|
|
from backend.inpaint.utils import load_img_to_array, save_array_to_img
|
|
|
|
|
|
@torch.no_grad()
|
|
def inpaint_img_with_lama(
|
|
img: np.ndarray,
|
|
mask: np.ndarray,
|
|
config_p: str,
|
|
ckpt_p: str,
|
|
mod=8,
|
|
device="cuda"
|
|
):
|
|
assert len(mask.shape) == 2
|
|
if np.max(mask) == 1:
|
|
mask = mask * 255
|
|
img = torch.from_numpy(img).float().div(255.)
|
|
mask = torch.from_numpy(mask).float()
|
|
predict_config = OmegaConf.load(config_p)
|
|
predict_config.model.path = ckpt_p
|
|
# device = torch.device(predict_config.device)
|
|
device = torch.device(device)
|
|
|
|
train_config_path = os.path.join(
|
|
predict_config.model.path, 'config.yaml')
|
|
|
|
with open(train_config_path, 'r') as f:
|
|
train_config = OmegaConf.create(yaml.safe_load(f))
|
|
|
|
train_config.training_model.predict_only = True
|
|
train_config.visualizer.kind = 'noop'
|
|
|
|
checkpoint_path = os.path.join(
|
|
predict_config.model.path, 'models',
|
|
predict_config.model.checkpoint
|
|
)
|
|
model = load_checkpoint(
|
|
train_config, checkpoint_path, strict=False, map_location='cpu')
|
|
model.freeze()
|
|
if not predict_config.get('refine', False):
|
|
model.to(device)
|
|
|
|
batch = {}
|
|
batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
|
|
batch['mask'] = mask[None, None]
|
|
unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
|
|
batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
|
|
batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
|
|
batch = move_to_device(batch, device)
|
|
batch['mask'] = (batch['mask'] > 0) * 1
|
|
|
|
batch = model(batch)
|
|
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
|
|
cur_res = cur_res.detach().cpu().numpy()
|
|
|
|
if unpad_to_size is not None:
|
|
orig_height, orig_width = unpad_to_size
|
|
cur_res = cur_res[:orig_height, :orig_width]
|
|
|
|
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
|
return cur_res
|
|
|
|
|
|
def build_lama_model(
|
|
config_p: str,
|
|
ckpt_p: str,
|
|
device="cuda"
|
|
):
|
|
predict_config = OmegaConf.load(config_p)
|
|
predict_config.model.path = ckpt_p
|
|
device = torch.device(device)
|
|
|
|
train_config_path = os.path.join(
|
|
predict_config.model.path, 'config.yaml')
|
|
|
|
with open(train_config_path, 'r') as f:
|
|
train_config = OmegaConf.create(yaml.safe_load(f))
|
|
|
|
train_config.training_model.predict_only = True
|
|
train_config.visualizer.kind = 'noop'
|
|
|
|
checkpoint_path = os.path.join(
|
|
predict_config.model.path, '../models',
|
|
predict_config.model.checkpoint
|
|
)
|
|
model = load_checkpoint(train_config, checkpoint_path, strict=False)
|
|
model.to(device)
|
|
model.freeze()
|
|
return model
|
|
|
|
|
|
@torch.no_grad()
|
|
def inpaint_img_with_builded_lama(
|
|
model,
|
|
img: np.ndarray,
|
|
mask: np.ndarray,
|
|
config_p=None,
|
|
mod=8,
|
|
device="cuda"
|
|
):
|
|
assert len(mask.shape) == 2
|
|
if np.max(mask) == 1:
|
|
mask = mask * 255
|
|
img = torch.from_numpy(img).float().div(255.)
|
|
mask = torch.from_numpy(mask).float()
|
|
|
|
batch = {}
|
|
batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
|
|
batch['mask'] = mask[None, None]
|
|
unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
|
|
batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
|
|
batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
|
|
batch = move_to_device(batch, device)
|
|
batch['mask'] = (batch['mask'] > 0) * 1
|
|
|
|
batch = model(batch)
|
|
cur_res = batch["inpainted"][0].permute(1, 2, 0)
|
|
cur_res = cur_res.detach().cpu().numpy()
|
|
|
|
if unpad_to_size is not None:
|
|
orig_height, orig_width = unpad_to_size
|
|
cur_res = cur_res[:orig_height, :orig_width]
|
|
|
|
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
|
return cur_res
|
|
|
|
|
|
|
|
def setup_args(parser):
|
|
parser.add_argument(
|
|
"--input_img", type=str, required=True,
|
|
help="Path to a single input img",
|
|
)
|
|
parser.add_argument(
|
|
"--input_mask_glob", type=str, required=True,
|
|
help="Glob to input masks",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir", type=str, required=True,
|
|
help="Output path to the directory with results.",
|
|
)
|
|
parser.add_argument(
|
|
"--lama_config", type=str,
|
|
default="./lama/configs/prediction/default.yaml",
|
|
help="The path to the config file of lama model. "
|
|
"Default: the config of big-lama",
|
|
)
|
|
parser.add_argument(
|
|
"--lama_ckpt", type=str, required=True,
|
|
help="The path to the lama checkpoint.",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""Example usage:
|
|
python lama_inpaint.py \
|
|
--input_img FA_demo/FA1_dog.png \
|
|
--input_mask_glob "results/FA1_dog/mask*.png" \
|
|
--output_dir results \
|
|
--lama_config lama/configs/prediction/default.yaml \
|
|
--lama_ckpt big-lama
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
setup_args(parser)
|
|
args = parser.parse_args(sys.argv[1:])
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
img_stem = Path(args.input_img).stem
|
|
mask_ps = sorted(glob.glob(args.input_mask_glob))
|
|
out_dir = Path(args.output_dir) / img_stem
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
img = load_img_to_array(args.input_img)
|
|
for mask_p in mask_ps:
|
|
mask = load_img_to_array(mask_p)
|
|
img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
|
|
img_inpainted = inpaint_img_with_lama(
|
|
img, mask, args.lama_config, args.lama_ckpt, device=device)
|
|
save_array_to_img(img_inpainted, img_inpainted_p) |