From 9b45f7a9c729511b4b7a34f6ea6ff1982d95d1af Mon Sep 17 00:00:00 2001 From: YaoFANGUK Date: Fri, 27 Oct 2023 09:17:42 +0800 Subject: [PATCH] minor --- README.md | 19 +++++++++++++++---- backend/main.py | 6 +++++- requirements.txt | 4 +--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index dc97c83..7ce658b 100755 --- a/README.md +++ b/README.md @@ -22,14 +22,16 @@ Video-subtitle-remover (vsr) 是一款基于AI技术,将视频中的硬字幕 ## 源码使用说明 +> **无Nvidia显卡请勿使用本项目**,最低配置: +> +> **GPU**:GTX 1060或以上显卡 +> +> CPU: 支持AVX指令集 + #### 1. 下载安装Miniconda - Windows: Miniconda3-py38_4.11.0-Windows-x86_64.exe - -- MacOS:Miniconda3-py38_4.11.0-MacOSX-x86_64.pkg - - - Linux: Miniconda3-py38_4.11.0-Linux-x86_64.sh #### 2. 创建并激活虚机环境 @@ -118,6 +120,15 @@ conda activate videoEnv > 如果安装cuda 11.2,请对应安装8.1.1的cuDNN,并使用对应cuda版本的paddlepaddle,**30系列以上的显卡驱动可能不支持 cuda 11.2及以下版本的安装** + - 安装GPU版本Pytorch: + + ```shell + conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia + ``` + 或者使用 + ```shell + pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117 + ``` - 安装其他依赖: diff --git a/backend/main.py b/backend/main.py index 3865858..64f8bee 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,6 +12,8 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import importlib import numpy as np import tempfile +import torch +from paddle import fluid from tqdm import tqdm from tools.infer import utility from tools.infer.predict_det import TextDetector @@ -119,7 +121,9 @@ class SubtitleRemover: self.video_temp_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_{"".join(random.sample(uln, 8))}.mp4') self.video_writer = cv2.VideoWriter(self.video_temp_out_name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size) self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4') - + fluid.install_check.run_check() + if torch.cuda.is_available(): + print('使用GPU进行加速') @staticmethod def get_coordinates(dt_box): diff --git a/requirements.txt b/requirements.txt index e58e9fe..43a0abe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,6 @@ albumentations==0.5.2 filesplit==3.0.2 opencv-python==4.8.1.78 -torch==2.0.1 -torchvision==0.15.2 scikit-image==0.17.2 imgaug==0.4.0 kornia==0.5.0 @@ -17,4 +15,4 @@ pandas==2.0.3 webdataset==0.2.57 pytorch-lightning==1.2.9 numpy==1.23.1 - +protobuf==3.20.0