70 Commits

Author SHA1 Message Date
天涯古巷
53baf28326 Revert "由于PySimpleGUI作者故意移除免费的旧版本,改用PySimpleGUI-4-foss" 2025-04-25 11:03:16 +08:00
天涯古巷
991294c00f Merge pull request #130 from eritpchy/main
由于PySimpleGUI作者故意移除免费的旧版本,改用PySimpleGUI-4-foss
2025-04-25 07:36:30 +08:00
Jason
ea7e01e3aa 发布docker镜像 2025-04-24 23:45:54 +08:00
Jason
acdb150aa2 支持CI自动发布 2025-04-24 16:22:01 +08:00
Jason
285bfbafa7 Update requirements.txt 2025-04-24 16:14:43 +08:00
Jason
97b4159d38 DirectML版本支持运行STTN模型(Windows) 2025-04-24 15:56:13 +08:00
Jason
bb80445cf4 修复结束时inpaint_area报错 2025-04-24 15:53:28 +08:00
Jason
7e8d0b818b 主界面显示版本号, 方便定位问题 2025-04-24 15:51:21 +08:00
Jason
77758d258b 适配torch 2.8.0 nightly build 2025-04-24 15:47:23 +08:00
Jason
c60234f4ec 改用PaddleOCR, 跟随主线更新 2025-04-24 15:46:38 +08:00
Jason
38ff91fad7 由于PySimpleGUI作者故意移除免费的旧版本,改用PySimpleGUI-4-foss 2025-02-28 15:35:59 +08:00
天涯古巷
9f9cded1ff Update README.md 2025-02-19 09:34:37 +08:00
天涯古巷
54027ceeb0 Add files via upload 2025-02-19 09:33:02 +08:00
天涯古巷
1dc9036dee Update README.md 2025-02-19 09:24:46 +08:00
天涯古巷
09dbfa47f2 Update README.md 2025-01-03 10:29:42 +08:00
天涯古巷
aa83db0f98 Update README.md 2024-12-24 11:37:19 +08:00
天涯古巷
f86c8c9fe8 Update README.md 2024-12-15 17:26:49 +08:00
天涯古巷
3dc8f3bfe0 Update main.py 2024-10-23 16:41:01 +08:00
天涯古巷
b0ca454473 感谢张音乐赞助 2024-10-13 14:22:40 +08:00
天涯古巷
9f7fd5b341 Update README_en.md 2024-10-09 18:00:21 +08:00
天涯古巷
535fdecef4 Update README_en.md 2024-10-09 17:58:14 +08:00
天涯古巷
019f7f4517 Update .condarc 2024-10-09 17:45:49 +08:00
天涯古巷
8c5ea2e19d Update README.md 2024-10-09 17:22:15 +08:00
天涯古巷
330cf54e1a Update README.md 2024-10-09 17:15:21 +08:00
天涯古巷
7019572f7b Update README.md 2024-09-30 23:59:58 +08:00
天涯古巷
ee53840adb Update README.md 2024-09-30 22:45:23 +08:00
天涯古巷
96d744b3a7 Update README.md 2024-09-30 22:42:35 +08:00
天涯古巷
32c47873ab 升级paddle到2.6.1 2024-09-30 22:34:06 +08:00
天涯古巷
99770a32b9 Update README.md 2024-09-30 22:29:47 +08:00
天涯古巷
0f71d732e1 Update README.md 2024-09-29 07:18:52 +08:00
天涯古巷
f3a982710d Update README.md 2024-09-25 20:53:36 +08:00
天涯古巷
e07849ef87 Update README.md 2024-09-19 20:10:28 +08:00
天涯古巷
96099ea2d4 Merge pull request #91 from Brikarl/main
Update PySimpleGUI version to 4.70.1
2024-09-19 20:06:44 +08:00
Brikarl
f4c22dd420 Update requirements.txt 2024-09-19 19:45:50 +08:00
天涯古巷
a3452832ff Update README.md 2024-07-10 17:56:02 +08:00
天涯古巷
45e80bc9b0 感谢Talkuv app的支持 2024-07-10 17:53:01 +08:00
天涯古巷
4a09342987 Update README.md 2024-07-06 17:27:20 +08:00
天涯古巷
caf4cb27f4 Update README.md 2024-06-19 08:48:14 +08:00
天涯古巷
c927476c0f 感谢衣食父母陈的赞助 2024-06-06 16:58:44 +08:00
天涯古巷
61aa3d8f88 Update README.md 2024-01-17 19:20:30 +08:00
天涯古巷
67fdacdd8b Update README.md 2024-01-16 22:39:33 +08:00
天涯古巷
3d21963995 Update README.md 2024-01-09 17:18:07 +08:00
YaoFANGUK
a3dd7b797d 添加注释 2024-01-09 11:05:07 +08:00
YaoFANGUK
6b353455a0 添加文献 2024-01-09 09:24:19 +08:00
YaoFANGUK
d6736d9206 添加sttn训练代码 2024-01-08 17:48:21 +08:00
天涯古巷
4abc3409ac Update README.md 2024-01-07 07:07:58 +08:00
YaoFANGUK
2d1eb11fd6 增大视野,保证去除效果 2024-01-05 16:57:40 +08:00
YaoFANGUK
9a65c17a50 修改备注 2024-01-05 14:39:05 +08:00
天涯古巷
3ce8d7409b Update README.md 2024-01-04 14:43:10 +08:00
YaoFANGUK
f9dd30fddf 兼容安卓手机不能分享生成视频的问题 2024-01-04 14:33:33 +08:00
YaoFANGUK
fda9024084 update readme_en 2024-01-02 14:48:49 +08:00
天涯古巷
19141ff5c9 Update README.md 2024-01-02 14:37:31 +08:00
天涯古巷
97b54f6d9e Update README_en.md 2023-12-30 09:27:48 +08:00
天涯古巷
584e574795 Update README.md 2023-12-30 09:24:57 +08:00
YaoFANGUK
dad37eba7d 更新readme 2023-12-29 15:52:55 +08:00
天涯古巷
063a896cb9 Update README.md 2023-12-29 15:41:00 +08:00
天涯古巷
63d8378f36 Update README.md 2023-12-29 12:57:45 +08:00
天涯古巷
4cbfa9ebf0 Update README.md 2023-12-29 11:58:15 +08:00
YaoFANGUK
8a8088be1f 新增单张图片可以选择去除区域 2023-12-29 10:48:59 +08:00
YaoFANGUK
757cc5bf77 感谢赞助 2023-12-29 10:33:46 +08:00
YaoFANGUK
e536d6af86 统一分辨率视频批处理时候,可以使用字幕区域 2023-12-29 10:28:00 +08:00
YaoFANGUK
311701d3e6 新增gui批处理功能 2023-12-29 10:13:45 +08:00
YaoFANGUK
a7e62db98a 屏蔽windows删除文件报错 2023-12-29 09:33:07 +08:00
YaoFANGUK
945aeb9bc8 新增文件类型判断 2023-12-29 09:23:42 +08:00
YaoFANGUK
6ea7482344 minor 2023-12-29 08:46:36 +08:00
YaoFANGUK
ba396d9569 未传入字幕区域时,进行全屏处理 2023-12-29 08:45:20 +08:00
天涯古巷
22b021d9ae Update README.md 2023-12-28 20:05:32 +08:00
天涯古巷
49ae0029f5 Update README.md 2023-12-28 20:04:26 +08:00
天涯古巷
f89c109636 Update README.md 2023-12-28 19:39:09 +08:00
天涯古巷
055a08403f Update README.md 2023-12-28 15:55:45 +08:00
26 changed files with 1236 additions and 95 deletions

View File

@@ -2,13 +2,9 @@ channels:
- defaults
show_channel_urls: true
default_channels:
- http://mirrors.aliyun.com/anaconda/pkgs/main
- http://mirrors.aliyun.com/anaconda/pkgs/r
- http://mirrors.aliyun.com/anaconda/pkgs/msys2
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2
custom_channels:
conda-forge: http://mirrors.aliyun.com/anaconda/cloud
msys2: http://mirrors.aliyun.com/anaconda/cloud
bioconda: http://mirrors.aliyun.com/anaconda/cloud
menpo: http://mirrors.aliyun.com/anaconda/cloud
pytorch: http://mirrors.aliyun.com/anaconda/cloud
simpleitk: http://mirrors.aliyun.com/anaconda/cloud
conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud

1
.gitignore vendored
View File

@@ -372,3 +372,4 @@ test*_no_sub*.mp4
/backend/models/video/ProPainter.pth
/backend/models/big-lama/big-lama.pt
/test/debug/
/backend/tools/train/release_model/

101
README.md
View File

@@ -12,12 +12,13 @@ Video-subtitle-remover (VSR) 是一款基于AI技术将视频中的硬字幕
- 通过超强AI算法模型对去除字幕文本的区域进行填充非相邻像素填充与马赛克去除
- 支持自定义字幕位置,仅去除定义位置中的字幕(传入位置)
- 支持全视频自动去除所有文本(不传入位置)
- 支持多选图片批量去除水印文本
<p style="text-align:center;"><img src="https://github.com/YaoFANGUK/video-subtitle-remover/raw/main/design/demo.png" alt="demo.png"/></p>
**使用说明:**
- 有使用问题请加群讨论QQ群806152575
- 有使用问题请加群讨论QQ群806152575已满、816881808
- 直接下载压缩包解压运行如果不能运行再按照下面的教程尝试源码安装conda环境运行
**下载地址:**
@@ -111,9 +112,9 @@ conda activate videoEnv
<h5>(1) 下载CUDA 11.7</h5>
<a href="https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_516.01_windows.exe">cuda_11.7.0_516.01_windows.exe</a>
<h5>(2) 安装CUDA 11.7</h5>
<h5>(3) 下载cuDNN 8.2.4</h5>
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x64-v8.2.4.15.zip">cudnn-windows-x64-v8.2.4.15.zip</a></p>
<h5>(4) 安装cuDNN 8.2.4</h5>
<h5>(3) 下载cuDNN v8.4.0 (April 1st, 2022), for CUDA 11.x</h5>
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip">cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip</a></p>
<h5>(4) 安装cuDNN 8.4.0</h5>
<p>
将cuDNN解压后的cuda文件夹中的bin, include, lib目录下的文件复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\对应目录下
</p>
@@ -136,12 +137,12 @@ conda activate videoEnv
- 安装GPU版本Pytorch:
```shell
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
```shell
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.8 -c pytorch -c nvidia
```
或者使用
```shell
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117
```shell
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
```
- 安装其他依赖:
@@ -166,29 +167,85 @@ python ./backend/main.py
```
## 常见问题
1. CondaHTTPError
1. 提取速度慢怎么办
修改backend/config.py中的参数可以大幅度提高去除速度
```python
MODE = InpaintMode.STTN # 设置为STTN算法
STTN_SKIP_DETECTION = True # 跳过字幕检测,跳过后可能会导致要去除的字幕遗漏或者误伤不需要去除字幕的视频帧
```
2. 视频去除效果不好怎么办
修改backend/config.py中的参数尝试不同的去除算法算法介绍
> - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
> - InpaintMode.LAMA 算法:对于图片效果最好,对动画类视频效果好,速度一般,不可以跳过字幕检测
> - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
- 使用STTN算法
```python
MODE = InpaintMode.STTN # 设置为STTN算法
# 相邻帧数, 调大会增加显存占用,效果变好
STTN_NEIGHBOR_STRIDE = 10
# 参考帧长度, 调大会增加显存占用,效果变好
STTN_REFERENCE_LENGTH = 10
# 设置STTN算法最大同时处理的帧数量设置越大速度越慢但效果越好
# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
STTN_MAX_LOAD_NUM = 30
```
- 使用LAMA算法
```python
MODE = InpaintMode.LAMA # 设置为STTN算法
LAMA_SUPER_FAST = False # 保证效果
```
> 如果对模型去字幕的效果不满意可以查看design文件夹里面的训练方法利用backend/tools/train里面的代码进行训练然后将训练的模型替换旧模型即可
3. CondaHTTPError
将项目中的.condarc放在用户目录下(C:/Users/<你的用户名>),如果用户目录已经存在该文件则覆盖
解决方案https://zhuanlan.zhihu.com/p/260034241
2. 7z文件解压错误
4. 7z文件解压错误
解决方案升级7-zip解压程序到最新版本
3. 4090使用cuda 11.7跑不起来
5. 4090使用cuda 11.7跑不起来
解决方案改用cuda 11.8
## 赞助
<img src="https://i.imgur.com/EMCP5Lv.jpeg" width="600">
```shell
pip install torch==2.1.0 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
```
| 捐赠者 | 累计捐赠金额 | 赞助席位 |
| --- | --- | --- |
| 坤V | 400.00 RMB | 金牌赞助席位 |
| 陈凯 | 50.00 RMB | 银牌赞助席位 |
| Tshuang | 20.00 RMB | 银牌赞助席位 |
| 很奇异| 15.00 RMB | 银牌赞助席位 |
| 何斐| 10.00 RMB | 牌赞助席位 |
| 长缨在手| 6.00 RMB | 牌赞助席位 |
| Leo| 1.00 RMB | 牌赞助席位 |
## 赞助
<img src="https://github.com/YaoFANGUK/video-subtitle-extractor/raw/main/design/sponsor.png" width="600">
| 捐赠者 | 累计捐赠金额 | 赞助席位 |
|---------------------------|------------| --- |
| 坤V | 400.00 RMB | 牌赞助席位 |
| Jenkit | 200.00 RMB | 牌赞助席位 |
| 子车松兰 | 188.00 RMB | 牌赞助席位 |
| 落花未逝 | 100.00 RMB | 金牌赞助席位 |
| 张音乐 | 100.00 RMB | 金牌赞助席位 |
| 麦格 | 100.00 RMB | 金牌赞助席位 |
| 无痕 | 100.00 RMB | 金牌赞助席位 |
| wr | 100.00 RMB | 金牌赞助席位 |
| 陈 | 100.00 RMB | 金牌赞助席位 |
| TalkLuv | 50.00 RMB | 银牌赞助席位 |
| 陈凯 | 50.00 RMB | 银牌赞助席位 |
| Tshuang | 20.00 RMB | 银牌赞助席位 |
| 很奇异 | 15.00 RMB | 银牌赞助席位 |
| 郭鑫 | 12.00 RMB | 银牌赞助席位 |
| 生活不止眼前的苟且 | 10.00 RMB | 铜牌赞助席位 |
| 何斐 | 10.00 RMB | 铜牌赞助席位 |
| 老猫 | 8.80 RMB | 铜牌赞助席位 |
| 伍六七 | 7.77 RMB | 铜牌赞助席位 |
| 长缨在手 | 6.00 RMB | 铜牌赞助席位 |
| 无忌 | 6.00 RMB | 铜牌赞助席位 |
| Stephen | 2.00 RMB | 铜牌赞助席位 |
| Leo | 1.00 RMB | 铜牌赞助席位 |

View File

@@ -12,6 +12,7 @@ Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subt
- Fills in the removed subtitle text area using a powerful AI algorithm model (non-adjacent pixel filling and mosaic removal).
- Supports custom subtitle positions by only removing subtitles in the defined location (input position).
- Supports automatic removal of all text throughout the entire video (without inputting a position).
- Supports multi-selection of images for batch removal of watermark text.
<p style="text-align:center;"><img src="https://github.com/YaoFANGUK/video-subtitle-remover/raw/main/design/demo.png" alt="demo.png"/></p>
@@ -110,11 +111,11 @@ Please make sure you have already installed Python 3.8+, use conda to create a p
<h5>(1) Download CUDA 11.7</h5>
<a href="https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_516.01_windows.exe">cuda_11.7.0_516.01_windows.exe</a>
<h5>(2) Install CUDA 11.7</h5>
<h5>(3) Download cuDNN 8.2.4</h5>
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x64-v8.2.4.15.zip">cudnn-windows-x64-v8.2.4.15.zip</a></p>
<h5>(4) Install cuDNN 8.2.4</h5>
<h5>(3) Download cuDNN 8.4.0</h5>
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip">cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip</a></p>
<h5>(4) Install cuDNN 8.4.0</h5>
<p>
unzip "cudnn-windows-x64-v8.2.4.15.zip", then move all files in "bin, include, lib" in cuda
unzip "cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip", then move all files in "bin, include, lib" in cuda
directory to C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\
</p>
</details>
@@ -136,12 +137,12 @@ Please make sure you have already installed Python 3.8+, use conda to create a p
- Install GPU version of Pytorch:
```shell
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
conda install pytorch==2.1.0 torchvision==0.16.0 pytorch-cuda=11.8 -c pytorch -c nvidia
```
or use
```shell
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117
pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
```
- Install other dependencies:
@@ -166,12 +167,54 @@ python ./backend/main.py
```
## Common Issues
1. CondaHTTPError
1. How to deal with slow removal speed
You can greatly increase the removal speed by modifying the parameters in backend/config.py:
```python
MODE = InpaintMode.STTN # Set to STTN algorithm
STTN_SKIP_DETECTION = True # Skip subtitle detection
```
2. What to do if the video removal results are not satisfactory
Modify the values in backend/config.py and try different removal algorithms. Here is an introduction to the algorithms:
> - **InpaintMode.STTN** algorithm: Good for live-action videos and fast in speed, capable of skipping subtitle detection
> - **InpaintMode.LAMA** algorithm: Best for images and effective for animated videos, moderate speed, unable to skip subtitle detection
> - **InpaintMode.PROPAINTER** algorithm: Consumes a significant amount of VRAM, slower in speed, works better for videos with very intense movement
- Using the STTN algorithm
```python
MODE = InpaintMode.STTN # Set to STTN algorithm
# Number of neighboring frames, increasing this will increase memory usage and improve the result
STTN_NEIGHBOR_STRIDE = 10
# Length of reference frames, increasing this will increase memory usage and improve the result
STTN_REFERENCE_LENGTH = 10
# Set the maximum number of frames processed simultaneously by the STTN algorithm, a larger value leads to slower processing but better results
# Ensure that STTN_MAX_LOAD_NUM is greater than STTN_NEIGHBOR_STRIDE and STTN_REFERENCE_LENGTH
STTN_MAX_LOAD_NUM = 30
```
- Using the LAMA algorithm
```python
MODE = InpaintMode.LAMA # Set to LAMA algorithm
LAMA_SUPER_FAST = False # Ensure quality
```
3. CondaHTTPError
Place the .condarc file from the project in the user directory (C:/Users/<your_username>). If the file already exists in the user directory, overwrite it.
Solution: https://zhuanlan.zhihu.com/p/260034241
2. 7z file extraction error
4. 7z file extraction error
Solution: Upgrade the 7-zip extraction program to the latest version.
```shell
pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
```

View File

@@ -64,12 +64,17 @@ class InpaintMode(Enum):
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# 是否使用h264编码如果需要安卓手机分享生成的视频请打开该选项
USE_H264 = True
# ×××××××××× 通用设置 start ××××××××××
"""
MODE可选算法类型
- InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
- InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测
- InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
"""
# 【设置inpaint算法】
# - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
# - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以字幕检测
# - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
MODE = InpaintMode.STTN
# 【设置像素点偏差】
# 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
@@ -85,17 +90,33 @@ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
# ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
# 以下参数仅适用STTN算法时才生效
# 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧
"""
1. STTN_SKIP_DETECTION
含义:是否使用跳过检测
效果设置为True跳过字幕检测会省去很大时间但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了
2. STTN_NEIGHBOR_STRIDE
含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域STTN_NEIGHBOR_STRIDE=5那么算法会使用第45帧、第40帧等作为参照。
效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。
3. STTN_REFERENCE_LENGTH
含义参数帧数量STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息
效果:调大会增加显存占用,处理效果变好,但是处理速度变慢
4. STTN_MAX_LOAD_NUM
含义STTN算法每次最多加载的视频帧数量
效果:设置越大速度越慢,但效果越好
注意要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
"""
STTN_SKIP_DETECTION = True
# 相邻帧数, 调大会增加显存占用,效果变好
STTN_NEIGHBOR_STRIDE = 10
# 参考帧长度, 调大会增加显存占用,效果变好
# 参考帧步长
STTN_NEIGHBOR_STRIDE = 5
# 参考帧长度(数量)
STTN_REFERENCE_LENGTH = 10
# 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好
# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
STTN_MAX_LOAD_NUM = 30
if STTN_MAX_LOAD_NUM < max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH):
STTN_MAX_LOAD_NUM = max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH)
# 设置STTN算法最大同时处理的帧数量
STTN_MAX_LOAD_NUM = 50
if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE:
STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
# ×××××××××× InpaintMode.PROPAINTER算法设置 start ××××××××××

View File

@@ -93,6 +93,7 @@ class STTNInpaint:
@staticmethod
def read_mask(path):
img = cv2.imread(path, 0)
# 转为binary mask
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
img = img[:, :, None]
return img
@@ -200,6 +201,24 @@ class STTNInpaint:
to_H -= h
return inpaint_area # 返回绘画区域列表
@staticmethod
def get_inpaint_area_by_selection(input_sub_area, mask):
print('use selection area for inpainting')
height, width = mask.shape[:2]
ymin, ymax, _, _ = input_sub_area
interval_size = 135
# 存储结果的列表
inpaint_area = []
# 计算并存储标准区间
for i in range(ymin, ymax, interval_size):
inpaint_area.append((i, i + interval_size))
# 检查最后一个区间是否达到了最大值
if inpaint_area[-1][1] != ymax:
# 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值
if inpaint_area[-1][1] + interval_size <= height:
inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size))
return inpaint_area # 返回绘画区域列表
class STTNVideoInpaint:
@@ -291,10 +310,10 @@ class STTNVideoInpaint:
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
writer.write(frame)
if input_sub_remover is not None and input_sub_remover.gui_mode:
if input_sub_remover is not None:
if tbar is not None:
input_sub_remover.update_progress(tbar, increment=1)
if original_frame is not None:
if original_frame is not None and input_sub_remover.gui_mode:
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
# 释放视频写入对象
writer.release()

View File

@@ -1,3 +1,4 @@
import torch
import shutil
import subprocess
import os
@@ -5,9 +6,11 @@ from pathlib import Path
import threading
import cv2
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config
from backend.tools.common_tools import is_video_or_image, is_image_file
from backend.scenedetect import scene_detect
from backend.scenedetect.detectors import ContentDetector
from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint
@@ -17,7 +20,6 @@ from backend.tools.inpaint_tools import create_mask, batch_generator
import importlib
import platform
import tempfile
import torch
import multiprocessing
from shapely.geometry import Polygon
import time
@@ -257,7 +259,43 @@ class SubtitleDetect:
return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
@staticmethod
def expand_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
def expand_and_merge_intervals(intervals, expand_size=config.STTN_NEIGHBOR_STRIDE*config.STTN_REFERENCE_LENGTH, max_length=config.STTN_MAX_LOAD_NUM):
# 初始化输出区间列表
expanded_intervals = []
# 对每个原始区间进行扩展
for interval in intervals:
start, end = interval
# 扩展至至少 'expand_size' 个单位,但不超过 'max_length' 个单位
expansion_amount = max(expand_size - (end - start + 1), 0)
# 在保证包含原区间的前提下尽可能平分前后扩展量
expand_start = max(start - expansion_amount // 2, 1) # 确保起始点不小于1
expand_end = end + expansion_amount // 2
# 如果扩展后的区间超出了最大长度,进行调整
if (expand_end - expand_start + 1) > max_length:
expand_end = expand_start + max_length - 1
# 对于单点的处理,需额外保证有至少 'expand_size' 长度
if start == end:
if expand_end - expand_start + 1 < expand_size:
expand_end = expand_start + expand_size - 1
# 检查与前一个区间是否有重叠并进行相应的合并
if expanded_intervals and expand_start <= expanded_intervals[-1][1]:
previous_start, previous_end = expanded_intervals.pop()
expand_start = previous_start
expand_end = max(expand_end, previous_end)
# 添加扩展后的区间至结果列表
expanded_intervals.append((expand_start, expand_end))
return expanded_intervals
@staticmethod
def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
"""
合并传入的字幕起始区间确保区间大小最低为STTN_REFERENCE_LENGTH
"""
@@ -290,12 +328,10 @@ class SubtitleDetect:
for start, end in expanded[1:]:
last_start, last_end = merged[-1]
# 检查是否重叠
if start <= last_end and (
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 需要合并
merged[-1] = (last_start, max(last_end, end)) # 合并区间
elif start == last_end + 1 and (
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 相邻区间也需要合并的场景
merged[-1] = (last_start, end)
else:
@@ -483,7 +519,7 @@ class SubtitleRemover:
self.gui_mode = gui_mode
# 判断是否为图片
self.is_picture = False
if str(vd_path).endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
if is_image_file(str(vd_path)):
self.sub_area = None
self.is_picture = True
# 视频路径
@@ -505,8 +541,7 @@ class SubtitleRemover:
# 创建视频临时对象windows下delete=True会有permission denied的报错
self.video_temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
# 创建视频写对象
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps,
self.size)
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
self.video_inpaint = None
self.lama_inpaint = None
@@ -647,16 +682,14 @@ class SubtitleRemover:
self.lama_inpaint = LamaInpaint()
inpainted_frame = self.lama_inpaint(frame, single_mask)
self.video_writer.write(inpainted_frame)
print(
f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
inner_index += 1
self.update_progress(tbar, increment=1)
elif len(batch) > 1:
inpainted_frames = self.video_inpaint.inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(
f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
inner_index += 1
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
@@ -670,12 +703,13 @@ class SubtitleRemover:
print('[Processing] start removing subtitles...')
if self.sub_area is not None:
ymin, ymax, xmin, xmax = self.sub_area
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
mask = create_mask(self.mask_size, mask_area_coordinates)
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
else:
print('please set subtitle area first')
print('[Info] No subtitle area has been set. Video will be processed in full screen. As a result, the final outcome might be suboptimal.')
ymin, ymax, xmin, xmax = 0, self.frame_height, 0, self.frame_width
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
mask = create_mask(self.mask_size, mask_area_coordinates)
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
def sttn_mode(self, tbar):
# 是否跳过字幕帧寻找
@@ -688,7 +722,7 @@ class SubtitleRemover:
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
print(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list)
print(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:
@@ -847,7 +881,7 @@ class SubtitleRemover:
audio_merge_command = [config.FFMPEG_PATH,
"-y", "-i", self.video_temp_file.name,
"-i", temp.name,
"-vcodec", "copy",
"-vcodec", "libx264" if config.USE_H264 else "copy",
"-acodec", "copy",
"-loglevel", "error", self.video_out_name]
try:
@@ -859,7 +893,10 @@ class SubtitleRemover:
try:
os.remove(temp.name)
except Exception:
print(f'failed to delete temp file {temp.name}')
if platform.system() in ['Windows']:
pass
else:
print(f'failed to delete temp file {temp.name}')
self.is_successful_merged = True
finally:
temp.close()
@@ -874,9 +911,13 @@ class SubtitleRemover:
if __name__ == '__main__':
multiprocessing.set_start_method("spawn")
# 1. 提示用户输入视频路径
video_path = input(f"Please input video file path: ").strip()
video_path = input(f"Please input video or image file path: ").strip()
# 判断视频路径是不是一个目录,是目录的化,批量处理改目录下的所有视频文件
# 2. 按以下顺序传入字幕区域
# sub_area = (ymin, ymax, xmin, xmax)
# 3. 新建字幕提取对象
sd = SubtitleRemover(video_path, sub_area=None)
sd.run()
if is_video_or_image(video_path):
sd = SubtitleRemover(video_path, sub_area=None)
sd.run()
else:
print(f'Invalid video path: {video_path}')

View File

@@ -0,0 +1,32 @@
import os
video_extensions = {
'.mp4', '.m4a', '.m4v', '.f4v', '.f4a', '.m4b', '.m4r', '.f4b', '.mov',
'.3gp', '.3gp2', '.3g2', '.3gpp', '.3gpp2', '.ogg', '.oga', '.ogv', '.ogx',
'.wmv', '.wma', '.asf', '.webm', '.flv', '.avi', '.gifv', '.mkv', '.rm',
'.rmvb', '.vob', '.dvd', '.mpg', '.mpeg', '.mp2', '.mpe', '.mpv', '.mpg',
'.mpeg', '.m2v', '.svi', '.3gp', '.mxf', '.roq', '.nsv', '.flv', '.f4v',
'.f4p', '.f4a', '.f4b'
}
image_extensions = {
'.jpg', '.jpeg', '.jpe', '.jif', '.jfif', '.jfi', '.png', '.gif',
'.webp', '.tiff', '.tif', '.psd', '.raw', '.arw', '.cr2', '.nrw',
'.k25', '.bmp', '.dib', '.heif', '.heic', '.ind', '.indd', '.indt',
'.jp2', '.j2k', '.jpf', '.jpx', '.jpm', '.mj2', '.svg', '.svgz',
'.ai', '.eps', '.ico'
}
def is_video_file(filename):
return os.path.splitext(filename)[-1].lower() in video_extensions
def is_image_file(filename):
return os.path.splitext(filename)[-1].lower() in image_extensions
def is_video_or_image(filename):
file_extension = os.path.splitext(filename)[-1].lower()
# 检查扩展名是否在定义的视频或图片文件后缀集合中
return file_extension in video_extensions or file_extension in image_extensions

View File

@@ -22,8 +22,8 @@ def merge_video(video_input_path0, video_input_path1, video_output_path):
if __name__ == '__main__':
v0_path = '../../test/test_2_low.mp4'
v1_path = '../../test/test_2_low_no_sub.mp4'
v0_path = '../../test/test4.mp4'
v1_path = '../../test/test4_no_sub(1).mp4'
video_out_path = '../../test/demo.mp4'
merge_video(v0_path, v1_path, video_out_path)
# ffmpeg 命令 mp4转gif

View File

@@ -0,0 +1,33 @@
{
"seed": 2020,
"save_dir": "release_model/",
"data_loader": {
"name": "davis",
"data_root": "datasets/",
"w": 640,
"h": 120,
"sample_length": 5
},
"losses": {
"hole_weight": 1,
"valid_weight": 1,
"adversarial_weight": 0.01,
"GAN_LOSS": "hinge"
},
"trainer": {
"type": "Adam",
"beta1": 0,
"beta2": 0.99,
"lr": 1e-4,
"d2glr": 1,
"batch_size": 8,
"num_workers": 2,
"verbosity": 2,
"log_step": 100,
"save_freq": 1e4,
"valid_freq": 1e4,
"iterations": 50e4,
"niter": 30e4,
"niter_steady": 30e4
}
}

View File

@@ -0,0 +1,33 @@
{
"seed": 2020,
"save_dir": "release_model/",
"data_loader": {
"name": "youtube-vos",
"data_root": "datasets_sttn/",
"w": 640,
"h": 120,
"sample_length": 5
},
"losses": {
"hole_weight": 1,
"valid_weight": 1,
"adversarial_weight": 0.01,
"GAN_LOSS": "hinge"
},
"trainer": {
"type": "Adam",
"beta1": 0,
"beta2": 0.99,
"lr": 1e-4,
"d2glr": 1,
"batch_size": 8,
"num_workers": 2,
"verbosity": 2,
"log_step": 100,
"save_freq": 1e4,
"valid_freq": 1e4,
"iterations": 50e4,
"niter": 15e4,
"niter_steady": 30e4
}
}

View File

@@ -0,0 +1,85 @@
import os
import json
import random
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
# 自定义的数据集
class Dataset(torch.utils.data.Dataset):
def __init__(self, args: dict, split='train', debug=False):
# 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
self.args = args
self.split = split
self.sample_length = args['sample_length'] # 样本长度参数
self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
# 打开存放数据相关信息的json文件
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
self.video_dict = json.load(f) # 加载json文件内容
self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
if debug or split != 'train': # 如果是调试模式或者不是训练集只取前100个视频
self.video_names = self.video_names[:100]
# 定义数据的转换操作,转换成堆叠的张量
self._to_tensors = transforms.Compose([
Stack(),
ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
])
def __len__(self):
# 返回数据集中视频的数量
return len(self.video_names)
def __getitem__(self, index):
# 获取一个样本项
try:
item = self.load_item(index) # 尝试加载指定索引的数据项
except:
print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
item = self.load_item(0) # 加载第一个项目作为兜底
return item
def load_item(self, index):
# 加载数据项的具体实现
video_name = self.video_names[index] # 根据索引获取视频名称
# 为所有视频帧生成帧文件名列表
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
# 生成随机运动的随机形状的遮罩
all_masks = create_random_shape_with_random_motion(
len(all_frames), imageHeight=self.h, imageWidth=self.w)
# 获取参考帧的索引
ref_index = get_ref_index(len(all_frames), self.sample_length)
# 读取视频帧
frames = []
masks = []
for idx in ref_index:
# 读取图片转化为RGB调整大小并添加到列表中
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
img = img.resize(self.size)
frames.append(img)
masks.append(all_masks[idx])
if self.split == 'train':
# 如果是训练集,随机水平翻转图像
frames = GroupRandomHorizontalFlip()(frames)
# 转换成张量形式
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
return frame_tensors, mask_tensors # 返回图像和遮罩的张量
def get_ref_index(length, sample_length):
# 获取参考帧索引的实现
if random.uniform(0, 1) > 0.5:
# 有一半的概率随机选择帧
ref_index = random.sample(range(length), sample_length)
ref_index.sort() # 排序保证顺序
else:
# 另一半概率选择连续的帧
pivot = random.randint(0, length-sample_length)
ref_index = [pivot+i for i in range(sample_length)]
return ref_index

View File

@@ -0,0 +1 @@
{"bear": 82, "bike-packing": 69, "blackswan": 50, "bmx-bumps": 90, "bmx-trees": 80, "boat": 75, "boxing-fisheye": 87, "breakdance": 84, "breakdance-flare": 71, "bus": 80, "camel": 90, "car-roundabout": 75, "car-shadow": 40, "car-turn": 80, "cat-girl": 89, "classic-car": 63, "color-run": 84, "cows": 104, "crossing": 52, "dance-jump": 60, "dance-twirl": 90, "dancing": 62, "disc-jockey": 76, "dog": 60, "dog-agility": 25, "dog-gooses": 86, "dogs-jump": 66, "dogs-scale": 83, "drift-chicane": 52, "drift-straight": 50, "drift-turn": 64, "drone": 91, "elephant": 80, "flamingo": 80, "goat": 90, "gold-fish": 78, "hike": 80, "hockey": 75, "horsejump-high": 50, "horsejump-low": 60, "india": 81, "judo": 34, "kid-football": 68, "kite-surf": 50, "kite-walk": 80, "koala": 100, "lab-coat": 47, "lady-running": 65, "libby": 49, "lindy-hop": 73, "loading": 50, "longboard": 52, "lucia": 70, "mallard-fly": 70, "mallard-water": 80, "mbike-trick": 79, "miami-surf": 70, "motocross-bumps": 60, "motocross-jump": 40, "motorbike": 43, "night-race": 46, "paragliding": 70, "paragliding-launch": 80, "parkour": 100, "pigs": 79, "planes-water": 38, "rallye": 50, "rhino": 90, "rollerblade": 35, "schoolgirls": 80, "scooter-black": 43, "scooter-board": 91, "scooter-gray": 75, "sheep": 68, "shooting": 40, "skate-park": 80, "snowboard": 66, "soapbox": 99, "soccerball": 48, "stroller": 91, "stunt": 71, "surf": 55, "swing": 60, "tennis": 70, "tractor-sand": 76, "train": 80, "tuk-tuk": 59, "upside-down": 65, "varanus-cage": 67, "walking": 72}

View File

@@ -0,0 +1 @@
{"baseball": 90, "basketball-game": 77, "bears-ball": 78, "bmx-rider": 85, "butterfly": 80, "car-competition": 66, "cat": 52, "chairlift": 99, "circus": 73, "city-ride": 70, "crafting": 45, "curling": 76, "dog-competition": 85, "dolphins-show": 74, "dribbling": 49, "drone-flying": 70, "ducks": 75, "elephant-hyenas": 55, "giraffes": 88, "gym-ball": 69, "helicopter-landing": 77, "horse-race": 80, "horses-kids": 78, "hurdles-race": 55, "ice-hockey": 52, "jet-ski": 83, "juggling-selfie": 78, "kayak-race": 63, "kids-robot": 75, "landing": 35, "luggage": 83, "mantaray": 73, "marbles": 70, "mascot": 78, "mermaid": 78, "monster-trucks": 99, "motorbike-indoors": 79, "motorbike-race": 88, "music-band": 87, "obstacles": 81, "obstacles-race": 48, "peacock": 75, "plane-exhibition": 73, "puppet": 100, "robot-battle": 85, "robotic-arm": 82, "rodeo": 85, "sea-turtle": 90, "skydiving-jumping": 75, "snowboard-race": 75, "snowboard-sand": 55, "surfer": 80, "swimmer": 86, "table-tennis": 70, "tram": 84, "trucks-race": 78, "twist-dance": 83, "volleyball-beach": 73, "water-slide": 88, "weightlifting": 90}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,56 @@
import torch
import torch.nn as nn
class AdversarialLoss(nn.Module):
"""
对抗性损失
根据论文 https://arxiv.org/abs/1711.10337 实现
"""
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
"""
可以选择的损失类型有 'nsgan' | 'lsgan' | 'hinge'
type: 指定使用哪种类型的 GAN 损失。
target_real_label: 真实图像的目标标签值。
target_fake_label: 生成图像的目标标签值。
"""
super(AdversarialLoss, self).__init__()
self.type = type # 损失类型
# 使用缓冲区注册标签,这样在模型保存和加载时会一同保存和加载
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
# 根据选择的类型初始化不同的损失函数
if type == 'nsgan':
self.criterion = nn.BCELoss() # 二进制交叉熵损失非饱和GAN
elif type == 'lsgan':
self.criterion = nn.MSELoss() # 均方误差损失最小平方GAN
elif type == 'hinge':
self.criterion = nn.ReLU() # 适用于hinge损失的ReLU函数
def __call__(self, outputs, is_real, is_disc=None):
"""
调用函数计算损失。
outputs: 网络输出。
is_real: 如果是真实样本,则为 True如果是生成样本则为 False。
is_disc: 指示当前是否在优化判别器。
"""
if self.type == 'hinge':
# 对于 hinge 损失
if is_disc:
# 如果是判别器
if is_real:
outputs = -outputs # 对真实样本反向标签
# max(0, 1 - (真/假)示例输出)
return self.criterion(1 + outputs).mean()
else:
# 如果是生成器, -min(0, -输出) = max(0, 输出)
return (-outputs).mean()
else:
# 对于 nsgan 和 lsgan 损失
labels = (self.real_label if is_real else self.fake_label).expand_as(
outputs)
# 计算模型输出和目标标签之间的损失
loss = self.criterion(outputs, labels)
return loss

View File

@@ -0,0 +1,96 @@
import os
import json
import argparse
from shutil import copyfile
import torch
import torch.multiprocessing as mp
from backend.tools.train.trainer_sttn import Trainer
from backend.tools.train.utils_sttn import (
get_world_size,
get_local_rank,
get_global_rank,
get_master_ip,
)
parser = argparse.ArgumentParser(description='STTN')
parser.add_argument('-c', '--config', default='configs_sttn/youtube-vos.json', type=str)
parser.add_argument('-m', '--model', default='sttn', type=str)
parser.add_argument('-p', '--port', default='23455', type=str)
parser.add_argument('-e', '--exam', action='store_true')
args = parser.parse_args()
def main_worker(rank, config):
# 如果配置中没有提到局部排序local_rank就给它和全局排序global_rank赋值为传入的排序rank
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
# 如果配置指定为分布式训练
if config['distributed']:
# 设置使用的CUDA设备为当前的本地排名对应的GPU
torch.cuda.set_device(int(config['local_rank']))
# 初始化分布式进程组通过nccl后端
torch.distributed.init_process_group(
backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch'
)
# 打印当前GPU的使用情况输出全球排名和本地排名
print('using GPU {}-{} for training'.format(
int(config['global_rank']), int(config['local_rank']))
)
# 创建模型保存的目录路径,包括模型名和配置文件名
config['save_dir'] = os.path.join(
config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0])
)
# 如果CUDA可用则设置设备为相应的CUDA设备否则为CPU
if torch.cuda.is_available():
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
else:
config['device'] = 'cpu'
# 如果不是分布式训练或者是分布式训练的主节点rank 0
if (not config['distributed']) or config['global_rank'] == 0:
# 创建模型保存目录并允许如果该目录存在则忽略创建exist_ok=True
os.makedirs(config['save_dir'], exist_ok=True)
# 设置配置文件的保存路径
config_path = os.path.join(
config['save_dir'], config['config'].split('/')[-1]
)
# 如果配置文件不存在,则从给定的配置文件路径复制到新路径
if not os.path.isfile(config_path):
copyfile(config['config'], config_path)
# 打印创建目录的信息
print('[**] create folder {}'.format(config['save_dir']))
# 初始化训练器传入配置参数和debug标记
trainer = Trainer(config, debug=args.exam)
# 开始训练
trainer.train()
if __name__ == "__main__":
# 加载配置文件
config = json.load(open(args.config))
config['model'] = args.model # 设置模型名称
config['config'] = args.config # 设置配置文件路径
# 设置分布式训练的相关配置
config['world_size'] = get_world_size() # 获取全局进程数即训练过程中参与计算的总GPU数量
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" # 设置初始化方法包括主节点IP和端口
config['distributed'] = True if config['world_size'] > 1 else False # 根据世界规模确定是否启用分布式训练
# 设置分布式并行训练环境
if get_master_ip() == "127.0.0.1":
# 如果主节点IP是本机地址那么手动启动多个分布式训练进程
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
else:
# 如果是由其他工具如OpenMPI启动的多个进程不需手动创建进程。
config['local_rank'] = get_local_rank() # 获取本地(单个节点)排名
config['global_rank'] = get_global_rank() # 获取全局排名
main_worker(-1, config) # 启动主工作函数

View File

@@ -0,0 +1,319 @@
import os
import glob
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from tensorboardX import SummaryWriter
from backend.inpaint.sttn.auto_sttn import Discriminator
from backend.inpaint.sttn.auto_sttn import InpaintGenerator
from backend.tools.train.dataset_sttn import Dataset
from backend.tools.train.loss_sttn import AdversarialLoss
class Trainer:
def __init__(self, config, debug=False):
# 训练器初始化
self.config = config # 保存配置信息
self.epoch = 0 # 当前训练所处的epoch
self.iteration = 0 # 当前训练迭代次数
if debug:
# 如果是调试模式,设置更频繁的保存和验证频率
self.config['trainer']['save_freq'] = 5
self.config['trainer']['valid_freq'] = 5
self.config['trainer']['iterations'] = 5
# 设置数据集和数据加载器
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug) # 创建训练集对象
self.train_sampler = None # 初始化训练集采样器为None
self.train_args = config['trainer'] # 训练过程参数
if config['distributed']:
# 如果是分布式训练,则初始化分布式采样器
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=config['world_size'],
rank=config['global_rank']
)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.train_args['batch_size'] // config['world_size'],
shuffle=(self.train_sampler is None), # 如果没有采样器则进行打乱
num_workers=self.train_args['num_workers'],
sampler=self.train_sampler
)
# 设置损失函数
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) # 对抗性损失
self.adversarial_loss = self.adversarial_loss.to(self.config['device']) # 将损失函数转移到相应设备
self.l1_loss = nn.L1Loss() # L1损失
# 初始化生成器和判别器模型
self.netG = InpaintGenerator() # 生成网络
self.netG = self.netG.to(self.config['device']) # 转移到设备
self.netD = Discriminator(
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge'
)
self.netD = self.netD.to(self.config['device']) # 判别网络
# 初始化优化器
self.optimG = torch.optim.Adam(
self.netG.parameters(), # 生成器参数
lr=config['trainer']['lr'], # 学习率
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
)
self.optimD = torch.optim.Adam(
self.netD.parameters(), # 判别器参数
lr=config['trainer']['lr'], # 学习率
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
)
self.load() # 加载模型
if config['distributed']:
# 如果是分布式训练,则使用分布式数据并行包装器
self.netG = DDP(
self.netG,
device_ids=[self.config['local_rank']],
output_device=self.config['local_rank'],
broadcast_buffers=True,
find_unused_parameters=False
)
self.netD = DDP(
self.netD,
device_ids=[self.config['local_rank']],
output_device=self.config['local_rank'],
broadcast_buffers=True,
find_unused_parameters=False
)
# 设置日志记录器
self.dis_writer = None # 判别器写入器
self.gen_writer = None # 生成器写入器
self.summary = {} # 存放摘要统计
if self.config['global_rank'] == 0 or (not config['distributed']):
# 如果不是分布式训练或者为分布式训练的主节点
self.dis_writer = SummaryWriter(
os.path.join(config['save_dir'], 'dis')
)
self.gen_writer = SummaryWriter(
os.path.join(config['save_dir'], 'gen')
)
# 获取当前学习率
def get_lr(self):
return self.optimG.param_groups[0]['lr']
# 调整学习率
def adjust_learning_rate(self):
# 计算衰减的学习率
decay = 0.1 ** (min(self.iteration, self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
new_lr = self.config['trainer']['lr'] * decay
# 如果新的学习率和当前学习率不同,则更新优化器中的学习率
if new_lr != self.get_lr():
for param_group in self.optimG.param_groups:
param_group['lr'] = new_lr
for param_group in self.optimD.param_groups:
param_group['lr'] = new_lr
# 添加摘要信息
def add_summary(self, writer, name, val):
# 添加并更新统计信息,每次迭代都累加
if name not in self.summary:
self.summary[name] = 0
self.summary[name] += val
# 每100次迭代记录一次
if writer is not None and self.iteration % 100 == 0:
writer.add_scalar(name, self.summary[name] / 100, self.iteration)
self.summary[name] = 0
# 加载模型netG and netD
def load(self):
model_path = self.config['save_dir'] # 模型的保存路径
# 检测是否存在最近的模型检查点
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
# 读取最后一个epoch的编号
latest_epoch = open(os.path.join(
model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
else:
# 如果不存在latest.ckpt尝试读取存储好的模型文件列表获取最近的一个
ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
os.path.join(model_path, '*.pth'))]
ckpts.sort() # 排序模型文件,以获取最近的一个
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None # 获取最近的epoch值
if latest_epoch is not None:
# 拼接得到生成器和判别器的模型文件路径
gen_path = os.path.join(
model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
dis_path = os.path.join(
model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
opt_path = os.path.join(
model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
# 如果是主节点,输出加载模型的信息
if self.config['global_rank'] == 0:
print('Loading model from {}...'.format(gen_path))
# 加载生成器模型
data = torch.load(gen_path, map_location=self.config['device'])
self.netG.load_state_dict(data['netG'])
# 加载判别器模型
data = torch.load(dis_path, map_location=self.config['device'])
self.netD.load_state_dict(data['netD'])
# 加载优化器状态
data = torch.load(opt_path, map_location=self.config['device'])
self.optimG.load_state_dict(data['optimG'])
self.optimD.load_state_dict(data['optimD'])
# 更新当前epoch和迭代次数
self.epoch = data['epoch']
self.iteration = data['iteration']
else:
# 如果没有找到模型文件,则输出警告信息
if self.config['global_rank'] == 0:
print('Warning: There is no trained model found. An initialized model will be used.')
# 保存模型参数,每次评估周期 (eval_epoch) 调用一次
def save(self, it):
# 只在全局排名为0的进程上执行保存操作通常代表主节点
if self.config['global_rank'] == 0:
# 生成保存生成器模型状态字典的文件路径
gen_path = os.path.join(
self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
# 生成保存判别器模型状态字典的文件路径
dis_path = os.path.join(
self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
# 生成保存优化器状态字典的文件路径
opt_path = os.path.join(
self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
# 打印消息表示模型正在保存
print('\nsaving model to {} ...'.format(gen_path))
# 判断模型是否是经过DataParallel或DDP包装的若是则获取原始的模型
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
netG = self.netG.module
netD = self.netD.module
else:
netG = self.netG
netD = self.netD
# 保存生成器和判别器的模型参数
torch.save({'netG': netG.state_dict()}, gen_path)
torch.save({'netD': netD.state_dict()}, dis_path)
# 保存当前的epoch、迭代次数和优化器的状态
torch.save({
'epoch': self.epoch,
'iteration': self.iteration,
'optimG': self.optimG.state_dict(),
'optimD': self.optimD.state_dict()
}, opt_path)
# 写入最新的迭代次数到"latest.ckpt"文件
os.system('echo {} > {}'.format(str(it).zfill(5),
os.path.join(self.config['save_dir'], 'latest.ckpt')))
# 训练入口
def train(self):
# 初始化进度条范围
pbar = range(int(self.train_args['iterations']))
# 如果是全局rank 0的进程则设置显示进度条
if self.config['global_rank'] == 0:
pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
# 开始训练循环
while True:
self.epoch += 1 # epoch计数增加
if self.config['distributed']:
# 如果是分布式训练,则对采样器进行设置,保证每个进程获取的数据不同
self.train_sampler.set_epoch(self.epoch)
# 调用训练一个epoch的函数
self._train_epoch(pbar)
# 如果迭代次数超过配置中的迭代上限,则退出循环
if self.iteration > self.train_args['iterations']:
break
# 训练结束输出
print('\nEnd training....')
# 每个训练周期处理输入并计算损失
def _train_epoch(self, pbar):
device = self.config['device'] # 获取设备信息
# 遍历数据加载器中的数据
for frames, masks in self.train_loader:
# 调整学习率
self.adjust_learning_rate()
# 迭代次数+1
self.iteration += 1
# 将frames和masks转移到设备上
frames, masks = frames.to(device), masks.to(device)
b, t, c, h, w = frames.size() # 获取帧和蒙版的尺寸
masked_frame = (frames * (1 - masks).float()) # 应用蒙版到图像
pred_img = self.netG(masked_frame, masks) # 使用生成器生成填充图像
# 调整frames和masks的维度以符合网络的输入要求
frames = frames.view(b * t, c, h, w)
masks = masks.view(b * t, 1, h, w)
comp_img = frames * (1. - masks) + masks * pred_img # 生成最终的组合图像
gen_loss = 0 # 初始化生成器损失
dis_loss = 0 # 初始化判别器损失
# 判别器对抗性损失
real_vid_feat = self.netD(frames) # 判别器对真实图像判别
fake_vid_feat = self.netD(comp_img.detach()) # 判别器对生成图像判别注意detach是为了不计算梯度
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) # 真实图像的损失
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) # 生成图像的损失
dis_loss += (dis_real_loss + dis_fake_loss) / 2 # 求平均的判别器损失
# 添加判别器损失到摘要
self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
# 优化判别器
self.optimD.zero_grad()
dis_loss.backward()
self.optimD.step()
# 生成器对抗性损失
gen_vid_feat = self.netD(comp_img)
gan_loss = self.adversarial_loss(gen_vid_feat, True, False) # 生成器的对抗损失
gan_loss = gan_loss * self.config['losses']['adversarial_weight'] # 权重放大
gen_loss += gan_loss # 累加到生成器损失
# 添加生成器对抗性损失到摘要
self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
# 生成器L1损失
hole_loss = self.l1_loss(pred_img * masks, frames * masks) # 只计算有蒙版区域的损失
# 考虑蒙版的平均值乘以配置中的hole_weight
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
gen_loss += hole_loss # 累加到生成器损失
# 添加hole_loss到摘要
self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
# 计算蒙版外区域的L1损失
valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
# 考虑非蒙版区的平均值乘以配置中的valid_weight
valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
gen_loss += valid_loss # 累加到生成器损失
# 添加valid_loss到摘要
self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
# 生成器优化
self.optimG.zero_grad()
gen_loss.backward()
self.optimG.step()
# 控制台日志输出
if self.config['global_rank'] == 0:
pbar.update(1) # 进度条更新
pbar.set_description(( # 设置进度条描述
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" # 打印损失数值
f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
)
# 模型保存
if self.iteration % self.train_args['save_freq'] == 0:
self.save(int(self.iteration // self.train_args['save_freq']))
# 迭代次数终止判断
if self.iteration > self.train_args['iterations']:
break

View File

@@ -0,0 +1,271 @@
import os
import matplotlib.patches as patches
from matplotlib.path import Path
import io
import cv2
import random
import zipfile
import numpy as np
from PIL import Image, ImageOps
import torch
import matplotlib
from matplotlib import pyplot as plt
matplotlib.use('agg')
class ZipReader(object):
file_dict = dict()
def __init__(self):
super(ZipReader, self).__init__()
@staticmethod
def build_file_dict(path):
file_dict = ZipReader.file_dict
if path in file_dict:
return file_dict[path]
else:
file_handle = zipfile.ZipFile(path, 'r')
file_dict[path] = file_handle
return file_dict[path]
@staticmethod
def imread(path, image_name):
zfile = ZipReader.build_file_dict(path)
data = zfile.read(image_name)
im = Image.open(io.BytesIO(data))
return im
class GroupRandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __init__(self, is_flow=False):
self.is_flow = is_flow
def __call__(self, img_group, is_flow=False):
v = random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
if self.is_flow:
for i in range(0, len(ret), 2):
# invert flow pixel values when flipping
ret[i] = ImageOps.invert(ret[i])
return ret
else:
return img_group
class Stack(object):
def __init__(self, roll=False):
self.roll = roll
def __call__(self, img_group):
mode = img_group[0].mode
if mode == '1':
img_group = [img.convert('L') for img in img_group]
mode = 'L'
if mode == 'L':
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
elif mode == 'RGB':
if self.roll:
return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
else:
return np.stack(img_group, axis=2)
else:
raise NotImplementedError(f"Image mode {mode}")
class ToTorchFormatTensor(object):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
def __init__(self, div=True):
self.div = div
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# numpy img: [L, C, H, W]
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
else:
# handle PIL Image
img = torch.ByteTensor(
torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
img = img.float().div(255) if self.div else img.float()
return img
def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432):
# get a random shape
height = random.randint(imageHeight//3, imageHeight-1)
width = random.randint(imageWidth//3, imageWidth-1)
edge_num = random.randint(6, 8)
ratio = random.randint(6, 8)/10
region = get_random_shape(
edge_num=edge_num, ratio=ratio, height=height, width=width)
region_width, region_height = region.size
# get random position
x, y = random.randint(
0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
velocity = get_random_velocity(max_speed=3)
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
masks = [m.convert('L')]
# return fixed masks
if random.uniform(0, 1) > 0.5:
return masks*video_length
# return moving masks
for _ in range(video_length-1):
x, y, velocity = random_move_control_points(
x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
m = Image.fromarray(
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
masks.append(m.convert('L'))
return masks
def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
'''
There is the initial point and 3 points per cubic bezier curve.
Thus, the curve will only pass though n points, which will be the sharp edges.
The other 2 modify the shape of the bezier curve.
edge_num, Number of possibly sharp edges
points_num, number of points in the Path
ratio, (0, 1) magnitude of the perturbation from the unit circle,
'''
points_num = edge_num*3 + 1
angles = np.linspace(0, 2*np.pi, points_num)
codes = np.full(points_num, Path.CURVE4)
codes[0] = Path.MOVETO
# Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
verts = np.stack((np.cos(angles), np.sin(angles))).T * \
(2*ratio*np.random.random(points_num)+1-ratio)[:, None]
verts[-1, :] = verts[0, :]
path = Path(verts, codes)
# draw paths into images
fig = plt.figure()
ax = fig.add_subplot(111)
patch = patches.PathPatch(path, facecolor='black', lw=2)
ax.add_patch(patch)
ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
ax.axis('off') # removes the axis to leave only the shape
fig.canvas.draw()
# convert plt images into numpy images
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
plt.close(fig)
# postprocess
data = cv2.resize(data, (width, height))[:, :, 0]
data = (1 - np.array(data > 0).astype(np.uint8))*255
corrdinates = np.where(data > 0)
xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
return region
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
speed, angle = velocity
d_speed, d_angle = maxAcceleration
if dist == 'uniform':
speed += np.random.uniform(-d_speed, d_speed)
angle += np.random.uniform(-d_angle, d_angle)
elif dist == 'guassian':
speed += np.random.normal(0, d_speed / 2)
angle += np.random.normal(0, d_angle / 2)
else:
raise NotImplementedError(
f'Distribution type {dist} is not supported.')
return (speed, angle)
def get_random_velocity(max_speed=3, dist='uniform'):
if dist == 'uniform':
speed = np.random.uniform(max_speed)
elif dist == 'guassian':
speed = np.abs(np.random.normal(0, max_speed / 2))
else:
raise NotImplementedError(
f'Distribution type {dist} is not supported.')
angle = np.random.uniform(0, 2 * np.pi)
return (speed, angle)
def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3):
region_width, region_height = region_size
speed, angle = lineVelocity
X += int(speed * np.cos(angle))
Y += int(speed * np.sin(angle))
lineVelocity = random_accelerate(
lineVelocity, maxLineAcceleration, dist='guassian')
if (X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0):
lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
new_X = np.clip(X, 0, imageHeight - region_height)
new_Y = np.clip(Y, 0, imageWidth - region_width)
return new_X, new_Y, lineVelocity
def get_world_size():
"""Find OMPI world size without calling mpi functions
:rtype: int
"""
if os.environ.get('PMI_SIZE') is not None:
return int(os.environ.get('PMI_SIZE') or 1)
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
else:
return torch.cuda.device_count()
def get_global_rank():
"""Find OMPI world rank without calling mpi functions
:rtype: int
"""
if os.environ.get('PMI_RANK') is not None:
return int(os.environ.get('PMI_RANK') or 0)
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
else:
return 0
def get_local_rank():
"""Find OMPI local rank without calling mpi functions
:rtype: int
"""
if os.environ.get('MPI_LOCALRANKID') is not None:
return int(os.environ.get('MPI_LOCALRANKID') or 0)
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
else:
return 0
def get_master_ip():
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
else:
return "127.0.0.1"
if __name__ == '__main__':
trials = 10
for _ in range(trials):
video_length = 10
# The returned masks are either stationary (50%) or moving (50%)
masks = create_random_shape_with_random_motion(
video_length, imageHeight=240, imageWidth=432)
for m in masks:
cv2.imshow('mask', np.array(m))
cv2.waitKey(500)

BIN
design/paper_intro.pdf Normal file

Binary file not shown.

BIN
design/paper_sttn.pdf Normal file

Binary file not shown.

BIN
design/sponsor.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 794 KiB

62
gui.py
View File

@@ -15,14 +15,15 @@ import multiprocessing
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import backend.main
from backend.tools.common_tools import is_image_file
class SubtitleRemoverGUI:
def __init__(self):
# 初次运行检查运行环境是否正常
from paddle import fluid
fluid.install_check.run_check()
from paddle import utils
utils.run_check()
self.font = 'Arial 10'
self.theme = 'LightBrown12'
sg.theme(self.theme)
@@ -233,6 +234,17 @@ class SubtitleRemoverGUI:
self.window['-X-SLIDER-W-'].update(w)
self._update_preview(frame, (y, h, x, w))
def __disable_button(self):
# 1) 禁止修改字幕滑块区域
self.window['-Y-SLIDER-'].update(disabled=True)
self.window['-X-SLIDER-'].update(disabled=True)
self.window['-Y-SLIDER-H-'].update(disabled=True)
self.window['-X-SLIDER-W-'].update(disabled=True)
# 2) 禁止再次点击【运行】、【打开】和【识别语言】按钮
self.window['-RUN-'].update(disabled=True)
self.window['-FILE-'].update(disabled=True)
self.window['-FILE_BTN-'].update(disabled=True)
def _run_event_handler(self, event, values):
"""
当点击运行按钮时:
@@ -244,15 +256,8 @@ class SubtitleRemoverGUI:
if self.video_cap is None:
print('Please Open Video First')
else:
# 1) 禁止修改字幕滑块区域
self.window['-Y-SLIDER-'].update(disabled=True)
self.window['-X-SLIDER-'].update(disabled=True)
self.window['-Y-SLIDER-H-'].update(disabled=True)
self.window['-X-SLIDER-W-'].update(disabled=True)
# 2) 禁止再次点击【运行】、【打开】和【识别语言】按钮
self.window['-RUN-'].update(disabled=True)
self.window['-FILE-'].update(disabled=True)
self.window['-FILE_BTN-'].update(disabled=True)
# 禁用按钮
self.__disable_button()
# 3) 设定字幕区域位置
self.xmin = int(values['-X-SLIDER-'])
self.xmax = int(values['-X-SLIDER-'] + values['-X-SLIDER-W-'])
@@ -262,8 +267,23 @@ class SubtitleRemoverGUI:
self.ymax = self.frame_height
if self.xmax > self.frame_width:
self.xmax = self.frame_width
print(f"{'SubtitleArea'}({self.ymin},{self.ymax},{self.xmin},{self.xmax})")
subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax)
if len(self.video_paths) <= 1:
subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax)
else:
print(f"{'Processing multiple videos or images'}")
# 先判断每个视频的分辨率是否一致一致的话设置相同的字幕区域否则设置为None
global_size = None
for temp_video_path in self.video_paths:
temp_cap = cv2.VideoCapture(temp_video_path)
if global_size is None:
global_size = (int(temp_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(temp_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
else:
temp_size = (int(temp_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(temp_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
if temp_size != global_size:
print('not all video/images in same size, processing in full screen')
subtitle_area = None
else:
subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax)
y_p = self.ymin / self.frame_height
h_p = (self.ymax - self.ymin) / self.frame_height
x_p = self.xmin / self.frame_width
@@ -273,7 +293,10 @@ class SubtitleRemoverGUI:
def task():
while self.video_paths:
video_path = self.video_paths.pop()
if subtitle_area is not None:
print(f"{'SubtitleArea'}({self.ymin},{self.ymax},{self.xmin},{self.xmax})")
self.sr = backend.main.SubtitleRemover(video_path, subtitle_area, True)
self.__disable_button()
self.sr.run()
Thread(target=task, daemon=True).start()
self.video_cap.release()
@@ -287,7 +310,18 @@ class SubtitleRemoverGUI:
"""
if event == '-SLIDER-' or event == '-Y-SLIDER-' or event == '-Y-SLIDER-H-' or event == '-X-SLIDER-' or event \
== '-X-SLIDER-W-':
if self.video_cap is not None and self.video_cap.isOpened():
# 判断是否时单张图片
if is_image_file(self.video_path):
img = cv2.imread(self.video_path)
self.window['-Y-SLIDER-H-'].update(range=(0, self.frame_height - values['-Y-SLIDER-']))
self.window['-X-SLIDER-W-'].update(range=(0, self.frame_width - values['-X-SLIDER-']))
# 画字幕框
y = int(values['-Y-SLIDER-'])
h = int(values['-Y-SLIDER-H-'])
x = int(values['-X-SLIDER-'])
w = int(values['-X-SLIDER-W-'])
self._update_preview(img, (y, h, x, w))
elif self.video_cap is not None and self.video_cap.isOpened():
frame_no = int(values['-SLIDER-'])
self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
ret, frame = self.video_cap.read()

View File

@@ -9,7 +9,7 @@ lmdb==1.4.1
PyYAML==6.0.1
omegaconf==2.1.2
tqdm==4.66.1
PySimpleGUI==4.55.1
PySimpleGUI==4.70.1
easydict==1.9
scikit-learn==0.24.2
pandas==2.0.3
@@ -18,4 +18,4 @@ pytorch-lightning==1.2.9
numpy==1.23.1
protobuf==3.20.0
av==11.0.0
einops==0.7.0
einops==0.7.0

BIN
test/test4.mp4 Normal file

Binary file not shown.