diff --git a/README.md b/README.md index 6867f44..34d256b 100755 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ 简体中文 | [English](README_en.md) +
+ VSR Logo +
+ ## 项目简介 ![License](https://img.shields.io/badge/License-Apache%202-red.svg) diff --git a/README_en.md b/README_en.md index 330d555..fb8d1ef 100755 --- a/README_en.md +++ b/README_en.md @@ -1,5 +1,9 @@ [简体中文](README.md) | English +
+ VSR Logo +
+ ## Project Introduction ![License](https://img.shields.io/badge/License-Apache%202-red.svg) @@ -7,38 +11,34 @@ ![support os](https://img.shields.io/badge/OS-Windows/macOS/Linux-green.svg) [![Docker](https://img.shields.io/badge/Docker-Image-blue?logo=docker)](https://hub.docker.com/r/eritpchy/video-subtitle-remover) -Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subtitles from videos. It mainly implements the following functionalities: +Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subtitles from videos. +It mainly implements the following functionalities: +- **Lossless resolution**: Removes hardcoded subtitles from videos and generates files without subtitles +- Fills in the removed subtitle text area using a powerful AI algorithm model (non-adjacent pixel filling and mosaic removal) +- Supports custom subtitle positions by only removing subtitles in the defined location (input position) +- Supports automatic removal of all text throughout the entire video (without inputting a position) +- Supports multi-selection of images for batch removal of watermark text -- **Lossless resolution**: Removes hardcoded subtitles from videos and generates files without subtitles. -- Fills in the removed subtitle text area using a powerful AI algorithm model (non-adjacent pixel filling and mosaic removal). -- Supports custom subtitle positions by only removing subtitles in the defined location (input position). -- Supports automatic removal of all text throughout the entire video (without inputting a position). -- Supports multi-selection of images for batch removal of watermark text. +![demo.png](https://github.com/YaoFANGUK/video-subtitle-remover/raw/main/design/demo.png) -![demo.png](design/demo.png) +**Instructions:** -> Download the .zip package directly, extract, and run it. If it cannot run, follow the tutorial below to try installing the conda environment and running the source code. - -**Download Links:** - -Windows GPU Version v1.1.0 (GPU): - -- Baidu Cloud Disk: vsr_windows_gpu_v1.1.0.zip Extraction Code: **vsr1** - -- Google Drive: vsr_windows_gpu_v1.1.0.zip +- If you have questions, please join the discussion group: QQ Group 210150985 (full), 806152575 (full), 816881808 (full), 295894827 +- Download the compressed package, extract and run it directly. If it cannot run, follow the tutorial below to try installing from source +**Download:** Release **Pre-built Package Comparison**: -| Pre-built Package Name | Python | Paddle | Torch | Environment | Supported Compute Capability Range | -|----------------------------------|------|-------|--------|-----------------------------------|------------------------------------| -| `vse-windows-cpu.7z` | 3.12 | 3.0.0 | 2.7.0 | Universal | Universal | -| `vse-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows without Nvidia GPU | Universal | -| `vse-windows-nvidia-cuda-11.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 – 8.9 | -| `vse-windows-nvidia-cuda-12.6.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 – 8.9 | -| `vse-windows-nvidia-cuda-12.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 – 9.0+ | +| Pre-built Package Name | Python | Paddle | Torch | Environment | Supported Compute Capability Range | +|-----------------------------------|--------|--------|-------|--------------------------------|------------------------------------| +| `vsr-windows-cpu.7z` | 3.12 | 3.0.0 | 2.7.0 | Universal | Universal | +| `vsr-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows non-Nvidia GPU | Universal | +| `vsr-windows-nvidia-cuda-11.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 – 8.9 | +| `vsr-windows-nvidia-cuda-12.6.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 – 8.9 | +| `vsr-windows-nvidia-cuda-12.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 – 9.0+ | -> NVIDIA provides a list of supported compute capabilities for each GPU model. You can refer to the following link: [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) to check which CUDA version is compatible with your GPU. +> NVIDIA provides a list of compute capabilities for each GPU model. Refer to [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) to check which CUDA version is compatible with your GPU. **Docker Versions:** ```shell @@ -46,7 +46,7 @@ Windows GPU Version v1.1.0 (GPU): docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.4.0-cuda11.8 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4 # Nvidia 40 Series Graphics Cards - docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.4.0-cuda12.6 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4 + docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.4.0-cuda12.6 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4 # Nvidia 50 Series Graphics Cards docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.4.0-cuda12.8 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4 @@ -57,11 +57,11 @@ Windows GPU Version v1.1.0 (GPU): # CPU docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.4.0-cpu python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4 - # Copy to host + # Export video docker cp vsr:/vsr/test/test_no_sub.mp4 ./ ``` -**Commandline:** +**Command Line:** ``` Video Subtitle Remover Command Line Tool @@ -82,12 +82,11 @@ options:

demo2.gif

-- Click to view demo video👇 -

demo.gif

## Source Code Usage Instructions + #### 1. Install Python Please ensure that you have installed Python 3.12+. @@ -159,7 +158,6 @@ This project supports four running modes: CUDA (NVIDIA GPU acceleration), CPU (n ```shell pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/ ``` - - Install Torch GPU version (CUDA 11.8): ```shell pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu118 @@ -170,7 +168,7 @@ This project supports four running modes: CUDA (NVIDIA GPU acceleration), CPU (n pip install -r requirements.txt ``` -- For Linux systems, you also need to install +- For Linux systems, you also need to install: ```shell # for cuda 12.x @@ -187,9 +185,8 @@ This project supports four running modes: CUDA (NVIDIA GPU acceleration), CPU (n ```shell pip install paddlepaddle==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/ pip install -r requirements.txt - pip install -r requirements_directml.txt + pip install torch_directml==0.2.5.dev240914 ``` - ##### (3) CPU Only (For systems without GPU or those not wanting to use GPU acceleration) - Suitable for systems without GPU or those that do not wish to use GPU. @@ -198,7 +195,6 @@ This project supports four running modes: CUDA (NVIDIA GPU acceleration), CPU (n pip install torch==2.7.0 torchvision==0.22.0 pip install -r requirements.txt ``` - ##### (4) Running on macOS (Apple Silicon) - Suitable for macOS (Apple Silicon) devices - For macOS (Intel), please use the CPU mode. Forcing GPU usage will only be slower. @@ -209,7 +205,6 @@ This project supports four running modes: CUDA (NVIDIA GPU acceleration), CPU (n pip install -r requirements.txt ``` > Tested with Python 3.13 - #### 4. Run the program - Run the graphical interface @@ -225,23 +220,21 @@ python ./backend/main.py ``` ## Common Issues - 1. How to deal with slow removal speed You can greatly increase the removal speed by modifying the parameters in backend/config.py: - ```python MODE = InpaintMode.STTN # Set to STTN algorithm -STTN_SKIP_DETECTION = True # Skip subtitle detection +STTN_SKIP_DETECTION = True # Skip subtitle detection, skipping may cause missed subtitles or damage to frames without subtitles ``` 2. What to do if the video removal results are not satisfactory Modify the values in backend/config.py and try different removal algorithms. Here is an introduction to the algorithms: -> - **InpaintMode.STTN** algorithm: Good for live-action videos and fast in speed, capable of skipping subtitle detection -> - **InpaintMode.LAMA** algorithm: Best for images and effective for animated videos, moderate speed, unable to skip subtitle detection -> - **InpaintMode.PROPAINTER** algorithm: Consumes a significant amount of VRAM, slower in speed, works better for videos with very intense movement +> - InpaintMode.STTN algorithm: Good for live-action videos and fast in speed, capable of skipping subtitle detection +> - InpaintMode.LAMA algorithm: Best for images and effective for animated videos, moderate speed, unable to skip subtitle detection +> - InpaintMode.PROPAINTER algorithm: Consumes a significant amount of VRAM, slower in speed, works better for videos with very intense movement - Using the STTN algorithm @@ -256,21 +249,18 @@ STTN_REFERENCE_LENGTH = 10 STTN_MAX_LOAD_NUM = 30 ``` - Using the LAMA algorithm - ```python MODE = InpaintMode.LAMA # Set to LAMA algorithm LAMA_SUPER_FAST = False # Ensure quality ``` +> If you are not satisfied with the subtitle removal results, you can check the training methods in the design folder, use the code in backend/tools/train to train, and then replace the old model with the trained model. -3. CondaHTTPError - -Place the .condarc file from the project in the user directory (C:/Users/). If the file already exists in the user directory, overwrite it. - -Solution: https://zhuanlan.zhihu.com/p/260034241 - -4. 7z file extraction error +3. 7z file extraction error Solution: Upgrade the 7-zip extraction program to the latest version. +## Sponsor + + diff --git a/backend/inpaint/lama_inpaint.py b/backend/inpaint/lama_inpaint.py index e68d1ac..10ad96a 100644 --- a/backend/inpaint/lama_inpaint.py +++ b/backend/inpaint/lama_inpaint.py @@ -1,9 +1,10 @@ import os +import gc from typing import Union, List import torch import numpy as np from PIL import Image -from backend.inpaint.utils.lama_util import prepare_img_and_mask +from backend.inpaint.utils.lama_util import prepare_img_and_mask, get_image, pad_img_to_modulo from backend import config from backend.tools.inpaint_tools import get_inpaint_area_by_mask @@ -26,6 +27,37 @@ class LamaInpaint: cur_res = cur_res[:orig_height, :orig_width] return cur_res + def _inpaint_batch(self, images: List[np.ndarray], masks: List[np.ndarray]): + """批量推理:将多帧合并为一个 batch tensor 一次性送入 GPU""" + if len(images) == 1: + return [self.inpaint(images[0], masks[0])] + + orig_height, orig_width = images[0].shape[:2] + batch_imgs = [] + batch_masks = [] + for img, msk in zip(images, masks): + batch_imgs.append(get_image(img)) + batch_masks.append(get_image(msk)) + + # 堆叠为 (B, C, H, W) 并 pad 到 8 的倍数 + batch_imgs = np.stack(batch_imgs) + batch_masks = np.stack(batch_masks) + + # 对每个样本做 pad + padded_imgs = np.stack([pad_img_to_modulo(img, 8) for img in batch_imgs]) + padded_masks = np.stack([pad_img_to_modulo(m, 8) for m in batch_masks]) + + img_tensor = torch.from_numpy(padded_imgs).to(self.device) + mask_tensor = torch.from_numpy(padded_masks).to(self.device) + mask_tensor = (mask_tensor > 0) * 1 + + with torch.inference_mode(): + inpainted = self.model(img_tensor, mask_tensor) + results = inpainted.permute(0, 2, 3, 1).detach().cpu().numpy() + results = np.clip(results * 255, 0, 255).astype('uint8') + + return [results[i][:orig_height, :orig_width] for i in range(len(images))] + def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray): """ :param input_frames: 原视频帧 @@ -38,48 +70,37 @@ class LamaInpaint: # 确定去字幕的垂直高度部分 split_h = int(W_ori * 3 / 16) inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask) - # 初始化帧存储变量 - # 高分辨率帧存储列表(浅拷贝 + 逐帧 copy,避免 deepcopy 开销) + # 高分辨率帧存储列表 frames_hr = [f.copy() for f in input_frames] - frames_scaled = {} # 存放缩放后帧的字典 - masks_scaled = {} # 存放缩放后遮罩的字典 comps = {} # 存放补全后帧的字典 # 存储最终的视频帧 inpainted_frames = [] - for k in range(len(inpaint_area)): - frames_scaled[k] = [] # 为每个去除部分初始化一个列表 - masks_scaled[k] = [] # 为每个去除部分初始化一个列表 - # 读取并缩放帧 - for j in range(len(frames_hr)): - image = frames_hr[j] - # 对每个去除部分进行切割和缩放 - for k in range(len(inpaint_area)): - image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割 - mask_crop = mask[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割 - frames_scaled[k].append(image_crop) # 将切割后的帧添加到对应列表 - masks_scaled[k].append(mask_crop) # 将切割后的遮罩添加到对应列表 - - # 处理每一个去除部分 for k in range(len(inpaint_area)): - # 调用inpaint函数逐帧处理 - comps[k] = [] - for i in range(len(frames_scaled[k])): - inpainted_frame = self.inpaint(frames_scaled[k][i], masks_scaled[k][i]) - comps[k].append(inpainted_frame) + # 收集该区域的所有裁剪帧和遮罩 + cropped_frames = [] + cropped_masks = [] + for j in range(len(frames_hr)): + image_crop = frames_hr[j][inpaint_area[k][0]:inpaint_area[k][1], :, :] + mask_crop = mask[inpaint_area[k][0]:inpaint_area[k][1], :, :] + cropped_frames.append(image_crop) + cropped_masks.append(mask_crop) + + # 批量推理 + comps[k] = self._inpaint_batch(cropped_frames, cropped_masks) + del cropped_frames, cropped_masks + gc.collect() # 如果存在去除部分 if inpaint_area: for j in range(len(frames_hr)): - frame = frames_hr[j] # 取出原始帧 - # 对于模式中的每一个段落 + frame = frames_hr[j] for k in range(len(inpaint_area)): - comp = comps[k][j] # 获取补全后的帧 - # 实现遮罩区域内的图像融合 - frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = comp - # 将最终帧添加到列表 + frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = comps[k][j] inpainted_frames.append(frame) - # print(f'processing frame, {len(frames_hr) - j} left') + + if torch.cuda.is_available(): + torch.cuda.empty_cache() return inpainted_frames diff --git a/backend/inpaint/propainter_inpaint.py b/backend/inpaint/propainter_inpaint.py index 726f713..8fffc58 100644 --- a/backend/inpaint/propainter_inpaint.py +++ b/backend/inpaint/propainter_inpaint.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import os +import gc import cv2 -import copy import numpy as np import scipy.ndimage from PIL import Image @@ -374,7 +374,7 @@ class PropainterInpaint: inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask, multiple=8) # 初始化帧存储变量 # 高分辨率帧存储列表 - frames_hr = copy.deepcopy(input_frames) + frames_hr = [f.copy() for f in input_frames] frames_scaled = {} # 存放缩放后帧的字典 masks_scaled = {} # 存放缩放后遮罩的字典 comps = {} # 存放补全后帧的字典 @@ -398,6 +398,8 @@ class PropainterInpaint: for k in range(len(inpaint_area)): # 调用inpaint函数进行处理 comps[k] = self.inpaint(frames_scaled[k], masks_scaled[k][0]) + del frames_scaled[k], masks_scaled[k] + gc.collect() # 如果存在去除部分 if inpaint_area: diff --git a/backend/inpaint/sttn_auto_inpaint.py b/backend/inpaint/sttn_auto_inpaint.py index bc25af9..ccc14c5 100644 --- a/backend/inpaint/sttn_auto_inpaint.py +++ b/backend/inpaint/sttn_auto_inpaint.py @@ -1,6 +1,7 @@ import os import time import sys +import gc from typing import List import cv2 @@ -15,6 +16,8 @@ from backend.config import config from backend.inpaint.sttn.auto_sttn import InpaintGenerator from backend.inpaint.utils.sttn_utils import Stack, ToTorchFormatTensor from backend.tools.inpaint_tools import get_inpaint_area_by_mask, is_frame_number_in_ab_sections +from backend.tools.video_io import FramePrefetcher +from backend.tools.hardware_accelerator import HardwareAccelerator # 定义图像预处理方式 _to_tensors = transforms.Compose([ @@ -125,7 +128,7 @@ class STTNInpaint: feats = feats.to(self.device) # 初始化一个与视频长度相同的列表,用于存储处理完成的帧 comp_frames = [None] * frame_length - # 关闭梯度计算,用于推理阶段节省内存并加速 + # 统一关闭梯度计算,用于推理阶段节省内存并加速 with torch.no_grad(): # 将处理好的帧通过编码器,产生特征表示 feats = self.model.encoder(feats.view(frame_length, 3, self.model_input_height, self.model_input_width)) @@ -133,33 +136,27 @@ class STTNInpaint: _, c, feat_h, feat_w = feats.size() # 调整特征形状以匹配模型的期望输入 feats = feats.view(1, frame_length, c, feat_h, feat_w) - # 获取重绘区域 - # 在设定的邻居帧步幅内循环处理视频 - for f in range(0, frame_length, self.neighbor_stride): - # 计算邻近帧的ID - neighbor_ids = [i for i in range(max(0, f - self.neighbor_stride), min(frame_length, f + self.neighbor_stride + 1))] - # 获取参考帧的索引 - ref_ids = self.get_ref_index(neighbor_ids, frame_length) - # 同样关闭梯度计算 - with torch.no_grad(): + # 在设定的邻居帧步幅内循环处理视频 + for f in range(0, frame_length, self.neighbor_stride): + # 计算邻近帧的ID + neighbor_ids = [i for i in range(max(0, f - self.neighbor_stride), min(frame_length, f + self.neighbor_stride + 1))] + # 获取参考帧的索引 + ref_ids = self.get_ref_index(neighbor_ids, frame_length) # 通过模型推断特征并传递给解码器以生成完成的帧 pred_feat = self.model.infer(feats[0, neighbor_ids + ref_ids, :, :, :]) - # 将预测的特征通过解码器生成图片,并应用激活函数tanh,然后分离出张量 - pred_img = torch.tanh(self.model.decoder(pred_feat[:len(neighbor_ids), :, :, :])).detach() - # 将结果张量重新缩放到0到255的范围内(图像像素值) + # 将预测的特征通过解码器生成图片,并应用激活函数tanh + pred_img = torch.tanh(self.model.decoder(pred_feat[:len(neighbor_ids), :, :, :])) + # 将结果张量重新缩放到0到255的范围内 pred_img = (pred_img + 1) / 2 # 将张量移动回CPU并转为NumPy数组 pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 # 遍历邻近帧 for i in range(len(neighbor_ids)): idx = neighbor_ids[i] - # 将预测的图片转换为无符号8位整数格式 - img = np.array(pred_img[i]).astype(np.uint8) + img = pred_img[i].astype(np.uint8) if comp_frames[idx] is None: - # 如果该位置为空,则赋值为新计算出的图片 comp_frames[idx] = img else: - # 如果此位置之前已有图片,则将新旧图片混合以提高质量 comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5 # 返回处理完成的帧序列 return comp_frames @@ -203,6 +200,8 @@ class STTNAutoInpaint: try: # 读取视频帧信息 reader, frame_info = self.read_frame_info_from_video() + # 使用帧预读取,I/O 与推理重叠 + prefetcher = FramePrefetcher(reader) if input_sub_remover is not None: ab_sections = input_sub_remover.ab_sections @@ -212,24 +211,35 @@ class STTNAutoInpaint: # 创建视频写入对象,用于输出修复后的视频 writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori'])) - # 计算需要迭代修复视频的次数 - rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1 # 计算分割高度,用于确定修复区域的大小 split_h = int(frame_info['W_ori'] * 3 / 16) - + if input_mask is None: # 读取掩码 mask = self.sttn_inpaint.read_mask(self.mask_path) else: _, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY) mask = mask[:, :, None] - + # 得到修复区域位置 inpaint_area = get_inpaint_area_by_mask(frame_info['W_ori'], frame_info['H_ori'], split_h, mask) + # 根据可用显存动态调整 clip_gap,避免 OOM + effective_clip_gap = self.clip_gap + vram_mb = HardwareAccelerator.instance().get_available_vram_mb() + if vram_mb > 0: + # 估算每帧约需 (W * H * 3 * 4) bytes,clip_gap帧约需 clip_gap * W * H * 12 bytes(含中间张量) + bytes_per_frame = frame_info['W_ori'] * frame_info['H_ori'] * 12 + max_frames_by_vram = int(vram_mb * 1024 * 1024 / bytes_per_frame) + max_frames_by_vram = max(max_frames_by_vram, 10) # 至少10帧 + effective_clip_gap = min(self.clip_gap, max_frames_by_vram) + if effective_clip_gap < self.clip_gap: + tqdm.write(f'GPU VRAM: {vram_mb:.0f}MB, adjusting clip_gap: {self.clip_gap} -> {effective_clip_gap}') + # 计算需要迭代修复视频的次数 + rec_time = frame_info['len'] // effective_clip_gap if frame_info['len'] % effective_clip_gap == 0 else frame_info['len'] // effective_clip_gap + 1 # 遍历每一次的迭代次数 for i in range(rec_time): - start_f = i * self.clip_gap # 起始帧位置 - end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置 + start_f = i * effective_clip_gap # 起始帧位置 + end_f = min((i + 1) * effective_clip_gap, frame_info['len']) # 结束帧位置 tqdm.write(f'Processing: {start_f + 1} - {end_f} / Total: {frame_info['len']}') frames_hr = [] # 高分辨率帧列表 @@ -243,7 +253,7 @@ class STTNAutoInpaint: # 读取和修复高分辨率帧 valid_frames_count = 0 for j in range(start_f, end_f): - success, image = reader.read() + success, image = prefetcher.read() if not success: print(f"Warning: Failed to read frame {j}.") break @@ -309,10 +319,17 @@ class STTNAutoInpaint: input_sub_remover.update_progress(tbar, increment=1) if original_frame is not None and input_sub_remover.gui_mode: input_sub_remover.update_preview_with_comp(original_frame, frame) + # 每个chunk处理完后清理GPU缓存 + del frames_hr, frames, comps + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() except Exception as e: print(f"Error during video processing: {str(e)}") # 不抛出异常,允许程序继续执行 finally: + if reader: + prefetcher.release() if writer: writer.release() diff --git a/backend/inpaint/sttn_det_inpaint.py b/backend/inpaint/sttn_det_inpaint.py index 730154b..65b8311 100644 --- a/backend/inpaint/sttn_det_inpaint.py +++ b/backend/inpaint/sttn_det_inpaint.py @@ -126,38 +126,35 @@ class STTNDetInpaint: frame_length = len(frames) # 对帧进行预处理转换为张量,并进行归一化 feats = _to_tensors(frames).unsqueeze(0) * 2 - 1 - + binary_masks = [np.expand_dims((np.array(m) > 0.5).astype(np.uint8), 2) for m in masks] # 将掩码转换为张量 - masks = (_to_tensors(masks).unsqueeze(0) > 0.5).float() - + masks_tensor = (_to_tensors(masks).unsqueeze(0) > 0.5).float() + # 把特征张量转移到指定的设备(CPU或GPU) - feats, masks = feats.to(self.device), masks.to(self.device) + feats, masks_tensor = feats.to(self.device), masks_tensor.to(self.device) # 初始化一个与视频长度相同的列表,用于存储处理完成的帧 comp_frames = [None] * frame_length - # 关闭梯度计算,用于推理阶段节省内存并加速 + # 统一关闭梯度计算,用于推理阶段节省内存并加速 with torch.no_grad(): # 将处理好的帧通过编码器,产生特征表示 - feats = self.model.encoder((feats*(1-masks).float()).view(frame_length, 3, self.model_input_height, self.model_input_width)) + feats = self.model.encoder((feats*(1-masks_tensor).float()).view(frame_length, 3, self.model_input_height, self.model_input_width)) # 获取特征维度信息 _, c, feat_h, feat_w = feats.size() # 调整特征形状以匹配模型的期望输入 feats = feats.view(1, frame_length, c, feat_h, feat_w) - # 获取重绘区域 - # 在设定的邻居帧步幅内循环处理视频 - for f in range(0, frame_length, self.neighbor_stride): - # 计算邻近帧的ID - neighbor_ids = [i for i in range(max(0, f - self.neighbor_stride), min(frame_length, f + self.neighbor_stride + 1))] - # 获取参考帧的索引 - ref_ids = self.get_ref_index(neighbor_ids, frame_length) - # 同样关闭梯度计算 - with torch.no_grad(): + # 在设定的邻居帧步幅内循环处理视频 + for f in range(0, frame_length, self.neighbor_stride): + # 计算邻近帧的ID + neighbor_ids = [i for i in range(max(0, f - self.neighbor_stride), min(frame_length, f + self.neighbor_stride + 1))] + # 获取参考帧的索引 + ref_ids = self.get_ref_index(neighbor_ids, frame_length) # 通过模型推断特征并传递给解码器以生成完成的帧 pred_feat = self.model.infer( - feats[0, neighbor_ids + ref_ids, :, :, :], masks[0, neighbor_ids + ref_ids, :, :, :]) + feats[0, neighbor_ids + ref_ids, :, :, :], masks_tensor[0, neighbor_ids + ref_ids, :, :, :]) - # 将预测的特征通过解码器生成图片,并应用激活函数tanh,然后分离出张量 - pred_img = torch.tanh(self.model.decoder(pred_feat[:len(neighbor_ids), :, :, :])).detach() + # 将预测的特征通过解码器生成图片,并应用激活函数tanh + pred_img = torch.tanh(self.model.decoder(pred_feat[:len(neighbor_ids), :, :, :])) # 将结果张量重新缩放到0到255的范围内(图像像素值) pred_img = (pred_img + 1) / 2 # 将张量移动回CPU并转为NumPy数组 @@ -166,13 +163,10 @@ class STTNDetInpaint: for i in range(len(neighbor_ids)): idx = neighbor_ids[i] # 将预测的图片转换为无符号8位整数格式 - img = np.array(pred_img[i]).astype( - np.uint8)*binary_masks[idx] + frames[idx] * (1-binary_masks[idx]) + img = pred_img[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * (1 - binary_masks[idx]) if comp_frames[idx] is None: - # 如果该位置为空,则赋值为新计算出的图片 comp_frames[idx] = img else: - # 如果此位置之前已有图片,则将新旧图片混合以提高质量 comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5 # 返回处理完成的帧序列 return comp_frames diff --git a/backend/tools/hardware_accelerator.py b/backend/tools/hardware_accelerator.py index 5cc0e67..e4e0d31 100644 --- a/backend/tools/hardware_accelerator.py +++ b/backend/tools/hardware_accelerator.py @@ -106,6 +106,27 @@ class HardwareAccelerator: def set_enabled(self, enable): self.__enabled = enable + def get_available_vram_mb(self): + """获取可用 GPU 显存(MB),无 GPU 返回 0""" + if not self.__enabled: + return 0 + if self.__cuda: + try: + free_vram = torch.cuda.mem_get_info()[0] # (free, total) + return free_vram / (1024 * 1024) + except Exception: + return 0 + if self.__mps: + try: + # MPS 没有直接查询接口,使用系统内存作为参考 + import subprocess + result = subprocess.run(['sysctl', '-n', 'hw.memsize'], capture_output=True, text=True) + total_mem = int(result.stdout.strip()) / (1024 * 1024) + return total_mem * 0.5 # 保守估计可用一半 + except Exception: + return 0 + return 0 + @property def device(self): """