diff --git a/README.md b/README.md
index dc97c83..7ce658b 100755
--- a/README.md
+++ b/README.md
@@ -22,14 +22,16 @@ Video-subtitle-remover (vsr) 是一款基于AI技术,将视频中的硬字幕
## 源码使用说明
+> **无Nvidia显卡请勿使用本项目**,最低配置:
+>
+> **GPU**:GTX 1060或以上显卡
+>
+> CPU: 支持AVX指令集
+
#### 1. 下载安装Miniconda
- Windows: Miniconda3-py38_4.11.0-Windows-x86_64.exe
-
-- MacOS:Miniconda3-py38_4.11.0-MacOSX-x86_64.pkg
-
-
- Linux: Miniconda3-py38_4.11.0-Linux-x86_64.sh
#### 2. 创建并激活虚机环境
@@ -118,6 +120,15 @@ conda activate videoEnv
> 如果安装cuda 11.2,请对应安装8.1.1的cuDNN,并使用对应cuda版本的paddlepaddle,**30系列以上的显卡驱动可能不支持 cuda 11.2及以下版本的安装**
+ - 安装GPU版本Pytorch:
+
+ ```shell
+ conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
+ ```
+ 或者使用
+ ```shell
+ pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117
+ ```
- 安装其他依赖:
diff --git a/backend/main.py b/backend/main.py
index 3865858..64f8bee 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -12,6 +12,8 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import importlib
import numpy as np
import tempfile
+import torch
+from paddle import fluid
from tqdm import tqdm
from tools.infer import utility
from tools.infer.predict_det import TextDetector
@@ -119,7 +121,9 @@ class SubtitleRemover:
self.video_temp_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_{"".join(random.sample(uln, 8))}.mp4')
self.video_writer = cv2.VideoWriter(self.video_temp_out_name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
-
+ fluid.install_check.run_check()
+ if torch.cuda.is_available():
+ print('使用GPU进行加速')
@staticmethod
def get_coordinates(dt_box):
diff --git a/requirements.txt b/requirements.txt
index e58e9fe..43a0abe 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,6 @@
albumentations==0.5.2
filesplit==3.0.2
opencv-python==4.8.1.78
-torch==2.0.1
-torchvision==0.15.2
scikit-image==0.17.2
imgaug==0.4.0
kornia==0.5.0
@@ -17,4 +15,4 @@ pandas==2.0.3
webdataset==0.2.57
pytorch-lightning==1.2.9
numpy==1.23.1
-
+protobuf==3.20.0