mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-21 09:14:49 +08:00
150 lines
4.9 KiB
Python
150 lines
4.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
from PIL import Image
|
|
import numpy as np
|
|
import math
|
|
import time
|
|
import importlib
|
|
import os
|
|
import argparse
|
|
import copy
|
|
import datetime
|
|
import random
|
|
import sys
|
|
import json
|
|
|
|
import torch
|
|
from torch.autograd import Variable
|
|
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.nn.init as init
|
|
import torch.utils.model_zoo as model_zoo
|
|
from torchvision import models
|
|
import torch.multiprocessing as mp
|
|
from torchvision import transforms
|
|
|
|
# My libs
|
|
from core.utils import Stack, ToTorchFormatTensor
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="STTN")
|
|
parser.add_argument("-v", "--video", type=str, required=True)
|
|
parser.add_argument("-m", "--mask", type=str, required=True)
|
|
parser.add_argument("-c", "--ckpt", type=str, required=True)
|
|
parser.add_argument("--model", type=str, default='sttn')
|
|
args = parser.parse_args()
|
|
|
|
|
|
w, h = 432, 240
|
|
ref_length = 10
|
|
neighbor_stride = 5
|
|
default_fps = 24
|
|
|
|
_to_tensors = transforms.Compose([
|
|
Stack(),
|
|
ToTorchFormatTensor()])
|
|
|
|
|
|
# sample reference frames from the whole video
|
|
def get_ref_index(neighbor_ids, length):
|
|
ref_index = []
|
|
for i in range(0, length, ref_length):
|
|
if not i in neighbor_ids:
|
|
ref_index.append(i)
|
|
return ref_index
|
|
|
|
|
|
# read frame-wise masks
|
|
def read_mask(mpath):
|
|
masks = []
|
|
mnames = os.listdir(mpath)
|
|
mnames.sort()
|
|
for m in mnames:
|
|
m = Image.open(os.path.join(mpath, m))
|
|
m = m.resize((w, h), Image.NEAREST)
|
|
m = np.array(m.convert('L'))
|
|
m = np.array(m > 0).astype(np.uint8)
|
|
m = cv2.dilate(m, cv2.getStructuringElement(
|
|
cv2.MORPH_CROSS, (3, 3)), iterations=4)
|
|
masks.append(Image.fromarray(m*255))
|
|
return masks
|
|
|
|
|
|
# read frames from video
|
|
def read_frame_from_videos(vname):
|
|
frames = []
|
|
vidcap = cv2.VideoCapture(vname)
|
|
success, image = vidcap.read()
|
|
count = 0
|
|
while success:
|
|
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
|
frames.append(image.resize((w,h)))
|
|
success, image = vidcap.read()
|
|
count += 1
|
|
return frames
|
|
|
|
|
|
def main_worker():
|
|
# set up models
|
|
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
|
net = importlib.import_module('model.' + args.model)
|
|
model = net.InpaintGenerator().to(device)
|
|
model_path = args.ckpt
|
|
data = torch.load(args.ckpt, map_location=device)
|
|
model.load_state_dict(data['netG'])
|
|
print('loading from: {}'.format(args.ckpt))
|
|
model.eval()
|
|
|
|
# prepare datset, encode all frames into deep space
|
|
frames = read_frame_from_videos(args.video)
|
|
video_length = len(frames)
|
|
feats = _to_tensors(frames).unsqueeze(0)*2-1
|
|
frames = [np.array(f).astype(np.uint8) for f in frames]
|
|
|
|
masks = read_mask(args.mask)
|
|
binary_masks = [np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) for m in masks]
|
|
masks = _to_tensors(masks).unsqueeze(0)
|
|
feats, masks = feats.to(device), masks.to(device)
|
|
comp_frames = [None]*video_length
|
|
|
|
with torch.no_grad():
|
|
feats = model.encoder((feats*(1-masks).float()).view(video_length, 3, h, w))
|
|
_, c, feat_h, feat_w = feats.size()
|
|
feats = feats.view(1, video_length, c, feat_h, feat_w)
|
|
print('loading videos and masks from: {}'.format(args.video))
|
|
|
|
# completing holes by spatial-temporal transformers
|
|
for f in range(0, video_length, neighbor_stride):
|
|
neighbor_ids = [i for i in range(max(0, f-neighbor_stride), min(video_length, f+neighbor_stride+1))]
|
|
ref_ids = get_ref_index(neighbor_ids, video_length)
|
|
with torch.no_grad():
|
|
pred_feat = model.infer(
|
|
feats[0, neighbor_ids+ref_ids, :, :, :], masks[0, neighbor_ids+ref_ids, :, :, :])
|
|
pred_img = torch.tanh(model.decoder(
|
|
pred_feat[:len(neighbor_ids), :, :, :])).detach()
|
|
pred_img = (pred_img + 1) / 2
|
|
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy()*255
|
|
for i in range(len(neighbor_ids)):
|
|
idx = neighbor_ids[i]
|
|
img = np.array(pred_img[i]).astype(
|
|
np.uint8)*binary_masks[idx] + frames[idx] * (1-binary_masks[idx])
|
|
if comp_frames[idx] is None:
|
|
comp_frames[idx] = img
|
|
else:
|
|
comp_frames[idx] = comp_frames[idx].astype(
|
|
np.float32)*0.5 + img.astype(np.float32)*0.5
|
|
writer = cv2.VideoWriter(f"{args.mask}_result.mp4", cv2.VideoWriter_fourcc(*"mp4v"), default_fps, (w, h))
|
|
for f in range(video_length):
|
|
comp = np.array(comp_frames[f]).astype(
|
|
np.uint8)*binary_masks[f] + frames[f] * (1-binary_masks[f])
|
|
writer.write(cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB))
|
|
writer.release()
|
|
print('Finish in {}'.format(f"{args.mask}_result.mp4"))
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main_worker()
|