mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-16 05:01:06 +08:00
133 lines
4.1 KiB
Python
133 lines
4.1 KiB
Python
import sys
|
|
import argparse
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from matplotlib import pyplot as plt
|
|
from typing import List
|
|
import torch
|
|
|
|
from segment_anything import SamPredictor, sam_model_registry
|
|
from backend.inpaint.utils import load_img_to_array, save_array_to_img, dilate_mask, \
|
|
show_mask, show_points
|
|
|
|
|
|
def predict_masks_with_sam(
|
|
img: np.ndarray,
|
|
point_coords: List[List[float]],
|
|
point_labels: List[int],
|
|
model_type: str,
|
|
ckpt_p: str,
|
|
device="cuda"
|
|
):
|
|
point_coords = np.array(point_coords)
|
|
point_labels = np.array(point_labels)
|
|
sam = sam_model_registry[model_type](checkpoint=ckpt_p)
|
|
sam.to(device=device)
|
|
predictor = SamPredictor(sam)
|
|
|
|
predictor.set_image(img)
|
|
masks, scores, logits = predictor.predict(
|
|
point_coords=point_coords,
|
|
point_labels=point_labels,
|
|
multimask_output=True,
|
|
)
|
|
return masks, scores, logits
|
|
|
|
|
|
def build_sam_model(model_type: str, ckpt_p: str, device="cuda"):
|
|
sam = sam_model_registry[model_type](checkpoint=ckpt_p)
|
|
sam.to(device=device)
|
|
predictor = SamPredictor(sam)
|
|
return predictor
|
|
|
|
|
|
|
|
def setup_args(parser):
|
|
parser.add_argument(
|
|
"--input_img", type=str, required=True,
|
|
help="Path to a single input img",
|
|
)
|
|
parser.add_argument(
|
|
"--point_coords", type=float, nargs='+', required=True,
|
|
help="The coordinate of the point prompt, [coord_W coord_H].",
|
|
)
|
|
parser.add_argument(
|
|
"--point_labels", type=int, nargs='+', required=True,
|
|
help="The labels of the point prompt, 1 or 0.",
|
|
)
|
|
parser.add_argument(
|
|
"--dilate_kernel_size", type=int, default=None,
|
|
help="Dilate kernel size. Default: None",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir", type=str, required=True,
|
|
help="Output path to the directory with results.",
|
|
)
|
|
parser.add_argument(
|
|
"--sam_model_type", type=str,
|
|
default="vit_h", choices=['vit_h', 'vit_l', 'vit_b'],
|
|
help="The type of sam model to load. Default: 'vit_h"
|
|
)
|
|
parser.add_argument(
|
|
"--sam_ckpt", type=str, required=True,
|
|
help="The path to the SAM checkpoint to use for mask generation.",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""Example usage:
|
|
python sam_segment.py \
|
|
--input_img FA_demo/FA1_dog.png \
|
|
--point_coords 750 500 \
|
|
--point_labels 1 \
|
|
--dilate_kernel_size 15 \
|
|
--output_dir ./results \
|
|
--sam_model_type "vit_h" \
|
|
--sam_ckpt sam_vit_h_4b8939.pth
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
setup_args(parser)
|
|
args = parser.parse_args(sys.argv[1:])
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
img = load_img_to_array(args.input_img)
|
|
|
|
masks, _, _ = predict_masks_with_sam(
|
|
img,
|
|
[args.point_coords],
|
|
args.point_labels,
|
|
model_type=args.sam_model_type,
|
|
ckpt_p=args.sam_ckpt,
|
|
device=device,
|
|
)
|
|
masks = masks.astype(np.uint8) * 255
|
|
|
|
# dilate mask to avoid unmasked edge effect
|
|
if args.dilate_kernel_size is not None:
|
|
masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
|
|
|
|
# visualize the segmentation results
|
|
img_stem = Path(args.input_img).stem
|
|
out_dir = Path(args.output_dir) / img_stem
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
for idx, mask in enumerate(masks):
|
|
# path to the results
|
|
mask_p = out_dir / f"mask_{idx}.png"
|
|
img_points_p = out_dir / f"with_points.png"
|
|
img_mask_p = out_dir / f"with_{Path(mask_p).name}"
|
|
|
|
# save the mask
|
|
save_array_to_img(mask, mask_p)
|
|
|
|
# save the pointed and masked image
|
|
dpi = plt.rcParams['figure.dpi']
|
|
height, width = img.shape[:2]
|
|
plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
|
|
plt.imshow(img)
|
|
plt.axis('off')
|
|
show_points(plt.gca(), [args.point_coords], args.point_labels,
|
|
size=(width*0.04)**2)
|
|
plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
|
|
show_mask(plt.gca(), mask, random_color=False)
|
|
plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
|
|
plt.close() |