diff --git a/.github/workflows/build-docker.yml b/.github/workflows/build-docker.yml
index affdb8c..8487224 100644
--- a/.github/workflows/build-docker.yml
+++ b/.github/workflows/build-docker.yml
@@ -36,6 +36,8 @@ jobs:
version: "12.8"
- type: directml
version: "latest"
+ - type: cpu
+ version: "latest"
steps:
@@ -94,4 +96,5 @@ jobs:
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- repository: ${{ secrets.DOCKERHUB_USERNAME }}/video-subtitle-remover
\ No newline at end of file
+ repository: ${{ secrets.DOCKERHUB_USERNAME }}/video-subtitle-remover
+ enable-url-completion: true
\ No newline at end of file
diff --git a/.github/workflows/build-windows-cpu.yml b/.github/workflows/build-windows-cpu.yml
new file mode 100644
index 0000000..1156e2e
--- /dev/null
+++ b/.github/workflows/build-windows-cpu.yml
@@ -0,0 +1,98 @@
+name: Build Windows CPU
+
+on:
+ push:
+ branches:
+ - '**'
+ workflow_dispatch:
+ inputs:
+ ssh:
+ description: 'SSH connection to Actions'
+ required: false
+ default: false
+
+permissions:
+ contents: write
+
+jobs:
+ build:
+ runs-on: windows-2019
+ steps:
+ - uses: actions/checkout@v4
+ - name: 读取 VERSION
+ id: version
+ run: |
+ VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
+ echo "VERSION=$VERSION" >> $GITHUB_ENV
+ echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
+ shell: bash
+ # - name: 检查 tag 是否已存在
+ # run: |
+ # TAG_NAME="${VERSION}"
+ # if git ls-remote --tags origin | grep -q "refs/tags/$TAG_NAME"; then
+ # echo "Tag $TAG_NAME 已存在,发布中止"
+ # exit 1
+ # fi
+ # shell: bash
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.12'
+ cache: 'pip' # caching pip dependencies
+ - name: 禁用硬件加速
+ run: sed -i 's/HARDWARD_ACCELERATION_OPTION *= *.*/HARDWARD_ACCELERATION_OPTION = False/g' backend/config.py
+ shell: bash
+ - run: pip install paddlepaddle==3.0.0
+ - run: pip install torch==2.7.0 torchvision==0.22.0
+ - run: pip install -r requirements.txt
+ - run: pip freeze > requirements.txt
+ - run: pip install QPT==1.0b8 setuptools
+ - name: 获取 site-packages 路径
+ shell: bash
+ run: |
+ SITE_PACKAGES=$(python -c "import site, os; print(os.path.join(site.getsitepackages()[0], 'Lib', 'site-packages'))")
+ SITE_PACKAGES_UNIX=$(cygpath -u "$SITE_PACKAGES")
+ echo "site-packages路径: $SITE_PACKAGES"
+ echo "site-packages UNIX路径: $SITE_PACKAGES_UNIX"
+ echo "SITE_PACKAGES_UNIX=$SITE_PACKAGES_UNIX" >> $GITHUB_ENV
+ echo "SITE_PACKAGES=$SITE_PACKAGES" >> $GITHUB_ENV
+ - name: 修复QPT内部错误
+ run: sed -i '98c\ try:\n dep = pkg.requires()\n except TypeError:\n continue' ${SITE_PACKAGES_UNIX}/qpt/kernel/qpackage.py
+ shell: bash
+ - name: Start SSH via tmate
+ if: (github.event.inputs.ssh == 'true' && github.event.inputs.ssh != 'false') || contains(github.event.action, 'ssh')
+ uses: mxschmitt/action-tmate@v3
+ - run: |
+ python backend/tools/makedist.py && \
+ mv ../vsr_out ./vsr_out && \
+ cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/ && \
+ sed -i 's/force=False)/force=False)\n exit(-1)/g' ./vsr_out/*/Python/Lib/site-packages/qpt/executor.py
+ env:
+ QPT_Action: "True"
+ shell: bash
+ - name: 上传 Debug 文件夹到 Artifacts
+ uses: actions/upload-artifact@v4
+ with:
+ name: vsr-v${{ env.VERSION }}-windows-cpu-debug
+ path: vsr_out/Debug/
+ - name: 上传 Release 文件夹到 Artifacts
+ uses: actions/upload-artifact@v4
+ with:
+ name: vsr-v${{ env.VERSION }}-windows-cpu-release
+ path: vsr_out/Release/
+ - name: 打包 Release 文件夹
+ run: |
+ cd vsr_out/Release
+ 7z a -t7z -mx=9 -m0=LZMA2 -ms=on -mfb=64 -md=32m -mmt=on -v2000m vsr-v${{ env.VERSION }}-windows-cpu.7z * && \
+ # 检测是否只有一个分卷
+ if [ -f vsr-v${{ env.VERSION }}-windows-cpu.7z.001 ] && [ ! -f vsr-v${{ env.VERSION }}-windows-cpu.7z.002 ]; then \
+ mv vsr-v${{ env.VERSION }}-windows-cpu.7z.001 vsr-v${{ env.VERSION }}-windows-cpu.7z; fi
+ shell: bash
+ - name: Release
+ uses: softprops/action-gh-release@v1
+ with:
+ prerelease: true
+ tag_name: ${{ env.VERSION }}
+ target_commitish: ${{ github.sha }}
+ name: 硬字幕去除器 v${{ env.VERSION }}
+ files: |
+ vsr_out/Release/vsr-v${{ env.VERSION }}-windows-cpu.7z*
\ No newline at end of file
diff --git a/.github/workflows/build-windows-cuda-11.8.yml b/.github/workflows/build-windows-cuda-11.8.yml
index 81d3458..cc5acb9 100644
--- a/.github/workflows/build-windows-cuda-11.8.yml
+++ b/.github/workflows/build-windows-cuda-11.8.yml
@@ -61,7 +61,8 @@ jobs:
- run: |
python backend/tools/makedist.py --cuda 11.8 && \
mv ../vsr_out ./vsr_out && \
- cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
+ cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/ && \
+ sed -i 's/force=False)/force=False)\n exit(-1)/g' ./vsr_out/*/Python/Lib/site-packages/qpt/executor.py
env:
QPT_Action: "True"
shell: bash
diff --git a/.github/workflows/build-windows-cuda-12.6.yml b/.github/workflows/build-windows-cuda-12.6.yml
index 5da6eb6..24ea342 100644
--- a/.github/workflows/build-windows-cuda-12.6.yml
+++ b/.github/workflows/build-windows-cuda-12.6.yml
@@ -61,7 +61,8 @@ jobs:
- run: |
python backend/tools/makedist.py --cuda 12.6 && \
mv ../vsr_out ./vsr_out && \
- cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
+ cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/ && \
+ sed -i 's/force=False)/force=False)\n exit(-1)/g' ./vsr_out/*/Python/Lib/site-packages/qpt/executor.py
env:
QPT_Action: "True"
shell: bash
diff --git a/.github/workflows/build-windows-cuda-12.8.yml b/.github/workflows/build-windows-cuda-12.8.yml
index fecc588..b78e44c 100644
--- a/.github/workflows/build-windows-cuda-12.8.yml
+++ b/.github/workflows/build-windows-cuda-12.8.yml
@@ -61,7 +61,8 @@ jobs:
- run: |
python backend/tools/makedist.py --cuda 12.8 && \
mv ../vsr_out ./vsr_out && \
- cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
+ cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/ && \
+ sed -i 's/force=False)/force=False)\n exit(-1)/g' ./vsr_out/*/Python/Lib/site-packages/qpt/executor.py
env:
QPT_Action: "True"
shell: bash
diff --git a/.github/workflows/build-windows-directml.yml b/.github/workflows/build-windows-directml.yml
index f81dd06..c19a3d1 100644
--- a/.github/workflows/build-windows-directml.yml
+++ b/.github/workflows/build-windows-directml.yml
@@ -61,7 +61,8 @@ jobs:
- run: |
python backend/tools/makedist.py --directml && \
mv ../vsr_out ./vsr_out && \
- cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
+ cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/ && \
+ sed -i 's/force=False)/force=False)\n exit(-1)/g' ./vsr_out/*/Python/Lib/site-packages/qpt/executor.py
env:
QPT_Action: "True"
shell: bash
diff --git a/.gitignore b/.gitignore
index 9a3671c..e69685b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -369,8 +369,9 @@ test_*.mp4
test*_no_sub*.mp4
/test/coods/
/local_test/
-/backend/models/video/ProPainter.pth
+/backend/models/propainter/ProPainter.pth
/backend/models/big-lama/big-lama.pt
/test/debug/
/backend/tools/train/release_model/
-model.onnx
\ No newline at end of file
+model.onnx
+/config/config.json
diff --git a/README.md b/README.md
index 1540904..b2ccea9 100755
--- a/README.md
+++ b/README.md
@@ -4,7 +4,8 @@


-
+
+[](https://hub.docker.com/r/eritpchy/video-subtitle-remover)
Video-subtitle-remover (VSR) 是一款基于AI技术,将视频中的硬字幕去除的软件。
主要实现了以下功能:
@@ -14,7 +15,7 @@ Video-subtitle-remover (VSR) 是一款基于AI技术,将视频中的硬字幕
- 支持全视频自动去除所有文本(不传入位置)
- 支持多选图片批量去除水印文本
-

+
**使用说明:**
@@ -32,32 +33,52 @@ Windows GPU版本v1.1.0(GPU):
**预构建包对比说明**:
| 预构建包名 | Python | Paddle | Torch | 环境 | 支持的计算能力范围|
|---------------|------------|--------------|--------------|-----------------------------|----------|
-| `vsr-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows 非Nvidia显卡 | 通用 |
-| `vsr-windows-nvidia-cuda-11.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 – 8.9 |
-| `vsr-windows-nvidia-cuda-12.6.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 – 8.9 |
-| `vsr-windows-nvidia-cuda-12.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 – 9.0+ |
+| `vsr-windows-cpu.7z` | 3.12 | 3.0.0 | 2.7.0 | 通用 | 通用 |
+| `vsr-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows 非Nvidia显卡 | 通用 |
+| `vsr-windows-nvidia-cuda-11.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 – 8.9 |
+| `vsr-windows-nvidia-cuda-12.6.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 – 8.9 |
+| `vsr-windows-nvidia-cuda-12.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 – 9.0+ |
> NVIDIA官方提供了各GPU型号的计算能力列表,您可以参考链接: [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) 查看你的GPU适合哪个CUDA版本
**Docker版本:**
```shell
# Nvidia 10 20 30系显卡
- docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda11.8
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-cuda11.8 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
# Nvidia 40系显卡
- docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.6
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-cuda12.6 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
# Nvidia 50系显卡
- docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.8
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-cuda12.8 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
# AMD / Intel 独显 集显
- docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-directml
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-directml python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
- # 演示视频, 输入
- /vsr/test/test.mp4
+ # CPU
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-cpu python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
+
+ # 导出视频
docker cp vsr:/vsr/test/test_no_sub.mp4 ./
```
+**命令行参数:**
+```
+Video Subtitle Remover Command Line Tool
+
+options:
+ -h, --help show this help message and exit
+ --input INPUT, -i INPUT
+ Input video file path
+ --output OUTPUT, -o OUTPUT
+ Output video file path (optional)
+ --ymin YMIN Subtitle area ymin (optional)
+ --ymax YMAX Subtitle area ymax (optional)
+ --xmin XMIN Subtitle area xmin (optional)
+ --xmax XMAX Subtitle area xmax (optional)
+ --inpaint-mode {sttn-auto,sttn-det,lama,propainter,opencv}
+ Inpaint mode, default is sttn-auto
+```
## 演示
- GUI版:
@@ -116,7 +137,7 @@ cd <源码所在目录>
#### 4. 安装合适的运行环境
-本项目支持 CUDA(NVIDIA显卡加速)和 DirectML(AMD、Intel等GPU/APU加速)两种运行模式。
+本项目支持 CUDA(NVIDIA显卡加速)、CPU(无 GPU)和 DirectML(AMD、Intel等GPU/APU加速)三种运行模式。
##### (1) CUDA(NVIDIA 显卡用户)
@@ -152,6 +173,16 @@ cd <源码所在目录>
pip install -r requirements.txt
```
+- Linux系统还需要安装
+
+ ```shell
+ # for cuda 12.x
+ pip install onnxruntime-gpu==1.22.0
+ # for cuda 11.x
+ pip install onnxruntime-gpu==1.20.1 --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-11/pypi/simple/
+ ```
+ > 详情见: [Install ONNX Runtime](https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x)
+
##### (2) DirectML(AMD、Intel等GPU/APU加速卡用户)
- 适用于 Windows 设备的 AMD/NVIDIA/Intel GPU。
@@ -161,7 +192,14 @@ cd <源码所在目录>
pip install -r requirements.txt
pip install torch_directml==0.2.5.dev240914
```
+##### (3) CPU 运行(无 GPU 加速)
+- 适用于没有 GPU 或不希望使用 GPU 的情况。
+ ```shell
+ pip install paddlepaddle==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
+ pip install torch==2.7.0 torchvision==0.22.0
+ pip install -r requirements.txt
+ ```
#### 4. 运行程序
diff --git a/README_en.md b/README_en.md
index 1bbd841..72711c5 100755
--- a/README_en.md
+++ b/README_en.md
@@ -5,6 +5,7 @@



+[](https://hub.docker.com/r/eritpchy/video-subtitle-remover)
Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subtitles from videos. It mainly implements the following functionalities:
@@ -14,7 +15,7 @@ Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subt
- Supports automatic removal of all text throughout the entire video (without inputting a position).
- Supports multi-selection of images for batch removal of watermark text.
-
+
> Download the .zip package directly, extract, and run it. If it cannot run, follow the tutorial below to try installing the conda environment and running the source code.
@@ -30,33 +31,53 @@ Windows GPU Version v1.1.0 (GPU):
**Pre-built Package Comparison**:
| Pre-built Package Name | Python | Paddle | Torch | Environment | Supported Compute Capability Range |
-|----------------------------------|--------|--------|--------|-----------------------------------|------------------------------------|
-| `vse-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows without Nvidia GPU | Universal |
-| `vse-windows-nvidia-cuda-11.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 – 8.9 |
-| `vse-windows-nvidia-cuda-12.6.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 – 8.9 |
-| `vse-windows-nvidia-cuda-12.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 – 9.0+ |
+|----------------------------------|------|-------|--------|-----------------------------------|------------------------------------|
+| `vse-windows-cpu.7z` | 3.12 | 3.0.0 | 2.7.0 | Universal | Universal |
+| `vse-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows without Nvidia GPU | Universal |
+| `vse-windows-nvidia-cuda-11.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 – 8.9 |
+| `vse-windows-nvidia-cuda-12.6.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 – 8.9 |
+| `vse-windows-nvidia-cuda-12.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 – 9.0+ |
> NVIDIA provides a list of supported compute capabilities for each GPU model. You can refer to the following link: [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) to check which CUDA version is compatible with your GPU.
**Docker Versions:**
```shell
# Nvidia 10, 20, 30 Series Graphics Cards
- docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda11.8
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-cuda11.8 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
# Nvidia 40 Series Graphics Cards
- docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.6
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-cuda12.6 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
# Nvidia 50 Series Graphics Cards
- docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.8
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-cuda12.8 python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
# AMD / Intel Dedicated or Integrated Graphics
- docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-directml
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-directml python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
- # Demo video, input
- /vsr/test/test.mp4
+ # CPU
+ docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.2.0-cpu python backend/main.py -i test/test.mp4 -o test/test_no_sub.mp4
+
+ # Copy to host
docker cp vsr:/vsr/test/test_no_sub.mp4 ./
```
+**Commandline:**
+```
+Video Subtitle Remover Command Line Tool
+
+options:
+ -h, --help show this help message and exit
+ --input INPUT, -i INPUT
+ Input video file path
+ --output OUTPUT, -o OUTPUT
+ Output video file path (optional)
+ --ymin YMIN Subtitle area ymin (optional)
+ --ymax YMAX Subtitle area ymax (optional)
+ --xmin XMIN Subtitle area xmin (optional)
+ --xmax XMAX Subtitle area xmax (optional)
+ --inpaint-mode {sttn-auto,sttn-det,lama,propainter,opencv}
+ Inpaint mode, default is sttn-auto
+```
## Demonstration
- GUI:
@@ -114,7 +135,7 @@ cd
#### 4. Install the Appropriate Runtime Environment
-This project supports two runtime modes: CUDA (NVIDIA GPU acceleration) and DirectML (AMD, Intel, and other GPUs/APUs).
+This project supports three runtime modes: CUDA (NVIDIA GPU acceleration), CPU (no GPU) and DirectML (AMD, Intel, and other GPUs/APUs).
##### (1) CUDA (For NVIDIA GPU users)
@@ -151,6 +172,16 @@ This project supports two runtime modes: CUDA (NVIDIA GPU acceleration) and Dire
pip install -r requirements.txt
```
+- For Linux systems, you also need to install
+
+ ```shell
+ # for cuda 12.x
+ pip install onnxruntime-gpu==1.22.0
+ # for cuda 11.x
+ pip install onnxruntime-gpu==1.20.1 --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-11/pypi/simple/
+ ```
+ > For more details, see: [Install ONNX Runtime](https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x)
+
##### (2) DirectML (For AMD, Intel, and other GPU/APU users)
- Suitable for Windows devices with AMD/NVIDIA/Intel GPUs.
@@ -161,6 +192,14 @@ This project supports two runtime modes: CUDA (NVIDIA GPU acceleration) and Dire
pip install -r requirements_directml.txt
```
+##### (3) CPU Only (For systems without GPU or those not wanting to use GPU acceleration)
+
+- Suitable for systems without GPU or those that do not wish to use GPU.
+ ```shell
+ pip install paddlepaddle==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
+ pip install torch==2.7.0 torchvision==0.22.0
+ pip install -r requirements.txt
+ ```
#### 4. Run the program
diff --git a/backend/config.py b/backend/config.py
index bef89a1..72ce60a 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -1,162 +1,122 @@
-import warnings
-from enum import Enum, unique
-warnings.filterwarnings('ignore')
+
import os
-import torch
-import logging
-import platform
-import stat
-from fsplit.filesplit import Filesplit
-import onnxruntime as ort
+from pathlib import Path
+from qfluentwidgets import (qconfig, ConfigItem, QConfig, OptionsValidator, BoolValidator, OptionsConfigItem,
+ EnumSerializer, RangeValidator, RangeConfigItem)
+from backend.tools.constant import InpaintMode, SubtitleDetectMode
+import configparser
# 项目版本号
-VERSION = "1.1.1"
-# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
-logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
-logging.disable(logging.WARNING) # 关闭WARNING日志的打印
-try:
- import torch_directml
- device = torch_directml.device(torch_directml.default_device())
- USE_DML = True
-except:
- USE_DML = False
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
-STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
-VIDEO_INPAINT_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'video')
-MODEL_VERSION = 'V4'
-DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')
-DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
+VERSION = "1.2.2"
+PROJECT_HOME_URL = "https://github.com/YaoFANGUK/video-subtitle-remover"
+PROJECT_ISSUES_URL = PROJECT_HOME_URL + "/issues"
+PROJECT_RELEASES_URL = PROJECT_HOME_URL + "/releases"
+PROJECT_UPDATE_URLS = [
+ "https://api.github.com/repos/YaoFANGUK/video-subtitle-remover/releases/latest",
+ "https://accelerate.xdow.net/api/repos/YaoFANGUK/video-subtitle-remover/releases/latest",
+]
-# 查看该路径下是否有模型完整文件,没有的话合并小文件生成完整文件
-if 'big-lama.pt' not in (os.listdir(LAMA_MODEL_PATH)):
- fs = Filesplit()
- fs.merge(input_dir=LAMA_MODEL_PATH)
+# 硬件加速选项开关
+HARDWARD_ACCELERATION_OPTION = True
-if 'inference.pdiparams' not in os.listdir(DET_MODEL_PATH):
- fs = Filesplit()
- fs.merge(input_dir=DET_MODEL_PATH)
+class Config(QConfig):
+ # 界面语言设置
+ intefaceTexts = {
+ '简体中文': 'ch',
+ '繁體中文': 'chinese_cht',
+ 'English': 'en',
+ '한국어': 'ko',
+ '日本語': 'japan',
+ 'Tiếng Việt': 'vi',
+ 'Español': 'es'
+ }
+ interface = OptionsConfigItem("Window", "Interface", "ChineseSimplified", OptionsValidator(intefaceTexts.values()), restart = True)
+
+ # 窗口位置和大小
+ windowX = ConfigItem("Window", "X", None)
+ windowY = ConfigItem("Window", "Y", None)
+ windowW = ConfigItem("Window", "Width", 1200)
+ windowH = ConfigItem("Window", "Height", 1200)
-if 'ProPainter.pth' not in os.listdir(VIDEO_INPAINT_MODEL_PATH):
- fs = Filesplit()
- fs.merge(input_dir=VIDEO_INPAINT_MODEL_PATH)
+ subtitleSelectionAreaX = ConfigItem("Main", "SubtitleSelectionAreaX", 0.15)
+ subtitleSelectionAreaY = ConfigItem("Main", "SubtitleSelectionAreaY", 0.88)
+ subtitleSelectionAreaW = ConfigItem("Main", "SubtitleSelectionAreaW", 0.70)
+ subtitleSelectionAreaH = ConfigItem("Main", "SubtitleSelectionAreaH", 0.11)
-# 指定ffmpeg可执行程序路径
-sys_str = platform.system()
-if sys_str == "Windows":
- ffmpeg_bin = os.path.join('win_x64', 'ffmpeg.exe')
-elif sys_str == "Linux":
- ffmpeg_bin = os.path.join('linux_x64', 'ffmpeg')
-else:
- ffmpeg_bin = os.path.join('macos', 'ffmpeg')
-FFMPEG_PATH = os.path.join(BASE_DIR, '', 'ffmpeg', ffmpeg_bin)
-
-if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64')):
- fs = Filesplit()
- fs.merge(input_dir=os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'))
-# 将ffmpeg添加可执行权限
-os.chmod(FFMPEG_PATH, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
-os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
-
-# 是否使用ONNX(DirectML/AMD/Intel)
-ONNX_PROVIDERS = []
-available_providers = ort.get_available_providers()
-for provider in available_providers:
- if provider in [
- "CPUExecutionProvider"
- ]:
- continue
- if provider not in [
- "DmlExecutionProvider", # DirectML,适用于 Windows GPU
- "ROCMExecutionProvider", # AMD ROCm
- "MIGraphXExecutionProvider", # AMD MIGraphX
- "VitisAIExecutionProvider", # AMD VitisAI,适用于 RyzenAI & Windows, 实测和DirectML性能似乎差不多
- "OpenVINOExecutionProvider", # Intel GPU
- "MetalExecutionProvider", # Apple macOS
- "CoreMLExecutionProvider", # Apple macOS
- "CUDAExecutionProvider", # Nvidia GPU
- ]:
- continue
- ONNX_PROVIDERS.append(provider)
-# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
-
-
-@unique
-class InpaintMode(Enum):
"""
- 图像重绘算法枚举
+ MODE可选算法类型
+ - InpaintMode.STTN_AUTO 智能擦除版
+ - InpaintMode.STTN_DET 带字幕检测版, 无智能擦除
+ - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测
+ - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
"""
- STTN = 'sttn'
- LAMA = 'lama'
- PROPAINTER = 'propainter'
+ # 【设置inpaint算法】
+ inpaintMode = OptionsConfigItem("Main", "InpaintMode", InpaintMode.STTN_AUTO, OptionsValidator(InpaintMode), EnumSerializer(InpaintMode))
+
+ subtitleDetectMode = OptionsConfigItem("Main", "SubtitleDetectMode", SubtitleDetectMode.Accurate, OptionsValidator(SubtitleDetectMode), EnumSerializer(SubtitleDetectMode))
+ # 【设置像素点偏差】
+ # 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
+ subtitleYXAxisDifferencePixel = RangeConfigItem("Main", "SubtitleYXAxisDifferencePixel", 10, RangeValidator(0, 300))
+ # 用于放大mask大小,防止自动检测的文本框过小,inpaint阶段出现文字边,有残留
+ subtitleAreaDeviationPixel = RangeConfigItem("Main", "SubtitleAreaDeviationPixel", 10, RangeValidator(1, 300))
+ # 同于判断两个文本框是否为同一行字幕,高度差距指定像素点以内认为是同一行
+ subtitleAreaYAxisDifferencePixel = RangeConfigItem("Main", "SubtitleAreaYAxisDifferencePixel", 20, RangeValidator(0, 300))
+ # 用于判断两个字幕文本的矩形框是否相似,如果X轴和Y轴偏差都在指定阈值内,则认为时同一个文本框
+ subtitleAreaPixelToleranceYPixel = RangeConfigItem("Main", "SubtitleAreaPixelToleranceYPixel", 20, RangeValidator(0, 300))
+ subtitleAreaPixelToleranceXPixel = RangeConfigItem("Main", "SubtitleAreaPixelToleranceXPixel", 20, RangeValidator(0, 300))
+ subtitleTimelineBackwardFrameCount = RangeConfigItem("Main", "SubtitleTimelineBackwardFrameCount", 3, RangeValidator(0, 300))
+ subtitleTimelineForwardFrameCount = RangeConfigItem("Main", "subtitleTimelineForwardFrameCount", 3, RangeValidator(0, 300))
+ # 以下参数仅适用STTN算法时,才生效
+ """
+ 1. STTN_SKIP_DETECTION
+ 含义:是否使用跳过检测
+ 效果:设置为True跳过字幕检测,会省去很大时间,但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了
-# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
-# 是否使用h264编码,如果需要安卓手机分享生成的视频,请打开该选项
-USE_H264 = True
+ 2. STTN_NEIGHBOR_STRIDE
+ 含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域,STTN_NEIGHBOR_STRIDE=5,那么算法会使用第45帧、第40帧等作为参照。
+ 效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。
-# ×××××××××× 通用设置 start ××××××××××
-"""
-MODE可选算法类型
-- InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
-- InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测
-- InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
-"""
-# 【设置inpaint算法】
-MODE = InpaintMode.STTN
-# 【设置像素点偏差】
-# 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
-THRESHOLD_HEIGHT_WIDTH_DIFFERENCE = 10
-# 用于放大mask大小,防止自动检测的文本框过小,inpaint阶段出现文字边,有残留
-SUBTITLE_AREA_DEVIATION_PIXEL = 20
-# 同于判断两个文本框是否为同一行字幕,高度差距指定像素点以内认为是同一行
-THRESHOLD_HEIGHT_DIFFERENCE = 20
-# 用于判断两个字幕文本的矩形框是否相似,如果X轴和Y轴偏差都在指定阈值内,则认为时同一个文本框
-PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差的像素点数
-PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
-# ×××××××××× 通用设置 end ××××××××××
+ 3. STTN_REFERENCE_LENGTH
+ 含义:参数帧数量,STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息
+ 效果:调大会增加显存占用,处理效果变好,但是处理速度变慢
-# ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
-# 以下参数仅适用STTN算法时,才生效
-"""
-1. STTN_SKIP_DETECTION
-含义:是否使用跳过检测
-效果:设置为True跳过字幕检测,会省去很大时间,但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了
+ 4. STTN_MAX_LOAD_NUM
+ 含义:STTN算法每次最多加载的视频帧数量
+ 效果:设置越大速度越慢,但效果越好
+ 注意:要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
+ """
+ # 参考帧步长
+ sttnNeighborStride = RangeConfigItem("Sttn", "NeighborStride", 5, RangeValidator(1, 100))
+ # 参考帧数量
+ sttnReferenceLength = RangeConfigItem("Sttn", "ReferenceLength", 10, RangeValidator(1, 100))
+ # 设置STTN算法最大同时处理的帧数量
+ sttnMaxLoadNum = RangeConfigItem("Sttn", "MaxLoadNum", 50, RangeValidator(1, 300))
+ getSttnMaxLoadNum = lambda self: max(self.sttnMaxLoadNum.value, self.sttnNeighborStride.value * self.sttnReferenceLength.value)
+
+ # 以下参数仅适用PROPAINTER算法时,才生效
+ # 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高
+ # 1280x720p视频设置80需要25G显存,设置50需要19G显存
+ # 720x480p视频设置80需要8G显存,设置50需要7G显存
+ propainterMaxLoadNum = RangeConfigItem("ProPainter", "MaxLoadNum", 70, RangeValidator(1, 300))
-2. STTN_NEIGHBOR_STRIDE
-含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域,STTN_NEIGHBOR_STRIDE=5,那么算法会使用第45帧、第40帧等作为参照。
-效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。
+ # 是否使用硬件加速
+ hardwareAcceleration = ConfigItem("Main", "HardwareAcceleration", HARDWARD_ACCELERATION_OPTION, BoolValidator())
+
+ # 启动时检查应用更新
+ checkUpdateOnStartup = ConfigItem("Main", "CheckUpdateOnStartup", True, BoolValidator())
-3. STTN_REFERENCE_LENGTH
-含义:参数帧数量,STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息
-效果:调大会增加显存占用,处理效果变好,但是处理速度变慢
+CONFIG_FILE = 'config/config.json'
+config = Config()
+qconfig.load(CONFIG_FILE, config)
-4. STTN_MAX_LOAD_NUM
-含义:STTN算法每次最多加载的视频帧数量
-效果:设置越大速度越慢,但效果越好
-注意:要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
-"""
-STTN_SKIP_DETECTION = True
-# 参考帧步长
-STTN_NEIGHBOR_STRIDE = 5
-# 参考帧长度(数量)
-STTN_REFERENCE_LENGTH = 10
-# 设置STTN算法最大同时处理的帧数量
-STTN_MAX_LOAD_NUM = 50
-if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE:
- STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE
-# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
+# 读取界面语言配置
+tr = configparser.ConfigParser()
-# ×××××××××× InpaintMode.PROPAINTER算法设置 start ××××××××××
-# 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高
-# 1280x720p视频设置80需要25G显存,设置50需要19G显存
-# 720x480p视频设置80需要8G显存,设置50需要7G显存
-PROPAINTER_MAX_LOAD_NUM = 70
-# ×××××××××× InpaintMode.PROPAINTER算法设置 end ××××××××××
+TRANSLATION_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'interface', f"{config.interface.value}.ini")
+tr.read(TRANSLATION_FILE, encoding='utf-8')
-# ×××××××××× InpaintMode.LAMA算法设置 start ××××××××××
-# 是否开启极速模式,开启后不保证inpaint效果,仅仅对包含文本的区域文本进行去除
-LAMA_SUPER_FAST = False
-# ×××××××××× InpaintMode.LAMA算法设置 end ××××××××××
-# ×××××××××××××××××××× [可以改] end ××××××××××××××××××××
+# 项目的base目录
+BASE_DIR = str(Path(os.path.abspath(__file__)).parent)
+
+os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
\ No newline at end of file
diff --git a/backend/inpaint/lama_inpaint.py b/backend/inpaint/lama_inpaint.py
index 403ae67..c81c2f6 100644
--- a/backend/inpaint/lama_inpaint.py
+++ b/backend/inpaint/lama_inpaint.py
@@ -1,22 +1,20 @@
import os
-from typing import Union
+import copy
+from typing import Union, List
import torch
import numpy as np
from PIL import Image
from backend.inpaint.utils.lama_util import prepare_img_and_mask
from backend import config
-
+from backend.tools.inpaint_tools import get_inpaint_area_by_mask
class LamaInpaint:
- def __init__(self, device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), model_path=None) -> None:
- if model_path is None:
- model_path = os.path.join(config.LAMA_MODEL_PATH, 'big-lama.pt')
+ def __init__(self, device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), model_path='big-lama.pt') -> None:
self.model = torch.jit.load(model_path, map_location=device)
self.model.eval()
- self.model.to(device)
self.device = device
- def __call__(self, image: Union[Image.Image, np.ndarray], mask: Union[Image.Image, np.ndarray]):
+ def inpaint(self, image: Union[Image.Image, np.ndarray], mask: Union[Image.Image, np.ndarray]):
if isinstance(image, np.ndarray):
orig_height, orig_width = image.shape[:2]
else:
@@ -29,3 +27,60 @@ class LamaInpaint:
cur_res = cur_res[:orig_height, :orig_width]
return cur_res
+ def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
+ """
+ :param input_frames: 原视频帧
+ :param input_mask: 字幕区域mask
+ """
+ mask = input_mask[:, :, None]
+ H_ori, W_ori = mask.shape[:2]
+ H_ori = int(H_ori + 0.5)
+ W_ori = int(W_ori + 0.5)
+ # 确定去字幕的垂直高度部分
+ split_h = int(W_ori * 3 / 16)
+ inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask)
+ # 初始化帧存储变量
+ # 高分辨率帧存储列表
+ frames_hr = copy.deepcopy(input_frames)
+ frames_scaled = {} # 存放缩放后帧的字典
+ masks_scaled = {} # 存放缩放后遮罩的字典
+ comps = {} # 存放补全后帧的字典
+ # 存储最终的视频帧
+ inpainted_frames = []
+ for k in range(len(inpaint_area)):
+ frames_scaled[k] = [] # 为每个去除部分初始化一个列表
+ masks_scaled[k] = [] # 为每个去除部分初始化一个列表
+
+ # 读取并缩放帧
+ for j in range(len(frames_hr)):
+ image = frames_hr[j]
+ # 对每个去除部分进行切割和缩放
+ for k in range(len(inpaint_area)):
+ image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
+ mask_crop = mask[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
+ frames_scaled[k].append(image_crop) # 将切割后的帧添加到对应列表
+ masks_scaled[k].append(mask_crop) # 将切割后的遮罩添加到对应列表
+
+ # 处理每一个去除部分
+ for k in range(len(inpaint_area)):
+ # 调用inpaint函数逐帧处理
+ comps[k] = []
+ for i in range(len(frames_scaled[k])):
+ inpainted_frame = self.inpaint(frames_scaled[k][i], masks_scaled[k][i])
+ comps[k].append(inpainted_frame)
+
+ # 如果存在去除部分
+ if inpaint_area:
+ for j in range(len(frames_hr)):
+ frame = frames_hr[j] # 取出原始帧
+ # 对于模式中的每一个段落
+ for k in range(len(inpaint_area)):
+ comp = comps[k][j] # 获取补全后的帧
+ # 实现遮罩区域内的图像融合
+ frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = comp
+ # 将最终帧添加到列表
+ inpainted_frames.append(frame)
+ # print(f'processing frame, {len(frames_hr) - j} left')
+ return inpainted_frames
+
+
diff --git a/backend/inpaint/opencv_inpaint.py b/backend/inpaint/opencv_inpaint.py
new file mode 100644
index 0000000..3779afa
--- /dev/null
+++ b/backend/inpaint/opencv_inpaint.py
@@ -0,0 +1,15 @@
+import cv2
+
+class OpenCVInpaint:
+
+ def __init__(self):
+ pass
+
+ def inpaint(self, frame, mask):
+ return cv2.inpaint(frame, mask, 3, cv2.INTER_LINEAR)
+
+ def __call__(self, frames, mask):
+ comp = []
+ for frame in frames:
+ comp.append(self.inpaint(frame, mask))
+ return comp
\ No newline at end of file
diff --git a/backend/inpaint/video_inpaint.py b/backend/inpaint/propainter_inpaint.py
similarity index 79%
rename from backend/inpaint/video_inpaint.py
rename to backend/inpaint/propainter_inpaint.py
index 94be3d1..726f713 100644
--- a/backend/inpaint/video_inpaint.py
+++ b/backend/inpaint/propainter_inpaint.py
@@ -1,9 +1,11 @@
# -*- coding: utf-8 -*-
import os
import cv2
+import copy
import numpy as np
import scipy.ndimage
from PIL import Image
+from typing import List
import torch
import torchvision
@@ -14,12 +16,12 @@ from backend.inpaint.video.model.recurrent_flow_completion import RecurrentFlowC
from backend.inpaint.video.model.propainter import InpaintGenerator
from backend.inpaint.video.core.utils import to_tensors
from backend.inpaint.video.model.misc import get_device
+from backend.tools.inpaint_tools import get_inpaint_area_by_mask
import warnings
warnings.filterwarnings("ignore")
-
def binary_mask(mask, th=0.1):
mask[mask > th] = 1
mask[mask <= th] = 0
@@ -33,6 +35,11 @@ def read_mask(mpath, length, size, flow_mask_dilates=8, mask_dilates=5):
flow_masks = []
# 如果传入的直接为numpy array
if isinstance(mpath, np.ndarray):
+ if mpath.ndim == 3 and mpath.shape[2] == 1:
+ mpath = mpath.squeeze(2) # 从 (H,W,1) 转为 (H,W)
+ elif mpath.ndim == 3 and mpath.shape[2] == 3:
+ # 如果是彩色图像,转为灰度
+ mpath = cv2.cvtColor(mpath, cv2.COLOR_BGR2GRAY)
masks_img = [Image.fromarray(mpath)]
# input single img path
else:
@@ -129,9 +136,10 @@ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=
return ref_index
-class VideoInpaint:
- def __init__(self, sub_video_length=config.PROPAINTER_MAX_LOAD_NUM, use_fp16=True):
- self.device = get_device()
+class PropainterInpaint:
+ def __init__(self, device, model_dir, sub_video_length=80, use_fp16=True):
+ self.device = device
+ self.model_dir = model_dir
self.use_fp16 = use_fp16
self.use_half = True if self.use_fp16 else False
if self.device == torch.device('cpu'):
@@ -157,21 +165,27 @@ class VideoInpaint:
def init_raft_model(self):
# set up RAFT and flow competition model
- return RAFT_bi(os.path.join(config.VIDEO_INPAINT_MODEL_PATH, 'raft-things.pth'), self.device)
+ return RAFT_bi(os.path.join(self.model_dir, 'raft-things.pth'), self.device)
def init_fix_flow_model(self):
fix_flow_complete_model = RecurrentFlowCompleteNet(
- os.path.join(config.VIDEO_INPAINT_MODEL_PATH, 'recurrent_flow_completion.pth'))
+ os.path.join(self.model_dir, 'recurrent_flow_completion.pth'))
for p in fix_flow_complete_model.parameters():
p.requires_grad = False
+
+ if self.use_half:
+ fix_flow_complete_model = fix_flow_complete_model.half()
fix_flow_complete_model.to(self.device)
fix_flow_complete_model.eval()
return fix_flow_complete_model
def init_inpaint_model(self):
# set up ProPainter model
- return InpaintGenerator(model_path=os.path.join(config.VIDEO_INPAINT_MODEL_PATH, 'ProPainter.pth')).to(
- self.device).eval()
+ model = InpaintGenerator(model_path=os.path.join(self.model_dir, 'ProPainter.pth'))
+ if self.use_half:
+ model = model.half()
+ model = model.to(self.device).eval()
+ return model
def inpaint(self, frames, mask):
if isinstance(frames[0], np.ndarray):
@@ -235,8 +249,6 @@ class VideoInpaint:
if self.use_half:
frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
- fix_flow_complete = self.fix_flow_complete.half()
- self.model = self.model.half()
# ---- complete flow ----
flow_length = gt_flows_bi[0].size(1)
@@ -248,10 +260,10 @@ class VideoInpaint:
e_f = min(flow_length, f + self.sub_video_length + pad_len)
pad_len_s = max(0, f) - s_f
pad_len_e = e_f - min(flow_length, f + self.sub_video_length)
- pred_flows_bi_sub, _ = fix_flow_complete.forward_bidirect_flow(
+ pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
flow_masks[:, s_f:e_f + 1])
- pred_flows_bi_sub = fix_flow_complete.combine_flow(
+ pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
pred_flows_bi_sub,
flow_masks[:, s_f:e_f + 1])
@@ -264,8 +276,8 @@ class VideoInpaint:
pred_flows_b = torch.cat(pred_flows_b, dim=1)
pred_flows_bi = (pred_flows_f, pred_flows_b)
else:
- pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
- pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
+ pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
+ pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
torch.cuda.empty_cache()
# ---- image propagation ----
@@ -348,6 +360,59 @@ class VideoInpaint:
comp_frames = [cv2.cvtColor(i, cv2.COLOR_RGB2BGR) for i in comp_frames]
return comp_frames
+ def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
+ """
+ :param input_frames: 原视频帧
+ :param input_mask: 字幕区域mask
+ """
+ mask = input_mask[:, :, None]
+ H_ori, W_ori = mask.shape[:2]
+ H_ori = int(H_ori + 0.5)
+ W_ori = int(W_ori + 0.5)
+ # 确定去字幕的垂直高度部分
+ split_h = int(W_ori * 3 / 16)
+ inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask, multiple=8)
+ # 初始化帧存储变量
+ # 高分辨率帧存储列表
+ frames_hr = copy.deepcopy(input_frames)
+ frames_scaled = {} # 存放缩放后帧的字典
+ masks_scaled = {} # 存放缩放后遮罩的字典
+ comps = {} # 存放补全后帧的字典
+ # 存储最终的视频帧
+ inpainted_frames = []
+ for k in range(len(inpaint_area)):
+ frames_scaled[k] = [] # 为每个去除部分初始化一个列表
+ masks_scaled[k] = [] # 为每个去除部分初始化一个列表
+
+ # 读取并缩放帧
+ for j in range(len(frames_hr)):
+ image = frames_hr[j]
+ # 对每个去除部分进行切割和缩放
+ for k in range(len(inpaint_area)):
+ image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], inpaint_area[k][2]:inpaint_area[k][3], :] # 切割
+ mask_crop = mask[inpaint_area[k][0]:inpaint_area[k][1], inpaint_area[k][2]:inpaint_area[k][3], :] # 切割
+ frames_scaled[k].append(image_crop) # 将缩放后的帧添加到对应列表
+ masks_scaled[k].append(mask_crop) # 将缩放后的遮罩添加到对应列表
+
+ # 处理每一个去除部分
+ for k in range(len(inpaint_area)):
+ # 调用inpaint函数进行处理
+ comps[k] = self.inpaint(frames_scaled[k], masks_scaled[k][0])
+
+ # 如果存在去除部分
+ if inpaint_area:
+ for j in range(len(frames_hr)):
+ frame = frames_hr[j] # 取出原始帧
+ # 对于模式中的每一个段落
+ for k in range(len(inpaint_area)):
+ comp = comps[k][j] # 获取补全后的帧
+ # 实现遮罩区域内的图像融合
+ frame[inpaint_area[k][0]:inpaint_area[k][1], inpaint_area[k][2]:inpaint_area[k][3], :] = comp
+ # 将最终帧添加到列表
+ inpainted_frames.append(frame)
+ # print(f'processing frame, {len(frames_hr) - j} left')
+ return inpainted_frames
+
def read_frames(v_path):
video_cap = cv2.VideoCapture(v_path)
@@ -362,11 +427,11 @@ def read_frames(v_path):
if __name__ == '__main__':
- # VideoInpaint
- video_inpaint = VideoInpaint(sub_video_length=80)
+ # PropainterInpaint
+ propainter_inpaint = PropainterInpaint(get_device(), ModelConfig().PROPAINTER_MODEL_DIR, sub_video_length=80)
frames = read_frames('/home/yao/Documents/Project/video-subtitle-remover/local_test/test1.mp4')
mask = cv2.imread('/home/yao/Documents/Project/video-subtitle-remover/local_test/test1_mask.png')
- inpainted_frames = video_inpaint.inpaint(frames, mask)
+ inpainted_frames = propainter_inpaint.inpaint(frames, mask)
save_root = '/home/yao/Documents/Project/video-subtitle-remover/local_test/'
video_out_path = os.path.join(save_root, 'inpaint_out.mp4')
print("size: ", inpainted_frames[0].shape)
diff --git a/backend/inpaint/sttn/network_sttn.py b/backend/inpaint/sttn/network_sttn.py
index 57d5016..22ad3b7 100644
--- a/backend/inpaint/sttn/network_sttn.py
+++ b/backend/inpaint/sttn/network_sttn.py
@@ -1,9 +1,11 @@
''' Spatial-Temporal Transformer Networks
'''
+import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
+import torchvision.models as models
from backend.inpaint.utils.spectral_norm import spectral_norm as _spectral_norm
@@ -61,7 +63,7 @@ class BaseNetwork(nn.Module):
class InpaintGenerator(BaseNetwork):
- def __init__(self, init_weights=True): # 1046
+ def __init__(self, init_weights=True):
super(InpaintGenerator, self).__init__()
channel = 256
stack_num = 8
@@ -82,7 +84,7 @@ class InpaintGenerator(BaseNetwork):
nn.LeakyReLU(0.2, inplace=True),
)
- # decoder: decode image from features
+ # decoder: decode frames from features
self.decoder = nn.Sequential(
deconv(channel, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
@@ -114,11 +116,9 @@ class InpaintGenerator(BaseNetwork):
masks = masks.view(t, c, h, w)
masks = F.interpolate(masks, scale_factor=1.0/4)
t, c, _, _ = feat.size()
- output = self.transformer({'x': feat, 'm': masks, 'b': 1, 'c': c})
- enc_feat = output['x']
- attn = output['attn']
- mm = output['smm']
- return enc_feat, attn, mm
+ enc_feat = self.transformer(
+ {'x': feat, 'm': masks, 'b': 1, 'c': c})['x']
+ return enc_feat
class deconv(nn.Module):
@@ -133,8 +133,9 @@ class deconv(nn.Module):
return self.conv(x)
-# ##################################################
-# ################## Transformer ####################
+# #############################################################################
+# ############################# Transformer ##################################
+# #############################################################################
class Attention(nn.Module):
@@ -205,22 +206,14 @@ class MultiHeadedAttention(nn.Module):
tmp1.append(y)
y = torch.cat(tmp1,1)
'''
- y, attn = self.attention(query, key, value, mm)
-
- # return attention value for visualization
- # here we return the attention value of patchsize=18
- if width == 18:
- select_attn = attn.view(t, out_h*out_w, t, out_h, out_w)[0]
- # mm, [b, thw, thw]
- select_mm = mm[0].view(t*out_h*out_w, t, out_h, out_w)[0]
-
+ y, _ = self.attention(query, key, value, mm)
# 3) "Concat" using a view and apply a final linear.
y = y.view(b, t, out_h, out_w, d_k, height, width)
y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w)
output.append(y)
output = torch.cat(output, 1)
x = self.output_linear(output)
- return x, select_attn, select_mm
+ return x
# Standard 2 layerd FFN of transformer
@@ -251,10 +244,9 @@ class TransformerBlock(nn.Module):
def forward(self, x):
x, m, b, c = x['x'], x['m'], x['b'], x['c']
- val, attn, mm = self.attention(x, m, b, c)
- x = x + val
+ x = x + self.attention(x, m, b, c)
x = x + self.feed_forward(x)
- return {'x': x, 'm': m, 'b': b, 'c': c, 'attn': attn, 'smm': mm}
+ return {'x': x, 'm': m, 'b': b, 'c': c}
# ######################################################################
@@ -309,4 +301,4 @@ class Discriminator(BaseNetwork):
def spectral_norm(module, mode=True):
if mode:
return _spectral_norm(module)
- return module
+ return module
\ No newline at end of file
diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_auto_inpaint.py
similarity index 80%
rename from backend/inpaint/sttn_inpaint.py
rename to backend/inpaint/sttn_auto_inpaint.py
index 4e0f504..4761486 100644
--- a/backend/inpaint/sttn_inpaint.py
+++ b/backend/inpaint/sttn_auto_inpaint.py
@@ -10,9 +10,10 @@ import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
-from backend import config
+from backend.config import config
from backend.inpaint.sttn.auto_sttn import InpaintGenerator
from backend.inpaint.utils.sttn_utils import Stack, ToTorchFormatTensor
+from backend.tools.inpaint_tools import get_inpaint_area_by_mask
# 定义图像预处理方式
_to_tensors = transforms.Compose([
@@ -20,21 +21,20 @@ _to_tensors = transforms.Compose([
ToTorchFormatTensor() # 将堆叠的图像转化为PyTorch张量
])
-
class STTNInpaint:
- def __init__(self):
- self.device = config.device
+ def __init__(self, device, model_path):
+ self.device = device
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
self.model = InpaintGenerator().to(self.device)
# 2. 载入预训练模型的权重,转载模型的状态字典
- self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location='cpu')['netG'])
+ self.model.load_state_dict(torch.load(model_path, map_location='cpu')['netG'])
# 3. # 将模型设置为评估模式
self.model.eval()
# 模型输入用的宽和高
self.model_input_width, self.model_input_height = 640, 120
# 2. 设置相连帧数
- self.neighbor_stride = config.STTN_NEIGHBOR_STRIDE
- self.ref_length = config.STTN_REFERENCE_LENGTH
+ self.neighbor_stride = config.sttnNeighborStride.value
+ self.ref_length = config.sttnReferenceLength.value
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
"""
@@ -48,7 +48,7 @@ class STTNInpaint:
W_ori = int(W_ori + 0.5)
# 确定去字幕的垂直高度部分
split_h = int(W_ori * 3 / 16)
- inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask)
+ inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask)
# 初始化帧存储变量
# 高分辨率帧存储列表
frames_hr = copy.deepcopy(input_frames)
@@ -87,7 +87,7 @@ class STTNInpaint:
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
# 将最终帧添加到列表
inpainted_frames.append(frame)
- print(f'processing frame, {len(frames_hr) - j} left')
+ # print(f'processing frame, {len(frames_hr) - j} left')
return inpainted_frames
@staticmethod
@@ -163,64 +163,8 @@ class STTNInpaint:
# 返回处理完成的帧序列
return comp_frames
- @staticmethod
- def get_inpaint_area_by_mask(H, h, mask):
- """
- 获取字幕去除区域,根据mask来确定需要填补的区域和高度
- """
- # 存储绘画区域的列表
- inpaint_area = []
- # 从视频底部的字幕位置开始,假设字幕通常位于底部
- to_H = from_H = H
- # 从底部向上遍历遮罩
- while from_H != 0:
- if to_H - h < 0:
- # 如果下一段会超出顶端,则从顶端开始
- from_H = 0
- to_H = h
- else:
- # 确定段的上边界
- from_H = to_H - h
- # 检查当前段落是否包含遮罩像素
- if not np.all(mask[from_H:to_H, :] == 0) and np.sum(mask[from_H:to_H, :]) > 10:
- # 如果不是第一个段落,向下移动以确保没遗漏遮罩区域
- if to_H != H:
- move = 0
- while to_H + move < H and not np.all(mask[to_H + move, :] == 0):
- move += 1
- # 确保没有越过底部
- if to_H + move < H and move < h:
- to_H += move
- from_H += move
- # 将该段落添加到列表中
- if (from_H, to_H) not in inpaint_area:
- inpaint_area.append((from_H, to_H))
- else:
- break
- # 移动到下一个段落
- to_H -= h
- return inpaint_area # 返回绘画区域列表
- @staticmethod
- def get_inpaint_area_by_selection(input_sub_area, mask):
- print('use selection area for inpainting')
- height, width = mask.shape[:2]
- ymin, ymax, _, _ = input_sub_area
- interval_size = 135
- # 存储结果的列表
- inpaint_area = []
- # 计算并存储标准区间
- for i in range(ymin, ymax, interval_size):
- inpaint_area.append((i, i + interval_size))
- # 检查最后一个区间是否达到了最大值
- if inpaint_area[-1][1] != ymax:
- # 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值
- if inpaint_area[-1][1] + interval_size <= height:
- inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size))
- return inpaint_area # 返回绘画区域列表
-
-
-class STTNVideoInpaint:
+class STTNAutoInpaint:
def read_frame_info_from_video(self):
# 使用opencv读取视频
@@ -235,9 +179,9 @@ class STTNVideoInpaint:
# 返回视频读取对象、帧信息和视频写入对象
return reader, frame_info
- def __init__(self, video_path, mask_path=None, clip_gap=None):
+ def __init__(self, device, model_path, video_path, mask_path=None, clip_gap=None):
# STTNInpaint视频修复实例初始化
- self.sttn_inpaint = STTNInpaint()
+ self.sttn_inpaint = STTNInpaint(device, model_path)
# 视频和掩码路径
self.video_path = video_path
self.mask_path = mask_path
@@ -248,7 +192,7 @@ class STTNVideoInpaint:
)
# 配置可在一次处理中加载的最大帧数
if clip_gap is None:
- self.clip_gap = config.STTN_MAX_LOAD_NUM
+ self.clip_gap = config.getSttnMaxLoadNum()
else:
self.clip_gap = clip_gap
@@ -277,8 +221,7 @@ class STTNVideoInpaint:
mask = mask[:, :, None]
# 得到修复区域位置
- inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
-
+ inpaint_area = get_inpaint_area_by_mask(frame_info['W_ori'], frame_info['H_ori'], split_h, mask)
# 遍历每一次的迭代次数
for i in range(rec_time):
start_f = i * self.clip_gap # 起始帧位置
@@ -346,7 +289,7 @@ class STTNVideoInpaint:
if tbar is not None:
input_sub_remover.update_progress(tbar, increment=1)
if original_frame is not None and input_sub_remover.gui_mode:
- input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
+ input_sub_remover.update_preview_with_comp(original_frame, frame)
except Exception as e:
print(f"Error during video processing: {str(e)}")
# 不抛出异常,允许程序继续执行
@@ -360,7 +303,7 @@ if __name__ == '__main__':
video_path = '../../test/test.mp4'
# 记录开始时间
start = time.time()
- sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=config.STTN_MAX_LOAD_NUM)
+ sttn_video_inpaint = STTNAutoInpaint(video_path, mask_path, clip_gap=config.getSttnMaxLoadNum())
sttn_video_inpaint()
print(f'video generated at {sttn_video_inpaint.video_out_path}')
print(f'time cost: {time.time() - start}')
diff --git a/backend/inpaint/sttn_det_inpaint.py b/backend/inpaint/sttn_det_inpaint.py
new file mode 100644
index 0000000..6250512
--- /dev/null
+++ b/backend/inpaint/sttn_det_inpaint.py
@@ -0,0 +1,179 @@
+import copy
+import time
+
+import cv2
+import numpy as np
+import torch
+from torchvision import transforms
+from typing import List
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+from backend.config import config
+from backend.inpaint.sttn.network_sttn import InpaintGenerator
+from backend.inpaint.utils.sttn_utils import Stack, ToTorchFormatTensor
+from backend.tools.inpaint_tools import get_inpaint_area_by_mask
+
+# 定义图像预处理方式
+_to_tensors = transforms.Compose([
+ Stack(), # 将图像堆叠为序列
+ ToTorchFormatTensor() # 将堆叠的图像转化为PyTorch张量
+])
+
+class STTNDetInpaint:
+ def __init__(self, device, model_path):
+ self.device = device
+ # 1. 创建InpaintGenerator模型实例并装载到选择的设备上
+ self.model = InpaintGenerator().to(self.device)
+ # 2. 载入预训练模型的权重,转载模型的状态字典
+ self.model.load_state_dict(torch.load(model_path, map_location='cpu')['netG'])
+ # 3. # 将模型设置为评估模式
+ self.model.eval()
+ # 模型输入用的宽和高
+ self.model_input_width, self.model_input_height = 432, 240
+ # 2. 设置相连帧数
+ self.neighbor_stride = config.sttnNeighborStride.value
+ self.ref_length = config.sttnReferenceLength.value
+
+ def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
+ """
+ :param input_frames: 原视频帧
+ :param mask: 字幕区域mask
+ """
+ mask = input_mask[:, :, None]
+ H_ori, W_ori = mask.shape[:2]
+ H_ori = int(H_ori + 0.5)
+ W_ori = int(W_ori + 0.5)
+ # 确定去字幕的垂直高度部分
+ if H_ori > W_ori:
+ split_h = int(H_ori * 5 / 9)
+ else:
+ split_h = int(W_ori * 5 / 18)
+ inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask)
+ # 初始化帧存储变量
+ # 高分辨率帧存储列表
+ frames_hr = copy.deepcopy(input_frames)
+ frames_scaled = {} # 存放缩放后帧的字典
+ masks_scaled = {} # 存放缩放后遮罩的字典
+ comps = {} # 存放补全后帧的字典
+ # 存储最终的视频帧
+ inpainted_frames = []
+ for k in range(len(inpaint_area)):
+ frames_scaled[k] = [] # 为每个去除部分初始化一个列表
+ masks_scaled[k] = [] # 为每个去除部分初始化一个列表
+
+ # 读取并缩放帧
+ for j in range(len(frames_hr)):
+ image = frames_hr[j]
+ # 对每个去除部分进行切割和缩放
+ for k in range(len(inpaint_area)):
+ image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
+ mask_crop = mask[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
+ image_resize = cv2.resize(image_crop, (self.model_input_width, self.model_input_height)) # 缩放
+ mask_resize = cv2.resize(mask_crop, (self.model_input_width, self.model_input_height)) # 缩放
+ frames_scaled[k].append(image_resize) # 将缩放后的帧添加到对应列表
+ masks_scaled[k].append(mask_resize) # 将缩放后的遮罩添加到对应列表
+
+ # 处理每一个去除部分
+ for k in range(len(inpaint_area)):
+ # 调用inpaint函数进行处理
+ comps[k] = self.inpaint(frames_scaled[k], masks_scaled[k])
+
+ # 如果存在去除部分
+ if inpaint_area:
+ for j in range(len(frames_hr)):
+ frame = frames_hr[j] # 取出原始帧
+ # 对于模式中的每一个段落
+ for k in range(len(inpaint_area)):
+ comp = cv2.resize(comps[k][j], (W_ori, split_h)) # 将补全帧缩放回原大小
+ comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间
+ # 获取遮罩区域并进行图像合成
+ mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域
+ # 实现遮罩区域内的图像融合
+ frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = comp
+ # 将最终帧添加到列表
+ inpainted_frames.append(frame)
+ # print(f'processing frame, {len(frames_hr) - j} left')
+ return inpainted_frames
+
+ @staticmethod
+ def read_mask(path):
+ img = cv2.imread(path, 0)
+ # 转为binary mask
+ ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
+ img = img[:, :, None]
+ return img
+
+ def get_ref_index(self, neighbor_ids, length):
+ """
+ 采样整个视频的参考帧
+ """
+ # 初始化参考帧的索引列表
+ ref_index = []
+ # 在视频长度范围内根据ref_length逐步迭代
+ for i in range(0, length, self.ref_length):
+ # 如果当前帧不在近邻帧中
+ if i not in neighbor_ids:
+ # 将它添加到参考帧列表
+ ref_index.append(i)
+ # 返回参考帧索引列表
+ return ref_index
+
+ def inpaint(self, frames: List[np.ndarray], masks: List[np.ndarray]):
+ """
+ 使用STTN完成空洞填充(空洞即被遮罩的区域)
+ """
+ frame_length = len(frames)
+ # 对帧进行预处理转换为张量,并进行归一化
+ feats = _to_tensors(frames).unsqueeze(0) * 2 - 1
+
+ binary_masks = [np.expand_dims((np.array(m) > 0.5).astype(np.uint8), 2) for m in masks]
+ # 将掩码转换为张量
+ masks = (_to_tensors(masks).unsqueeze(0) > 0.5).float()
+
+ # 把特征张量转移到指定的设备(CPU或GPU)
+ feats, masks = feats.to(self.device), masks.to(self.device)
+ # 初始化一个与视频长度相同的列表,用于存储处理完成的帧
+ comp_frames = [None] * frame_length
+ # 关闭梯度计算,用于推理阶段节省内存并加速
+ with torch.no_grad():
+ # 将处理好的帧通过编码器,产生特征表示
+ feats = self.model.encoder((feats*(1-masks).float()).view(frame_length, 3, self.model_input_height, self.model_input_width))
+ # 获取特征维度信息
+ _, c, feat_h, feat_w = feats.size()
+ # 调整特征形状以匹配模型的期望输入
+ feats = feats.view(1, frame_length, c, feat_h, feat_w)
+ # 获取重绘区域
+ # 在设定的邻居帧步幅内循环处理视频
+ for f in range(0, frame_length, self.neighbor_stride):
+ # 计算邻近帧的ID
+ neighbor_ids = [i for i in range(max(0, f - self.neighbor_stride), min(frame_length, f + self.neighbor_stride + 1))]
+ # 获取参考帧的索引
+ ref_ids = self.get_ref_index(neighbor_ids, frame_length)
+ # 同样关闭梯度计算
+ with torch.no_grad():
+ # 通过模型推断特征并传递给解码器以生成完成的帧
+ pred_feat = self.model.infer(
+ feats[0, neighbor_ids + ref_ids, :, :, :], masks[0, neighbor_ids + ref_ids, :, :, :])
+
+ # 将预测的特征通过解码器生成图片,并应用激活函数tanh,然后分离出张量
+ pred_img = torch.tanh(self.model.decoder(pred_feat[:len(neighbor_ids), :, :, :])).detach()
+ # 将结果张量重新缩放到0到255的范围内(图像像素值)
+ pred_img = (pred_img + 1) / 2
+ # 将张量移动回CPU并转为NumPy数组
+ pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
+ # 遍历邻近帧
+ for i in range(len(neighbor_ids)):
+ idx = neighbor_ids[i]
+ # 将预测的图片转换为无符号8位整数格式
+ img = np.array(pred_img[i]).astype(
+ np.uint8)*binary_masks[idx] + frames[idx] * (1-binary_masks[idx])
+ if comp_frames[idx] is None:
+ # 如果该位置为空,则赋值为新计算出的图片
+ comp_frames[idx] = img
+ else:
+ # 如果此位置之前已有图片,则将新旧图片混合以提高质量
+ comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
+ # 返回处理完成的帧序列
+ return comp_frames
diff --git a/backend/interface/ch.ini b/backend/interface/ch.ini
new file mode 100644
index 0000000..c0035f8
--- /dev/null
+++ b/backend/interface/ch.ini
@@ -0,0 +1,129 @@
+[Setting]
+FeedbackButton = 提供反馈
+FeedbackTitle = 提供反馈
+FeedbackDesc = 通过提供反馈帮助我们改进
+CopyrightButton = 检查更新
+CopyrightTitle = 关于
+CopyrightDesc = ©版权所有 2023, YaoFANGUK, Jason Eric (界面设计), 当前版本: {}
+ProjectLinkTitle = 字幕去除器
+ProjectLinkDesc = 基于AI的图片/视频硬字幕去除、文本水印去除,无损分辨率生成去字幕、去水印后的图片/视频文件。无需申请第三方API,本地实现。
+BasicSetting = 基础设置
+AdvancedSetting = 高级设置
+SubtitleDetectionSetting = 字幕检测设置
+SttnSetting = STTN设置
+ProPainterSetting = ProPainter设置
+AboutSetting = 关于
+HardwareAcceleration = 硬件加速
+HardwareAccelerationDesc = 使用GPU或ONNX后端进行加速处理
+SubtitleYXAxisDifferencePixel = 高宽像素差阈值
+SubtitleYXAxisDifferencePixelDesc = 用于判断是不是非字幕区域,默认为10 (一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
+SubtitleAreaDeviationPixel = 允许的像素偏移量
+SubtitleAreaDeviationPixelDesc = 用于放大mask大小,防止自动检测的文本框过小,inpaint阶段出现文字边,有残留,默认为10
+SubtitleAreaYAxisDifferencePixel = 同行字幕高度差阈值
+SubtitleAreaYAxisDifferencePixelDesc = 同于判断两个文本框是否为同一行字幕,高度差距指定像素点以内认为是同一行,默认为20
+SubtitleAreaPixelToleranceYPixel = Y轴容忍像素偏差
+SubtitleAreaPixelToleranceYPixelDesc = 用于判断两个字幕文本的矩形框是否相似,如果Y轴偏差都在指定阈值内,则认为时同一个文本框,默认为20
+SubtitleAreaPixelToleranceXPixel = X轴容忍像素偏差
+SubtitleAreaPixelToleranceXPixelDesc = 用于判断两个字幕文本的矩形框是否相似,如果X轴偏差都在指定阈值内,则认为时同一个文本框,默认为20
+SubtitleTimelineBackwardFrameCount = 字幕时间轴回退帧数
+SubtitleTimelineBackwardFrameCountDesc = 用于在时间轴上在检测到字幕的基础上往前增加的处理帧数, 增加这个可以处理更加缓慢渐入的字幕, 默认为3帧
+subtitleTimelineForwardFrameCount = 字幕时间轴前进帧数
+subtitleTimelineForwardFrameCountDesc = 用于在时间轴上在检测到字幕的基础上往后增加的处理帧数, 增加这个可以处理更加缓慢渐出的字幕, 默认为3帧
+SttnNeighborStride = 参考帧步长
+SttnNeighborStrideDesc = 默认为5
+SttnReferenceLength = 参考帧数量
+SttnReferenceLengthDesc = 默认为10
+SttnMaxLoadNum = 最大同时处理的帧数量
+SttnMaxLoadNumDesc = 设置越大处理效果越好,但是要求显存越高,默认为50
+PropainterMaxLoadNum = 最大同时处理的帧数量
+PropainterMaxLoadNumDesc = 设置越大处理效果越好,但是要求显存越高,默认为70
+CheckUpdateOnStartup = 在应用程序启动时检查更新
+CheckUpdateOnStartupDesc = 新版本将更加稳定, 并拥有更多功能(建议启用此选项)
+UpdatesAvailableTitle = 有可用更新
+UpdatesAvailableDesc = 发现新版本 {}, 是否更新?
+NoUpdatesAvailableTitle = 无可用更新
+NoUpdatesAvailableDesc = 软件已是最新版本
+
+[SubtitleExtractorGUI]
+Title = 字幕去除器
+Open = 打开
+AllFile = 所有文件
+Vertical = 垂直方向
+Horizontal = 水平方向
+Run = 运行
+Stop = 停止
+Setting = 设置
+OpenVideoSuccess = 成功打开视频
+OpenVideoFailed = 无法打开视频: {}, 格式不兼容或文件损坏
+OpenVideoFirst = 请先打开视频
+SubtitleArea = 字幕区域
+VideoPreview = 视频预览
+InterfaceLanguage = 界面语言
+InpaintMode = 处理模型
+SelectSubtitleArea = 请在视频预览中框选处理区域
+InpaintModeDesc = STTN智能擦除, 对于真人视频效果较好,速度快, 智能擦除(最低4GB显存)
+ STTN字幕检测 带字幕检测版, 无智能擦除(最低4GB显存)
+ LAMA: 对于动画类视频效果好,速度一般(显存要求较低)
+ ProPainter: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好(最低8GB显存)
+ OpenCV: 极速模式, 不保证inpaint效果,仅仅对包含文本的区域文本进行去除(显存要求较低)
+SubtitleDetectMode = 字幕检测
+ErrorDuringProcessing = 处理过程中发生错误: {}
+
+[Main]
+SubtitleDetectionAcceleratorON = 字幕检测使用{}进行加速
+OnnxExectionProviderNotSupportedSkipped = ONNX 执行提供程序: {} 不支持,已跳过。
+OnnxExecutionProviderDetected=检测到 ONNX 执行提供程序: {}
+OnnxRuntimeNotInstall = ONNX 运行环境未安装,已跳过。
+NoSubtitleDetected = 未检测到任何字幕, 请检查文件{}是否正确
+DirectMLWarning = 警告: DirectML 加速仅适用于 STTN 模型,其他模型将使用 CPU 运行。
+ProcessingStartFindingSubtitles = [处理中] 开始查找字幕...
+FinishedFindingSubtitles = [结束] 查找字幕完成...
+ProcessingStartRemovingSubtitles = [处理中] 开始移除字幕...
+UseModel = 去除字幕使用模型: {}
+FullScreenProcessingNote = 未设置字幕区域,将对全屏进行处理,最终效果可能不理想
+ReadFileFailed = 读取文件 {} 失败
+FinishedProcessing = [完成] 字幕移除成功, 文件已保存到 {}
+ProcessingTime = 处理时间: {}秒
+FailToMergeAudio = 合并音频失败: {}
+FailToExtractAudio = 提取音频失败: {}
+CopyFileFailed = 复制文件 {} 到 {} 失败, 原因: {}
+
+[TaskList]
+Pending = 待处理
+Processing = 处理中
+Completed = 已完成
+Failed = 失败
+Name = 名称
+Progress = ⠀进度⠀
+Status = ㅤ状态ㅤ
+OpenSourceVideoLocation = 打开原始视频位置
+OpenTargetVideoLocation = 打开目标视频位置
+ResetTaskStatus = 重置任务状态
+DeleteTask = 删除任务
+Warning = 警告
+UnableToLocateFile = 找不到文件,可能已被移动或删除
+TargetFileNotFound = 文件尚未生成,请先等待任务完成
+
+[VersionService]
+VersionInfo = 当前版本: {} 最新版本: {}
+RequestError = 尝试访问 {} 失败, 原因: {}
+
+[InpaintMode]
+SttnAuto = STTN智能擦除
+SttnDet = STTN字幕检测
+LAMA = LAMA
+ProPainter = ProPainter
+OpenCV = OpenCV
+
+[SubtitleDetectMode]
+Fast = 快速
+Accurate = 精准
+
+[InterfaceLanguage]
+ChineseSimplified = 简体中文
+ChineseTraditional = 繁體中文
+English = English
+Japanese = 日本語
+Korean = 한국어
+Vietnamese = Tiếng Việt
+Español = Español
\ No newline at end of file
diff --git a/backend/interface/chinese_cht.ini b/backend/interface/chinese_cht.ini
new file mode 100644
index 0000000..08f23ba
--- /dev/null
+++ b/backend/interface/chinese_cht.ini
@@ -0,0 +1,129 @@
+[Setting]
+FeedbackButton = 提供反饋
+FeedbackTitle = 提供反饋
+FeedbackDesc = 透過提供反饋幫助我們改進
+CopyrightButton = 檢查更新
+CopyrightTitle = 關於
+CopyrightDesc = ©版權所有 2023, YaoFANGUK, Jason Eric (介面設計), 當前版本: {}
+ProjectLinkTitle = 字幕去除器
+ProjectLinkDesc = 基於AI的圖片/影片硬字幕去除、文字浮水印去除,無損解析度生成去字幕、去浮水印後的圖片/影片檔案。無需申請第三方API,本地實現。
+BasicSetting = 基礎設定
+AdvancedSetting = 進階設定
+SubtitleDetectionSetting = 字幕檢測設定
+SttnSetting = STTN設定
+ProPainterSetting = ProPainter設定
+AboutSetting = 關於
+HardwareAcceleration = 硬體加速
+HardwareAccelerationDesc = 使用GPU或ONNX後端進行加速處理
+SubtitleYXAxisDifferencePixel = 高寬像素差閾值
+SubtitleYXAxisDifferencePixelDesc = 用於判斷是否為非字幕區域,預設為10 (一般認為字幕文字框的長度大於寬度,若字幕框的高大於寬,且幅度超過指定像素點大小,則視為錯誤檢測)
+SubtitleAreaDeviationPixel = 允許的像素偏移量
+SubtitleAreaDeviationPixelDesc = 用於放大遮罩大小,防止自動檢測的文字框過小,inpaint階段出現文字邊緣殘留,預設為10
+SubtitleAreaYAxisDifferencePixel = 同行字幕高度差閾值
+SubtitleAreaYAxisDifferencePixelDesc = 用於判斷兩個文字框是否為同一行字幕,高度差距在指定像素點內視為同一行,預設為20
+SubtitleAreaPixelToleranceYPixel = Y軸容忍像素偏差
+SubtitleAreaPixelToleranceYPixelDesc = 用於判斷兩個字幕文字框是否相似,若Y軸偏差均在指定閾值內,則視為同一個文字框,預設為20
+SubtitleAreaPixelToleranceXPixel = X軸容忍像素偏差
+SubtitleAreaPixelToleranceXPixelDesc = 用於判斷兩個字幕文字框是否相似,若X軸偏差均在指定閾值內,則視為同一個文字框,預設為20
+SubtitleTimelineBackwardFrameCount = 字幕時間軸回退影格數
+SubtitleTimelineBackwardFrameCountDesc = 用於在時間軸上於檢測到字幕的基礎上往前增加的處理影格數,可處理漸入較慢的字幕,預設為3影格
+subtitleTimelineForwardFrameCount = 字幕時間軸前進影格數
+subtitleTimelineForwardFrameCountDesc = 用於在時間軸上於檢測到字幕的基礎上往後增加的處理影格數,可處理漸出較慢的字幕,預設為3影格
+SttnNeighborStride = 參考影格步長
+SttnNeighborStrideDesc = 預設為5
+SttnReferenceLength = 參考影格數量
+SttnReferenceLengthDesc = 預設為10
+SttnMaxLoadNum = 最大同時處理的影格數量
+SttnMaxLoadNumDesc = 數值越大處理效果越好,但需更高顯示記憶體,預設為50
+PropainterMaxLoadNum = 最大同時處理的影格數量
+PropainterMaxLoadNumDesc = 數值越大處理效果越好,但需更高顯示記憶體,預設為70
+CheckUpdateOnStartup = 在應用程式啟動時檢查更新
+CheckUpdateOnStartupDesc = 新版本將更穩定並提供更多功能(建議啟用此選項)
+UpdatesAvailableTitle = 有可用更新
+UpdatesAvailableDesc = 發現新版本 {},是否更新?
+NoUpdatesAvailableTitle = 無可用更新
+NoUpdatesAvailableDesc = 軟體已是最新版本
+
+[SubtitleExtractorGUI]
+Title = 字幕去除器
+Open = 開啟
+AllFile = 所有檔案
+Vertical = 垂直方向
+Horizontal = 水平方向
+Run = 執行
+Stop = 停止
+Setting = 設定
+OpenVideoSuccess = 成功開啟影片
+OpenVideoFailed = 無法開啟影片: {},格式不相容或檔案損毀
+OpenVideoFirst = 請先開啟影片
+SubtitleArea = 字幕區域
+VideoPreview = 影片預覽
+InterfaceLanguage = 介面語言
+InpaintMode = 處理模型
+SelectSubtitleArea = 請在影片預覽中框選處理區域
+InpaintModeDesc = STTN智能擦除,對於真人視頻效果較好,速度快,智能擦除(最低4GB顯存)
+ STTN字幕檢測 帶字幕檢測版,無智能擦除(最低4GB顯存)
+ LAMA:對於動畫類視頻效果好,速度一般(顯存要求較低)
+ ProPainter:需要消耗大量顯存,速度較慢,對運動非常劇烈的視頻效果較好(最低8GB顯存)
+ OpenCV:極速模式,不保證inpaint效果,僅僅對包含文本的區域文本進行去除(顯存要求較低)
+SubtitleDetectMode = 字幕檢測模式
+ErrorDuringProcessing = 處理過程中發生錯誤: {}
+
+[Main]
+SubtitleDetectionAcceleratorON = 字幕檢測使用{}進行加速
+OnnxExectionProviderNotSupportedSkipped = ONNX 執行提供者: {} 不支援,已略過。
+OnnxExecutionProviderDetected = 偵測到 ONNX 執行提供者: {}
+OnnxRuntimeNotInstall = ONNX 執行環境未安裝,已略過。
+NoSubtitleDetected = 未檢測到任何字幕,請檢查檔案{}是否正確
+DirectMLWarning = 警告:DirectML 加速僅適用於 STTN 模型,其他模型將使用 CPU 執行。
+ProcessingStartFindingSubtitles = [處理中] 開始搜尋字幕...
+FinishedFindingSubtitles = [結束] 搜尋字幕完成...
+ProcessingStartRemovingSubtitles = [處理中] 開始移除字幕...
+UseModel = 去除字幕使用模型: {}
+FullScreenProcessingNote = 未設定字幕區域,將對全螢幕進行處理,最終效果可能不理想
+ReadFileFailed = 讀取檔案 {} 失敗
+FinishedProcessing = [完成] 字幕移除成功,檔案已儲存至 {}
+ProcessingTime = 處理時間: {}秒
+FailToMergeAudio = 合併音訊失敗: {}
+FailToExtractAudio = 提取音訊失敗: {}
+CopyFileFailed = 複製檔案 {} 至 {} 失敗,原因: {}
+
+[TaskList]
+Pending = 待處理
+Processing = 處理中
+Completed = 已完成
+Failed = 失敗
+Name = 名稱
+Progress = ⠀進度⠀
+Status = ㅤ狀態ㅤ
+OpenSourceVideoLocation = 開啟原始影片位置
+OpenTargetVideoLocation = 開啟目標影片位置
+ResetTaskStatus = 重設任務狀態
+DeleteTask = 刪除任務
+Warning = 警告
+UnableToLocateFile = 找不到檔案,可能已被移動或刪除
+TargetFileNotFound = 檔案尚未生成,請先等待任務完成
+
+[VersionService]
+VersionInfo = 當前版本: {} 最新版本: {}
+RequestError = 嘗試存取 {} 失敗,原因: {}
+
+[InpaintMode]
+SttnAuto = STTN智慧擦除
+SttnDet = STTN字幕檢測
+LAMA = LAMA
+ProPainter = ProPainter
+OpenCV = OpenCV
+
+[SubtitleDetectMode]
+Fast = 快速
+Accurate = 精準
+
+[InterfaceLanguage]
+ChineseSimplified = 简体中文
+ChineseTraditional = 繁體中文
+English = English
+Japanese = 日本語
+Korean = 한국어
+Vietnamese = Tiếng Việt
+Español = Español
\ No newline at end of file
diff --git a/backend/interface/en.ini b/backend/interface/en.ini
new file mode 100644
index 0000000..417e274
--- /dev/null
+++ b/backend/interface/en.ini
@@ -0,0 +1,129 @@
+[Setting]
+FeedbackButton = Provide Feedback
+FeedbackTitle = Provide Feedback
+FeedbackDesc = Help us improve by providing feedback
+CopyrightButton = Check Updates
+CopyrightTitle = About
+CopyrightDesc = © Copyright 2023, YaoFANGUK, Jason Eric (UI Design), Current Version: {}
+ProjectLinkTitle = Subtitle Remover
+ProjectLinkDesc = AI-based image/video hard subtitle removal and text watermark removal, generating output files with original resolution. No third-party API required, locally implemented.
+BasicSetting = Basic Settings
+AdvancedSetting = Advanced Settings
+SubtitleDetectionSetting = Subtitle Detection Settings
+SttnSetting = STTN Settings
+ProPainterSetting = ProPainter Settings
+AboutSetting = About
+HardwareAcceleration = Hardware Acceleration
+HardwareAccelerationDesc = Accelerate processing using GPU or ONNX backend
+SubtitleYXAxisDifferencePixel = Height-Width Pixel Difference Threshold
+SubtitleYXAxisDifferencePixelDesc = Determines non-subtitle regions (default 10). Subtitles generally have longer text boxes; if height exceeds width beyond this threshold, considered false detection.
+SubtitleAreaDeviationPixel = Allowed Pixel Offset
+SubtitleAreaDeviationPixelDesc = Enlarge mask size to prevent residual text edges during inpaint (default 10).
+SubtitleAreaYAxisDifferencePixel = Same-line Subtitle Height Difference Threshold
+SubtitleAreaYAxisDifferencePixelDesc = Determines if two text boxes belong to the same subtitle line (default 20 pixels).
+SubtitleAreaPixelToleranceYPixel = Y-axis Pixel Tolerance
+SubtitleAreaPixelToleranceYPixelDesc = Determines similarity between subtitle boxes on Y-axis (default 20 pixels).
+SubtitleAreaPixelToleranceXPixel = X-axis Pixel Tolerance
+SubtitleAreaPixelToleranceXPixelDesc = Determines similarity between subtitle boxes on X-axis (default 20 pixels).
+SubtitleTimelineBackwardFrameCount = Timeline Backward Frames
+SubtitleTimelineBackwardFrameCountDesc = Add frames before detected subtitles to handle slow fade-ins (default 3 frames).
+subtitleTimelineForwardFrameCount = Timeline Forward Frames
+subtitleTimelineForwardFrameCountDesc = Add frames after detected subtitles to handle slow fade-outs (default 3 frames).
+SttnNeighborStride = Reference Frame Stride
+SttnNeighborStrideDesc = Default: 5
+SttnReferenceLength = Reference Frame Count
+SttnReferenceLengthDesc = Default: 10
+SttnMaxLoadNum = Max Concurrent Processing Frames
+SttnMaxLoadNumDesc = Higher values improve quality but require more VRAM (default 50).
+PropainterMaxLoadNum = Max Concurrent Processing Frames
+PropainterMaxLoadNumDesc = Higher values improve quality but require more VRAM (default 70).
+CheckUpdateOnStartup = Check Updates on Startup
+CheckUpdateOnStartupDesc = New versions offer improved stability and features (recommended).
+UpdatesAvailableTitle = Update Available
+UpdatesAvailableDesc = New version {} found. Update now?
+NoUpdatesAvailableTitle = No Updates Available
+NoUpdatesAvailableDesc = Software is up-to-date.
+
+[SubtitleExtractorGUI]
+Title = Subtitle Remover
+Open = Open
+AllFile = All Files
+Vertical = Vertical
+Horizontal = Horizontal
+Run = Run
+Stop = Stop
+Setting = Settings
+OpenVideoSuccess = Video opened successfully
+OpenVideoFailed = Failed to open video: {} (invalid format or corrupted file)
+OpenVideoFirst = Please open a video first
+SubtitleArea = Subtitle Area
+VideoPreview = Video Preview
+InterfaceLanguage = Interface Language
+InpaintMode = Processing Model
+SelectSubtitleArea = Select processing area in video preview
+InpaintModeDesc = STTN Smart Inpainting: Best for real-person videos, fast speed, smart inpainting (minimum 4GB VRAM)
+ STTN Subtitle Detection: With subtitle detection, no smart inpainting (minimum 4GB VRAM)
+ LAMA: Good for animation videos, moderate speed (low VRAM requirement)
+ ProPainter: Consumes a lot of VRAM, slower speed, best for videos with intense motion (minimum 8GB VRAM)
+ OpenCV: Ultra-fast mode, inpainting effect not guaranteed, only removes text in detected regions (low VRAM requirement)
+SubtitleDetectMode = Subtitle Detection
+ErrorDuringProcessing = Error during processing: {}
+
+[Main]
+SubtitleDetectionAcceleratorON = Subtitle detection accelerated with {}
+OnnxExectionProviderNotSupportedSkipped = ONNX provider: {} not supported, skipped.
+OnnxExecutionProviderDetected = Detected ONNX provider: {}
+OnnxRuntimeNotInstall = ONNX runtime not installed, skipped.
+NoSubtitleDetected = No subtitles detected. Check file: {}
+DirectMLWarning = Warning: DirectML acceleration only works with STTN model.
+ProcessingStartFindingSubtitles = [Processing] Detecting subtitles...
+FinishedFindingSubtitles = [Complete] Subtitle detection finished.
+ProcessingStartRemovingSubtitles = [Processing] Removing subtitles...
+UseModel = Use model for subtitle removal: {}
+FullScreenProcessingNote = Processing full screen (no area selected). Quality may vary.
+ReadFileFailed = Failed to read file: {}
+FinishedProcessing = [Complete] Subtitles removed. Output saved to: {}
+ProcessingTime = Processing time: {} seconds
+FailToMergeAudio = Audio merge failed: {}
+FailToExtractAudio = Audio extraction failed: {}
+CopyFileFailed = Failed to copy {} to {}. Reason: {}
+
+[TaskList]
+Pending = Pending
+Processing = Processing
+Completed = Completed
+Failed = Failed
+Name = Name
+Progress = ⠀Progress⠀
+Status = ㅤStatusㅤ
+OpenSourceVideoLocation = Open Source File
+OpenTargetVideoLocation = Open Output File
+ResetTaskStatus = Reset Task
+DeleteTask = Delete Task
+Warning = Warning
+UnableToLocateFile = File not found (may be moved/deleted)
+TargetFileNotFound = Output file not generated. Wait for completion.
+
+[VersionService]
+VersionInfo = Current: {} Latest: {}
+RequestError = Failed to access {}. Reason: {}
+
+[InpaintMode]
+SttnAuto = STTN Smart Erase
+SttnDet = STTN Detection
+LAMA = LAMA
+ProPainter = ProPainter
+OpenCV = OpenCV
+
+[SubtitleDetectMode]
+Fast = Fast
+Accurate = Accurate
+
+[InterfaceLanguage]
+ChineseSimplified = 简体中文
+ChineseTraditional = 繁體中文
+English = English
+Japanese = 日本語
+Korean = 한국어
+Vietnamese = Tiếng Việt
+Español = Español
\ No newline at end of file
diff --git a/backend/interface/es.ini b/backend/interface/es.ini
new file mode 100644
index 0000000..e38666e
--- /dev/null
+++ b/backend/interface/es.ini
@@ -0,0 +1,129 @@
+[Setting]
+FeedbackButton = Proporcionar comentarios
+FeedbackTitle = Proporcionar comentarios
+FeedbackDesc = Ayúdanos a mejorar enviando tus comentarios
+CopyrightButton = Buscar actualizaciones
+CopyrightTitle = Acerca de
+CopyrightDesc = © Derechos reservados 2023, YaoFANGUK, Jason Eric (Diseño UI), Versión actual: {}
+ProjectLinkTitle = Eliminador de subtítulos
+ProjectLinkDesc = Eliminación IA de subtítulos duros y marcas de agua en imágenes/videos, genera archivos en resolución original. Implementación local sin APIs externas.
+BasicSetting = Configuración básica
+AdvancedSetting = Configuración avanzada
+SubtitleDetectionSetting = Detección de subtítulos
+SttnSetting = Configuración STTN
+ProPainterSetting = Configuración ProPainter
+AboutSetting = Acerca de
+HardwareAcceleration = Aceleración hardware
+HardwareAccelerationDesc = Usar GPU o backend ONNX para acelerar el procesamiento
+SubtitleYXAxisDifferencePixel = Umbral diferencia alto/ancho
+SubtitleYXAxisDifferencePixelDesc = Determina áreas no subtituladas (valor predeterminado 10). Los subtítulos suelen ser rectángulos horizontales.
+SubtitleAreaDeviationPixel = Margen de píxeles permitido
+SubtitleAreaDeviationPixelDesc = Amplía la máscara para evitar bordes residuales (valor predeterminado 10).
+SubtitleAreaYAxisDifferencePixel = Umbral altura misma línea
+SubtitleAreaYAxisDifferencePixelDesc = Determina si dos cuadros están en la misma línea (valor predeterminado 20 píxeles).
+SubtitleAreaPixelToleranceYPixel = Tolerancia eje Y
+SubtitleAreaPixelToleranceYPixelDesc = Determina similitud vertical entre subtítulos (valor predeterminado 20).
+SubtitleAreaPixelToleranceXPixel = Tolerancia eje X
+SubtitleAreaPixelToleranceXPixelDesc = Determina similitud horizontal entre subtítulos (valor predeterminado 20).
+SubtitleTimelineBackwardFrameCount = Retroceso en línea temporal
+SubtitleTimelineBackwardFrameCountDesc = Añade fotogramas antes de subtítulos detectados (valor predeterminado 3).
+subtitleTimelineForwardFrameCount = Avance en línea temporal
+subtitleTimelineForwardFrameCountDesc = Añade fotogramas después de subtítulos detectados (valor predeterminado 3).
+SttnNeighborStride = Intervalo de referencia
+SttnNeighborStrideDesc = Valor predeterminado: 5
+SttnReferenceLength = Cantidad de referencias
+SttnReferenceLengthDesc = Valor predeterminado: 10
+SttnMaxLoadNum = Máx. fotogramas simultáneos
+SttnMaxLoadNumDesc = Mayor valor mejora calidad pero requiere más VRAM (valor predeterminado 50).
+PropainterMaxLoadNum = Máx. fotogramas simultáneos
+PropainterMaxLoadNumDesc = Mayor valor mejora calidad pero requiere más VRAM (valor predeterminado 70).
+CheckUpdateOnStartup = Buscar actualizaciones al iniciar
+CheckUpdateOnStartupDesc = Versiones nuevas ofrecen mejor estabilidad y funciones (recomendado).
+UpdatesAvailableTitle = Actualización disponible
+UpdatesAvailableDesc = Nueva versión {} disponible. ¿Actualizar ahora?
+NoUpdatesAvailableTitle = Sin actualizaciones
+NoUpdatesAvailableDesc = El software está actualizado.
+
+[SubtitleExtractorGUI]
+Title = Eliminador de subtítulos
+Open = Abrir
+AllFile = Todos los archivos
+Vertical = Vertical
+Horizontal = Horizontal
+Run = Ejecutar
+Stop = Detener
+Setting = Configuración
+OpenVideoSuccess = Video abierto correctamente
+OpenVideoFailed = Error al abrir video: {} (formato incompatible o archivo dañado)
+OpenVideoFirst = Abre un video primero
+SubtitleArea = Área de subtítulos
+VideoPreview = Vista previa
+InterfaceLanguage = Idioma de interfaz
+InpaintMode = Modelo de procesamiento
+SelectSubtitleArea = Selecciona área en vista previa
+InpaintModeDesc = STTN Borrado inteligente: Mejor para videos de personas reales, velocidad rápida, borrado inteligente (mínimo 4GB de VRAM)
+ STTN Detección de subtítulos: Con detección de subtítulos, sin borrado inteligente (mínimo 4GB de VRAM)
+ LAMA: Bueno para videos animados, velocidad media (bajo requerimiento de VRAM)
+ ProPainter: Consume mucha VRAM, velocidad lenta, mejor para videos con mucho movimiento (mínimo 8GB de VRAM)
+ OpenCV: Modo ultra rápido, el efecto de borrado no está garantizado, solo elimina texto en las áreas detectadas (bajo requerimiento de VRAM)
+SubtitleDetectMode = Detección de subtítulos
+ErrorDuringProcessing = Error durante el procesamiento: {}
+
+[Main]
+SubtitleDetectionAcceleratorON = Detección de subtítulos acelerada con {}
+OnnxExectionProviderNotSupportedSkipped = Proveedor ONNX: {} no soportado, omitido.
+OnnxExecutionProviderDetected = Proveedor ONNX detectado: {}
+OnnxRuntimeNotInstall = Entorno ONNX no instalado, omitido.
+NoSubtitleDetected = Sin subtítulos detectados. Verifica archivo: {}
+DirectMLWarning = Advertencia: Aceleración DirectML solo funciona con modelo STTN.
+ProcessingStartFindingSubtitles = [Procesando] Detectando subtítulos...
+FinishedFindingSubtitles = [Completo] Detección finalizada.
+ProcessingStartRemovingSubtitles = [Procesando] Eliminando subtítulos...
+UseModel = Usar modelo para eliminar subtítulos: {}
+FullScreenProcessingNote = Procesando pantalla completa (área no seleccionada).
+ReadFileFailed = Error al leer archivo: {}
+FinishedProcessing = [Completo] Subtítulos eliminados. Guardado en: {}
+ProcessingTime = Tiempo procesamiento: {} segundos
+FailToMergeAudio = Error mezclando audio: {}
+FailToExtractAudio = Error extrayendo audio: {}
+CopyFileFailed = Error copiando {} a {}. Razón: {}
+
+[TaskList]
+Pending = Pendiente
+Processing = Procesando
+Completed = Completado
+Failed = Fallado
+Name = Nombre
+Progress = ⠀Progreso⠀
+Status = ㅤEstadoㅤ
+OpenSourceVideoLocation = Abrir ubicación original
+OpenTargetVideoLocation = Abrir archivo resultante
+ResetTaskStatus = Reiniciar tarea
+DeleteTask = Eliminar tarea
+Warning = Advertencia
+UnableToLocateFile = Archivo no encontrado (posiblemente movido/eliminado)
+TargetFileNotFound = Archivo resultado no generado. Espera a completar.
+
+[VersionService]
+VersionInfo = Versión actual: {} Última versión: {}
+RequestError = Error accediendo {}. Razón: {}
+
+[InpaintMode]
+SttnAuto = STTN borrado inteligente
+SttnDet = STTN detección
+LAMA = LAMA
+ProPainter = ProPainter
+OpenCV = OpenCV
+
+[SubtitleDetectMode]
+Fast = Rápido
+Accurate = Preciso
+
+[InterfaceLanguage]
+ChineseSimplified = 简体中文
+ChineseTraditional = 繁體中文
+English = English
+Japanese = 日本語
+Korean = 한국어
+Vietnamese = Tiếng Việt
+Español = Español
\ No newline at end of file
diff --git a/backend/interface/japan.ini b/backend/interface/japan.ini
new file mode 100644
index 0000000..aa0cd16
--- /dev/null
+++ b/backend/interface/japan.ini
@@ -0,0 +1,129 @@
+[Setting]
+FeedbackButton = フィードバック
+FeedbackTitle = フィードバック提供
+FeedbackDesc = フィードバックを通じて改善にご協力ください
+CopyrightButton = アップデート確認
+CopyrightTitle = バージョン情報
+CopyrightDesc = ©著作権 2023, YaoFANGUK, Jason Eric (UIデザイン), 現在のバージョン: {}
+ProjectLinkTitle = 字幕除去ツール
+ProjectLinkDesc = AIベースの画像/動画ハード字幕除去、テキスト透かし除去。解像度を維持したまま字幕・透かしを除去したファイルを生成。サードパーティAPI不要のローカル実装。
+BasicSetting = 基本設定
+AdvancedSetting = 高度設定
+SubtitleDetectionSetting = 字幕検出設定
+SttnSetting = STTN設定
+ProPainterSetting = ProPainter設定
+AboutSetting = 情報
+HardwareAcceleration = ハードウェアアクセラレーション
+HardwareAccelerationDesc = GPUまたはONNXバックエンドを使用した高速処理
+SubtitleYXAxisDifferencePixel = 高さ/幅ピクセル差しきい値
+SubtitleYXAxisDifferencePixelDesc = 非字幕領域判定基準(デフォルト10)。字幕ボックスは通常幅より高さが短い
+SubtitleAreaDeviationPixel = 許容ピクセル偏差
+SubtitleAreaDeviationPixelDesc = マスクサイズ拡張(デフォルト10)。小さなテキストボックス防止
+SubtitleAreaYAxisDifferencePixel = 同行字幕高さ差しきい値
+SubtitleAreaYAxisDifferencePixelDesc = 同一行字幕判定基準(デフォルト20ピクセル)
+SubtitleAreaPixelToleranceYPixel = Y軸許容偏差
+SubtitleAreaPixelToleranceYPixelDesc = 類似字幕ボックスY軸偏差基準(デフォルト20)
+SubtitleAreaPixelToleranceXPixel = X軸許容偏差
+SubtitleAreaPixelToleranceXPixelDesc = 類似字幕ボックスX軸偏差基準(デフォルト20)
+SubtitleTimelineBackwardFrameCount = 字幕タイムライン後退フレーム数
+SubtitleTimelineBackwardFrameCountDesc = 徐々に出現する字幕処理用追加フレーム(デフォルト3)
+subtitleTimelineForwardFrameCount = 字幕タイムライン前進フレーム数
+subtitleTimelineForwardFrameCountDesc = 徐々に消える字幕処理用追加フレーム(デフォルト3)
+SttnNeighborStride = 参照フレーム間隔
+SttnNeighborStrideDesc = デフォルト: 5
+SttnReferenceLength = 参照フレーム数
+SttnReferenceLengthDesc = デフォルト: 10
+SttnMaxLoadNum = 最大同時処理フレーム数
+SttnMaxLoadNumDesc = 値が大きいほど高品質(VRAM要求増加、デフォルト50)
+PropainterMaxLoadNum = 最大同時処理フレーム数
+PropainterMaxLoadNumDesc = 値が大きいほど高品質(VRAM要求増加、デフォルト70)
+CheckUpdateOnStartup = 起動時アップデート確認
+CheckUpdateOnStartupDesc = 新バージョンは安定性/機能向上(推奨)
+UpdatesAvailableTitle = 利用可能なアップデート
+UpdatesAvailableDesc = 新バージョン {} を発見。更新しますか?
+NoUpdatesAvailableTitle = 利用可能なアップデートなし
+NoUpdatesAvailableDesc = 最新バージョンです
+
+[SubtitleExtractorGUI]
+Title = 字幕除去ツール
+Open = 開く
+AllFile = 全てのファイル
+Vertical = 垂直方向
+Horizontal = 水平方向
+Run = 実行
+Stop = 停止
+Setting = 設定
+OpenVideoSuccess = 動画を正常に開きました
+OpenVideoFailed = 動画を開けません: {}(形式非対応/ファイル破損)
+OpenVideoFirst = 動画を先に開いてください
+SubtitleArea = 字幕領域
+VideoPreview = 動画プレビュー
+InterfaceLanguage = インターフェース言語
+InpaintMode = 処理モデル
+SelectSubtitleArea = プレビューで処理領域を選択
+InpaintModeDesc = STTNスマート消去:実写動画に最適、高速、スマート消去(最低4GB VRAM)
+ STTN字幕検出:字幕検出付き、スマート消去なし(最低4GB VRAM)
+ LAMA:アニメ動画に最適、速度は普通(VRAM要件低め)
+ ProPainter:大量のVRAMを消費、速度は遅い、激しい動きの動画に最適(最低8GB VRAM)
+ OpenCV:超高速モード、消去効果は保証されません、検出されたテキスト領域のみ削除(VRAM要件低め)
+SubtitleDetectMode = 字幕検出
+ErrorDuringProcessing = 処理中にエラーが発生しました: {}
+
+[Main]
+SubtitleDetectionAcceleratorON = 字幕検出を{}で加速
+OnnxExectionProviderNotSupportedSkipped = ONNXプロバイダ: {} 非対応
+OnnxExecutionProviderDetected=ONNXプロバイダ検出: {}
+OnnxRuntimeNotInstall = ONNXランタイム未インストール
+NoSubtitleDetected = 字幕未検出。ファイル確認: {}
+DirectMLWarning = 警告: DirectML加速はSTTNモデルのみ
+ProcessingStartFindingSubtitles = [処理中] 字幕検索開始...
+FinishedFindingSubtitles = [完了] 字幕検索終了
+ProcessingStartRemovingSubtitles = [処理中] 字幕削除開始...
+UseModel = 字幕除去用モデル: {}
+FullScreenProcessingNote = 全画面処理(領域未選択)
+ReadFileFailed = ファイル読み込み失敗: {}
+FinishedProcessing = [完了] 字幕削除成功。保存先: {}
+ProcessingTime = 処理時間: {}秒
+FailToMergeAudio = 音声統合失敗: {}
+FailToExtractAudio = 音声抽出失敗: {}
+CopyFileFailed = ファイルコピー失敗 {} → {}。理由: {}
+
+[TaskList]
+Pending = 待機中
+Processing = 処理中
+Completed = 完了
+Failed = 失敗
+Name = 名称
+Progress = ⠀進捗⠀
+Status = ㅤ状態ㅤ
+OpenSourceVideoLocation = 元動画場所を開く
+OpenTargetVideoLocation = 結果動画場所を開く
+ResetTaskStatus = タスク状態リセット
+DeleteTask = タスク削除
+Warning = 警告
+UnableToLocateFile = ファイルが見つかりません
+TargetFileNotFound = ファイル未生成(処理完了待機)
+
+[VersionService]
+VersionInfo = 現在バージョン: {} 最新バージョン: {}
+RequestError = {} へのアクセス失敗。理由: {}
+
+[InpaintMode]
+SttnAuto = STTNインテリジェント消去
+SttnDet = STTN字幕検出
+LAMA = LAMA
+ProPainter = ProPainter
+OpenCV = OpenCV
+
+[SubtitleDetectMode]
+Fast = 高速
+Accurate = 高精度
+
+[InterfaceLanguage]
+ChineseSimplified = 简体中文
+ChineseTraditional = 繁體中文
+English = English
+Japanese = 日本語
+Korean = 한국어
+Vietnamese = Tiếng Việt
+Español = Español
\ No newline at end of file
diff --git a/backend/interface/ko.ini b/backend/interface/ko.ini
new file mode 100644
index 0000000..abef289
--- /dev/null
+++ b/backend/interface/ko.ini
@@ -0,0 +1,129 @@
+[Setting]
+FeedbackButton = 피드백
+FeedbackTitle = 피드백 제공
+FeedbackDesc = 피드백을 통해 개선에 도움을 주세요
+CopyrightButton = 업데이트 확인
+CopyrightTitle = 정보
+CopyrightDesc = ©저작권 2023, YaoFANGUK, Jason Eric (UI 디자인), 현재 버전: {}
+ProjectLinkTitle = 자막 제거 도구
+ProjectLinkDesc = AI 기반 이미지/동영상 하드 자막 제거, 텍스트 워터마크 제거. 원본 해상도 유지하며 자막 및 워터마크 제거된 파일 생성. 타사 API 불필요, 로컬 처리 구현.
+BasicSetting = 기본 설정
+AdvancedSetting = 고급 설정
+SubtitleDetectionSetting = 자막 감지 설정
+SttnSetting = STTN 설정
+ProPainterSetting = ProPainter 설정
+AboutSetting = 정보
+HardwareAcceleration = 하드웨어 가속
+HardwareAccelerationDesc = GPU 또는 ONNX 백엔드 사용 가속 처리
+SubtitleYXAxisDifferencePixel = 높이-너비 픽셀 차이 임계값
+SubtitleYXAxisDifferencePixelDesc = 비자막 영역 판단 기준 (기본값 10). 자막 상자는 일반적으로 가로 길이가 세로보다 큽니다.
+SubtitleAreaDeviationPixel = 허용 픽셀 편차
+SubtitleAreaDeviationPixelDesc = 마스크 크기 확장 (기본값 10). 작은 텍스트 상자 방지
+SubtitleAreaYAxisDifferencePixel = 동일 행 자막 높이 차 임계값
+SubtitleAreaYAxisDifferencePixelDesc = 동일 행 자막 판단 기준 (기본값 20픽셀)
+SubtitleAreaPixelToleranceYPixel = Y축 허용 편차
+SubtitleAreaPixelToleranceYPixelDesc = 유사 자막 상자 Y축 편차 기준 (기본값 20)
+SubtitleAreaPixelToleranceXPixel = X축 허용 편차
+SubtitleAreaPixelToleranceXPixelDesc = 유사 자막 상자 X축 편차 기준 (기본값 20)
+SubtitleTimelineBackwardFrameCount = 타임라인 역방향 프레임 수
+SubtitleTimelineBackwardFrameCountDesc = 점진적 시작 자막 처리용 추가 프레임 (기본값 3)
+subtitleTimelineForwardFrameCount = 타임라인 순방향 프레임 수
+subtitleTimelineForwardFrameCountDesc = 점진적 종료 자막 처리용 추가 프레임 (기본값 3)
+SttnNeighborStride = 참조 프레임 간격
+SttnNeighborStrideDesc = 기본값: 5
+SttnReferenceLength = 참조 프레임 수
+SttnReferenceLengthDesc = 기본값: 10
+SttnMaxLoadNum = 최대 동시 처리 프레임
+SttnMaxLoadNumDesc = 값 클수록 품질 향상 (VRAM 요구 증가, 기본값 50)
+PropainterMaxLoadNum = 최대 동시 처리 프레임
+PropainterMaxLoadNumDesc = 값 클수록 품질 향상 (VRAM 요구 증가, 기본값 70)
+CheckUpdateOnStartup = 시작시 업데이트 확인
+CheckUpdateOnStartupDesc = 새 버전은 안정성/기능 개선 포함 (권장)
+UpdatesAvailableTitle = 업데이트 가능
+UpdatesAvailableDesc = 새 버전 {} 발견. 업데이트할까요?
+NoUpdatesAvailableTitle = 사용 가능한 업데이트 없음
+NoUpdatesAvailableDesc = 최신 버전입니다.
+
+[SubtitleExtractorGUI]
+Title = 자막 제거 도구
+Open = 열기
+AllFile = 모든 파일
+Vertical = 수직
+Horizontal = 수평
+Run = 실행
+Stop = 중지
+Setting = 설정
+OpenVideoSuccess = 동영상 열기 성공
+OpenVideoFailed = 동영상 열기 실패: {} (형식 불일치/파일 손상)
+OpenVideoFirst = 동영상을 먼저 열어주세요
+SubtitleArea = 자막 영역
+VideoPreview = 동영상 미리보기
+InterfaceLanguage = 인터페이스 언어
+InpaintMode = 처리 모델
+SelectSubtitleArea = 미리보기에서 처리 영역 선택
+InpaintModeDesc = STTN 스마트 지우기: 실제 인물 영상에 적합, 빠른 속도, 스마트 지우기(최소 4GB VRAM)
+ STTN 자막 감지: 자막 감지 버전, 스마트 지우기 없음(최소 4GB VRAM)
+ LAMA: 애니메이션 영상에 적합, 보통 속도(VRAM 요구량 낮음)
+ ProPainter: 많은 VRAM 소모, 느린 속도, 격렬한 움직임 영상에 적합(최소 8GB VRAM)
+ OpenCV: 초고속 모드, 인페인트 효과 보장 안 됨, 텍스트 영역만 제거(VRAM 요구량 낮음)
+SubtitleDetectMode = 자막 감지
+ErrorDuringProcessing = 처리 중 오류: {}
+
+[Main]
+SubtitleDetectionAcceleratorON = 자막 감지 {} 가속 사용
+OnnxExectionProviderNotSupportedSkipped = ONNX 공급자: {} 지원 안됨
+OnnxExecutionProviderDetected=ONNX 공급자 감지: {}
+OnnxRuntimeNotInstall = ONNX 런타임 미설치
+NoSubtitleDetected = 자막 없음. 파일 확인: {}
+DirectMLWarning = 경고: DirectML 가속은 STTN 모델 전용
+ProcessingStartFindingSubtitles = [진행중] 자막 검색 시작...
+FinishedFindingSubtitles = [완료] 자막 검색 완료
+ProcessingStartRemovingSubtitles = [진행중] 자막 제거 시작...
+UseModel = 자막 제거 모델 사용: {}
+FullScreenProcessingNote = 전체 화면 처리 (영역 미선택)
+ReadFileFailed = 파일 읽기 실패: {}
+FinishedProcessing = [완료] 자막 제거 완료. 저장 위치: {}
+ProcessingTime = 처리 시간: {}초
+FailToMergeAudio = 오디오 병합 실패: {}
+FailToExtractAudio = 오디오 추출 실패: {}
+CopyFileFailed = 파일 복사 실패 {} → {}. 이유: {}
+
+[TaskList]
+Pending = 대기중
+Processing = 처리중
+Completed = 완료됨
+Failed = 실패
+Name = 이름
+Progress = ⠀진행률⠀
+Status = ㅤ상태ㅤ
+OpenSourceVideoLocation = 원본 파일 위치 열기
+OpenTargetVideoLocation = 결과 파일 위치 열기
+ResetTaskStatus = 작업 상태 재설정
+DeleteTask = 작업 삭제
+Warning = 경고
+UnableToLocateFile = 파일 찾을 수 없음
+TargetFileNotFound = 결과 파일 미생성 (작업 완료 대기)
+
+[VersionService]
+VersionInfo = 현재 버전: {} 최신 버전: {}
+RequestError = {} 접근 실패. 이유: {}
+
+[InpaintMode]
+SttnAuto = STTN 지능형 제거
+SttnDet = STTN 자막 감지
+LAMA = LAMA
+ProPainter = ProPainter
+OpenCV = OpenCV
+
+[SubtitleDetectMode]
+Fast = 빠름
+Accurate = 정확
+
+[InterfaceLanguage]
+ChineseSimplified = 简体中文
+ChineseTraditional = 繁體中文
+English = English
+Japanese = 日本語
+Korean = 한국어
+Vietnamese = Tiếng Việt
+Español = Español
\ No newline at end of file
diff --git a/backend/interface/vi.ini b/backend/interface/vi.ini
new file mode 100644
index 0000000..894aeed
--- /dev/null
+++ b/backend/interface/vi.ini
@@ -0,0 +1,129 @@
+[Setting]
+FeedbackButton = Gửi phản hồi
+FeedbackTitle = Gửi phản hồi
+FeedbackDesc = Giúp chúng tôi cải thiện bằng cách gửi phản hồi
+CopyrightButton = Kiểm tra cập nhật
+CopyrightTitle = Giới thiệu
+CopyrightDesc = ©Bản quyền 2023, YaoFANGUK, Jason Eric (Thiết kế UI), Phiên bản hiện tại: {}
+ProjectLinkTitle = Công cụ xóa phụ đề
+ProjectLinkDesc = Xóa phụ đề cứng và watermark văn bản từ ảnh/video bằng AI, tạo file đầu ra giữ nguyên độ phân giải. Không cần API bên thứ ba, xử lý cục bộ.
+BasicSetting = Cài đặt cơ bản
+AdvancedSetting = Cài đặt nâng cao
+SubtitleDetectionSetting = Cài đặt phát hiện phụ đề
+SttnSetting = Cài đặt STTN
+ProPainterSetting = Cài đặt ProPainter
+AboutSetting = Giới thiệu
+HardwareAcceleration = Tăng tốc phần cứng
+HardwareAccelerationDesc = Sử dụng GPU hoặc backend ONNX để tăng tốc xử lý
+SubtitleYXAxisDifferencePixel = Ngưỡng chênh lệch chiều cao/rộng
+SubtitleYXAxisDifferencePixelDesc = Xác định vùng không phải phụ đề (mặc định 10). Hộp phụ đề thường có chiều dài lớn hơn chiều rộng.
+SubtitleAreaDeviationPixel = Độ lệch pixel cho phép
+SubtitleAreaDeviationPixelDesc = Mở rộng kích thước mask (mặc định 10), tránh hộp văn bản quá nhỏ
+SubtitleAreaYAxisDifferencePixel = Ngưỡng chênh lệch chiều cao cùng dòng
+SubtitleAreaYAxisDifferencePixelDesc = Xác định phụ đề cùng dòng (mặc định 20 pixel)
+SubtitleAreaPixelToleranceYPixel = Dung sai trục Y
+SubtitleAreaPixelToleranceYPixelDesc = Xác định hộp phụ đề tương tự theo trục Y (mặc định 20)
+SubtitleAreaPixelToleranceXPixel = Dung sai trục X
+SubtitleAreaPixelToleranceXPixelDesc = Xác định hộp phụ đề tương tự theo trục X (mặc định 20)
+SubtitleTimelineBackwardFrameCount = Số khung lùi timeline
+SubtitleTimelineBackwardFrameCountDesc = Thêm khung xử lý cho phụ đề xuất hiện dần (mặc định 3)
+subtitleTimelineForwardFrameCount = Số khung tiến timeline
+subtitleTimelineForwardFrameCountDesc = Thêm khung xử lý cho phụ đề biến mất dần (mặc định 3)
+SttnNeighborStride = Bước khung tham chiếu
+SttnNeighborStrideDesc = Mặc định: 5
+SttnReferenceLength = Số khung tham chiếu
+SttnReferenceLengthDesc = Mặc định: 10
+SttnMaxLoadNum = Số khung xử lý tối đa
+SttnMaxLoadNumDesc = Càng cao càng tốt (yêu cầu nhiều VRAM, mặc định 50)
+PropainterMaxLoadNum = Số khung xử lý tối đa
+PropainterMaxLoadNumDesc = Càng cao càng tốt (yêu cầu nhiều VRAM, mặc định 70)
+CheckUpdateOnStartup = Kiểm tra cập nhật khi khởi động
+CheckUpdateOnStartupDesc = Phiên bản mới ổn định hơn (khuyến nghị bật)
+UpdatesAvailableTitle = Có bản cập nhật
+UpdatesAvailableDesc = Phát hiện phiên bản mới {}, cập nhật?
+NoUpdatesAvailableTitle = Không có cập nhật
+NoUpdatesAvailableDesc = Đang dùng phiên bản mới nhất
+
+[SubtitleExtractorGUI]
+Title = Công cụ xóa phụ đề
+Open = Mở
+AllFile = Tất cả file
+Vertical = Dọc
+Horizontal = Ngang
+Run = Chạy
+Stop = Dừng
+Setting = Cài đặt
+OpenVideoSuccess = Mở video thành công
+OpenVideoFailed = Lỗi mở video: {} (định dạng không hỗ trợ)
+OpenVideoFirst = Vui lòng mở video trước
+SubtitleArea = Vùng phụ đề
+VideoPreview = Xem trước video
+InterfaceLanguage = Ngôn ngữ giao diện
+InpaintMode = Chế độ xử lý
+SelectSubtitleArea = Chọn vùng xử lý trong preview
+InpaintModeDesc = STTN Xóa thông minh: Phù hợp cho video người thật, tốc độ nhanh, xóa thông minh (tối thiểu 4GB VRAM)
+ STTN Phát hiện phụ đề: Có phát hiện phụ đề, không xóa thông minh (tối thiểu 4GB VRAM)
+ LAMA: Phù hợp cho video hoạt hình, tốc độ trung bình (yêu cầu VRAM thấp)
+ ProPainter: Tiêu tốn nhiều VRAM, tốc độ chậm, phù hợp cho video chuyển động mạnh (tối thiểu 8GB VRAM)
+ OpenCV: Chế độ siêu nhanh, không đảm bảo hiệu quả xóa, chỉ xóa vùng chứa văn bản (yêu cầu VRAM thấp)
+SubtitleDetectMode = Chế độ phát hiện
+ErrorDuringProcessing = Lỗi khi xử lý: {}
+
+[Main]
+SubtitleDetectionAcceleratorON = Phát hiện phụ đề được tăng tốc bằng {}
+OnnxExectionProviderNotSupportedSkipped = ONNX provider: {} không hỗ trợ
+OnnxExecutionProviderDetected= Phát hiện ONNX provider: {}
+OnnxRuntimeNotInstall = Chưa cài đặt ONNX Runtime
+NoSubtitleDetected = Không phát hiện phụ đề, kiểm tra file: {}
+DirectMLWarning = Cảnh báo: DirectML chỉ hỗ trợ STTN
+ProcessingStartFindingSubtitles = [Đang xử lý] Bắt đầu tìm phụ đề...
+FinishedFindingSubtitles = [Hoàn thành] Tìm phụ đề xong
+ProcessingStartRemovingSubtitles = [Đang xử lý] Bắt đầu xóa phụ đề...
+UseModel = Sử dụng mô hình xóa phụ đề: {}
+FullScreenProcessingNote = Xử lý toàn màn hình (không chọn vùng)
+ReadFileFailed = Lỗi đọc file: {}
+FinishedProcessing = [Hoàn thành] Xóa phụ đề thành công, lưu tại: {}
+ProcessingTime = Thời gian xử lý: {} giây
+FailToMergeAudio = Lỗi ghép audio: {}
+FailToExtractAudio = Lỗi trích xuất audio: {}
+CopyFileFailed = Lỗi sao chép {} → {}, lý do: {}
+
+[TaskList]
+Pending = Đang chờ
+Processing = Đang xử lý
+Completed = Hoàn thành
+Failed = Thất bại
+Name = Tên
+Progress = ⠀Tiến trình⠀
+Status = ㅤTrạng tháiㅤ
+OpenSourceVideoLocation = Mở vị trí video gốc
+OpenTargetVideoLocation = Mở vị trí file kết quả
+ResetTaskStatus = Đặt lại trạng thái
+DeleteTask = Xóa task
+Warning = Cảnh báo
+UnableToLocateFile = Không tìm thấy file
+TargetFileNotFound = File chưa được tạo
+
+[VersionService]
+VersionInfo = Phiên bản hiện tại: {} Mới nhất: {}
+RequestError = Lỗi truy cập {}, lý do: {}
+
+[InpaintMode]
+SttnAuto = STTN xóa thông minh
+SttnDet = STTN phát hiện
+LAMA = LAMA
+ProPainter = ProPainter
+OpenCV = OpenCV
+
+[SubtitleDetectMode]
+Fast = Nhanh
+Accurate = Chính xác
+
+[InterfaceLanguage]
+ChineseSimplified = 简体中文
+ChineseTraditional = 繁體中文
+English = English
+Japanese = 日本語
+Korean = 한국어
+Vietnamese = Tiếng Việt
+Español = Español
\ No newline at end of file
diff --git a/backend/main.py b/backend/main.py
index eb401d1..ca94c12 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -1,5 +1,7 @@
+import gc
import torch
import shutil
+import traceback
import subprocess
import os
from pathlib import Path
@@ -10,575 +12,43 @@ from functools import cached_property
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-import config
-from backend.tools.common_tools import is_video_or_image, is_image_file
-from backend.scenedetect import scene_detect
-from backend.scenedetect.detectors import ContentDetector
-from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint
+from backend.config import *
+from backend.tools.hardware_accelerator import HardwareAccelerator
+from backend.tools.common_tools import is_video_or_image, is_image_file, get_readable_path, read_image
+from backend.inpaint.sttn_auto_inpaint import STTNAutoInpaint
+from backend.inpaint.sttn_det_inpaint import STTNDetInpaint
from backend.inpaint.lama_inpaint import LamaInpaint
-from backend.inpaint.video_inpaint import VideoInpaint
-from backend.tools.inpaint_tools import create_mask, batch_generator
-import importlib
+from backend.inpaint.opencv_inpaint import OpenCVInpaint
+from backend.inpaint.propainter_inpaint import PropainterInpaint
+from backend.tools.inpaint_tools import create_mask, batch_generator, expand_frame_ranges
+from backend.tools.model_config import ModelConfig
+from backend.tools.ffmpeg_cli import FFmpegCLI
+from backend.tools.subtitle_detect import SubtitleDetect
import platform
import tempfile
import multiprocessing
from shapely.geometry import Polygon
import time
from tqdm import tqdm
-
-
-class SubtitleDetect:
- """
- 文本框检测类,用于检测视频帧中是否存在文本框
- """
-
- def __init__(self, video_path, sub_area=None):
- self.video_path = video_path
- self.sub_area = sub_area
-
- @cached_property
- def text_detector(self):
- import paddle
- paddle.disable_signal_handler()
- from paddleocr.tools.infer import utility
- from paddleocr.tools.infer.predict_det import TextDetector
- # 获取参数对象
- importlib.reload(config)
- args = utility.parse_args()
- args.det_algorithm = 'DB'
- args.det_model_dir = self.convertToOnnxModelIfNeeded(config.DET_MODEL_PATH)
- args.use_onnx=len(config.ONNX_PROVIDERS) > 0
- args.onnx_providers=config.ONNX_PROVIDERS
- return TextDetector(args)
-
- def detect_subtitle(self, img):
- dt_boxes, elapse = self.text_detector(img)
- return dt_boxes, elapse
-
- @staticmethod
- def get_coordinates(dt_box):
- """
- 从返回的检测框中获取坐标
- :param dt_box 检测框返回结果
- :return list 坐标点列表
- """
- coordinate_list = list()
- if isinstance(dt_box, list):
- for i in dt_box:
- i = list(i)
- (x1, y1) = int(i[0][0]), int(i[0][1])
- (x2, y2) = int(i[1][0]), int(i[1][1])
- (x3, y3) = int(i[2][0]), int(i[2][1])
- (x4, y4) = int(i[3][0]), int(i[3][1])
- xmin = max(x1, x4)
- xmax = min(x2, x3)
- ymin = max(y1, y2)
- ymax = min(y3, y4)
- coordinate_list.append((xmin, xmax, ymin, ymax))
- return coordinate_list
-
- def find_subtitle_frame_no(self, sub_remover=None):
- video_cap = cv2.VideoCapture(self.video_path)
- frame_count = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
- tbar = tqdm(total=int(frame_count), unit='frame', position=0, file=sys.__stdout__, desc='Subtitle Finding')
- current_frame_no = 0
- subtitle_frame_no_box_dict = {}
- print('[Processing] start finding subtitles...')
- while video_cap.isOpened():
- ret, frame = video_cap.read()
- # 如果读取视频帧失败(视频读到最后一帧)
- if not ret:
- break
- # 读取视频帧成功
- current_frame_no += 1
- dt_boxes, elapse = self.detect_subtitle(frame)
- coordinate_list = self.get_coordinates(dt_boxes.tolist())
- if coordinate_list:
- temp_list = []
- for coordinate in coordinate_list:
- xmin, xmax, ymin, ymax = coordinate
- if self.sub_area is not None:
- s_ymin, s_ymax, s_xmin, s_xmax = self.sub_area
- if (s_xmin <= xmin and xmax <= s_xmax
- and s_ymin <= ymin
- and ymax <= s_ymax):
- temp_list.append((xmin, xmax, ymin, ymax))
- else:
- temp_list.append((xmin, xmax, ymin, ymax))
- if len(temp_list) > 0:
- subtitle_frame_no_box_dict[current_frame_no] = temp_list
- tbar.update(1)
- if sub_remover:
- sub_remover.progress_total = (100 * float(current_frame_no) / float(frame_count)) // 2
- subtitle_frame_no_box_dict = self.unify_regions(subtitle_frame_no_box_dict)
- # if config.UNITE_COORDINATES:
- # subtitle_frame_no_box_dict = self.get_subtitle_frame_no_box_dict_with_united_coordinates(subtitle_frame_no_box_dict)
- # if sub_remover is not None:
- # try:
- # # 当帧数大于1时,说明并非图片或单帧
- # if sub_remover.frame_count > 1:
- # subtitle_frame_no_box_dict = self.filter_mistake_sub_area(subtitle_frame_no_box_dict,
- # sub_remover.fps)
- # except Exception:
- # pass
- # subtitle_frame_no_box_dict = self.prevent_missed_detection(subtitle_frame_no_box_dict)
- print('[Finished] Finished finding subtitles...')
- new_subtitle_frame_no_box_dict = dict()
- for key in subtitle_frame_no_box_dict.keys():
- if len(subtitle_frame_no_box_dict[key]) > 0:
- new_subtitle_frame_no_box_dict[key] = subtitle_frame_no_box_dict[key]
- return new_subtitle_frame_no_box_dict
-
- def convertToOnnxModelIfNeeded(self, model_dir, model_filename="inference.pdmodel", params_filename="inference.pdiparams", opset_version=14):
- """Converts a Paddle model to ONNX if ONNX providers are available and the model does not already exist."""
-
- if not config.ONNX_PROVIDERS:
- return model_dir
-
- onnx_model_path = os.path.join(model_dir, "model.onnx")
-
- if os.path.exists(onnx_model_path):
- print(f"ONNX model already exists: {onnx_model_path}. Skipping conversion.")
- return onnx_model_path
-
- print(f"Converting Paddle model {model_dir} to ONNX...")
- model_file = os.path.join(model_dir, model_filename)
- params_file = os.path.join(model_dir, params_filename) if params_filename else ""
-
- try:
- import paddle2onnx
- # Ensure the target directory exists
- os.makedirs(os.path.dirname(onnx_model_path), exist_ok=True)
-
- # Convert and save the model
- onnx_model = paddle2onnx.export(
- model_filename=model_file,
- params_filename=params_file,
- save_file=onnx_model_path,
- opset_version=opset_version,
- auto_upgrade_opset=True,
- verbose=True,
- enable_onnx_checker=True,
- enable_experimental_op=True,
- enable_optimize=True,
- custom_op_info={},
- deploy_backend="onnxruntime",
- calibration_file="calibration.cache",
- external_file=os.path.join(model_dir, "external_data"),
- export_fp16_model=False,
- )
-
- print(f"Conversion successful. ONNX model saved to: {onnx_model_path}")
- return onnx_model_path
- except Exception as e:
- print(f"Error during conversion: {e}")
- return model_dir
-
-
- @staticmethod
- def split_range_by_scene(intervals, points):
- # 确保离散值列表是有序的
- points.sort()
- # 用于存储结果区间的列表
- result_intervals = []
- # 遍历区间
- for start, end in intervals:
- # 在当前区间内的点
- current_points = [p for p in points if start <= p <= end]
-
- # 遍历当前区间内的离散点
- for p in current_points:
- # 如果当前离散点不是区间的起始点,添加从区间开始到离散点前一个数字的区间
- if start < p:
- result_intervals.append((start, p - 1))
- # 更新区间开始为当前离散点
- start = p
- # 添加从最后一个离散点或区间开始到区间结束的区间
- result_intervals.append((start, end))
- # 输出结果
- return result_intervals
-
- @staticmethod
- def get_scene_div_frame_no(v_path):
- """
- 获取发生场景切换的帧号
- """
- scene_div_frame_no_list = []
- scene_list = scene_detect(v_path, ContentDetector())
- for scene in scene_list:
- start, end = scene
- if start.frame_num == 0:
- pass
- else:
- scene_div_frame_no_list.append(start.frame_num + 1)
- return scene_div_frame_no_list
-
- @staticmethod
- def are_similar(region1, region2):
- """判断两个区域是否相似。"""
- xmin1, xmax1, ymin1, ymax1 = region1
- xmin2, xmax2, ymin2, ymax2 = region2
-
- return abs(xmin1 - xmin2) <= config.PIXEL_TOLERANCE_X and abs(xmax1 - xmax2) <= config.PIXEL_TOLERANCE_X and \
- abs(ymin1 - ymin2) <= config.PIXEL_TOLERANCE_Y and abs(ymax1 - ymax2) <= config.PIXEL_TOLERANCE_Y
-
- def unify_regions(self, raw_regions):
- """将连续相似的区域统一,保持列表结构。"""
- if len(raw_regions) > 0:
- keys = sorted(raw_regions.keys()) # 对键进行排序以确保它们是连续的
- unified_regions = {}
-
- # 初始化
- last_key = keys[0]
- unify_value_map = {last_key: raw_regions[last_key]}
-
- for key in keys[1:]:
- current_regions = raw_regions[key]
-
- # 新增一个列表来存放匹配过的标准区间
- new_unify_values = []
-
- for idx, region in enumerate(current_regions):
- last_standard_region = unify_value_map[last_key][idx] if idx < len(unify_value_map[last_key]) else None
-
- # 如果当前的区间与前一个键的对应区间相似,我们统一它们
- if last_standard_region and self.are_similar(region, last_standard_region):
- new_unify_values.append(last_standard_region)
- else:
- new_unify_values.append(region)
-
- # 更新unify_value_map为最新的区间值
- unify_value_map[key] = new_unify_values
- last_key = key
-
- # 将最终统一后的结果传递给unified_regions
- for key in keys:
- unified_regions[key] = unify_value_map[key]
- return unified_regions
- else:
- return raw_regions
-
- @staticmethod
- def find_continuous_ranges(subtitle_frame_no_box_dict):
- """
- 获取字幕出现的起始帧号与结束帧号
- """
- numbers = sorted(list(subtitle_frame_no_box_dict.keys()))
- ranges = []
- start = numbers[0] # 初始区间开始值
-
- for i in range(1, len(numbers)):
- # 如果当前数字与前一个数字间隔超过1,
- # 则上一个区间结束,记录当前区间的开始与结束
- if numbers[i] - numbers[i - 1] != 1:
- end = numbers[i - 1] # 则该数字是当前连续区间的终点
- ranges.append((start, end))
- start = numbers[i] # 开始下一个连续区间
- # 添加最后一个区间
- ranges.append((start, numbers[-1]))
- return ranges
-
- @staticmethod
- def find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict):
- numbers = sorted(list(subtitle_frame_no_box_dict.keys()))
- ranges = []
- start = numbers[0] # 初始区间开始值
- for i in range(1, len(numbers)):
- # 如果当前帧号与前一个帧号间隔超过1,
- # 则上一个区间结束,记录当前区间的开始与结束
- if numbers[i] - numbers[i - 1] != 1:
- end = numbers[i - 1] # 则该数字是当前连续区间的终点
- ranges.append((start, end))
- start = numbers[i] # 开始下一个连续区间
- # 如果当前帧号与前一个帧号间隔为1,且当前帧号对应的坐标点与上一帧号对应的坐标点不一致
- # 记录当前区间的开始与结束
- if numbers[i] - numbers[i - 1] == 1:
- if subtitle_frame_no_box_dict[numbers[i]] != subtitle_frame_no_box_dict[numbers[i - 1]]:
- end = numbers[i - 1] # 则该数字是当前连续区间的终点
- ranges.append((start, end))
- start = numbers[i] # 开始下一个连续区间
- # 添加最后一个区间
- ranges.append((start, numbers[-1]))
- return ranges
-
- @staticmethod
- def sub_area_to_polygon(sub_area):
- """
- xmin, xmax, ymin, ymax = sub_area
- """
- s_xmin = sub_area[0]
- s_xmax = sub_area[1]
- s_ymin = sub_area[2]
- s_ymax = sub_area[3]
- return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
-
- @staticmethod
- def expand_and_merge_intervals(intervals, expand_size=config.STTN_NEIGHBOR_STRIDE*config.STTN_REFERENCE_LENGTH, max_length=config.STTN_MAX_LOAD_NUM):
- # 初始化输出区间列表
- expanded_intervals = []
-
- # 对每个原始区间进行扩展
- for interval in intervals:
- start, end = interval
-
- # 扩展至至少 'expand_size' 个单位,但不超过 'max_length' 个单位
- expansion_amount = max(expand_size - (end - start + 1), 0)
-
- # 在保证包含原区间的前提下尽可能平分前后扩展量
- expand_start = max(start - expansion_amount // 2, 1) # 确保起始点不小于1
- expand_end = end + expansion_amount // 2
-
- # 如果扩展后的区间超出了最大长度,进行调整
- if (expand_end - expand_start + 1) > max_length:
- expand_end = expand_start + max_length - 1
-
- # 对于单点的处理,需额外保证有至少 'expand_size' 长度
- if start == end:
- if expand_end - expand_start + 1 < expand_size:
- expand_end = expand_start + expand_size - 1
-
- # 检查与前一个区间是否有重叠并进行相应的合并
- if expanded_intervals and expand_start <= expanded_intervals[-1][1]:
- previous_start, previous_end = expanded_intervals.pop()
- expand_start = previous_start
- expand_end = max(expand_end, previous_end)
-
- # 添加扩展后的区间至结果列表
- expanded_intervals.append((expand_start, expand_end))
-
- return expanded_intervals
-
- @staticmethod
- def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
- """
- 合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH
- """
- expanded = []
- # 首先单独处理单点区间以扩展它们
- for start, end in intervals:
- if start == end: # 单点区间
- # 扩展到接近的目标长度,但保证前后不重叠
- prev_end = expanded[-1][1] if expanded else float('-inf')
- next_start = float('inf')
- # 查找下一个区间的起始点
- for ns, ne in intervals:
- if ns > end:
- next_start = ns
- break
- # 确定新的扩展起点和终点
- new_start = max(start - (target_length - 1) // 2, prev_end + 1)
- new_end = min(start + (target_length - 1) // 2, next_start - 1)
- # 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展
- if new_end < new_start:
- new_start, new_end = start, start # 保持原样
- expanded.append((new_start, new_end))
- else:
- # 非单点区间直接保留,稍后处理任何可能的重叠
- expanded.append((start, end))
- # 排序以合并那些因扩展导致重叠的区间
- expanded.sort(key=lambda x: x[0])
- # 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时
- merged = [expanded[0]]
- for start, end in expanded[1:]:
- last_start, last_end = merged[-1]
- # 检查是否重叠
- if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
- # 需要合并
- merged[-1] = (last_start, max(last_end, end)) # 合并区间
- elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
- # 相邻区间也需要合并的场景
- merged[-1] = (last_start, end)
- else:
- # 如果没有重叠且都大于目标长度,则直接保留
- merged.append((start, end))
- return merged
-
- def compute_iou(self, box1, box2):
- box1_polygon = self.sub_area_to_polygon(box1)
- box2_polygon = self.sub_area_to_polygon(box2)
- intersection = box1_polygon.intersection(box2_polygon)
- if intersection.is_empty:
- return -1
- else:
- union_area = (box1_polygon.area + box2_polygon.area - intersection.area)
- if union_area > 0:
- intersection_area_rate = intersection.area / union_area
- else:
- intersection_area_rate = 0
- return intersection_area_rate
-
- def get_area_max_box_dict(self, sub_frame_no_list_continuous, subtitle_frame_no_box_dict):
- _area_max_box_dict = dict()
- for start_no, end_no in sub_frame_no_list_continuous:
- # 寻找面积最大文本框
- current_no = start_no
- # 查找当前区间矩形框最大面积
- area_max_box_list = []
- while current_no <= end_no:
- for coord in subtitle_frame_no_box_dict[current_no]:
- # 取出每一个文本框坐标
- xmin, xmax, ymin, ymax = coord
- # 计算当前文本框坐标面积
- current_area = abs(xmax - xmin) * abs(ymax - ymin)
- # 如果区间最大框列表为空,则当前面积为区间最大面积
- if len(area_max_box_list) < 1:
- area_max_box_list.append({
- 'area': current_area,
- 'xmin': xmin,
- 'xmax': xmax,
- 'ymin': ymin,
- 'ymax': ymax
- })
- # 如果列表非空,判断当前文本框是与区间最大文本框在同一区域
- else:
- has_same_position = False
- # 遍历每个区间最大文本框,判断当前文本框位置是否与区间最大文本框列表的某个文本框位于同一行且交叉
- for area_max_box in area_max_box_list:
- if (area_max_box['ymin'] - config.THRESHOLD_HEIGHT_DIFFERENCE <= ymin
- and ymax <= area_max_box['ymax'] + config.THRESHOLD_HEIGHT_DIFFERENCE):
- if self.compute_iou((xmin, xmax, ymin, ymax), (
- area_max_box['xmin'], area_max_box['xmax'], area_max_box['ymin'],
- area_max_box['ymax'])) != -1:
- # 如果高度差异不一样
- if abs(abs(area_max_box['ymax'] - area_max_box['ymin']) - abs(
- ymax - ymin)) < config.THRESHOLD_HEIGHT_DIFFERENCE:
- has_same_position = True
- # 如果在同一行,则计算当前面积是不是最大
- # 判断面积大小,若当前面积更大,则将当前行的最大区域坐标点更新
- if has_same_position and current_area > area_max_box['area']:
- area_max_box['area'] = current_area
- area_max_box['xmin'] = xmin
- area_max_box['xmax'] = xmax
- area_max_box['ymin'] = ymin
- area_max_box['ymax'] = ymax
- # 如果遍历了所有的区间最大文本框列表,发现是新的一行,则直接添加
- if not has_same_position:
- new_large_area = {
- 'area': current_area,
- 'xmin': xmin,
- 'xmax': xmax,
- 'ymin': ymin,
- 'ymax': ymax
- }
- if new_large_area not in area_max_box_list:
- area_max_box_list.append(new_large_area)
- break
- current_no += 1
- _area_max_box_list = list()
- for area_max_box in area_max_box_list:
- if area_max_box not in _area_max_box_list:
- _area_max_box_list.append(area_max_box)
- _area_max_box_dict[f'{start_no}->{end_no}'] = _area_max_box_list
- return _area_max_box_dict
-
- def get_subtitle_frame_no_box_dict_with_united_coordinates(self, subtitle_frame_no_box_dict):
- """
- 将多个视频帧的文本区域坐标统一
- """
- subtitle_frame_no_box_dict_with_united_coordinates = dict()
- frame_no_list = self.find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict)
- area_max_box_dict = self.get_area_max_box_dict(frame_no_list, subtitle_frame_no_box_dict)
- for start_no, end_no in frame_no_list:
- current_no = start_no
- while True:
- area_max_box_list = area_max_box_dict[f'{start_no}->{end_no}']
- current_boxes = subtitle_frame_no_box_dict[current_no]
- new_subtitle_frame_no_box_list = []
- for current_box in current_boxes:
- current_xmin, current_xmax, current_ymin, current_ymax = current_box
- for max_box in area_max_box_list:
- large_xmin = max_box['xmin']
- large_xmax = max_box['xmax']
- large_ymin = max_box['ymin']
- large_ymax = max_box['ymax']
- box1 = (current_xmin, current_xmax, current_ymin, current_ymax)
- box2 = (large_xmin, large_xmax, large_ymin, large_ymax)
- res = self.compute_iou(box1, box2)
- if res != -1:
- new_subtitle_frame_no_box = (large_xmin, large_xmax, large_ymin, large_ymax)
- if new_subtitle_frame_no_box not in new_subtitle_frame_no_box_list:
- new_subtitle_frame_no_box_list.append(new_subtitle_frame_no_box)
- subtitle_frame_no_box_dict_with_united_coordinates[current_no] = new_subtitle_frame_no_box_list
- current_no += 1
- if current_no > end_no:
- break
- return subtitle_frame_no_box_dict_with_united_coordinates
-
- def prevent_missed_detection(self, subtitle_frame_no_box_dict):
- """
- 添加额外的文本框,防止漏检
- """
- frame_no_list = self.find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict)
- for start_no, end_no in frame_no_list:
- current_no = start_no
- while True:
- current_box_list = subtitle_frame_no_box_dict[current_no]
- if current_no + 1 != end_no and (current_no + 1) in subtitle_frame_no_box_dict.keys():
- next_box_list = subtitle_frame_no_box_dict[current_no + 1]
- if set(current_box_list).issubset(set(next_box_list)):
- subtitle_frame_no_box_dict[current_no] = subtitle_frame_no_box_dict[current_no + 1]
- current_no += 1
- if current_no > end_no:
- break
- return subtitle_frame_no_box_dict
-
- @staticmethod
- def get_frequency_in_range(sub_frame_no_list_continuous, subtitle_frame_no_box_dict):
- sub_area_with_frequency = {}
- for start_no, end_no in sub_frame_no_list_continuous:
- current_no = start_no
- while True:
- current_box_list = subtitle_frame_no_box_dict[current_no]
- for current_box in current_box_list:
- if str(current_box) not in sub_area_with_frequency.keys():
- sub_area_with_frequency[f'{current_box}'] = 1
- else:
- sub_area_with_frequency[f'{current_box}'] += 1
- current_no += 1
- if current_no > end_no:
- break
- return sub_area_with_frequency
-
- def filter_mistake_sub_area(self, subtitle_frame_no_box_dict, fps):
- """
- 过滤错误的字幕区域
- """
- sub_frame_no_list_continuous = self.find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict)
- sub_area_with_frequency = self.get_frequency_in_range(sub_frame_no_list_continuous, subtitle_frame_no_box_dict)
- correct_sub_area = []
- for sub_area in sub_area_with_frequency.keys():
- if sub_area_with_frequency[sub_area] >= (fps // 2):
- correct_sub_area.append(sub_area)
- else:
- print(f'drop {sub_area}')
- correct_subtitle_frame_no_box_dict = dict()
- for frame_no in subtitle_frame_no_box_dict.keys():
- current_box_list = subtitle_frame_no_box_dict[frame_no]
- new_box_list = []
- for current_box in current_box_list:
- if str(current_box) in correct_sub_area and current_box not in new_box_list:
- new_box_list.append(current_box)
- correct_subtitle_frame_no_box_dict[frame_no] = new_box_list
- return correct_subtitle_frame_no_box_dict
-
+import numpy as np
class SubtitleRemover:
def __init__(self, vd_path, sub_area=None, gui_mode=False):
- importlib.reload(config)
# 线程锁
self.lock = threading.RLock()
# 用户指定的字幕区域位置
self.sub_area = sub_area
# 是否为gui运行,gui运行需要显示预览
self.gui_mode = gui_mode
+ self.hardware_accelerator = HardwareAccelerator.instance()
+ # 是否使用硬件加速
+ self.hardware_accelerator.set_enabled(config.hardwareAcceleration.value)
+ self.model_config = ModelConfig()
# 判断是否为图片
- self.is_picture = False
- if is_image_file(str(vd_path)):
- self.sub_area = None
- self.is_picture = True
+ self.is_picture = is_image_file(str(vd_path))
# 视频路径
self.video_path = vd_path
- self.video_cap = cv2.VideoCapture(vd_path)
+ self.video_cap = cv2.VideoCapture(get_readable_path(vd_path))
# 通过视频路径获取视频名称
self.vd_name = Path(self.video_path).stem
# 视频帧总数
@@ -590,61 +60,27 @@ class SubtitleRemover:
self.mask_size = (int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)))
self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- # 创建字幕检测对象
- self.sub_detector = SubtitleDetect(self.video_path, self.sub_area)
# 创建视频临时对象,windows下delete=True会有permission denied的报错
self.video_temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
# 创建视频写对象
- self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
- self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
- self.video_inpaint = None
- self.lama_inpaint = None
+ self.video_writer = cv2.VideoWriter(get_readable_path(self.video_temp_file.name), cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
+ self.video_out_path = os.path.abspath(os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4'))
+ self.propainter_inpaint = None
self.ext = os.path.splitext(vd_path)[-1]
if self.is_picture:
pic_dir = os.path.join(os.path.dirname(self.video_path), 'no_sub')
if not os.path.exists(pic_dir):
os.makedirs(pic_dir)
- self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
- if torch.cuda.is_available():
- print('use GPU for acceleration')
- if config.USE_DML:
- print('use DirectML for acceleration')
- if config.MODE != config.InpaintMode.STTN:
- print('Warning: DirectML acceleration is only available for STTN model. Falling back to CPU for other models.')
- for provider in config.ONNX_PROVIDERS:
- print(f"Detected execution provider: {provider}")
-
+ self.video_out_path = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
# 总处理进度
self.progress_total = 0
self.progress_remover = 0
self.isFinished = False
- # 预览帧
- self.preview_frame = None
# 是否将原音频嵌入到去除字幕后的视频
self.is_successful_merged = False
-
- @staticmethod
- def get_coordinates(dt_box):
- """
- 从返回的检测框中获取坐标
- :param dt_box 检测框返回结果
- :return list 坐标点列表
- """
- coordinate_list = list()
- if isinstance(dt_box, list):
- for i in dt_box:
- i = list(i)
- (x1, y1) = int(i[0][0]), int(i[0][1])
- (x2, y2) = int(i[1][0]), int(i[1][1])
- (x3, y3) = int(i[2][0]), int(i[2][1])
- (x4, y4) = int(i[3][0]), int(i[3][1])
- xmin = max(x1, x4)
- xmax = min(x2, x3)
- ymin = max(y1, y2)
- ymax = min(y3, y4)
- coordinate_list.append((xmin, xmax, ymin, ymax))
- return coordinate_list
+ # 进度监听器列表
+ self.progress_listeners = []
@staticmethod
def is_current_frame_no_start(frame_no, continuous_frame_no_list):
@@ -669,18 +105,67 @@ class SubtitleRemover:
def update_progress(self, tbar, increment):
tbar.update(increment)
current_percentage = (tbar.n / tbar.total) * 100
- self.progress_remover = int(current_percentage) // 2
- self.progress_total = 50 + self.progress_remover
+ self.progress_remover = int(current_percentage)
+ self.progress_total = self.progress_remover
+ self.notify_progress_listeners()
+
+ def append_output(self, *args):
+ """输出信息到控制台
+ Args:
+ *args: 要输出的内容,多个参数将用空格连接
+ """
+ print(*args)
+
+ def add_progress_listener(self, listener):
+ """
+ 添加进度监听器
+
+ Args:
+ listener: 一个回调函数,接收参数 (progress_total, isFinished)
+ """
+ if listener not in self.progress_listeners:
+ self.progress_listeners.append(listener)
+
+ def remove_progress_listener(self, listener):
+ """
+ 移除进度监听器
+
+ Args:
+ listener: 要移除的监听器函数
+ """
+ if listener in self.progress_listeners:
+ self.progress_listeners.remove(listener)
+
+ def notify_progress_listeners(self):
+ """
+ 通知所有进度监听器当前进度
+ """
+ for listener in self.progress_listeners:
+ try:
+ listener(self.progress_total, self.isFinished)
+ except Exception as e:
+ traceback.print_exc()
+
+ def update_preview_with_comp(self, frame_ori, frame_comp):
+ """
+ 更新预览
+ """
+ pass
def propainter_mode(self, tbar):
- print('use propainter mode')
- sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
- continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
- scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
- continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list,
+ sub_detector = SubtitleDetect(self.video_path, self.sub_area)
+ sub_list = sub_detector.find_subtitle_frame_no(sub_remover=self)
+ if len(sub_list) == 0:
+ raise Exception(tr['Main']['NoSubtitleDetected'].format(self.video_path))
+ continuous_frame_no_list = sub_detector.find_continuous_ranges_with_same_mask(sub_list)
+ scene_div_points = sub_detector.get_scene_div_frame_no(self.video_path)
+ continuous_frame_no_list = sub_detector.split_range_by_scene(continuous_frame_no_list,
scene_div_points)
- self.video_inpaint = VideoInpaint(config.PROPAINTER_MAX_LOAD_NUM)
- print('[Processing] start removing subtitles...')
+ del sub_detector
+ gc.collect()
+ device = self.hardware_accelerator.device if self.hardware_accelerator.has_cuda() else torch.device("cpu")
+ propainter_inpaint = PropainterInpaint(device, self.model_config.PROPAINTER_MODEL_DIR, config.propainterMaxLoadNum.value)
+ self.append_output(tr['Main']['ProcessingStartRemovingSubtitles'])
index = 0
while True:
ret, frame = self.video_cap.read()
@@ -690,22 +175,22 @@ class SubtitleRemover:
# 如果当前帧没有水印/文本则直接写
if index not in sub_list.keys():
self.video_writer.write(frame)
- print(f'write frame: {index}')
+ # self.append_output(f'write frame: {index}')
self.update_progress(tbar, increment=1)
continue
# 如果有水印,判断该帧是不是开头帧
else:
# 如果是开头帧,则批推理到尾帧
if self.is_current_frame_no_start(index, continuous_frame_no_list):
- # print(f'No 1 Current index: {index}')
+ # self.append_output(f'No 1 Current index: {index}')
start_frame_no = index
- print(f'find start: {start_frame_no}')
+ # self.append_output(f'find start: {start_frame_no}')
# 找到结束帧
end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list)
# 判断当前帧号是不是字幕起始位置
# 如果获取的结束帧号不为-1则说明
if end_frame_no != -1:
- print(f'find end: {end_frame_no}')
+ # self.append_output(f'find end: {end_frame_no}')
# ************ 读取该区间所有帧 start ************
temp_frames = list()
# 将头帧加入处理列表
@@ -725,261 +210,262 @@ class SubtitleRemover:
elif len(temp_frames) == 1:
inner_index += 1
single_mask = create_mask(self.mask_size, sub_list[index])
- if self.lama_inpaint is None:
- self.lama_inpaint = LamaInpaint()
- inpainted_frame = self.lama_inpaint(frame, single_mask)
+ inpainted_frame = self.lama_inpaint.inpaint(frame, single_mask)
self.video_writer.write(inpainted_frame)
- print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
+ # self.append_output(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
self.update_progress(tbar, increment=1)
continue
else:
# 将读取的视频帧分批处理
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, sub_list[start_frame_no])
- for batch in batch_generator(temp_frames, config.PROPAINTER_MAX_LOAD_NUM):
+ for batch in batch_generator(temp_frames, config.propainterMaxLoadNum.value):
# 2. 调用批推理
if len(batch) == 1:
single_mask = create_mask(self.mask_size, sub_list[start_frame_no])
- if self.lama_inpaint is None:
- self.lama_inpaint = LamaInpaint()
- inpainted_frame = self.lama_inpaint(frame, single_mask)
+ inpainted_frame = self.lama_inpaint.inpaint(frame, single_mask)
self.video_writer.write(inpainted_frame)
- print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
+ # self.append_output(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
inner_index += 1
self.update_progress(tbar, increment=1)
elif len(batch) > 1:
- inpainted_frames = self.video_inpaint.inpaint(batch, mask)
+ inpainted_frames = propainter_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
- print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
+ # self.append_output(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
inner_index += 1
- if self.gui_mode:
- self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
+ self.update_preview_with_comp(np.clip(batch[i]+mask[:,:,np.newaxis]*0.3,0,255).astype(np.uint8), inpainted_frame)
self.update_progress(tbar, increment=len(batch))
- def sttn_mode_with_no_detection(self, tbar):
+ def sttn_auto_mode(self, tbar):
"""
使用sttn对选中区域进行重绘,不进行字幕检测
"""
- print('use sttn mode with no detection')
- print('[Processing] start removing subtitles...')
+ self.append_output(tr['Main']['ProcessingStartRemovingSubtitles'])
if self.sub_area is not None:
ymin, ymax, xmin, xmax = self.sub_area
else:
- print('[Info] No subtitle area has been set. Video will be processed in full screen. As a result, the final outcome might be suboptimal.')
+ self.append_output(tr['Main']['FullScreenProcessingNote'])
ymin, ymax, xmin, xmax = 0, self.frame_height, 0, self.frame_width
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
mask = create_mask(self.mask_size, mask_area_coordinates)
- sttn_video_inpaint = STTNVideoInpaint(self.video_path)
+ sttn_video_inpaint = STTNAutoInpaint(self.hardware_accelerator.device, self.model_config.STTN_AUTO_MODEL_PATH, self.video_path)
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
- def sttn_mode(self, tbar):
- # 是否跳过字幕帧寻找
- if config.STTN_SKIP_DETECTION:
- # 若跳过则世界使用sttn模式
- self.sttn_mode_with_no_detection(tbar)
- else:
- print('use sttn mode')
- sttn_inpaint = STTNInpaint()
- sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
- continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
- print(continuous_frame_no_list)
- continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list)
- print(continuous_frame_no_list)
- start_end_map = dict()
- for interval in continuous_frame_no_list:
- start, end = interval
- start_end_map[start] = end
- current_frame_index = 0
- print('[Processing] start removing subtitles...')
- while True:
- ret, frame = self.video_cap.read()
- # 如果读取到为,则结束
- if not ret:
- break
- current_frame_index += 1
- # 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
- if current_frame_index not in start_end_map.keys():
- self.video_writer.write(frame)
- print(f'write frame: {current_frame_index}')
- self.update_progress(tbar, increment=1)
- if self.gui_mode:
- self.preview_frame = cv2.hconcat([frame, frame])
- # 如果是区间开始,则找到尾巴
- else:
- start_frame_index = current_frame_index
- end_frame_index = start_end_map[current_frame_index]
- print(f'processing frame {start_frame_index} to {end_frame_index}')
- # 用于存储需要去字幕的视频帧
- frames_need_inpaint = list()
- frames_need_inpaint.append(frame)
- inner_index = 0
- # 接着往下读,直到读取到尾巴
- for j in range(end_frame_index - start_frame_index):
- ret, frame = self.video_cap.read()
- if not ret:
- break
- current_frame_index += 1
- frames_need_inpaint.append(frame)
- mask_area_coordinates = []
- # 1. 获取当前批次的mask坐标全集
- for mask_index in range(start_frame_index, end_frame_index):
- if mask_index in sub_list.keys():
- for area in sub_list[mask_index]:
- xmin, xmax, ymin, ymax = area
- # 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
- if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE:
- continue
- if area not in mask_area_coordinates:
- mask_area_coordinates.append(area)
- # 1. 获取当前批次使用的mask
- mask = create_mask(self.mask_size, mask_area_coordinates)
- print(f'inpaint with mask: {mask_area_coordinates}')
- for batch in batch_generator(frames_need_inpaint, config.STTN_MAX_LOAD_NUM):
- # 2. 调用批推理
- if len(batch) >= 1:
- inpainted_frames = sttn_inpaint(batch, mask)
- for i, inpainted_frame in enumerate(inpainted_frames):
- self.video_writer.write(inpainted_frame)
- print(f'write frame: {start_frame_index + inner_index} with mask')
- inner_index += 1
- if self.gui_mode:
- self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
- self.update_progress(tbar, increment=len(batch))
-
- def lama_mode(self, tbar):
- print('use lama mode')
- sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
- if self.lama_inpaint is None:
- self.lama_inpaint = LamaInpaint()
- index = 0
- print('[Processing] start removing subtitles...')
+ def video_inpaint(self, tbar, model):
+ sub_detector = SubtitleDetect(self.video_path, self.sub_area)
+ sub_list = sub_detector.find_subtitle_frame_no(sub_remover=self)
+ if len(sub_list) == 0:
+ raise Exception(tr['Main']['NoSubtitleDetected'].format(self.video_path))
+ continuous_frame_no_list = sub_detector.find_continuous_ranges_with_same_mask(sub_list)
+ tbar.write(f"Subtitle detected: {continuous_frame_no_list}")
+ continuous_frame_no_list = expand_frame_ranges(continuous_frame_no_list, config.subtitleTimelineBackwardFrameCount.value, config.subtitleTimelineForwardFrameCount.value)
+ tbar.write(f"Subtitle timeline expand ({config.subtitleTimelineBackwardFrameCount.value} <- -> {config.subtitleTimelineForwardFrameCount.value}): {continuous_frame_no_list}")
+ continuous_frame_no_list = sub_detector.filter_and_merge_intervals(continuous_frame_no_list, config.sttnReferenceLength.value)
+ tbar.write(f'Subtitle filter_and_merge_intervals: {continuous_frame_no_list}')
+ del sub_detector
+ gc.collect()
+ start_end_map = dict()
+ for interval in continuous_frame_no_list:
+ start, end = interval
+ start_end_map[start] = end
+ current_frame_index = 0
+ self.append_output(tr['Main']['ProcessingStartRemovingSubtitles'])
while True:
ret, frame = self.video_cap.read()
+ # 如果读取到为,则结束
if not ret:
break
- original_frame = frame
- index += 1
- if index in sub_list.keys():
- mask = create_mask(self.mask_size, sub_list[index])
- if config.LAMA_SUPER_FAST:
- frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
- else:
- frame = self.lama_inpaint(frame, mask)
- if self.gui_mode:
- self.preview_frame = cv2.hconcat([original_frame, frame])
- if self.is_picture:
- cv2.imencode(self.ext, frame)[1].tofile(self.video_out_name)
- else:
+ current_frame_index += 1
+ # 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
+ if current_frame_index not in start_end_map.keys():
self.video_writer.write(frame)
- tbar.update(1)
- self.progress_remover = 100 * float(index) / float(self.frame_count) // 2
- self.progress_total = 50 + self.progress_remover
+ # self.append_output(f'write frame: {current_frame_index}')
+ self.update_progress(tbar, increment=1)
+ self.update_preview_with_comp(frame, frame)
+ # 如果是区间开始,则找到尾巴
+ else:
+ start_frame_index = current_frame_index
+ end_frame_index = start_end_map[current_frame_index]
+ tbar.write(f'processing frame {start_frame_index} to {end_frame_index}')
+ # 用于存储需要去字幕的视频帧
+ frames_need_inpaint = list()
+ frames_need_inpaint.append(frame)
+ inner_index = 0
+ # 接着往下读,直到读取到尾巴
+ for j in range(end_frame_index - start_frame_index):
+ ret, frame = self.video_cap.read()
+ if not ret:
+ break
+ current_frame_index += 1
+ frames_need_inpaint.append(frame)
+ mask_area_coordinates = []
+ # 1. 获取当前批次的mask坐标全集
+ for mask_index in range(start_frame_index, end_frame_index):
+ if mask_index in sub_list.keys():
+ for area in sub_list[mask_index]:
+ xmin, xmax, ymin, ymax = area
+ # 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
+ if (ymax - ymin) - (xmax - xmin) > config.subtitleYXAxisDifferencePixel.value:
+ continue
+ if area not in mask_area_coordinates:
+ mask_area_coordinates.append(area)
+ # 1. 获取当前批次使用的mask
+ mask = create_mask(self.mask_size, mask_area_coordinates)
+ # self.append_output(f'inpaint with mask: {mask_area_coordinates}')
+ for batch in batch_generator(frames_need_inpaint, config.getSttnMaxLoadNum()):
+ # 2. 调用批推理
+ if len(batch) >= 1:
+ inpainted_frames = model(batch, mask)
+ for i, inpainted_frame in enumerate(inpainted_frames):
+ self.video_writer.write(inpainted_frame)
+ # self.append_output(f'write frame: {start_frame_index + inner_index} with mask')
+ inner_index += 1
+ self.update_preview_with_comp(np.clip(batch[i]+mask[:,:,np.newaxis]*0.3,0,255).astype(np.uint8), inpainted_frame)
+ self.update_progress(tbar, increment=len(batch))
def run(self):
# 记录开始时间
start_time = time.time()
+ # 如果使用GPU加速,则打印GPU加速提示
+ if self.hardware_accelerator.has_accelerator():
+ accelerator_name = self.hardware_accelerator.accelerator_name
+ self.append_output(tr['Main']['SubtitleDetectionAcceleratorON'].format(accelerator_name))
+ if accelerator_name == 'DirectML' and config.inpaintMode.value not in [InpaintMode.STTN_AUTO, InpaintMode.STTN_DET]:
+ self.append_output(tr['Main']['DirectMLWarning'])
# 重置进度条
self.progress_total = 0
tbar = tqdm(total=int(self.frame_count), unit='frame', position=0, file=sys.__stdout__,
desc='Subtitle Removing')
if self.is_picture:
- sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
- self.lama_inpaint = LamaInpaint()
- original_frame = cv2.imread(self.video_path)
+ original_frame = read_image(self.video_path)
+ if original_frame is None:
+ self.append_output(tr['Main']['ReadImageFailed'].format(self.video_path))
+ return
+ sub_detector = SubtitleDetect(self.video_path, self.sub_area)
+ sub_list = sub_detector.detect_subtitle(original_frame)
+ del sub_detector
+ gc.collect()
if len(sub_list):
- mask = create_mask(original_frame.shape[0:2], sub_list[1])
- inpainted_frame = self.lama_inpaint(original_frame, mask)
+ mask = create_mask(original_frame.shape[0:2], sub_list)
+ inpainted_frame = self.lama_inpaint.inpaint(original_frame, mask)
+ self.update_preview_with_comp(np.clip(original_frame+mask[:,:,np.newaxis]*0.3,0,255).astype(np.uint8), inpainted_frame)
else:
inpainted_frame = original_frame
- if self.gui_mode:
- self.preview_frame = cv2.hconcat([original_frame, inpainted_frame])
- cv2.imencode(self.ext, inpainted_frame)[1].tofile(self.video_out_name)
+ self.update_preview_with_comp(original_frame, inpainted_frame)
+ cv2.imencode(self.ext, inpainted_frame)[1].tofile(self.video_out_path)
tbar.update(1)
self.progress_total = 100
else:
# 精准模式下,获取场景分割的帧号,进一步切割
- if config.MODE == config.InpaintMode.PROPAINTER:
+ self.log_model()
+ if config.inpaintMode.value == InpaintMode.PROPAINTER:
self.propainter_mode(tbar)
- elif config.MODE == config.InpaintMode.STTN:
- self.sttn_mode(tbar)
+ elif config.inpaintMode.value == InpaintMode.STTN_AUTO:
+ self.sttn_auto_mode(tbar)
+ elif config.inpaintMode.value == InpaintMode.STTN_DET:
+ self.video_inpaint(tbar, self.sttn_det_inpaint)
+ elif config.inpaintMode.value == InpaintMode.LAMA:
+ self.video_inpaint(tbar, self.lama_inpaint)
+ elif config.inpaintMode.value == InpaintMode.OPENCV:
+ self.video_inpaint(tbar, OpenCVInpaint())
else:
- self.lama_mode(tbar)
+ raise Exception(f'inpaint mode: {config.inpaintMode.value} not implemented')
+
self.video_cap.release()
self.video_writer.release()
if not self.is_picture:
# 将原音频合并到新生成的视频文件中
self.merge_audio_to_video()
- print(f"[Finished]Subtitle successfully removed, video generated at:{self.video_out_name}")
- else:
- print(f"[Finished]Subtitle successfully removed, picture generated at:{self.video_out_name}")
- print(f'time cost: {round(time.time() - start_time, 2)}s')
+ self.append_output(tr['Main']['FinishedProcessing'].format(self.video_out_path))
+ self.append_output(tr['Main']['ProcessingTime'].format(round(time.time() - start_time)))
self.isFinished = True
self.progress_total = 100
if os.path.exists(self.video_temp_file.name):
try:
os.remove(self.video_temp_file.name)
except Exception:
- if platform.system() in ['Windows']:
- pass
- else:
- print(f'failed to delete temp file {self.video_temp_file.name}')
+ pass #ignore
+
+ def log_model(self):
+ model_friendly_name = list(tr['InpaintMode'].values())[list(InpaintMode).index(config.inpaintMode.value)]
+ model_device = 'CPU'
+ if config.inpaintMode.value != InpaintMode.OPENCV and self.hardware_accelerator.has_accelerator():
+ accelerator_name = self.hardware_accelerator.accelerator_name
+ if accelerator_name == 'DirectML' and config.inpaintMode.value in [InpaintMode.STTN_AUTO, InpaintMode.STTN_DET]:
+ model_device = 'DirectML'
+ if self.hardware_accelerator.has_cuda():
+ model_device = accelerator_name
+ self.append_output(tr['Main']['UseModel'].format(f"{model_friendly_name} ({model_device})"))
def merge_audio_to_video(self):
# 创建音频临时对象,windows下delete=True会有permission denied的报错
temp = tempfile.NamedTemporaryFile(suffix='.aac', delete=False)
- audio_extract_command = [config.FFMPEG_PATH,
+ audio_extract_command = [FFmpegCLI.instance().ffmpeg_path,
"-y", "-i", self.video_path,
"-acodec", "copy",
"-vn", "-loglevel", "error", temp.name]
use_shell = True if os.name == "nt" else False
try:
subprocess.check_output(audio_extract_command, stdin=open(os.devnull), shell=use_shell)
- except Exception:
- print('fail to extract audio')
+ except Exception as e:
+ traceback.print_exc()
+ self.append_output(tr['Main']['FailToExtractAudio'].format(str(e)))
return
else:
if os.path.exists(self.video_temp_file.name):
- audio_merge_command = [config.FFMPEG_PATH,
+ audio_merge_command = [FFmpegCLI.instance().ffmpeg_path,
"-y", "-i", self.video_temp_file.name,
"-i", temp.name,
- "-vcodec", "libx264" if config.USE_H264 else "copy",
+ "-vcodec", "copy",
"-acodec", "copy",
- "-loglevel", "error", self.video_out_name]
+ "-loglevel", "error", self.video_out_path]
try:
subprocess.check_output(audio_merge_command, stdin=open(os.devnull), shell=use_shell)
- except Exception:
- print('fail to merge audio')
+ except Exception as e:
+ traceback.print_exc()
+ self.append_output(tr['Main']['FailToMergeAudio'].format(str(e)))
return
if os.path.exists(temp.name):
try:
os.remove(temp.name)
except Exception:
- if platform.system() in ['Windows']:
- pass
- else:
- print(f'failed to delete temp file {temp.name}')
+ #ignore
+ pass
self.is_successful_merged = True
finally:
temp.close()
if not self.is_successful_merged:
try:
- shutil.copy2(self.video_temp_file.name, self.video_out_name)
+ shutil.copy2(self.video_temp_file.name, self.video_out_path)
except IOError as e:
- print("Unable to copy file. %s" % e)
+ self.append_output(tr['Main']['CopyFileFailed'].format(self.video_temp_file.name, self.video_out_path, str(e)))
self.video_temp_file.close()
+ @cached_property
+ def lama_inpaint(self):
+ model_path = os.path.join(self.model_config.LAMA_MODEL_DIR, 'big-lama.pt')
+ device = self.hardware_accelerator.device if self.hardware_accelerator.has_cuda() else torch.device("cpu")
+ return LamaInpaint(device, model_path)
+
+ @cached_property
+ def sttn_det_inpaint(self):
+ return STTNDetInpaint(self.hardware_accelerator.device, self.model_config.STTN_DET_MODEL_PATH)
+
if __name__ == '__main__':
multiprocessing.set_start_method("spawn")
- # 1. 提示用户输入视频路径
- video_path = input(f"Please input video or image file path: ").strip()
- # 判断视频路径是不是一个目录,是目录的化,批量处理改目录下的所有视频文件
- # 2. 按以下顺序传入字幕区域
- # sub_area = (ymin, ymax, xmin, xmax)
- # 3. 新建字幕提取对象
- if is_video_or_image(video_path):
- sd = SubtitleRemover(video_path, sub_area=None)
- sd.run()
- else:
- print(f'Invalid video path: {video_path}')
+ from backend.tools.args_handler import parse_args
+ args = parse_args()
+ sub_area = None if args.ymin is None or args.ymax is None or args.xmin is None or args.xmax is None else (
+ args.ymin, args.ymax, args.xmin, args.xmax)
+
+ print('Subtitle Area:', 'fullscreen' if sub_area is None else sub_area)
+ sr = SubtitleRemover(args.input, sub_area=sub_area)
+ if not is_video_or_image(args.input):
+ sr.append_output(f'Error: {video_path} is not supported not corrupted.')
+ exit(-1)
+ sr.video_out_path = args.output
+ config.inpaintMode.value = args.inpaint_mode
+ sr.run()
+
diff --git a/backend/models/video/ProPainter_1.pth b/backend/models/propainter/ProPainter_1.pth
similarity index 100%
rename from backend/models/video/ProPainter_1.pth
rename to backend/models/propainter/ProPainter_1.pth
diff --git a/backend/models/video/ProPainter_2.pth b/backend/models/propainter/ProPainter_2.pth
similarity index 100%
rename from backend/models/video/ProPainter_2.pth
rename to backend/models/propainter/ProPainter_2.pth
diff --git a/backend/models/video/ProPainter_3.pth b/backend/models/propainter/ProPainter_3.pth
similarity index 100%
rename from backend/models/video/ProPainter_3.pth
rename to backend/models/propainter/ProPainter_3.pth
diff --git a/backend/models/video/ProPainter_4.pth b/backend/models/propainter/ProPainter_4.pth
similarity index 100%
rename from backend/models/video/ProPainter_4.pth
rename to backend/models/propainter/ProPainter_4.pth
diff --git a/backend/models/video/fs_manifest.csv b/backend/models/propainter/fs_manifest.csv
similarity index 100%
rename from backend/models/video/fs_manifest.csv
rename to backend/models/propainter/fs_manifest.csv
diff --git a/backend/models/video/raft-things.pth b/backend/models/propainter/raft-things.pth
similarity index 100%
rename from backend/models/video/raft-things.pth
rename to backend/models/propainter/raft-things.pth
diff --git a/backend/models/video/recurrent_flow_completion.pth b/backend/models/propainter/recurrent_flow_completion.pth
similarity index 100%
rename from backend/models/video/recurrent_flow_completion.pth
rename to backend/models/propainter/recurrent_flow_completion.pth
diff --git a/backend/models/sttn/infer_model.pth b/backend/models/sttn-auto/infer_model.pth
similarity index 100%
rename from backend/models/sttn/infer_model.pth
rename to backend/models/sttn-auto/infer_model.pth
diff --git a/backend/models/sttn-det/sttn.pth b/backend/models/sttn-det/sttn.pth
new file mode 100644
index 0000000..4b9c245
Binary files /dev/null and b/backend/models/sttn-det/sttn.pth differ
diff --git a/backend/tools/args_handler.py b/backend/tools/args_handler.py
new file mode 100644
index 0000000..a9a420d
--- /dev/null
+++ b/backend/tools/args_handler.py
@@ -0,0 +1,41 @@
+import argparse
+from enum import Enum
+
+from .constant import InpaintMode
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Video Subtitle Remover Command Line Tool"
+ )
+ parser.add_argument(
+ "--input", "-i", required=True, type=str,
+ help="Input video file path"
+ )
+ parser.add_argument(
+ "--output", "-o", required=False, type=str, default=None,
+ help="Output video file path (optional)"
+ )
+ parser.add_argument(
+ "--ymin", type=int, default=None,
+ help="Subtitle area ymin (optional)"
+ )
+ parser.add_argument(
+ "--ymax", type=int, default=None,
+ help="Subtitle area ymax (optional)"
+ )
+ parser.add_argument(
+ "--xmin", type=int, default=None,
+ help="Subtitle area xmin (optional)"
+ )
+ parser.add_argument(
+ "--xmax", type=int, default=None,
+ help="Subtitle area xmax (optional)"
+ )
+ parser.add_argument(
+ "--inpaint-mode", type=str, default="sttn-auto",
+ choices=[mode.name.lower().replace('_','-') for mode in InpaintMode],
+ help="Inpaint mode, default is sttn-auto"
+ )
+ args = parser.parse_args()
+ args.inpaint_mode = InpaintMode[args.inpaint_mode.replace('-','_').upper()]
+ return args
\ No newline at end of file
diff --git a/backend/tools/common_tools.py b/backend/tools/common_tools.py
index 54372d5..9e74240 100644
--- a/backend/tools/common_tools.py
+++ b/backend/tools/common_tools.py
@@ -1,4 +1,10 @@
import os
+import sys
+import ctypes
+
+import cv2
+import numpy as np
+from fsplit.filesplit import Filesplit
video_extensions = {
'.mp4', '.m4a', '.m4v', '.f4v', '.f4a', '.m4b', '.m4r', '.f4b', '.mov',
@@ -30,3 +36,24 @@ def is_video_or_image(filename):
file_extension = os.path.splitext(filename)[-1].lower()
# 检查扩展名是否在定义的视频或图片文件后缀集合中
return file_extension in video_extensions or file_extension in image_extensions
+
+def merge_big_file_if_not_exists(dir, file):
+ if file not in os.listdir(dir):
+ fs = Filesplit()
+ fs.merge(input_dir=dir)
+
+def get_readable_path(path):
+ if sys.platform != 'win32':
+ return path
+ buf = ctypes.create_unicode_buffer(4096)
+ ctypes.windll.kernel32.GetShortPathNameW(path, buf, 4096)
+ return buf.value
+
+def read_image(path):
+ if os.path.getsize(path) > 100*1024*1024: # 100MB
+ print(f"Image {path} is too large, skip")
+ return None
+ img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), -1)
+ if img is not None and img.shape[-1] == 4:
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
+ return img
\ No newline at end of file
diff --git a/backend/tools/concurrent/__init__.py b/backend/tools/concurrent/__init__.py
new file mode 100644
index 0000000..d7c1bc7
--- /dev/null
+++ b/backend/tools/concurrent/__init__.py
@@ -0,0 +1 @@
+from .task_manager import TaskExecutor, Future
\ No newline at end of file
diff --git a/backend/tools/concurrent/future.py b/backend/tools/concurrent/future.py
new file mode 100644
index 0000000..e52e107
--- /dev/null
+++ b/backend/tools/concurrent/future.py
@@ -0,0 +1,248 @@
+from typing import List, Optional, Callable, Iterable, Sized, Tuple, Union
+
+from PySide6.QtCore import QObject, Signal, QMutex, QSemaphore
+
+
+class FutureError(BaseException):
+ pass
+
+
+class FutureFailed(FutureError):
+ def __init__(self, _exception: Optional[BaseException]):
+ super().__init__()
+ self.exception = _exception
+
+ def __repr__(self):
+ return f"FutureFailed({self.exception})"
+
+ def __str__(self):
+ return f"FutureFailed({self.exception})"
+
+
+class GatheredFutureFailed(FutureError):
+ def __init__(self, failures: List[Tuple['Future', BaseException]]):
+ super().__init__()
+ self.failures = failures
+
+ def __repr__(self):
+ return f"GatheredFutureFailed({self.failures})"
+
+ def __str__(self):
+ return f"GatheredFutureFailed({self.failures})"
+
+ def __iter__(self):
+ return iter(self.failures)
+
+ def __len__(self):
+ return len(self.failures)
+
+
+class FutureCancelled(FutureError):
+ def __init__(self):
+ super().__init__()
+
+ def __repr__(self):
+ return f"FutureCanceled()"
+
+ def __str__(self):
+ return f"FutureCanceled()"
+
+
+class Future(QObject):
+ result = Signal(object) # self
+ done = Signal(object) # self
+ failed = Signal(object) # self
+ partialDone = Signal(object) # child future
+ childrenDone = Signal(object) # self
+
+ def __init__(self, semaphore=0):
+ super().__init__()
+ self._taskID = None
+ self._failedCallback = lambda e: None
+ self._done = False
+ self._failed = False
+ self._result = None
+ self._exception = None
+ self._children = []
+ self._counter = 0
+ self._parent = None
+ self._callback = lambda _: None
+ self._mutex = QMutex()
+ self._extra = {}
+ self._semaphore = QSemaphore(semaphore)
+
+ def __onChildDone(self, childFuture: 'Future') -> None:
+ self._mutex.lock()
+ if childFuture.isFailed():
+ self._failed = True
+ self._counter += 1
+ self.partialDone.emit(childFuture)
+ try:
+ idx = getattr(childFuture, "_idx")
+ self._result[idx] = childFuture._result
+ self._mutex.unlock()
+ except AttributeError:
+ self._mutex.unlock()
+ raise RuntimeError(
+ "Invalid child future: please ensure that the child future is created by method 'Future.setChildren'")
+
+ if self._counter == len(self._children):
+ if self._failed: # set failed
+ l = []
+
+ for i, child in enumerate(self._children):
+ if isinstance(e := child.getException(), FutureError):
+ l.append((self._children[i], e))
+
+ self.setFailed(GatheredFutureFailed(l))
+ else:
+ self.setResult(self._result)
+
+ def __setChildren(self, children: List['Future']) -> None:
+ self._children = children
+ self._result = [None] * len(children)
+
+ for i, fut in enumerate(self._children):
+ setattr(fut, f"_idx", i)
+ fut.childrenDone.connect(self.__onChildDone)
+ fut._parent = self
+
+ for i, fut in enumerate(self._children): # check if child is done
+ if fut.isDone():
+ self.__onChildDone(fut)
+
+ def setResult(self, result) -> None:
+ """
+ :param result: The result to set
+ :return: None
+
+ do not set result in thread pool,or it may not set correctly
+ please use in main thread,or use signal-slot to set result !!!
+ """
+ if self._done:
+ raise RuntimeError("Future already done")
+
+ self._result = result
+ self._done = True
+
+ if self._parent:
+ self.childrenDone.emit(self)
+
+ if self._callback:
+ self._callback(result)
+
+ self.result.emit(result)
+ self.done.emit(self)
+
+ def setFailed(self, exception) -> None:
+ """
+ :param exception: The exception to set
+ :return: None
+ """
+ if self._done:
+ raise RuntimeError("Future already done")
+
+ self._exception = FutureFailed(exception)
+ self._done = True
+ self._failed = True
+
+ if self._parent:
+ self.childrenDone.emit(self)
+
+ if self._failedCallback:
+ self._failedCallback(self)
+
+ self.failed.emit(self._exception)
+ self.done.emit(self)
+
+ def setCallback(self, callback: Callable[[object, ], None]) -> None:
+ self._callback = callback
+
+ def setFailedCallback(self, callback: Callable[['Future', ], None]) -> None:
+ self._failedCallback = lambda e: callback(self)
+
+ def hasException(self) -> bool:
+ if self._children:
+ return any([fut.hasException() for fut in self._children])
+ else:
+ return self._exception is not None
+
+ def hasChildren(self) -> bool:
+ return bool(self._children)
+
+ def getException(self) -> Optional[BaseException]:
+ return self._exception
+
+ def setTaskID(self, _id: int) -> None:
+ self._taskID = _id
+
+ def getTaskID(self) -> int:
+ return self._taskID
+
+ def getChildren(self) -> List['Future']:
+ return self._children
+
+ @staticmethod
+ def gather(futures: {Iterable, Sized}) -> 'Future':
+ """
+ :param futures: An iterable of Future objects
+ :return: A Future object that will be done when all futures are done
+ """
+ future = Future()
+ future.__setChildren(futures)
+ return future
+
+ @property
+ def semaphore(self):
+ return self._semaphore
+
+ def wait(self):
+ if self.hasChildren():
+ for child in self.getChildren():
+ child.wait()
+ else:
+ self.semaphore.acquire(1)
+
+ def synchronize(self):
+ self.wait()
+
+ def isDone(self) -> bool:
+ return self._done
+
+ def isFailed(self) -> bool:
+ return self._failed
+
+ def getResult(self) -> Union[object, List[object]]:
+ return self._result
+
+ def setExtra(self, key, value):
+ self._extra[key] = value
+
+ def getExtra(self, key):
+ return self._extra.get(key, None)
+
+ def hasExtra(self, key):
+ return key in self._extra
+
+ def then(self, onSuccess: Callable, onFailed: Callable = None, onFinished : Callable = None):
+ self.result.connect(onSuccess)
+
+ if onFailed:
+ self.failed.connect(onFailed)
+
+ if onFinished:
+ self.done.connect(onFinished)
+
+ return self
+
+ def __getattr__(self, item):
+ return self.getExtra(item)
+
+ def __repr__(self):
+ return f"Future:({self._result})"
+
+ def __str__(self):
+ return f"Future({self._result})"
+
+ def __eq__(self, other):
+ return self._result == other._result
\ No newline at end of file
diff --git a/backend/tools/concurrent/task.py b/backend/tools/concurrent/task.py
new file mode 100644
index 0000000..bb42331
--- /dev/null
+++ b/backend/tools/concurrent/task.py
@@ -0,0 +1,48 @@
+import functools
+from typing import Optional
+
+from PySide6.QtCore import QObject, Signal, QRunnable
+
+from .future import Future
+
+
+class WorkerSignal(QObject):
+ finished = Signal(object)
+
+
+class BaseTask(QRunnable):
+ def __init__(self, _id: int, future: Future):
+ super().__init__()
+ self._signal = WorkerSignal() # Signal(object)
+ self._future = future
+ self._id = _id
+ self._exception: Optional[BaseException] = None
+ self._semaphore = future.semaphore
+
+ @property
+ def finished(self):
+ return self._signal.finished
+
+ @property
+ def signal(self):
+ return self._signal
+
+ def _taskDone(self, **data):
+ for d in data.items():
+ self._future.setExtra(*d)
+ self._signal.finished.emit(self._future)
+ self._semaphore.release(1)
+
+
+class Task(BaseTask):
+ def __init__(self, _id: int, future: Future, target: functools.partial, args, kwargs):
+ super().__init__(_id=_id, future=future)
+ self._target = target
+ self._kwargs = kwargs
+ self._args = args
+
+ def run(self) -> None:
+ try:
+ self._taskDone(result=self._target(*self._args, **self._kwargs))
+ except Exception as exception:
+ self._taskDone(exception=exception)
\ No newline at end of file
diff --git a/backend/tools/concurrent/task_manager.py b/backend/tools/concurrent/task_manager.py
new file mode 100644
index 0000000..23943a4
--- /dev/null
+++ b/backend/tools/concurrent/task_manager.py
@@ -0,0 +1,111 @@
+import functools
+import warnings
+from typing import Dict, List, Callable
+
+from PySide6 import QtCore
+from PySide6.QtCore import QThreadPool, QObject, QRunnable
+
+from .future import Future, FutureCancelled
+from .task import BaseTask, Task
+
+
+def cpu_count():
+ return 8
+
+
+class BaseTaskExecutor(QObject):
+ def __init__(self, useGlobalThreadPool=True):
+ super().__init__()
+ self.useGlobalThreadPool = useGlobalThreadPool
+
+ if useGlobalThreadPool:
+ self.threadPool = QThreadPool.globalInstance()
+ else:
+ self.threadPool = QThreadPool()
+ self.threadPool.setMaxThreadCount(2 * cpu_count()) # IO-Bound = 2*N, CPU-Bound = N + 1
+
+ self.taskMap = {}
+ self.tasks: Dict[int, BaseTask] = {}
+ self.taskCounter = 0
+
+ def deleteLater(self) -> None:
+ if not self.useGlobalThreadPool:
+ self.threadPool.clear()
+ self.threadPool.waitForDone()
+ self.threadPool.deleteLater()
+
+ super().deleteLater()
+
+ def _taskRun(self, task: BaseTask, future: Future, **kwargs):
+ self.tasks[self.taskCounter] = task
+ future.setTaskID(self.taskCounter)
+ task.signal.finished.connect(self._taskDone, type=QtCore.Qt.ConnectionType.QueuedConnection)
+ self.threadPool.start(task)
+ self.taskCounter += 1
+
+ def _taskDone(self, fut: Future):
+ """
+ need manually set Future.setFailed() or Future.setResult() to be called!!!
+ """
+ self.tasks.pop(fut.getTaskID())
+ if isinstance(e := fut.getExtra("exception"), Exception):
+ fut.setFailed(e)
+ else:
+ fut.setResult(fut.getExtra("result"))
+
+ def _taskCancel(self, fut: Future):
+ stack: List[Future] = [fut]
+ while stack:
+ f = stack.pop()
+
+ if not f.hasChildren() and not f.isDone():
+ self._taskSingleCancel(f)
+ f.setFailed(FutureCancelled())
+
+ stack.extend(f.getChildren())
+
+ def _taskSingleCancel(self, fut: Future):
+ _id = fut.getTaskID()
+ taskRef: BaseTask = self.tasks[_id]
+
+ if taskRef is not None:
+ try:
+ taskRef.setAutoDelete(False)
+ self.threadPool.cancel(taskRef)
+ taskRef.setAutoDelete(True)
+ except RuntimeError:
+ print("wrapped C/C++ object of type FetchImageTask has been deleted")
+
+ del taskRef
+
+ def cancelTask(self, fut: Future):
+ warnings.warn("BaseTaskExecutor.cancelTask: 目前好像不能正常工作...", DeprecationWarning)
+ self._taskCancel(fut)
+
+
+class TaskExecutor(BaseTaskExecutor):
+
+ globalInstance = None
+
+ def asyncRun(self, target: Callable, *args, **kwargs) -> Future:
+ future = Future()
+ task = Task(
+ _id=self.taskCounter,
+ future=future,
+ target=target if target is functools.partial else functools.partial(target),
+ args=args,
+ kwargs=kwargs
+ )
+ self._taskRun(task, future)
+ return future
+
+ @classmethod
+ def instance(cls):
+ if cls.globalInstance is None:
+ cls.globalInstance = TaskExecutor()
+
+ return cls.globalInstance
+
+ @classmethod
+ def runTask(cls, task: Callable, *args, **kwargs) -> Future:
+ return cls.instance().asyncRun(task, *args, **kwargs)
\ No newline at end of file
diff --git a/backend/tools/constant.py b/backend/tools/constant.py
new file mode 100644
index 0000000..0645bc5
--- /dev/null
+++ b/backend/tools/constant.py
@@ -0,0 +1,20 @@
+from enum import Enum, unique
+
+@unique
+class InpaintMode(Enum):
+ """
+ 图像重绘算法枚举
+ """
+ STTN_AUTO = "sttn-auto"
+ STTN_DET = "sttn-det"
+ LAMA = "lama"
+ PROPAINTER = "propainter"
+ OPENCV = "opencv"
+
+@unique
+class SubtitleDetectMode(Enum):
+ """
+ 字幕检测算法枚举
+ """
+ Fast = 0
+ Accurate = 1
\ No newline at end of file
diff --git a/backend/tools/ffmpeg_cli.py b/backend/tools/ffmpeg_cli.py
new file mode 100644
index 0000000..1c64fec
--- /dev/null
+++ b/backend/tools/ffmpeg_cli.py
@@ -0,0 +1,36 @@
+import os
+import stat
+
+import platform
+from .common_tools import merge_big_file_if_not_exists
+from backend.config import BASE_DIR
+
+class FFmpegCLI:
+
+ """
+ 进程管理器类,用于管理子进程的生命周期
+ 使用弱引用避免内存泄漏
+ """
+ _instance = None
+
+ @classmethod
+ def instance(cls):
+ """单例模式获取实例"""
+ if cls._instance is None:
+ cls._instance = FFmpegCLI()
+ return cls._instance
+
+ def __init__(self):
+ os.chmod(self.ffmpeg_path, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
+
+ @property
+ def ffmpeg_path(self):
+ system = platform.system()
+ if system == "Windows":
+ ffmpeg_dir = os.path.join(BASE_DIR, 'ffmpeg', 'win_x64')
+ merge_big_file_if_not_exists(ffmpeg_dir, 'ffmpeg.exe')
+ return os.path.join(ffmpeg_dir, 'ffmpeg.exe')
+ elif system == "Linux":
+ return os.path.join(BASE_DIR, 'ffmpeg', 'linux_x64', 'ffmpeg')
+ else:
+ return os.path.join(BASE_DIR, 'ffmpeg', 'macos', 'ffmpeg')
\ No newline at end of file
diff --git a/backend/tools/hardware_accelerator.py b/backend/tools/hardware_accelerator.py
new file mode 100644
index 0000000..8cbd666
--- /dev/null
+++ b/backend/tools/hardware_accelerator.py
@@ -0,0 +1,120 @@
+import traceback
+import importlib.util
+
+import torch
+
+from backend.config import tr
+
+class HardwareAccelerator:
+
+ # 类变量,用于存储单例实例
+ _instance = None
+
+ @classmethod
+ def instance(cls):
+ """获取单例实例"""
+ if cls._instance is None:
+ cls._instance = HardwareAccelerator()
+ cls._instance.initialize()
+ return cls._instance
+
+ def __init__(self):
+ self.__cuda = False
+ self.__dml = False
+ self.__onnx_providers = []
+ self.__enabled = True
+ self.__device = None
+
+ def initialize(self):
+ self.check_directml_available()
+ self.check_cuda_available()
+ self.load_onnx_providers()
+
+ def check_directml_available(self):
+ self.__dml = importlib.util.find_spec("torch_directml")
+
+ def check_cuda_available(self):
+ self.__cuda = torch.cuda.is_available()
+
+ def load_onnx_providers(self):
+ try:
+ import onnxruntime as ort
+ available_providers = ort.get_available_providers()
+ for provider in available_providers:
+ if provider in [
+ "CPUExecutionProvider"
+ ]:
+ continue
+ if provider not in [
+ "DmlExecutionProvider", # DirectML,适用于 Windows GPU
+ "ROCMExecutionProvider", # AMD ROCm
+ "MIGraphXExecutionProvider", # AMD MIGraphX
+ "VitisAIExecutionProvider", # AMD VitisAI,适用于 RyzenAI & Windows, 实测和DirectML性能似乎差不多
+ "OpenVINOExecutionProvider", # Intel GPU
+ "MetalExecutionProvider", # Apple macOS
+ "CoreMLExecutionProvider", # Apple macOS
+ "CUDAExecutionProvider", # Nvidia GPU
+ ]:
+ print(tr['Main']['OnnxExectionProviderNotSupportedSkipped'].format(provider))
+ continue
+ print(tr['Main']['OnnxExecutionProviderDetected'].format(provider))
+ self.__onnx_providers.append(provider)
+ except ModuleNotFoundError as e:
+ print(tr['Main']['OnnxRuntimeNotInstall'])
+
+ def has_accelerator(self):
+ if not self.__enabled:
+ return False
+ return self.__cuda or self.__dml or len(self.__onnx_providers) > 0
+
+ @property
+ def accelerator_name(self):
+ if not self.__enabled:
+ return "CPU"
+ if self.__dml:
+ return "DirectML"
+ if self.__cuda:
+ return "GPU"
+ elif len(self.__onnx_providers) > 0:
+ return ", ".join(self.__onnx_providers)
+ else:
+ return "CPU"
+
+ @property
+ def onnx_providers(self):
+ if not self.__enabled:
+ return []
+ return self.__onnx_providers
+
+ def has_cuda(self):
+ if not self.__enabled:
+ return False
+ return self.__cuda
+
+ def set_enabled(self, enable):
+ self.__enabled = enable
+
+ @property
+ def device(self):
+ """
+ onnxruntime-directml 1.21.1-1.22.0(往上未测试) 和 torch-directml 不能同时初始化, 会相互影响
+ 提示site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 266, in run
+ return self._sess.run(output_names, input_feed, run_options)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ UnicodeDecodeError: 'utf-8' codec can't decode byte 0xb2 in position 344: invalid start bn 344: invalid start byte
+ onnxruntime-directml 1.21.1 则正常, 但Win10跑不起来, Win11正常
+ 为了避免冲突以及避免重写一个QPT智能部署流程, 这里采用延迟初始化的方式+继续使用onnxruntime-directml 1.20.1
+ 当然SubtitleDetect放到一个独立进程去操作也是可以的
+ """
+ if self.__enabled:
+ if self.__dml:
+ try:
+ import torch_directml
+ return torch_directml.device(torch_directml.default_device())
+ self.__dml = True
+ except:
+ traceback.print_exc()
+ self.__dml = False
+ if self.__cuda:
+ return torch.device("cuda:0")
+ return torch.device("cpu")
\ No newline at end of file
diff --git a/backend/tools/inpaint_tools.py b/backend/tools/inpaint_tools.py
index a727794..27116a5 100644
--- a/backend/tools/inpaint_tools.py
+++ b/backend/tools/inpaint_tools.py
@@ -2,9 +2,7 @@ import multiprocessing
import cv2
import numpy as np
-from backend import config
-from backend.inpaint.lama_inpaint import LamaInpaint
-
+from backend.config import config
def batch_generator(data, max_batch_size):
"""
@@ -30,88 +28,277 @@ def batch_generator(data, max_batch_size):
if last_batch_start < n_samples:
yield data[last_batch_start:]
-
-def inference_task(batch_data):
- inpainted_frame_dict = dict()
- for data in batch_data:
- index, original_frame, coords_list = data
- mask_size = original_frame.shape[:2]
- mask = create_mask(mask_size, coords_list)
- inpaint_frame = inpaint(original_frame, mask)
- inpainted_frame_dict[index] = inpaint_frame
- return inpainted_frame_dict
-
-
-def parallel_inference(inputs, batch_size=None, pool_size=None):
- """
- 并行推理,同时保持结果顺序
- """
- if pool_size is None:
- pool_size = multiprocessing.cpu_count()
- # 使用上下文管理器自动管理进程池
- with multiprocessing.Pool(processes=pool_size) as pool:
- batched_inputs = list(batch_generator(inputs, batch_size))
- # 使用map函数保证输入输出的顺序是一致的
- batch_results = pool.map(inference_task, batched_inputs)
- # 将批推理结果展平
- index_inpainted_frames = [item for sublist in batch_results for item in sublist]
- return index_inpainted_frames
-
-
-def inpaint(img, mask):
- lama_inpaint_instance = LamaInpaint()
- img_inpainted = lama_inpaint_instance(img, mask)
- return img_inpainted
-
-
-def inpaint_with_multiple_masks(censored_img, mask_list):
- inpainted_frame = censored_img
- if mask_list:
- for mask in mask_list:
- inpainted_frame = inpaint(inpainted_frame, mask)
- return inpainted_frame
-
-
def create_mask(size, coords_list):
mask = np.zeros(size, dtype="uint8")
if coords_list:
for coords in coords_list:
xmin, xmax, ymin, ymax = coords
# 为了避免框过小,放大10个像素
- x1 = xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL
+ x1 = xmin - config.subtitleAreaDeviationPixel.value
if x1 < 0:
x1 = 0
- y1 = ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL
+ y1 = ymin - config.subtitleAreaDeviationPixel.value
if y1 < 0:
y1 = 0
- x2 = xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL
- y2 = ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL
+ x2 = xmax + config.subtitleAreaDeviationPixel.value
+ y2 = ymax + config.subtitleAreaDeviationPixel.value
cv2.rectangle(mask, (x1, y1),
(x2, y2), (255, 255, 255), thickness=-1)
return mask
-
-def inpaint_video(video_path, sub_list):
- index = 0
- frame_to_inpaint_list = []
- video_cap = cv2.VideoCapture(video_path)
- while True:
- # 读取视频帧
- ret, frame = video_cap.read()
- if not ret:
- break
- index += 1
- if index in sub_list.keys():
- frame_to_inpaint_list.append((index, frame, sub_list[index]))
- if len(frame_to_inpaint_list) > config.PROPAINTER_MAX_LOAD_NUM:
- batch_results = parallel_inference(frame_to_inpaint_list)
- for index, frame in batch_results:
- file_name = f'/home/yao/Documents/Project/video-subtitle-remover/test/temp/{index}.png'
- cv2.imwrite(file_name, frame)
- print(f"success write: {file_name}")
- frame_to_inpaint_list.clear()
- print(f'finished')
-
+def get_inpaint_area_by_mask(W, H, h, mask, multiple=1):
+ """
+ 获取字幕去除区域,根据mask来确定需要填补的区域和高度,
+ 并根据模型要求调整区域大小为指定倍数
+
+ Args:
+ W: 图像宽度
+ H: 图像高度
+ h: 检测区域高度
+ mask: 遮罩图像
+ multiple: 区域尺寸需要满足的倍数,默认为1
+
+ Returns:
+ 调整后的绘画区域列表,格式为[(ymin, ymax, xmin, xmax), ...]
+ """
+ # 存储绘画区域的列表
+ inpaint_area = []
+
+ # 如果mask全为0,直接返回空列表
+ if np.all(mask == 0):
+ return inpaint_area
+
+ # 使用连通组件分析找出mask中的所有孤岛
+ # 首先确保mask是二值图像
+ binary_mask = (mask > 0).astype(np.uint8) * 255
+
+ # 查找连通组件
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
+
+ # 跳过背景(标签0)
+ island_info = []
+ for i in range(1, num_labels):
+ # 获取当前孤岛的统计信息
+ x = stats[i, cv2.CC_STAT_LEFT]
+ y = stats[i, cv2.CC_STAT_TOP]
+ w = stats[i, cv2.CC_STAT_WIDTH]
+ height = stats[i, cv2.CC_STAT_HEIGHT]
+ area = stats[i, cv2.CC_STAT_AREA]
+
+ # 忽略太小的区域(可能是噪点)
+ if area < 10:
+ continue
+
+ # 保存孤岛信息:顶部y坐标,底部y坐标,中心点y坐标,面积,标签
+ center_y = int(centroids[i][1])
+ island_info.append((y, y + height, center_y, area, i))
+
+ # 如果没有有效孤岛,返回空列表
+ if not island_info:
+ return inpaint_area
+
+ # 按中心点y坐标排序孤岛
+ island_info.sort(key=lambda x: x[2])
+
+ # 尝试合并孤岛
+ merged_islands = []
+ current_group = [island_info[0]]
+
+ for i in range(1, len(island_info)):
+ # 当前组的范围
+ min_y = min([island[0] for island in current_group])
+ max_y = max([island[1] for island in current_group])
+
+ # 当前孤岛
+ top_y, bottom_y, center_y, _, _ = island_info[i]
+
+ # 计算如果添加当前孤岛,新组的范围
+ new_min_y = min(min_y, top_y)
+ new_max_y = max(max_y, bottom_y)
+
+ # 检查是否有mask连接当前组和新孤岛
+ has_connection = False
+ if max_y < top_y: # 只有当前组在新孤岛上方时才需要检查连接
+ # 检查两个区域之间是否有mask像素
+ middle_region = binary_mask[max_y:top_y, :]
+ if np.any(middle_region > 0):
+ has_connection = True
+ else: # 重叠或相邻
+ has_connection = True
+
+ # 检查合并后的高度是否在h范围内,并且有连接
+ if new_max_y - new_min_y <= h and has_connection:
+ # 可以合并
+ current_group.append(island_info[i])
+ else:
+ # 无法合并,保存当前组并开始新组
+ merged_islands.append(current_group)
+ current_group = [island_info[i]]
+
+ # 添加最后一个组
+ merged_islands.append(current_group)
+
+ # 为每个合并后的组创建区域
+ for group in merged_islands:
+ # 获取组内所有孤岛的范围
+ min_y = min([island[0] for island in group])
+ max_y = max([island[1] for island in group])
+
+ # 计算组的中心点
+ center_y = sum([island[2] for island in group]) // len(group)
+
+ # 确保区域高度精确等于h
+ half_h = h // 2
+
+ # 从中心点向上下扩展,确保高度为h
+ ymin = max(0, center_y - half_h)
+ ymax = ymin + h # 确保高度精确等于h
+
+ # 如果超出图像底部,从底部向上调整
+ if ymax > H:
+ ymax = H
+ ymin = max(0, H - h) # 确保高度为h
+
+ # 检查是否包含了所有孤岛
+ if ymin > min_y or ymax < max_y:
+ # 如果区域不能完全包含所有孤岛,尝试调整位置但保持高度为h
+ if max_y - min_y <= h:
+ # 孤岛总高度不超过h,可以调整位置使其完全包含
+ ymin = min_y
+ ymax = ymin + h
+ # 如果超出底部,从底部向上调整
+ if ymax > H:
+ ymax = H
+ ymin = max(0, H - h)
+ else:
+ # 孤岛总高度超过h,无法完全包含,优先包含中心区域
+ # 计算孤岛的中心
+ island_center = (min_y + max_y) // 2
+ ymin = max(0, island_center - half_h)
+ ymax = ymin + h
+ # 如果超出底部,从底部向上调整
+ if ymax > H:
+ ymax = H
+ ymin = max(0, H - h)
+
+ # 使用完整宽度
+ xmin = 0
+ xmax = W
+
+ # 调整区域大小为指定倍数
+ if multiple > 1:
+ # 计算区域高度
+ height = ymax - ymin
+ # 计算需要调整的高度,使其成为multiple的倍数
+ remainder = height % multiple
+
+ if remainder != 0:
+ # 需要调整的像素数
+ adjust_pixels = multiple - remainder
+
+ # 计算区域中心点
+ center_y = (ymin + ymax) / 2
+
+ # 优先对称扩展
+ if ymin - adjust_pixels/2 >= 0 and ymax + adjust_pixels/2 <= H:
+ # 对称扩展
+ ymin = int(center_y - height/2 - adjust_pixels/2)
+ ymax = int(center_y + height/2 + adjust_pixels/2)
+ # 如果对称扩展会超出边界,尝试对称缩小
+ elif height > multiple: # 确保缩小后高度至少为multiple
+ # 对称缩小
+ ymin = int(center_y - (height - remainder)/2)
+ ymax = int(center_y + (height - remainder)/2)
+ # 如果无法对称调整,则尝试单边调整
+ else:
+ # 向下扩展
+ if ymax + adjust_pixels <= H:
+ ymax += adjust_pixels
+ # 向上扩展
+ elif ymin - adjust_pixels >= 0:
+ ymin -= adjust_pixels
+ # 如果都不行,则尝试缩小区域
+ elif height > multiple:
+ ymax = ymin + height - remainder
+
+ # 调整宽度,确保是multiple的倍数
+ width = xmax - xmin
+ remainder_w = width % multiple
+
+ if remainder_w != 0:
+ # 需要调整的像素数
+ adjust_pixels_w = multiple - remainder_w
+
+ # 计算中心点,对称缩小
+ center_x = (xmin + xmax) / 2
+ xmin = int(center_x - (width - remainder_w)/2)
+ xmax = int(center_x + (width - remainder_w)/2)
+
+ # 将该区域添加到列表中,格式为(ymin, ymax, xmin, xmax)
+ area = (int(ymin), int(ymax), int(xmin), int(xmax))
+ if area not in inpaint_area:
+ inpaint_area.append(area)
+
+ return inpaint_area # 返回绘画区域列表,格式为[(ymin, ymax, xmin, xmax), ...]
+
+def expand_frame_ranges(frame_ranges, backward_frame_count, forward_frame_count):
+ """
+ 扩展帧区间列表,向前和向后扩展指定的帧数,并确保区间连续性
+
+ Args:
+ frame_ranges: 帧区间列表,格式为[(start1, end1), (start2, end2), ...]
+ backward_frame_count: 向前扩展的帧数
+ forward_frame_count: 向后扩展的帧数
+
+ Returns:
+ 扩展后的帧区间列表,保证连续性
+ """
+ if not frame_ranges:
+ return []
+
+ # 按起始帧排序
+ sorted_ranges = sorted(frame_ranges)
+ expanded_ranges = []
+
+ for i, (start, end) in enumerate(sorted_ranges):
+ # 向前扩展,但不能小于1
+ new_start = max(1, start - backward_frame_count)
+
+ # 向后扩展
+ new_end = end + forward_frame_count
+
+ # 检查是否与下一个区间重叠
+ if i < len(sorted_ranges) - 1:
+ next_start = sorted_ranges[i + 1][0]
+
+ # 如果扩展后的结束帧超过了下一个区间的起始帧
+ if new_end >= next_start:
+ # 计算中点
+ mid_point = (end + next_start) // 2
+
+ # 如果区间是连续的(相差1),则对半平分
+ if next_start - end == 1:
+ new_end = end # 保持原结束帧
+ else:
+ # 非连续区间,限制扩展到下一个区间起始帧减去backward_frame_count
+ max_expand = next_start - 1 # 确保不会与下一个区间重叠
+ new_end = min(new_end, max_expand)
+
+ # 确保与前一个区间不重叠
+ if expanded_ranges:
+ prev_end = expanded_ranges[-1][1]
+ if new_start <= prev_end:
+ # 如果新区间的开始小于等于前一个区间的结束,调整开始位置
+ new_start = prev_end + 1
+
+ # 确保区间有效(开始不大于结束)
+ if new_start <= new_end:
+ expanded_ranges.append((new_start, new_end))
+ else:
+ # 如果调整后区间无效,保留原始区间
+ expanded_ranges.append((start, end))
+
+ return expanded_ranges
if __name__ == '__main__':
multiprocessing.set_start_method("spawn")
diff --git a/backend/tools/model_config.py b/backend/tools/model_config.py
new file mode 100644
index 0000000..641cf1f
--- /dev/null
+++ b/backend/tools/model_config.py
@@ -0,0 +1,63 @@
+import os
+from backend.config import config, BASE_DIR
+from backend.tools.common_tools import merge_big_file_if_not_exists
+from backend.tools.constant import SubtitleDetectMode
+
+class ModelConfig:
+ def __init__(self):
+ self.LAMA_MODEL_DIR = os.path.join(BASE_DIR, 'models', 'big-lama')
+ self.STTN_AUTO_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn-auto', 'infer_model.pth')
+ self.STTN_DET_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn-det', 'sttn.pth')
+ self.PROPAINTER_MODEL_DIR = os.path.join(BASE_DIR,'models', 'propainter')
+ if config.subtitleDetectMode.value == SubtitleDetectMode.Fast:
+ self.DET_MODEL_DIR = os.path.join(BASE_DIR,'models', 'V4', 'ch_det_fast')
+ elif config.subtitleDetectMode.value == SubtitleDetectMode.Accurate:
+ self.DET_MODEL_DIR = os.path.join(BASE_DIR, 'models', 'V4', 'ch_det')
+ else:
+ raise ValueError(f"Invalid subtitle detect mode: {config.subtitleDetectMode.value}")
+
+ merge_big_file_if_not_exists(self.LAMA_MODEL_DIR, 'bit-lama.pt')
+ merge_big_file_if_not_exists(self.PROPAINTER_MODEL_DIR, 'ProPainter.pth')
+ merge_big_file_if_not_exists(self.DET_MODEL_DIR, 'inference.pdiparams')
+
+ def convertToOnnxModelIfNeeded(self, model_dir, model_filename="inference.pdmodel", params_filename="inference.pdiparams", opset_version=14):
+ """Converts a Paddle model to ONNX if ONNX providers are available and the model does not already exist."""
+
+ onnx_model_path = os.path.join(model_dir, "model.onnx")
+
+ if os.path.exists(onnx_model_path):
+ print(f"ONNX model already exists: {onnx_model_path}. Skipping conversion.")
+ return onnx_model_path
+
+ print(f"Converting Paddle model {model_dir} to ONNX...")
+ model_file = os.path.join(model_dir, model_filename)
+ params_file = os.path.join(model_dir, params_filename) if params_filename else ""
+
+ try:
+ import paddle2onnx
+ # Ensure the target directory exists
+ os.makedirs(os.path.dirname(onnx_model_path), exist_ok=True)
+
+ # Convert and save the model
+ onnx_model = paddle2onnx.export(
+ model_filename=model_file,
+ params_filename=params_file,
+ save_file=onnx_model_path,
+ opset_version=opset_version,
+ auto_upgrade_opset=True,
+ verbose=True,
+ enable_onnx_checker=True,
+ enable_experimental_op=True,
+ enable_optimize=True,
+ custom_op_info={},
+ deploy_backend="onnxruntime",
+ calibration_file="calibration.cache",
+ external_file=os.path.join(model_dir, "external_data"),
+ export_fp16_model=False,
+ )
+
+ print(f"Conversion successful. ONNX model saved to: {onnx_model_path}")
+ return onnx_model_path
+ except Exception as e:
+ print(f"Error during conversion: {e}")
+ return model_dir
diff --git a/backend/tools/ocr.py b/backend/tools/ocr.py
new file mode 100644
index 0000000..0feeab7
--- /dev/null
+++ b/backend/tools/ocr.py
@@ -0,0 +1,20 @@
+def get_coordinates(dt_box):
+ """
+ 从返回的检测框中获取坐标
+ :param dt_box 检测框返回结果
+ :return list 坐标点列表
+ """
+ coordinate_list = list()
+ if isinstance(dt_box, list):
+ for i in dt_box:
+ i = list(i)
+ (x1, y1) = int(i[0][0]), int(i[0][1])
+ (x2, y2) = int(i[1][0]), int(i[1][1])
+ (x3, y3) = int(i[2][0]), int(i[2][1])
+ (x4, y4) = int(i[3][0]), int(i[3][1])
+ xmin = max(x1, x4)
+ xmax = min(x2, x3)
+ ymin = max(y1, y2)
+ ymax = min(y3, y4)
+ coordinate_list.append((xmin, xmax, ymin, ymax))
+ return coordinate_list
diff --git a/backend/tools/process_manager.py b/backend/tools/process_manager.py
new file mode 100644
index 0000000..7d4642c
--- /dev/null
+++ b/backend/tools/process_manager.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+"""
+@desc: 进程管理器,用于管理和终止子进程
+"""
+import weakref
+import signal
+import os
+import platform
+import logging
+import atexit
+import subprocess
+import concurrent.futures
+
+class ProcessManager:
+ """
+ 进程管理器类,用于管理子进程的生命周期
+ 使用弱引用避免内存泄漏
+ """
+ _instance = None
+
+ @classmethod
+ def instance(cls):
+ """单例模式获取实例"""
+ if cls._instance is None:
+ cls._instance = ProcessManager()
+ return cls._instance
+
+ def __init__(self):
+ """初始化进程管理器"""
+ self.processes = {}
+ self.logger = logging.getLogger(__name__)
+
+ # 注册退出处理函数
+ atexit.register(self.terminate_all)
+
+ def add_process(self, process, name=None):
+ """
+ 添加进程到管理器
+
+ Args:
+ process: 要添加的进程对象 (subprocess.Popen实例)
+ name: 进程名称,如果不提供则使用进程ID
+ """
+ if process is None:
+ return
+
+ process_id = name or f"Process:{id(process)}"
+ self.processes[process_id] = process
+ print(f"Added process: {process_id}, PID: {process.pid if hasattr(process, 'pid') else 'unknown'}")
+ return process_id
+
+ def add_pid(self, pid, name=None):
+ process_id = name or f"Pid:{pid}"
+ self.processes[process_id] = pid
+ print(f"Added process: {process_id}, PID: {pid}")
+ return process_id
+
+ def remove_process(self, process_id):
+ """
+ 从管理器中移除进程
+
+ Args:
+ process_id: 进程ID或名称
+ """
+ if process_id in self.processes:
+ del self.processes[process_id]
+ print(f"Removed process: {process_id}")
+ return True
+ return False
+
+ def terminate_all(self):
+ """并发终止所有管理的进程"""
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ futures = []
+ for process_id, process in list(self.processes.items()):
+ if isinstance(process, int):
+ futures.append(executor.submit(self.terminate_by_pid, process))
+ else:
+ futures.append(executor.submit(self.terminate_by_process, process))
+
+ # 等待所有终止操作完成
+ concurrent.futures.wait(futures)
+
+ # 清空进程字典
+ self.processes.clear()
+
+ def terminate_by_process(self, process):
+ if process is None:
+ return
+ try:
+ print(f"Terminating process: pid: {process.pid}")
+ if hasattr(process, 'poll') and process.poll() is not None:
+ # 进程已经结束,直接返回
+ return
+
+ # 进程还在运行
+ process.terminate()
+ if hasattr(process, 'join'):
+ try:
+ process.join(timeout=3)
+ except:
+ pass
+ if hasattr(process, 'wait'):
+ try:
+ process.wait(timeout=3)
+ except:
+ pass
+ # 进程未能正常终止,尝试强制终止
+ if hasattr(process, 'kill'):
+ process.kill()
+ except Exception as e:
+ # print(f"Error terminating process: {str(e)}")
+ pass
+ self.terminate_by_pid(process.pid)
+
+ def terminate_by_pid(self, pid):
+ try:
+ # 使用系统命令强制终止进程
+ if platform.system() == 'Windows':
+ subprocess.run(['taskkill', '/F', '/T', '/PID', str(pid)],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=3)
+ else:
+ subprocess.run(['pkill', '-9', '-P', str(pid)],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=2)
+ subprocess.run(['kill', '-9', str(pid)],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=3)
+ except Exception as e:
+ print(f"Error forcibly terminating process with PID {pid}: {str(e)}")
\ No newline at end of file
diff --git a/backend/tools/subtitle_detect.py b/backend/tools/subtitle_detect.py
new file mode 100644
index 0000000..0dcdd6b
--- /dev/null
+++ b/backend/tools/subtitle_detect.py
@@ -0,0 +1,258 @@
+import sys
+from functools import cached_property
+
+import cv2
+from tqdm import tqdm
+
+from .model_config import ModelConfig
+from .hardware_accelerator import HardwareAccelerator
+from .common_tools import get_readable_path
+from .ocr import get_coordinates
+from backend.config import config, tr
+from backend.scenedetect import scene_detect
+from backend.scenedetect.detectors import ContentDetector
+
+class SubtitleDetect:
+ """
+ 文本框检测类,用于检测视频帧中是否存在文本框
+ """
+
+ def __init__(self, video_path, sub_area=None):
+ self.video_path = video_path
+ self.sub_area = sub_area
+
+ @cached_property
+ def text_detector(self):
+ import paddle
+ paddle.disable_signal_handler()
+ from paddleocr.tools.infer import utility
+ from paddleocr.tools.infer.predict_det import TextDetector
+ hardware_accelerator = HardwareAccelerator.instance()
+ onnx_providers = hardware_accelerator.onnx_providers
+ model_config = ModelConfig()
+ parser = utility.init_args()
+ args = parser.parse_args([])
+ args.det_algorithm = 'DB'
+ args.det_model_dir = model_config.convertToOnnxModelIfNeeded(model_config.DET_MODEL_DIR) if len(onnx_providers) > 0 else model_config.DET_MODEL_DIR
+ args.use_gpu=hardware_accelerator.has_cuda()
+ args.use_onnx=len(onnx_providers) > 0
+ args.onnx_providers=onnx_providers
+ return TextDetector(args)
+
+ def detect_subtitle(self, img):
+ temp_list = []
+ dt_boxes, elapse = self.text_detector(img)
+ coordinate_list = get_coordinates(dt_boxes.tolist())
+ if coordinate_list:
+ for coordinate in coordinate_list:
+ xmin, xmax, ymin, ymax = coordinate
+ if self.sub_area is not None:
+ s_ymin, s_ymax, s_xmin, s_xmax = self.sub_area
+ if (s_xmin <= xmin and xmax <= s_xmax
+ and s_ymin <= ymin
+ and ymax <= s_ymax):
+ temp_list.append((xmin, xmax, ymin, ymax))
+ else:
+ temp_list.append((xmin, xmax, ymin, ymax))
+ return temp_list
+
+ def find_subtitle_frame_no(self, sub_remover=None):
+ video_cap = cv2.VideoCapture(get_readable_path(self.video_path))
+ frame_count = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
+ tbar = tqdm(total=int(frame_count), unit='frame', position=0, file=sys.__stdout__, desc='Subtitle Finding')
+ current_frame_no = 0
+ subtitle_frame_no_box_dict = {}
+ if sub_remover:
+ sub_remover.append_output(tr['Main']['ProcessingStartFindingSubtitles'])
+ while video_cap.isOpened():
+ ret, frame = video_cap.read()
+ # 如果读取视频帧失败(视频读到最后一帧)
+ if not ret:
+ break
+ # 读取视频帧成功
+ current_frame_no += 1
+ temp_list = self.detect_subtitle(frame)
+ if len(temp_list) > 0:
+ subtitle_frame_no_box_dict[current_frame_no] = temp_list
+ tbar.update(1)
+ if sub_remover:
+ sub_remover.progress_total = (100 * float(current_frame_no) / float(frame_count)) // 2
+ subtitle_frame_no_box_dict = self.unify_regions(subtitle_frame_no_box_dict)
+ if sub_remover:
+ sub_remover.append_output(tr['Main']['FinishedFindingSubtitles'])
+ new_subtitle_frame_no_box_dict = dict()
+ for key in subtitle_frame_no_box_dict.keys():
+ if len(subtitle_frame_no_box_dict[key]) > 0:
+ new_subtitle_frame_no_box_dict[key] = subtitle_frame_no_box_dict[key]
+ return new_subtitle_frame_no_box_dict
+
+ @staticmethod
+ def split_range_by_scene(intervals, points):
+ # 确保离散值列表是有序的
+ points.sort()
+ # 用于存储结果区间的列表
+ result_intervals = []
+ # 遍历区间
+ for start, end in intervals:
+ # 在当前区间内的点
+ current_points = [p for p in points if start <= p <= end]
+
+ # 遍历当前区间内的离散点
+ for p in current_points:
+ # 如果当前离散点不是区间的起始点,添加从区间开始到离散点前一个数字的区间
+ if start < p:
+ result_intervals.append((start, p - 1))
+ # 更新区间开始为当前离散点
+ start = p
+ # 添加从最后一个离散点或区间开始到区间结束的区间
+ result_intervals.append((start, end))
+ # 输出结果
+ return result_intervals
+
+ @staticmethod
+ def get_scene_div_frame_no(v_path):
+ """
+ 获取发生场景切换的帧号
+ """
+ scene_div_frame_no_list = []
+ scene_list = scene_detect(v_path, ContentDetector())
+ for scene in scene_list:
+ start, end = scene
+ if start.frame_num == 0:
+ pass
+ else:
+ scene_div_frame_no_list.append(start.frame_num + 1)
+ return scene_div_frame_no_list
+
+ @staticmethod
+ def are_similar(region1, region2):
+ """判断两个区域是否相似。"""
+ xmin1, xmax1, ymin1, ymax1 = region1
+ xmin2, xmax2, ymin2, ymax2 = region2
+
+ return abs(xmin1 - xmin2) <= config.subtitleAreaPixelToleranceXPixel.value and abs(xmax1 - xmax2) <= config.subtitleAreaPixelToleranceXPixel.value and \
+ abs(ymin1 - ymin2) <= config.subtitleAreaPixelToleranceYPixel.value and abs(ymax1 - ymax2) <= config.subtitleAreaPixelToleranceYPixel.value
+
+ def unify_regions(self, raw_regions):
+ """将连续相似的区域统一,保持列表结构。"""
+ if len(raw_regions) > 0:
+ keys = sorted(raw_regions.keys()) # 对键进行排序以确保它们是连续的
+ unified_regions = {}
+
+ # 初始化
+ last_key = keys[0]
+ unify_value_map = {last_key: raw_regions[last_key]}
+
+ for key in keys[1:]:
+ current_regions = raw_regions[key]
+
+ # 新增一个列表来存放匹配过的标准区间
+ new_unify_values = []
+
+ for idx, region in enumerate(current_regions):
+ last_standard_region = unify_value_map[last_key][idx] if idx < len(unify_value_map[last_key]) else None
+
+ # 如果当前的区间与前一个键的对应区间相似,我们统一它们
+ if last_standard_region and self.are_similar(region, last_standard_region):
+ new_unify_values.append(last_standard_region)
+ else:
+ new_unify_values.append(region)
+
+ # 更新unify_value_map为最新的区间值
+ unify_value_map[key] = new_unify_values
+ last_key = key
+
+ # 将最终统一后的结果传递给unified_regions
+ for key in keys:
+ unified_regions[key] = unify_value_map[key]
+ return unified_regions
+ else:
+ return raw_regions
+
+ @staticmethod
+ def find_continuous_ranges(subtitle_frame_no_box_dict):
+ """
+ 获取字幕出现的起始帧号与结束帧号
+ """
+ numbers = sorted(list(subtitle_frame_no_box_dict.keys()))
+ ranges = []
+ start = numbers[0] # 初始区间开始值
+
+ for i in range(1, len(numbers)):
+ # 如果当前数字与前一个数字间隔超过1,
+ # 则上一个区间结束,记录当前区间的开始与结束
+ if numbers[i] - numbers[i - 1] != 1:
+ end = numbers[i - 1] # 则该数字是当前连续区间的终点
+ ranges.append((start, end))
+ start = numbers[i] # 开始下一个连续区间
+ # 添加最后一个区间
+ ranges.append((start, numbers[-1]))
+ return ranges
+
+ @staticmethod
+ def find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict):
+ numbers = sorted(list(subtitle_frame_no_box_dict.keys()))
+ ranges = []
+ start = numbers[0] # 初始区间开始值
+ for i in range(1, len(numbers)):
+ # 如果当前帧号与前一个帧号间隔超过1,
+ # 则上一个区间结束,记录当前区间的开始与结束
+ if numbers[i] - numbers[i - 1] != 1:
+ end = numbers[i - 1] # 则该数字是当前连续区间的终点
+ ranges.append((start, end))
+ start = numbers[i] # 开始下一个连续区间
+ # 如果当前帧号与前一个帧号间隔为1,且当前帧号对应的坐标点与上一帧号对应的坐标点不一致
+ # 记录当前区间的开始与结束
+ if numbers[i] - numbers[i - 1] == 1:
+ if subtitle_frame_no_box_dict[numbers[i]] != subtitle_frame_no_box_dict[numbers[i - 1]]:
+ end = numbers[i - 1] # 则该数字是当前连续区间的终点
+ ranges.append((start, end))
+ start = numbers[i] # 开始下一个连续区间
+ # 添加最后一个区间
+ ranges.append((start, numbers[-1]))
+ return ranges
+
+ @staticmethod
+ def filter_and_merge_intervals(intervals, target_length):
+ """
+ 合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH
+ """
+ expanded = []
+ # 首先单独处理单点区间以扩展它们
+ for start, end in intervals:
+ if start == end: # 单点区间
+ # 扩展到接近的目标长度,但保证前后不重叠
+ prev_end = expanded[-1][1] if expanded else float('-inf')
+ next_start = float('inf')
+ # 查找下一个区间的起始点
+ for ns, ne in intervals:
+ if ns > end:
+ next_start = ns
+ break
+ # 确定新的扩展起点和终点
+ new_start = max(start - (target_length - 1) // 2, prev_end + 1)
+ new_end = min(start + (target_length - 1) // 2, next_start - 1)
+ # 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展
+ if new_end < new_start:
+ new_start, new_end = start, start # 保持原样
+ expanded.append((new_start, new_end))
+ else:
+ # 非单点区间直接保留,稍后处理任何可能的重叠
+ expanded.append((start, end))
+ # 排序以合并那些因扩展导致重叠的区间
+ expanded.sort(key=lambda x: x[0])
+ # 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时
+ merged = [expanded[0]]
+ for start, end in expanded[1:]:
+ last_start, last_end = merged[-1]
+ # 检查是否重叠
+ if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
+ # 需要合并
+ merged[-1] = (last_start, max(last_end, end)) # 合并区间
+ elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
+ # 相邻区间也需要合并的场景
+ merged[-1] = (last_start, end)
+ else:
+ # 如果没有重叠且都大于目标长度,则直接保留
+ merged.append((start, end))
+ return merged
diff --git a/backend/tools/subtitle_remover_remote_call.py b/backend/tools/subtitle_remover_remote_call.py
new file mode 100644
index 0000000..669f46a
--- /dev/null
+++ b/backend/tools/subtitle_remover_remote_call.py
@@ -0,0 +1,75 @@
+import multiprocessing
+import threading
+from enum import Enum
+
+class Command(Enum):
+ FINISH = 0,
+ PROGRESS = 1,
+ LOG = 2,
+ MANAGE_PROCESS = 3,
+ ERROR = 4,
+ UPDATE_PREVIEW_WITH_COMP = 5,
+
+class SubtitleRemoverRemoteCall:
+ """
+ 远程回调函数类,用于在多进程环境中传递回调函数
+ """
+ def __init__(self):
+ self.queue = multiprocessing.Queue()
+ self.callbacks = {}
+ self.running = True
+ threading.Thread(target=self.run, daemon=True).start()
+
+ def run(self):
+ try:
+ while self.running:
+ cmd, args = self.queue.get(block=True)
+ if cmd == Command.FINISH:
+ break
+ callback = self.callbacks.get(cmd)
+ if callback:
+ callback(*args)
+ finally:
+ self.running = False
+
+ def stop(self):
+ self.running = False
+
+ def register_update_progress_callback(self, callback):
+ self.callbacks[Command.PROGRESS] = callback
+
+ def register_log_callback(self, callback):
+ self.callbacks[Command.LOG] = callback
+
+ def register_manage_process_callback(self, callback):
+ self.callbacks[Command.MANAGE_PROCESS] = callback
+
+ def register_update_preview_with_comp_callback(self, callback):
+ self.callbacks[Command.UPDATE_PREVIEW_WITH_COMP] = callback
+
+ def register_error_callback(self, callback):
+ self.callbacks[Command.ERROR] = callback
+
+ @staticmethod
+ def remote_call_update_progress(queue, progress, isFinished):
+ queue.put((Command.PROGRESS, (progress, isFinished,)))
+
+ @staticmethod
+ def remote_call_append_log(queue, *args):
+ queue.put((Command.LOG, (*args,)))
+
+ @staticmethod
+ def remote_call_finish(queue, *args):
+ queue.put((Command.FINISH, (None,)))
+
+ @staticmethod
+ def remote_call_catch_error(queue, e):
+ queue.put((Command.ERROR, (e,)))
+
+ @staticmethod
+ def remote_call_manage_process(queue, pid):
+ queue.put((Command.MANAGE_PROCESS, (pid,)))
+
+ @staticmethod
+ def remote_call_update_preview_with_comp(queue, *args):
+ queue.put((Command.UPDATE_PREVIEW_WITH_COMP, (*args,)))
\ No newline at end of file
diff --git a/backend/tools/theme_listener.py b/backend/tools/theme_listener.py
new file mode 100644
index 0000000..13aee96
--- /dev/null
+++ b/backend/tools/theme_listener.py
@@ -0,0 +1,26 @@
+from PySide6 import QtCore, QtWidgets, QtGui
+
+from qfluentwidgets import setTheme, qconfig, Theme
+import darkdetect
+
+
+class SystemThemeListener(QtCore.QThread):
+ """ System theme listener """
+
+ systemThemeChanged = QtCore.Signal()
+
+ def __init__(self, parent=None):
+ super().__init__(parent=parent)
+
+ def run(self):
+ darkdetect.listener(self._onThemeChanged)
+
+ def _onThemeChanged(self, theme: str):
+ theme = Theme.DARK if theme.lower() == "dark" else Theme.LIGHT
+ setTheme(theme)
+ if qconfig.themeMode.value != Theme.AUTO or theme == qconfig.theme:
+ return
+
+ qconfig.theme = Theme.AUTO
+ qconfig._cfg.themeChanged.emit(Theme.AUTO)
+ self.systemThemeChanged.emit()
\ No newline at end of file
diff --git a/backend/tools/version_service.py b/backend/tools/version_service.py
new file mode 100644
index 0000000..22f51a4
--- /dev/null
+++ b/backend/tools/version_service.py
@@ -0,0 +1,83 @@
+# coding: utf-8
+import re
+import os
+import sys
+import requests
+
+from PySide6.QtCore import QVersionNumber
+
+from backend.config import VERSION, PROJECT_UPDATE_URLS, tr
+
+
+class VersionService:
+ """ Version service """
+
+ def __init__(self):
+ self.current_version = VERSION
+ self.lastest_version = VERSION
+ self.version_pattern = re.compile(r'v*((\d+)\.(\d+)\.(\d+))')
+ self.api_endpoints = PROJECT_UPDATE_URLS
+
+ def get_latest_version(self):
+ """ get latest version """
+ headers = {
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36 Edg/112.0.1722.64"
+ }
+
+ proxy = self.get_system_proxy()
+ proxies = {
+ "http": proxy,
+ "https": proxy
+ }
+
+ # 依次尝试不同的API端点
+ for url in self.api_endpoints:
+ try:
+ response = requests.get(url, headers=headers, timeout=5, allow_redirects=True, proxies=proxies)
+ response.raise_for_status()
+
+ # 解析版本
+ version = response.json()['tag_name'] # type:str
+ match = self.version_pattern.search(version)
+ if not match:
+ continue # 如果版本格式不匹配,尝试下一个API
+
+ self.lastest_version = match.group(1)
+ print(tr['VersionService']['VersionInfo'].format(VERSION, self.lastest_version))
+ return self.lastest_version
+ except Exception as e:
+ print(tr['VersionService']['RequestError'].format(url, str(e)))
+ continue # 出错时尝试下一个API
+
+ # 所有API都失败时返回当前版本
+ return VERSION
+
+ def has_new_version(self) -> bool:
+ """ check whether there is a new version """
+ version = QVersionNumber.fromString(self.get_latest_version())
+ current_version = QVersionNumber.fromString(self.current_version)
+ return version > current_version
+
+ def get_system_proxy(self):
+ """ get system proxy """
+ if sys.platform == "win32":
+ try:
+ import winreg
+
+ with winreg.OpenKey(winreg.HKEY_CURRENT_USER, r'Software\Microsoft\Windows\CurrentVersion\Internet Settings') as key:
+ enabled, _ = winreg.QueryValueEx(key, 'ProxyEnable')
+
+ if enabled:
+ return "http://" + winreg.QueryValueEx(key, 'ProxyServer')
+ except:
+ pass
+ elif sys.platform == "darwin":
+ s = os.popen('scutil --proxy').read()
+ info = dict(re.findall(r'(?m)^\s+([A-Z]\w+)\s+:\s+(\S+)', s))
+
+ if info.get('HTTPEnable') == '1':
+ return f"http://{info['HTTPProxy']}:{info['HTTPPort']}"
+ elif info.get('ProxyAutoConfigEnable') == '1':
+ return info['ProxyAutoConfigURLString']
+
+ return os.environ.get("http_proxy")
\ No newline at end of file
diff --git a/design/demo.png b/design/demo.png
index 5917f1c..8362bdb 100644
Binary files a/design/demo.png and b/design/demo.png differ
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 655fa37..5259ab0 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -2,29 +2,53 @@ FROM python:3.12
RUN --mount=type=cache,target=/root/.cache,sharing=private \
apt update && \
- apt install -y libgl1-mesa-glx && \
+ apt install -y libgl1-mesa-glx \
+ # pyside6
+ libegl1 libxkbcommon0 libdbus-1-3 && \
true
ADD . /vsr
ARG CUDA_VERSION=11.8
-ARG USE_DIRECTML=0
+ARG HARDWARD_ACCELERATOR="cuda"
-# 如果是 CUDA 版本,执行 CUDA 特定设置
+# 如果是 CUDA 12.x 版本,执行 CUDA 12.x 特定设置
RUN --mount=type=cache,target=/root/.cache,sharing=private \
- if [ "${USE_DIRECTML:-0}" != "1" ]; then \
+ if [ "${HARDWARD_ACCELERATOR}" = "cuda" ] && [ "${CUDA_VERSION}" != "11.8" ]; then \
pip install paddlepaddle==3.0 && \
pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu$(echo ${CUDA_VERSION} | tr -d '.') && \
pip install -r /vsr/requirements.txt; \
fi
+# 如果是 CUDA 11.8 版本,执行 CUDA 11.8 特定设置
+RUN --mount=type=cache,target=/root/.cache,sharing=private \
+if [ "${HARDWARD_ACCELERATOR}" = "cuda" ] && [ "${CUDA_VERSION}" = "11.8" ]; then \
+ pip install paddlepaddle==3.0 && \
+ pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu$(echo ${CUDA_VERSION} | tr -d '.') && \
+ pip install -r /vsr/requirements.txt && \
+ pip uninstall -y onnxruntime-gpu && \
+ pip install onnxruntime-gpu==1.20.1 --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-11/pypi/simple/ && \
+ # for paddle
+ pip install setuptools==80.4.0; \
+fi
+
# 如果是 DirectML 版本,执行 DirectML 特定设置
RUN --mount=type=cache,target=/root/.cache,sharing=private \
- if [ "${USE_DIRECTML:-0}" = "1" ]; then \
+ if [ "${HARDWARD_ACCELERATOR}" = "directml" ]; then \
pip install paddlepaddle==3.0 && \
pip install torch_directml==0.2.5.dev240914 && \
pip install -r /vsr/requirements.txt; \
fi
+# 如果是 CPU 版本,执行 CPU 特定设置
+RUN --mount=type=cache,target=/root/.cache,sharing=private \
+ if [ "${HARDWARD_ACCELERATOR}" = "cpu" ]; then \
+ pip install paddlepaddle==3.0 && \
+ pip install -r /vsr/requirements.txt && \
+ sed -i 's/HARDWARD_ACCELERATION_OPTION *= *.*/HARDWARD_ACCELERATION_OPTION = False/g' /vsr/backend/config.py; \
+ fi
+
ENV LD_LIBRARY_PATH=/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/
+ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/site-packages/nvidia/cuda_runtime/lib/
+ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/site-packages/nvidia/cuda_nvrtc/lib/
WORKDIR /vsr
CMD ["python", "/vsr/backend/main.py"]
\ No newline at end of file
diff --git a/gui.py b/gui.py
index 09946db..3fbe17d 100644
--- a/gui.py
+++ b/gui.py
@@ -1,405 +1,184 @@
# -*- coding: utf-8 -*-
"""
-@Author : Fang Yao
-@Time : 2023/4/1 6:07 下午
+@Author : Fang Yao(原作者) / 改写:Jason Eric
+@Time : 2023/4/1 6:07 下午(原始时间)
@FileName: gui.py
-@desc: 字幕去除器图形化界面
+@desc: 字幕去除器图形化界面(由 PySimpleGUI 改写为 PySide6)
"""
+
+import sys
import os
import configparser
-import PySimpleGUI as sg
import cv2
-import sys
-from threading import Thread
import multiprocessing
-sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
-sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-import backend.main
-from backend.tools.common_tools import is_image_file
+from PySide6.QtCore import Qt, QTranslator
+from PySide6 import QtCore, QtWidgets, QtGui
+from PySide6.QtWidgets import QApplication, QFrame, QStackedWidget, QHBoxLayout, QLabel
+from qfluentwidgets import (FluentWindow, PushButton, Slider, ProgressBar, PlainTextEdit,
+ setTheme, Theme, FluentIcon, CardWidget, SettingCardGroup,
+ ComboBoxSettingCard, SwitchSettingCard, setThemeColor, OptionsConfigItem,
+ OptionsValidator, SubtitleLabel, HollowHandleStyle, qconfig, ConfigItem, QConfig,
+ NavigationWidget, NavigationItemPosition, isDarkTheme, InfoBar)
+
+from qframelesswindow.utils import getSystemAccentColor
+from backend.config import config, tr, VERSION
+from backend.tools.theme_listener import SystemThemeListener
+from backend.tools.process_manager import ProcessManager
+from ui.advanced_setting_interface import AdvancedSettingInterface
+from ui.home_interface import HomeInterface
-class SubtitleRemoverGUI:
-
+class SubtitleExtractorGUI(FluentWindow):
def __init__(self):
- self.font = 'Arial 10'
- self.theme = 'LightBrown12'
- sg.theme(self.theme)
- self.icon = os.path.join(os.path.dirname(__file__), 'design', 'vsr.ico')
- self.screen_width, self.screen_height = sg.Window.get_screen_size()
- self.subtitle_config_file = os.path.join(os.path.dirname(__file__), 'subtitle.ini')
- print(self.screen_width, self.screen_height)
- # 设置视频预览区域大小
- self.video_preview_width = 960
- self.video_preview_height = self.video_preview_width * 9 // 16
- # 默认组件大小
- self.horizontal_slider_size = (120, 20)
- self.output_size = (100, 10)
- self.progressbar_size = (60, 20)
- # 分辨率低于1080
- if self.screen_width // 2 < 960:
- self.video_preview_width = 640
- self.video_preview_height = self.video_preview_width * 9 // 16
- self.horizontal_slider_size = (60, 20)
- self.output_size = (58, 10)
- self.progressbar_size = (28, 20)
- # 字幕提取器布局
- self.layout = None
- # 字幕提取其窗口
- self.window = None
- # 视频路径
- self.video_path = None
- # 视频cap
- self.video_cap = None
- # 视频的帧率
- self.fps = None
- # 视频的帧数
- self.frame_count = None
- # 视频的宽
- self.frame_width = None
- # 视频的高
- self.frame_height = None
- # 设置字幕区域高宽
- self.xmin = None
- self.xmax = None
- self.ymin = None
- self.ymax = None
- # 字幕提取器
- self.sr = None
+ super().__init__()
+ # 禁用云母效果
+ self.setMicaEffectEnabled(False)
+ # 设置深色主题并跟随系统主题色
+ # setTheme(Theme.LIGHT)
+ # setThemeColor(getSystemAccentColor(), save=True)
- def run(self):
- # 创建布局
+ # 初始化系统主题监听器并连接信号
+ # self.themeListener = SystemThemeListener(self)
+ # self.themeListener.start()
+
+ # 设置窗口图标
+ self.setWindowIcon(QtGui.QIcon("design/vsr.ico"))
+ self.setWindowTitle(tr['SubtitleExtractorGUI']['Title'] + " v" + VERSION)
+ # 创建界面布局
self._create_layout()
- # 创建窗口
- self.window = sg.Window(title=f'Video Subtitle Remover v{backend.main.config.VERSION}' , layout=self.layout,
- icon=self.icon)
- while True:
- # 循环读取事件
- event, values = self.window.read(timeout=10)
- # 处理【打开】事件
- self._file_event_handler(event, values)
- # 处理【滑动】事件
- self._slide_event_handler(event, values)
- # 处理【运行】事件
- self._run_event_handler(event, values)
- # 如果关闭软件,退出
- if event == sg.WIN_CLOSED:
- break
- # 更新进度条
- if self.sr is not None:
- self.window['-PROG-'].update(self.sr.progress_total)
- if self.sr.preview_frame is not None:
- self.window['-DISPLAY-'].update(data=cv2.imencode('.png', self._img_resize(self.sr.preview_frame))[1].tobytes())
- if self.sr.isFinished:
- # 1) 打开修改字幕滑块区域按钮
- self.window['-Y-SLIDER-'].update(disabled=False)
- self.window['-X-SLIDER-'].update(disabled=False)
- self.window['-Y-SLIDER-H-'].update(disabled=False)
- self.window['-X-SLIDER-W-'].update(disabled=False)
- # 2) 打开【运行】、【打开】和【识别语言】按钮
- self.window['-RUN-'].update(disabled=False)
- self.window['-FILE-'].update(disabled=False)
- self.window['-FILE_BTN-'].update(disabled=False)
- self.sr = None
- if len(self.video_paths) >= 1:
- # 1) 关闭修改字幕滑块区域按钮
- self.window['-Y-SLIDER-'].update(disabled=True)
- self.window['-X-SLIDER-'].update(disabled=True)
- self.window['-Y-SLIDER-H-'].update(disabled=True)
- self.window['-X-SLIDER-W-'].update(disabled=True)
- # 2) 关闭【运行】、【打开】和【识别语言】按钮
- self.window['-RUN-'].update(disabled=True)
- self.window['-FILE-'].update(disabled=True)
- self.window['-FILE_BTN-'].update(disabled=True)
+ self._connectSignalToSlot()
+ self._lazy_check_update()
+
+ def _lazy_check_update(self):
+ """ 延迟检查更新 """
+ if not config.checkUpdateOnStartup.value:
+ return
+ self.check_update_timer = QtCore.QTimer(self)
+ self.check_update_timer.setSingleShot(True)
+ self.check_update_timer.timeout.connect(lambda: self.advancedSettingInterface.check_update(ignore=True))
+ self.check_update_timer.start(2000)
+
+ def _connectSignalToSlot(self):
+ config.appRestartSig.connect(self._showRestartTooltip)
+
+ def _showRestartTooltip(self):
+ """ show restart tooltip """
+ InfoBar.success(
+ 'Updated successfully',
+ 'Configuration takes effect after restart',
+ duration=5000,
+ parent=self
+ )
def _create_layout(self):
- """
- 创建字幕提取器布局
- """
- garbage = os.path.join(os.path.dirname(__file__), 'output')
- if os.path.exists(garbage):
- import shutil
- shutil.rmtree(garbage, True)
- self.layout = [
- # 显示视频预览
- [sg.Image(size=(self.video_preview_width, self.video_preview_height), background_color='black',
- key='-DISPLAY-')],
- # 打开按钮 + 快进快退条
- [sg.Input(key='-FILE-', visible=False, enable_events=True),
- sg.FilesBrowse(button_text='Open', file_types=((
- 'All Files', '*.*'), ('mp4', '*.mp4'),
- ('flv', '*.flv'),
- ('wmv', '*.wmv'),
- ('avi', '*.avi')),
- key='-FILE_BTN-', size=(10, 1), font=self.font),
- sg.Slider(size=self.horizontal_slider_size, range=(1, 1), key='-SLIDER-', orientation='h',
- enable_events=True, font=self.font,
- disable_number_display=True),
- ],
- # 输出区域
- [sg.Output(size=self.output_size, font=self.font),
- sg.Frame(title='Vertical', font=self.font, key='-FRAME1-',
- layout=[[
- sg.Slider(range=(0, 0), orientation='v', size=(10, 20),
- disable_number_display=True,
- enable_events=True, font=self.font,
- pad=((10, 10), (20, 20)),
- default_value=0, key='-Y-SLIDER-'),
- sg.Slider(range=(0, 0), orientation='v', size=(10, 20),
- disable_number_display=True,
- enable_events=True, font=self.font,
- pad=((10, 10), (20, 20)),
- default_value=0, key='-Y-SLIDER-H-'),
- ]], pad=((15, 5), (0, 0))),
- sg.Frame(title='Horizontal', font=self.font, key='-FRAME2-',
- layout=[[
- sg.Slider(range=(0, 0), orientation='v', size=(10, 20),
- disable_number_display=True,
- pad=((10, 10), (20, 20)),
- enable_events=True, font=self.font,
- default_value=0, key='-X-SLIDER-'),
- sg.Slider(range=(0, 0), orientation='v', size=(10, 20),
- disable_number_display=True,
- pad=((10, 10), (20, 20)),
- enable_events=True, font=self.font,
- default_value=0, key='-X-SLIDER-W-'),
- ]], pad=((15, 5), (0, 0)))
- ],
+ # 创建主页面和高级设置页面
+ self.homeInterface = HomeInterface(self)
+ self.homeInterface.setObjectName("HomeInterface")
+ self.advancedSettingInterface = AdvancedSettingInterface(self)
+ self.advancedSettingInterface.setObjectName("AdvancedSettingInterface")
+
+ # 添加到主窗口作为子界面
+ self.addSubInterface(self.homeInterface,FluentIcon.HOME, tr['SubtitleExtractorGUI']['Title'])
+ self.addSubInterface(self.advancedSettingInterface, FluentIcon.SETTING, tr['Setting']['AdvancedSetting'], NavigationItemPosition.BOTTOM)
- # 运行按钮 + 进度条
- [sg.Button(button_text='Run', key='-RUN-',
- font=self.font, size=(20, 1)),
- sg.ProgressBar(100, orientation='h', size=self.progressbar_size, key='-PROG-', auto_size_text=True)
- ],
- ]
+ def on_navigation_item_changed(self, key):
+ """导航项变更时的处理函数"""
+ if key == 'main':
+ self.stackWidget.setCurrentIndex(0)
+ elif key == 'advanced':
+ self.stackWidget.setCurrentIndex(1)
- def _file_event_handler(self, event, values):
- """
- 当点击打开按钮时:
- 1)打开视频文件,将画布显示视频帧
- 2)获取视频信息,初始化进度条滑块范围
- """
- if event == '-FILE-':
- self.video_paths = values['-FILE-'].split(';')
- self.video_path = self.video_paths[0]
- if self.video_path != '':
- self.video_cap = cv2.VideoCapture(self.video_path)
- if self.video_cap is None:
- return
- if self.video_cap.isOpened():
- ret, frame = self.video_cap.read()
- if ret:
- for video in self.video_paths:
- print(f"Open Video Success:{video}")
- # 获取视频的帧数
- self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
- # 获取视频的高度
- self.frame_height = self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
- # 获取视频的宽度
- self.frame_width = self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)
- # 获取视频的帧率
- self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
- # 调整视频帧大小,使播放器能够显示
- resized_frame = self._img_resize(frame)
- # resized_frame = cv2.resize(src=frame, dsize=(self.video_preview_width, self.video_preview_height))
- # 显示视频帧
- self.window['-DISPLAY-'].update(data=cv2.imencode('.png', resized_frame)[1].tobytes())
- # 更新视频进度条滑块range
- self.window['-SLIDER-'].update(range=(1, self.frame_count))
- self.window['-SLIDER-'].update(1)
- # 预设字幕区域位置
- y_p, h_p, x_p, w_p = self.parse_subtitle_config()
- y = self.frame_height * y_p
- h = self.frame_height * h_p
- x = self.frame_width * x_p
- w = self.frame_width * w_p
- # 更新视频字幕位置滑块range
- # 更新Y-SLIDER范围
- self.window['-Y-SLIDER-'].update(range=(0, self.frame_height), disabled=False)
- # 更新Y-SLIDER默认值
- self.window['-Y-SLIDER-'].update(y)
- # 更新X-SLIDER范围
- self.window['-X-SLIDER-'].update(range=(0, self.frame_width), disabled=False)
- # 更新X-SLIDER默认值
- self.window['-X-SLIDER-'].update(x)
- # 更新Y-SLIDER-H范围
- self.window['-Y-SLIDER-H-'].update(range=(0, self.frame_height - y))
- # 更新Y-SLIDER-H默认值
- self.window['-Y-SLIDER-H-'].update(h)
- # 更新X-SLIDER-W范围
- self.window['-X-SLIDER-W-'].update(range=(0, self.frame_width - x))
- # 更新X-SLIDER-W默认值
- self.window['-X-SLIDER-W-'].update(w)
- self._update_preview(frame, (y, h, x, w))
+ def closeEvent(self, event):
+ """程序关闭时保存窗口位置并恢复标准输出和标准错误"""
+ self.save_window_position()
+ # 断开信号连接
+ # self.themeListener.terminate()
+ # self.themeListener.deleteLater()
+ ProcessManager.instance().terminate_all()
+ super().closeEvent(event)
- def __disable_button(self):
- # 1) 禁止修改字幕滑块区域
- self.window['-Y-SLIDER-'].update(disabled=True)
- self.window['-X-SLIDER-'].update(disabled=True)
- self.window['-Y-SLIDER-H-'].update(disabled=True)
- self.window['-X-SLIDER-W-'].update(disabled=True)
- # 2) 禁止再次点击【运行】、【打开】和【识别语言】按钮
- self.window['-RUN-'].update(disabled=True)
- self.window['-FILE-'].update(disabled=True)
- self.window['-FILE_BTN-'].update(disabled=True)
+ def _onThemeChangedFinished(self):
+ super()._onThemeChangedFinished()
- def _run_event_handler(self, event, values):
- """
- 当点击运行按钮时:
- 1) 禁止修改字幕滑块区域
- 2) 禁止再次点击【运行】和【打开】按钮
- 3) 设定字幕区域位置
- """
- if event == '-RUN-':
- if self.video_cap is None:
- print('Please Open Video First')
- else:
- # 禁用按钮
- self.__disable_button()
- # 3) 设定字幕区域位置
- self.xmin = int(values['-X-SLIDER-'])
- self.xmax = int(values['-X-SLIDER-'] + values['-X-SLIDER-W-'])
- self.ymin = int(values['-Y-SLIDER-'])
- self.ymax = int(values['-Y-SLIDER-'] + values['-Y-SLIDER-H-'])
- if self.ymax > self.frame_height:
- self.ymax = self.frame_height
- if self.xmax > self.frame_width:
- self.xmax = self.frame_width
- if len(self.video_paths) <= 1:
- subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax)
- else:
- print(f"{'Processing multiple videos or images'}")
- # 先判断每个视频的分辨率是否一致,一致的话设置相同的字幕区域,否则设置为None
- global_size = None
- for temp_video_path in self.video_paths:
- temp_cap = cv2.VideoCapture(temp_video_path)
- if global_size is None:
- global_size = (int(temp_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(temp_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
- else:
- temp_size = (int(temp_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(temp_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
- if temp_size != global_size:
- print('not all video/images in same size, processing in full screen')
- subtitle_area = None
- else:
- subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax)
- y_p = self.ymin / self.frame_height
- h_p = (self.ymax - self.ymin) / self.frame_height
- x_p = self.xmin / self.frame_width
- w_p = (self.xmax - self.xmin) / self.frame_width
- self.set_subtitle_config(y_p, h_p, x_p, w_p)
+ def save_window_position(self):
+ """保存窗口位置到配置文件"""
+ # 保存窗口位置和大小
+ config.set(config.windowX, self.x())
+ config.set(config.windowY, self.y())
+ config.set(config.windowW, self.width())
+ config.set(config.windowH, self.height())
- def task():
- while self.video_paths:
- video_path = self.video_paths.pop()
- if subtitle_area is not None:
- print(f"{'SubtitleArea'}:({self.ymin},{self.ymax},{self.xmin},{self.xmax})")
- self.sr = backend.main.SubtitleRemover(video_path, subtitle_area, True)
- self.__disable_button()
- self.sr.run()
- Thread(target=task, daemon=True).start()
- self.video_cap.release()
- self.video_cap = None
-
- def _slide_event_handler(self, event, values):
- """
- 当滑动视频进度条/滑动字幕选择区域滑块时:
- 1) 判断视频是否存在,如果存在则显示对应的视频帧
- 2) 绘制rectangle
- """
- if event == '-SLIDER-' or event == '-Y-SLIDER-' or event == '-Y-SLIDER-H-' or event == '-X-SLIDER-' or event \
- == '-X-SLIDER-W-':
- # 判断是否时单张图片
- if is_image_file(self.video_path):
- img = cv2.imread(self.video_path)
- self.window['-Y-SLIDER-H-'].update(range=(0, self.frame_height - values['-Y-SLIDER-']))
- self.window['-X-SLIDER-W-'].update(range=(0, self.frame_width - values['-X-SLIDER-']))
- # 画字幕框
- y = int(values['-Y-SLIDER-'])
- h = int(values['-Y-SLIDER-H-'])
- x = int(values['-X-SLIDER-'])
- w = int(values['-X-SLIDER-W-'])
- self._update_preview(img, (y, h, x, w))
- elif self.video_cap is not None and self.video_cap.isOpened():
- frame_no = int(values['-SLIDER-'])
- self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
- ret, frame = self.video_cap.read()
- if ret:
- self.window['-Y-SLIDER-H-'].update(range=(0, self.frame_height-values['-Y-SLIDER-']))
- self.window['-X-SLIDER-W-'].update(range=(0, self.frame_width-values['-X-SLIDER-']))
- # 画字幕框
- y = int(values['-Y-SLIDER-'])
- h = int(values['-Y-SLIDER-H-'])
- x = int(values['-X-SLIDER-'])
- w = int(values['-X-SLIDER-W-'])
- self._update_preview(frame, (y, h, x, w))
-
- def _update_preview(self, frame, y_h_x_w):
- y, h, x, w = y_h_x_w
- # 画字幕框
- draw = cv2.rectangle(img=frame, pt1=(int(x), int(y)), pt2=(int(x) + int(w), int(y) + int(h)),
- color=(0, 255, 0), thickness=3)
- # 调整视频帧大小,使播放器能够显示
- resized_frame = self._img_resize(draw)
- # 显示视频帧
- self.window['-DISPLAY-'].update(data=cv2.imencode('.png', resized_frame)[1].tobytes())
-
- def _img_resize(self, image):
- top, bottom, left, right = (0, 0, 0, 0)
- height, width = image.shape[0], image.shape[1]
- # 对长短不想等的图片,找到最长的一边
- longest_edge = height
- # 计算短边需要增加多少像素宽度使其与长边等长
- if width < longest_edge:
- dw = longest_edge - width
- left = dw // 2
- right = dw - left
- else:
- pass
- # 给图像增加边界
- constant = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
- return cv2.resize(constant, (self.video_preview_width, self.video_preview_height))
-
- def set_subtitle_config(self, y, h, x, w):
- # 写入配置文件
- with open(self.subtitle_config_file, mode='w', encoding='utf-8') as f:
- f.write('[AREA]\n')
- f.write(f'Y = {y}\n')
- f.write(f'H = {h}\n')
- f.write(f'X = {x}\n')
- f.write(f'W = {w}\n')
-
- def parse_subtitle_config(self):
- y_p, h_p, x_p, w_p = .78, .21, .05, .9
- # 如果配置文件不存在,则写入配置文件
- if not os.path.exists(self.subtitle_config_file):
- self.set_subtitle_config(y_p, h_p, x_p, w_p)
- return y_p, h_p, x_p, w_p
- else:
+ def update_progress(self):
+ # 定时器轮询更新进度(现在更新到视频滑块上)
+ if self.se is not None:
try:
- config = configparser.ConfigParser()
- config.read(self.subtitle_config_file, encoding='utf-8')
- conf_y_p, conf_h_p, conf_x_p, conf_w_p = float(config['AREA']['Y']), float(config['AREA']['H']), float(config['AREA']['X']), float(config['AREA']['W'])
- return conf_y_p, conf_h_p, conf_x_p, conf_w_p
- except Exception:
- self.set_subtitle_config(y_p, h_p, x_p, w_p)
- return y_p, h_p, x_p, w_p
+ pos = min(self.frame_count - 1, int(self.se.progress_total / 100 * self.frame_count))
+ if pos != self.video_slider.value():
+ self.video_slider.setValue(pos)
+ # 检查是否完成
+ if self.se.isFinished:
+ self.processing_finished()
+ except Exception as e:
+ # 捕获任何异常,防止崩溃
+ print(f"更新进度时出错: {str(e)}")
+
+ def load_window_position(self):
+ # 尝试读取窗口位置
+ try:
+ x = config.windowX.value
+ y = config.windowY.value
+ width = config.windowW.value
+ height = config.windowH.value
+
+ if not x or not y:
+ self.center_window()
+ return
+
+ # 确保窗口在屏幕内
+ screen_rect = QtWidgets.QApplication.primaryScreen().availableGeometry()
+ if (x >= 0 and y >= 0 and
+ x + width <= screen_rect.width() and
+ y + height <= screen_rect.height()):
+ self.setGeometry(x, y, width, height)
+ else:
+ self.center_window()
+ except Exception as e:
+ print(e)
+ self.center_window()
+
+ def center_window(self):
+ """将窗口居中显示"""
+ screen_rect = QtWidgets.QApplication.primaryScreen().availableGeometry()
+ window_rect = self.frameGeometry()
+ center_point = screen_rect.center()
+ window_rect.moveCenter(center_point)
+ self.move(window_rect.topLeft())
+
+ def keyPressEvent(self, event):
+ """处理键盘事件"""
+ # 检测Ctrl+C组合键
+ if event.key() == QtCore.Qt.Key_C and event.modifiers() == QtCore.Qt.ControlModifier:
+ print("\n程序被用户中断(Ctrl+C),正在退出...")
+ self.close()
+ else:
+ super().keyPressEvent(event)
if __name__ == '__main__':
- try:
- multiprocessing.set_start_method("spawn")
- # 运行图形化界面
- subtitleRemoverGUI = SubtitleRemoverGUI()
- subtitleRemoverGUI.run()
- except Exception as e:
- print(f'[{type(e)}] {e}')
- import traceback
- traceback.print_exc()
- msg = traceback.format_exc()
- err_log_path = os.path.join(os.path.expanduser('~'), 'VSR-Error-Message.log')
- with open(err_log_path, 'w', encoding='utf-8') as f:
- f.writelines(msg)
- import platform
- if platform.system() == 'Windows':
- os.system('pause')
- else:
- input()
+ multiprocessing.set_start_method("spawn")
+ QApplication.setHighDpiScaleFactorRoundingPolicy(
+ Qt.HighDpiScaleFactorRoundingPolicy.PassThrough)
+ app = QtWidgets.QApplication(sys.argv)
+ app.setAttribute(Qt.AA_DontCreateNativeWidgetSiblings)
+ window = SubtitleExtractorGUI()
+ # 先设置透明, 再显示, 否则会有闪烁的效果
+ window.setWindowOpacity(0.0)
+ window.show()
+ window.load_window_position()
+ # 使用动画效果逐渐显示窗口
+ animation = QtCore.QPropertyAnimation(window, b"windowOpacity")
+ animation.setDuration(300) # 300毫秒的动画
+ animation.setStartValue(0.0)
+ animation.setEndValue(1.0)
+ animation.start()
+ app.exec()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 9a4fc9c..5baeba9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,7 +8,7 @@ lmdb==1.6.2
PyYAML==6.0.2
omegaconf==2.3.0
tqdm==4.67.1
-PySimpleGUI-4-foss==4.60.4.1
+pyside6-fluent-widgets==1.7.7
easydict==1.13
scikit-learn==1.6.1
pandas==2.2.3
@@ -19,5 +19,8 @@ av==14.3.0
einops==0.8.1
paddleocr==2.10.0
paddle2onnx==1.3.1
-onnxruntime-gpu==1.20.1
-onnxruntime-directml==1.20.1; sys_platform == 'win32'
+onnxruntime-gpu==1.20.1; sys_platform == 'linux'
+# do not upgrade, windows 10/11 compatable issues, check hardware_acclerator.py:100
+onnxruntime-directml==1.20.1; sys_platform == 'win32'
+onnxruntime-coreml==1.13.1; sys_platform == 'darwin'
+je-showinfilemanager==1.1.6a4
\ No newline at end of file
diff --git a/ui/advanced_setting_interface.py b/ui/advanced_setting_interface.py
new file mode 100644
index 0000000..a9d6364
--- /dev/null
+++ b/ui/advanced_setting_interface.py
@@ -0,0 +1,247 @@
+"""
+@desc: 高级设置页面
+"""
+
+from PySide6 import QtWidgets, QtCore, QtGui
+from qfluentwidgets import (ScrollArea, ExpandLayout, CardWidget, SubtitleLabel,
+ FluentIcon, NavigationWidget, NavigationItemPosition,
+ SettingCardGroup, RangeSettingCard, SwitchSettingCard,
+ HyperlinkCard, PrimaryPushSettingCard, ComboBoxSettingCard,
+ MessageBox)
+from backend.config import config, tr, VERSION, PROJECT_HOME_URL, PROJECT_ISSUES_URL, PROJECT_RELEASES_URL
+from backend.tools.version_service import VersionService
+from backend.tools.concurrent import TaskExecutor
+
+class AdvancedSettingInterface(ScrollArea):
+ """高级设置页面"""
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.parent = parent
+ self.version_manager = VersionService()
+ self.__initWidget()
+
+ def __initWidget(self):
+ # 创建滚动内容的容器
+ self.scrollWidget = QtWidgets.QWidget(self)
+ self.expandLayout = ExpandLayout(self.scrollWidget)
+
+ # 设置滚动区域属性
+ self.setWidget(self.scrollWidget)
+ self.enableTransparentBackground()
+ self.setWidgetResizable(True)
+ self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
+ self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded)
+
+ # 设置滚动区域样式以适应主题
+ self.setAttribute(QtCore.Qt.WA_StyledBackground)
+
+ # 设置UI
+ self.setup_ui()
+ self.setup_layout()
+
+ def setup_layout(self):
+ self.subtitle_detection_group.addSettingCard(self.subtitle_yx_axis_difference_pixel)
+ self.subtitle_detection_group.addSettingCard(self.subtitle_area_deviation_pixel)
+ self.subtitle_detection_group.addSettingCard(self.subtitle_area_y_axis_difference_pixel)
+ self.subtitle_detection_group.addSettingCard(self.subtitle_area_pixel_tolerance_y_pixel)
+ self.subtitle_detection_group.addSettingCard(self.subtitle_area_pixel_tolerance_x_pixel)
+ self.subtitle_detection_group.addSettingCard(self.subtitle_timeline_backward_frame_count)
+ self.subtitle_detection_group.addSettingCard(self.subtitle_timeline_forward_frame_count)
+ self.expandLayout.addWidget(self.subtitle_detection_group)
+
+ self.sttn_group.addSettingCard(self.sttn_neighbor_stride)
+ self.sttn_group.addSettingCard(self.sttn_reference_length)
+ self.sttn_group.addSettingCard(self.sttn_max_load_num)
+ self.expandLayout.addWidget(self.sttn_group)
+
+ self.propainter_group.addSettingCard(self.propainter_max_load_num)
+ self.expandLayout.addWidget(self.propainter_group)
+
+ self.advanced_group.addSettingCard(self.check_update_on_startup)
+ self.expandLayout.addWidget(self.advanced_group)
+
+ self.about_group.addSettingCard(self.feedback)
+ self.about_group.addSettingCard(self.copyright)
+ self.about_group.addSettingCard(self.project_link)
+
+ self.expandLayout.addWidget(self.about_group)
+ self.expandLayout.setSpacing(16)
+ self.expandLayout.setContentsMargins(16, 16, 16, 48)
+
+ def setup_ui(self):
+ """设置UI"""
+ # 字幕检测设置组
+ self.subtitle_detection_group = SettingCardGroup(tr["Setting"]["SubtitleDetectionSetting"], self.scrollWidget)
+ # STTN设置组
+ self.sttn_group = SettingCardGroup(tr["Setting"]["SttnSetting"], self.scrollWidget)
+ # Propainter设置组
+ self.propainter_group = SettingCardGroup(tr["Setting"]["ProPainterSetting"], self.scrollWidget)
+ # 高级设置组
+ self.advanced_group = SettingCardGroup(tr["Setting"]["AdvancedSetting"], self.scrollWidget)
+ # 关于设置组
+ self.about_group = SettingCardGroup(tr["Setting"]["AboutSetting"], self.scrollWidget)
+
+ self.subtitle_yx_axis_difference_pixel = RangeSettingCard(
+ configItem=config.subtitleYXAxisDifferencePixel,
+ icon=FluentIcon.ZOOM,
+ title=tr["Setting"]["SubtitleYXAxisDifferencePixel"],
+ content=tr["Setting"]["SubtitleYXAxisDifferencePixelDesc"],
+ parent=self.subtitle_detection_group
+ )
+
+ self.subtitle_area_deviation_pixel = RangeSettingCard(
+ configItem=config.subtitleAreaDeviationPixel,
+ icon=FluentIcon.ZOOM_IN,
+ title=tr["Setting"]["SubtitleAreaDeviationPixel"],
+ content=tr["Setting"]["SubtitleAreaDeviationPixelDesc"],
+ parent=self.subtitle_detection_group
+ )
+
+ self.subtitle_area_y_axis_difference_pixel = RangeSettingCard(
+ configItem=config.subtitleAreaYAxisDifferencePixel,
+ icon=FluentIcon.ALIGNMENT,
+ title=tr["Setting"]["SubtitleAreaYAxisDifferencePixel"],
+ content=tr["Setting"]["SubtitleAreaYAxisDifferencePixelDesc"],
+ parent=self.subtitle_detection_group
+ )
+
+ self.subtitle_area_pixel_tolerance_y_pixel = RangeSettingCard(
+ configItem=config.subtitleAreaPixelToleranceYPixel,
+ icon=FluentIcon.UP,
+ title=tr["Setting"]["SubtitleAreaPixelToleranceYPixel"],
+ content=tr["Setting"]["SubtitleAreaPixelToleranceYPixelDesc"],
+ parent=self.subtitle_detection_group
+ )
+
+ self.subtitle_area_pixel_tolerance_x_pixel = RangeSettingCard(
+ configItem=config.subtitleAreaPixelToleranceXPixel,
+ icon=FluentIcon.RIGHT_ARROW,
+ title=tr["Setting"]["SubtitleAreaPixelToleranceXPixel"],
+ content=tr["Setting"]["SubtitleAreaPixelToleranceXPixelDesc"],
+ parent=self.subtitle_detection_group
+ )
+
+ self.subtitle_timeline_backward_frame_count = RangeSettingCard(
+ configItem=config.subtitleTimelineBackwardFrameCount,
+ icon=FluentIcon.PAGE_LEFT,
+ title=tr["Setting"]["SubtitleTimelineBackwardFrameCount"],
+ content=tr["Setting"]["SubtitleTimelineBackwardFrameCountDesc"],
+ parent=self.subtitle_detection_group
+ )
+
+ self.subtitle_timeline_forward_frame_count = RangeSettingCard(
+ configItem=config.subtitleTimelineForwardFrameCount,
+ icon=FluentIcon.PAGE_RIGHT,
+ title=tr["Setting"]["subtitleTimelineForwardFrameCount"],
+ content=tr["Setting"]["subtitleTimelineForwardFrameCountDesc"],
+ parent=self.subtitle_detection_group
+ )
+
+ self.sttn_neighbor_stride = RangeSettingCard(
+ configItem=config.sttnNeighborStride,
+ icon=FluentIcon.UNIT,
+ title=tr["Setting"]["SttnNeighborStride"],
+ content=tr["Setting"]["SttnNeighborStrideDesc"],
+ parent=self.sttn_group
+ )
+
+ self.sttn_reference_length = RangeSettingCard(
+ configItem=config.sttnReferenceLength,
+ icon=FluentIcon.MORE,
+ title=tr["Setting"]["SttnReferenceLength"],
+ content=tr["Setting"]["SttnReferenceLengthDesc"],
+ parent=self.sttn_group
+ )
+
+ self.sttn_max_load_num = RangeSettingCard(
+ configItem=config.sttnMaxLoadNum,
+ icon=FluentIcon.DICTIONARY,
+ title=tr["Setting"]["SttnMaxLoadNum"],
+ content=tr["Setting"]["SttnMaxLoadNumDesc"],
+ parent=self.sttn_group
+ )
+
+ self.propainter_max_load_num = RangeSettingCard(
+ configItem=config.propainterMaxLoadNum,
+ icon=FluentIcon.DICTIONARY,
+ title=tr["Setting"]["PropainterMaxLoadNum"],
+ content=tr["Setting"]["PropainterMaxLoadNumDesc"],
+ parent=self.propainter_group
+ )
+
+ self.check_update_on_startup = SwitchSettingCard(
+ configItem=config.checkUpdateOnStartup,
+ icon=FluentIcon.UPDATE,
+ title=tr["Setting"]["CheckUpdateOnStartup"],
+ content=tr["Setting"]["CheckUpdateOnStartupDesc"],
+ parent=self.advanced_group
+ )
+
+ # 添加反馈链接
+ self.feedback = PrimaryPushSettingCard(
+ text=tr["Setting"]["FeedbackButton"],
+ icon=FluentIcon.MAIL,
+ title=tr["Setting"]["FeedbackTitle"],
+ content=tr["Setting"]["FeedbackDesc"],
+ parent=self.about_group
+ )
+ self.feedback.clicked.connect(lambda: QtGui.QDesktopServices.openUrl(
+ QtCore.QUrl(PROJECT_ISSUES_URL)
+ ))
+ # 添加版权信息
+ self.copyright = PrimaryPushSettingCard(
+ text=tr["Setting"]["CopyrightButton"],
+ icon=FluentIcon.MAIL,
+ title=tr["Setting"]["CopyrightTitle"],
+ content=tr["Setting"]["CopyrightDesc"].format(VERSION),
+ parent=self.about_group
+ )
+ self.copyright.clicked.connect(lambda: self.check_update())
+ # 添加项目链接
+ self.project_link = HyperlinkCard(
+ url=PROJECT_HOME_URL,
+ text=PROJECT_HOME_URL,
+ icon=FluentIcon.GITHUB,
+ title=tr["Setting"]["ProjectLinkTitle"],
+ content=tr["Setting"]["ProjectLinkDesc"],
+ parent=self.about_group
+ )
+
+ def show_message_box(self, title: str, content: str, showYesButton=False, yesSlot=None):
+ """ show message box """
+ w = MessageBox(title, content, self)
+ if not showYesButton:
+ w.cancelButton.setText(self.tr('Close'))
+ w.yesButton.hide()
+ w.buttonLayout.insertStretch(0, 1)
+
+ if w.exec() and yesSlot is not None:
+ yesSlot()
+
+ def check_update(self, ignore=False):
+ """ check software update
+
+ Parameters
+ ----------
+ ignore: bool
+ ignore message box when no updates are available
+ """
+ TaskExecutor.runTask(self.version_manager.has_new_version).then(
+ lambda success: self.on_version_info_fetched(success, ignore))
+
+ def on_version_info_fetched(self, success, ignore=False):
+ if success:
+ self.show_message_box(
+ tr["Setting"]["UpdatesAvailableTitle"],
+ tr["Setting"]["UpdatesAvailableDesc"].format(self.version_manager.lastest_version),
+ True,
+ lambda: QtGui.QDesktopServices.openUrl(
+ QtCore.QUrl(PROJECT_RELEASES_URL)
+ )
+ )
+ elif not ignore:
+ self.show_message_box(
+ tr["Setting"]["NoUpdatesAvailableTitle"],
+ tr["Setting"]["NoUpdatesAvailableDesc"],
+ )
\ No newline at end of file
diff --git a/ui/component/task_list_component.py b/ui/component/task_list_component.py
new file mode 100644
index 0000000..fde4bc8
--- /dev/null
+++ b/ui/component/task_list_component.py
@@ -0,0 +1,331 @@
+import os
+from enum import Enum
+from dataclasses import dataclass
+from PySide6.QtWidgets import QWidget, QVBoxLayout, QMenu, QAbstractItemView, QTableWidgetItem, QHeaderView
+from PySide6.QtCore import Qt, Signal, QModelIndex, QUrl
+from qfluentwidgets import TableWidget, BodyLabel, FluentIcon, InfoBar, InfoBarPosition
+from PySide6.QtGui import QAction, QColor, QBrush
+from showinfm import show_in_file_manager
+
+from backend.config import tr
+
+class TaskStatus(Enum):
+ PENDING = tr['TaskList']['Pending']
+ PROCESSING = tr['TaskList']['Processing']
+ COMPLETED = tr['TaskList']['Completed']
+ FAILED = tr['TaskList']['Failed']
+
+@dataclass
+class Task:
+ path: str
+ name: str
+ progress: int
+ status: TaskStatus
+ output_path: str
+
+class TaskListComponent(QWidget):
+ """任务列表组件"""
+
+ # 定义信号
+ task_selected = Signal(int, str) # 任务被选中时发出信号,参数为任务索引和视频路径
+ task_deleted = Signal(int) # 任务被删除时发出信号,参数为任务索引
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.setObjectName("TaskListComponent")
+
+ # 初始化变量
+ self.tasks = [] # 存储任务列表
+ self.current_task_index = -1 # 当前选中的任务索引
+
+ # 创建布局
+ self.__initWidget()
+
+ def __initWidget(self):
+ """初始化组件"""
+ layout = QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ # 创建表格
+ self.table = TableWidget(self)
+ self.table.setColumnCount(3)
+ self.table.setHorizontalHeaderLabels([tr['TaskList']['Name'], tr['TaskList']['Progress'], tr['TaskList']['Status']])
+
+ # 设置表格样式
+ self.table.setShowGrid(False)
+ self.table.setAlternatingRowColors(True)
+
+ # 设置列宽模式
+ header = self.table.horizontalHeader()
+ header.setSectionResizeMode(0, QHeaderView.Stretch) # 名称列拉伸填充
+ header.setSectionResizeMode(1, QHeaderView.ResizeToContents) # 进度列自适应内容宽度
+ header.setSectionResizeMode(2, QHeaderView.ResizeToContents) # 状态列自适应内容宽度
+
+ self.table.setSelectionBehavior(QAbstractItemView.SelectRows)
+ self.table.setSelectionMode(QAbstractItemView.SingleSelection)
+ self.table.setEditTriggers(QAbstractItemView.NoEditTriggers)
+
+ # 连接信号
+ self.table.setContextMenuPolicy(Qt.CustomContextMenu)
+ self.table.customContextMenuRequested.connect(self.show_context_menu)
+ self.table.clicked.connect(self.on_task_clicked)
+
+ layout.addWidget(self.table)
+
+ def add_task(self, video_path, output_path):
+ """添加任务到列表
+
+ Args:
+ video_path: 视频文件路径
+ """
+ # 覆盖相同路径的任务
+ for row, task in enumerate(self.tasks[:]):
+ if task.path == video_path:
+ self.delete_task(row)
+ continue
+
+ # 获取文件名
+ file_name = os.path.basename(video_path)
+
+ # 添加到任务列表
+ task = Task(
+ path=video_path,
+ name=file_name,
+ progress=0,
+ status=TaskStatus.PENDING,
+ output_path=output_path,
+ )
+ self.tasks.append(task)
+
+ # 更新表格
+ row = len(self.tasks) - 1
+ self.table.setRowCount(len(self.tasks))
+
+ item0 = QTableWidgetItem(file_name)
+ item1 = QTableWidgetItem("0%")
+ item2 = QTableWidgetItem(TaskStatus.PENDING.value)
+
+ # 设置文件名单元格的省略模式为中间省略
+ item0.setTextAlignment(Qt.AlignVCenter | Qt.AlignLeft)
+ item0.setToolTip(video_path) # 设置完整路径为工具提示
+ # 设置表格的文本省略模式
+ self.table.setTextElideMode(Qt.ElideMiddle)
+
+ item1.setTextAlignment(Qt.AlignCenter)
+ item2.setTextAlignment(Qt.AlignCenter)
+
+ self.table.setItem(row, 0, item0)
+ self.table.setItem(row, 1, item1)
+ self.table.setItem(row, 2, item2)
+
+ # 滚动到最新添加的行
+ self.table.scrollToBottom()
+ return True
+
+ def update_task_progress(self, index, progress):
+ """更新任务进度
+
+ Args:
+ index: 任务索引
+ progress: 进度值(0-100)
+ """
+ if 0 <= index < len(self.tasks):
+ self.tasks[index].progress = progress
+
+ # 更新进度单元格
+ progress_item = self.table.item(index, 1)
+ if progress_item:
+ progress_item.setText(f"{progress}%")
+
+ # 如果是当前处理的任务,滚动到可见区域
+ if index == self.current_task_index:
+ self.table.scrollTo(self.table.model().index(index, 0))
+
+ def update_task_status(self, index, status):
+ """更新任务状态
+
+ Args:
+ index: 任务索引
+ status: 任务状态
+ """
+ if 0 <= index < len(self.tasks):
+ self.tasks[index].status = status
+ status_item = self.table.item(index, 2)
+ if status_item:
+ status_item.setText(status.value)
+
+ # 根据状态设置不同颜色
+ if status == TaskStatus.COMPLETED:
+ status_item.setForeground(QBrush(QColor("#2ecc71"))) # 绿色
+ elif status == TaskStatus.PROCESSING:
+ status_item.setForeground(QBrush(QColor("#3498db"))) # 蓝色
+ elif status == TaskStatus.FAILED:
+ status_item.setForeground(QBrush(QColor("#e74c3c"))) # 红色
+
+ # 如果是当前处理的任务,滚动到可见区域
+ if index == self.current_task_index:
+ self.table.scrollTo(self.table.model().index(index, 0))
+
+ # 选中当前行
+ self.table.selectRow(index)
+
+ def get_pending_tasks(self):
+ """获取所有待处理的任务
+
+ Returns:
+ list: 待处理任务列表,每项为 (索引, 任务) 元组
+ """
+ return [(i, task) for i, task in enumerate(self.tasks) if task.status == TaskStatus.PENDING]
+
+ def get_all_tasks(self):
+ """获取所有任务
+
+ Returns:
+ list: 所有任务列表
+ """
+ return self.tasks
+
+ def get_task(self, index):
+ """获取指定索引的任务
+
+ Args:
+ index: 任务索引
+
+ Returns:
+ Task: 任务对象
+ """
+ if 0 <= index < len(self.tasks):
+ return self.tasks[index]
+ return None
+
+ def find_task_index_by_path(self, path):
+ tasks = self.get_all_tasks()
+ for idx, task in enumerate(tasks):
+ if task.path == path:
+ return idx
+ return -1 # 没找到返回-1
+
+ def show_context_menu(self, pos):
+ """显示右键菜单
+
+ Args:
+ pos: 鼠标位置
+ """
+ index = self.table.indexAt(pos)
+ if index.isValid():
+ menu = QMenu(self)
+
+ # 打开视频文件位置
+ open_video_location_action = QAction(tr['TaskList']['OpenSourceVideoLocation'], self)
+ open_video_location_action.triggered.connect(lambda: self.open_file_location(self.tasks[index.row()].path))
+ menu.addAction(open_video_location_action)
+
+ # 打开目标文件位置
+ def open_target_location():
+ task = self.tasks[index.row()]
+ path = task.output_path
+ if task.status != TaskStatus.COMPLETED:
+ InfoBar.warning(
+ title=tr['TaskList']['Warning'],
+ content=tr['TaskList']['TargetFileNotFound'],
+ parent=self.get_root_parent(),
+ duration=3000
+ )
+ return
+ self.open_file_location(path)
+ open_target_location_action = QAction(tr['TaskList']['OpenTargetVideoLocation'], self)
+ open_target_location_action.triggered.connect(open_target_location)
+ menu.addAction(open_target_location_action)
+
+ reset_task_status_action = QAction(tr['TaskList']['ResetTaskStatus'], self)
+ reset_task_status_action.triggered.connect((lambda: (
+ self.update_task_status(index.row(), TaskStatus.PENDING),
+ self.update_task_progress(index.row(), 0)
+ )
+ ))
+ menu.addAction(reset_task_status_action)
+
+ # 删除任务
+ delete_action = QAction(tr['TaskList']['DeleteTask'], self)
+ delete_action.triggered.connect(lambda: self.delete_task(index.row()))
+ menu.addAction(delete_action)
+
+ # 显示菜单
+ menu.exec_(self.table.viewport().mapToGlobal(pos))
+
+ def delete_task(self, row):
+ """删除任务
+
+ Args:
+ row: 行索引
+ """
+ if 0 <= row < len(self.tasks):
+ # 从列表中删除
+ del self.tasks[row]
+
+ # 从表格中删除
+ self.table.removeRow(row)
+
+ # 如果删除的是当前任务,重置当前任务索引
+ if row == self.current_task_index:
+ self.current_task_index = -1
+
+ # 发出任务删除信号
+ self.task_deleted.emit(row)
+
+ def on_task_clicked(self, index):
+ """任务被点击时的处理
+
+ Args:
+ index: 索引
+ """
+ row = index.row()
+ if 0 <= row < len(self.tasks):
+ self.current_task_index = row
+ # 发出信号,通知外部加载对应视频
+ self.task_selected.emit(row, self.tasks[row].path)
+
+ def set_current_task(self, index):
+ """设置当前处理的任务
+
+ Args:
+ index: 任务索引
+ """
+ if 0 <= index < len(self.tasks):
+ self.current_task_index = index
+ self.table.selectRow(index)
+ self.table.scrollTo(self.table.model().index(index, 0))
+
+ def select_task(self, index):
+ """选中指定任务
+
+ Args:
+ index: 任务索引
+ """
+ self.set_current_task(index)
+
+ def open_file_location(self, path):
+ """打开文件所在位置
+
+ Args:
+ row: 行索引
+ path: 目标路径
+ """
+ # 检查视频文件是否存在
+ if not os.path.exists(path):
+ InfoBar.warning(
+ title=tr['TaskList']['Warning'],
+ content=tr['TaskList']['UnableToLocateFile'],
+ parent=self.get_root_parent(),
+ duration=3000
+ )
+ return
+
+ show_in_file_manager(os.path.abspath(path))
+
+ def get_root_parent(self):
+ parent = self
+ while parent.parent():
+ parent = parent.parent()
+ return parent
\ No newline at end of file
diff --git a/ui/component/video_display_component.py b/ui/component/video_display_component.py
new file mode 100644
index 0000000..9920db7
--- /dev/null
+++ b/ui/component/video_display_component.py
@@ -0,0 +1,566 @@
+import cv2
+from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QSizePolicy
+from PySide6.QtCore import Qt, Signal, QRect, QRectF, QTimer, QObject, QEvent
+from PySide6 import QtCore, QtWidgets, QtGui
+from qfluentwidgets import qconfig, CardWidget, HollowHandleStyle
+
+from backend.config import config, tr
+
+class VideoDisplayComponent(QWidget):
+ """视频显示组件,包含视频预览和选择框功能"""
+
+ # 定义信号
+ selection_changed = Signal(QRect) # 选择框变化信号
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.parent = parent
+
+ # 初始化变量
+ self.is_drawing = False
+ self.selection_rect = QRect()
+ self.drag_start_pos = None
+ self.resize_edge = None
+ self.edge_size = 10 # 调整大小的边缘区域
+ self.enable_mouse_events = True # 控制是否启用鼠标事件
+
+ # 获取屏幕大小
+ screen = QtWidgets.QApplication.primaryScreen().size()
+ self.screen_width = screen.width()
+ self.screen_height = screen.height()
+
+ # 设置视频预览区域大小(根据屏幕宽度动态调整)
+ self.video_preview_width = 960
+ self.video_preview_height = self.video_preview_width * 9 // 16
+ if self.screen_width // 2 < 960:
+ self.video_preview_width = 640
+ self.video_preview_height = self.video_preview_width * 9 // 16
+
+ # 视频相关参数
+ self.frame_width = None
+ self.frame_height = None
+ self.scaled_width = None
+ self.scaled_height = None
+ self.border_left = 0
+ self.border_top = 0
+
+ # 保存选择框的相对位置和大小(相对于实际视频的比例)
+ self.selection_ratio = None
+
+ self.__initWidget()
+
+ def __initWidget(self):
+ """初始化组件"""
+ main_layout = QVBoxLayout(self)
+ main_layout.setSpacing(0)
+ main_layout.setContentsMargins(0, 0, 0, 0)
+
+ # 视频预览区域和进度条容器
+ self.video_container = CardWidget(self)
+ self.video_container.setObjectName('videoContainer')
+ video_layout = QVBoxLayout()
+ video_layout.setSpacing(0)
+ video_layout.setContentsMargins(2, 2, 2, 2)
+ video_layout.setAlignment(Qt.AlignCenter)
+
+ # 创建内部黑色背景容器
+ self.black_container = QWidget(self)
+ self.black_container.setObjectName('blackContainer')
+ self.black_container.setStyleSheet("""
+ #blackContainer {
+ background-color: black;
+ border-radius: 10px;
+ border: 0px solid transparent;
+ }
+ """)
+ black_layout = QVBoxLayout()
+ black_layout.setContentsMargins(0, 0, 0, 0)
+ black_layout.setSpacing(0)
+ black_layout.setAlignment(Qt.AlignCenter)
+
+ # 视频显示标签
+ self.video_display = QtWidgets.QLabel()
+ self.video_display.setStyleSheet("""
+ background-color: black;
+ border-top-left-radius: 10px;
+ border-top-right-radius: 10px;
+ border: 0px solid transparent;
+ """)
+ self.video_display.setMinimumSize(self.video_preview_width, self.video_preview_height)
+
+ self.video_display.setMouseTracking(True)
+ self.video_display.setScaledContents(True)
+ self.video_display.setAlignment(Qt.AlignCenter)
+ self.video_display.mousePressEvent = self.selection_mouse_press
+ self.video_display.mouseMoveEvent = self.selection_mouse_move
+ self.video_display.mouseReleaseEvent = self.selection_mouse_release
+
+ # 视频滑块
+ self.video_slider = QtWidgets.QSlider(Qt.Horizontal)
+ self.video_slider.setMinimum(1)
+ self.video_slider.setFixedHeight(22)
+ self.video_slider.setMaximum(100) # 默认最大值设为100,与进度百分比一致
+ self.video_slider.setValue(1)
+ self.video_slider.setStyle(HollowHandleStyle({
+ "handle.color": QtGui.QColor(255, 255, 255),
+ "handle.ring-width": 4,
+ "handle.hollow-radius": 6,
+ "handle.margin": 1
+ }))
+
+ # 视频预览区域
+ self.video_display.setObjectName('videoDisplay')
+ # black_layout.addWidget(self.video_display, 0, Qt.AlignCenter)
+ # 创建一个容器来保持比例
+ ratio_container = QWidget()
+ ratio_layout = QVBoxLayout(ratio_container)
+ ratio_layout.setContentsMargins(0, 0, 0, 0)
+ ratio_layout.addWidget(self.video_display)
+
+ # 设置固定的宽高比
+ ratio_container.setFixedHeight(ratio_container.width() * 9 // 16)
+ ratio_container.setMinimumWidth(self.video_preview_width)
+
+ # 添加到布局
+ black_layout.addWidget(ratio_container)
+
+ # 添加一个事件过滤器来处理大小变化
+ class RatioEventFilter(QObject):
+ def eventFilter(self, obj, event):
+ if event.type() == QEvent.Resize:
+ obj.setFixedHeight(obj.width() * 9 // 16)
+ return False
+
+ ratio_filter = RatioEventFilter(ratio_container)
+ ratio_container.installEventFilter(ratio_filter)
+
+ # 进度条和滑块容器
+ control_container = QWidget(self)
+ control_layout = QVBoxLayout()
+ control_layout.setContentsMargins(8, 8, 8, 8)
+ control_layout.addWidget(self.video_slider)
+
+ control_container.setLayout(control_layout)
+ control_container.setStyleSheet("""
+ background-color: black;
+ border-bottom-left-radius: 8px;
+ border-bottom-right-radius: 8px;
+ """)
+ black_layout.addWidget(control_container)
+
+ self.black_container.setLayout(black_layout)
+ video_layout.addWidget(self.black_container)
+ self.video_container.setLayout(video_layout)
+ main_layout.addWidget(self.video_container)
+
+ def update_video_display(self, frame, draw_selection=True):
+ """更新视频显示"""
+ if frame is None:
+ return
+
+ # 调整视频帧大小以适应视频预览区域
+ frame = cv2.resize(frame, (self.video_preview_width, self.video_preview_height))
+ # 将 OpenCV 帧(BGR 格式)转换为 QImage 并显示在 QLabel 上
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ h, w, ch = rgb_frame.shape
+ bytes_per_line = ch * w
+ image = QtGui.QImage(rgb_frame.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
+ pix = QtGui.QPixmap.fromImage(image)
+
+ # 创建带圆角的图像
+ rounded_pix = QtGui.QPixmap(pix.size())
+ rounded_pix.fill(Qt.transparent) # 填充透明背景
+
+ painter = QtGui.QPainter(rounded_pix)
+ painter.setRenderHint(QtGui.QPainter.Antialiasing) # 抗锯齿
+ painter.setRenderHint(QtGui.QPainter.SmoothPixmapTransform, True)
+
+ # 创建圆角路径
+ path = QtGui.QPainterPath()
+ rect = QRectF(0, 0, pix.width(), pix.height())
+
+ # 手动创建只有左上和右上圆角的路径
+ radius = 8
+ path.moveTo(radius, 0)
+ path.lineTo(pix.width() - radius, 0)
+ path.arcTo(pix.width() - radius * 2, 0, radius * 2, radius * 2, 90, -90)
+ path.lineTo(pix.width(), pix.height())
+ path.lineTo(0, pix.height())
+ path.lineTo(0, radius)
+ path.arcTo(0, 0, radius * 2, radius * 2, 180, -90)
+ path.closeSubpath()
+
+ painter.setClipPath(path)
+ painter.drawPixmap(0, 0, pix)
+ painter.end()
+
+ # 保存当前的pixmap用于绘制选择框
+ self.current_pixmap = rounded_pix.copy()
+
+ self.video_display.setPixmap(rounded_pix)
+
+ # 如果有保存的选择框比例,根据新视频尺寸重新计算选择框
+ if draw_selection and self.selection_ratio is not None and self.scaled_width and self.scaled_height:
+ x_ratio, y_ratio, w_ratio, h_ratio = self.selection_ratio
+
+ # 计算新的选择框坐标和大小
+ x = int(x_ratio * self.scaled_width) + self.border_left
+ y = int(y_ratio * self.scaled_height) + self.border_top
+ w = int(w_ratio * self.scaled_width)
+ h = int(h_ratio * self.scaled_height)
+
+ # 创建新的选择框
+ self.selection_rect = QRect(x, y, w, h)
+
+ # 更新视频显示
+ self.update_preview_with_rect()
+
+ def update_preview_with_rect(self, rect=None):
+ """更新带有选择框的预览"""
+ if not hasattr(self, 'current_pixmap') or self.current_pixmap is None:
+ return
+
+ # 如果提供了新的矩形,使用它
+ if rect is not None:
+ self.selection_rect = rect
+
+ # 创建一个副本用于绘制
+ pixmap_copy = self.current_pixmap.copy()
+ painter = QtGui.QPainter(pixmap_copy)
+
+ # 设置选择框样式
+ pen = QtGui.QPen(QtGui.QColor(0, 255, 0)) # 绿色
+ pen.setWidth(2)
+ painter.setPen(pen)
+
+ # 绘制选择框
+ painter.drawRect(self.selection_rect)
+ painter.end()
+
+ # 更新显示
+ self.video_display.setPixmap(pixmap_copy)
+
+ def selection_mouse_press(self, event):
+ """鼠标按下事件处理"""
+ if not self.enable_mouse_events:
+ return
+ pos = event.pos()
+
+ # 检测双击或三击,重置选择框
+ if event.type() == QtCore.QEvent.MouseButtonDblClick:
+ self.selection_rect = QRect(pos, pos)
+ self.resize_edge = None
+ self.is_drawing = True
+ self.drag_start_pos = pos
+ return
+
+ # 检查是否在选择框边缘(用于调整大小)
+ if self.selection_rect.isValid():
+ # 右下角
+ if abs(pos.x() - self.selection_rect.right()) <= self.edge_size and abs(pos.y() - self.selection_rect.bottom()) <= self.edge_size:
+ self.resize_edge = "bottomright"
+ self.drag_start_pos = pos
+ return
+ # 右上角
+ elif abs(pos.x() - self.selection_rect.right()) <= self.edge_size and abs(pos.y() - self.selection_rect.top()) <= self.edge_size:
+ self.resize_edge = "topright"
+ self.drag_start_pos = pos
+ return
+ # 左下角
+ elif abs(pos.x() - self.selection_rect.left()) <= self.edge_size and abs(pos.y() - self.selection_rect.bottom()) <= self.edge_size:
+ self.resize_edge = "bottomleft"
+ self.drag_start_pos = pos
+ return
+ # 左边缘
+ elif abs(pos.x() - self.selection_rect.left()) <= self.edge_size and self.selection_rect.top() <= pos.y() <= self.selection_rect.bottom():
+ self.resize_edge = "left"
+ self.drag_start_pos = pos
+ return
+ # 右边缘
+ elif abs(pos.x() - self.selection_rect.right()) <= self.edge_size and self.selection_rect.top() <= pos.y() <= self.selection_rect.bottom():
+ self.resize_edge = "right"
+ self.drag_start_pos = pos
+ return
+ # 上边缘
+ elif abs(pos.y() - self.selection_rect.top()) <= self.edge_size and self.selection_rect.left() <= pos.x() <= self.selection_rect.right():
+ self.resize_edge = "top"
+ self.drag_start_pos = pos
+ return
+ # 下边缘
+ elif abs(pos.y() - self.selection_rect.bottom()) <= self.edge_size and self.selection_rect.left() <= pos.x() <= self.selection_rect.right():
+ self.resize_edge = "bottom"
+ self.drag_start_pos = pos
+ return
+ # 左上角
+ elif abs(pos.x() - self.selection_rect.left()) <= self.edge_size and abs(pos.y() - self.selection_rect.top()) <= self.edge_size:
+ self.resize_edge = "topleft"
+ self.drag_start_pos = pos
+ return
+ # 在选择框内部(用于移动)
+ elif self.selection_rect.contains(pos):
+ self.resize_edge = "move"
+ self.drag_start_pos = pos
+ return
+
+ # 开始新的选择
+ self.is_drawing = True
+ self.selection_rect = QRect(pos, pos)
+ self.drag_start_pos = pos
+ self.resize_edge = None
+
+ def selection_mouse_move(self, event):
+ """鼠标移动事件处理"""
+ if not self.enable_mouse_events:
+ return
+ pos = event.pos()
+
+ # 根据不同的操作模式处理鼠标移动
+ if self.is_drawing: # 绘制新选择框
+ self.selection_rect.setBottomRight(pos)
+ self.update_preview_with_rect()
+ elif self.resize_edge: # 调整选择框大小或位置
+ if self.resize_edge == "move":
+ # 移动整个选择框
+ dx = pos.x() - self.drag_start_pos.x()
+ dy = pos.y() - self.drag_start_pos.y()
+
+ # 保存原始选择框尺寸
+ original_width = self.selection_rect.width()
+ original_height = self.selection_rect.height()
+
+ # 计算新位置
+ new_rect = self.selection_rect.translated(dx, dy)
+
+ # 获取视频显示区域
+ display_rect = self.video_display.rect()
+
+ # 检查是否超出边界,如果超出则调整位置但保持尺寸
+ if new_rect.left() < 0:
+ new_rect.moveLeft(0)
+ if new_rect.top() < 0:
+ new_rect.moveTop(0)
+ if new_rect.right() > display_rect.width():
+ new_rect.moveRight(display_rect.width())
+ if new_rect.bottom() > display_rect.height():
+ new_rect.moveBottom(display_rect.height())
+
+ # 确保尺寸不变
+ if new_rect.width() != original_width or new_rect.height() != original_height:
+ # 如果尺寸变了,恢复原始尺寸
+ if new_rect.left() == 0:
+ new_rect.setWidth(original_width)
+ if new_rect.top() == 0:
+ new_rect.setHeight(original_height)
+ if new_rect.right() == display_rect.width():
+ new_rect.setLeft(new_rect.right() - original_width)
+ if new_rect.bottom() == display_rect.height():
+ new_rect.setTop(new_rect.bottom() - original_height)
+
+ self.selection_rect = new_rect
+ self.drag_start_pos = pos
+ else:
+ # 调整选择框大小
+ if "left" in self.resize_edge:
+ self.selection_rect.setLeft(pos.x())
+ if "right" in self.resize_edge:
+ self.selection_rect.setRight(pos.x())
+ if "top" in self.resize_edge:
+ self.selection_rect.setTop(pos.y())
+ if "bottom" in self.resize_edge:
+ self.selection_rect.setBottom(pos.y())
+
+ # 确保选择框在视频显示区域内
+ display_rect = self.video_display.rect()
+ if self.selection_rect.left() < 0:
+ self.selection_rect.setLeft(0)
+ if self.selection_rect.top() < 0:
+ self.selection_rect.setTop(0)
+ if self.selection_rect.right() > display_rect.width():
+ self.selection_rect.setRight(display_rect.width())
+ if self.selection_rect.bottom() > display_rect.height():
+ self.selection_rect.setBottom(display_rect.height())
+
+ self.update_preview_with_rect()
+ else:
+ # 更新鼠标指针形状
+ self.update_cursor_shape(pos)
+
+ def selection_mouse_release(self, event):
+ """鼠标释放事件处理"""
+ if not self.enable_mouse_events:
+ return
+ # 结束绘制或调整
+ self.is_drawing = False
+ self.resize_edge = None
+
+ # 标准化选择框(确保宽度和高度为正)
+ self.selection_rect = self.selection_rect.normalized()
+
+ # 保存选择框的相对位置和大小
+ self.save_selection_ratio()
+
+ # 发送选择框变化信号
+ self.selection_changed.emit(self.selection_rect)
+
+ def update_cursor_shape(self, pos):
+ """根据鼠标位置更新光标形状"""
+ if not self.selection_rect.isValid():
+ self.video_display.setCursor(Qt.ArrowCursor)
+ return
+
+ # 检查鼠标是否在选择框边缘
+ if (abs(pos.x() - self.selection_rect.left()) <= self.edge_size and
+ self.selection_rect.top() <= pos.y() <= self.selection_rect.bottom()):
+ self.video_display.setCursor(Qt.SizeHorCursor)
+ elif (abs(pos.x() - self.selection_rect.right()) <= self.edge_size and
+ self.selection_rect.top() <= pos.y() <= self.selection_rect.bottom()):
+ self.video_display.setCursor(Qt.SizeHorCursor)
+ elif (abs(pos.y() - self.selection_rect.top()) <= self.edge_size and
+ self.selection_rect.left() <= pos.x() <= self.selection_rect.right()):
+ self.video_display.setCursor(Qt.SizeVerCursor)
+ elif (abs(pos.y() - self.selection_rect.bottom()) <= self.edge_size and
+ self.selection_rect.left() <= pos.x() <= self.selection_rect.right()):
+ self.video_display.setCursor(Qt.SizeVerCursor)
+ elif (abs(pos.x() - self.selection_rect.left()) <= self.edge_size and
+ abs(pos.y() - self.selection_rect.top()) <= self.edge_size):
+ self.video_display.setCursor(Qt.SizeFDiagCursor)
+ elif (abs(pos.x() - self.selection_rect.right()) <= self.edge_size and
+ abs(pos.y() - self.selection_rect.top()) <= self.edge_size):
+ self.video_display.setCursor(Qt.SizeBDiagCursor)
+ elif (abs(pos.x() - self.selection_rect.left()) <= self.edge_size and
+ abs(pos.y() - self.selection_rect.bottom()) <= self.edge_size):
+ self.video_display.setCursor(Qt.SizeBDiagCursor)
+ elif (abs(pos.x() - self.selection_rect.right()) <= self.edge_size and
+ abs(pos.y() - self.selection_rect.bottom()) <= self.edge_size):
+ self.video_display.setCursor(Qt.SizeFDiagCursor)
+ elif self.selection_rect.contains(pos):
+ self.video_display.setCursor(Qt.SizeAllCursor)
+ else:
+ self.video_display.setCursor(Qt.ArrowCursor)
+
+ def set_video_parameters(self, frame_width, frame_height, scaled_width=None, scaled_height=None, border_left=0, border_top=0):
+ """设置视频参数"""
+ self.frame_width = frame_width
+ self.frame_height = frame_height
+ self.scaled_width = scaled_width
+ self.scaled_height = scaled_height
+ self.border_left = border_left
+ self.border_top = border_top
+
+ def get_selection_coordinates(self):
+ """获取选择框坐标"""
+ return self.selection_rect
+
+ def set_selection_rect(self, rect):
+ """设置选择框"""
+ self.selection_rect = rect
+ self.save_selection_ratio()
+ self.update_preview_with_rect()
+
+ def load_selection_ratio(self):
+ """从配置中加载选择框的相对位置和大小"""
+ # 检查是否有有效的视频尺寸
+ if not self.scaled_width or not self.scaled_height:
+ return False
+
+ # 从配置中读取选择框的相对位置和大小
+ x_ratio = config.subtitleSelectionAreaX.value
+ y_ratio = config.subtitleSelectionAreaY.value
+ w_ratio = config.subtitleSelectionAreaW.value
+ h_ratio = config.subtitleSelectionAreaH.value
+
+ # 检查配置值是否有效
+ if x_ratio is None or y_ratio is None or w_ratio is None or h_ratio is None:
+ return False
+
+ # 检查配置值是否在有效范围内
+ if w_ratio <= 0.01 or h_ratio <= 0.005:
+ config.set(config.subtitleSelectionAreaX, config.subtitleSelectionAreaX.defaultValue)
+ config.set(config.subtitleSelectionAreaY, config.subtitleSelectionAreaY.defaultValue)
+ config.set(config.subtitleSelectionAreaW, config.subtitleSelectionAreaW.defaultValue)
+ config.set(config.subtitleSelectionAreaH, config.subtitleSelectionAreaH.defaultValue)
+ x_ratio = config.subtitleSelectionAreaX.value
+ y_ratio = config.subtitleSelectionAreaY.value
+ w_ratio = config.subtitleSelectionAreaW.value
+ h_ratio = config.subtitleSelectionAreaH.value
+
+ # 保存选择框比例
+ self.selection_ratio = (x_ratio, y_ratio, w_ratio, h_ratio)
+
+ # 计算实际像素坐标
+ x = int(x_ratio * self.scaled_width) + self.border_left
+ y = int(y_ratio * self.scaled_height) + self.border_top
+ w = int(w_ratio * self.scaled_width)
+ h = int(h_ratio * self.scaled_height)
+
+ # 创建选择框
+ self.selection_rect = QRect(x, y, w, h)
+
+ # 更新预览
+ self.update_preview_with_rect()
+
+ return True
+
+ def save_selection_ratio(self):
+ """保存选择框的相对位置和大小(相对于实际视频的比例)"""
+ if not self.selection_rect.isValid() or not self.scaled_width or not self.scaled_height:
+ return
+
+ # 调整选择框坐标,考虑黑边偏移
+ x_adjusted = max(0, self.selection_rect.x() - self.border_left)
+ y_adjusted = max(0, self.selection_rect.y() - self.border_top)
+
+ # 如果选择框超出了实际视频区域,需要调整宽度和高度
+ w_adjusted = min(self.selection_rect.width(), self.scaled_width - x_adjusted)
+ h_adjusted = min(self.selection_rect.height(), self.scaled_height - y_adjusted)
+
+ # 转换为相对比例
+ x_ratio = x_adjusted / self.scaled_width
+ y_ratio = y_adjusted / self.scaled_height
+ w_ratio = w_adjusted / self.scaled_width
+ h_ratio = h_adjusted / self.scaled_height
+
+ self.selection_ratio = (x_ratio, y_ratio, w_ratio, h_ratio)
+
+ config.subtitleSelectionAreaY.value = y_ratio
+ config.subtitleSelectionAreaH.value = h_ratio
+ config.subtitleSelectionAreaX.value = x_ratio
+ config.subtitleSelectionAreaW.value = w_ratio
+
+ qconfig.save()
+
+ def get_original_coordinates(self):
+ """获取选择框在原始视频中的坐标"""
+ if not self.selection_rect.isValid() or not self.scaled_width or not self.scaled_height:
+ return None
+
+ # 调整选择框坐标,考虑黑边偏移
+ x_adjusted = max(0, self.selection_rect.x() - self.border_left)
+ y_adjusted = max(0, self.selection_rect.y() - self.border_top)
+
+ # 如果选择框超出了实际视频区域,需要调整宽度和高度
+ w_adjusted = min(self.selection_rect.width(), self.scaled_width - x_adjusted)
+ h_adjusted = min(self.selection_rect.height(), self.scaled_height - y_adjusted)
+
+ # 转换为原始视频坐标
+ scale_x = self.frame_width / self.scaled_width
+ scale_y = self.frame_height / self.scaled_height
+
+ xmin = int(x_adjusted * scale_x)
+ xmax = int((x_adjusted + w_adjusted) * scale_x)
+ ymin = int(y_adjusted * scale_y)
+ ymax = int((y_adjusted + h_adjusted) * scale_y)
+
+ # 确保坐标在有效范围内
+ xmin = max(0, min(xmin, self.frame_width))
+ xmax = max(0, min(xmax, self.frame_width))
+ ymin = max(0, min(ymin, self.frame_height))
+ ymax = max(0, min(ymax, self.frame_height))
+
+ return (ymin, ymax, xmin, xmax)
+
+ def set_dragger_enabled(self, enabled):
+ """设置拖动器是否可用"""
+ self.enable_mouse_events = enabled
+ self.video_display.setMouseTracking(enabled)
+ self.video_display.setCursor(Qt.ArrowCursor)
\ No newline at end of file
diff --git a/ui/home_interface.py b/ui/home_interface.py
new file mode 100644
index 0000000..5cb943e
--- /dev/null
+++ b/ui/home_interface.py
@@ -0,0 +1,586 @@
+import os
+import cv2
+import threading
+import atexit
+import multiprocessing
+import time
+import traceback
+from pathlib import Path
+from multiprocessing import managers
+from PySide6.QtWidgets import QWidget, QHBoxLayout, QVBoxLayout, QLabel, QPushButton, QFileDialog
+from PySide6.QtCore import Qt, Slot, QTimer, QRect, QRectF, Signal
+from PySide6 import QtCore, QtWidgets, QtGui
+from qfluentwidgets import (qconfig, PushButton, CardWidget, SubtitleLabel, PlainTextEdit,
+ FluentIcon, HollowHandleStyle)
+from ui.setting_interface import SettingInterface
+from ui.component.video_display_component import VideoDisplayComponent
+from ui.component.task_list_component import TaskListComponent, TaskStatus
+from ui.icon.my_fluent_icon import MyFluentIcon
+from backend.config import config, tr
+from backend.tools.subtitle_remover_remote_call import SubtitleRemoverRemoteCall
+from backend.tools.process_manager import ProcessManager
+from backend.tools.common_tools import get_readable_path, is_image_file, read_image
+
+class HomeInterface(QWidget):
+ progress_signal = Signal(int, bool)
+ append_log_signal = Signal(list)
+ update_preview_with_comp_signal = Signal(list)
+ task_error_signal = Signal(object)
+ def __init__(self, parent=None):
+ super().__init__(parent=parent)
+ self.setObjectName("HomeInterface")
+ # 初始化一些变量
+ self.video_path = None
+ self.video_cap = None
+ self.fps = None
+ self.frame_count = None
+ self.frame_width = None
+ self.frame_height = None
+ self.se = None # 后台字幕提取器
+
+ # 字幕区域参数
+ self.xmin = None
+ self.xmax = None
+ self.ymin = None
+ self.ymax = None
+
+ # 添加自动滚动控制标志
+ self.auto_scroll = True
+ self.running_task = False
+ self.running_process = None
+
+ # 当前正在处理的任务索引
+ self.current_processing_task_index = -1
+
+ self.__initWidget()
+ self.progress_signal.connect(self.update_progress)
+ self.append_log_signal.connect(self.append_log)
+ self.update_preview_with_comp_signal.connect(self.update_preview_with_comp)
+ self.task_error_signal.connect(self.on_task_error)
+
+ def __initWidget(self):
+ """创建主页面"""
+ main_layout = QHBoxLayout(self)
+ main_layout.setSpacing(8)
+ main_layout.setContentsMargins(16, 16, 16, 16)
+
+ # 左侧视频区域
+ left_layout = QVBoxLayout()
+ left_layout.setSpacing(8)
+
+ # 创建视频显示组件
+ self.video_display_component = VideoDisplayComponent(self)
+ left_layout.addWidget(self.video_display_component)
+
+ # 获取视频显示和滑块的引用
+ self.video_display = self.video_display_component.video_display
+ self.video_slider = self.video_display_component.video_slider
+ self.video_slider.valueChanged.connect(self.slider_changed)
+
+ # 输出文本区域
+ self.output_text = PlainTextEdit()
+ self.output_text.setMinimumHeight(150)
+ self.output_text.setReadOnly(True)
+ self.output_text.document().setDocumentMargin(10)
+ # 连接滚动条值变化信号
+ self.output_text.verticalScrollBar().valueChanged.connect(self.on_scroll_change)
+
+ output_container = CardWidget(self)
+ output_layout = QVBoxLayout()
+ output_layout.setContentsMargins(0, 0, 0, 0)
+ output_layout.addWidget(self.output_text)
+ output_container.setLayout(output_layout)
+ left_layout.addWidget(output_container)
+
+ main_layout.addLayout(left_layout, 2)
+
+ # 右侧设置区域
+ right_layout = QVBoxLayout()
+ right_layout.setSpacing(10)
+
+ # 设置容器
+ settings_container = CardWidget(self)
+ settings_container.setLayout(SettingInterface(settings_container))
+ right_layout.addWidget(settings_container)
+
+ # 添加任务列表容器
+ task_list_container = CardWidget(self)
+ task_list_layout = QHBoxLayout()
+ task_list_layout.setContentsMargins(0, 0, 0, 0)
+ task_list_layout.setSpacing(0)
+ self.task_list_component = TaskListComponent(self)
+ self.task_list_component.task_selected.connect(self.on_task_selected)
+ self.task_list_component.task_deleted.connect(self.on_task_deleted)
+ task_list_layout.addWidget(self.task_list_component)
+ task_list_container.setLayout(task_list_layout)
+ right_layout.addWidget(task_list_container, 1) # 占满剩余空间
+
+ # 操作按钮容器
+ button_container = CardWidget(self)
+ button_layout = QHBoxLayout()
+ button_layout.setContentsMargins(16, 16, 16, 16)
+ button_layout.setSpacing(8)
+
+ self.file_button = PushButton(tr['SubtitleExtractorGUI']['Open'], self)
+ self.file_button.setIcon(FluentIcon.FOLDER)
+ self.file_button.clicked.connect(self.open_file)
+ button_layout.addWidget(self.file_button)
+
+ self.run_button = PushButton(tr['SubtitleExtractorGUI']['Run'], self)
+ self.run_button.setIcon(FluentIcon.PLAY)
+ self.run_button.clicked.connect(self.run_button_clicked)
+ button_layout.addWidget(self.run_button)
+
+ self.stop_button = PushButton(tr['SubtitleExtractorGUI']['Stop'], self)
+ self.stop_button.setIcon(MyFluentIcon.Stop)
+ self.stop_button.setVisible(False)
+ self.stop_button.clicked.connect(self.stop_button_clicked)
+
+ button_layout.addWidget(self.stop_button)
+
+ button_container.setLayout(button_layout)
+ right_layout.addWidget(button_container)
+
+ main_layout.addLayout(right_layout, 1)
+
+ def on_scroll_change(self, value):
+ """监控滚动条位置变化"""
+ scrollbar = self.output_text.verticalScrollBar()
+ # 如果滚动到底部,启用自动滚动
+ if value == scrollbar.maximum():
+ self.auto_scroll = True
+ # 如果用户向上滚动,禁用自动滚动
+ elif self.auto_scroll and value < scrollbar.maximum():
+ self.auto_scroll = False
+
+
+ def slider_changed(self, value):
+ if self.video_cap is not None and self.video_cap.isOpened():
+ frame_no = self.video_slider.value()
+ self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
+ ret, frame = self.video_cap.read()
+ if ret:
+ # 更新预览图像
+ self.update_preview(frame)
+
+ def on_task_selected(self, index, file_path):
+ """处理任务被选中事件
+
+ Args:
+ index: 任务索引
+ file_path: 文件路径
+ """
+ # 加载选中的视频进行预览
+ self.load_video(file_path)
+
+ def on_task_deleted(self, index):
+ """处理任务被删除事件
+
+ Args:
+ index: 任务索引
+ """
+ # 如果删除的是正在处理的任务,则需要更新状态
+ if index == self.current_processing_task_index:
+ self.current_processing_task_index = -1
+
+ task = self.task_list_component.get_task(0)
+ if task:
+ # 如果还有任务,选中第一个
+ self.load_video(task.path)
+ self.task_list_component.select_task(0)
+
+ def update_preview(self, frame):
+ # 先缩放图像
+ resized_frame = self._img_resize(frame)
+
+ # 设置视频参数
+ self.video_display_component.set_video_parameters(
+ self.frame_width, self.frame_height,
+ self.scaled_width if hasattr(self, 'scaled_width') else None,
+ self.scaled_height if hasattr(self, 'scaled_height') else None,
+ self.border_left if hasattr(self, 'border_left') else 0,
+ self.border_top if hasattr(self, 'border_top') else 0
+ )
+
+ # 更新视频显示(这会同时保存current_pixmap)
+ self.video_display_component.update_video_display(resized_frame)
+
+ def _img_resize(self, image):
+ height, width = image.shape[:2]
+
+ video_preview_width = self.video_display_component.video_preview_width
+ video_preview_height = self.video_display_component.video_preview_height
+ # 计算等比缩放后的尺寸
+ target_ratio = video_preview_width / video_preview_height
+ image_ratio = width / height
+
+ if image_ratio > target_ratio:
+ # 宽度适配,高度按比例缩放
+ new_width = video_preview_width
+ new_height = int(new_width / image_ratio)
+ top_border = (video_preview_height - new_height) // 2
+ bottom_border = video_preview_height - new_height - top_border
+ left_border = 0
+ right_border = 0
+ else:
+ # 高度适配,宽度按比例缩放
+ new_height = video_preview_height
+ new_width = int(new_height * image_ratio)
+ left_border = (video_preview_width - new_width) // 2
+ right_border = video_preview_width - new_width - left_border
+ top_border = 0
+ bottom_border = 0
+
+ # 先缩放图像
+ resized = cv2.resize(image, (new_width, new_height))
+
+ # 添加黑边以填充到目标尺寸
+ padded = cv2.copyMakeBorder(
+ resized,
+ top_border, bottom_border,
+ left_border, right_border,
+ cv2.BORDER_CONSTANT,
+ value=[0, 0, 0]
+ )
+
+ # 保存边框信息,用于坐标转换
+ self.border_left = left_border
+ self.border_right = right_border
+ self.border_top = top_border
+ self.border_bottom = bottom_border
+ self.original_width = width
+ self.original_height = height
+ self.is_vertical = width < height
+ self.scaled_width = new_width
+ self.scaled_height = new_height
+
+ return padded
+
+ def stop_button_clicked(self):
+ try:
+ self.running_task = False
+ running_process = self.running_process
+ if running_process:
+ ProcessManager.instance().terminate_by_process(running_process)
+ # 更新任务状态为待处理
+ if self.current_processing_task_index >= 0:
+ self.task_list_component.update_task_status(self.current_processing_task_index, TaskStatus.PENDING)
+ finally:
+ self.running_process = None
+ self.run_button.setVisible(True)
+ self.stop_button.setVisible(False)
+
+ def run_button_clicked(self):
+ if not self.task_list_component.get_pending_tasks():
+ self.append_output(tr['SubtitleExtractorGUI']['OpenVideoFirst'])
+ return
+
+ try:
+ # 获取所有待执行的任务
+ pending_tasks = self.task_list_component.get_pending_tasks()
+ if not pending_tasks:
+ return
+
+ self.run_button.setVisible(False)
+ self.stop_button.setVisible(True)
+ # 开启后台线程处理视频
+ def task():
+ self.running_task = True
+ try:
+ while self.running_task:
+ try:
+ pending_tasks = self.task_list_component.get_pending_tasks()
+ if not pending_tasks:
+ break
+ pending_task = pending_tasks[0]
+ # 更新当前处理的任务索引
+ self.current_processing_task_index, task = pending_task
+ if not self.load_video(task.path):
+ self.append_output(tr['SubtitleExtractorGUI']['OpenVideoFailed'].format(task.path))
+ self.task_list_component.update_task_status(self.current_processing_task_index, TaskStatus.FAILED)
+ continue
+
+ # 更新任务状态为运行中
+ self.task_list_component.update_task_progress(self.current_processing_task_index, 1)
+
+ # 选中当前任务
+ self.task_list_component.select_task(self.current_processing_task_index)
+
+ if self.video_cap:
+ self.video_cap.release()
+ self.video_cap = None
+
+
+ # 获取字幕区域坐标(直接从视频显示组件获取)
+ subtitle_area = self.video_display_component.get_original_coordinates()
+ if not subtitle_area:
+ self.append_output(tr['SubtitleExtractorGUI']['SelectSubtitleArea'])
+ return
+ self.append_output(f"{tr['SubtitleExtractorGUI']['SubtitleArea']}: {subtitle_area}")
+
+ self.task_list_component.update_task_status(self.current_processing_task_index, TaskStatus.PROCESSING)
+ process = self.run_subtitle_remover_process(task.path, task.output_path, subtitle_area)
+
+ # 更新任务状态为已完成
+ task = self.task_list_component.get_task(self.current_processing_task_index)
+ if process.exitcode == 0 and task and task.status == TaskStatus.PROCESSING:
+ self.progress_signal.emit(100, True)
+ self.task_list_component.update_task_status(self.current_processing_task_index, TaskStatus.COMPLETED)
+ else:
+ self.task_list_component.update_task_status(self.current_processing_task_index, TaskStatus.FAILED)
+
+ except Exception as e:
+ print(e)
+ self.append_output(f"Error: {e}")
+ # 更新任务状态为失败
+ if self.current_processing_task_index >= 0:
+ self.task_list_component.update_task_status(self.current_processing_task_index, TaskStatus.FAILED)
+ break
+ finally:
+ if self.video_cap:
+ self.video_cap.release()
+ self.video_cap = None
+ time.sleep(1)
+ finally:
+ self.running_task = False
+ self.run_button.setVisible(True)
+ self.stop_button.setVisible(False)
+
+ threading.Thread(target=task, daemon=True).start()
+ except Exception as e:
+ print(traceback.format_exc())
+ self.append_output(f"Error: {e}")
+ # 没有待执行的任务,恢复按钮状态
+ self.run_button.setVisible(True)
+ self.stop_button.setVisible(False)
+
+ @staticmethod
+ def remover_process(queue, video_path, output_path, subtitle_area):
+ """
+ 在子进程中执行字幕提取的函数
+
+ Args:
+ video_path: 视频文件路径
+ output_path: 输出文件路径
+ subtitle_area: 字幕区域坐标 (ymin, ymax, xmin, xmax)
+ """
+ sr = None
+ try:
+ from backend.main import SubtitleRemover
+ sr = SubtitleRemover(video_path, subtitle_area, True)
+ sr.video_out_path = output_path
+ sr.add_progress_listener(lambda progress, isFinished: SubtitleRemoverRemoteCall.remote_call_update_progress(queue, progress, isFinished))
+ sr.append_output = lambda *args: SubtitleRemoverRemoteCall.remote_call_append_log(queue, args)
+ sr.manage_process = lambda pid: SubtitleRemoverRemoteCall.remote_call_manage_process(queue, pid)
+ sr.update_preview_with_comp = lambda *args: SubtitleRemoverRemoteCall.remote_call_update_preview_with_comp(queue, args)
+ sr.run()
+ except Exception as e:
+ traceback.print_exc()
+ SubtitleRemoverRemoteCall.remote_call_catch_error(queue, e)
+ finally:
+ if sr:
+ sr.isFinished = True
+ sr.vsf_running = False
+ SubtitleRemoverRemoteCall.remote_call_finish(queue)
+
+
+ # 修改run_subtitle_remover_process方法
+ def run_subtitle_remover_process(self, video_path, output_path, subtitle_area):
+ """
+ 使用多进程执行字幕提取,并等待进程完成
+
+ Args:
+ video_path: 视频文件路径
+ output_path: 输出文件路径
+ subtitle_area: 字幕区域坐标 (ymin, ymax, xmin, xmax)
+ """
+ subtitle_remover_remote_caller = SubtitleRemoverRemoteCall()
+ subtitle_remover_remote_caller.register_update_progress_callback(self.progress_signal.emit)
+ subtitle_remover_remote_caller.register_log_callback(self.append_log_signal.emit)
+ subtitle_remover_remote_caller.register_update_preview_with_comp_callback(self.update_preview_with_comp_signal.emit)
+ subtitle_remover_remote_caller.register_error_callback(self.task_error_signal.emit)
+ process = multiprocessing.Process(
+ target=HomeInterface.remover_process,
+ args=(subtitle_remover_remote_caller.queue, video_path, output_path, subtitle_area)
+ )
+ try:
+ if not self.running_task:
+ return process
+ process.start()
+ ProcessManager.instance().add_process(process)
+ self.running_process = process
+ process.join()
+ print(f"Process exited with code {process.exitcode}")
+ finally:
+ subtitle_remover_remote_caller.stop()
+ return process
+
+ @Slot()
+ def processing_finished(self):
+ pending_tasks = self.task_list_component.get_pending_tasks()
+ if pending_tasks:
+ # 还有待执行任务, 忽略
+ return
+ # 处理完成后恢复界面可用性
+ self.run_button.setVisible(True)
+ self.stop_button.setVisible(False)
+ self.se = None
+ # 重置视频滑块
+ self.video_slider.setValue(1)
+ # 重置当前处理任务索引
+ self.current_processing_task_index = -1
+
+ @Slot(int, bool)
+ def update_progress(self, progress_total, isFinished):
+ try:
+ pos = min(self.frame_count - 1, int(progress_total / 100 * self.frame_count))
+ if pos != self.video_slider.value():
+ self.video_slider.blockSignals(True)
+ self.video_slider.setValue(pos)
+ self.video_slider.blockSignals(False)
+
+ # 更新任务进度
+ if self.current_processing_task_index >= 0:
+ self.task_list_component.update_task_progress(
+ self.current_processing_task_index,
+ progress_total,
+ )
+
+ # 检查是否完成
+ if isFinished:
+ self.processing_finished()
+ except Exception as e:
+ # 捕获任何异常,防止崩溃
+ print(f"更新进度时出错: {str(e)}")
+
+ @Slot(list)
+ def append_log(self, log):
+ self.append_output(*log)
+
+ def append_output(self, *args):
+ """添加文本到输出区域并控制滚动
+ Args:
+ *args: 要输出的内容,多个参数将用空格连接
+ """
+ # 将所有参数转换为字符串并用空格连接
+ text = ' '.join(str(arg) for arg in args).rstrip()
+ self.output_text.appendPlainText(text)
+ print(*args) # 保持原始的 print 行为
+ # 如果启用了自动滚动,则滚动到底部
+ if self.auto_scroll:
+ scrollbar = self.output_text.verticalScrollBar()
+ scrollbar.setValue(scrollbar.maximum())
+
+ @Slot(list)
+ def update_preview_with_comp(self, args):
+ """更新执行时预览"""
+ frame_ori, frame_comp = args
+
+ subtitle_area = self.video_display_component.get_original_coordinates()
+ if subtitle_area:
+ ymin, ymax, xmin, xmax = subtitle_area
+ if frame_ori is frame_comp:
+ frame_ori = frame_ori.copy()
+ cv2.rectangle(frame_ori, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
+ preview_frame = cv2.hconcat([frame_ori, frame_comp])
+ # 先缩放图像
+ resized_frame = self._img_resize(preview_frame)
+ # 更新视频显示(这会同时保存current_pixmap)
+ self.video_display_component.update_video_display(resized_frame, draw_selection=False)
+ self.video_display_component.set_dragger_enabled(False)
+
+ @Slot(object)
+ def on_task_error(self, e):
+ self.append_output(tr['SubtitleExtractorGUI']['ErrorDuringProcessing'].format(str(e)))
+ if self.current_processing_task_index >= 0:
+ self.task_list_component.update_task_status(self.current_processing_task_index, TaskStatus.FAILED)
+
+ def load_video(self, video_path):
+ self.video_path = video_path
+ if self.video_cap:
+ self.video_cap.release()
+ self.video_cap = None
+ self.video_cap = cv2.VideoCapture(get_readable_path(self.video_path))
+ if not self.video_cap.isOpened():
+ return self.load_as_picture(video_path)
+ ret, frame = self.video_cap.read()
+ if not ret:
+ return self.load_as_picture(video_path)
+ self.frame_count = int(self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
+
+ self.update_preview(frame)
+ self.video_display_component.load_selection_ratio()
+ self.video_slider.setMaximum(self.frame_count)
+ self.video_slider.setValue(1)
+ self.video_display_component.set_dragger_enabled(True)
+ return True
+
+ def load_as_picture(self, path):
+ if not is_image_file(path):
+ return False
+ self.video_path = path
+ self.video_cap = None
+ frame = read_image(get_readable_path(path))
+ if frame is None:
+ return False
+ self.frame_count = 1
+ self.frame_height = frame.shape[0]
+ self.frame_width = frame.shape[1]
+ self.fps = 1
+ self.update_preview(frame)
+ self.video_display_component.load_selection_ratio()
+ self.video_slider.setMaximum(self.frame_count)
+ self.video_slider.setValue(1)
+ self.video_display_component.set_dragger_enabled(True)
+ return True
+
+
+ def open_file(self):
+ files, _ = QtWidgets.QFileDialog.getOpenFileNames(
+ self,
+ tr['SubtitleExtractorGUI']['Open'],
+ "",
+ "All Files (*.*);;MP4 Files (*.mp4);;FLV Files (*.flv);;WMV Files (*.wmv);;AVI Files (*.avi)"
+ )
+ if files:
+ files_loaded = []
+ # 倒序打开, 确保第一个视频截图显示在屏幕上
+ for path in reversed(files):
+ if self.load_video(path):
+ self.append_output(f"{tr['SubtitleExtractorGUI']['OpenVideoSuccess']}: {path}")
+ files_loaded.append(path)
+ else:
+ self.append_output(f"{tr['SubtitleExtractorGUI']['OpenVideoFailed']}: {path}")
+ # 正序添加, 确保任务列表顺序一致
+ for path in reversed(files_loaded):
+ # 添加到任务列表
+ if is_image_file(path):
+ output_path = os.path.abspath(os.path.join(os.path.dirname(path), f'{Path(path).stem}_no_sub.png'))
+ else:
+ output_path = os.path.abspath(os.path.join(os.path.dirname(path), f'{Path(path).stem}_no_sub.mp4'))
+ self.task_list_component.add_task(path, output_path)
+ self.task_list_component.select_task(max(0, self.task_list_component.find_task_index_by_path(path)))
+
+ def closeEvent(self, event):
+ """窗口关闭时断开信号连接"""
+ try:
+ # 断开信号连接
+ self.progress_signal.disconnect(self.update_progress)
+ self.append_log_signal.disconnect(self.append_log)
+ self.update_preview_with_comp_signal.disconnect(self.update_preview_with_comp)
+ self.task_error_signal.disconnect(self.on_task_error)
+
+ # 释放视频资源
+ if self.video_cap:
+ self.video_cap.release()
+ self.video_cap = None
+
+ # 确保所有子进程都已终止
+ ProcessManager.instance().terminate_all()
+ except Exception as e:
+ print(f"关闭窗口时出错: {str(e)}")
+ super().closeEvent(event)
+
\ No newline at end of file
diff --git a/ui/icon/my_fluent_icon.py b/ui/icon/my_fluent_icon.py
new file mode 100644
index 0000000..6848506
--- /dev/null
+++ b/ui/icon/my_fluent_icon.py
@@ -0,0 +1,11 @@
+from enum import Enum
+
+from qfluentwidgets import getIconColor, Theme, FluentIconBase
+
+
+class MyFluentIcon(FluentIconBase, Enum):
+ Stop = "stop"
+
+ def path(self, theme=Theme.AUTO):
+ # getIconColor() return "white" or "black" according to current theme
+ return f'./ui/icon/{self.value}_{getIconColor(theme)}.svg'
diff --git a/ui/icon/stop_black.svg b/ui/icon/stop_black.svg
new file mode 100644
index 0000000..e4a0c7d
--- /dev/null
+++ b/ui/icon/stop_black.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/ui/icon/stop_white.svg b/ui/icon/stop_white.svg
new file mode 100644
index 0000000..095ad12
--- /dev/null
+++ b/ui/icon/stop_white.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/ui/setting_interface.py b/ui/setting_interface.py
new file mode 100644
index 0000000..67527da
--- /dev/null
+++ b/ui/setting_interface.py
@@ -0,0 +1,70 @@
+from PySide6 import QtWidgets
+from qfluentwidgets import (FluentWindow, PushButton, Slider, ProgressBar, PlainTextEdit,
+ setTheme, Theme, FluentIcon, CardWidget, SettingCardGroup,
+ ComboBoxSettingCard, SwitchSettingCard, RangeSettingCard,
+ PushSettingCard, PrimaryPushSettingCard, OptionsSettingCard,
+ FolderListSettingCard, HyperlinkCard, ColorSettingCard,
+ CustomColorSettingCard)
+from backend.config import config, tr, HARDWARD_ACCELERATION_OPTION
+from backend.tools.constant import InpaintMode, SubtitleDetectMode
+
+class SettingInterface(QtWidgets.QVBoxLayout):
+
+ def __init__(self, parent):
+ super().__init__()
+ self.setContentsMargins(16, 16, 16, 16)
+
+ # 界面语言设置
+ self.interface_combo = ComboBoxSettingCard(
+ configItem=config.interface,
+ icon=FluentIcon.LANGUAGE,
+ title=tr["SubtitleExtractorGUI"]["InterfaceLanguage"],
+ content="",
+ parent=parent,
+ texts=config.intefaceTexts.keys(),
+ )
+ self.addWidget(self.interface_combo)
+
+ # 处理模式设置
+ self.inpaint_mode_combo = ComboBoxSettingCard(
+ configItem=config.inpaintMode,
+ icon=FluentIcon.GLOBE,
+ title=tr["SubtitleExtractorGUI"]["InpaintMode"],
+ content="",
+ parent=parent,
+ texts=[list(tr['InpaintMode'].values())[i] for i,_ in enumerate(config.inpaintMode.validator.options)],
+ )
+ self.inpaint_mode_combo.setToolTip(tr["SubtitleExtractorGUI"]["InpaintModeDesc"])
+ self.addWidget(self.inpaint_mode_combo)
+
+ self.subtitle_detect_model_combo = ComboBoxSettingCard(
+ configItem=config.subtitleDetectMode,
+ icon=FluentIcon.SEARCH,
+ title=tr["SubtitleExtractorGUI"]["SubtitleDetectMode"],
+ content="",
+ parent=parent,
+ texts=[list(tr['SubtitleDetectMode'].values())[i] for i,_ in enumerate(config.subtitleDetectMode.validator.options)],
+ )
+ self.addWidget(self.subtitle_detect_model_combo)
+
+ # 是否启用硬件加速
+ self.hardware_acceleration = SwitchSettingCard(
+ configItem=config.hardwareAcceleration,
+ icon=FluentIcon.SPEED_HIGH,
+ title=tr["Setting"]["HardwareAcceleration"],
+ content=tr["Setting"]["HardwareAccelerationDesc"],
+ parent=parent
+ )
+ self.addWidget(self.hardware_acceleration)
+ # 如果硬件加速选项被禁用, 设置硬件加速为False并只读
+ if not HARDWARD_ACCELERATION_OPTION:
+ self.hardware_acceleration.switchButton.setChecked(False)
+ self.hardware_acceleration.switchButton.setEnabled(False)
+ config.set(config.hardwareAcceleration, False)
+ # 添加一些空间
+ self.addStretch(1)
+
+ def reset_setting(self):
+ """重置所有设置为默认值"""
+ # 这里需要实现重置逻辑
+ pass
\ No newline at end of file