修复bug

This commit is contained in:
YaoFANGUK
2023-12-13 19:50:15 +08:00
parent 6b817bd57a
commit 29c5317a69
3 changed files with 16 additions and 9 deletions

View File

@@ -1,6 +1,7 @@
import os
from typing import Union
import cv2
import torch
import numpy as np
from PIL import Image
@@ -18,10 +19,15 @@ class LamaInpaint:
self.device = device
def __call__(self, image: Union[Image.Image, np.ndarray], mask: Union[Image.Image, np.ndarray]):
if isinstance(image, np.ndarray):
orig_height, orig_width = image.shape[:2]
else:
orig_height, orig_width = np.array(image).shape[:2]
image, mask = prepare_img_and_mask(image, mask, self.device)
with torch.inference_mode():
inpainted = self.model(image, mask)
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype(np.uint8)
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cur_res[:orig_height, :orig_width]
return cur_res