mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-13 02:54:45 +08:00
Compare commits
70 Commits
1.1.0
...
revert-130
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53baf28326 | ||
|
|
991294c00f | ||
|
|
ea7e01e3aa | ||
|
|
acdb150aa2 | ||
|
|
285bfbafa7 | ||
|
|
97b4159d38 | ||
|
|
bb80445cf4 | ||
|
|
7e8d0b818b | ||
|
|
77758d258b | ||
|
|
c60234f4ec | ||
|
|
38ff91fad7 | ||
|
|
9f9cded1ff | ||
|
|
54027ceeb0 | ||
|
|
1dc9036dee | ||
|
|
09dbfa47f2 | ||
|
|
aa83db0f98 | ||
|
|
f86c8c9fe8 | ||
|
|
3dc8f3bfe0 | ||
|
|
b0ca454473 | ||
|
|
9f7fd5b341 | ||
|
|
535fdecef4 | ||
|
|
019f7f4517 | ||
|
|
8c5ea2e19d | ||
|
|
330cf54e1a | ||
|
|
7019572f7b | ||
|
|
ee53840adb | ||
|
|
96d744b3a7 | ||
|
|
32c47873ab | ||
|
|
99770a32b9 | ||
|
|
0f71d732e1 | ||
|
|
f3a982710d | ||
|
|
e07849ef87 | ||
|
|
96099ea2d4 | ||
|
|
f4c22dd420 | ||
|
|
a3452832ff | ||
|
|
45e80bc9b0 | ||
|
|
4a09342987 | ||
|
|
caf4cb27f4 | ||
|
|
c927476c0f | ||
|
|
61aa3d8f88 | ||
|
|
67fdacdd8b | ||
|
|
3d21963995 | ||
|
|
a3dd7b797d | ||
|
|
6b353455a0 | ||
|
|
d6736d9206 | ||
|
|
4abc3409ac | ||
|
|
2d1eb11fd6 | ||
|
|
9a65c17a50 | ||
|
|
3ce8d7409b | ||
|
|
f9dd30fddf | ||
|
|
fda9024084 | ||
|
|
19141ff5c9 | ||
|
|
97b54f6d9e | ||
|
|
584e574795 | ||
|
|
dad37eba7d | ||
|
|
063a896cb9 | ||
|
|
63d8378f36 | ||
|
|
4cbfa9ebf0 | ||
|
|
8a8088be1f | ||
|
|
757cc5bf77 | ||
|
|
e536d6af86 | ||
|
|
311701d3e6 | ||
|
|
a7e62db98a | ||
|
|
945aeb9bc8 | ||
|
|
6ea7482344 | ||
|
|
ba396d9569 | ||
|
|
22b021d9ae | ||
|
|
49ae0029f5 | ||
|
|
f89c109636 | ||
|
|
055a08403f |
14
.condarc
14
.condarc
@@ -2,13 +2,9 @@ channels:
|
||||
- defaults
|
||||
show_channel_urls: true
|
||||
default_channels:
|
||||
- http://mirrors.aliyun.com/anaconda/pkgs/main
|
||||
- http://mirrors.aliyun.com/anaconda/pkgs/r
|
||||
- http://mirrors.aliyun.com/anaconda/pkgs/msys2
|
||||
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
|
||||
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
|
||||
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2
|
||||
custom_channels:
|
||||
conda-forge: http://mirrors.aliyun.com/anaconda/cloud
|
||||
msys2: http://mirrors.aliyun.com/anaconda/cloud
|
||||
bioconda: http://mirrors.aliyun.com/anaconda/cloud
|
||||
menpo: http://mirrors.aliyun.com/anaconda/cloud
|
||||
pytorch: http://mirrors.aliyun.com/anaconda/cloud
|
||||
simpleitk: http://mirrors.aliyun.com/anaconda/cloud
|
||||
conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
|
||||
pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -372,3 +372,4 @@ test*_no_sub*.mp4
|
||||
/backend/models/video/ProPainter.pth
|
||||
/backend/models/big-lama/big-lama.pt
|
||||
/test/debug/
|
||||
/backend/tools/train/release_model/
|
||||
|
||||
101
README.md
101
README.md
@@ -12,12 +12,13 @@ Video-subtitle-remover (VSR) 是一款基于AI技术,将视频中的硬字幕
|
||||
- 通过超强AI算法模型,对去除字幕文本的区域进行填充(非相邻像素填充与马赛克去除)
|
||||
- 支持自定义字幕位置,仅去除定义位置中的字幕(传入位置)
|
||||
- 支持全视频自动去除所有文本(不传入位置)
|
||||
- 支持多选图片批量去除水印文本
|
||||
|
||||
<p style="text-align:center;"><img src="https://github.com/YaoFANGUK/video-subtitle-remover/raw/main/design/demo.png" alt="demo.png"/></p>
|
||||
|
||||
**使用说明:**
|
||||
|
||||
- 有使用问题请加群讨论,QQ群:806152575
|
||||
- 有使用问题请加群讨论,QQ群:806152575(已满)、816881808
|
||||
- 直接下载压缩包解压运行,如果不能运行再按照下面的教程,尝试源码安装conda环境运行
|
||||
|
||||
**下载地址:**
|
||||
@@ -111,9 +112,9 @@ conda activate videoEnv
|
||||
<h5>(1) 下载CUDA 11.7</h5>
|
||||
<a href="https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_516.01_windows.exe">cuda_11.7.0_516.01_windows.exe</a>
|
||||
<h5>(2) 安装CUDA 11.7</h5>
|
||||
<h5>(3) 下载cuDNN 8.2.4</h5>
|
||||
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x64-v8.2.4.15.zip">cudnn-windows-x64-v8.2.4.15.zip</a></p>
|
||||
<h5>(4) 安装cuDNN 8.2.4</h5>
|
||||
<h5>(3) 下载cuDNN v8.4.0 (April 1st, 2022), for CUDA 11.x</h5>
|
||||
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip">cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip</a></p>
|
||||
<h5>(4) 安装cuDNN 8.4.0</h5>
|
||||
<p>
|
||||
将cuDNN解压后的cuda文件夹中的bin, include, lib目录下的文件复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\对应目录下
|
||||
</p>
|
||||
@@ -136,12 +137,12 @@ conda activate videoEnv
|
||||
|
||||
- 安装GPU版本Pytorch:
|
||||
|
||||
```shell
|
||||
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
|
||||
```shell
|
||||
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||
```
|
||||
或者使用
|
||||
```shell
|
||||
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117
|
||||
```shell
|
||||
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
- 安装其他依赖:
|
||||
@@ -166,29 +167,85 @@ python ./backend/main.py
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
1. CondaHTTPError
|
||||
1. 提取速度慢怎么办
|
||||
|
||||
修改backend/config.py中的参数,可以大幅度提高去除速度
|
||||
```python
|
||||
MODE = InpaintMode.STTN # 设置为STTN算法
|
||||
STTN_SKIP_DETECTION = True # 跳过字幕检测,跳过后可能会导致要去除的字幕遗漏或者误伤不需要去除字幕的视频帧
|
||||
```
|
||||
|
||||
2. 视频去除效果不好怎么办
|
||||
|
||||
修改backend/config.py中的参数,尝试不同的去除算法,算法介绍
|
||||
|
||||
> - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
|
||||
> - InpaintMode.LAMA 算法:对于图片效果最好,对动画类视频效果好,速度一般,不可以跳过字幕检测
|
||||
> - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
|
||||
|
||||
- 使用STTN算法
|
||||
|
||||
```python
|
||||
MODE = InpaintMode.STTN # 设置为STTN算法
|
||||
# 相邻帧数, 调大会增加显存占用,效果变好
|
||||
STTN_NEIGHBOR_STRIDE = 10
|
||||
# 参考帧长度, 调大会增加显存占用,效果变好
|
||||
STTN_REFERENCE_LENGTH = 10
|
||||
# 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好
|
||||
# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
|
||||
STTN_MAX_LOAD_NUM = 30
|
||||
```
|
||||
- 使用LAMA算法
|
||||
```python
|
||||
MODE = InpaintMode.LAMA # 设置为STTN算法
|
||||
LAMA_SUPER_FAST = False # 保证效果
|
||||
```
|
||||
|
||||
> 如果对模型去字幕的效果不满意,可以查看design文件夹里面的训练方法,利用backend/tools/train里面的代码进行训练,然后将训练的模型替换旧模型即可
|
||||
|
||||
3. CondaHTTPError
|
||||
|
||||
将项目中的.condarc放在用户目录下(C:/Users/<你的用户名>),如果用户目录已经存在该文件则覆盖
|
||||
|
||||
解决方案:https://zhuanlan.zhihu.com/p/260034241
|
||||
|
||||
2. 7z文件解压错误
|
||||
4. 7z文件解压错误
|
||||
|
||||
解决方案:升级7-zip解压程序到最新版本
|
||||
|
||||
3. 4090使用cuda 11.7跑不起来
|
||||
5. 4090使用cuda 11.7跑不起来
|
||||
|
||||
解决方案:改用cuda 11.8
|
||||
|
||||
## 赞助
|
||||
<img src="https://i.imgur.com/EMCP5Lv.jpeg" width="600">
|
||||
```shell
|
||||
pip install torch==2.1.0 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
| 捐赠者 | 累计捐赠金额 | 赞助席位 |
|
||||
| --- | --- | --- |
|
||||
| 坤V | 400.00 RMB | 金牌赞助席位 |
|
||||
| 陈凯 | 50.00 RMB | 银牌赞助席位 |
|
||||
| Tshuang | 20.00 RMB | 银牌赞助席位 |
|
||||
| 很奇异| 15.00 RMB | 银牌赞助席位 |
|
||||
| 何斐| 10.00 RMB | 银牌赞助席位 |
|
||||
| 长缨在手| 6.00 RMB | 银牌赞助席位 |
|
||||
| Leo| 1.00 RMB | 铜牌赞助席位 |
|
||||
## 赞助
|
||||
|
||||
<img src="https://github.com/YaoFANGUK/video-subtitle-extractor/raw/main/design/sponsor.png" width="600">
|
||||
|
||||
| 捐赠者 | 累计捐赠金额 | 赞助席位 |
|
||||
|---------------------------|------------| --- |
|
||||
| 坤V | 400.00 RMB | 金牌赞助席位 |
|
||||
| Jenkit | 200.00 RMB | 金牌赞助席位 |
|
||||
| 子车松兰 | 188.00 RMB | 金牌赞助席位 |
|
||||
| 落花未逝 | 100.00 RMB | 金牌赞助席位 |
|
||||
| 张音乐 | 100.00 RMB | 金牌赞助席位 |
|
||||
| 麦格 | 100.00 RMB | 金牌赞助席位 |
|
||||
| 无痕 | 100.00 RMB | 金牌赞助席位 |
|
||||
| wr | 100.00 RMB | 金牌赞助席位 |
|
||||
| 陈 | 100.00 RMB | 金牌赞助席位 |
|
||||
| TalkLuv | 50.00 RMB | 银牌赞助席位 |
|
||||
| 陈凯 | 50.00 RMB | 银牌赞助席位 |
|
||||
| Tshuang | 20.00 RMB | 银牌赞助席位 |
|
||||
| 很奇异 | 15.00 RMB | 银牌赞助席位 |
|
||||
| 郭鑫 | 12.00 RMB | 银牌赞助席位 |
|
||||
| 生活不止眼前的苟且 | 10.00 RMB | 铜牌赞助席位 |
|
||||
| 何斐 | 10.00 RMB | 铜牌赞助席位 |
|
||||
| 老猫 | 8.80 RMB | 铜牌赞助席位 |
|
||||
| 伍六七 | 7.77 RMB | 铜牌赞助席位 |
|
||||
| 长缨在手 | 6.00 RMB | 铜牌赞助席位 |
|
||||
| 无忌 | 6.00 RMB | 铜牌赞助席位 |
|
||||
| Stephen | 2.00 RMB | 铜牌赞助席位 |
|
||||
| Leo | 1.00 RMB | 铜牌赞助席位 |
|
||||
|
||||
59
README_en.md
59
README_en.md
@@ -12,6 +12,7 @@ Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subt
|
||||
- Fills in the removed subtitle text area using a powerful AI algorithm model (non-adjacent pixel filling and mosaic removal).
|
||||
- Supports custom subtitle positions by only removing subtitles in the defined location (input position).
|
||||
- Supports automatic removal of all text throughout the entire video (without inputting a position).
|
||||
- Supports multi-selection of images for batch removal of watermark text.
|
||||
|
||||
<p style="text-align:center;"><img src="https://github.com/YaoFANGUK/video-subtitle-remover/raw/main/design/demo.png" alt="demo.png"/></p>
|
||||
|
||||
@@ -110,11 +111,11 @@ Please make sure you have already installed Python 3.8+, use conda to create a p
|
||||
<h5>(1) Download CUDA 11.7</h5>
|
||||
<a href="https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_516.01_windows.exe">cuda_11.7.0_516.01_windows.exe</a>
|
||||
<h5>(2) Install CUDA 11.7</h5>
|
||||
<h5>(3) Download cuDNN 8.2.4</h5>
|
||||
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x64-v8.2.4.15.zip">cudnn-windows-x64-v8.2.4.15.zip</a></p>
|
||||
<h5>(4) Install cuDNN 8.2.4</h5>
|
||||
<h5>(3) Download cuDNN 8.4.0</h5>
|
||||
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip">cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip</a></p>
|
||||
<h5>(4) Install cuDNN 8.4.0</h5>
|
||||
<p>
|
||||
unzip "cudnn-windows-x64-v8.2.4.15.zip", then move all files in "bin, include, lib" in cuda
|
||||
unzip "cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip", then move all files in "bin, include, lib" in cuda
|
||||
directory to C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\
|
||||
</p>
|
||||
</details>
|
||||
@@ -136,12 +137,12 @@ Please make sure you have already installed Python 3.8+, use conda to create a p
|
||||
- Install GPU version of Pytorch:
|
||||
|
||||
```shell
|
||||
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
|
||||
conda install pytorch==2.1.0 torchvision==0.16.0 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||
```
|
||||
or use
|
||||
|
||||
```shell
|
||||
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117
|
||||
pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
- Install other dependencies:
|
||||
@@ -166,12 +167,54 @@ python ./backend/main.py
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
1. CondaHTTPError
|
||||
|
||||
1. How to deal with slow removal speed
|
||||
|
||||
You can greatly increase the removal speed by modifying the parameters in backend/config.py:
|
||||
|
||||
```python
|
||||
MODE = InpaintMode.STTN # Set to STTN algorithm
|
||||
STTN_SKIP_DETECTION = True # Skip subtitle detection
|
||||
```
|
||||
|
||||
2. What to do if the video removal results are not satisfactory
|
||||
|
||||
Modify the values in backend/config.py and try different removal algorithms. Here is an introduction to the algorithms:
|
||||
|
||||
> - **InpaintMode.STTN** algorithm: Good for live-action videos and fast in speed, capable of skipping subtitle detection
|
||||
> - **InpaintMode.LAMA** algorithm: Best for images and effective for animated videos, moderate speed, unable to skip subtitle detection
|
||||
> - **InpaintMode.PROPAINTER** algorithm: Consumes a significant amount of VRAM, slower in speed, works better for videos with very intense movement
|
||||
|
||||
- Using the STTN algorithm
|
||||
|
||||
```python
|
||||
MODE = InpaintMode.STTN # Set to STTN algorithm
|
||||
# Number of neighboring frames, increasing this will increase memory usage and improve the result
|
||||
STTN_NEIGHBOR_STRIDE = 10
|
||||
# Length of reference frames, increasing this will increase memory usage and improve the result
|
||||
STTN_REFERENCE_LENGTH = 10
|
||||
# Set the maximum number of frames processed simultaneously by the STTN algorithm, a larger value leads to slower processing but better results
|
||||
# Ensure that STTN_MAX_LOAD_NUM is greater than STTN_NEIGHBOR_STRIDE and STTN_REFERENCE_LENGTH
|
||||
STTN_MAX_LOAD_NUM = 30
|
||||
```
|
||||
- Using the LAMA algorithm
|
||||
|
||||
```python
|
||||
MODE = InpaintMode.LAMA # Set to LAMA algorithm
|
||||
LAMA_SUPER_FAST = False # Ensure quality
|
||||
```
|
||||
|
||||
|
||||
3. CondaHTTPError
|
||||
|
||||
Place the .condarc file from the project in the user directory (C:/Users/<your_username>). If the file already exists in the user directory, overwrite it.
|
||||
|
||||
Solution: https://zhuanlan.zhihu.com/p/260034241
|
||||
|
||||
2. 7z file extraction error
|
||||
4. 7z file extraction error
|
||||
|
||||
Solution: Upgrade the 7-zip extraction program to the latest version.
|
||||
|
||||
```shell
|
||||
pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
@@ -64,12 +64,17 @@ class InpaintMode(Enum):
|
||||
|
||||
|
||||
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
||||
# 是否使用h264编码,如果需要安卓手机分享生成的视频,请打开该选项
|
||||
USE_H264 = True
|
||||
|
||||
# ×××××××××× 通用设置 start ××××××××××
|
||||
"""
|
||||
MODE可选算法类型
|
||||
- InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
|
||||
- InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测
|
||||
- InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
|
||||
"""
|
||||
# 【设置inpaint算法】
|
||||
# - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
|
||||
# - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以字幕检测
|
||||
# - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
|
||||
MODE = InpaintMode.STTN
|
||||
# 【设置像素点偏差】
|
||||
# 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
|
||||
@@ -85,17 +90,33 @@ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
|
||||
|
||||
# ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
|
||||
# 以下参数仅适用STTN算法时,才生效
|
||||
# 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧
|
||||
"""
|
||||
1. STTN_SKIP_DETECTION
|
||||
含义:是否使用跳过检测
|
||||
效果:设置为True跳过字幕检测,会省去很大时间,但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了
|
||||
|
||||
2. STTN_NEIGHBOR_STRIDE
|
||||
含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域,STTN_NEIGHBOR_STRIDE=5,那么算法会使用第45帧、第40帧等作为参照。
|
||||
效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。
|
||||
|
||||
3. STTN_REFERENCE_LENGTH
|
||||
含义:参数帧数量,STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息
|
||||
效果:调大会增加显存占用,处理效果变好,但是处理速度变慢
|
||||
|
||||
4. STTN_MAX_LOAD_NUM
|
||||
含义:STTN算法每次最多加载的视频帧数量
|
||||
效果:设置越大速度越慢,但效果越好
|
||||
注意:要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
|
||||
"""
|
||||
STTN_SKIP_DETECTION = True
|
||||
# 相邻帧数, 调大会增加显存占用,效果变好
|
||||
STTN_NEIGHBOR_STRIDE = 10
|
||||
# 参考帧长度, 调大会增加显存占用,效果变好
|
||||
# 参考帧步长
|
||||
STTN_NEIGHBOR_STRIDE = 5
|
||||
# 参考帧长度(数量)
|
||||
STTN_REFERENCE_LENGTH = 10
|
||||
# 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好
|
||||
# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
|
||||
STTN_MAX_LOAD_NUM = 30
|
||||
if STTN_MAX_LOAD_NUM < max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH):
|
||||
STTN_MAX_LOAD_NUM = max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH)
|
||||
# 设置STTN算法最大同时处理的帧数量
|
||||
STTN_MAX_LOAD_NUM = 50
|
||||
if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE:
|
||||
STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE
|
||||
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
|
||||
|
||||
# ×××××××××× InpaintMode.PROPAINTER算法设置 start ××××××××××
|
||||
|
||||
@@ -93,6 +93,7 @@ class STTNInpaint:
|
||||
@staticmethod
|
||||
def read_mask(path):
|
||||
img = cv2.imread(path, 0)
|
||||
# 转为binary mask
|
||||
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
|
||||
img = img[:, :, None]
|
||||
return img
|
||||
@@ -200,6 +201,24 @@ class STTNInpaint:
|
||||
to_H -= h
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
|
||||
@staticmethod
|
||||
def get_inpaint_area_by_selection(input_sub_area, mask):
|
||||
print('use selection area for inpainting')
|
||||
height, width = mask.shape[:2]
|
||||
ymin, ymax, _, _ = input_sub_area
|
||||
interval_size = 135
|
||||
# 存储结果的列表
|
||||
inpaint_area = []
|
||||
# 计算并存储标准区间
|
||||
for i in range(ymin, ymax, interval_size):
|
||||
inpaint_area.append((i, i + interval_size))
|
||||
# 检查最后一个区间是否达到了最大值
|
||||
if inpaint_area[-1][1] != ymax:
|
||||
# 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值
|
||||
if inpaint_area[-1][1] + interval_size <= height:
|
||||
inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size))
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
|
||||
|
||||
class STTNVideoInpaint:
|
||||
|
||||
@@ -291,10 +310,10 @@ class STTNVideoInpaint:
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
writer.write(frame)
|
||||
if input_sub_remover is not None and input_sub_remover.gui_mode:
|
||||
if input_sub_remover is not None:
|
||||
if tbar is not None:
|
||||
input_sub_remover.update_progress(tbar, increment=1)
|
||||
if original_frame is not None:
|
||||
if original_frame is not None and input_sub_remover.gui_mode:
|
||||
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
|
||||
# 释放视频写入对象
|
||||
writer.release()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import shutil
|
||||
import subprocess
|
||||
import os
|
||||
@@ -5,9 +6,11 @@ from pathlib import Path
|
||||
import threading
|
||||
import cv2
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import config
|
||||
from backend.tools.common_tools import is_video_or_image, is_image_file
|
||||
from backend.scenedetect import scene_detect
|
||||
from backend.scenedetect.detectors import ContentDetector
|
||||
from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint
|
||||
@@ -17,7 +20,6 @@ from backend.tools.inpaint_tools import create_mask, batch_generator
|
||||
import importlib
|
||||
import platform
|
||||
import tempfile
|
||||
import torch
|
||||
import multiprocessing
|
||||
from shapely.geometry import Polygon
|
||||
import time
|
||||
@@ -257,7 +259,43 @@ class SubtitleDetect:
|
||||
return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
|
||||
|
||||
@staticmethod
|
||||
def expand_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
|
||||
def expand_and_merge_intervals(intervals, expand_size=config.STTN_NEIGHBOR_STRIDE*config.STTN_REFERENCE_LENGTH, max_length=config.STTN_MAX_LOAD_NUM):
|
||||
# 初始化输出区间列表
|
||||
expanded_intervals = []
|
||||
|
||||
# 对每个原始区间进行扩展
|
||||
for interval in intervals:
|
||||
start, end = interval
|
||||
|
||||
# 扩展至至少 'expand_size' 个单位,但不超过 'max_length' 个单位
|
||||
expansion_amount = max(expand_size - (end - start + 1), 0)
|
||||
|
||||
# 在保证包含原区间的前提下尽可能平分前后扩展量
|
||||
expand_start = max(start - expansion_amount // 2, 1) # 确保起始点不小于1
|
||||
expand_end = end + expansion_amount // 2
|
||||
|
||||
# 如果扩展后的区间超出了最大长度,进行调整
|
||||
if (expand_end - expand_start + 1) > max_length:
|
||||
expand_end = expand_start + max_length - 1
|
||||
|
||||
# 对于单点的处理,需额外保证有至少 'expand_size' 长度
|
||||
if start == end:
|
||||
if expand_end - expand_start + 1 < expand_size:
|
||||
expand_end = expand_start + expand_size - 1
|
||||
|
||||
# 检查与前一个区间是否有重叠并进行相应的合并
|
||||
if expanded_intervals and expand_start <= expanded_intervals[-1][1]:
|
||||
previous_start, previous_end = expanded_intervals.pop()
|
||||
expand_start = previous_start
|
||||
expand_end = max(expand_end, previous_end)
|
||||
|
||||
# 添加扩展后的区间至结果列表
|
||||
expanded_intervals.append((expand_start, expand_end))
|
||||
|
||||
return expanded_intervals
|
||||
|
||||
@staticmethod
|
||||
def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
|
||||
"""
|
||||
合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH
|
||||
"""
|
||||
@@ -290,12 +328,10 @@ class SubtitleDetect:
|
||||
for start, end in expanded[1:]:
|
||||
last_start, last_end = merged[-1]
|
||||
# 检查是否重叠
|
||||
if start <= last_end and (
|
||||
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
# 需要合并
|
||||
merged[-1] = (last_start, max(last_end, end)) # 合并区间
|
||||
elif start == last_end + 1 and (
|
||||
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
# 相邻区间也需要合并的场景
|
||||
merged[-1] = (last_start, end)
|
||||
else:
|
||||
@@ -483,7 +519,7 @@ class SubtitleRemover:
|
||||
self.gui_mode = gui_mode
|
||||
# 判断是否为图片
|
||||
self.is_picture = False
|
||||
if str(vd_path).endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
||||
if is_image_file(str(vd_path)):
|
||||
self.sub_area = None
|
||||
self.is_picture = True
|
||||
# 视频路径
|
||||
@@ -505,8 +541,7 @@ class SubtitleRemover:
|
||||
# 创建视频临时对象,windows下delete=True会有permission denied的报错
|
||||
self.video_temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
||||
# 创建视频写对象
|
||||
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps,
|
||||
self.size)
|
||||
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
|
||||
self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
|
||||
self.video_inpaint = None
|
||||
self.lama_inpaint = None
|
||||
@@ -647,16 +682,14 @@ class SubtitleRemover:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
inpainted_frame = self.lama_inpaint(frame, single_mask)
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(
|
||||
f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||||
inner_index += 1
|
||||
self.update_progress(tbar, increment=1)
|
||||
elif len(batch) > 1:
|
||||
inpainted_frames = self.video_inpaint.inpaint(batch, mask)
|
||||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(
|
||||
f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
|
||||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
|
||||
inner_index += 1
|
||||
if self.gui_mode:
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
@@ -670,12 +703,13 @@ class SubtitleRemover:
|
||||
print('[Processing] start removing subtitles...')
|
||||
if self.sub_area is not None:
|
||||
ymin, ymax, xmin, xmax = self.sub_area
|
||||
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
|
||||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||||
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
|
||||
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
|
||||
else:
|
||||
print('please set subtitle area first')
|
||||
print('[Info] No subtitle area has been set. Video will be processed in full screen. As a result, the final outcome might be suboptimal.')
|
||||
ymin, ymax, xmin, xmax = 0, self.frame_height, 0, self.frame_width
|
||||
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
|
||||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||||
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
|
||||
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
|
||||
|
||||
def sttn_mode(self, tbar):
|
||||
# 是否跳过字幕帧寻找
|
||||
@@ -688,7 +722,7 @@ class SubtitleRemover:
|
||||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||||
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
||||
print(continuous_frame_no_list)
|
||||
continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list)
|
||||
continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list)
|
||||
print(continuous_frame_no_list)
|
||||
start_end_map = dict()
|
||||
for interval in continuous_frame_no_list:
|
||||
@@ -847,7 +881,7 @@ class SubtitleRemover:
|
||||
audio_merge_command = [config.FFMPEG_PATH,
|
||||
"-y", "-i", self.video_temp_file.name,
|
||||
"-i", temp.name,
|
||||
"-vcodec", "copy",
|
||||
"-vcodec", "libx264" if config.USE_H264 else "copy",
|
||||
"-acodec", "copy",
|
||||
"-loglevel", "error", self.video_out_name]
|
||||
try:
|
||||
@@ -859,7 +893,10 @@ class SubtitleRemover:
|
||||
try:
|
||||
os.remove(temp.name)
|
||||
except Exception:
|
||||
print(f'failed to delete temp file {temp.name}')
|
||||
if platform.system() in ['Windows']:
|
||||
pass
|
||||
else:
|
||||
print(f'failed to delete temp file {temp.name}')
|
||||
self.is_successful_merged = True
|
||||
finally:
|
||||
temp.close()
|
||||
@@ -874,9 +911,13 @@ class SubtitleRemover:
|
||||
if __name__ == '__main__':
|
||||
multiprocessing.set_start_method("spawn")
|
||||
# 1. 提示用户输入视频路径
|
||||
video_path = input(f"Please input video file path: ").strip()
|
||||
video_path = input(f"Please input video or image file path: ").strip()
|
||||
# 判断视频路径是不是一个目录,是目录的化,批量处理改目录下的所有视频文件
|
||||
# 2. 按以下顺序传入字幕区域
|
||||
# sub_area = (ymin, ymax, xmin, xmax)
|
||||
# 3. 新建字幕提取对象
|
||||
sd = SubtitleRemover(video_path, sub_area=None)
|
||||
sd.run()
|
||||
if is_video_or_image(video_path):
|
||||
sd = SubtitleRemover(video_path, sub_area=None)
|
||||
sd.run()
|
||||
else:
|
||||
print(f'Invalid video path: {video_path}')
|
||||
|
||||
32
backend/tools/common_tools.py
Normal file
32
backend/tools/common_tools.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
|
||||
video_extensions = {
|
||||
'.mp4', '.m4a', '.m4v', '.f4v', '.f4a', '.m4b', '.m4r', '.f4b', '.mov',
|
||||
'.3gp', '.3gp2', '.3g2', '.3gpp', '.3gpp2', '.ogg', '.oga', '.ogv', '.ogx',
|
||||
'.wmv', '.wma', '.asf', '.webm', '.flv', '.avi', '.gifv', '.mkv', '.rm',
|
||||
'.rmvb', '.vob', '.dvd', '.mpg', '.mpeg', '.mp2', '.mpe', '.mpv', '.mpg',
|
||||
'.mpeg', '.m2v', '.svi', '.3gp', '.mxf', '.roq', '.nsv', '.flv', '.f4v',
|
||||
'.f4p', '.f4a', '.f4b'
|
||||
}
|
||||
|
||||
image_extensions = {
|
||||
'.jpg', '.jpeg', '.jpe', '.jif', '.jfif', '.jfi', '.png', '.gif',
|
||||
'.webp', '.tiff', '.tif', '.psd', '.raw', '.arw', '.cr2', '.nrw',
|
||||
'.k25', '.bmp', '.dib', '.heif', '.heic', '.ind', '.indd', '.indt',
|
||||
'.jp2', '.j2k', '.jpf', '.jpx', '.jpm', '.mj2', '.svg', '.svgz',
|
||||
'.ai', '.eps', '.ico'
|
||||
}
|
||||
|
||||
|
||||
def is_video_file(filename):
|
||||
return os.path.splitext(filename)[-1].lower() in video_extensions
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return os.path.splitext(filename)[-1].lower() in image_extensions
|
||||
|
||||
|
||||
def is_video_or_image(filename):
|
||||
file_extension = os.path.splitext(filename)[-1].lower()
|
||||
# 检查扩展名是否在定义的视频或图片文件后缀集合中
|
||||
return file_extension in video_extensions or file_extension in image_extensions
|
||||
@@ -22,8 +22,8 @@ def merge_video(video_input_path0, video_input_path1, video_output_path):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v0_path = '../../test/test_2_low.mp4'
|
||||
v1_path = '../../test/test_2_low_no_sub.mp4'
|
||||
v0_path = '../../test/test4.mp4'
|
||||
v1_path = '../../test/test4_no_sub(1).mp4'
|
||||
video_out_path = '../../test/demo.mp4'
|
||||
merge_video(v0_path, v1_path, video_out_path)
|
||||
# ffmpeg 命令 mp4转gif
|
||||
|
||||
33
backend/tools/train/configs_sttn/davis.json
Normal file
33
backend/tools/train/configs_sttn/davis.json
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"seed": 2020,
|
||||
"save_dir": "release_model/",
|
||||
"data_loader": {
|
||||
"name": "davis",
|
||||
"data_root": "datasets/",
|
||||
"w": 640,
|
||||
"h": 120,
|
||||
"sample_length": 5
|
||||
},
|
||||
"losses": {
|
||||
"hole_weight": 1,
|
||||
"valid_weight": 1,
|
||||
"adversarial_weight": 0.01,
|
||||
"GAN_LOSS": "hinge"
|
||||
},
|
||||
"trainer": {
|
||||
"type": "Adam",
|
||||
"beta1": 0,
|
||||
"beta2": 0.99,
|
||||
"lr": 1e-4,
|
||||
"d2glr": 1,
|
||||
"batch_size": 8,
|
||||
"num_workers": 2,
|
||||
"verbosity": 2,
|
||||
"log_step": 100,
|
||||
"save_freq": 1e4,
|
||||
"valid_freq": 1e4,
|
||||
"iterations": 50e4,
|
||||
"niter": 30e4,
|
||||
"niter_steady": 30e4
|
||||
}
|
||||
}
|
||||
33
backend/tools/train/configs_sttn/youtube-vos.json
Normal file
33
backend/tools/train/configs_sttn/youtube-vos.json
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"seed": 2020,
|
||||
"save_dir": "release_model/",
|
||||
"data_loader": {
|
||||
"name": "youtube-vos",
|
||||
"data_root": "datasets_sttn/",
|
||||
"w": 640,
|
||||
"h": 120,
|
||||
"sample_length": 5
|
||||
},
|
||||
"losses": {
|
||||
"hole_weight": 1,
|
||||
"valid_weight": 1,
|
||||
"adversarial_weight": 0.01,
|
||||
"GAN_LOSS": "hinge"
|
||||
},
|
||||
"trainer": {
|
||||
"type": "Adam",
|
||||
"beta1": 0,
|
||||
"beta2": 0.99,
|
||||
"lr": 1e-4,
|
||||
"d2glr": 1,
|
||||
"batch_size": 8,
|
||||
"num_workers": 2,
|
||||
"verbosity": 2,
|
||||
"log_step": 100,
|
||||
"save_freq": 1e4,
|
||||
"valid_freq": 1e4,
|
||||
"iterations": 50e4,
|
||||
"niter": 15e4,
|
||||
"niter_steady": 30e4
|
||||
}
|
||||
}
|
||||
85
backend/tools/train/dataset_sttn.py
Normal file
85
backend/tools/train/dataset_sttn.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion
|
||||
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
|
||||
|
||||
|
||||
# 自定义的数据集
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, args: dict, split='train', debug=False):
|
||||
# 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
|
||||
self.args = args
|
||||
self.split = split
|
||||
self.sample_length = args['sample_length'] # 样本长度参数
|
||||
self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
|
||||
|
||||
# 打开存放数据相关信息的json文件
|
||||
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
|
||||
self.video_dict = json.load(f) # 加载json文件内容
|
||||
self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
|
||||
if debug or split != 'train': # 如果是调试模式或者不是训练集,只取前100个视频
|
||||
self.video_names = self.video_names[:100]
|
||||
|
||||
# 定义数据的转换操作,转换成堆叠的张量
|
||||
self._to_tensors = transforms.Compose([
|
||||
Stack(),
|
||||
ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
# 返回数据集中视频的数量
|
||||
return len(self.video_names)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# 获取一个样本项
|
||||
try:
|
||||
item = self.load_item(index) # 尝试加载指定索引的数据项
|
||||
except:
|
||||
print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
|
||||
item = self.load_item(0) # 加载第一个项目作为兜底
|
||||
return item
|
||||
|
||||
def load_item(self, index):
|
||||
# 加载数据项的具体实现
|
||||
video_name = self.video_names[index] # 根据索引获取视频名称
|
||||
# 为所有视频帧生成帧文件名列表
|
||||
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
|
||||
# 生成随机运动的随机形状的遮罩
|
||||
all_masks = create_random_shape_with_random_motion(
|
||||
len(all_frames), imageHeight=self.h, imageWidth=self.w)
|
||||
# 获取参考帧的索引
|
||||
ref_index = get_ref_index(len(all_frames), self.sample_length)
|
||||
# 读取视频帧
|
||||
frames = []
|
||||
masks = []
|
||||
for idx in ref_index:
|
||||
# 读取图片,转化为RGB,调整大小并添加到列表中
|
||||
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
|
||||
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
|
||||
img = img.resize(self.size)
|
||||
frames.append(img)
|
||||
masks.append(all_masks[idx])
|
||||
if self.split == 'train':
|
||||
# 如果是训练集,随机水平翻转图像
|
||||
frames = GroupRandomHorizontalFlip()(frames)
|
||||
# 转换成张量形式
|
||||
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
|
||||
mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
|
||||
return frame_tensors, mask_tensors # 返回图像和遮罩的张量
|
||||
|
||||
|
||||
def get_ref_index(length, sample_length):
|
||||
# 获取参考帧索引的实现
|
||||
if random.uniform(0, 1) > 0.5:
|
||||
# 有一半的概率随机选择帧
|
||||
ref_index = random.sample(range(length), sample_length)
|
||||
ref_index.sort() # 排序保证顺序
|
||||
else:
|
||||
# 另一半概率选择连续的帧
|
||||
pivot = random.randint(0, length-sample_length)
|
||||
ref_index = [pivot+i for i in range(sample_length)]
|
||||
return ref_index
|
||||
1
backend/tools/train/datasets_sttn/davis/test.json
Normal file
1
backend/tools/train/datasets_sttn/davis/test.json
Normal file
@@ -0,0 +1 @@
|
||||
{"bear": 82, "bike-packing": 69, "blackswan": 50, "bmx-bumps": 90, "bmx-trees": 80, "boat": 75, "boxing-fisheye": 87, "breakdance": 84, "breakdance-flare": 71, "bus": 80, "camel": 90, "car-roundabout": 75, "car-shadow": 40, "car-turn": 80, "cat-girl": 89, "classic-car": 63, "color-run": 84, "cows": 104, "crossing": 52, "dance-jump": 60, "dance-twirl": 90, "dancing": 62, "disc-jockey": 76, "dog": 60, "dog-agility": 25, "dog-gooses": 86, "dogs-jump": 66, "dogs-scale": 83, "drift-chicane": 52, "drift-straight": 50, "drift-turn": 64, "drone": 91, "elephant": 80, "flamingo": 80, "goat": 90, "gold-fish": 78, "hike": 80, "hockey": 75, "horsejump-high": 50, "horsejump-low": 60, "india": 81, "judo": 34, "kid-football": 68, "kite-surf": 50, "kite-walk": 80, "koala": 100, "lab-coat": 47, "lady-running": 65, "libby": 49, "lindy-hop": 73, "loading": 50, "longboard": 52, "lucia": 70, "mallard-fly": 70, "mallard-water": 80, "mbike-trick": 79, "miami-surf": 70, "motocross-bumps": 60, "motocross-jump": 40, "motorbike": 43, "night-race": 46, "paragliding": 70, "paragliding-launch": 80, "parkour": 100, "pigs": 79, "planes-water": 38, "rallye": 50, "rhino": 90, "rollerblade": 35, "schoolgirls": 80, "scooter-black": 43, "scooter-board": 91, "scooter-gray": 75, "sheep": 68, "shooting": 40, "skate-park": 80, "snowboard": 66, "soapbox": 99, "soccerball": 48, "stroller": 91, "stunt": 71, "surf": 55, "swing": 60, "tennis": 70, "tractor-sand": 76, "train": 80, "tuk-tuk": 59, "upside-down": 65, "varanus-cage": 67, "walking": 72}
|
||||
1
backend/tools/train/datasets_sttn/davis/train.json
Normal file
1
backend/tools/train/datasets_sttn/davis/train.json
Normal file
@@ -0,0 +1 @@
|
||||
{"baseball": 90, "basketball-game": 77, "bears-ball": 78, "bmx-rider": 85, "butterfly": 80, "car-competition": 66, "cat": 52, "chairlift": 99, "circus": 73, "city-ride": 70, "crafting": 45, "curling": 76, "dog-competition": 85, "dolphins-show": 74, "dribbling": 49, "drone-flying": 70, "ducks": 75, "elephant-hyenas": 55, "giraffes": 88, "gym-ball": 69, "helicopter-landing": 77, "horse-race": 80, "horses-kids": 78, "hurdles-race": 55, "ice-hockey": 52, "jet-ski": 83, "juggling-selfie": 78, "kayak-race": 63, "kids-robot": 75, "landing": 35, "luggage": 83, "mantaray": 73, "marbles": 70, "mascot": 78, "mermaid": 78, "monster-trucks": 99, "motorbike-indoors": 79, "motorbike-race": 88, "music-band": 87, "obstacles": 81, "obstacles-race": 48, "peacock": 75, "plane-exhibition": 73, "puppet": 100, "robot-battle": 85, "robotic-arm": 82, "rodeo": 85, "sea-turtle": 90, "skydiving-jumping": 75, "snowboard-race": 75, "snowboard-sand": 55, "surfer": 80, "swimmer": 86, "table-tennis": 70, "tram": 84, "trucks-race": 78, "twist-dance": 83, "volleyball-beach": 73, "water-slide": 88, "weightlifting": 90}
|
||||
1
backend/tools/train/datasets_sttn/youtube-vos/test.json
Normal file
1
backend/tools/train/datasets_sttn/youtube-vos/test.json
Normal file
File diff suppressed because one or more lines are too long
1
backend/tools/train/datasets_sttn/youtube-vos/train.json
Normal file
1
backend/tools/train/datasets_sttn/youtube-vos/train.json
Normal file
File diff suppressed because one or more lines are too long
56
backend/tools/train/loss_sttn.py
Normal file
56
backend/tools/train/loss_sttn.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
"""
|
||||
对抗性损失
|
||||
根据论文 https://arxiv.org/abs/1711.10337 实现
|
||||
"""
|
||||
|
||||
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
|
||||
"""
|
||||
可以选择的损失类型有 'nsgan' | 'lsgan' | 'hinge'
|
||||
type: 指定使用哪种类型的 GAN 损失。
|
||||
target_real_label: 真实图像的目标标签值。
|
||||
target_fake_label: 生成图像的目标标签值。
|
||||
"""
|
||||
super(AdversarialLoss, self).__init__()
|
||||
self.type = type # 损失类型
|
||||
# 使用缓冲区注册标签,这样在模型保存和加载时会一同保存和加载
|
||||
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||||
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||||
|
||||
# 根据选择的类型初始化不同的损失函数
|
||||
if type == 'nsgan':
|
||||
self.criterion = nn.BCELoss() # 二进制交叉熵损失(非饱和GAN)
|
||||
elif type == 'lsgan':
|
||||
self.criterion = nn.MSELoss() # 均方误差损失(最小平方GAN)
|
||||
elif type == 'hinge':
|
||||
self.criterion = nn.ReLU() # 适用于hinge损失的ReLU函数
|
||||
|
||||
def __call__(self, outputs, is_real, is_disc=None):
|
||||
"""
|
||||
调用函数计算损失。
|
||||
outputs: 网络输出。
|
||||
is_real: 如果是真实样本,则为 True;如果是生成样本,则为 False。
|
||||
is_disc: 指示当前是否在优化判别器。
|
||||
"""
|
||||
if self.type == 'hinge':
|
||||
# 对于 hinge 损失
|
||||
if is_disc:
|
||||
# 如果是判别器
|
||||
if is_real:
|
||||
outputs = -outputs # 对真实样本反向标签
|
||||
# max(0, 1 - (真/假)示例输出)
|
||||
return self.criterion(1 + outputs).mean()
|
||||
else:
|
||||
# 如果是生成器, -min(0, -输出) = max(0, 输出)
|
||||
return (-outputs).mean()
|
||||
else:
|
||||
# 对于 nsgan 和 lsgan 损失
|
||||
labels = (self.real_label if is_real else self.fake_label).expand_as(
|
||||
outputs)
|
||||
# 计算模型输出和目标标签之间的损失
|
||||
loss = self.criterion(outputs, labels)
|
||||
return loss
|
||||
96
backend/tools/train/train_sttn.py
Normal file
96
backend/tools/train/train_sttn.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
from shutil import copyfile
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from backend.tools.train.trainer_sttn import Trainer
|
||||
from backend.tools.train.utils_sttn import (
|
||||
get_world_size,
|
||||
get_local_rank,
|
||||
get_global_rank,
|
||||
get_master_ip,
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description='STTN')
|
||||
parser.add_argument('-c', '--config', default='configs_sttn/youtube-vos.json', type=str)
|
||||
parser.add_argument('-m', '--model', default='sttn', type=str)
|
||||
parser.add_argument('-p', '--port', default='23455', type=str)
|
||||
parser.add_argument('-e', '--exam', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def main_worker(rank, config):
|
||||
# 如果配置中没有提到局部排序(local_rank),就给它和全局排序(global_rank)赋值为传入的排序(rank)
|
||||
if 'local_rank' not in config:
|
||||
config['local_rank'] = config['global_rank'] = rank
|
||||
|
||||
# 如果配置指定为分布式训练
|
||||
if config['distributed']:
|
||||
# 设置使用的CUDA设备为当前的本地排名对应的GPU
|
||||
torch.cuda.set_device(int(config['local_rank']))
|
||||
# 初始化分布式进程组,通过nccl后端
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl',
|
||||
init_method=config['init_method'],
|
||||
world_size=config['world_size'],
|
||||
rank=config['global_rank'],
|
||||
group_name='mtorch'
|
||||
)
|
||||
# 打印当前GPU的使用情况,输出全球排名和本地排名
|
||||
print('using GPU {}-{} for training'.format(
|
||||
int(config['global_rank']), int(config['local_rank']))
|
||||
)
|
||||
|
||||
# 创建模型保存的目录路径,包括模型名和配置文件名
|
||||
config['save_dir'] = os.path.join(
|
||||
config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0])
|
||||
)
|
||||
|
||||
# 如果CUDA可用,则设置设备为相应的CUDA设备,否则为CPU
|
||||
if torch.cuda.is_available():
|
||||
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
|
||||
else:
|
||||
config['device'] = 'cpu'
|
||||
|
||||
# 如果不是分布式训练,或者是分布式训练的主节点(rank 0)
|
||||
if (not config['distributed']) or config['global_rank'] == 0:
|
||||
# 创建模型保存目录,并允许如果该目录存在则忽略创建(exist_ok=True)
|
||||
os.makedirs(config['save_dir'], exist_ok=True)
|
||||
# 设置配置文件的保存路径
|
||||
config_path = os.path.join(
|
||||
config['save_dir'], config['config'].split('/')[-1]
|
||||
)
|
||||
# 如果配置文件不存在,则从给定的配置文件路径复制到新路径
|
||||
if not os.path.isfile(config_path):
|
||||
copyfile(config['config'], config_path)
|
||||
# 打印创建目录的信息
|
||||
print('[**] create folder {}'.format(config['save_dir']))
|
||||
|
||||
# 初始化训练器,传入配置参数和debug标记
|
||||
trainer = Trainer(config, debug=args.exam)
|
||||
# 开始训练
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 加载配置文件
|
||||
config = json.load(open(args.config))
|
||||
config['model'] = args.model # 设置模型名称
|
||||
config['config'] = args.config # 设置配置文件路径
|
||||
|
||||
# 设置分布式训练的相关配置
|
||||
config['world_size'] = get_world_size() # 获取全局进程数,即训练过程中参与计算的总GPU数量
|
||||
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" # 设置初始化方法,包括主节点IP和端口
|
||||
config['distributed'] = True if config['world_size'] > 1 else False # 根据世界规模确定是否启用分布式训练
|
||||
|
||||
# 设置分布式并行训练环境
|
||||
if get_master_ip() == "127.0.0.1":
|
||||
# 如果主节点IP是本机地址,那么手动启动多个分布式训练进程
|
||||
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
|
||||
else:
|
||||
# 如果是由其他工具如OpenMPI启动的多个进程,不需手动创建进程。
|
||||
config['local_rank'] = get_local_rank() # 获取本地(单个节点)排名
|
||||
config['global_rank'] = get_global_rank() # 获取全局排名
|
||||
main_worker(-1, config) # 启动主工作函数
|
||||
319
backend/tools/train/trainer_sttn.py
Normal file
319
backend/tools/train/trainer_sttn.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import os
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from backend.inpaint.sttn.auto_sttn import Discriminator
|
||||
from backend.inpaint.sttn.auto_sttn import InpaintGenerator
|
||||
from backend.tools.train.dataset_sttn import Dataset
|
||||
from backend.tools.train.loss_sttn import AdversarialLoss
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, debug=False):
|
||||
# 训练器初始化
|
||||
self.config = config # 保存配置信息
|
||||
self.epoch = 0 # 当前训练所处的epoch
|
||||
self.iteration = 0 # 当前训练迭代次数
|
||||
if debug:
|
||||
# 如果是调试模式,设置更频繁的保存和验证频率
|
||||
self.config['trainer']['save_freq'] = 5
|
||||
self.config['trainer']['valid_freq'] = 5
|
||||
self.config['trainer']['iterations'] = 5
|
||||
|
||||
# 设置数据集和数据加载器
|
||||
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug) # 创建训练集对象
|
||||
self.train_sampler = None # 初始化训练集采样器为None
|
||||
self.train_args = config['trainer'] # 训练过程参数
|
||||
if config['distributed']:
|
||||
# 如果是分布式训练,则初始化分布式采样器
|
||||
self.train_sampler = DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=config['world_size'],
|
||||
rank=config['global_rank']
|
||||
)
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.train_args['batch_size'] // config['world_size'],
|
||||
shuffle=(self.train_sampler is None), # 如果没有采样器则进行打乱
|
||||
num_workers=self.train_args['num_workers'],
|
||||
sampler=self.train_sampler
|
||||
)
|
||||
|
||||
# 设置损失函数
|
||||
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) # 对抗性损失
|
||||
self.adversarial_loss = self.adversarial_loss.to(self.config['device']) # 将损失函数转移到相应设备
|
||||
self.l1_loss = nn.L1Loss() # L1损失
|
||||
|
||||
# 初始化生成器和判别器模型
|
||||
self.netG = InpaintGenerator() # 生成网络
|
||||
self.netG = self.netG.to(self.config['device']) # 转移到设备
|
||||
self.netD = Discriminator(
|
||||
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge'
|
||||
)
|
||||
self.netD = self.netD.to(self.config['device']) # 判别网络
|
||||
# 初始化优化器
|
||||
self.optimG = torch.optim.Adam(
|
||||
self.netG.parameters(), # 生成器参数
|
||||
lr=config['trainer']['lr'], # 学习率
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
|
||||
)
|
||||
self.optimD = torch.optim.Adam(
|
||||
self.netD.parameters(), # 判别器参数
|
||||
lr=config['trainer']['lr'], # 学习率
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
|
||||
)
|
||||
self.load() # 加载模型
|
||||
|
||||
if config['distributed']:
|
||||
# 如果是分布式训练,则使用分布式数据并行包装器
|
||||
self.netG = DDP(
|
||||
self.netG,
|
||||
device_ids=[self.config['local_rank']],
|
||||
output_device=self.config['local_rank'],
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=False
|
||||
)
|
||||
self.netD = DDP(
|
||||
self.netD,
|
||||
device_ids=[self.config['local_rank']],
|
||||
output_device=self.config['local_rank'],
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=False
|
||||
)
|
||||
|
||||
# 设置日志记录器
|
||||
self.dis_writer = None # 判别器写入器
|
||||
self.gen_writer = None # 生成器写入器
|
||||
self.summary = {} # 存放摘要统计
|
||||
if self.config['global_rank'] == 0 or (not config['distributed']):
|
||||
# 如果不是分布式训练或者为分布式训练的主节点
|
||||
self.dis_writer = SummaryWriter(
|
||||
os.path.join(config['save_dir'], 'dis')
|
||||
)
|
||||
self.gen_writer = SummaryWriter(
|
||||
os.path.join(config['save_dir'], 'gen')
|
||||
)
|
||||
|
||||
# 获取当前学习率
|
||||
def get_lr(self):
|
||||
return self.optimG.param_groups[0]['lr']
|
||||
|
||||
# 调整学习率
|
||||
def adjust_learning_rate(self):
|
||||
# 计算衰减的学习率
|
||||
decay = 0.1 ** (min(self.iteration, self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
|
||||
new_lr = self.config['trainer']['lr'] * decay
|
||||
# 如果新的学习率和当前学习率不同,则更新优化器中的学习率
|
||||
if new_lr != self.get_lr():
|
||||
for param_group in self.optimG.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
for param_group in self.optimD.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
|
||||
# 添加摘要信息
|
||||
def add_summary(self, writer, name, val):
|
||||
# 添加并更新统计信息,每次迭代都累加
|
||||
if name not in self.summary:
|
||||
self.summary[name] = 0
|
||||
self.summary[name] += val
|
||||
# 每100次迭代记录一次
|
||||
if writer is not None and self.iteration % 100 == 0:
|
||||
writer.add_scalar(name, self.summary[name] / 100, self.iteration)
|
||||
self.summary[name] = 0
|
||||
|
||||
# 加载模型netG and netD
|
||||
def load(self):
|
||||
model_path = self.config['save_dir'] # 模型的保存路径
|
||||
# 检测是否存在最近的模型检查点
|
||||
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
||||
# 读取最后一个epoch的编号
|
||||
latest_epoch = open(os.path.join(
|
||||
model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
|
||||
else:
|
||||
# 如果不存在latest.ckpt,尝试读取存储好的模型文件列表,获取最近的一个
|
||||
ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
|
||||
os.path.join(model_path, '*.pth'))]
|
||||
ckpts.sort() # 排序模型文件,以获取最近的一个
|
||||
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None # 获取最近的epoch值
|
||||
if latest_epoch is not None:
|
||||
# 拼接得到生成器和判别器的模型文件路径
|
||||
gen_path = os.path.join(
|
||||
model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
dis_path = os.path.join(
|
||||
model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
opt_path = os.path.join(
|
||||
model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
# 如果是主节点,输出加载模型的信息
|
||||
if self.config['global_rank'] == 0:
|
||||
print('Loading model from {}...'.format(gen_path))
|
||||
# 加载生成器模型
|
||||
data = torch.load(gen_path, map_location=self.config['device'])
|
||||
self.netG.load_state_dict(data['netG'])
|
||||
# 加载判别器模型
|
||||
data = torch.load(dis_path, map_location=self.config['device'])
|
||||
self.netD.load_state_dict(data['netD'])
|
||||
# 加载优化器状态
|
||||
data = torch.load(opt_path, map_location=self.config['device'])
|
||||
self.optimG.load_state_dict(data['optimG'])
|
||||
self.optimD.load_state_dict(data['optimD'])
|
||||
# 更新当前epoch和迭代次数
|
||||
self.epoch = data['epoch']
|
||||
self.iteration = data['iteration']
|
||||
else:
|
||||
# 如果没有找到模型文件,则输出警告信息
|
||||
if self.config['global_rank'] == 0:
|
||||
print('Warning: There is no trained model found. An initialized model will be used.')
|
||||
|
||||
# 保存模型参数,每次评估周期 (eval_epoch) 调用一次
|
||||
def save(self, it):
|
||||
# 只在全局排名为0的进程上执行保存操作,通常代表主节点
|
||||
if self.config['global_rank'] == 0:
|
||||
# 生成保存生成器模型状态字典的文件路径
|
||||
gen_path = os.path.join(
|
||||
self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
|
||||
# 生成保存判别器模型状态字典的文件路径
|
||||
dis_path = os.path.join(
|
||||
self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
|
||||
# 生成保存优化器状态字典的文件路径
|
||||
opt_path = os.path.join(
|
||||
self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
|
||||
|
||||
# 打印消息表示模型正在保存
|
||||
print('\nsaving model to {} ...'.format(gen_path))
|
||||
|
||||
# 判断模型是否是经过DataParallel或DDP包装的,若是则获取原始的模型
|
||||
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
||||
netG = self.netG.module
|
||||
netD = self.netD.module
|
||||
else:
|
||||
netG = self.netG
|
||||
netD = self.netD
|
||||
|
||||
# 保存生成器和判别器的模型参数
|
||||
torch.save({'netG': netG.state_dict()}, gen_path)
|
||||
torch.save({'netD': netD.state_dict()}, dis_path)
|
||||
# 保存当前的epoch、迭代次数和优化器的状态
|
||||
torch.save({
|
||||
'epoch': self.epoch,
|
||||
'iteration': self.iteration,
|
||||
'optimG': self.optimG.state_dict(),
|
||||
'optimD': self.optimD.state_dict()
|
||||
}, opt_path)
|
||||
|
||||
# 写入最新的迭代次数到"latest.ckpt"文件
|
||||
os.system('echo {} > {}'.format(str(it).zfill(5),
|
||||
os.path.join(self.config['save_dir'], 'latest.ckpt')))
|
||||
|
||||
# 训练入口
|
||||
|
||||
def train(self):
|
||||
# 初始化进度条范围
|
||||
pbar = range(int(self.train_args['iterations']))
|
||||
# 如果是全局rank 0的进程,则设置显示进度条
|
||||
if self.config['global_rank'] == 0:
|
||||
pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
|
||||
|
||||
# 开始训练循环
|
||||
while True:
|
||||
self.epoch += 1 # epoch计数增加
|
||||
if self.config['distributed']:
|
||||
# 如果是分布式训练,则对采样器进行设置,保证每个进程获取的数据不同
|
||||
self.train_sampler.set_epoch(self.epoch)
|
||||
|
||||
# 调用训练一个epoch的函数
|
||||
self._train_epoch(pbar)
|
||||
# 如果迭代次数超过配置中的迭代上限,则退出循环
|
||||
if self.iteration > self.train_args['iterations']:
|
||||
break
|
||||
# 训练结束输出
|
||||
print('\nEnd training....')
|
||||
|
||||
# 每个训练周期处理输入并计算损失
|
||||
|
||||
def _train_epoch(self, pbar):
|
||||
device = self.config['device'] # 获取设备信息
|
||||
|
||||
# 遍历数据加载器中的数据
|
||||
for frames, masks in self.train_loader:
|
||||
# 调整学习率
|
||||
self.adjust_learning_rate()
|
||||
# 迭代次数+1
|
||||
self.iteration += 1
|
||||
|
||||
# 将frames和masks转移到设备上
|
||||
frames, masks = frames.to(device), masks.to(device)
|
||||
b, t, c, h, w = frames.size() # 获取帧和蒙版的尺寸
|
||||
masked_frame = (frames * (1 - masks).float()) # 应用蒙版到图像
|
||||
pred_img = self.netG(masked_frame, masks) # 使用生成器生成填充图像
|
||||
# 调整frames和masks的维度以符合网络的输入要求
|
||||
frames = frames.view(b * t, c, h, w)
|
||||
masks = masks.view(b * t, 1, h, w)
|
||||
comp_img = frames * (1. - masks) + masks * pred_img # 生成最终的组合图像
|
||||
|
||||
gen_loss = 0 # 初始化生成器损失
|
||||
dis_loss = 0 # 初始化判别器损失
|
||||
|
||||
# 判别器对抗性损失
|
||||
real_vid_feat = self.netD(frames) # 判别器对真实图像判别
|
||||
fake_vid_feat = self.netD(comp_img.detach()) # 判别器对生成图像判别,注意detach是为了不计算梯度
|
||||
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) # 真实图像的损失
|
||||
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) # 生成图像的损失
|
||||
dis_loss += (dis_real_loss + dis_fake_loss) / 2 # 求平均的判别器损失
|
||||
# 添加判别器损失到摘要
|
||||
self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
|
||||
self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
|
||||
# 优化判别器
|
||||
self.optimD.zero_grad()
|
||||
dis_loss.backward()
|
||||
self.optimD.step()
|
||||
|
||||
# 生成器对抗性损失
|
||||
gen_vid_feat = self.netD(comp_img)
|
||||
gan_loss = self.adversarial_loss(gen_vid_feat, True, False) # 生成器的对抗损失
|
||||
gan_loss = gan_loss * self.config['losses']['adversarial_weight'] # 权重放大
|
||||
gen_loss += gan_loss # 累加到生成器损失
|
||||
# 添加生成器对抗性损失到摘要
|
||||
self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
|
||||
|
||||
# 生成器L1损失
|
||||
hole_loss = self.l1_loss(pred_img * masks, frames * masks) # 只计算有蒙版区域的损失
|
||||
# 考虑蒙版的平均值,乘以配置中的hole_weight
|
||||
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
|
||||
gen_loss += hole_loss # 累加到生成器损失
|
||||
# 添加hole_loss到摘要
|
||||
self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
|
||||
|
||||
# 计算蒙版外区域的L1损失
|
||||
valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
|
||||
# 考虑非蒙版区的平均值,乘以配置中的valid_weight
|
||||
valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
|
||||
gen_loss += valid_loss # 累加到生成器损失
|
||||
# 添加valid_loss到摘要
|
||||
self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
|
||||
|
||||
# 生成器优化
|
||||
self.optimG.zero_grad()
|
||||
gen_loss.backward()
|
||||
self.optimG.step()
|
||||
|
||||
# 控制台日志输出
|
||||
if self.config['global_rank'] == 0:
|
||||
pbar.update(1) # 进度条更新
|
||||
pbar.set_description(( # 设置进度条描述
|
||||
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" # 打印损失数值
|
||||
f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
|
||||
)
|
||||
|
||||
# 模型保存
|
||||
if self.iteration % self.train_args['save_freq'] == 0:
|
||||
self.save(int(self.iteration // self.train_args['save_freq']))
|
||||
# 迭代次数终止判断
|
||||
if self.iteration > self.train_args['iterations']:
|
||||
break
|
||||
|
||||
271
backend/tools/train/utils_sttn.py
Normal file
271
backend/tools/train/utils_sttn.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import os
|
||||
|
||||
import matplotlib.patches as patches
|
||||
from matplotlib.path import Path
|
||||
import io
|
||||
import cv2
|
||||
import random
|
||||
import zipfile
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
import torch
|
||||
import matplotlib
|
||||
from matplotlib import pyplot as plt
|
||||
matplotlib.use('agg')
|
||||
|
||||
|
||||
class ZipReader(object):
|
||||
file_dict = dict()
|
||||
|
||||
def __init__(self):
|
||||
super(ZipReader, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def build_file_dict(path):
|
||||
file_dict = ZipReader.file_dict
|
||||
if path in file_dict:
|
||||
return file_dict[path]
|
||||
else:
|
||||
file_handle = zipfile.ZipFile(path, 'r')
|
||||
file_dict[path] = file_handle
|
||||
return file_dict[path]
|
||||
|
||||
@staticmethod
|
||||
def imread(path, image_name):
|
||||
zfile = ZipReader.build_file_dict(path)
|
||||
data = zfile.read(image_name)
|
||||
im = Image.open(io.BytesIO(data))
|
||||
return im
|
||||
|
||||
|
||||
class GroupRandomHorizontalFlip(object):
|
||||
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, is_flow=False):
|
||||
self.is_flow = is_flow
|
||||
|
||||
def __call__(self, img_group, is_flow=False):
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
|
||||
if self.is_flow:
|
||||
for i in range(0, len(ret), 2):
|
||||
# invert flow pixel values when flipping
|
||||
ret[i] = ImageOps.invert(ret[i])
|
||||
return ret
|
||||
else:
|
||||
return img_group
|
||||
|
||||
|
||||
class Stack(object):
|
||||
def __init__(self, roll=False):
|
||||
self.roll = roll
|
||||
|
||||
def __call__(self, img_group):
|
||||
mode = img_group[0].mode
|
||||
if mode == '1':
|
||||
img_group = [img.convert('L') for img in img_group]
|
||||
mode = 'L'
|
||||
if mode == 'L':
|
||||
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
|
||||
elif mode == 'RGB':
|
||||
if self.roll:
|
||||
return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
|
||||
else:
|
||||
return np.stack(img_group, axis=2)
|
||||
else:
|
||||
raise NotImplementedError(f"Image mode {mode}")
|
||||
|
||||
|
||||
class ToTorchFormatTensor(object):
|
||||
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
||||
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
||||
|
||||
def __init__(self, div=True):
|
||||
self.div = div
|
||||
|
||||
def __call__(self, pic):
|
||||
if isinstance(pic, np.ndarray):
|
||||
# numpy img: [L, C, H, W]
|
||||
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
|
||||
else:
|
||||
# handle PIL Image
|
||||
img = torch.ByteTensor(
|
||||
torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
||||
# put it from HWC to CHW format
|
||||
# yikes, this transpose takes 80% of the loading time/CPU
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
img = img.float().div(255) if self.div else img.float()
|
||||
return img
|
||||
|
||||
|
||||
def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432):
|
||||
# get a random shape
|
||||
height = random.randint(imageHeight//3, imageHeight-1)
|
||||
width = random.randint(imageWidth//3, imageWidth-1)
|
||||
edge_num = random.randint(6, 8)
|
||||
ratio = random.randint(6, 8)/10
|
||||
region = get_random_shape(
|
||||
edge_num=edge_num, ratio=ratio, height=height, width=width)
|
||||
region_width, region_height = region.size
|
||||
# get random position
|
||||
x, y = random.randint(
|
||||
0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
|
||||
velocity = get_random_velocity(max_speed=3)
|
||||
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
||||
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
||||
masks = [m.convert('L')]
|
||||
# return fixed masks
|
||||
if random.uniform(0, 1) > 0.5:
|
||||
return masks*video_length
|
||||
# return moving masks
|
||||
for _ in range(video_length-1):
|
||||
x, y, velocity = random_move_control_points(
|
||||
x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
|
||||
m = Image.fromarray(
|
||||
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
||||
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
||||
masks.append(m.convert('L'))
|
||||
return masks
|
||||
|
||||
|
||||
def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
|
||||
'''
|
||||
There is the initial point and 3 points per cubic bezier curve.
|
||||
Thus, the curve will only pass though n points, which will be the sharp edges.
|
||||
The other 2 modify the shape of the bezier curve.
|
||||
edge_num, Number of possibly sharp edges
|
||||
points_num, number of points in the Path
|
||||
ratio, (0, 1) magnitude of the perturbation from the unit circle,
|
||||
'''
|
||||
points_num = edge_num*3 + 1
|
||||
angles = np.linspace(0, 2*np.pi, points_num)
|
||||
codes = np.full(points_num, Path.CURVE4)
|
||||
codes[0] = Path.MOVETO
|
||||
# Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
|
||||
verts = np.stack((np.cos(angles), np.sin(angles))).T * \
|
||||
(2*ratio*np.random.random(points_num)+1-ratio)[:, None]
|
||||
verts[-1, :] = verts[0, :]
|
||||
path = Path(verts, codes)
|
||||
# draw paths into images
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
patch = patches.PathPatch(path, facecolor='black', lw=2)
|
||||
ax.add_patch(patch)
|
||||
ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
|
||||
ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
|
||||
ax.axis('off') # removes the axis to leave only the shape
|
||||
fig.canvas.draw()
|
||||
# convert plt images into numpy images
|
||||
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||
data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
|
||||
plt.close(fig)
|
||||
# postprocess
|
||||
data = cv2.resize(data, (width, height))[:, :, 0]
|
||||
data = (1 - np.array(data > 0).astype(np.uint8))*255
|
||||
corrdinates = np.where(data > 0)
|
||||
xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
|
||||
corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
|
||||
region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
|
||||
return region
|
||||
|
||||
|
||||
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
|
||||
speed, angle = velocity
|
||||
d_speed, d_angle = maxAcceleration
|
||||
if dist == 'uniform':
|
||||
speed += np.random.uniform(-d_speed, d_speed)
|
||||
angle += np.random.uniform(-d_angle, d_angle)
|
||||
elif dist == 'guassian':
|
||||
speed += np.random.normal(0, d_speed / 2)
|
||||
angle += np.random.normal(0, d_angle / 2)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Distribution type {dist} is not supported.')
|
||||
return (speed, angle)
|
||||
|
||||
|
||||
def get_random_velocity(max_speed=3, dist='uniform'):
|
||||
if dist == 'uniform':
|
||||
speed = np.random.uniform(max_speed)
|
||||
elif dist == 'guassian':
|
||||
speed = np.abs(np.random.normal(0, max_speed / 2))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Distribution type {dist} is not supported.')
|
||||
angle = np.random.uniform(0, 2 * np.pi)
|
||||
return (speed, angle)
|
||||
|
||||
|
||||
def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3):
|
||||
region_width, region_height = region_size
|
||||
speed, angle = lineVelocity
|
||||
X += int(speed * np.cos(angle))
|
||||
Y += int(speed * np.sin(angle))
|
||||
lineVelocity = random_accelerate(
|
||||
lineVelocity, maxLineAcceleration, dist='guassian')
|
||||
if (X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0):
|
||||
lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
|
||||
new_X = np.clip(X, 0, imageHeight - region_height)
|
||||
new_Y = np.clip(Y, 0, imageWidth - region_width)
|
||||
return new_X, new_Y, lineVelocity
|
||||
|
||||
|
||||
def get_world_size():
|
||||
"""Find OMPI world size without calling mpi functions
|
||||
:rtype: int
|
||||
"""
|
||||
if os.environ.get('PMI_SIZE') is not None:
|
||||
return int(os.environ.get('PMI_SIZE') or 1)
|
||||
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
|
||||
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
|
||||
else:
|
||||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def get_global_rank():
|
||||
"""Find OMPI world rank without calling mpi functions
|
||||
:rtype: int
|
||||
"""
|
||||
if os.environ.get('PMI_RANK') is not None:
|
||||
return int(os.environ.get('PMI_RANK') or 0)
|
||||
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
|
||||
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
"""Find OMPI local rank without calling mpi functions
|
||||
:rtype: int
|
||||
"""
|
||||
if os.environ.get('MPI_LOCALRANKID') is not None:
|
||||
return int(os.environ.get('MPI_LOCALRANKID') or 0)
|
||||
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
|
||||
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_master_ip():
|
||||
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
|
||||
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
|
||||
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
|
||||
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
|
||||
else:
|
||||
return "127.0.0.1"
|
||||
|
||||
if __name__ == '__main__':
|
||||
trials = 10
|
||||
for _ in range(trials):
|
||||
video_length = 10
|
||||
# The returned masks are either stationary (50%) or moving (50%)
|
||||
masks = create_random_shape_with_random_motion(
|
||||
video_length, imageHeight=240, imageWidth=432)
|
||||
|
||||
for m in masks:
|
||||
cv2.imshow('mask', np.array(m))
|
||||
cv2.waitKey(500)
|
||||
|
||||
BIN
design/paper_intro.pdf
Normal file
BIN
design/paper_intro.pdf
Normal file
Binary file not shown.
BIN
design/paper_sttn.pdf
Normal file
BIN
design/paper_sttn.pdf
Normal file
Binary file not shown.
BIN
design/sponsor.png
Normal file
BIN
design/sponsor.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 794 KiB |
62
gui.py
62
gui.py
@@ -15,14 +15,15 @@ import multiprocessing
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import backend.main
|
||||
from backend.tools.common_tools import is_image_file
|
||||
|
||||
|
||||
class SubtitleRemoverGUI:
|
||||
|
||||
def __init__(self):
|
||||
# 初次运行检查运行环境是否正常
|
||||
from paddle import fluid
|
||||
fluid.install_check.run_check()
|
||||
from paddle import utils
|
||||
utils.run_check()
|
||||
self.font = 'Arial 10'
|
||||
self.theme = 'LightBrown12'
|
||||
sg.theme(self.theme)
|
||||
@@ -233,6 +234,17 @@ class SubtitleRemoverGUI:
|
||||
self.window['-X-SLIDER-W-'].update(w)
|
||||
self._update_preview(frame, (y, h, x, w))
|
||||
|
||||
def __disable_button(self):
|
||||
# 1) 禁止修改字幕滑块区域
|
||||
self.window['-Y-SLIDER-'].update(disabled=True)
|
||||
self.window['-X-SLIDER-'].update(disabled=True)
|
||||
self.window['-Y-SLIDER-H-'].update(disabled=True)
|
||||
self.window['-X-SLIDER-W-'].update(disabled=True)
|
||||
# 2) 禁止再次点击【运行】、【打开】和【识别语言】按钮
|
||||
self.window['-RUN-'].update(disabled=True)
|
||||
self.window['-FILE-'].update(disabled=True)
|
||||
self.window['-FILE_BTN-'].update(disabled=True)
|
||||
|
||||
def _run_event_handler(self, event, values):
|
||||
"""
|
||||
当点击运行按钮时:
|
||||
@@ -244,15 +256,8 @@ class SubtitleRemoverGUI:
|
||||
if self.video_cap is None:
|
||||
print('Please Open Video First')
|
||||
else:
|
||||
# 1) 禁止修改字幕滑块区域
|
||||
self.window['-Y-SLIDER-'].update(disabled=True)
|
||||
self.window['-X-SLIDER-'].update(disabled=True)
|
||||
self.window['-Y-SLIDER-H-'].update(disabled=True)
|
||||
self.window['-X-SLIDER-W-'].update(disabled=True)
|
||||
# 2) 禁止再次点击【运行】、【打开】和【识别语言】按钮
|
||||
self.window['-RUN-'].update(disabled=True)
|
||||
self.window['-FILE-'].update(disabled=True)
|
||||
self.window['-FILE_BTN-'].update(disabled=True)
|
||||
# 禁用按钮
|
||||
self.__disable_button()
|
||||
# 3) 设定字幕区域位置
|
||||
self.xmin = int(values['-X-SLIDER-'])
|
||||
self.xmax = int(values['-X-SLIDER-'] + values['-X-SLIDER-W-'])
|
||||
@@ -262,8 +267,23 @@ class SubtitleRemoverGUI:
|
||||
self.ymax = self.frame_height
|
||||
if self.xmax > self.frame_width:
|
||||
self.xmax = self.frame_width
|
||||
print(f"{'SubtitleArea'}:({self.ymin},{self.ymax},{self.xmin},{self.xmax})")
|
||||
subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax)
|
||||
if len(self.video_paths) <= 1:
|
||||
subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax)
|
||||
else:
|
||||
print(f"{'Processing multiple videos or images'}")
|
||||
# 先判断每个视频的分辨率是否一致,一致的话设置相同的字幕区域,否则设置为None
|
||||
global_size = None
|
||||
for temp_video_path in self.video_paths:
|
||||
temp_cap = cv2.VideoCapture(temp_video_path)
|
||||
if global_size is None:
|
||||
global_size = (int(temp_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(temp_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
else:
|
||||
temp_size = (int(temp_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(temp_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
if temp_size != global_size:
|
||||
print('not all video/images in same size, processing in full screen')
|
||||
subtitle_area = None
|
||||
else:
|
||||
subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax)
|
||||
y_p = self.ymin / self.frame_height
|
||||
h_p = (self.ymax - self.ymin) / self.frame_height
|
||||
x_p = self.xmin / self.frame_width
|
||||
@@ -273,7 +293,10 @@ class SubtitleRemoverGUI:
|
||||
def task():
|
||||
while self.video_paths:
|
||||
video_path = self.video_paths.pop()
|
||||
if subtitle_area is not None:
|
||||
print(f"{'SubtitleArea'}:({self.ymin},{self.ymax},{self.xmin},{self.xmax})")
|
||||
self.sr = backend.main.SubtitleRemover(video_path, subtitle_area, True)
|
||||
self.__disable_button()
|
||||
self.sr.run()
|
||||
Thread(target=task, daemon=True).start()
|
||||
self.video_cap.release()
|
||||
@@ -287,7 +310,18 @@ class SubtitleRemoverGUI:
|
||||
"""
|
||||
if event == '-SLIDER-' or event == '-Y-SLIDER-' or event == '-Y-SLIDER-H-' or event == '-X-SLIDER-' or event \
|
||||
== '-X-SLIDER-W-':
|
||||
if self.video_cap is not None and self.video_cap.isOpened():
|
||||
# 判断是否时单张图片
|
||||
if is_image_file(self.video_path):
|
||||
img = cv2.imread(self.video_path)
|
||||
self.window['-Y-SLIDER-H-'].update(range=(0, self.frame_height - values['-Y-SLIDER-']))
|
||||
self.window['-X-SLIDER-W-'].update(range=(0, self.frame_width - values['-X-SLIDER-']))
|
||||
# 画字幕框
|
||||
y = int(values['-Y-SLIDER-'])
|
||||
h = int(values['-Y-SLIDER-H-'])
|
||||
x = int(values['-X-SLIDER-'])
|
||||
w = int(values['-X-SLIDER-W-'])
|
||||
self._update_preview(img, (y, h, x, w))
|
||||
elif self.video_cap is not None and self.video_cap.isOpened():
|
||||
frame_no = int(values['-SLIDER-'])
|
||||
self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
|
||||
ret, frame = self.video_cap.read()
|
||||
|
||||
@@ -9,7 +9,7 @@ lmdb==1.4.1
|
||||
PyYAML==6.0.1
|
||||
omegaconf==2.1.2
|
||||
tqdm==4.66.1
|
||||
PySimpleGUI==4.55.1
|
||||
PySimpleGUI==4.70.1
|
||||
easydict==1.9
|
||||
scikit-learn==0.24.2
|
||||
pandas==2.0.3
|
||||
@@ -18,4 +18,4 @@ pytorch-lightning==1.2.9
|
||||
numpy==1.23.1
|
||||
protobuf==3.20.0
|
||||
av==11.0.0
|
||||
einops==0.7.0
|
||||
einops==0.7.0
|
||||
|
||||
BIN
test/test4.mp4
Normal file
BIN
test/test4.mp4
Normal file
Binary file not shown.
Reference in New Issue
Block a user