mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-02 06:27:31 +08:00
修复LAMA模式100%卡死:帧区间扩展超出视频总帧数导致FramePrefetcher死锁
- 限制字幕区间end不超过frame_count,防止内循环消费哨兵后外层永久阻塞 - LAMA批量推理改为mini-batch(4帧),避免GPU OOM - 各inpaint模型空inpaint_area时返回原始帧 - FFmpeg子进程添加600s超时保护 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -28,35 +28,42 @@ class LamaInpaint:
|
||||
return cur_res
|
||||
|
||||
def _inpaint_batch(self, images: List[np.ndarray], masks: List[np.ndarray]):
|
||||
"""批量推理:将多帧合并为一个 batch tensor 一次性送入 GPU"""
|
||||
"""批量推理:将多帧分小批次送入 GPU,避免单次推理过大导致卡死"""
|
||||
if len(images) == 1:
|
||||
return [self.inpaint(images[0], masks[0])]
|
||||
|
||||
orig_height, orig_width = images[0].shape[:2]
|
||||
batch_imgs = []
|
||||
batch_masks = []
|
||||
for img, msk in zip(images, masks):
|
||||
batch_imgs.append(get_image(img))
|
||||
batch_masks.append(get_image(msk))
|
||||
# 分小批次推理,每批最多 4 帧
|
||||
mini_batch_size = 4
|
||||
results = [None] * len(images)
|
||||
for start in range(0, len(images), mini_batch_size):
|
||||
end = min(start + mini_batch_size, len(images))
|
||||
batch_imgs = []
|
||||
batch_masks = []
|
||||
for i in range(start, end):
|
||||
batch_imgs.append(get_image(images[i]))
|
||||
batch_masks.append(get_image(masks[i]))
|
||||
|
||||
# 堆叠为 (B, C, H, W) 并 pad 到 8 的倍数
|
||||
batch_imgs = np.stack(batch_imgs)
|
||||
batch_masks = np.stack(batch_masks)
|
||||
padded_imgs = np.stack([pad_img_to_modulo(img, 8) for img in batch_imgs])
|
||||
padded_masks = np.stack([pad_img_to_modulo(m, 8) for m in batch_masks])
|
||||
|
||||
# 对每个样本做 pad
|
||||
padded_imgs = np.stack([pad_img_to_modulo(img, 8) for img in batch_imgs])
|
||||
padded_masks = np.stack([pad_img_to_modulo(m, 8) for m in batch_masks])
|
||||
img_tensor = torch.from_numpy(padded_imgs).to(self.device)
|
||||
mask_tensor = torch.from_numpy(padded_masks).to(self.device)
|
||||
mask_tensor = (mask_tensor > 0) * 1
|
||||
|
||||
img_tensor = torch.from_numpy(padded_imgs).to(self.device)
|
||||
mask_tensor = torch.from_numpy(padded_masks).to(self.device)
|
||||
mask_tensor = (mask_tensor > 0) * 1
|
||||
with torch.inference_mode():
|
||||
inpainted = self.model(img_tensor, mask_tensor)
|
||||
batch_results = inpainted.permute(0, 2, 3, 1).detach().cpu().numpy()
|
||||
batch_results = np.clip(batch_results * 255, 0, 255).astype('uint8')
|
||||
|
||||
with torch.inference_mode():
|
||||
inpainted = self.model(img_tensor, mask_tensor)
|
||||
results = inpainted.permute(0, 2, 3, 1).detach().cpu().numpy()
|
||||
results = np.clip(results * 255, 0, 255).astype('uint8')
|
||||
for i in range(end - start):
|
||||
results[start + i] = batch_results[i][:orig_height, :orig_width]
|
||||
|
||||
return [results[i][:orig_height, :orig_width] for i in range(len(images))]
|
||||
del img_tensor, mask_tensor, padded_imgs, padded_masks
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
|
||||
"""
|
||||
@@ -98,6 +105,9 @@ class LamaInpaint:
|
||||
for k in range(len(inpaint_area)):
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = comps[k][j]
|
||||
inpainted_frames.append(frame)
|
||||
else:
|
||||
# 无需处理的区域,返回原始帧
|
||||
inpainted_frames = frames_hr
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -413,6 +413,8 @@ class PropainterInpaint:
|
||||
# 将最终帧添加到列表
|
||||
inpainted_frames.append(frame)
|
||||
# print(f'processing frame, {len(frames_hr) - j} left')
|
||||
else:
|
||||
inpainted_frames = frames_hr
|
||||
return inpainted_frames
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,8 @@ class STTNInpaint:
|
||||
# 将最终帧添加到列表
|
||||
inpainted_frames.append(frame)
|
||||
# print(f'processing frame, {len(frames_hr) - j} left')
|
||||
else:
|
||||
inpainted_frames = frames_hr
|
||||
return inpainted_frames
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -94,6 +94,8 @@ class STTNDetInpaint:
|
||||
# 将最终帧添加到列表
|
||||
inpainted_frames.append(frame)
|
||||
# print(f'processing frame, {len(frames_hr) - j} left')
|
||||
else:
|
||||
inpainted_frames = frames_hr
|
||||
return inpainted_frames
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -271,9 +271,9 @@ class SubtitleRemover:
|
||||
del sub_detector
|
||||
gc.collect()
|
||||
start_end_map = dict()
|
||||
for interval in continuous_frame_no_list:
|
||||
start, end = interval
|
||||
start_end_map[start] = end
|
||||
for start, end in continuous_frame_no_list:
|
||||
# 确保区间不超出视频总帧数,否则会导致 FramePrefetcher 哨兵被内循环消费后外层死锁
|
||||
start_end_map[start] = min(end, self.frame_count)
|
||||
current_frame_index = 0
|
||||
self.append_output(tr['Main']['ProcessingStartRemovingSubtitles'])
|
||||
# 使用帧预读取,I/O 与推理重叠
|
||||
@@ -423,7 +423,7 @@ class SubtitleRemover:
|
||||
"-vn", "-loglevel", "error", temp.name]
|
||||
use_shell = True if os.name == "nt" else False
|
||||
try:
|
||||
subprocess.check_output(audio_extract_command, stdin=open(os.devnull), shell=use_shell)
|
||||
subprocess.check_output(audio_extract_command, stdin=open(os.devnull), shell=use_shell, timeout=600)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
self.append_output(tr['Main']['FailToExtractAudio'].format(str(e)))
|
||||
@@ -437,7 +437,7 @@ class SubtitleRemover:
|
||||
"-acodec", "copy",
|
||||
"-loglevel", "error", self.video_out_path]
|
||||
try:
|
||||
subprocess.check_output(audio_merge_command, stdin=open(os.devnull), shell=use_shell)
|
||||
subprocess.check_output(audio_merge_command, stdin=open(os.devnull), shell=use_shell, timeout=600)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
self.append_output(tr['Main']['FailToMergeAudio'].format(str(e)))
|
||||
|
||||
@@ -97,4 +97,8 @@ class FFmpegVideoWriter:
|
||||
self._process.stdin.close()
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
self._process.wait()
|
||||
try:
|
||||
self._process.wait(timeout=600)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._process.terminate()
|
||||
self._process.wait(timeout=5)
|
||||
|
||||
Reference in New Issue
Block a user