mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-22 22:27:33 +08:00
修复bug
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -18,10 +19,15 @@ class LamaInpaint:
|
||||
self.device = device
|
||||
|
||||
def __call__(self, image: Union[Image.Image, np.ndarray], mask: Union[Image.Image, np.ndarray]):
|
||||
if isinstance(image, np.ndarray):
|
||||
orig_height, orig_width = image.shape[:2]
|
||||
else:
|
||||
orig_height, orig_width = np.array(image).shape[:2]
|
||||
image, mask = prepare_img_and_mask(image, mask, self.device)
|
||||
with torch.inference_mode():
|
||||
inpainted = self.model(image, mask)
|
||||
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype(np.uint8)
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
||||
cur_res = cur_res[:orig_height, :orig_width]
|
||||
return cur_res
|
||||
|
||||
|
||||
Reference in New Issue
Block a user