mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-13 06:37:32 +08:00
80 lines
1.8 KiB
Python
Executable File
80 lines
1.8 KiB
Python
Executable File
import sys
|
|
import argparse
|
|
import os
|
|
import cv2
|
|
import glob
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from .raft import RAFT
|
|
from .utils import flow_viz
|
|
from .utils.utils import InputPadder
|
|
|
|
|
|
|
|
DEVICE = 'cuda'
|
|
|
|
def load_image(imfile):
|
|
img = np.array(Image.open(imfile)).astype(np.uint8)
|
|
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
|
return img
|
|
|
|
|
|
def load_image_list(image_files):
|
|
images = []
|
|
for imfile in sorted(image_files):
|
|
images.append(load_image(imfile))
|
|
|
|
images = torch.stack(images, dim=0)
|
|
images = images.to(DEVICE)
|
|
|
|
padder = InputPadder(images.shape)
|
|
return padder.pad(images)[0]
|
|
|
|
|
|
def viz(img, flo):
|
|
img = img[0].permute(1,2,0).cpu().numpy()
|
|
flo = flo[0].permute(1,2,0).cpu().numpy()
|
|
|
|
# map flow to rgb image
|
|
flo = flow_viz.flow_to_image(flo)
|
|
# img_flo = np.concatenate([img, flo], axis=0)
|
|
img_flo = flo
|
|
|
|
cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
|
|
# cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
|
# cv2.waitKey()
|
|
|
|
|
|
def demo(args):
|
|
model = torch.nn.DataParallel(RAFT(args))
|
|
model.load_state_dict(torch.load(args.model))
|
|
|
|
model = model.module
|
|
model.to(DEVICE)
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
|
glob.glob(os.path.join(args.path, '*.jpg'))
|
|
|
|
images = load_image_list(images)
|
|
for i in range(images.shape[0]-1):
|
|
image1 = images[i,None]
|
|
image2 = images[i+1,None]
|
|
|
|
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
|
viz(image1, flow_up)
|
|
|
|
|
|
def RAFT_infer(args):
|
|
model = torch.nn.DataParallel(RAFT(args))
|
|
model.load_state_dict(torch.load(args.model))
|
|
|
|
model = model.module
|
|
model.to(DEVICE)
|
|
model.eval()
|
|
|
|
return model
|