85 Commits
1.1.0 ... main

Author SHA1 Message Date
天涯古巷
a577efd3e1 Update README.md
Some checks failed
Docker Build and Push / check-secrets (push) Successful in 10s
Docker Build and Push / build-and-push (cuda, 11.8) (push) Has been skipped
Docker Build and Push / build-and-push (cuda, 12.6) (push) Has been skipped
Docker Build and Push / build-and-push (cuda, 12.8) (push) Has been skipped
Docker Build and Push / build-and-push (directml, latest) (push) Has been skipped
Build Windows CUDA 11.8 / build (push) Has been cancelled
Build Windows CUDA 12.6 / build (push) Has been cancelled
Build Windows CUDA 12.8 / build (push) Has been cancelled
Build Windows DirectML / build (push) Has been cancelled
2025-12-03 08:40:41 +08:00
方耀
2bb2d54314 ganxie lyons
Some checks failed
Docker Build and Push / check-secrets (push) Successful in 23s
Docker Build and Push / build-and-push (cuda, 11.8) (push) Has been skipped
Docker Build and Push / build-and-push (cuda, 12.6) (push) Has been skipped
Docker Build and Push / build-and-push (cuda, 12.8) (push) Has been skipped
Docker Build and Push / build-and-push (directml, latest) (push) Has been skipped
Build Windows CUDA 11.8 / build (push) Has been cancelled
Build Windows CUDA 12.6 / build (push) Has been cancelled
Build Windows CUDA 12.8 / build (push) Has been cancelled
Build Windows DirectML / build (push) Has been cancelled
2025-06-26 14:43:41 +08:00
天涯古巷
7049a24883 Update README.md
Some checks failed
Docker Build and Push / check-secrets (push) Successful in 12s
Docker Build and Push / build-and-push (cuda, 11.8) (push) Has been skipped
Docker Build and Push / build-and-push (cuda, 12.8) (push) Has been skipped
Docker Build and Push / build-and-push (cuda, 12.6) (push) Has been skipped
Docker Build and Push / build-and-push (directml, latest) (push) Has been skipped
Build Windows CUDA 11.8 / build (push) Has been cancelled
Build Windows CUDA 12.6 / build (push) Has been cancelled
Build Windows CUDA 12.8 / build (push) Has been cancelled
Build Windows DirectML / build (push) Has been cancelled
2025-05-13 10:35:27 +08:00
天涯古巷
7e0ed62b9e Merge pull request #151 from eritpchy/patch-2
Some checks failed
Docker Build and Push / check-secrets (push) Successful in 5s
Docker Build and Push / build-and-push (cuda, 12.6) (push) Has been skipped
Docker Build and Push / build-and-push (directml, latest) (push) Has been skipped
Docker Build and Push / build-and-push (cuda, 11.8) (push) Has been skipped
Docker Build and Push / build-and-push (cuda, 12.8) (push) Has been skipped
Build Windows CUDA 11.8 / build (push) Has been cancelled
Build Windows CUDA 12.6 / build (push) Has been cancelled
Build Windows CUDA 12.8 / build (push) Has been cancelled
Build Windows DirectML / build (push) Has been cancelled
2025年5月积累更新
2025-05-13 09:49:04 +08:00
Jason
7a5384ad95 更新README 2025-05-06 22:10:11 +08:00
Jason
b40ab7d921 发布docker镜像 2025-05-06 21:18:32 +08:00
Jason
7304905a84 支持CI自动发布 2025-04-25 13:01:51 +08:00
Jason
150923b409 Update requirements.txt 2025-04-25 13:01:40 +08:00
Jason
746db4bced DirectML版本支持运行STTN模型(Windows) 2025-04-25 13:01:31 +08:00
Jason
30e7913981 修复结束时inpaint_area报错 2025-04-25 13:01:26 +08:00
Jason
6fc6d35584 主界面显示版本号, 方便定位问题 2025-04-25 13:01:09 +08:00
Jason
87d8d9d3d7 适配torch 2.8.0 nightly build 2025-04-25 13:00:58 +08:00
Jason
3770ccdcfd 改用PaddleOCR, 跟随主线更新 2025-04-25 13:00:30 +08:00
Jason
29873c33ea 由于PySimpleGUI作者故意移除免费的旧版本,改用PySimpleGUI-4-foss 2025-04-25 13:00:19 +08:00
天涯古巷
196a1a8e7b Merge pull request #150 from YaoFANGUK/revert-130-main
Revert "由于PySimpleGUI作者故意移除免费的旧版本,改用PySimpleGUI-4-foss"
2025-04-25 11:03:29 +08:00
天涯古巷
53baf28326 Revert "由于PySimpleGUI作者故意移除免费的旧版本,改用PySimpleGUI-4-foss" 2025-04-25 11:03:16 +08:00
天涯古巷
991294c00f Merge pull request #130 from eritpchy/main
由于PySimpleGUI作者故意移除免费的旧版本,改用PySimpleGUI-4-foss
2025-04-25 07:36:30 +08:00
Jason
ea7e01e3aa 发布docker镜像 2025-04-24 23:45:54 +08:00
Jason
acdb150aa2 支持CI自动发布 2025-04-24 16:22:01 +08:00
Jason
285bfbafa7 Update requirements.txt 2025-04-24 16:14:43 +08:00
Jason
97b4159d38 DirectML版本支持运行STTN模型(Windows) 2025-04-24 15:56:13 +08:00
Jason
bb80445cf4 修复结束时inpaint_area报错 2025-04-24 15:53:28 +08:00
Jason
7e8d0b818b 主界面显示版本号, 方便定位问题 2025-04-24 15:51:21 +08:00
Jason
77758d258b 适配torch 2.8.0 nightly build 2025-04-24 15:47:23 +08:00
Jason
c60234f4ec 改用PaddleOCR, 跟随主线更新 2025-04-24 15:46:38 +08:00
Jason
38ff91fad7 由于PySimpleGUI作者故意移除免费的旧版本,改用PySimpleGUI-4-foss 2025-02-28 15:35:59 +08:00
天涯古巷
9f9cded1ff Update README.md 2025-02-19 09:34:37 +08:00
天涯古巷
54027ceeb0 Add files via upload 2025-02-19 09:33:02 +08:00
天涯古巷
1dc9036dee Update README.md 2025-02-19 09:24:46 +08:00
天涯古巷
09dbfa47f2 Update README.md 2025-01-03 10:29:42 +08:00
天涯古巷
aa83db0f98 Update README.md 2024-12-24 11:37:19 +08:00
天涯古巷
f86c8c9fe8 Update README.md 2024-12-15 17:26:49 +08:00
天涯古巷
3dc8f3bfe0 Update main.py 2024-10-23 16:41:01 +08:00
天涯古巷
b0ca454473 感谢张音乐赞助 2024-10-13 14:22:40 +08:00
天涯古巷
9f7fd5b341 Update README_en.md 2024-10-09 18:00:21 +08:00
天涯古巷
535fdecef4 Update README_en.md 2024-10-09 17:58:14 +08:00
天涯古巷
019f7f4517 Update .condarc 2024-10-09 17:45:49 +08:00
天涯古巷
8c5ea2e19d Update README.md 2024-10-09 17:22:15 +08:00
天涯古巷
330cf54e1a Update README.md 2024-10-09 17:15:21 +08:00
天涯古巷
7019572f7b Update README.md 2024-09-30 23:59:58 +08:00
天涯古巷
ee53840adb Update README.md 2024-09-30 22:45:23 +08:00
天涯古巷
96d744b3a7 Update README.md 2024-09-30 22:42:35 +08:00
天涯古巷
32c47873ab 升级paddle到2.6.1 2024-09-30 22:34:06 +08:00
天涯古巷
99770a32b9 Update README.md 2024-09-30 22:29:47 +08:00
天涯古巷
0f71d732e1 Update README.md 2024-09-29 07:18:52 +08:00
天涯古巷
f3a982710d Update README.md 2024-09-25 20:53:36 +08:00
天涯古巷
e07849ef87 Update README.md 2024-09-19 20:10:28 +08:00
天涯古巷
96099ea2d4 Merge pull request #91 from Brikarl/main
Update PySimpleGUI version to 4.70.1
2024-09-19 20:06:44 +08:00
Brikarl
f4c22dd420 Update requirements.txt 2024-09-19 19:45:50 +08:00
天涯古巷
a3452832ff Update README.md 2024-07-10 17:56:02 +08:00
天涯古巷
45e80bc9b0 感谢Talkuv app的支持 2024-07-10 17:53:01 +08:00
天涯古巷
4a09342987 Update README.md 2024-07-06 17:27:20 +08:00
天涯古巷
caf4cb27f4 Update README.md 2024-06-19 08:48:14 +08:00
天涯古巷
c927476c0f 感谢衣食父母陈的赞助 2024-06-06 16:58:44 +08:00
天涯古巷
61aa3d8f88 Update README.md 2024-01-17 19:20:30 +08:00
天涯古巷
67fdacdd8b Update README.md 2024-01-16 22:39:33 +08:00
天涯古巷
3d21963995 Update README.md 2024-01-09 17:18:07 +08:00
YaoFANGUK
a3dd7b797d 添加注释 2024-01-09 11:05:07 +08:00
YaoFANGUK
6b353455a0 添加文献 2024-01-09 09:24:19 +08:00
YaoFANGUK
d6736d9206 添加sttn训练代码 2024-01-08 17:48:21 +08:00
天涯古巷
4abc3409ac Update README.md 2024-01-07 07:07:58 +08:00
YaoFANGUK
2d1eb11fd6 增大视野,保证去除效果 2024-01-05 16:57:40 +08:00
YaoFANGUK
9a65c17a50 修改备注 2024-01-05 14:39:05 +08:00
天涯古巷
3ce8d7409b Update README.md 2024-01-04 14:43:10 +08:00
YaoFANGUK
f9dd30fddf 兼容安卓手机不能分享生成视频的问题 2024-01-04 14:33:33 +08:00
YaoFANGUK
fda9024084 update readme_en 2024-01-02 14:48:49 +08:00
天涯古巷
19141ff5c9 Update README.md 2024-01-02 14:37:31 +08:00
天涯古巷
97b54f6d9e Update README_en.md 2023-12-30 09:27:48 +08:00
天涯古巷
584e574795 Update README.md 2023-12-30 09:24:57 +08:00
YaoFANGUK
dad37eba7d 更新readme 2023-12-29 15:52:55 +08:00
天涯古巷
063a896cb9 Update README.md 2023-12-29 15:41:00 +08:00
天涯古巷
63d8378f36 Update README.md 2023-12-29 12:57:45 +08:00
天涯古巷
4cbfa9ebf0 Update README.md 2023-12-29 11:58:15 +08:00
YaoFANGUK
8a8088be1f 新增单张图片可以选择去除区域 2023-12-29 10:48:59 +08:00
YaoFANGUK
757cc5bf77 感谢赞助 2023-12-29 10:33:46 +08:00
YaoFANGUK
e536d6af86 统一分辨率视频批处理时候,可以使用字幕区域 2023-12-29 10:28:00 +08:00
YaoFANGUK
311701d3e6 新增gui批处理功能 2023-12-29 10:13:45 +08:00
YaoFANGUK
a7e62db98a 屏蔽windows删除文件报错 2023-12-29 09:33:07 +08:00
YaoFANGUK
945aeb9bc8 新增文件类型判断 2023-12-29 09:23:42 +08:00
YaoFANGUK
6ea7482344 minor 2023-12-29 08:46:36 +08:00
YaoFANGUK
ba396d9569 未传入字幕区域时,进行全屏处理 2023-12-29 08:45:20 +08:00
天涯古巷
22b021d9ae Update README.md 2023-12-28 20:05:32 +08:00
天涯古巷
49ae0029f5 Update README.md 2023-12-28 20:04:26 +08:00
天涯古巷
f89c109636 Update README.md 2023-12-28 19:39:09 +08:00
天涯古巷
055a08403f Update README.md 2023-12-28 15:55:45 +08:00
254 changed files with 2208 additions and 59826 deletions

View File

@@ -2,13 +2,9 @@ channels:
- defaults
show_channel_urls: true
default_channels:
- http://mirrors.aliyun.com/anaconda/pkgs/main
- http://mirrors.aliyun.com/anaconda/pkgs/r
- http://mirrors.aliyun.com/anaconda/pkgs/msys2
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2
custom_channels:
conda-forge: http://mirrors.aliyun.com/anaconda/cloud
msys2: http://mirrors.aliyun.com/anaconda/cloud
bioconda: http://mirrors.aliyun.com/anaconda/cloud
menpo: http://mirrors.aliyun.com/anaconda/cloud
pytorch: http://mirrors.aliyun.com/anaconda/cloud
simpleitk: http://mirrors.aliyun.com/anaconda/cloud
conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud

97
.github/workflows/build-docker.yml vendored Normal file
View File

@@ -0,0 +1,97 @@
name: Docker Build and Push
on:
push:
branches:
- '**'
workflow_dispatch:
jobs:
check-secrets:
runs-on: ubuntu-latest
outputs:
has_secrets: ${{ steps.check.outputs.has_secrets }}
steps:
- id: check
run: |
if [[ -n "${{ secrets.DOCKERHUB_USERNAME }}" && -n "${{ secrets.DOCKERHUB_TOKEN }}" ]]; then
echo "has_secrets=true" >> $GITHUB_OUTPUT
else
echo "has_secrets=false" >> $GITHUB_OUTPUT
echo "未设置 Docker Hub 凭据,将跳过整个 Action"
fi
build-and-push:
needs: check-secrets
if: needs.check-secrets.outputs.has_secrets == 'true'
runs-on: ubuntu-latest
strategy:
matrix:
include:
- type: cuda
version: "11.8"
- type: cuda
version: "12.6"
- type: cuda
version: "12.8"
- type: directml
version: "latest"
steps:
- name: Show system
run: |
echo -e "Total CPU cores\t: $(nproc)"
cat /proc/cpuinfo | grep 'model name'
ulimit -a
- name: Free disk space
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android
df -h
- name: 检出代码
uses: actions/checkout@v4
- name: 读取 VERSION
id: version
run: |
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
echo "VERSION=$VERSION" >> $GITHUB_ENV
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
shell: bash
- name: 设置 Docker Buildx
uses: docker/setup-buildx-action@v3
- name: 登录到 Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: 提取元数据
id: meta
uses: docker/metadata-action@v4
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/video-subtitle-remover
tags: |
type=raw,value=${{ env.VERSION }}-${{ matrix.type }}${{ matrix.type == 'cuda' && matrix.version || '' }}
- name: 构建并推送
uses: docker/build-push-action@v6
with:
context: .
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
build-args: |
${{ matrix.type == 'cuda' && format('CUDA_VERSION={0}', matrix.version) || '' }}
${{ matrix.type == 'directml' && 'USE_DIRECTML=1' || '' }}
- name: Docker Hub Description
uses: peter-evans/dockerhub-description@v4
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
repository: ${{ secrets.DOCKERHUB_USERNAME }}/video-subtitle-remover

View File

@@ -0,0 +1,94 @@
name: Build Windows CUDA 11.8
on:
push:
branches:
- '**'
workflow_dispatch:
inputs:
ssh:
description: 'SSH connection to Actions'
required: false
default: false
permissions:
contents: write
jobs:
build:
runs-on: windows-2019
steps:
- uses: actions/checkout@v4
- name: 读取 VERSION
id: version
run: |
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
echo "VERSION=$VERSION" >> $GITHUB_ENV
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
shell: bash
# - name: 检查 tag 是否已存在
# run: |
# TAG_NAME="${VERSION}"
# if git ls-remote --tags origin | grep -q "refs/tags/$TAG_NAME"; then
# echo "Tag $TAG_NAME 已存在,发布中止"
# exit 1
# fi
# shell: bash
- uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip' # caching pip dependencies
- run: pip install paddlepaddle==3.0.0
- run: pip install -r requirements.txt
- run: pip freeze > requirements.txt
- run: pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu118
- run: pip install QPT==1.0b8 setuptools
- name: 获取 site-packages 路径
shell: bash
run: |
SITE_PACKAGES=$(python -c "import site, os; print(os.path.join(site.getsitepackages()[0], 'Lib', 'site-packages'))")
SITE_PACKAGES_UNIX=$(cygpath -u "$SITE_PACKAGES")
echo "site-packages路径: $SITE_PACKAGES"
echo "site-packages UNIX路径: $SITE_PACKAGES_UNIX"
echo "SITE_PACKAGES_UNIX=$SITE_PACKAGES_UNIX" >> $GITHUB_ENV
echo "SITE_PACKAGES=$SITE_PACKAGES" >> $GITHUB_ENV
- name: 修复QPT内部错误
run: sed -i '98c\ try:\n dep = pkg.requires()\n except TypeError:\n continue' ${SITE_PACKAGES_UNIX}/qpt/kernel/qpackage.py
shell: bash
- name: Start SSH via tmate
if: (github.event.inputs.ssh == 'true' && github.event.inputs.ssh != 'false') || contains(github.event.action, 'ssh')
uses: mxschmitt/action-tmate@v3
- run: |
python backend/tools/makedist.py --cuda 11.8 && \
mv ../vsr_out ./vsr_out && \
cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
env:
QPT_Action: "True"
shell: bash
- name: 上传 Debug 文件夹到 Artifacts
uses: actions/upload-artifact@v4
with:
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8-debug
path: vsr_out/Debug/
- name: 上传 Release 文件夹到 Artifacts
uses: actions/upload-artifact@v4
with:
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8-release
path: vsr_out/Release/
- name: 打包 Release 文件夹
run: |
cd vsr_out/Release
7z a -t7z -mx=9 -m0=LZMA2 -ms=on -mfb=64 -md=32m -mmt=on -v2000m vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z * && \
# 检测是否只有一个分卷
if [ -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z.001 ] && [ ! -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z.002 ]; then \
mv vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z.001 vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z; fi
shell: bash
- name: Release
uses: softprops/action-gh-release@v1
with:
prerelease: true
tag_name: ${{ env.VERSION }}
target_commitish: ${{ github.sha }}
name: 硬字幕去除器 v${{ env.VERSION }}
files: |
vsr_out/Release/vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z*

View File

@@ -0,0 +1,94 @@
name: Build Windows CUDA 12.6
on:
push:
branches:
- '**'
workflow_dispatch:
inputs:
ssh:
description: 'SSH connection to Actions'
required: false
default: false
permissions:
contents: write
jobs:
build:
runs-on: windows-2019
steps:
- uses: actions/checkout@v4
- name: 读取 VERSION
id: version
run: |
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
echo "VERSION=$VERSION" >> $GITHUB_ENV
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
shell: bash
# - name: 检查 tag 是否已存在
# run: |
# TAG_NAME="${VERSION}"
# if git ls-remote --tags origin | grep -q "refs/tags/$TAG_NAME"; then
# echo "Tag $TAG_NAME 已存在,发布中止"
# exit 1
# fi
# shell: bash
- uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip' # caching pip dependencies
- run: pip install paddlepaddle==3.0.0
- run: pip install -r requirements.txt
- run: pip freeze > requirements.txt
- run: pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126
- run: pip install QPT==1.0b8 setuptools
- name: 获取 site-packages 路径
shell: bash
run: |
SITE_PACKAGES=$(python -c "import site, os; print(os.path.join(site.getsitepackages()[0], 'Lib', 'site-packages'))")
SITE_PACKAGES_UNIX=$(cygpath -u "$SITE_PACKAGES")
echo "site-packages路径: $SITE_PACKAGES"
echo "site-packages UNIX路径: $SITE_PACKAGES_UNIX"
echo "SITE_PACKAGES_UNIX=$SITE_PACKAGES_UNIX" >> $GITHUB_ENV
echo "SITE_PACKAGES=$SITE_PACKAGES" >> $GITHUB_ENV
- name: 修复QPT内部错误
run: sed -i '98c\ try:\n dep = pkg.requires()\n except TypeError:\n continue' ${SITE_PACKAGES_UNIX}/qpt/kernel/qpackage.py
shell: bash
- name: Start SSH via tmate
if: (github.event.inputs.ssh == 'true' && github.event.inputs.ssh != 'false') || contains(github.event.action, 'ssh')
uses: mxschmitt/action-tmate@v3
- run: |
python backend/tools/makedist.py --cuda 12.6 && \
mv ../vsr_out ./vsr_out && \
cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
env:
QPT_Action: "True"
shell: bash
- name: 上传 Debug 文件夹到 Artifacts
uses: actions/upload-artifact@v4
with:
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6-debug
path: vsr_out/Debug/
- name: 上传 Release 文件夹到 Artifacts
uses: actions/upload-artifact@v4
with:
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6-release
path: vsr_out/Release/
- name: 打包 Release 文件夹
run: |
cd vsr_out/Release
7z a -t7z -mx=9 -m0=LZMA2 -ms=on -mfb=64 -md=32m -mmt=on -v2000m vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z * && \
# 检测是否只有一个分卷
if [ -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z.001 ] && [ ! -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z.002 ]; then \
mv vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z.001 vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z; fi
shell: bash
- name: Release
uses: softprops/action-gh-release@v1
with:
prerelease: true
tag_name: ${{ env.VERSION }}
target_commitish: ${{ github.sha }}
name: 硬字幕去除器 v${{ env.VERSION }}
files: |
vsr_out/Release/vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z*

View File

@@ -0,0 +1,94 @@
name: Build Windows CUDA 12.8
on:
push:
branches:
- '**'
workflow_dispatch:
inputs:
ssh:
description: 'SSH connection to Actions'
required: false
default: false
permissions:
contents: write
jobs:
build:
runs-on: windows-2019
steps:
- uses: actions/checkout@v4
- name: 读取 VERSION
id: version
run: |
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
echo "VERSION=$VERSION" >> $GITHUB_ENV
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
shell: bash
# - name: 检查 tag 是否已存在
# run: |
# TAG_NAME="${VERSION}"
# if git ls-remote --tags origin | grep -q "refs/tags/$TAG_NAME"; then
# echo "Tag $TAG_NAME 已存在,发布中止"
# exit 1
# fi
# shell: bash
- uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip' # caching pip dependencies
- run: pip install paddlepaddle==3.0.0
- run: pip install -r requirements.txt
- run: pip freeze > requirements.txt
- run: pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu128
- run: pip install QPT==1.0b8 setuptools
- name: 获取 site-packages 路径
shell: bash
run: |
SITE_PACKAGES=$(python -c "import site, os; print(os.path.join(site.getsitepackages()[0], 'Lib', 'site-packages'))")
SITE_PACKAGES_UNIX=$(cygpath -u "$SITE_PACKAGES")
echo "site-packages路径: $SITE_PACKAGES"
echo "site-packages UNIX路径: $SITE_PACKAGES_UNIX"
echo "SITE_PACKAGES_UNIX=$SITE_PACKAGES_UNIX" >> $GITHUB_ENV
echo "SITE_PACKAGES=$SITE_PACKAGES" >> $GITHUB_ENV
- name: 修复QPT内部错误
run: sed -i '98c\ try:\n dep = pkg.requires()\n except TypeError:\n continue' ${SITE_PACKAGES_UNIX}/qpt/kernel/qpackage.py
shell: bash
- name: Start SSH via tmate
if: (github.event.inputs.ssh == 'true' && github.event.inputs.ssh != 'false') || contains(github.event.action, 'ssh')
uses: mxschmitt/action-tmate@v3
- run: |
python backend/tools/makedist.py --cuda 12.8 && \
mv ../vsr_out ./vsr_out && \
cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
env:
QPT_Action: "True"
shell: bash
- name: 上传 Debug 文件夹到 Artifacts
uses: actions/upload-artifact@v4
with:
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8-debug
path: vsr_out/Debug/
- name: 上传 Release 文件夹到 Artifacts
uses: actions/upload-artifact@v4
with:
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8-release
path: vsr_out/Release/
- name: 打包 Release 文件夹
run: |
cd vsr_out/Release
7z a -t7z -mx=9 -m0=LZMA2 -ms=on -mfb=64 -md=32m -mmt=on -v2000m vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z * && \
# 检测是否只有一个分卷
if [ -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z.001 ] && [ ! -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z.002 ]; then \
mv vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z.001 vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z; fi
shell: bash
- name: Release
uses: softprops/action-gh-release@v1
with:
prerelease: true
tag_name: ${{ env.VERSION }}
target_commitish: ${{ github.sha }}
name: 硬字幕去除器 v${{ env.VERSION }}
files: |
vsr_out/Release/vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z*

View File

@@ -0,0 +1,94 @@
name: Build Windows DirectML
on:
push:
branches:
- '**'
workflow_dispatch:
inputs:
ssh:
description: 'SSH connection to Actions'
required: false
default: false
permissions:
contents: write
jobs:
build:
runs-on: windows-2019
steps:
- uses: actions/checkout@v4
- name: 读取 VERSION
id: version
run: |
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
echo "VERSION=$VERSION" >> $GITHUB_ENV
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
shell: bash
# - name: 检查 tag 是否已存在
# run: |
# TAG_NAME="${VERSION}"
# if git ls-remote --tags origin | grep -q "refs/tags/$TAG_NAME"; then
# echo "Tag $TAG_NAME 已存在,发布中止"
# exit 1
# fi
# shell: bash
- uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip' # caching pip dependencies
- run: pip install paddlepaddle==3.0.0
- run: pip install -r requirements.txt
- run: pip freeze > requirements.txt
- run: pip install torch_directml==0.2.5.dev240914
- run: pip install QPT==1.0b8 setuptools
- name: 获取 site-packages 路径
shell: bash
run: |
SITE_PACKAGES=$(python -c "import site, os; print(os.path.join(site.getsitepackages()[0], 'Lib', 'site-packages'))")
SITE_PACKAGES_UNIX=$(cygpath -u "$SITE_PACKAGES")
echo "site-packages路径: $SITE_PACKAGES"
echo "site-packages UNIX路径: $SITE_PACKAGES_UNIX"
echo "SITE_PACKAGES_UNIX=$SITE_PACKAGES_UNIX" >> $GITHUB_ENV
echo "SITE_PACKAGES=$SITE_PACKAGES" >> $GITHUB_ENV
- name: 修复QPT内部错误
run: sed -i '98c\ try:\n dep = pkg.requires()\n except TypeError:\n continue' ${SITE_PACKAGES_UNIX}/qpt/kernel/qpackage.py
shell: bash
- name: Start SSH via tmate
if: (github.event.inputs.ssh == 'true' && github.event.inputs.ssh != 'false') || contains(github.event.action, 'ssh')
uses: mxschmitt/action-tmate@v3
- run: |
python backend/tools/makedist.py --directml && \
mv ../vsr_out ./vsr_out && \
cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
env:
QPT_Action: "True"
shell: bash
- name: 上传 Debug 文件夹到 Artifacts
uses: actions/upload-artifact@v4
with:
name: vsr-v${{ env.VERSION }}-windows-directml-debug
path: vsr_out/Debug/
- name: 上传 Release 文件夹到 Artifacts
uses: actions/upload-artifact@v4
with:
name: vsr-v${{ env.VERSION }}-windows-directml-release
path: vsr_out/Release/
- name: 打包 Release 文件夹
run: |
cd vsr_out/Release
7z a -t7z -mx=9 -m0=LZMA2 -ms=on -mfb=64 -md=32m -mmt=on -v2000m vsr-v${{ env.VERSION }}-windows-directml.7z * && \
# 检测是否只有一个分卷
if [ -f vsr-v${{ env.VERSION }}-windows-directml.7z.001 ] && [ ! -f vsr-v${{ env.VERSION }}-windows-directml.7z.002 ]; then \
mv vsr-v${{ env.VERSION }}-windows-directml.7z.001 vsr-v${{ env.VERSION }}-windows-directml.7z; fi
shell: bash
- name: Release
uses: softprops/action-gh-release@v1
with:
prerelease: true
tag_name: ${{ env.VERSION }}
target_commitish: ${{ github.sha }}
name: 硬字幕去除器 v${{ env.VERSION }}
files: |
vsr_out/Release/vsr-v${{ env.VERSION }}-windows-directml.7z*

2
.gitignore vendored
View File

@@ -372,3 +372,5 @@ test*_no_sub*.mp4
/backend/models/video/ProPainter.pth
/backend/models/big-lama/big-lama.pt
/test/debug/
/backend/tools/train/release_model/
model.onnx

273
README.md
View File

@@ -3,7 +3,7 @@
## 项目简介
![License](https://img.shields.io/badge/License-Apache%202-red.svg)
![python version](https://img.shields.io/badge/Python-3.8+-blue.svg)
![python version](https://img.shields.io/badge/Python-3.11+-blue.svg)
![support os](https://img.shields.io/badge/OS-Windows/macOS/Linux-green.svg)
Video-subtitle-remover (VSR) 是一款基于AI技术将视频中的硬字幕去除的软件。
@@ -12,13 +12,14 @@ Video-subtitle-remover (VSR) 是一款基于AI技术将视频中的硬字幕
- 通过超强AI算法模型对去除字幕文本的区域进行填充非相邻像素填充与马赛克去除
- 支持自定义字幕位置,仅去除定义位置中的字幕(传入位置)
- 支持全视频自动去除所有文本(不传入位置)
- 支持多选图片批量去除水印文本
<p style="text-align:center;"><img src="https://github.com/YaoFANGUK/video-subtitle-remover/raw/main/design/demo.png" alt="demo.png"/></p>
**使用说明:**
- 有使用问题请加群讨论QQ群806152575
- 直接下载压缩包解压运行如果不能运行再按照下面的教程尝试源码安装conda环境运行
- 有使用问题请加群讨论QQ群210150985已满、806152575已满、816881808已满、295894827
- 直接下载压缩包解压运行如果不能运行再按照下面的教程尝试源码安装conda环境运行
**下载地址:**
@@ -26,9 +27,36 @@ Windows GPU版本v1.1.0GPU
- 百度网盘: <a href="https://pan.baidu.com/s/1zR6CjRztmOGBbOkqK8R1Ng?pwd=vsr1">vsr_windows_gpu_v1.1.0.zip</a> 提取码:**vsr1**
- Google Drive: <a href="https://drive.google.com/drive/folders/1NRgLNoHHOmdO4GxLhkPbHsYfMOB_3Elr?usp=sharing">vsr_windows_gpu_v1.1.0.zip</a>
- Google Drive: <a href="https://drive.google.com/drive/folders/1NRgLNoHHOmdO4GxLhkPbHsYfMOB_3Elr?usp=sharing">vsr_windows_gpu_v1.1.0.zip</a>
> 仅供具有Nvidia显卡的用户使用(AMD的显卡不行)
**预构建包对比说明**
| 预构建包名 | Python | Paddle | Torch | 环境 | 支持的计算能力范围|
|---------------|------------|--------------|--------------|-----------------------------|----------|
| `vsr-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows 非Nvidia显卡 | 通用 |
| `vsr-windows-nvidia-cuda-11.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 8.9 |
| `vsr-windows-nvidia-cuda-12.6.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 8.9 |
| `vsr-windows-nvidia-cuda-12.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 9.0+ |
> NVIDIA官方提供了各GPU型号的计算能力列表您可以参考链接: [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) 查看你的GPU适合哪个CUDA版本
**Docker版本**
```shell
# Nvidia 10 20 30系显卡
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda11.8
# Nvidia 40系显卡
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.6
# Nvidia 50系显卡
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.8
# AMD / Intel 独显 集显
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-directml
# 演示视频, 输入
/vsr/test/test.mp4
docker cp vsr:/vsr/test/test_no_sub.mp4 ./
```
## 演示
@@ -42,114 +70,98 @@ Windows GPU版本v1.1.0GPU
## 源码使用说明
> **无Nvidia显卡请勿使用本项目**,最低配置:
>
> **GPU**GTX 1060或以上显卡
>
> CPU: 支持AVX指令集
#### 1. 下载安装Miniconda
#### 1. 安装 Python
- Windows: <a href="https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Windows-x86_64.exe">Miniconda3-py38_4.11.0-Windows-x86_64.exe</a>
请确保您已经安装了 Python 3.12+。
- Linux: <a href="https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Linux-x86_64.sh">Miniconda3-py38_4.11.0-Linux-x86_64.sh</a>
- Windows 用户可以前往 [Python 官网](https://www.python.org/downloads/windows/) 下载并安装 Python。
- MacOS 用户可以使用 Homebrew 安装:
```shell
brew install python@3.12
```
- Linux 用户可以使用包管理器安装,例如 Ubuntu/Debian
```shell
sudo apt update && sudo apt install python3.12 python3.12-venv python3.12-dev
```
#### 2. 创建并激活虚机环境
#### 2. 安装依赖文件
1切换到源码所在目录
请使用虚拟环境来管理项目依赖,避免与系统环境冲突。
1创建虚拟环境并激活
```shell
python -m venv videoEnv
```
- Windows
```shell
videoEnv\\Scripts\\activate
```
- MacOS/Linux
```shell
source videoEnv/bin/activate
```
#### 3. 创建并激活项目目录
切换到源码所在目录:
```shell
cd <源码所在目录>
```
> 例如:如果的源代码放在D盘的tools文件下并且源代码的文件夹名为video-subtitle-remover输入 ```cd D:/tools/video-subtitle-remover-main```
> 例如:如果的源代码放在 D 盘的 tools 文件下,并且源代码的文件夹名为 video-subtitle-remover输入
> ```shell
> cd D:/tools/video-subtitle-remover-main
> ```
2创建激活conda环境
```shell
conda create -n videoEnv python=3.8
```
#### 4. 安装合适的运行环境
```shell
conda activate videoEnv
```
本项目支持 CUDANVIDIA显卡加速和 DirectMLAMD、Intel等GPU/APU加速两种运行模式。
#### 3. 安装依赖文件
##### (1) CUDANVIDIA 显卡用户)
请确保你已经安装 python 3.8+使用conda创建项目虚拟环境并激活环境 (建议创建虚拟环境运行,以免后续出现问题)
> 请确保您的 NVIDIA 显卡驱动支持所选 CUDA 版本。
- 安装CUDA和cuDNN
- 推荐 CUDA 11.8,对应 cuDNN 8.6.0。
<details>
<summary>Linux用户</summary>
<h5>(1) 下载CUDA 11.7</h5>
<pre><code>wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run</code></pre>
<h5>(2) 安装CUDA 11.7</h5>
<pre><code>sudo sh cuda_11.7.0_515.43.04_linux.run</code></pre>
<p>1. 输入accept</p>
<img src="https://i.328888.xyz/2023/03/31/iwVoeH.png" width="500" alt="">
<p>2. 选中CUDA Toolkit 11.7如果你没有安装nvidia驱动则选中Driver如果你已经安装了nvidia驱动请不要选中driver之后选中install回车</p>
<img src="https://i.328888.xyz/2023/03/31/iwVThJ.png" width="500" alt="">
<p>3. 添加环境变量</p>
<p>在 ~/.bashrc 加入以下内容</p>
<pre><code># CUDA
export PATH=/usr/local/cuda-11.7/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}</code></pre>
<p>使其生效</p>
<pre><code>source ~/.bashrc</code></pre>
<h5>(3) 下载cuDNN 8.4.1</h5>
<p>国内:<a href="https://pan.baidu.com/s/1Gd_pSVzWfX1G7zCuqz6YYA">cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz</a> 提取码57mg</p>
<p>国外:<a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz">cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz</a></p>
<h5>(4) 安装cuDNN 8.4.1</h5>
<pre><code> tar -xf cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz
mv cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive cuda
sudo cp ./cuda/include/* /usr/local/cuda-11.7/include/
sudo cp ./cuda/lib/* /usr/local/cuda-11.7/lib64/
sudo chmod a+r /usr/local/cuda-11.7/lib64/*
sudo chmod a+r /usr/local/cuda-11.7/include/*</code></pre>
</details>
- 安装 CUDA
- Windows[CUDA 11.8 下载](https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe)
- Linux
```shell
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run
```
- MacOS 不支持 CUDA。
<details>
<summary>Windows用户</summary>
<h5>(1) 下载CUDA 11.7</h5>
<a href="https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_516.01_windows.exe">cuda_11.7.0_516.01_windows.exe</a>
<h5>(2) 安装CUDA 11.7</h5>
<h5>(3) 下载cuDNN 8.2.4</h5>
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x64-v8.2.4.15.zip">cudnn-windows-x64-v8.2.4.15.zip</a></p>
<h5>(4) 安装cuDNN 8.2.4</h5>
<p>
将cuDNN解压后的cuda文件夹中的bin, include, lib目录下的文件复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\对应目录下
</p>
</details>
- 安装 cuDNNCUDA 11.8 对应 cuDNN 8.6.0
- [Windows cuDNN 8.6.0 下载](https://developer.download.nvidia.cn/compute/redist/cudnn/v8.6.0/local_installers/11.8/cudnn-windows-x86_64-8.6.0.163_cuda11-archive.zip)
- [Linux cuDNN 8.6.0 下载](https://developer.download.nvidia.cn/compute/redist/cudnn/v8.6.0/local_installers/11.8/cudnn-linux-x86_64-8.6.0.163_cuda11-archive.tar.xz)
- 安装方法请参考 NVIDIA 官方文档。
- 安装GPU版本Paddlepaddle:
- windows:
```shell
python -m pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html
```
- Linux:
```shell
python -m pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```
- 安装GPU版本Pytorch:
```shell
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
- 安装 PaddlePaddle GPU 版本CUDA 11.8
```shell
pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
```
或者使用
```shell
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117
- 安装 Torch GPU 版本CUDA 11.8
```shell
pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu118
```
- 安装其他依赖:
- 安装其他依赖
```shell
pip install -r requirements.txt
```
##### (2) DirectMLAMD、Intel等GPU/APU加速卡用户
- 适用于 Windows 设备的 AMD/NVIDIA/Intel GPU。
- 安装 ONNX Runtime DirectML 版本:
```shell
pip install paddlepaddle==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
pip install -r requirements.txt
pip install torch_directml==0.2.5.dev240914
```
#### 4. 运行程序
@@ -166,29 +178,80 @@ python ./backend/main.py
```
## 常见问题
1. CondaHTTPError
1. 提取速度慢怎么办
修改backend/config.py中的参数可以大幅度提高去除速度
```python
MODE = InpaintMode.STTN # 设置为STTN算法
STTN_SKIP_DETECTION = True # 跳过字幕检测,跳过后可能会导致要去除的字幕遗漏或者误伤不需要去除字幕的视频帧
```
2. 视频去除效果不好怎么办
修改backend/config.py中的参数尝试不同的去除算法算法介绍
> - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
> - InpaintMode.LAMA 算法:对于图片效果最好,对动画类视频效果好,速度一般,不可以跳过字幕检测
> - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
- 使用STTN算法
```python
MODE = InpaintMode.STTN # 设置为STTN算法
# 相邻帧数, 调大会增加显存占用,效果变好
STTN_NEIGHBOR_STRIDE = 10
# 参考帧长度, 调大会增加显存占用,效果变好
STTN_REFERENCE_LENGTH = 10
# 设置STTN算法最大同时处理的帧数量设置越大速度越慢但效果越好
# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
STTN_MAX_LOAD_NUM = 30
```
- 使用LAMA算法
```python
MODE = InpaintMode.LAMA # 设置为STTN算法
LAMA_SUPER_FAST = False # 保证效果
```
> 如果对模型去字幕的效果不满意可以查看design文件夹里面的训练方法利用backend/tools/train里面的代码进行训练然后将训练的模型替换旧模型即可
3. CondaHTTPError
将项目中的.condarc放在用户目录下(C:/Users/<你的用户名>),如果用户目录已经存在该文件则覆盖
解决方案https://zhuanlan.zhihu.com/p/260034241
2. 7z文件解压错误
4. 7z文件解压错误
解决方案升级7-zip解压程序到最新版本
3. 4090使用cuda 11.7跑不起来
解决方案改用cuda 11.8
## 赞助
<img src="https://i.imgur.com/EMCP5Lv.jpeg" width="600">
| 捐赠者 | 累计捐赠金额 | 赞助席位 |
| --- | --- | --- |
| 坤V | 400.00 RMB | 金牌赞助席位 |
| 陈凯 | 50.00 RMB | 银牌赞助席位 |
| Tshuang | 20.00 RMB | 牌赞助席位 |
| 很奇异| 15.00 RMB | 牌赞助席位 |
| 何斐| 10.00 RMB | 牌赞助席位 |
| 长缨在手| 6.00 RMB | 牌赞助席位 |
| Leo| 1.00 RMB | 牌赞助席位 |
<img src="https://github.com/YaoFANGUK/video-subtitle-extractor/raw/main/design/sponsor.png" width="600">
| 捐赠者 | 累计捐赠金额 | 赞助席位 |
|---------------------------|------------| --- |
| 坤V | 400.00 RMB | 牌赞助席位 |
| Jenkit | 200.00 RMB | 牌赞助席位 |
| 子车松兰 | 188.00 RMB | 牌赞助席位 |
| 落花未逝 | 100.00 RMB | 牌赞助席位 |
| 张音乐 | 100.00 RMB | 牌赞助席位 |
| 麦格 | 100.00 RMB | 金牌赞助席位 |
| 无痕 | 100.00 RMB | 金牌赞助席位 |
| wr | 100.00 RMB | 金牌赞助席位 |
| 陈 | 100.00 RMB | 金牌赞助席位 |
| lyons | 100.00 RMB | 金牌赞助席位 |
| TalkLuv | 50.00 RMB | 银牌赞助席位 |
| 陈凯 | 50.00 RMB | 银牌赞助席位 |
| Freeman | 30.00 RMB | 银牌赞助席位 |
| Tshuang | 20.00 RMB | 银牌赞助席位 |
| 很奇异 | 15.00 RMB | 银牌赞助席位 |
| 郭鑫 | 12.00 RMB | 银牌赞助席位 |
| 生活不止眼前的苟且 | 10.00 RMB | 铜牌赞助席位 |
| 何斐 | 10.00 RMB | 铜牌赞助席位 |
| 老猫 | 8.80 RMB | 铜牌赞助席位 |
| 伍六七 | 7.77 RMB | 铜牌赞助席位 |
| 长缨在手 | 6.00 RMB | 铜牌赞助席位 |
| 无忌 | 6.00 RMB | 铜牌赞助席位 |
| Stephen | 2.00 RMB | 铜牌赞助席位 |
| Leo | 1.00 RMB | 铜牌赞助席位 |

View File

@@ -3,7 +3,7 @@
## Project Introduction
![License](https://img.shields.io/badge/License-Apache%202-red.svg)
![python version](https://img.shields.io/badge/Python-3.8+-blue.svg)
![python version](https://img.shields.io/badge/Python-3.11+-blue.svg)
![support os](https://img.shields.io/badge/OS-Windows/macOS/Linux-green.svg)
Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subtitles from videos. It mainly implements the following functionalities:
@@ -12,6 +12,7 @@ Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subt
- Fills in the removed subtitle text area using a powerful AI algorithm model (non-adjacent pixel filling and mosaic removal).
- Supports custom subtitle positions by only removing subtitles in the defined location (input position).
- Supports automatic removal of all text throughout the entire video (without inputting a position).
- Supports multi-selection of images for batch removal of watermark text.
<p style="text-align:center;"><img src="https://github.com/YaoFANGUK/video-subtitle-remover/raw/main/design/demo.png" alt="demo.png"/></p>
@@ -25,7 +26,36 @@ Windows GPU Version v1.1.0 (GPU):
- Google Drive: <a href="https://drive.google.com/drive/folders/1NRgLNoHHOmdO4GxLhkPbHsYfMOB_3Elr?usp=sharing">vsr_windows_gpu_v1.1.0.zip</a>
> For use only by users with Nvidia graphics cards (AMD graphics cards are not supported).
**Pre-built Package Comparison**:
| Pre-built Package Name | Python | Paddle | Torch | Environment | Supported Compute Capability Range |
|----------------------------------|--------|--------|--------|-----------------------------------|------------------------------------|
| `vse-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows without Nvidia GPU | Universal |
| `vse-windows-nvidia-cuda-11.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 8.9 |
| `vse-windows-nvidia-cuda-12.6.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 8.9 |
| `vse-windows-nvidia-cuda-12.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 9.0+ |
> NVIDIA provides a list of supported compute capabilities for each GPU model. You can refer to the following link: [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) to check which CUDA version is compatible with your GPU.
**Docker Versions:**
```shell
# Nvidia 10, 20, 30 Series Graphics Cards
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda11.8
# Nvidia 40 Series Graphics Cards
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.6
# Nvidia 50 Series Graphics Cards
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.8
# AMD / Intel Dedicated or Integrated Graphics
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-directml
# Demo video, input
/vsr/test/test.mp4
docker cp vsr:/vsr/test/test_no_sub.mp4 ./
```
## Demonstration
@@ -39,117 +69,98 @@ Windows GPU Version v1.1.0 (GPU):
## Source Code Usage Instructions
> **Do not use this project without an Nvidia graphics card**. The minimum requirements are:
>
> **GPU**: GTX 1060 or higher graphics card
>
> CPU: Supports AVX instruction set
#### 1. Install Python
#### 1. Download and install Miniconda
Please ensure that you have installed Python 3.12+.
- Windows: <a href="https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Windows-x86_64.exe">Miniconda3-py38_4.11.0-Windows-x86_64.exe</a>
- Windows users can go to the [Python official website](https://www.python.org/downloads/windows/) to download and install Python.
- MacOS users can install using Homebrew:
```shell
brew install python@3.12
```
- Linux users can install via the package manager, such as on Ubuntu/Debian:
```shell
sudo apt update && sudo apt install python3.12 python3.12-venv python3.12-dev
```
- Linux: <a href="https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Linux-x86_64.sh">Miniconda3-py38_4.11.0-Linux-x86_64.sh</a>
#### 2. Install Dependencies
#### 2. Create and activate a virtual environment
It is recommended to use a virtual environment to manage project dependencies to avoid conflicts with the system environment.
(1) Switch to the source code directory:
(1) Create and activate the virtual environment:
```shell
python -m venv videoEnv
```
- Windows:
```shell
videoEnv\\Scripts\\activate
```
- MacOS/Linux:
```shell
source videoEnv/bin/activate
```
#### 3. Create and Activate Project Directory
Change to the directory where your source code is located:
```shell
cd <source_code_directory>
```
> For example, if your source code is in the `tools` folder on the D drive and the folder name is `video-subtitle-remover`, use:
> ```shell
> cd D:/tools/video-subtitle-remover-main
> ```
> For example, if your source code is in the `tools` folder on drive D, and the source code folder name is `video-subtitle-remover`, enter `cd D:/tools/video-subtitle-remover-main`.
#### 4. Install the Appropriate Runtime Environment
(2) Create and activate the conda environment:
This project supports two runtime modes: CUDA (NVIDIA GPU acceleration) and DirectML (AMD, Intel, and other GPUs/APUs).
```shell
conda create -n videoEnv python=3.8
```
##### (1) CUDA (For NVIDIA GPU users)
```shell
conda activate videoEnv
```
> Make sure your NVIDIA GPU driver supports the selected CUDA version.
#### 3. Install dependencies
Please make sure you have already installed Python 3.8+, use conda to create a project virtual environment and activate the environment (it is recommended to create a virtual environment to run to avoid subsequent problems).
- Install **CUDA** and **cuDNN**
<details>
<summary>Linux</summary>
<h5>(1) Download CUDA 11.7</h5>
<pre><code>wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run</code></pre>
<h5>(2) Install CUDA 11.7</h5>
<pre><code>sudo sh cuda_11.7.0_515.43.04_linux.run</code></pre>
<p>1. Input accept</p>
<img src="https://i.328888.xyz/2023/03/31/iwVoeH.png" width="500" alt="">
<p>2. make sure CUDA Toolkit 11.7 is chosen (If you have already installed driver, do not select Driver)</p>
<img src="https://i.328888.xyz/2023/03/31/iwVThJ.png" width="500" alt="">
<p>3. Add environment variables</p>
<p>add the following content in <strong>~/.bashrc</strong></p>
<pre><code># CUDA
export PATH=/usr/local/cuda-11.7/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}</code></pre>
<p>Make sure it works</p>
<pre><code>source ~/.bashrc</code></pre>
<h5>(3) Download cuDNN 8.4.1</h5>
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz">cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz</a></p>
<h5>(4) Install cuDNN 8.4.1</h5>
<pre><code> tar -xf cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz
mv cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive cuda
sudo cp ./cuda/include/* /usr/local/cuda-11.7/include/
sudo cp ./cuda/lib/* /usr/local/cuda-11.7/lib64/
sudo chmod a+r /usr/local/cuda-11.7/lib64/*
sudo chmod a+r /usr/local/cuda-11.7/include/*</code></pre>
</details>
<details>
<summary>Windows</summary>
<h5>(1) Download CUDA 11.7</h5>
<a href="https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_516.01_windows.exe">cuda_11.7.0_516.01_windows.exe</a>
<h5>(2) Install CUDA 11.7</h5>
<h5>(3) Download cuDNN 8.2.4</h5>
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x64-v8.2.4.15.zip">cudnn-windows-x64-v8.2.4.15.zip</a></p>
<h5>(4) Install cuDNN 8.2.4</h5>
<p>
unzip "cudnn-windows-x64-v8.2.4.15.zip", then move all files in "bin, include, lib" in cuda
directory to C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\
</p>
</details>
- Install GPU version of Paddlepaddle:
- windows:
```shell
python -m pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html
```
- Recommended CUDA 11.8, corresponding to cuDNN 8.6.0.
- Install CUDA:
- Windows: [Download CUDA 11.8](https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe)
- Linux:
```shell
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run
```
- CUDA is not supported on MacOS.
```shell
python -m pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```
- Install cuDNN (CUDA 11.8 corresponds to cuDNN 8.6.0):
- [Windows cuDNN 8.6.0 Download](https://developer.download.nvidia.cn/compute/redist/cudnn/v8.6.0/local_installers/11.8/cudnn-windows-x86_64-8.6.0.163_cuda11-archive.zip)
- [Linux cuDNN 8.6.0 Download](https://developer.download.nvidia.cn/compute/redist/cudnn/v8.6.0/local_installers/11.8/cudnn-linux-x86_64-8.6.0.163_cuda11-archive.tar.xz)
- Follow the installation guide in the NVIDIA official documentation.
- Install GPU version of Pytorch:
```shell
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
- Install PaddlePaddle GPU version (CUDA 11.8):
```shell
pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
```
or use
```shell
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117
- Install Torch GPU version (CUDA 11.8):
```shell
pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu118
```
- Install other dependencies:
```shell
pip install -r requirements.txt
```
##### (2) DirectML (For AMD, Intel, and other GPU/APU users)
- Suitable for Windows devices with AMD/NVIDIA/Intel GPUs.
- Install ONNX Runtime DirectML version:
```shell
pip install paddlepaddle==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
pip install -r requirements.txt
pip install -r requirements_directml.txt
```
#### 4. Run the program
@@ -166,12 +177,52 @@ python ./backend/main.py
```
## Common Issues
1. CondaHTTPError
1. How to deal with slow removal speed
You can greatly increase the removal speed by modifying the parameters in backend/config.py:
```python
MODE = InpaintMode.STTN # Set to STTN algorithm
STTN_SKIP_DETECTION = True # Skip subtitle detection
```
2. What to do if the video removal results are not satisfactory
Modify the values in backend/config.py and try different removal algorithms. Here is an introduction to the algorithms:
> - **InpaintMode.STTN** algorithm: Good for live-action videos and fast in speed, capable of skipping subtitle detection
> - **InpaintMode.LAMA** algorithm: Best for images and effective for animated videos, moderate speed, unable to skip subtitle detection
> - **InpaintMode.PROPAINTER** algorithm: Consumes a significant amount of VRAM, slower in speed, works better for videos with very intense movement
- Using the STTN algorithm
```python
MODE = InpaintMode.STTN # Set to STTN algorithm
# Number of neighboring frames, increasing this will increase memory usage and improve the result
STTN_NEIGHBOR_STRIDE = 10
# Length of reference frames, increasing this will increase memory usage and improve the result
STTN_REFERENCE_LENGTH = 10
# Set the maximum number of frames processed simultaneously by the STTN algorithm, a larger value leads to slower processing but better results
# Ensure that STTN_MAX_LOAD_NUM is greater than STTN_NEIGHBOR_STRIDE and STTN_REFERENCE_LENGTH
STTN_MAX_LOAD_NUM = 30
```
- Using the LAMA algorithm
```python
MODE = InpaintMode.LAMA # Set to LAMA algorithm
LAMA_SUPER_FAST = False # Ensure quality
```
3. CondaHTTPError
Place the .condarc file from the project in the user directory (C:/Users/<your_username>). If the file already exists in the user directory, overwrite it.
Solution: https://zhuanlan.zhihu.com/p/260034241
2. 7z file extraction error
4. 7z file extraction error
Solution: Upgrade the 7-zip extraction program to the latest version.

View File

@@ -7,12 +7,20 @@ import logging
import platform
import stat
from fsplit.filesplit import Filesplit
import paddle
import onnxruntime as ort
# 项目版本号
VERSION = "1.1.1"
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
paddle.disable_signal_handler()
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
try:
import torch_directml
device = torch_directml.device(torch_directml.default_device())
USE_DML = True
except:
USE_DML = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
@@ -50,6 +58,27 @@ if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'
# 将ffmpeg添加可执行权限
os.chmod(FFMPEG_PATH, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
# 是否使用ONNX(DirectML/AMD/Intel)
ONNX_PROVIDERS = []
available_providers = ort.get_available_providers()
for provider in available_providers:
if provider in [
"CPUExecutionProvider"
]:
continue
if provider not in [
"DmlExecutionProvider", # DirectML适用于 Windows GPU
"ROCMExecutionProvider", # AMD ROCm
"MIGraphXExecutionProvider", # AMD MIGraphX
"VitisAIExecutionProvider", # AMD VitisAI适用于 RyzenAI & Windows, 实测和DirectML性能似乎差不多
"OpenVINOExecutionProvider", # Intel GPU
"MetalExecutionProvider", # Apple macOS
"CoreMLExecutionProvider", # Apple macOS
"CUDAExecutionProvider", # Nvidia GPU
]:
continue
ONNX_PROVIDERS.append(provider)
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
@@ -64,12 +93,17 @@ class InpaintMode(Enum):
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# 是否使用h264编码如果需要安卓手机分享生成的视频请打开该选项
USE_H264 = True
# ×××××××××× 通用设置 start ××××××××××
"""
MODE可选算法类型
- InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
- InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测
- InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
"""
# 【设置inpaint算法】
# - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
# - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以字幕检测
# - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
MODE = InpaintMode.STTN
# 【设置像素点偏差】
# 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
@@ -85,17 +119,33 @@ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
# ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
# 以下参数仅适用STTN算法时才生效
# 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧
"""
1. STTN_SKIP_DETECTION
含义:是否使用跳过检测
效果设置为True跳过字幕检测会省去很大时间但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了
2. STTN_NEIGHBOR_STRIDE
含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域STTN_NEIGHBOR_STRIDE=5那么算法会使用第45帧、第40帧等作为参照。
效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。
3. STTN_REFERENCE_LENGTH
含义参数帧数量STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息
效果:调大会增加显存占用,处理效果变好,但是处理速度变慢
4. STTN_MAX_LOAD_NUM
含义STTN算法每次最多加载的视频帧数量
效果:设置越大速度越慢,但效果越好
注意要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
"""
STTN_SKIP_DETECTION = True
# 相邻帧数, 调大会增加显存占用,效果变好
STTN_NEIGHBOR_STRIDE = 10
# 参考帧长度, 调大会增加显存占用,效果变好
# 参考帧步长
STTN_NEIGHBOR_STRIDE = 5
# 参考帧长度(数量)
STTN_REFERENCE_LENGTH = 10
# 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好
# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
STTN_MAX_LOAD_NUM = 30
if STTN_MAX_LOAD_NUM < max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH):
STTN_MAX_LOAD_NUM = max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH)
# 设置STTN算法最大同时处理的帧数量
STTN_MAX_LOAD_NUM = 50
if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE:
STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
# ×××××××××× InpaintMode.PROPAINTER算法设置 start ××××××××××

View File

@@ -27,7 +27,7 @@ class STTNInpaint:
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
self.model = InpaintGenerator().to(self.device)
# 2. 载入预训练模型的权重,转载模型的状态字典
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG'])
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location='cpu')['netG'])
# 3. # 将模型设置为评估模式
self.model.eval()
# 模型输入用的宽和高
@@ -93,6 +93,7 @@ class STTNInpaint:
@staticmethod
def read_mask(path):
img = cv2.imread(path, 0)
# 转为binary mask
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
img = img[:, :, None]
return img
@@ -200,6 +201,24 @@ class STTNInpaint:
to_H -= h
return inpaint_area # 返回绘画区域列表
@staticmethod
def get_inpaint_area_by_selection(input_sub_area, mask):
print('use selection area for inpainting')
height, width = mask.shape[:2]
ymin, ymax, _, _ = input_sub_area
interval_size = 135
# 存储结果的列表
inpaint_area = []
# 计算并存储标准区间
for i in range(ymin, ymax, interval_size):
inpaint_area.append((i, i + interval_size))
# 检查最后一个区间是否达到了最大值
if inpaint_area[-1][1] != ymax:
# 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值
if inpaint_area[-1][1] + interval_size <= height:
inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size))
return inpaint_area # 返回绘画区域列表
class STTNVideoInpaint:
@@ -234,70 +253,106 @@ class STTNVideoInpaint:
self.clip_gap = clip_gap
def __call__(self, input_mask=None, input_sub_remover=None, tbar=None):
# 读取视频帧信息
reader, frame_info = self.read_frame_info_from_video()
if input_sub_remover is not None:
writer = input_sub_remover.video_writer
else:
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
# 计算需要迭代修复视频的次数
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
# 计算分割高度,用于确定修复区域的大小
split_h = int(frame_info['W_ori'] * 3 / 16)
if input_mask is None:
# 读取掩码
mask = self.sttn_inpaint.read_mask(self.mask_path)
else:
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
# 得到修复区域位置
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
# 遍历每一次的迭代次数
for i in range(rec_time):
start_f = i * self.clip_gap # 起始帧位置
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
frames_hr = [] # 高分辨率帧列表
frames = {} # 帧字典,用于存储裁剪后的图像
comps = {} # 组合字典,用于存储修复后的图像
# 初始化帧字典
for k in range(len(inpaint_area)):
frames[k] = []
# 读取和修复高分辨率帧
for j in range(start_f, end_f):
success, image = reader.read()
frames_hr.append(image)
reader = None
writer = None
try:
# 读取视频帧信息
reader, frame_info = self.read_frame_info_from_video()
if input_sub_remover is not None:
writer = input_sub_remover.video_writer
else:
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
# 计算需要迭代修复视频的次数
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
# 计算分割高度,用于确定修复区域的大小
split_h = int(frame_info['W_ori'] * 3 / 16)
if input_mask is None:
# 读取掩码
mask = self.sttn_inpaint.read_mask(self.mask_path)
else:
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
# 得到修复区域位置
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
# 遍历每一次的迭代次数
for i in range(rec_time):
start_f = i * self.clip_gap # 起始帧位置
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
frames_hr = [] # 高分辨率帧列表
frames = {} # 帧字典,用于存储裁剪后的图像
comps = {} # 组合字典,用于存储修复后的图像
# 初始化帧字典
for k in range(len(inpaint_area)):
# 裁剪、缩放并添加到帧字典
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :]
image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height))
frames[k].append(image_resize)
# 对每个修复区域运行修复
for k in range(len(inpaint_area)):
comps[k] = self.sttn_inpaint.inpaint(frames[k])
# 如果有要修复的区域
if inpaint_area is not []:
for j in range(end_f - start_f):
if input_sub_remover is not None and input_sub_remover.gui_mode:
original_frame = copy.deepcopy(frames_hr[j])
else:
original_frame = None
frame = frames_hr[j]
frames[k] = []
# 读取和修复高分辨率帧
valid_frames_count = 0
for j in range(start_f, end_f):
success, image = reader.read()
if not success:
print(f"Warning: Failed to read frame {j}.")
break
frames_hr.append(image)
valid_frames_count += 1
for k in range(len(inpaint_area)):
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h))
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
writer.write(frame)
if input_sub_remover is not None and input_sub_remover.gui_mode:
if tbar is not None:
input_sub_remover.update_progress(tbar, increment=1)
if original_frame is not None:
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
# 释放视频写入对象
writer.release()
# 裁剪、缩放并添加到帧字典
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :]
image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height))
frames[k].append(image_resize)
# 如果没有读取到有效帧,则跳过当前迭代
if valid_frames_count == 0:
print(f"Warning: No valid frames found in range {start_f+1}-{end_f}. Skipping this segment.")
continue
# 对每个修复区域运行修复
for k in range(len(inpaint_area)):
if len(frames[k]) > 0: # 确保有帧可以处理
comps[k] = self.sttn_inpaint.inpaint(frames[k])
else:
comps[k] = []
# 如果有要修复的区域
if inpaint_area and valid_frames_count > 0:
for j in range(valid_frames_count):
if input_sub_remover is not None and input_sub_remover.gui_mode:
original_frame = copy.deepcopy(frames_hr[j])
else:
original_frame = None
frame = frames_hr[j]
for k in range(len(inpaint_area)):
if j < len(comps[k]): # 确保索引有效
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h))
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
writer.write(frame)
if input_sub_remover is not None:
if tbar is not None:
input_sub_remover.update_progress(tbar, increment=1)
if original_frame is not None and input_sub_remover.gui_mode:
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
except Exception as e:
print(f"Error during video processing: {str(e)}")
# 不抛出异常,允许程序继续执行
finally:
if writer:
writer.release()
if __name__ == '__main__':

View File

@@ -53,8 +53,16 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None
return logger
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
def get_version_numbers(version_str):
# 匹配主要版本号(支持 2.8.0 或 2.8.0.dev20250422+cu128 等格式)
pattern = r"^(\d+)\.(\d+)\.(\d+)"
match = re.match(pattern, version_str)
if match:
return [int(x) for x in match.groups()]
return [0, 0, 0] # 如果无法匹配,返回默认值
# 使用示例
IS_HIGH_VERSION = get_version_numbers(torch.__version__) >= [1, 12, 0]
def gpu_is_available():

View File

@@ -1,3 +1,4 @@
import torch
import shutil
import subprocess
import os
@@ -5,9 +6,12 @@ from pathlib import Path
import threading
import cv2
import sys
from functools import cached_property
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config
from backend.tools.common_tools import is_video_or_image, is_image_file
from backend.scenedetect import scene_detect
from backend.scenedetect.detectors import ContentDetector
from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint
@@ -17,13 +21,10 @@ from backend.tools.inpaint_tools import create_mask, batch_generator
import importlib
import platform
import tempfile
import torch
import multiprocessing
from shapely.geometry import Polygon
import time
from tqdm import tqdm
from tools.infer import utility
from tools.infer.predict_det import TextDetector
class SubtitleDetect:
@@ -32,14 +33,23 @@ class SubtitleDetect:
"""
def __init__(self, video_path, sub_area=None):
self.video_path = video_path
self.sub_area = sub_area
@cached_property
def text_detector(self):
import paddle
paddle.disable_signal_handler()
from paddleocr.tools.infer import utility
from paddleocr.tools.infer.predict_det import TextDetector
# 获取参数对象
importlib.reload(config)
args = utility.parse_args()
args.det_algorithm = 'DB'
args.det_model_dir = config.DET_MODEL_PATH
self.text_detector = TextDetector(args)
self.video_path = video_path
self.sub_area = sub_area
args.det_model_dir = self.convertToOnnxModelIfNeeded(config.DET_MODEL_PATH)
args.use_onnx=len(config.ONNX_PROVIDERS) > 0
args.onnx_providers=config.ONNX_PROVIDERS
return TextDetector(args)
def detect_subtitle(self, img):
dt_boxes, elapse = self.text_detector(img)
@@ -119,6 +129,52 @@ class SubtitleDetect:
new_subtitle_frame_no_box_dict[key] = subtitle_frame_no_box_dict[key]
return new_subtitle_frame_no_box_dict
def convertToOnnxModelIfNeeded(self, model_dir, model_filename="inference.pdmodel", params_filename="inference.pdiparams", opset_version=14):
"""Converts a Paddle model to ONNX if ONNX providers are available and the model does not already exist."""
if not config.ONNX_PROVIDERS:
return model_dir
onnx_model_path = os.path.join(model_dir, "model.onnx")
if os.path.exists(onnx_model_path):
print(f"ONNX model already exists: {onnx_model_path}. Skipping conversion.")
return onnx_model_path
print(f"Converting Paddle model {model_dir} to ONNX...")
model_file = os.path.join(model_dir, model_filename)
params_file = os.path.join(model_dir, params_filename) if params_filename else ""
try:
import paddle2onnx
# Ensure the target directory exists
os.makedirs(os.path.dirname(onnx_model_path), exist_ok=True)
# Convert and save the model
onnx_model = paddle2onnx.export(
model_filename=model_file,
params_filename=params_file,
save_file=onnx_model_path,
opset_version=opset_version,
auto_upgrade_opset=True,
verbose=True,
enable_onnx_checker=True,
enable_experimental_op=True,
enable_optimize=True,
custom_op_info={},
deploy_backend="onnxruntime",
calibration_file="calibration.cache",
external_file=os.path.join(model_dir, "external_data"),
export_fp16_model=False,
)
print(f"Conversion successful. ONNX model saved to: {onnx_model_path}")
return onnx_model_path
except Exception as e:
print(f"Error during conversion: {e}")
return model_dir
@staticmethod
def split_range_by_scene(intervals, points):
# 确保离散值列表是有序的
@@ -257,7 +313,43 @@ class SubtitleDetect:
return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
@staticmethod
def expand_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
def expand_and_merge_intervals(intervals, expand_size=config.STTN_NEIGHBOR_STRIDE*config.STTN_REFERENCE_LENGTH, max_length=config.STTN_MAX_LOAD_NUM):
# 初始化输出区间列表
expanded_intervals = []
# 对每个原始区间进行扩展
for interval in intervals:
start, end = interval
# 扩展至至少 'expand_size' 个单位,但不超过 'max_length' 个单位
expansion_amount = max(expand_size - (end - start + 1), 0)
# 在保证包含原区间的前提下尽可能平分前后扩展量
expand_start = max(start - expansion_amount // 2, 1) # 确保起始点不小于1
expand_end = end + expansion_amount // 2
# 如果扩展后的区间超出了最大长度,进行调整
if (expand_end - expand_start + 1) > max_length:
expand_end = expand_start + max_length - 1
# 对于单点的处理,需额外保证有至少 'expand_size' 长度
if start == end:
if expand_end - expand_start + 1 < expand_size:
expand_end = expand_start + expand_size - 1
# 检查与前一个区间是否有重叠并进行相应的合并
if expanded_intervals and expand_start <= expanded_intervals[-1][1]:
previous_start, previous_end = expanded_intervals.pop()
expand_start = previous_start
expand_end = max(expand_end, previous_end)
# 添加扩展后的区间至结果列表
expanded_intervals.append((expand_start, expand_end))
return expanded_intervals
@staticmethod
def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
"""
合并传入的字幕起始区间确保区间大小最低为STTN_REFERENCE_LENGTH
"""
@@ -290,12 +382,10 @@ class SubtitleDetect:
for start, end in expanded[1:]:
last_start, last_end = merged[-1]
# 检查是否重叠
if start <= last_end and (
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 需要合并
merged[-1] = (last_start, max(last_end, end)) # 合并区间
elif start == last_end + 1 and (
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 相邻区间也需要合并的场景
merged[-1] = (last_start, end)
else:
@@ -483,7 +573,7 @@ class SubtitleRemover:
self.gui_mode = gui_mode
# 判断是否为图片
self.is_picture = False
if str(vd_path).endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
if is_image_file(str(vd_path)):
self.sub_area = None
self.is_picture = True
# 视频路径
@@ -505,8 +595,7 @@ class SubtitleRemover:
# 创建视频临时对象windows下delete=True会有permission denied的报错
self.video_temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
# 创建视频写对象
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps,
self.size)
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
self.video_inpaint = None
self.lama_inpaint = None
@@ -518,6 +607,14 @@ class SubtitleRemover:
self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
if torch.cuda.is_available():
print('use GPU for acceleration')
if config.USE_DML:
print('use DirectML for acceleration')
if config.MODE != config.InpaintMode.STTN:
print('Warning: DirectML acceleration is only available for STTN model. Falling back to CPU for other models.')
for provider in config.ONNX_PROVIDERS:
print(f"Detected execution provider: {provider}")
# 总处理进度
self.progress_total = 0
self.progress_remover = 0
@@ -647,16 +744,14 @@ class SubtitleRemover:
self.lama_inpaint = LamaInpaint()
inpainted_frame = self.lama_inpaint(frame, single_mask)
self.video_writer.write(inpainted_frame)
print(
f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
inner_index += 1
self.update_progress(tbar, increment=1)
elif len(batch) > 1:
inpainted_frames = self.video_inpaint.inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(
f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
inner_index += 1
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
@@ -670,12 +765,13 @@ class SubtitleRemover:
print('[Processing] start removing subtitles...')
if self.sub_area is not None:
ymin, ymax, xmin, xmax = self.sub_area
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
mask = create_mask(self.mask_size, mask_area_coordinates)
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
else:
print('please set subtitle area first')
print('[Info] No subtitle area has been set. Video will be processed in full screen. As a result, the final outcome might be suboptimal.')
ymin, ymax, xmin, xmax = 0, self.frame_height, 0, self.frame_width
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
mask = create_mask(self.mask_size, mask_area_coordinates)
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
def sttn_mode(self, tbar):
# 是否跳过字幕帧寻找
@@ -688,7 +784,7 @@ class SubtitleRemover:
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
print(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list)
print(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:
@@ -847,7 +943,7 @@ class SubtitleRemover:
audio_merge_command = [config.FFMPEG_PATH,
"-y", "-i", self.video_temp_file.name,
"-i", temp.name,
"-vcodec", "copy",
"-vcodec", "libx264" if config.USE_H264 else "copy",
"-acodec", "copy",
"-loglevel", "error", self.video_out_name]
try:
@@ -859,7 +955,10 @@ class SubtitleRemover:
try:
os.remove(temp.name)
except Exception:
print(f'failed to delete temp file {temp.name}')
if platform.system() in ['Windows']:
pass
else:
print(f'failed to delete temp file {temp.name}')
self.is_successful_merged = True
finally:
temp.close()
@@ -874,9 +973,13 @@ class SubtitleRemover:
if __name__ == '__main__':
multiprocessing.set_start_method("spawn")
# 1. 提示用户输入视频路径
video_path = input(f"Please input video file path: ").strip()
video_path = input(f"Please input video or image file path: ").strip()
# 判断视频路径是不是一个目录,是目录的化,批量处理改目录下的所有视频文件
# 2. 按以下顺序传入字幕区域
# sub_area = (ymin, ymax, xmin, xmax)
# 3. 新建字幕提取对象
sd = SubtitleRemover(video_path, sub_area=None)
sd.run()
if is_video_or_image(video_path):
sd = SubtitleRemover(video_path, sub_area=None)
sd.run()
else:
print(f'Invalid video path: {video_path}')

View File

@@ -1,16 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.filterwarnings("ignore", category=Warning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

View File

@@ -1,109 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import numpy as np
import skimage
import paddle
import signal
import random
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy
from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
def term_mp(sig_num, frame):
""" kill all child processes
"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'
]
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, mode, logger, seed)
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last']
shuffle = loader_config['shuffle']
num_workers = loader_config['num_workers']
if 'use_shared_memory' in loader_config.keys():
use_shared_memory = loader_config['use_shared_memory']
else:
use_shared_memory = True
if mode == "Train":
# Distribute data to multiple cards
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
else:
# Distribute data to single card
batch_sampler = BatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
if 'collate_fn' in loader_config:
from . import collate_fn
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
else:
collate_fn = None
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory,
collate_fn=collate_fn)
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
return data_loader

View File

@@ -1,72 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numbers
import numpy as np
from collections import defaultdict
class DictCollator(object):
"""
data batch
"""
def __call__(self, batch):
# todosupport batch operators
data_dict = defaultdict(list)
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
data_dict[k].append(v)
for k in to_tensor_keys:
data_dict[k] = paddle.to_tensor(data_dict[k])
return data_dict
class ListCollator(object):
"""
data batch
"""
def __call__(self, batch):
# todosupport batch operators
data_dict = defaultdict(list)
to_tensor_idxs = []
for sample in batch:
for idx, v in enumerate(sample):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
to_tensor_idxs.append(idx)
data_dict[idx].append(v)
for idx in to_tensor_idxs:
data_dict[idx] = paddle.to_tensor(data_dict[idx])
return list(data_dict.values())
class SSLRotateCollate(object):
"""
bach: [
[(4*3xH*W), (4,)]
[(4*3xH*W), (4,)]
...
]
"""
def __call__(self, batch):
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
return output

View File

@@ -1,26 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.vision.transforms import ColorJitter as pp_ColorJitter
__all__ = ['ColorJitter']
class ColorJitter(object):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs):
self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
def __call__(self, data):
image = data['image']
image = self.aug(image)
data['image'] = image
return data

View File

@@ -1,74 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from .iaa_augment import IaaAugment
from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .ColorJitter import ColorJitter
from .operators import *
from .label_ops import *
from .east_process import *
from .sast_process import *
from .pg_process import *
from .gen_table_mask import *
from .vqa import *
from .fce_aug import *
from .fce_targets import FCENetTargets
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
return ops

View File

@@ -1,170 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import cv2
import random
import numpy as np
from PIL import Image
from shapely.geometry import Polygon
from ppocr.data.imaug.iaa_augment import IaaAugment
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
from tools.infer.utility import get_rotate_crop_image
class CopyPaste(object):
def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
self.ext_data_num = 1
self.objects_paste_ratio = objects_paste_ratio
self.limit_paste = limit_paste
augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
self.aug = IaaAugment(augmenter_args)
def __call__(self, data):
point_num = data['polys'].shape[1]
src_img = data['image']
src_polys = data['polys'].tolist()
src_ignores = data['ignore_tags'].tolist()
ext_data = data['ext_data'][0]
ext_image = ext_data['image']
ext_polys = ext_data['polys']
ext_ignores = ext_data['ignore_tags']
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
select_num = max(
1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
random.shuffle(indexs)
select_idxs = indexs[:select_num]
select_polys = ext_polys[select_idxs]
select_ignores = ext_ignores[select_idxs]
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
src_img = Image.fromarray(src_img).convert('RGBA')
for poly, tag in zip(select_polys, select_ignores):
box_img = get_rotate_crop_image(ext_image, poly)
src_img, box = self.paste_img(src_img, box_img, src_polys)
if box is not None:
box = box.tolist()
for _ in range(len(box), point_num):
box.append(box[-1])
src_polys.append(box)
src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
h, w = src_img.shape[:2]
src_polys = np.array(src_polys)
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
data['image'] = src_img
data['polys'] = src_polys
data['ignore_tags'] = np.array(src_ignores)
return data
def paste_img(self, src_img, box_img, src_polys):
box_img_pil = Image.fromarray(box_img).convert('RGBA')
src_w, src_h = src_img.size
box_w, box_h = box_img_pil.size
angle = np.random.randint(0, 360)
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
box = rotate_bbox(box_img, box, angle)[0]
box_img_pil = box_img_pil.rotate(angle, expand=1)
box_w, box_h = box_img_pil.width, box_img_pil.height
if src_w - box_w < 0 or src_h - box_h < 0:
return src_img, None
paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
src_h - box_h)
if paste_x is None:
return src_img, None
box[:, 0] += paste_x
box[:, 1] += paste_y
r, g, b, A = box_img_pil.split()
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
return src_img, box
def select_coord(self, src_polys, box, endx, endy):
if self.limit_paste:
xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
), box[:, 0].max(), box[:, 1].max()
for _ in range(50):
paste_x = random.randint(0, endx)
paste_y = random.randint(0, endy)
xmin1 = xmin + paste_x
xmax1 = xmax + paste_x
ymin1 = ymin + paste_y
ymax1 = ymax + paste_y
num_poly_in_rect = 0
for poly in src_polys:
if not is_poly_outside_rect(poly, xmin1, ymin1,
xmax1 - xmin1, ymax1 - ymin1):
num_poly_in_rect += 1
break
if num_poly_in_rect == 0:
return paste_x, paste_y
return None, None
else:
paste_x = random.randint(0, endx)
paste_y = random.randint(0, endy)
return paste_x, paste_y
def get_union(pD, pG):
return Polygon(pD).union(Polygon(pG)).area
def get_intersection_over_union(pD, pG):
return get_intersection(pD, pG) / get_union(pD, pG)
def get_intersection(pD, pG):
return Polygon(pD).intersection(Polygon(pG)).area
def rotate_bbox(img, text_polys, angle, scale=1):
"""
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
Args:
img: np.ndarray
text_polys: np.ndarray N*4*2
angle: int
scale: int
Returns:
"""
w = img.shape[1]
h = img.shape[0]
rangle = np.deg2rad(angle)
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
rot_mat[0, 2] += rot_move[0]
rot_mat[1, 2] += rot_move[1]
# ---------------------- rotate box ----------------------
rot_text_polys = list()
for bbox in text_polys:
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
rot_text_polys.append([point1, point2, point3, point4])
return np.array(rot_text_polys, dtype=np.float32)

View File

@@ -1,436 +0,0 @@
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
This code is refered from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
"""
import math
import cv2
import numpy as np
import json
import sys
import os
__all__ = ['EASTProcessTrain']
class EASTProcessTrain(object):
def __init__(self,
image_shape=[512, 512],
background_ratio=0.125,
min_crop_side_ratio=0.1,
min_text_size=10,
**kwargs):
self.input_size = image_shape[1]
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
self.background_ratio = background_ratio
self.min_crop_side_ratio = min_crop_side_ratio
self.min_text_size = min_text_size
def preprocess(self, im):
input_size = self.input_size
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(input_size) / float(im_size_max)
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
# im = im[:, :, ::-1].astype(np.float32)
im = im / 255
im -= img_mean
im /= img_std
new_h, new_w, _ = im.shape
im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
im_padded[:new_h, :new_w, :] = im
im_padded = im_padded.transpose((2, 0, 1))
im_padded = im_padded[np.newaxis, :]
return im_padded, im_scale
def rotate_im_poly(self, im, text_polys):
"""
rotate image with 90 / 180 / 270 degre
"""
im_w, im_h = im.shape[1], im.shape[0]
dst_im = im.copy()
dst_polys = []
rand_degree_ratio = np.random.rand()
rand_degree_cnt = 1
if 0.333 < rand_degree_ratio < 0.666:
rand_degree_cnt = 2
elif rand_degree_ratio > 0.666:
rand_degree_cnt = 3
for i in range(rand_degree_cnt):
dst_im = np.rot90(dst_im)
rot_degree = -90 * rand_degree_cnt
rot_angle = rot_degree * math.pi / 180.0
n_poly = text_polys.shape[0]
cx, cy = 0.5 * im_w, 0.5 * im_h
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
for i in range(n_poly):
wordBB = text_polys[i]
poly = []
for j in range(4):
sx, sy = wordBB[j][0], wordBB[j][1]
dx = math.cos(rot_angle) * (sx - cx)\
- math.sin(rot_angle) * (sy - cy) + ncx
dy = math.sin(rot_angle) * (sx - cx)\
+ math.cos(rot_angle) * (sy - cy) + ncy
poly.append([dx, dy])
dst_polys.append(poly)
dst_polys = np.array(dst_polys, dtype=np.float32)
return dst_im, dst_polys
def polygon_area(self, poly):
"""
compute area of a polygon
:param poly:
:return:
"""
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
return np.sum(edge) / 2.
def check_and_validate_polys(self, polys, tags, img_height, img_width):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
:param polys:
:param tags:
:return:
"""
h, w = img_height, img_width
if polys.shape[0] == 0:
return polys
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
validated_polys = []
validated_tags = []
for poly, tag in zip(polys, tags):
p_area = self.polygon_area(poly)
#invalid poly
if abs(p_area) < 1:
continue
if p_area > 0:
#'poly in wrong direction'
if not tag:
tag = True #reversed cases should be ignore
poly = poly[(0, 3, 2, 1), :]
validated_polys.append(poly)
validated_tags.append(tag)
return np.array(validated_polys), np.array(validated_tags)
def draw_img_polys(self, img, polys):
if len(img.shape) == 4:
img = np.squeeze(img, axis=0)
if img.shape[0] == 3:
img = img.transpose((1, 2, 0))
img[:, :, 2] += 123.68
img[:, :, 1] += 116.78
img[:, :, 0] += 103.94
cv2.imwrite("tmp.jpg", img)
img = cv2.imread("tmp.jpg")
for box in polys:
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
import random
ino = random.randint(0, 100)
cv2.imwrite("tmp_%d.jpg" % ino, img)
return
def shrink_poly(self, poly, r):
"""
fit a poly inside the origin poly, maybe bugs here...
used for generate the score map
:param poly: the text poly
:param r: r in the paper
:return: the shrinked poly
"""
# shrink ratio
R = 0.3
# find the longer pair
dist0 = np.linalg.norm(poly[0] - poly[1])
dist1 = np.linalg.norm(poly[2] - poly[3])
dist2 = np.linalg.norm(poly[0] - poly[3])
dist3 = np.linalg.norm(poly[1] - poly[2])
if dist0 + dist1 > dist2 + dist3:
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
## p0, p1
theta = np.arctan2((poly[1][1] - poly[0][1]),
(poly[1][0] - poly[0][0]))
poly[0][0] += R * r[0] * np.cos(theta)
poly[0][1] += R * r[0] * np.sin(theta)
poly[1][0] -= R * r[1] * np.cos(theta)
poly[1][1] -= R * r[1] * np.sin(theta)
## p2, p3
theta = np.arctan2((poly[2][1] - poly[3][1]),
(poly[2][0] - poly[3][0]))
poly[3][0] += R * r[3] * np.cos(theta)
poly[3][1] += R * r[3] * np.sin(theta)
poly[2][0] -= R * r[2] * np.cos(theta)
poly[2][1] -= R * r[2] * np.sin(theta)
## p0, p3
theta = np.arctan2((poly[3][0] - poly[0][0]),
(poly[3][1] - poly[0][1]))
poly[0][0] += R * r[0] * np.sin(theta)
poly[0][1] += R * r[0] * np.cos(theta)
poly[3][0] -= R * r[3] * np.sin(theta)
poly[3][1] -= R * r[3] * np.cos(theta)
## p1, p2
theta = np.arctan2((poly[2][0] - poly[1][0]),
(poly[2][1] - poly[1][1]))
poly[1][0] += R * r[1] * np.sin(theta)
poly[1][1] += R * r[1] * np.cos(theta)
poly[2][0] -= R * r[2] * np.sin(theta)
poly[2][1] -= R * r[2] * np.cos(theta)
else:
## p0, p3
# print poly
theta = np.arctan2((poly[3][0] - poly[0][0]),
(poly[3][1] - poly[0][1]))
poly[0][0] += R * r[0] * np.sin(theta)
poly[0][1] += R * r[0] * np.cos(theta)
poly[3][0] -= R * r[3] * np.sin(theta)
poly[3][1] -= R * r[3] * np.cos(theta)
## p1, p2
theta = np.arctan2((poly[2][0] - poly[1][0]),
(poly[2][1] - poly[1][1]))
poly[1][0] += R * r[1] * np.sin(theta)
poly[1][1] += R * r[1] * np.cos(theta)
poly[2][0] -= R * r[2] * np.sin(theta)
poly[2][1] -= R * r[2] * np.cos(theta)
## p0, p1
theta = np.arctan2((poly[1][1] - poly[0][1]),
(poly[1][0] - poly[0][0]))
poly[0][0] += R * r[0] * np.cos(theta)
poly[0][1] += R * r[0] * np.sin(theta)
poly[1][0] -= R * r[1] * np.cos(theta)
poly[1][1] -= R * r[1] * np.sin(theta)
## p2, p3
theta = np.arctan2((poly[2][1] - poly[3][1]),
(poly[2][0] - poly[3][0]))
poly[3][0] += R * r[3] * np.cos(theta)
poly[3][1] += R * r[3] * np.sin(theta)
poly[2][0] -= R * r[2] * np.cos(theta)
poly[2][1] -= R * r[2] * np.sin(theta)
return poly
def generate_quad(self, im_size, polys, tags):
"""
Generate quadrangle.
"""
h, w = im_size
poly_mask = np.zeros((h, w), dtype=np.uint8)
score_map = np.zeros((h, w), dtype=np.uint8)
# (x1, y1, ..., x4, y4, short_edge_norm)
geo_map = np.zeros((h, w, 9), dtype=np.float32)
# mask used during traning, to ignore some hard areas
training_mask = np.ones((h, w), dtype=np.uint8)
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
tag = poly_tag[1]
r = [None, None, None, None]
for i in range(4):
dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
r[i] = min(dist1, dist2)
# score map
shrinked_poly = self.shrink_poly(
poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
cv2.fillPoly(score_map, shrinked_poly, 1)
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
# if the poly is too small, then ignore it during training
poly_h = min(
np.linalg.norm(poly[0] - poly[3]),
np.linalg.norm(poly[1] - poly[2]))
poly_w = min(
np.linalg.norm(poly[0] - poly[1]),
np.linalg.norm(poly[2] - poly[3]))
if min(poly_h, poly_w) < self.min_text_size:
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0)
if tag:
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0)
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
# geo map.
y_in_poly = xy_in_poly[:, 0]
x_in_poly = xy_in_poly[:, 1]
poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
for pno in range(4):
geo_channel_beg = pno * 2
geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
x_in_poly - poly[pno, 0]
geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
y_in_poly - poly[pno, 1]
geo_map[y_in_poly, x_in_poly, 8] = \
1.0 / max(min(poly_h, poly_w), 1.0)
return score_map, geo_map, training_mask
def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
"""
make random crop from the input image
:param im:
:param polys:
:param tags:
:param crop_background:
:param max_tries:
:return:
"""
h, w, _ = im.shape
pad_h = h // 10
pad_w = w // 10
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return im, polys, tags
for i in range(max_tries):
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
if xmax - xmin < self.min_crop_side_ratio * w or \
ymax - ymin < self.min_crop_side_ratio * h:
# area too small
continue
if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin)\
& (polys[:, :, 0] <= xmax)\
& (polys[:, :, 1] >= ymin)\
& (polys[:, :, 1] <= ymax)
selected_polys = np.where(
np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = []
tags = []
return im, polys, tags
else:
continue
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
polys[:, :, 0] -= xmin
polys[:, :, 1] -= ymin
return im, polys, tags
return im, polys, tags
def crop_background_infor(self, im, text_polys, text_tags):
im, text_polys, text_tags = self.crop_area(
im, text_polys, text_tags, crop_background=True)
if len(text_polys) > 0:
return None
# pad and resize image
input_size = self.input_size
im, ratio = self.preprocess(im)
score_map = np.zeros((input_size, input_size), dtype=np.float32)
geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
training_mask = np.ones((input_size, input_size), dtype=np.float32)
return im, score_map, geo_map, training_mask
def crop_foreground_infor(self, im, text_polys, text_tags):
im, text_polys, text_tags = self.crop_area(
im, text_polys, text_tags, crop_background=False)
if text_polys.shape[0] == 0:
return None
#continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
# pad and resize image
input_size = self.input_size
im, ratio = self.preprocess(im)
text_polys[:, :, 0] *= ratio
text_polys[:, :, 1] *= ratio
_, _, new_h, new_w = im.shape
# print(im.shape)
# self.draw_img_polys(im, text_polys)
score_map, geo_map, training_mask = self.generate_quad(
(new_h, new_w), text_polys, text_tags)
return im, score_map, geo_map, training_mask
def __call__(self, data):
im = data['image']
text_polys = data['polys']
text_tags = data['ignore_tags']
if im is None:
return None
if text_polys.shape[0] == 0:
return None
#add rotate cases
if np.random.rand() < 0.5:
im, text_polys = self.rotate_im_poly(im, text_polys)
h, w, _ = im.shape
text_polys, text_tags = self.check_and_validate_polys(text_polys,
text_tags, h, w)
if text_polys.shape[0] == 0:
return None
# random scale this image
rd_scale = np.random.choice(self.random_scale)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
if np.random.rand() < self.background_ratio:
outs = self.crop_background_infor(im, text_polys, text_tags)
else:
outs = self.crop_foreground_infor(im, text_polys, text_tags)
if outs is None:
return None
im, score_map, geo_map, training_mask = outs
score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
geo_map = np.swapaxes(geo_map, 1, 2)
geo_map = np.swapaxes(geo_map, 1, 0)
geo_map = geo_map[:, ::4, ::4].astype(np.float32)
training_mask = training_mask[np.newaxis, ::4, ::4]
training_mask = training_mask.astype(np.float32)
data['image'] = im[0]
data['score_map'] = score_map
data['geo_map'] = geo_map
data['training_mask'] = training_mask
return data

View File

@@ -1,564 +0,0 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/transforms.py
"""
import numpy as np
from PIL import Image, ImageDraw
import cv2
from shapely.geometry import Polygon
import math
from ppocr.utils.poly_nms import poly_intersection
class RandomScaling:
def __init__(self, size=800, scale=(3. / 4, 5. / 2), **kwargs):
"""Random scale the image while keeping aspect.
Args:
size (int) : Base size before scaling.
scale (tuple(float)) : The range of scaling.
"""
assert isinstance(size, int)
assert isinstance(scale, float) or isinstance(scale, tuple)
self.size = size
self.scale = scale if isinstance(scale, tuple) \
else (1 - scale, 1 + scale)
def __call__(self, data):
image = data['image']
text_polys = data['polys']
h, w, _ = image.shape
aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
scales = self.size * 1.0 / max(h, w) * aspect_ratio
scales = np.array([scales, scales])
out_size = (int(h * scales[1]), int(w * scales[0]))
image = cv2.resize(image, out_size[::-1])
data['image'] = image
text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
data['polys'] = text_polys
return data
class RandomCropFlip:
def __init__(self,
pad_ratio=0.1,
crop_ratio=0.5,
iter_num=1,
min_area_ratio=0.2,
**kwargs):
"""Random crop and flip a patch of the image.
Args:
crop_ratio (float): The ratio of cropping.
iter_num (int): Number of operations.
min_area_ratio (float): Minimal area ratio between cropped patch
and original image.
"""
assert isinstance(crop_ratio, float)
assert isinstance(iter_num, int)
assert isinstance(min_area_ratio, float)
self.pad_ratio = pad_ratio
self.epsilon = 1e-2
self.crop_ratio = crop_ratio
self.iter_num = iter_num
self.min_area_ratio = min_area_ratio
def __call__(self, results):
for i in range(self.iter_num):
results = self.random_crop_flip(results)
return results
def random_crop_flip(self, results):
image = results['image']
polygons = results['polys']
ignore_tags = results['ignore_tags']
if len(polygons) == 0:
return results
if np.random.random() >= self.crop_ratio:
return results
h, w, _ = image.shape
area = h * w
pad_h = int(h * self.pad_ratio)
pad_w = int(w * self.pad_ratio)
h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h,
pad_w)
if len(h_axis) == 0 or len(w_axis) == 0:
return results
attempt = 0
while attempt < 50:
attempt += 1
polys_keep = []
polys_new = []
ignore_tags_keep = []
ignore_tags_new = []
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
# area too small
continue
pts = np.stack([[xmin, xmax, xmax, xmin],
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
pp = Polygon(pts)
fail_flag = False
for polygon, ignore_tag in zip(polygons, ignore_tags):
ppi = Polygon(polygon.reshape(-1, 2))
ppiou, _ = poly_intersection(ppi, pp, buffer=0)
if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
np.abs(ppiou) > self.epsilon:
fail_flag = True
break
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
polys_new.append(polygon)
ignore_tags_new.append(ignore_tag)
else:
polys_keep.append(polygon)
ignore_tags_keep.append(ignore_tag)
if fail_flag:
continue
else:
break
cropped = image[ymin:ymax, xmin:xmax, :]
select_type = np.random.randint(3)
if select_type == 0:
img = np.ascontiguousarray(cropped[:, ::-1])
elif select_type == 1:
img = np.ascontiguousarray(cropped[::-1, :])
else:
img = np.ascontiguousarray(cropped[::-1, ::-1])
image[ymin:ymax, xmin:xmax, :] = img
results['img'] = image
if len(polys_new) != 0:
height, width, _ = cropped.shape
if select_type == 0:
for idx, polygon in enumerate(polys_new):
poly = polygon.reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
polys_new[idx] = poly
elif select_type == 1:
for idx, polygon in enumerate(polys_new):
poly = polygon.reshape(-1, 2)
poly[:, 1] = height - poly[:, 1] + 2 * ymin
polys_new[idx] = poly
else:
for idx, polygon in enumerate(polys_new):
poly = polygon.reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
poly[:, 1] = height - poly[:, 1] + 2 * ymin
polys_new[idx] = poly
polygons = polys_keep + polys_new
ignore_tags = ignore_tags_keep + ignore_tags_new
results['polys'] = np.array(polygons)
results['ignore_tags'] = ignore_tags
return results
def generate_crop_target(self, image, all_polys, pad_h, pad_w):
"""Generate crop target and make sure not to crop the polygon
instances.
Args:
image (ndarray): The image waited to be crop.
all_polys (list[list[ndarray]]): All polygons including ground
truth polygons and ground truth ignored polygons.
pad_h (int): Padding length of height.
pad_w (int): Padding length of width.
Returns:
h_axis (ndarray): Vertical cropping range.
w_axis (ndarray): Horizontal cropping range.
"""
h, w, _ = image.shape
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
text_polys = []
for polygon in all_polys:
rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
box = cv2.boxPoints(rect)
box = np.int0(box)
text_polys.append([box[0], box[1], box[2], box[3]])
polys = np.array(text_polys, dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
return h_axis, w_axis
class RandomCropPolyInstances:
"""Randomly crop images and make sure to contain at least one intact
instance."""
def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
super().__init__()
self.crop_ratio = crop_ratio
self.min_side_ratio = min_side_ratio
def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
assert isinstance(min_len, int)
assert len(valid_array) > min_len
start_array = valid_array.copy()
max_start = min(len(start_array) - min_len, max_start)
start_array[max_start:] = 0
start_array[0] = 1
diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
region_starts = np.where(diff_array < 0)[0]
region_ends = np.where(diff_array > 0)[0]
region_ind = np.random.randint(0, len(region_starts))
start = np.random.randint(region_starts[region_ind],
region_ends[region_ind])
end_array = valid_array.copy()
min_end = max(start + min_len, min_end)
end_array[:min_end] = 0
end_array[-1] = 1
diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
region_starts = np.where(diff_array < 0)[0]
region_ends = np.where(diff_array > 0)[0]
region_ind = np.random.randint(0, len(region_starts))
end = np.random.randint(region_starts[region_ind],
region_ends[region_ind])
return start, end
def sample_crop_box(self, img_size, results):
"""Generate crop box and make sure not to crop the polygon instances.
Args:
img_size (tuple(int)): The image size (h, w).
results (dict): The results dict.
"""
assert isinstance(img_size, tuple)
h, w = img_size[:2]
key_masks = results['polys']
x_valid_array = np.ones(w, dtype=np.int32)
y_valid_array = np.ones(h, dtype=np.int32)
selected_mask = key_masks[np.random.randint(0, len(key_masks))]
selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32)
max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
for mask in key_masks:
mask = mask.reshape((-1, 2)).astype(np.int32)
clip_x = np.clip(mask[:, 0], 0, w - 1)
clip_y = np.clip(mask[:, 1], 0, h - 1)
min_x, max_x = np.min(clip_x), np.max(clip_x)
min_y, max_y = np.min(clip_y), np.max(clip_y)
x_valid_array[min_x - 2:max_x + 3] = 0
y_valid_array[min_y - 2:max_y + 3] = 0
min_w = int(w * self.min_side_ratio)
min_h = int(h * self.min_side_ratio)
x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start,
min_x_end)
y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start,
min_y_end)
return np.array([x1, y1, x2, y2])
def crop_img(self, img, bbox):
assert img.ndim == 3
h, w, _ = img.shape
assert 0 <= bbox[1] < bbox[3] <= h
assert 0 <= bbox[0] < bbox[2] <= w
return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
def __call__(self, results):
image = results['image']
polygons = results['polys']
ignore_tags = results['ignore_tags']
if len(polygons) < 1:
return results
if np.random.random_sample() < self.crop_ratio:
crop_box = self.sample_crop_box(image.shape, results)
img = self.crop_img(image, crop_box)
results['image'] = img
# crop and filter masks
x1, y1, x2, y2 = crop_box
w = max(x2 - x1, 1)
h = max(y2 - y1, 1)
polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1
polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1
valid_masks_list = []
valid_tags_list = []
for ind, polygon in enumerate(polygons):
if (polygon[:, ::2] > -4).all() and (
polygon[:, ::2] < w + 4).all() and (
polygon[:, 1::2] > -4).all() and (
polygon[:, 1::2] < h + 4).all():
polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
valid_masks_list.append(polygon)
valid_tags_list.append(ignore_tags[ind])
results['polys'] = np.array(valid_masks_list)
results['ignore_tags'] = valid_tags_list
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
class RandomRotatePolyInstances:
def __init__(self,
rotate_ratio=0.5,
max_angle=10,
pad_with_fixed_color=False,
pad_value=(0, 0, 0),
**kwargs):
"""Randomly rotate images and polygon masks.
Args:
rotate_ratio (float): The ratio of samples to operate rotation.
max_angle (int): The maximum rotation angle.
pad_with_fixed_color (bool): The flag for whether to pad rotated
image with fixed value. If set to False, the rotated image will
be padded onto cropped image.
pad_value (tuple(int)): The color value for padding rotated image.
"""
self.rotate_ratio = rotate_ratio
self.max_angle = max_angle
self.pad_with_fixed_color = pad_with_fixed_color
self.pad_value = pad_value
def rotate(self, center, points, theta, center_shift=(0, 0)):
# rotate points.
(center_x, center_y) = center
center_y = -center_y
x, y = points[:, ::2], points[:, 1::2]
y = -y
theta = theta / 180 * math.pi
cos = math.cos(theta)
sin = math.sin(theta)
x = (x - center_x)
y = (y - center_y)
_x = center_x + x * cos - y * sin + center_shift[0]
_y = -(center_y + x * sin + y * cos) + center_shift[1]
points[:, ::2], points[:, 1::2] = _x, _y
return points
def cal_canvas_size(self, ori_size, degree):
assert isinstance(ori_size, tuple)
angle = degree * math.pi / 180.0
h, w = ori_size[:2]
cos = math.cos(angle)
sin = math.sin(angle)
canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
canvas_size = (canvas_h, canvas_w)
return canvas_size
def sample_angle(self, max_angle):
angle = np.random.random_sample() * 2 * max_angle - max_angle
return angle
def rotate_img(self, img, angle, canvas_size):
h, w = img.shape[:2]
rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
if self.pad_with_fixed_color:
target_img = cv2.warpAffine(
img,
rotation_matrix, (canvas_size[1], canvas_size[0]),
flags=cv2.INTER_NEAREST,
borderValue=self.pad_value)
else:
mask = np.zeros_like(img)
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
np.random.randint(0, w * 7 // 8))
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
mask = cv2.warpAffine(
mask,
rotation_matrix, (canvas_size[1], canvas_size[0]),
borderValue=[1, 1, 1])
target_img = cv2.warpAffine(
img,
rotation_matrix, (canvas_size[1], canvas_size[0]),
borderValue=[0, 0, 0])
target_img = target_img + img_cut * mask
return target_img
def __call__(self, results):
if np.random.random_sample() < self.rotate_ratio:
image = results['image']
polygons = results['polys']
h, w = image.shape[:2]
angle = self.sample_angle(self.max_angle)
canvas_size = self.cal_canvas_size((h, w), angle)
center_shift = (int((canvas_size[1] - w) / 2), int(
(canvas_size[0] - h) / 2))
image = self.rotate_img(image, angle, canvas_size)
results['image'] = image
# rotate polygons
rotated_masks = []
for mask in polygons:
rotated_mask = self.rotate((w / 2, h / 2), mask, angle,
center_shift)
rotated_masks.append(rotated_mask)
results['polys'] = np.array(rotated_masks)
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
class SquareResizePad:
def __init__(self,
target_size,
pad_ratio=0.6,
pad_with_fixed_color=False,
pad_value=(0, 0, 0),
**kwargs):
"""Resize or pad images to be square shape.
Args:
target_size (int): The target size of square shaped image.
pad_with_fixed_color (bool): The flag for whether to pad rotated
image with fixed value. If set to False, the rescales image will
be padded onto cropped image.
pad_value (tuple(int)): The color value for padding rotated image.
"""
assert isinstance(target_size, int)
assert isinstance(pad_ratio, float)
assert isinstance(pad_with_fixed_color, bool)
assert isinstance(pad_value, tuple)
self.target_size = target_size
self.pad_ratio = pad_ratio
self.pad_with_fixed_color = pad_with_fixed_color
self.pad_value = pad_value
def resize_img(self, img, keep_ratio=True):
h, w, _ = img.shape
if keep_ratio:
t_h = self.target_size if h >= w else int(h * self.target_size / w)
t_w = self.target_size if h <= w else int(w * self.target_size / h)
else:
t_h = t_w = self.target_size
img = cv2.resize(img, (t_w, t_h))
return img, (t_h, t_w)
def square_pad(self, img):
h, w = img.shape[:2]
if h == w:
return img, (0, 0)
pad_size = max(h, w)
if self.pad_with_fixed_color:
expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
expand_img[:] = self.pad_value
else:
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
np.random.randint(0, w * 7 // 8))
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
expand_img = cv2.resize(img_cut, (pad_size, pad_size))
if h > w:
y0, x0 = 0, (h - w) // 2
else:
y0, x0 = (w - h) // 2, 0
expand_img[y0:y0 + h, x0:x0 + w] = img
offset = (x0, y0)
return expand_img, offset
def square_pad_mask(self, points, offset):
x0, y0 = offset
pad_points = points.copy()
pad_points[::2] = pad_points[::2] + x0
pad_points[1::2] = pad_points[1::2] + y0
return pad_points
def __call__(self, results):
image = results['image']
polygons = results['polys']
h, w = image.shape[:2]
if np.random.random_sample() < self.pad_ratio:
image, out_size = self.resize_img(image, keep_ratio=True)
image, offset = self.square_pad(image)
else:
image, out_size = self.resize_img(image, keep_ratio=False)
offset = (0, 0)
results['image'] = image
try:
polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[
1] / w + offset[0]
polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[
0] / h + offset[1]
except:
pass
results['polys'] = polygons
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str

View File

@@ -1,658 +0,0 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py
"""
import cv2
import numpy as np
from numpy.fft import fft
from numpy.linalg import norm
import sys
class FCENetTargets:
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
for Arbitrary-Shaped Text Detection.
[https://arxiv.org/abs/2104.10442]
Args:
fourier_degree (int): The maximum Fourier transform degree k.
resample_step (float): The step size for resampling the text center
line (TCL). It's better not to exceed half of the minimum width.
center_region_shrink_ratio (float): The shrink ratio of text center
region.
level_size_divisors (tuple(int)): The downsample ratio on each level.
level_proportion_range (tuple(tuple(int))): The range of text sizes
assigned to each level.
"""
def __init__(self,
fourier_degree=5,
resample_step=4.0,
center_region_shrink_ratio=0.3,
level_size_divisors=(8, 16, 32),
level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
orientation_thr=2.0,
**kwargs):
super().__init__()
assert isinstance(level_size_divisors, tuple)
assert isinstance(level_proportion_range, tuple)
assert len(level_size_divisors) == len(level_proportion_range)
self.fourier_degree = fourier_degree
self.resample_step = resample_step
self.center_region_shrink_ratio = center_region_shrink_ratio
self.level_size_divisors = level_size_divisors
self.level_proportion_range = level_proportion_range
self.orientation_thr = orientation_thr
def vector_angle(self, vec1, vec2):
if vec1.ndim > 1:
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
else:
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
if vec2.ndim > 1:
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
else:
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
return np.arccos(
np.clip(
np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
def resample_line(self, line, n):
"""Resample n points on a line.
Args:
line (ndarray): The points composing a line.
n (int): The resampled points number.
Returns:
resampled_line (ndarray): The points composing the resampled line.
"""
assert line.ndim == 2
assert line.shape[0] >= 2
assert line.shape[1] == 2
assert isinstance(n, int)
assert n > 0
length_list = [
norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
]
total_length = sum(length_list)
length_cumsum = np.cumsum([0.0] + length_list)
delta_length = total_length / (float(n) + 1e-8)
current_edge_ind = 0
resampled_line = [line[0]]
for i in range(1, n):
current_line_len = i * delta_length
while current_line_len >= length_cumsum[current_edge_ind + 1]:
current_edge_ind += 1
current_edge_end_shift = current_line_len - length_cumsum[
current_edge_ind]
end_shift_ratio = current_edge_end_shift / length_list[
current_edge_ind]
current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
- line[current_edge_ind]
) * end_shift_ratio
resampled_line.append(current_point)
resampled_line.append(line[-1])
resampled_line = np.array(resampled_line)
return resampled_line
def reorder_poly_edge(self, points):
"""Get the respective points composing head edge, tail edge, top
sideline and bottom sideline.
Args:
points (ndarray): The points composing a text polygon.
Returns:
head_edge (ndarray): The two points composing the head edge of text
polygon.
tail_edge (ndarray): The two points composing the tail edge of text
polygon.
top_sideline (ndarray): The points composing top curved sideline of
text polygon.
bot_sideline (ndarray): The points composing bottom curved sideline
of text polygon.
"""
assert points.ndim == 2
assert points.shape[0] >= 4
assert points.shape[1] == 2
head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
head_edge, tail_edge = points[head_inds], points[tail_inds]
pad_points = np.vstack([points, points])
if tail_inds[1] < 1:
tail_inds[1] = len(points)
sideline1 = pad_points[head_inds[1]:tail_inds[1]]
sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
sideline_mean_shift = np.mean(
sideline1, axis=0) - np.mean(
sideline2, axis=0)
if sideline_mean_shift[1] > 0:
top_sideline, bot_sideline = sideline2, sideline1
else:
top_sideline, bot_sideline = sideline1, sideline2
return head_edge, tail_edge, top_sideline, bot_sideline
def find_head_tail(self, points, orientation_thr):
"""Find the head edge and tail edge of a text polygon.
Args:
points (ndarray): The points composing a text polygon.
orientation_thr (float): The threshold for distinguishing between
head edge and tail edge among the horizontal and vertical edges
of a quadrangle.
Returns:
head_inds (list): The indexes of two points composing head edge.
tail_inds (list): The indexes of two points composing tail edge.
"""
assert points.ndim == 2
assert points.shape[0] >= 4
assert points.shape[1] == 2
assert isinstance(orientation_thr, float)
if len(points) > 4:
pad_points = np.vstack([points, points[0]])
edge_vec = pad_points[1:] - pad_points[:-1]
theta_sum = []
adjacent_vec_theta = []
for i, edge_vec1 in enumerate(edge_vec):
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
adjacent_edge_vec = edge_vec[adjacent_ind]
temp_theta_sum = np.sum(
self.vector_angle(edge_vec1, adjacent_edge_vec))
temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
adjacent_edge_vec[1])
theta_sum.append(temp_theta_sum)
adjacent_vec_theta.append(temp_adjacent_theta)
theta_sum_score = np.array(theta_sum) / np.pi
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
poly_center = np.mean(points, axis=0)
edge_dist = np.maximum(
norm(
pad_points[1:] - poly_center, axis=-1),
norm(
pad_points[:-1] - poly_center, axis=-1))
dist_score = edge_dist / np.max(edge_dist)
position_score = np.zeros(len(edge_vec))
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
score += 0.35 * dist_score
if len(points) % 2 == 0:
position_score[(len(score) // 2 - 1)] += 1
position_score[-1] += 1
score += 0.1 * position_score
pad_score = np.concatenate([score, score])
score_matrix = np.zeros((len(score), len(score) - 3))
x = np.arange(len(score) - 3) / float(len(score) - 4)
gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
(x - 0.5) / 0.5, 2.) / 2)
gaussian = gaussian / np.max(gaussian)
for i in range(len(score)):
score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
score) - 1)] * gaussian * 0.3
head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
score_matrix.shape)
tail_start = (head_start + tail_increment + 2) % len(points)
head_end = (head_start + 1) % len(points)
tail_end = (tail_start + 1) % len(points)
if head_end > tail_end:
head_start, tail_start = tail_start, head_start
head_end, tail_end = tail_end, head_end
head_inds = [head_start, head_end]
tail_inds = [tail_start, tail_end]
else:
if self.vector_slope(points[1] - points[0]) + self.vector_slope(
points[3] - points[2]) < self.vector_slope(points[
2] - points[1]) + self.vector_slope(points[0] - points[
3]):
horizontal_edge_inds = [[0, 1], [2, 3]]
vertical_edge_inds = [[3, 0], [1, 2]]
else:
horizontal_edge_inds = [[3, 0], [1, 2]]
vertical_edge_inds = [[0, 1], [2, 3]]
vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
0]] - points[vertical_edge_inds[1][1]])
horizontal_len_sum = norm(points[horizontal_edge_inds[0][
0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
[1]])
if vertical_len_sum > horizontal_len_sum * orientation_thr:
head_inds = horizontal_edge_inds[0]
tail_inds = horizontal_edge_inds[1]
else:
head_inds = vertical_edge_inds[0]
tail_inds = vertical_edge_inds[1]
return head_inds, tail_inds
def resample_sidelines(self, sideline1, sideline2, resample_step):
"""Resample two sidelines to be of the same points number according to
step size.
Args:
sideline1 (ndarray): The points composing a sideline of a text
polygon.
sideline2 (ndarray): The points composing another sideline of a
text polygon.
resample_step (float): The resampled step size.
Returns:
resampled_line1 (ndarray): The resampled line 1.
resampled_line2 (ndarray): The resampled line 2.
"""
assert sideline1.ndim == sideline2.ndim == 2
assert sideline1.shape[1] == sideline2.shape[1] == 2
assert sideline1.shape[0] >= 2
assert sideline2.shape[0] >= 2
assert isinstance(resample_step, float)
length1 = sum([
norm(sideline1[i + 1] - sideline1[i])
for i in range(len(sideline1) - 1)
])
length2 = sum([
norm(sideline2[i + 1] - sideline2[i])
for i in range(len(sideline2) - 1)
])
total_length = (length1 + length2) / 2
resample_point_num = max(int(float(total_length) / resample_step), 1)
resampled_line1 = self.resample_line(sideline1, resample_point_num)
resampled_line2 = self.resample_line(sideline2, resample_point_num)
return resampled_line1, resampled_line2
def generate_center_region_mask(self, img_size, text_polys):
"""Generate text center region mask.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
center_region_mask (ndarray): The text center region mask.
"""
assert isinstance(img_size, tuple)
# assert check_argument.is_2dlist(text_polys)
h, w = img_size
center_region_mask = np.zeros((h, w), np.uint8)
center_region_boxes = []
for poly in text_polys:
# assert len(poly) == 1
polygon_points = poly.reshape(-1, 2)
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self.resample_sidelines(
top_line, bot_line, self.resample_step)
resampled_bot_line = resampled_bot_line[::-1]
center_line = (resampled_top_line + resampled_bot_line) / 2
line_head_shrink_len = norm(resampled_top_line[0] -
resampled_bot_line[0]) / 4.0
line_tail_shrink_len = norm(resampled_top_line[-1] -
resampled_bot_line[-1]) / 4.0
head_shrink_num = int(line_head_shrink_len // self.resample_step)
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
center_line = center_line[head_shrink_num:len(center_line) -
tail_shrink_num]
resampled_top_line = resampled_top_line[head_shrink_num:len(
resampled_top_line) - tail_shrink_num]
resampled_bot_line = resampled_bot_line[head_shrink_num:len(
resampled_bot_line) - tail_shrink_num]
for i in range(0, len(center_line) - 1):
tl = center_line[i] + (resampled_top_line[i] - center_line[i]
) * self.center_region_shrink_ratio
tr = center_line[i + 1] + (resampled_top_line[i + 1] -
center_line[i + 1]
) * self.center_region_shrink_ratio
br = center_line[i + 1] + (resampled_bot_line[i + 1] -
center_line[i + 1]
) * self.center_region_shrink_ratio
bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
) * self.center_region_shrink_ratio
current_center_box = np.vstack([tl, tr, br,
bl]).astype(np.int32)
center_region_boxes.append(current_center_box)
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
return center_region_mask
def resample_polygon(self, polygon, n=400):
"""Resample one polygon with n points on its boundary.
Args:
polygon (list[float]): The input polygon.
n (int): The number of resampled points.
Returns:
resampled_polygon (list[float]): The resampled polygon.
"""
length = []
for i in range(len(polygon)):
p1 = polygon[i]
if i == len(polygon) - 1:
p2 = polygon[0]
else:
p2 = polygon[i + 1]
length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
total_length = sum(length)
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
n_on_each_line = n_on_each_line.astype(np.int32)
new_polygon = []
for i in range(len(polygon)):
num = n_on_each_line[i]
p1 = polygon[i]
if i == len(polygon) - 1:
p2 = polygon[0]
else:
p2 = polygon[i + 1]
if num == 0:
continue
dxdy = (p2 - p1) / num
for j in range(num):
point = p1 + dxdy * j
new_polygon.append(point)
return np.array(new_polygon)
def normalize_polygon(self, polygon):
"""Normalize one polygon so that its start point is at right most.
Args:
polygon (list[float]): The origin polygon.
Returns:
new_polygon (lost[float]): The polygon with start point at right.
"""
temp_polygon = polygon - polygon.mean(axis=0)
x = np.abs(temp_polygon[:, 0])
y = temp_polygon[:, 1]
index_x = np.argsort(x)
index_y = np.argmin(y[index_x[:8]])
index = index_x[index_y]
new_polygon = np.concatenate([polygon[index:], polygon[:index]])
return new_polygon
def poly2fourier(self, polygon, fourier_degree):
"""Perform Fourier transformation to generate Fourier coefficients ck
from polygon.
Args:
polygon (ndarray): An input polygon.
fourier_degree (int): The maximum Fourier degree K.
Returns:
c (ndarray(complex)): Fourier coefficients.
"""
points = polygon[:, 0] + polygon[:, 1] * 1j
c_fft = fft(points) / len(points)
c = np.hstack((c_fft[-fourier_degree:], c_fft[:fourier_degree + 1]))
return c
def clockwise(self, c, fourier_degree):
"""Make sure the polygon reconstructed from Fourier coefficients c in
the clockwise direction.
Args:
polygon (list[float]): The origin polygon.
Returns:
new_polygon (lost[float]): The polygon in clockwise point order.
"""
if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
return c
elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
return c[::-1]
else:
if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
return c
else:
return c[::-1]
def cal_fourier_signature(self, polygon, fourier_degree):
"""Calculate Fourier signature from input polygon.
Args:
polygon (ndarray): The input polygon.
fourier_degree (int): The maximum Fourier degree K.
Returns:
fourier_signature (ndarray): An array shaped (2k+1, 2) containing
real part and image part of 2k+1 Fourier coefficients.
"""
resampled_polygon = self.resample_polygon(polygon)
resampled_polygon = self.normalize_polygon(resampled_polygon)
fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree)
fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
real_part = np.real(fourier_coeff).reshape((-1, 1))
image_part = np.imag(fourier_coeff).reshape((-1, 1))
fourier_signature = np.hstack([real_part, image_part])
return fourier_signature
def generate_fourier_maps(self, img_size, text_polys):
"""Generate Fourier coefficient maps.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
fourier_real_map (ndarray): The Fourier coefficient real part maps.
fourier_image_map (ndarray): The Fourier coefficient image part
maps.
"""
assert isinstance(img_size, tuple)
h, w = img_size
k = self.fourier_degree
real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
for poly in text_polys:
mask = np.zeros((h, w), dtype=np.uint8)
polygon = np.array(poly).reshape((1, -1, 2))
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
fourier_coeff = self.cal_fourier_signature(polygon[0], k)
for i in range(-k, k + 1):
if i != 0:
real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
1 - mask) * real_map[i + k, :, :]
imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
1 - mask) * imag_map[i + k, :, :]
else:
yx = np.argwhere(mask > 0.5)
k_ind = np.ones((len(yx)), dtype=np.int64) * k
y, x = yx[:, 0], yx[:, 1]
real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
return real_map, imag_map
def generate_text_region_mask(self, img_size, text_polys):
"""Generate text center region mask and geometry attribute maps.
Args:
img_size (tuple): The image size (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
text_region_mask (ndarray): The text region mask.
"""
assert isinstance(img_size, tuple)
h, w = img_size
text_region_mask = np.zeros((h, w), dtype=np.uint8)
for poly in text_polys:
polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
cv2.fillPoly(text_region_mask, polygon, 1)
return text_region_mask
def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
"""Generate effective mask by setting the ineffective regions to 0 and
effective regions to 1.
Args:
mask_size (tuple): The mask size.
polygons_ignore (list[[ndarray]]: The list of ignored text
polygons.
Returns:
mask (ndarray): The effective mask of (height, width).
"""
mask = np.ones(mask_size, dtype=np.uint8)
for poly in polygons_ignore:
instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2)
cv2.fillPoly(mask, instance, 0)
return mask
def generate_level_targets(self, img_size, text_polys, ignore_polys):
"""Generate ground truth target on each level.
Args:
img_size (list[int]): Shape of input image.
text_polys (list[list[ndarray]]): A list of ground truth polygons.
ignore_polys (list[list[ndarray]]): A list of ignored polygons.
Returns:
level_maps (list(ndarray)): A list of ground target on each level.
"""
h, w = img_size
lv_size_divs = self.level_size_divisors
lv_proportion_range = self.level_proportion_range
lv_text_polys = [[] for i in range(len(lv_size_divs))]
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
level_maps = []
for poly in text_polys:
polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
_, _, box_w, box_h = cv2.boundingRect(polygon)
proportion = max(box_h, box_w) / (h + 1e-8)
for ind, proportion_range in enumerate(lv_proportion_range):
if proportion_range[0] < proportion < proportion_range[1]:
lv_text_polys[ind].append(poly / lv_size_divs[ind])
for ignore_poly in ignore_polys:
polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
_, _, box_w, box_h = cv2.boundingRect(polygon)
proportion = max(box_h, box_w) / (h + 1e-8)
for ind, proportion_range in enumerate(lv_proportion_range):
if proportion_range[0] < proportion < proportion_range[1]:
lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind])
for ind, size_divisor in enumerate(lv_size_divs):
current_level_maps = []
level_img_size = (h // size_divisor, w // size_divisor)
text_region = self.generate_text_region_mask(
level_img_size, lv_text_polys[ind])[None]
current_level_maps.append(text_region)
center_region = self.generate_center_region_mask(
level_img_size, lv_text_polys[ind])[None]
current_level_maps.append(center_region)
effective_mask = self.generate_effective_mask(
level_img_size, lv_ignore_polys[ind])[None]
current_level_maps.append(effective_mask)
fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
level_img_size, lv_text_polys[ind])
current_level_maps.append(fourier_real_map)
current_level_maps.append(fourier_image_maps)
level_maps.append(np.concatenate(current_level_maps))
return level_maps
def generate_targets(self, results):
"""Generate the ground truth targets for FCENet.
Args:
results (dict): The input result dictionary.
Returns:
results (dict): The output result dictionary.
"""
assert isinstance(results, dict)
image = results['image']
polygons = results['polys']
ignore_tags = results['ignore_tags']
h, w, _ = image.shape
polygon_masks = []
polygon_masks_ignore = []
for tag, polygon in zip(ignore_tags, polygons):
if tag is True:
polygon_masks_ignore.append(polygon)
else:
polygon_masks.append(polygon)
level_maps = self.generate_level_targets((h, w), polygon_masks,
polygon_masks_ignore)
mapping = {
'p3_maps': level_maps[0],
'p4_maps': level_maps[1],
'p5_maps': level_maps[2]
}
for key, value in mapping.items():
results[key] = value
return results
def __call__(self, results):
results = self.generate_targets(results)
return results

View File

@@ -1,244 +0,0 @@
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class GenTableMask(object):
""" gen table mask """
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
return box_list, projection_map
def projection_cx(self, box_img):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
erode = cv2.erode(thresh1, kernel, iterations=1)
else:
erode = thresh1
# 水平膨胀
kernel = np.ones((1, 5), np.uint8)
erosion = cv2.dilate(erode, kernel, iterations=1)
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
split_bbox_list = []
if len(box_list) > 1:
for i, (h_start, h_end) in enumerate(box_list):
if i == 0:
h_start = 0
if i == len(box_list):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
h_end += 1
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
split_bbox_list.append([w_start, h_start, w_end, h_end])
else:
split_bbox_list.append([0, 0, w, h])
return split_bbox_list
def shrink_bbox(self, bbox):
left, top, right, bottom = bbox
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
left_new = left + sh_w
right_new = right - sh_w
top_new = top + sh_h
bottom_new = bottom - sh_h
if left_new >= right_new:
left_new = left
right_new = right
if top_new >= bottom_new:
top_new = top
bottom_new = bottom
return [left_new, top_new, right_new, bottom_new]
def __call__(self, data):
img = data['image']
cells = data['cells']
height, width = img.shape[0:2]
if self.mask_type == 1:
mask_img = np.zeros((height, width), dtype=np.float32)
else:
mask_img = np.zeros((height, width, 3), dtype=np.float32)
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
left, top, right, bottom = bbox
box_img = img[top:bottom, left:right, :].copy()
split_bbox_list = self.projection_cx(box_img)
for sno in range(len(split_bbox_list)):
split_bbox_list[sno][0] += left
split_bbox_list[sno][1] += top
split_bbox_list[sno][2] += left
split_bbox_list[sno][3] += top
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
class ResizeTableImage(object):
def __init__(self, max_len, **kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
def get_img_bbox(self, cells):
bbox_list = []
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = []
for bno in range(len(bbox_list)):
left, top, right, bottom = bbox_list[bno].copy()
left = int(left * ratio)
top = int(top * ratio)
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len
return data
class PaddingTableImage(object):
def __init__(self, **kwargs):
super(PaddingTableImage, self).__init__()
def __call__(self, data):
img = data['image']
max_len = data['max_len']
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
return data

View File

@@ -1,105 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import imgaug
import imgaug.augmenters as iaa
class AugmenterBuilder(object):
def __init__(self):
pass
def build(self, args, root=True):
if args is None or len(args) == 0:
return None
elif isinstance(args, list):
if root:
sequence = [self.build(value, root=False) for value in args]
return iaa.Sequential(sequence)
else:
return getattr(iaa, args[0])(
*[self.to_tuple_if_list(a) for a in args[1:]])
elif isinstance(args, dict):
cls = getattr(iaa, args['type'])
return cls(**{
k: self.to_tuple_if_list(v)
for k, v in args['args'].items()
})
else:
raise RuntimeError('unknown augmenter arg: ' + str(args))
def to_tuple_if_list(self, obj):
if isinstance(obj, list):
return tuple(obj)
return obj
class IaaAugment():
def __init__(self, augmenter_args=None, **kwargs):
if augmenter_args is None:
augmenter_args = [{
'type': 'Fliplr',
'args': {
'p': 0.5
}
}, {
'type': 'Affine',
'args': {
'rotate': [-10, 10]
}
}, {
'type': 'Resize',
'args': {
'size': [0.5, 3]
}
}]
self.augmenter = AugmenterBuilder().build(augmenter_args)
def __call__(self, data):
image = data['image']
shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
data['image'] = aug.augment_image(image)
data = self.may_augment_annotation(aug, data, shape)
return data
def may_augment_annotation(self, aug, data, shape):
if aug is None:
return data
line_polys = []
for poly in data['polys']:
new_poly = self.may_augment_poly(aug, shape, poly)
line_polys.append(new_poly)
data['polys'] = np.array(line_polys)
return data
def may_augment_poly(self, aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(
keypoints, shape=img_shape)])[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly

File diff suppressed because it is too large Load Diff

View File

@@ -1,173 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
np.seterr(divide='ignore', invalid='ignore')
import pyclipper
from shapely.geometry import Polygon
import sys
import warnings
warnings.simplefilter("ignore")
__all__ = ['MakeBorderMap']
class MakeBorderMap(object):
def __init__(self,
shrink_ratio=0.4,
thresh_min=0.3,
thresh_max=0.7,
**kwargs):
self.shrink_ratio = shrink_ratio
self.thresh_min = thresh_min
self.thresh_max = thresh_max
def __call__(self, data):
img = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
canvas = np.zeros(img.shape[:2], dtype=np.float32)
mask = np.zeros(img.shape[:2], dtype=np.float32)
for i in range(len(text_polys)):
if ignore_tags[i]:
continue
self.draw_border_map(text_polys[i], canvas, mask=mask)
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
data['threshold_map'] = canvas
data['threshold_mask'] = mask
return data
def draw_border_map(self, polygon, canvas, mask):
polygon = np.array(polygon)
assert polygon.ndim == 2
assert polygon.shape[1] == 2
polygon_shape = Polygon(polygon)
if polygon_shape.area <= 0:
return
distance = polygon_shape.area * (
1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
padded_polygon = np.array(padding.Execute(distance)[0])
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
xmin = padded_polygon[:, 0].min()
xmax = padded_polygon[:, 0].max()
ymin = padded_polygon[:, 1].min()
ymax = padded_polygon[:, 1].max()
width = xmax - xmin + 1
height = ymax - ymin + 1
polygon[:, 0] = polygon[:, 0] - xmin
polygon[:, 1] = polygon[:, 1] - ymin
xs = np.broadcast_to(
np.linspace(
0, width - 1, num=width).reshape(1, width), (height, width))
ys = np.broadcast_to(
np.linspace(
0, height - 1, num=height).reshape(height, 1), (height, width))
distance_map = np.zeros(
(polygon.shape[0], height, width), dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
xmin_valid - xmin:xmax_valid - xmax + width],
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
def _distance(self, xs, ys, point_1, point_2):
'''
compute the distance from point to a line
ys: coordinates in the first axis
xs: coordinates in the second axis
point_1, point_2: (x, y), the end of the line
'''
height, width = xs.shape[:2]
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
1])
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
1])
square_distance = np.square(point_1[0] - point_2[0]) + np.square(
point_1[1] - point_2[1])
cosin = (square_distance - square_distance_1 - square_distance_2) / (
2 * np.sqrt(square_distance_1 * square_distance_2))
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
square_distance)
result[cosin <
0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
< 0]
# self.extend_line(point_1, point_2, result)
return result
def extend_line(self, point_1, point_2, result, shrink_ratio):
ex_point_1 = (int(
round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
int(
round(point_1[1] + (point_1[1] - point_2[1]) * (
1 + shrink_ratio))))
cv2.line(
result,
tuple(ex_point_1),
tuple(point_1),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0)
ex_point_2 = (int(
round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
int(
round(point_2[1] + (point_2[1] - point_1[1]) * (
1 + shrink_ratio))))
cv2.line(
result,
tuple(ex_point_2),
tuple(point_2),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0)
return ex_point_1, ex_point_2

View File

@@ -1,106 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import cv2
import numpy as np
import pyclipper
from shapely.geometry import Polygon
__all__ = ['MakePseGt']
class MakePseGt(object):
def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
self.kernel_num = kernel_num
self.min_shrink_ratio = min_shrink_ratio
self.size = size
def __call__(self, data):
image = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
h, w, _ = image.shape
short_edge = min(h, w)
if short_edge < self.size:
# keep short_size >= self.size
scale = self.size / short_edge
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
text_polys *= scale
gt_kernels = []
for i in range(1, self.kernel_num + 1):
# s1->sn, from big to small
rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1
) * i
text_kernel, ignore_tags = self.generate_kernel(
image.shape[0:2], rate, text_polys, ignore_tags)
gt_kernels.append(text_kernel)
training_mask = np.ones(image.shape[0:2], dtype='uint8')
for i in range(text_polys.shape[0]):
if ignore_tags[i]:
cv2.fillPoly(training_mask,
text_polys[i].astype(np.int32)[np.newaxis, :, :],
0)
gt_kernels = np.array(gt_kernels)
gt_kernels[gt_kernels > 0] = 1
data['image'] = image
data['polys'] = text_polys
data['gt_kernels'] = gt_kernels[0:]
data['gt_text'] = gt_kernels[0]
data['mask'] = training_mask.astype('float32')
return data
def generate_kernel(self,
img_size,
shrink_ratio,
text_polys,
ignore_tags=None):
"""
Refer to part of the code:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
"""
h, w = img_size
text_kernel = np.zeros((h, w), dtype=np.float32)
for i, poly in enumerate(text_polys):
polygon = Polygon(poly)
distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (
polygon.length + 1e-6)
subject = [tuple(l) for l in poly]
pco = pyclipper.PyclipperOffset()
pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrinked = np.array(pco.Execute(-distance))
if len(shrinked) == 0 or shrinked.size == 0:
if ignore_tags is not None:
ignore_tags[i] = True
continue
try:
shrinked = np.array(shrinked[0]).reshape(-1, 2)
except:
if ignore_tags is not None:
ignore_tags[i] = True
continue
cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
return text_kernel, ignore_tags

View File

@@ -1,123 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
from shapely.geometry import Polygon
import pyclipper
__all__ = ['MakeShrinkMap']
class MakeShrinkMap(object):
r'''
Making binary mask from detection data with ICDAR format.
Typically following the process of class `MakeICDARData`.
'''
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
self.min_text_size = min_text_size
self.shrink_ratio = shrink_ratio
def __call__(self, data):
image = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
h, w = image.shape[:2]
text_polys, ignore_tags = self.validate_polygons(text_polys,
ignore_tags, h, w)
gt = np.zeros((h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32)
for i in range(len(text_polys)):
polygon = text_polys[i]
height = max(polygon[:, 1]) - min(polygon[:, 1])
width = max(polygon[:, 0]) - min(polygon[:, 0])
if ignore_tags[i] or min(height, width) < self.min_text_size:
cv2.fillPoly(mask,
polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
else:
polygon_shape = Polygon(polygon)
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON)
shrinked = []
# Increase the shrink ratio every time we get multiple polygon returned back
possible_ratios = np.arange(self.shrink_ratio, 1,
self.shrink_ratio)
np.append(possible_ratios, 1)
# print(possible_ratios)
for ratio in possible_ratios:
# print(f"Change shrink ratio to {ratio}")
distance = polygon_shape.area * (
1 - np.power(ratio, 2)) / polygon_shape.length
shrinked = padding.Execute(-distance)
if len(shrinked) == 1:
break
if shrinked == []:
cv2.fillPoly(mask,
polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
continue
for each_shirnk in shrinked:
shirnk = np.array(each_shirnk).reshape(-1, 2)
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
data['shrink_map'] = gt
data['shrink_mask'] = mask
return data
def validate_polygons(self, polygons, ignore_tags, h, w):
'''
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
'''
if len(polygons) == 0:
return polygons, ignore_tags
assert len(polygons) == len(ignore_tags)
for polygon in polygons:
polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
for i in range(len(polygons)):
area = self.polygon_area(polygons[i])
if abs(area) < 1:
ignore_tags[i] = True
if area > 0:
polygons[i] = polygons[i][::-1, :]
return polygons, ignore_tags
def polygon_area(self, polygon):
"""
compute polygon area
"""
area = 0
q = polygon[-1]
for p in polygon:
area += p[0] * q[1] - p[1] * q[0]
q = p
return area / 2.0

View File

@@ -1,468 +0,0 @@
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
import math
class DecodeImage(object):
""" decode image """
def __init__(self,
img_mode='RGB',
channel_first=False,
ignore_orientation=False,
**kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
self.ignore_orientation = ignore_orientation
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
if self.ignore_orientation:
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
cv2.IMREAD_COLOR)
else:
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NRTRDecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class Fasttext(object):
def __init__(self, path="None", **kwargs):
import fasttext
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
label = data['label']
fast_label = self.fast_model[label]
data['fast_label'] = fast_label
return data
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class Pad(object):
def __init__(self, size=None, size_div=32, **kwargs):
if size is not None and not isinstance(size, (int, list, tuple)):
raise TypeError("Type of target_size is invalid. Now is {}".format(
type(size)))
if isinstance(size, int):
size = [size, size]
self.size = size
self.size_div = size_div
def __call__(self, data):
img = data['image']
img_h, img_w = img.shape[0], img.shape[1]
if self.size:
resize_h2, resize_w2 = self.size
assert (
img_h < resize_h2 and img_w < resize_w2
), '(h, w) of target size should be greater than (img_h, img_w)'
else:
resize_h2 = max(
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
self.size_div)
resize_w2 = max(
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
self.size_div)
img = cv2.copyMakeBorder(
img,
0,
resize_h2 - img_h,
0,
resize_w2 - img_w,
cv2.BORDER_CONSTANT,
value=0)
data['image'] = img
return data
class Resize(object):
def __init__(self, size=(640, 640), **kwargs):
self.size = size
def resize_image(self, img):
resize_h, resize_w = self.size
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, [ratio_h, ratio_w]
def __call__(self, data):
img = data['image']
if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
if 'polys' in data:
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize
return data
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest(object):
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
class KieResize(object):
def __init__(self, **kwargs):
super(KieResize, self).__init__()
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
'img_scale'][1]
def __call__(self, data):
img = data['image']
points = data['points']
src_h, src_w, _ = img.shape
im_resized, scale_factor, [ratio_h, ratio_w
], [new_h, new_w] = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor)
data['ori_image'] = img
data['ori_boxes'] = points
data['points'] = resize_points
data['image'] = im_resized
data['shape'] = np.array([new_h, new_w])
return data
def resize_image(self, img):
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
scale = [512, 1024]
h, w = img.shape[:2]
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
scale_factor) + 0.5)
max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(img, (resize_w, resize_h))
new_h, new_w = im.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
def resize_boxes(self, im, points, scale_factor):
points = points * scale_factor
img_shape = im.shape[:2]
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points

View File

@@ -1,906 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
__all__ = ['PGProcessTrain']
class PGProcessTrain(object):
def __init__(self,
character_dict_path,
max_text_length,
max_text_nums,
tcl_len,
batch_size=14,
min_crop_size=24,
min_text_size=4,
max_text_size=512,
**kwargs):
self.tcl_len = tcl_len
self.max_text_length = max_text_length
self.max_text_nums = max_text_nums
self.batch_size = batch_size
self.min_crop_size = min_crop_size
self.min_text_size = min_text_size
self.max_text_size = max_text_size
self.Lexicon_Table = self.get_dict(character_dict_path)
self.pad_num = len(self.Lexicon_Table)
self.img_id = 0
def get_dict(self, character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def quad_area(self, poly):
"""
compute area of a polygon
:param poly:
:return:
"""
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
return np.sum(edge) / 2.
def gen_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad
def check_and_validate_polys(self, polys, tags, im_size):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
:param polys:
:param tags:
:return:
"""
(h, w) = im_size
if polys.shape[0] == 0:
return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
validated_polys = []
validated_tags = []
hv_tags = []
for poly, tag in zip(polys, tags):
quad = self.gen_quad_from_poly(poly)
p_area = self.quad_area(quad)
if abs(p_area) < 1:
print('invalid poly')
continue
if p_area > 0:
if tag == False:
print('poly in wrong direction')
tag = True # reversed cases should be ignore
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
1), :]
quad = quad[(0, 3, 2, 1), :]
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
quad[2])
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
quad[2])
hv_tag = 1
if len_w * 2.0 < len_h:
hv_tag = 0
validated_polys.append(poly)
validated_tags.append(tag)
hv_tags.append(hv_tag)
return np.array(validated_polys), np.array(validated_tags), np.array(
hv_tags)
def crop_area(self,
im,
polys,
tags,
hv_tags,
txts,
crop_background=False,
max_tries=25):
"""
make random crop from the input image
:param im:
:param polys: [b,4,2]
:param tags:
:param crop_background:
:param max_tries: 50 -> 25
:return:
"""
h, w, _ = im.shape
pad_h = h // 10
pad_w = w // 10
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return im, polys, tags, hv_tags, txts
for i in range(max_tries):
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
if xmax - xmin < self.min_crop_size or \
ymax - ymin < self.min_crop_size:
continue
if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
selected_polys = np.where(
np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
txts_tmp = []
for selected_poly in selected_polys:
txts_tmp.append(txts[selected_poly])
txts = txts_tmp
return im[ymin: ymax + 1, xmin: xmax + 1, :], \
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
else:
continue
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
hv_tags = hv_tags[selected_polys]
txts_tmp = []
for selected_poly in selected_polys:
txts_tmp.append(txts[selected_poly])
txts = txts_tmp
polys[:, :, 0] -= xmin
polys[:, :, 1] -= ymin
return im, polys, tags, hv_tags, txts
return im, polys, tags, hv_tags, txts
def fit_and_gather_tcl_points_v2(self,
min_area_quad,
poly,
max_h,
max_w,
fixed_point_num=64,
img_id=0,
reference_height=3):
"""
Find the center point of poly as key_points, then fit and gather.
"""
key_point_xys = []
point_num = poly.shape[0]
for idx in range(point_num // 2):
center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
key_point_xys.append(center_point)
tmp_image = np.zeros(
shape=(
max_h,
max_w, ), dtype='float32')
cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
False, 1.0)
ys, xs = np.where(tmp_image > 0)
xy_text = np.array(list(zip(xs, ys)), dtype='float32')
left_center_pt = (
(min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
right_center_pt = (
(min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / (
np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_unit_vec_tile = np.tile(proj_unit_vec,
(xy_text.shape[0], 1)) # (n, 2)
left_center_pt_tile = np.tile(left_center_pt,
(xy_text.shape[0], 1)) # (n, 2)
xy_text_to_left_center = xy_text - left_center_pt_tile
proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
xy_text = xy_text[np.argsort(proj_value)]
# convert to np and keep the num of point not greater then fixed_point_num
pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
point_num = len(pos_info)
if point_num > fixed_point_num:
keep_ids = [
int((point_num * 1.0 / fixed_point_num) * x)
for x in range(fixed_point_num)
]
pos_info = pos_info[keep_ids, :]
keep = int(min(len(pos_info), fixed_point_num))
if np.random.rand() < 0.2 and reference_height >= 3:
dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
[keep, 1])
pos_info += random_float
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
# padding to fixed length
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
pos_m[:keep] = 1.0
return pos_l, pos_m
def generate_direction_map(self, poly_quads, n_char, direction_map):
"""
"""
width_list = []
height_list = []
for quad in poly_quads:
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
width_list.append(quad_w)
height_list.append(quad_h)
norm_width = max(sum(width_list) / n_char, 1.0)
average_height = max(sum(height_list) / len(height_list), 1.0)
k = 1
for quad in poly_quads:
direct_vector_full = (
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
direct_vector = direct_vector_full / (
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
direction_label = tuple(
map(float,
[direct_vector[0], direct_vector[1], 1.0 / average_height]))
cv2.fillPoly(direction_map,
quad.round().astype(np.int32)[np.newaxis, :, :],
direction_label)
k += 1
return direction_map
def calculate_average_height(self, poly_quads):
"""
"""
height_list = []
for quad in poly_quads:
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
height_list.append(quad_h)
average_height = max(sum(height_list) / len(height_list), 1.0)
return average_height
def generate_tcl_ctc_label(self,
h,
w,
polys,
tags,
text_strs,
ds_ratio,
tcl_ratio=0.3,
shrink_ratio_of_width=0.15):
"""
Generate polygon.
"""
score_map_big = np.zeros(
(
h,
w, ), dtype=np.float32)
h, w = int(h * ds_ratio), int(w * ds_ratio)
polys = polys * ds_ratio
score_map = np.zeros(
(
h,
w, ), dtype=np.float32)
score_label_map = np.zeros(
(
h,
w, ), dtype=np.float32)
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
training_mask = np.ones(
(
h,
w, ), dtype=np.float32)
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
[1, 1, 3]).astype(np.float32)
label_idx = 0
score_label_map_text_label_list = []
pos_list, pos_mask, label_list = [], [], []
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
tag = poly_tag[1]
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
min_area_quad_w = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
continue
if tag:
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
else:
text_label = text_strs[poly_idx]
text_label = self.prepare_text_label(text_label,
self.Lexicon_Table)
text_label_index_list = [[self.Lexicon_Table.index(c_)]
for c_ in text_label
if c_ in self.Lexicon_Table]
if len(text_label_index_list) < 1:
continue
tcl_poly = self.poly2tcl(poly, tcl_ratio)
tcl_quads = self.poly2quads(tcl_poly)
poly_quads = self.poly2quads(poly)
stcl_quads, quad_index = self.shrink_poly_along_width(
tcl_quads,
shrink_ratio_of_width=shrink_ratio_of_width,
expand_height_ratio=1.0 / tcl_ratio)
cv2.fillPoly(score_map,
np.round(stcl_quads).astype(np.int32), 1.0)
cv2.fillPoly(score_map_big,
np.round(stcl_quads / ds_ratio).astype(np.int32),
1.0)
for idx, quad in enumerate(stcl_quads):
quad_mask = np.zeros((h, w), dtype=np.float32)
quad_mask = cv2.fillPoly(
quad_mask,
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
quad_mask, tbo_map)
# score label map and score_label_map_text_label_list for refine
if label_idx == 0:
text_pos_list_ = [[len(self.Lexicon_Table)], ]
score_label_map_text_label_list.append(text_pos_list_)
label_idx += 1
cv2.fillPoly(score_label_map,
np.round(poly_quads).astype(np.int32), label_idx)
score_label_map_text_label_list.append(text_label_index_list)
# direction info, fix-me
n_char = len(text_label_index_list)
direction_map = self.generate_direction_map(poly_quads, n_char,
direction_map)
# pos info
average_shrink_height = self.calculate_average_height(
stcl_quads)
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
min_area_quad,
poly,
max_h=h,
max_w=w,
fixed_point_num=64,
img_id=self.img_id,
reference_height=average_shrink_height)
label_l = text_label_index_list
if len(text_label_index_list) < 2:
continue
pos_list.append(pos_l)
pos_mask.append(pos_m)
label_list.append(label_l)
# use big score_map for smooth tcl lines
score_map_big_resized = cv2.resize(
score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
return score_map, score_label_map, tbo_map, direction_map, training_mask, \
pos_list, pos_mask, label_list, score_label_map_text_label_list
def adjust_point(self, poly):
"""
adjust point order.
"""
point_num = poly.shape[0]
if point_num == 4:
len_1 = np.linalg.norm(poly[0] - poly[1])
len_2 = np.linalg.norm(poly[1] - poly[2])
len_3 = np.linalg.norm(poly[2] - poly[3])
len_4 = np.linalg.norm(poly[3] - poly[0])
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
poly = poly[[1, 2, 3, 0], :]
elif point_num > 4:
vector_1 = poly[0] - poly[1]
vector_2 = poly[1] - poly[2]
cos_theta = np.dot(vector_1, vector_2) / (
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
theta = np.arccos(np.round(cos_theta, decimals=4))
if abs(theta) > (70 / 180 * math.pi):
index = list(range(1, point_num)) + [0]
poly = poly[np.array(index), :]
return poly
def gen_min_area_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
if point_num == 4:
min_area_quad = poly
center_point = np.sum(poly, axis=0) / 4
else:
rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad, center_point
def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def shrink_poly_along_width(self,
quads,
shrink_ratio_of_width,
expand_height_ratio=1.0):
"""
shrink poly with given length.
"""
upper_edge_list = []
def get_cut_info(edge_len_list, cut_len):
for idx, edge_len in enumerate(edge_len_list):
cut_len -= edge_len
if cut_len <= 0.000001:
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
return idx, ratio
for quad in quads:
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
upper_edge_list.append(upper_edge_len)
# length of left edge and right edge.
left_length = np.linalg.norm(quads[0][0] - quads[0][
3]) * expand_height_ratio
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
2]) * expand_height_ratio
shrink_length = min(left_length, right_length,
sum(upper_edge_list)) * shrink_ratio_of_width
# shrinking length
upper_len_left = shrink_length
upper_len_right = sum(upper_edge_list) - shrink_length
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
left_quad = self.shrink_quad_along_width(
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
right_quad = self.shrink_quad_along_width(
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
out_quad_list = []
if left_idx == right_idx:
out_quad_list.append(
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
else:
out_quad_list.append(left_quad)
for idx in range(left_idx + 1, right_idx):
out_quad_list.append(quads[idx])
out_quad_list.append(right_quad)
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
def prepare_text_label(self, label_str, Lexicon_Table):
"""
Prepare text lablel by given Lexicon_Table.
"""
if len(Lexicon_Table) == 36:
return label_str.lower()
else:
return label_str
def vector_angle(self, A, B):
"""
Calculate the angle between vector AB and x-axis positive direction.
"""
AB = np.array([B[1] - A[1], B[0] - A[0]])
return np.arctan2(*AB)
def theta_line_cross_point(self, theta, point):
"""
Calculate the line through given point and angle in ax + by + c =0 form.
"""
x, y = point
cos = np.cos(theta)
sin = np.sin(theta)
return [sin, -cos, cos * y - sin * x]
def line_cross_two_point(self, A, B):
"""
Calculate the line through given point A and B in ax + by + c =0 form.
"""
angle = self.vector_angle(A, B)
return self.theta_line_cross_point(angle, A)
def average_angle(self, poly):
"""
Calculate the average angle between left and right edge in given poly.
"""
p0, p1, p2, p3 = poly
angle30 = self.vector_angle(p3, p0)
angle21 = self.vector_angle(p2, p1)
return (angle30 + angle21) / 2
def line_cross_point(self, line1, line2):
"""
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
"""
a1, b1, c1 = line1
a2, b2, c2 = line2
d = a1 * b2 - a2 * b1
if d == 0:
print('Cross point does not exist')
return np.array([0, 0], dtype=np.float32)
else:
x = (b1 * c2 - b2 * c1) / d
y = (a2 * c1 - a1 * c2) / d
return np.array([x, y], dtype=np.float32)
def quad2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point. (4, 2)
"""
ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
def poly2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point.
"""
ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
tcl_poly = np.zeros_like(poly)
point_num = poly.shape[0]
for idx in range(point_num // 2):
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
) * ratio_pair
tcl_poly[idx] = point_pair[0]
tcl_poly[point_num - 1 - idx] = point_pair[1]
return tcl_poly
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
"""
Generate tbo_map for give quad.
"""
# upper and lower line function: ax + by + c = 0;
up_line = self.line_cross_two_point(quad[0], quad[1])
lower_line = self.line_cross_two_point(quad[3], quad[2])
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[1] - quad[2]))
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3]))
# average angle of left and right line.
angle = self.average_angle(quad)
xy_in_poly = np.argwhere(tcl_mask == 1)
for y, x in xy_in_poly:
point = (x, y)
line = self.theta_line_cross_point(angle, point)
cross_point_upper = self.line_cross_point(up_line, line)
cross_point_lower = self.line_cross_point(lower_line, line)
##FIX, offset reverse
upper_offset_x, upper_offset_y = cross_point_upper - point
lower_offset_x, lower_offset_y = cross_point_lower - point
tbo_map[y, x, 0] = upper_offset_y
tbo_map[y, x, 1] = upper_offset_x
tbo_map[y, x, 2] = lower_offset_y
tbo_map[y, x, 3] = lower_offset_x
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
return tbo_map
def poly2quads(self, poly):
"""
Split poly into quads.
"""
quad_list = []
point_num = poly.shape[0]
# point pair
point_pair_list = []
for idx in range(point_num // 2):
point_pair = [poly[idx], poly[point_num - 1 - idx]]
point_pair_list.append(point_pair)
quad_num = point_num // 2 - 1
for idx in range(quad_num):
# reshape and adjust to clock-wise
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
).reshape(4, 2)[[0, 2, 3, 1]])
return np.array(quad_list)
def rotate_im_poly(self, im, text_polys):
"""
rotate image with 90 / 180 / 270 degre
"""
im_w, im_h = im.shape[1], im.shape[0]
dst_im = im.copy()
dst_polys = []
rand_degree_ratio = np.random.rand()
rand_degree_cnt = 1
if rand_degree_ratio > 0.5:
rand_degree_cnt = 3
for i in range(rand_degree_cnt):
dst_im = np.rot90(dst_im)
rot_degree = -90 * rand_degree_cnt
rot_angle = rot_degree * math.pi / 180.0
n_poly = text_polys.shape[0]
cx, cy = 0.5 * im_w, 0.5 * im_h
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
for i in range(n_poly):
wordBB = text_polys[i]
poly = []
for j in range(4): # 16->4
sx, sy = wordBB[j][0], wordBB[j][1]
dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
sy - cy) + ncx
dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
sy - cy) + ncy
poly.append([dx, dy])
dst_polys.append(poly)
return dst_im, np.array(dst_polys, dtype=np.float32)
def __call__(self, data):
input_size = 512
im = data['image']
text_polys = data['polys']
text_tags = data['ignore_tags']
text_strs = data['texts']
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w))
if text_polys.shape[0] <= 0:
return None
# set aspect ratio and keep area fix
asp_scales = np.arange(1.0, 1.55, 0.1)
asp_scale = np.random.choice(asp_scales)
if np.random.rand() < 0.5:
asp_scale = 1.0 / asp_scale
asp_scale = math.sqrt(asp_scale)
asp_wx = asp_scale
asp_hy = 1.0 / asp_scale
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
text_polys[:, :, 0] *= asp_wx
text_polys[:, :, 1] *= asp_hy
h, w, _ = im.shape
if max(h, w) > 2048:
rd_scale = 2048.0 / max(h, w)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
h, w, _ = im.shape
if min(h, w) < 16:
return None
# no background
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
im,
text_polys,
text_tags,
hv_tags,
text_strs,
crop_background=False)
if text_polys.shape[0] == 0:
return None
# # continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
new_h, new_w, _ = im.shape
if (new_h is None) or (new_w is None):
return None
# resize image
std_ratio = float(input_size) / max(new_w, new_h)
rand_scales = np.array(
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
rz_scale = std_ratio * np.random.choice(rand_scales)
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
text_polys[:, :, 0] *= rz_scale
text_polys[:, :, 1] *= rz_scale
# add gaussian blur
if np.random.rand() < 0.1 * 0.5:
ks = np.random.permutation(5)[0] + 1
ks = int(ks / 2) * 2 + 1
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
# add brighter
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 + np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
# add darker
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 - np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
# Padding the im to [input_size, input_size]
new_h, new_w, _ = im.shape
if min(new_w, new_h) < input_size * 0.5:
return None
im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
im_padded[:, :, 2] = 0.485 * 255
im_padded[:, :, 1] = 0.456 * 255
im_padded[:, :, 0] = 0.406 * 255
# Random the start position
del_h = input_size - new_h
del_w = input_size - new_w
sh, sw = 0, 0
if del_h > 1:
sh = int(np.random.rand() * del_h)
if del_w > 1:
sw = int(np.random.rand() * del_w)
# Padding
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
text_polys[:, :, 0] += sw
text_polys[:, :, 1] += sh
score_map, score_label_map, border_map, direction_map, training_mask, \
pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
input_size,
text_polys,
text_tags,
text_strs, 0.25)
if len(label_list) <= 0: # eliminate negative samples
return None
pos_list_temp = np.zeros([64, 3])
pos_mask_temp = np.zeros([64, 1])
label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
for i, label in enumerate(label_list):
n = len(label)
if n > self.max_text_length:
label_list[i] = label[:self.max_text_length]
continue
while n < self.max_text_length:
label.append([self.pad_num])
n += 1
for i in range(len(label_list)):
label_list[i] = np.array(label_list[i])
if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
return None
for __ in range(self.max_text_nums - len(pos_list), 0, -1):
pos_list.append(pos_list_temp)
pos_mask.append(pos_mask_temp)
label_list.append(label_list_temp)
if self.img_id == self.batch_size - 1:
self.img_id = 0
else:
self.img_id += 1
im_padded[:, :, 2] -= 0.485 * 255
im_padded[:, :, 1] -= 0.456 * 255
im_padded[:, :, 0] -= 0.406 * 255
im_padded[:, :, 2] /= (255.0 * 0.229)
im_padded[:, :, 1] /= (255.0 * 0.224)
im_padded[:, :, 0] /= (255.0 * 0.225)
im_padded = im_padded.transpose((2, 0, 1))
images = im_padded[::-1, :, :]
tcl_maps = score_map[np.newaxis, :, :]
tcl_label_maps = score_label_map[np.newaxis, :, :]
border_maps = border_map.transpose((2, 0, 1))
direction_maps = direction_map.transpose((2, 0, 1))
training_masks = training_mask[np.newaxis, :, :]
pos_list = np.array(pos_list)
pos_mask = np.array(pos_mask)
label_list = np.array(label_list)
data['images'] = images
data['tcl_maps'] = tcl_maps
data['tcl_label_maps'] = tcl_label_maps
data['border_maps'] = border_maps
data['direction_maps'] = direction_maps
data['training_masks'] = training_masks
data['label_list'] = label_list
data['pos_list'] = pos_list
data['pos_mask'] = pos_mask
return data

View File

@@ -1,143 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
import six
class RawRandAugment(object):
def __init__(self,
num_layers=2,
magnitude=5,
fillcolor=(128, 128, 128),
**kwargs):
self.num_layers = num_layers
self.magnitude = magnitude
self.max_level = 10
abso_level = self.magnitude / self.max_level
self.level_map = {
"shearX": 0.3 * abso_level,
"shearY": 0.3 * abso_level,
"translateX": 150.0 / 331 * abso_level,
"translateY": 150.0 / 331 * abso_level,
"rotate": 30 * abso_level,
"color": 0.9 * abso_level,
"posterize": int(4.0 * abso_level),
"solarize": 256.0 * abso_level,
"contrast": 0.9 * abso_level,
"sharpness": 0.9 * abso_level,
"brightness": 0.9 * abso_level,
"autocontrast": 0,
"equalize": 0,
"invert": 0
}
# from https://stackoverflow.com/questions/5252170/
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot,
Image.new("RGBA", rot.size, (128, ) * 4),
rot).convert(img.mode)
rnd_ch_op = random.choice
self.func = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"posterize": lambda img, magnitude:
ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude:
ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude:
ImageEnhance.Contrast(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"sharpness": lambda img, magnitude:
ImageEnhance.Sharpness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"brightness": lambda img, magnitude:
ImageEnhance.Brightness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"autocontrast": lambda img, magnitude:
ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
def __call__(self, img):
avaiable_op_names = list(self.level_map.keys())
for layer_num in range(self.num_layers):
op_name = np.random.choice(avaiable_op_names)
img = self.func[op_name](img, self.level_map[op_name])
return img
class RandAugment(RawRandAugment):
""" RandAugment wrapper to auto fit different img types """
def __init__(self, prob=0.5, *args, **kwargs):
self.prob = prob
if six.PY2:
super(RandAugment, self).__init__(*args, **kwargs)
else:
super().__init__(*args, **kwargs)
def __call__(self, data):
if np.random.rand() > self.prob:
return data
img = data['image']
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
if six.PY2:
img = super(RandAugment, self).__call__(img)
else:
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
data['image'] = img
return data

View File

@@ -1,234 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
import random
def is_poly_in_rect(poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
return False
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
return False
return True
def is_poly_outside_rect(poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
return True
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
return True
return False
def split_regions(axis):
regions = []
min_axis = 0
for i in range(1, axis.shape[0]):
if axis[i] != axis[i - 1] + 1:
region = axis[min_axis:i]
min_axis = i
regions.append(region)
return regions
def random_select(axis, max_size):
xx = np.random.choice(axis, size=2)
xmin = np.min(xx)
xmax = np.max(xx)
xmin = np.clip(xmin, 0, max_size - 1)
xmax = np.clip(xmax, 0, max_size - 1)
return xmin, xmax
def region_wise_random_select(regions, max_size):
selected_index = list(np.random.choice(len(regions), 2))
selected_values = []
for index in selected_index:
axis = regions[index]
xx = int(np.random.choice(axis, size=1))
selected_values.append(xx)
xmin = min(selected_values)
xmax = max(selected_values)
return xmin, xmax
def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
h, w, _ = im.shape
h_array = np.zeros(h, dtype=np.int32)
w_array = np.zeros(w, dtype=np.int32)
for points in text_polys:
points = np.round(points, decimals=0).astype(np.int32)
minx = np.min(points[:, 0])
maxx = np.max(points[:, 0])
w_array[minx:maxx] = 1
miny = np.min(points[:, 1])
maxy = np.max(points[:, 1])
h_array[miny:maxy] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return 0, 0, w, h
h_regions = split_regions(h_axis)
w_regions = split_regions(w_axis)
for i in range(max_tries):
if len(w_regions) > 1:
xmin, xmax = region_wise_random_select(w_regions, w)
else:
xmin, xmax = random_select(w_axis, w)
if len(h_regions) > 1:
ymin, ymax = region_wise_random_select(h_regions, h)
else:
ymin, ymax = random_select(h_axis, h)
if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
# area too small
continue
num_poly_in_rect = 0
for poly in text_polys:
if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
ymax - ymin):
num_poly_in_rect += 1
break
if num_poly_in_rect > 0:
return xmin, ymin, xmax - xmin, ymax - ymin
return 0, 0, w, h
class EastRandomCropData(object):
def __init__(self,
size=(640, 640),
max_tries=10,
min_crop_side_ratio=0.1,
keep_ratio=True,
**kwargs):
self.size = size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
self.keep_ratio = keep_ratio
def __call__(self, data):
img = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
texts = data['texts']
all_care_polys = [
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = crop_area(
img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
# crop 图片 保持比例填充
scale_w = self.size[0] / crop_w
scale_h = self.size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
if self.keep_ratio:
padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
img.dtype)
padimg[:h, :w] = cv2.resize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg
else:
img = cv2.resize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
tuple(self.size))
# crop 文本框
text_polys_crop = []
ignore_tags_crop = []
texts_crop = []
for poly, text, tag in zip(text_polys, texts, ignore_tags):
poly = ((poly - (crop_x, crop_y)) * scale).tolist()
if not is_poly_outside_rect(poly, 0, 0, w, h):
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
data['image'] = img
data['polys'] = np.array(text_polys_crop)
data['ignore_tags'] = ignore_tags_crop
data['texts'] = texts_crop
return data
class RandomCropImgMask(object):
def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
self.size = size
self.main_key = main_key
self.crop_keys = crop_keys
self.p = p
def __call__(self, data):
image = data['image']
h, w = image.shape[0:2]
th, tw = self.size
if w == tw and h == th:
return data
mask = data[self.main_key]
if np.max(mask) > 0 and random.random() > self.p:
# make sure to crop the text region
tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
tl[tl < 0] = 0
br = np.max(np.where(mask > 0), axis=1) - (th, tw)
br[br < 0] = 0
br[0] = min(br[0], h - th)
br[1] = min(br[1], w - tw)
i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
else:
i = random.randint(0, h - th) if h - th > 0 else 0
j = random.randint(0, w - tw) if w - tw > 0 else 0
# return i, j, th, tw
for k in data:
if k in self.crop_keys:
if len(data[k].shape) == 3:
if np.argmin(data[k].shape) == 0:
img = data[k][:, i:i + th, j:j + tw]
if img.shape[1] != img.shape[2]:
a = 1
elif np.argmin(data[k].shape) == 2:
img = data[k][i:i + th, j:j + tw, :]
if img.shape[1] != img.shape[0]:
a = 1
else:
img = data[k]
else:
img = data[k][i:i + th, j:j + tw]
if img.shape[0] != img.shape[1]:
a = 1
data[k] = img
return data

View File

@@ -1,601 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
import random
import copy
from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
class RecAug(object):
def __init__(self, use_tia=True, aug_prob=0.4, **kwargs):
self.use_tia = use_tia
self.aug_prob = aug_prob
def __call__(self, data):
img = data['image']
img = warp(img, 10, self.use_tia, self.aug_prob)
data['image'] = img
return data
class RecConAug(object):
def __init__(self,
prob=0.5,
image_shape=(32, 320, 3),
max_text_length=25,
ext_data_num=1,
**kwargs):
self.ext_data_num = ext_data_num
self.prob = prob
self.max_text_length = max_text_length
self.image_shape = image_shape
self.max_wh_ratio = self.image_shape[1] / self.image_shape[0]
def merge_ext_data(self, data, ext_data):
ori_w = round(data['image'].shape[1] / data['image'].shape[0] *
self.image_shape[0])
ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] *
self.image_shape[0])
data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0]))
ext_data['image'] = cv2.resize(ext_data['image'],
(ext_w, self.image_shape[0]))
data['image'] = np.concatenate(
[data['image'], ext_data['image']], axis=1)
data["label"] += ext_data["label"]
return data
def __call__(self, data):
rnd_num = random.random()
if rnd_num > self.prob:
return data
for idx, ext_data in enumerate(data["ext_data"]):
if len(data["label"]) + len(ext_data[
"label"]) > self.max_text_length:
break
concat_ratio = data['image'].shape[1] / data['image'].shape[
0] + ext_data['image'].shape[1] / ext_data['image'].shape[0]
if concat_ratio > self.max_wh_ratio:
break
data = self.merge_ext_data(data, ext_data)
data.pop("ext_data")
return data
class ClsResizeImg(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape
def __call__(self, data):
img = data['image']
norm_img, _ = resize_norm_img(img, self.image_shape)
data['image'] = norm_img
return data
class NRTRRecResizeImg(object):
def __init__(self, image_shape, resize_type, padding=False, **kwargs):
self.image_shape = image_shape
self.resize_type = resize_type
self.padding = padding
def __call__(self, data):
img = data['image']
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
image_shape = self.image_shape
if self.padding:
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
norm_img = np.expand_dims(resized_image, -1)
norm_img = norm_img.transpose((2, 0, 1))
resized_image = norm_img.astype(np.float32) / 128. - 1.
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
data['image'] = padding_im
return data
if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
img = np.array(img)
if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data
class RecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_dict_path is not None:
norm_img, valid_ratio = resize_norm_img_chinese(img,
self.image_shape)
else:
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
self.padding)
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
self.num_heads = num_heads
self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_srn(img, self.image_shape)
data['image'] = norm_img
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
data['encoder_word_pos'] = encoder_word_pos
data['gsrm_word_pos'] = gsrm_word_pos
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
return data
class SARRecResizeImg(object):
def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
img, self.image_shape, self.width_downsample_ratio)
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
data['valid_ratio'] = valid_ratio
return data
class PRENResizeImg(object):
def __init__(self, image_shape, **kwargs):
"""
Accroding to original paper's realization, it's a hard resize method here.
So maybe you should optimize it to fit for your task better.
"""
self.dst_h, self.dst_w = image_shape
def __call__(self, data):
img = data['image']
resized_img = cv2.resize(
img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR)
resized_img = resized_img.transpose((2, 0, 1)) / 255
resized_img -= 0.5
resized_img /= 0.5
data['image'] = resized_img.astype(np.float32)
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img(img, image_shape, padding=True):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
valid_ratio = min(1.0, float(resized_w / imgW))
return padding_im, valid_ratio
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(imgH * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
valid_ratio = min(1.0, float(resized_w / imgW))
return padding_im, valid_ratio
def resize_norm_img_srn(img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
[num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
[num_heads, 1, 1]) * [-1e9]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def flag():
"""
flag
"""
return 1 if random.random() > 0.5000001 else -1
def cvtColor(img):
"""
cvtColor
"""
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
delta = 0.001 * random.random() * flag()
hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return new_img
def blur(img):
"""
blur
"""
h, w, _ = img.shape
if h > 10 and w > 10:
return cv2.GaussianBlur(img, (5, 5), 1)
else:
return img
def jitter(img):
"""
jitter
"""
w, h, _ = img.shape
if h > 10 and w > 10:
thres = min(w, h)
s = int(random.random() * thres * 0.01)
src_img = img.copy()
for i in range(s):
img[i:, i:, :] = src_img[:w - i, :h - i, :]
return img
else:
return img
def add_gasuss_noise(image, mean=0, var=0.1):
"""
Gasuss noise
"""
noise = np.random.normal(mean, var**0.5, image.shape)
out = image + 0.5 * noise
out = np.clip(out, 0, 255)
out = np.uint8(out)
return out
def get_crop(image):
"""
random crop
"""
h, w, _ = image.shape
top_min = 1
top_max = 8
top_crop = int(random.randint(top_min, top_max))
top_crop = min(top_crop, h - 1)
crop_img = image.copy()
ratio = random.randint(0, 1)
if ratio:
crop_img = crop_img[top_crop:h, :, :]
else:
crop_img = crop_img[0:h - top_crop, :, :]
return crop_img
class Config:
"""
Config
"""
def __init__(self, use_tia):
self.anglex = random.random() * 30
self.angley = random.random() * 15
self.anglez = random.random() * 10
self.fov = 42
self.r = 0
self.shearx = random.random() * 0.3
self.sheary = random.random() * 0.05
self.borderMode = cv2.BORDER_REPLICATE
self.use_tia = use_tia
def make(self, w, h, ang):
"""
make
"""
self.anglex = random.random() * 5 * flag()
self.angley = random.random() * 5 * flag()
self.anglez = -1 * random.random() * int(ang) * flag()
self.fov = 42
self.r = 0
self.shearx = 0
self.sheary = 0
self.borderMode = cv2.BORDER_REPLICATE
self.w = w
self.h = h
self.perspective = self.use_tia
self.stretch = self.use_tia
self.distort = self.use_tia
self.crop = True
self.affine = False
self.reverse = True
self.noise = True
self.jitter = True
self.blur = True
self.color = True
def rad(x):
"""
rad
"""
return x * np.pi / 180
def get_warpR(config):
"""
get_warpR
"""
anglex, angley, anglez, fov, w, h, r = \
config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
if w > 69 and w < 112:
anglex = anglex * 1.5
z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
# Homogeneous coordinate transformation matrix
rx = np.array([[1, 0, 0, 0],
[0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
0,
-np.sin(rad(anglex)),
np.cos(rad(anglex)),
0,
], [0, 0, 0, 1]], np.float32)
ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
[0, 1, 0, 0], [
-np.sin(rad(angley)),
0,
np.cos(rad(angley)),
0,
], [0, 0, 0, 1]], np.float32)
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
[0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
r = rx.dot(ry).dot(rz)
# generate 4 points
pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
p3 = np.array([0, h, 0, 0], np.float32) - pcenter
p4 = np.array([w, h, 0, 0], np.float32) - pcenter
dst1 = r.dot(p1)
dst2 = r.dot(p2)
dst3 = r.dot(p3)
dst4 = r.dot(p4)
list_dst = np.array([dst1, dst2, dst3, dst4])
org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
dst = np.zeros((4, 2), np.float32)
# Project onto the image plane
dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0]
dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1]
warpR = cv2.getPerspectiveTransform(org, dst)
dst1, dst2, dst3, dst4 = dst
r1 = int(min(dst1[1], dst2[1]))
r2 = int(max(dst3[1], dst4[1]))
c1 = int(min(dst1[0], dst3[0]))
c2 = int(max(dst2[0], dst4[0]))
try:
ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1))
dx = -c1
dy = -r1
T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
ret = T1.dot(warpR)
except:
ratio = 1.0
T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
ret = T1
return ret, (-r1, -c1), ratio, dst
def get_warpAffine(config):
"""
get_warpAffine
"""
anglez = config.anglez
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
return rz
def warp(img, ang, use_tia=True, prob=0.4):
"""
warp
"""
h, w, _ = img.shape
config = Config(use_tia=use_tia)
config.make(w, h, ang)
new_img = img
if config.distort:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_distort(new_img, random.randint(3, 6))
if config.stretch:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_stretch(new_img, random.randint(3, 6))
if config.perspective:
if random.random() <= prob:
new_img = tia_perspective(new_img)
if config.crop:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = get_crop(new_img)
if config.blur:
if random.random() <= prob:
new_img = blur(new_img)
if config.color:
if random.random() <= prob:
new_img = cvtColor(new_img)
if config.jitter:
new_img = jitter(new_img)
if config.noise:
if random.random() <= prob:
new_img = add_gasuss_noise(new_img)
if config.reverse:
if random.random() <= prob:
new_img = 255 - new_img
return new_img

View File

@@ -1,777 +0,0 @@
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
This part code is refered from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
"""
import math
import cv2
import numpy as np
import json
import sys
import os
__all__ = ['SASTProcessTrain']
class SASTProcessTrain(object):
def __init__(self,
image_shape=[512, 512],
min_crop_size=24,
min_crop_side_ratio=0.3,
min_text_size=10,
max_text_size=512,
**kwargs):
self.input_size = image_shape[1]
self.min_crop_size = min_crop_size
self.min_crop_side_ratio = min_crop_side_ratio
self.min_text_size = min_text_size
self.max_text_size = max_text_size
def quad_area(self, poly):
"""
compute area of a polygon
:param poly:
:return:
"""
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
return np.sum(edge) / 2.
def gen_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
if True:
rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
:param polys:
:param tags:
:return:
"""
(h, w) = xxx_todo_changeme
if polys.shape[0] == 0:
return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
validated_polys = []
validated_tags = []
hv_tags = []
for poly, tag in zip(polys, tags):
quad = self.gen_quad_from_poly(poly)
p_area = self.quad_area(quad)
if abs(p_area) < 1:
print('invalid poly')
continue
if p_area > 0:
if tag == False:
print('poly in wrong direction')
tag = True # reversed cases should be ignore
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
1), :]
quad = quad[(0, 3, 2, 1), :]
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
quad[2])
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
quad[2])
hv_tag = 1
if len_w * 2.0 < len_h:
hv_tag = 0
validated_polys.append(poly)
validated_tags.append(tag)
hv_tags.append(hv_tag)
return np.array(validated_polys), np.array(validated_tags), np.array(
hv_tags)
def crop_area(self,
im,
polys,
tags,
hv_tags,
crop_background=False,
max_tries=25):
"""
make random crop from the input image
:param im:
:param polys:
:param tags:
:param crop_background:
:param max_tries: 50 -> 25
:return:
"""
h, w, _ = im.shape
pad_h = h // 10
pad_w = w // 10
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return im, polys, tags, hv_tags
for i in range(max_tries):
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
# if xmax - xmin < ARGS.min_crop_side_ratio * w or \
# ymax - ymin < ARGS.min_crop_side_ratio * h:
if xmax - xmin < self.min_crop_size or \
ymax - ymin < self.min_crop_size:
# area too small
continue
if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
selected_polys = np.where(
np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
return im[ymin : ymax + 1, xmin : xmax + 1, :], \
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys]
else:
continue
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
hv_tags = hv_tags[selected_polys]
polys[:, :, 0] -= xmin
polys[:, :, 1] -= ymin
return im, polys, tags, hv_tags
return im, polys, tags, hv_tags
def generate_direction_map(self, poly_quads, direction_map):
"""
"""
width_list = []
height_list = []
for quad in poly_quads:
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
width_list.append(quad_w)
height_list.append(quad_h)
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
for quad in poly_quads:
direct_vector_full = (
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
direct_vector = direct_vector_full / (
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
direction_label = tuple(
map(float, [
direct_vector[0], direct_vector[1], 1.0 / (average_height +
1e-6)
]))
cv2.fillPoly(direction_map,
quad.round().astype(np.int32)[np.newaxis, :, :],
direction_label)
return direction_map
def calculate_average_height(self, poly_quads):
"""
"""
height_list = []
for quad in poly_quads:
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
height_list.append(quad_h)
average_height = max(sum(height_list) / len(height_list), 1.0)
return average_height
def generate_tcl_label(self,
hw,
polys,
tags,
ds_ratio,
tcl_ratio=0.3,
shrink_ratio_of_width=0.15):
"""
Generate polygon.
"""
h, w = hw
h, w = int(h * ds_ratio), int(w * ds_ratio)
polys = polys * ds_ratio
score_map = np.zeros(
(
h,
w, ), dtype=np.float32)
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
training_mask = np.ones(
(
h,
w, ), dtype=np.float32)
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
[1, 1, 3]).astype(np.float32)
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
tag = poly_tag[1]
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
min_area_quad_w = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
continue
if tag:
# continue
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
else:
tcl_poly = self.poly2tcl(poly, tcl_ratio)
tcl_quads = self.poly2quads(tcl_poly)
poly_quads = self.poly2quads(poly)
# stcl map
stcl_quads, quad_index = self.shrink_poly_along_width(
tcl_quads,
shrink_ratio_of_width=shrink_ratio_of_width,
expand_height_ratio=1.0 / tcl_ratio)
# generate tcl map
cv2.fillPoly(score_map,
np.round(stcl_quads).astype(np.int32), 1.0)
# generate tbo map
for idx, quad in enumerate(stcl_quads):
quad_mask = np.zeros((h, w), dtype=np.float32)
quad_mask = cv2.fillPoly(
quad_mask,
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
quad_mask, tbo_map)
return score_map, tbo_map, training_mask
def generate_tvo_and_tco(self,
hw,
polys,
tags,
tcl_ratio=0.3,
ds_ratio=0.25):
"""
Generate tcl map, tvo map and tbo map.
"""
h, w = hw
h, w = int(h * ds_ratio), int(w * ds_ratio)
polys = polys * ds_ratio
poly_mask = np.zeros((h, w), dtype=np.float32)
tvo_map = np.ones((9, h, w), dtype=np.float32)
tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
# tco map
tco_map = np.ones((3, h, w), dtype=np.float32)
tco_map[0] = np.tile(np.arange(0, w), (h, 1))
tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
poly_short_edge_map = np.ones((h, w), dtype=np.float32)
for poly, poly_tag in zip(polys, tags):
if poly_tag == True:
continue
# adjust point order for vertical poly
poly = self.adjust_point(poly)
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
min_area_quad_w = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
# generate tcl map and text, 128 * 128
tcl_poly = self.poly2tcl(poly, tcl_ratio)
# generate poly_tv_xy_map
for idx in range(4):
cv2.fillPoly(
poly_tv_xy_map[2 * idx],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(min(max(min_area_quad[idx, 0], 0), w)))
cv2.fillPoly(
poly_tv_xy_map[2 * idx + 1],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(min(max(min_area_quad[idx, 1], 0), h)))
# generate poly_tc_xy_map
for idx in range(2):
cv2.fillPoly(
poly_tc_xy_map[idx],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(center_point[idx]))
# generate poly_short_edge_map
cv2.fillPoly(
poly_short_edge_map,
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
# generate poly_mask and training_mask
cv2.fillPoly(poly_mask,
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
1)
tvo_map *= poly_mask
tvo_map[:8] -= poly_tv_xy_map
tvo_map[-1] /= poly_short_edge_map
tvo_map = tvo_map.transpose((1, 2, 0))
tco_map *= poly_mask
tco_map[:2] -= poly_tc_xy_map
tco_map[-1] /= poly_short_edge_map
tco_map = tco_map.transpose((1, 2, 0))
return tvo_map, tco_map
def adjust_point(self, poly):
"""
adjust point order.
"""
point_num = poly.shape[0]
if point_num == 4:
len_1 = np.linalg.norm(poly[0] - poly[1])
len_2 = np.linalg.norm(poly[1] - poly[2])
len_3 = np.linalg.norm(poly[2] - poly[3])
len_4 = np.linalg.norm(poly[3] - poly[0])
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
poly = poly[[1, 2, 3, 0], :]
elif point_num > 4:
vector_1 = poly[0] - poly[1]
vector_2 = poly[1] - poly[2]
cos_theta = np.dot(vector_1, vector_2) / (
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
theta = np.arccos(np.round(cos_theta, decimals=4))
if abs(theta) > (70 / 180 * math.pi):
index = list(range(1, point_num)) + [0]
poly = poly[np.array(index), :]
return poly
def gen_min_area_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
if point_num == 4:
min_area_quad = poly
center_point = np.sum(poly, axis=0) / 4
else:
rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad, center_point
def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def shrink_poly_along_width(self,
quads,
shrink_ratio_of_width,
expand_height_ratio=1.0):
"""
shrink poly with given length.
"""
upper_edge_list = []
def get_cut_info(edge_len_list, cut_len):
for idx, edge_len in enumerate(edge_len_list):
cut_len -= edge_len
if cut_len <= 0.000001:
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
return idx, ratio
for quad in quads:
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
upper_edge_list.append(upper_edge_len)
# length of left edge and right edge.
left_length = np.linalg.norm(quads[0][0] - quads[0][
3]) * expand_height_ratio
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
2]) * expand_height_ratio
shrink_length = min(left_length, right_length,
sum(upper_edge_list)) * shrink_ratio_of_width
# shrinking length
upper_len_left = shrink_length
upper_len_right = sum(upper_edge_list) - shrink_length
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
left_quad = self.shrink_quad_along_width(
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
right_quad = self.shrink_quad_along_width(
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
out_quad_list = []
if left_idx == right_idx:
out_quad_list.append(
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
else:
out_quad_list.append(left_quad)
for idx in range(left_idx + 1, right_idx):
out_quad_list.append(quads[idx])
out_quad_list.append(right_quad)
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
def vector_angle(self, A, B):
"""
Calculate the angle between vector AB and x-axis positive direction.
"""
AB = np.array([B[1] - A[1], B[0] - A[0]])
return np.arctan2(*AB)
def theta_line_cross_point(self, theta, point):
"""
Calculate the line through given point and angle in ax + by + c =0 form.
"""
x, y = point
cos = np.cos(theta)
sin = np.sin(theta)
return [sin, -cos, cos * y - sin * x]
def line_cross_two_point(self, A, B):
"""
Calculate the line through given point A and B in ax + by + c =0 form.
"""
angle = self.vector_angle(A, B)
return self.theta_line_cross_point(angle, A)
def average_angle(self, poly):
"""
Calculate the average angle between left and right edge in given poly.
"""
p0, p1, p2, p3 = poly
angle30 = self.vector_angle(p3, p0)
angle21 = self.vector_angle(p2, p1)
return (angle30 + angle21) / 2
def line_cross_point(self, line1, line2):
"""
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
"""
a1, b1, c1 = line1
a2, b2, c2 = line2
d = a1 * b2 - a2 * b1
if d == 0:
#print("line1", line1)
#print("line2", line2)
print('Cross point does not exist')
return np.array([0, 0], dtype=np.float32)
else:
x = (b1 * c2 - b2 * c1) / d
y = (a2 * c1 - a1 * c2) / d
return np.array([x, y], dtype=np.float32)
def quad2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point. (4, 2)
"""
ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
def poly2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point.
"""
ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
tcl_poly = np.zeros_like(poly)
point_num = poly.shape[0]
for idx in range(point_num // 2):
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
) * ratio_pair
tcl_poly[idx] = point_pair[0]
tcl_poly[point_num - 1 - idx] = point_pair[1]
return tcl_poly
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
"""
Generate tbo_map for give quad.
"""
# upper and lower line function: ax + by + c = 0;
up_line = self.line_cross_two_point(quad[0], quad[1])
lower_line = self.line_cross_two_point(quad[3], quad[2])
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[1] - quad[2]))
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3]))
# average angle of left and right line.
angle = self.average_angle(quad)
xy_in_poly = np.argwhere(tcl_mask == 1)
for y, x in xy_in_poly:
point = (x, y)
line = self.theta_line_cross_point(angle, point)
cross_point_upper = self.line_cross_point(up_line, line)
cross_point_lower = self.line_cross_point(lower_line, line)
##FIX, offset reverse
upper_offset_x, upper_offset_y = cross_point_upper - point
lower_offset_x, lower_offset_y = cross_point_lower - point
tbo_map[y, x, 0] = upper_offset_y
tbo_map[y, x, 1] = upper_offset_x
tbo_map[y, x, 2] = lower_offset_y
tbo_map[y, x, 3] = lower_offset_x
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
return tbo_map
def poly2quads(self, poly):
"""
Split poly into quads.
"""
quad_list = []
point_num = poly.shape[0]
# point pair
point_pair_list = []
for idx in range(point_num // 2):
point_pair = [poly[idx], poly[point_num - 1 - idx]]
point_pair_list.append(point_pair)
quad_num = point_num // 2 - 1
for idx in range(quad_num):
# reshape and adjust to clock-wise
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
).reshape(4, 2)[[0, 2, 3, 1]])
return np.array(quad_list)
def __call__(self, data):
im = data['image']
text_polys = data['polys']
text_tags = data['ignore_tags']
if im is None:
return None
if text_polys.shape[0] == 0:
return None
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w))
if text_polys.shape[0] == 0:
return None
#set aspect ratio and keep area fix
asp_scales = np.arange(1.0, 1.55, 0.1)
asp_scale = np.random.choice(asp_scales)
if np.random.rand() < 0.5:
asp_scale = 1.0 / asp_scale
asp_scale = math.sqrt(asp_scale)
asp_wx = asp_scale
asp_hy = 1.0 / asp_scale
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
text_polys[:, :, 0] *= asp_wx
text_polys[:, :, 1] *= asp_hy
h, w, _ = im.shape
if max(h, w) > 2048:
rd_scale = 2048.0 / max(h, w)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
h, w, _ = im.shape
if min(h, w) < 16:
return None
#no background
im, text_polys, text_tags, hv_tags = self.crop_area(im, \
text_polys, text_tags, hv_tags, crop_background=False)
if text_polys.shape[0] == 0:
return None
#continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
new_h, new_w, _ = im.shape
if (new_h is None) or (new_w is None):
return None
#resize image
std_ratio = float(self.input_size) / max(new_w, new_h)
rand_scales = np.array(
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
rz_scale = std_ratio * np.random.choice(rand_scales)
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
text_polys[:, :, 0] *= rz_scale
text_polys[:, :, 1] *= rz_scale
#add gaussian blur
if np.random.rand() < 0.1 * 0.5:
ks = np.random.permutation(5)[0] + 1
ks = int(ks / 2) * 2 + 1
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
#add brighter
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 + np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
#add darker
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 - np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
# Padding the im to [input_size, input_size]
new_h, new_w, _ = im.shape
if min(new_w, new_h) < self.input_size * 0.5:
return None
im_padded = np.ones(
(self.input_size, self.input_size, 3), dtype=np.float32)
im_padded[:, :, 2] = 0.485 * 255
im_padded[:, :, 1] = 0.456 * 255
im_padded[:, :, 0] = 0.406 * 255
# Random the start position
del_h = self.input_size - new_h
del_w = self.input_size - new_w
sh, sw = 0, 0
if del_h > 1:
sh = int(np.random.rand() * del_h)
if del_w > 1:
sw = int(np.random.rand() * del_w)
# Padding
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
text_polys[:, :, 0] += sw
text_polys[:, :, 1] += sh
score_map, border_map, training_mask = self.generate_tcl_label(
(self.input_size, self.input_size), text_polys, text_tags, 0.25)
# SAST head
tvo_map, tco_map = self.generate_tvo_and_tco(
(self.input_size, self.input_size),
text_polys,
text_tags,
tcl_ratio=0.3,
ds_ratio=0.25)
# print("test--------tvo_map shape:", tvo_map.shape)
im_padded[:, :, 2] -= 0.485 * 255
im_padded[:, :, 1] -= 0.456 * 255
im_padded[:, :, 0] -= 0.406 * 255
im_padded[:, :, 2] /= (255.0 * 0.229)
im_padded[:, :, 1] /= (255.0 * 0.224)
im_padded[:, :, 0] /= (255.0 * 0.225)
im_padded = im_padded.transpose((2, 0, 1))
data['image'] = im_padded[::-1, :, :]
data['score_map'] = score_map[np.newaxis, :, :]
data['border_map'] = border_map.transpose((2, 0, 1))
data['training_mask'] = training_mask[np.newaxis, :, :]
data['tvo_map'] = tvo_map.transpose((2, 0, 1))
data['tco_map'] = tco_map.transpose((2, 0, 1))
return data

View File

@@ -1,60 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
import random
from PIL import Image
from .rec_img_aug import resize_norm_img
class SSLRotateResize(object):
def __init__(self,
image_shape,
padding=False,
select_all=True,
mode="train",
**kwargs):
self.image_shape = image_shape
self.padding = padding
self.select_all = select_all
self.mode = mode
def __call__(self, data):
img = data["image"]
data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
data["image_r180"] = cv2.rotate(data["image_r90"],
cv2.ROTATE_90_CLOCKWISE)
data["image_r270"] = cv2.rotate(data["image_r180"],
cv2.ROTATE_90_CLOCKWISE)
images = []
for key in ["image", "image_r90", "image_r180", "image_r270"]:
images.append(
resize_norm_img(
data.pop(key),
image_shape=self.image_shape,
padding=self.padding)[0])
data["image"] = np.stack(images, axis=0)
data["label"] = np.array(list(range(4)))
if not self.select_all:
data["image"] = data["image"][0::2] # just choose 0 and 180
data["label"] = data["label"][0:2] # label needs to be continuous
if self.mode == "test":
data["image"] = data["image"][0]
data["label"] = data["label"][0]
return data

View File

@@ -1,17 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .augment import tia_perspective, tia_distort, tia_stretch
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']

View File

@@ -1,120 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py
"""
import numpy as np
from .warp_mls import WarpMLS
def tia_distort(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut // 3
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
dst_pts.append(
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
np.random.randint(thresh) - half_thresh
])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
img_h + np.random.randint(thresh) - half_thresh
])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_stretch(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut * 4 // 5
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, 0])
dst_pts.append([img_w, 0])
dst_pts.append([img_w, img_h])
dst_pts.append([0, img_h])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
move = np.random.randint(thresh) - half_thresh
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([cut * cut_idx + move, 0])
dst_pts.append([cut * cut_idx + move, img_h])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_perspective(src):
img_h, img_w = src.shape[:2]
thresh = img_h // 2
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, np.random.randint(thresh)])
dst_pts.append([img_w, np.random.randint(thresh)])
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
dst_pts.append([0, img_h - np.random.randint(thresh)])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst

View File

@@ -1,168 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
"""
import numpy as np
class WarpMLS:
def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
self.src = src
self.src_pts = src_pts
self.dst_pts = dst_pts
self.pt_count = len(self.dst_pts)
self.dst_w = dst_w
self.dst_h = dst_h
self.trans_ratio = trans_ratio
self.grid_size = 100
self.rdx = np.zeros((self.dst_h, self.dst_w))
self.rdy = np.zeros((self.dst_h, self.dst_w))
@staticmethod
def __bilinear_interp(x, y, v11, v12, v21, v22):
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
(1 - y) + v22 * y) * x
def generate(self):
self.calc_delta()
return self.gen_img()
def calc_delta(self):
w = np.zeros(self.pt_count, dtype=np.float32)
if self.pt_count < 2:
return
i = 0
while 1:
if self.dst_w <= i < self.dst_w + self.grid_size - 1:
i = self.dst_w - 1
elif i >= self.dst_w:
break
j = 0
while 1:
if self.dst_h <= j < self.dst_h + self.grid_size - 1:
j = self.dst_h - 1
elif j >= self.dst_h:
break
sw = 0
swp = np.zeros(2, dtype=np.float32)
swq = np.zeros(2, dtype=np.float32)
new_pt = np.zeros(2, dtype=np.float32)
cur_pt = np.array([i, j], dtype=np.float32)
k = 0
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
break
w[k] = 1. / (
(i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
(j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
sw += w[k]
swp = swp + w[k] * np.array(self.dst_pts[k])
swq = swq + w[k] * np.array(self.src_pts[k])
if k == self.pt_count - 1:
pstar = 1 / sw * swp
qstar = 1 / sw * swq
miu_s = 0
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
continue
pt_i = self.dst_pts[k] - pstar
miu_s += w[k] * np.sum(pt_i * pt_i)
cur_pt -= pstar
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
continue
pt_i = self.dst_pts[k] - pstar
pt_j = np.array([-pt_i[1], pt_i[0]])
tmp_pt = np.zeros(2, dtype=np.float32)
tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
np.sum(pt_j * cur_pt) * self.src_pts[k][1]
tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
tmp_pt *= (w[k] / miu_s)
new_pt += tmp_pt
new_pt += qstar
else:
new_pt = self.src_pts[k]
self.rdx[j, i] = new_pt[0] - i
self.rdy[j, i] = new_pt[1] - j
j += self.grid_size
i += self.grid_size
def gen_img(self):
src_h, src_w = self.src.shape[:2]
dst = np.zeros_like(self.src, dtype=np.float32)
for i in np.arange(0, self.dst_h, self.grid_size):
for j in np.arange(0, self.dst_w, self.grid_size):
ni = i + self.grid_size
nj = j + self.grid_size
w = h = self.grid_size
if ni >= self.dst_h:
ni = self.dst_h - 1
h = ni - i + 1
if nj >= self.dst_w:
nj = self.dst_w - 1
w = nj - j + 1
di = np.reshape(np.arange(h), (-1, 1))
dj = np.reshape(np.arange(w), (1, -1))
delta_x = self.__bilinear_interp(
di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
self.rdx[ni, j], self.rdx[ni, nj])
delta_y = self.__bilinear_interp(
di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
self.rdy[ni, j], self.rdy[ni, nj])
nx = j + dj + delta_x * self.trans_ratio
ny = i + di + delta_y * self.trans_ratio
nx = np.clip(nx, 0, src_w - 1)
ny = np.clip(ny, 0, src_h - 1)
nxi = np.array(np.floor(nx), dtype=np.int32)
nyi = np.array(np.floor(ny), dtype=np.int32)
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
if len(self.src.shape) == 3:
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
else:
x = ny - nyi
y = nx - nxi
dst[i:i + h, j:j + w] = self.__bilinear_interp(
x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
self.src[nyi1, nxi], self.src[nyi1, nxi1])
dst = np.clip(dst, 0, 255)
dst = np.array(dst, dtype=np.uint8)
return dst

View File

@@ -1,19 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
__all__ = [
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
]

View File

@@ -1,17 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
from .vqa_token_pad import VQATokenPad
from .vqa_token_relation import VQAReTokenRelation

View File

@@ -1,122 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
class VQASerTokenChunk(object):
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
self.max_seq_len = max_seq_len
self.infer_mode = infer_mode
def __call__(self, data):
encoded_inputs_all = []
seq_len = len(data['input_ids'])
for index in range(0, seq_len, self.max_seq_len):
chunk_beg = index
chunk_end = min(index + self.max_seq_len, seq_len)
encoded_inputs_example = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
encoded_inputs_example[key] = data[key]
else:
encoded_inputs_example[key] = data[key][chunk_beg:
chunk_end]
else:
encoded_inputs_example[key] = data[key]
encoded_inputs_all.append(encoded_inputs_example)
if len(encoded_inputs_all) == 0:
return None
return encoded_inputs_all[0]
class VQAReTokenChunk(object):
def __init__(self,
max_seq_len=512,
entities_labels=None,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.entities_labels = {
'HEADER': 0,
'QUESTION': 1,
'ANSWER': 2
} if entities_labels is None else entities_labels
self.infer_mode = infer_mode
def __call__(self, data):
# prepare data
entities = data.pop('entities')
relations = data.pop('relations')
encoded_inputs_all = []
for index in range(0, len(data["input_ids"]), self.max_seq_len):
item = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
item[key] = data[key]
else:
item[key] = data[key][index:index + self.max_seq_len]
else:
item[key] = data[key]
# select entity in current chunk
entities_in_this_span = []
global_to_local_map = {} #
for entity_id, entity in enumerate(entities):
if (index <= entity["start"] < index + self.max_seq_len and
index <= entity["end"] < index + self.max_seq_len):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
entities_in_this_span.append(entity)
# select relations in current chunk
relations_in_this_span = []
for relation in relations:
if (index <= relation["start_index"] < index + self.max_seq_len
and index <= relation["end_index"] <
index + self.max_seq_len):
relations_in_this_span.append({
"head": global_to_local_map[relation["head"]],
"tail": global_to_local_map[relation["tail"]],
"start_index": relation["start_index"] - index,
"end_index": relation["end_index"] - index,
})
item.update({
"entities": self.reformat(entities_in_this_span),
"relations": self.reformat(relations_in_this_span),
})
if len(item['entities']) > 0:
item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label']
]
encoded_inputs_all.append(item)
if len(encoded_inputs_all) == 0:
return None
return encoded_inputs_all[0]
def reformat(self, data):
new_data = defaultdict(list)
for item in data:
for k, v in item.items():
new_data[k].append(v)
return new_data

View File

@@ -1,104 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
class VQATokenPad(object):
def __init__(self,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
truncation_strategy="longest_first",
return_overflowing_tokens=False,
return_special_tokens_mask=False,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.pad_to_max_seq_len = max_seq_len
self.return_attention_mask = return_attention_mask
self.return_token_type_ids = return_token_type_ids
self.truncation_strategy = truncation_strategy
self.return_overflowing_tokens = return_overflowing_tokens
self.return_special_tokens_mask = return_special_tokens_mask
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
self.infer_mode = infer_mode
def __call__(self, data):
needs_to_be_padded = self.pad_to_max_seq_len and len(data[
"input_ids"]) < self.max_seq_len
if needs_to_be_padded:
if 'tokenizer_params' in data:
tokenizer_params = data.pop('tokenizer_params')
else:
tokenizer_params = dict(
padding_side='right', pad_token_type_id=0, pad_token_id=1)
difference = self.max_seq_len - len(data["input_ids"])
if tokenizer_params['padding_side'] == 'right':
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data[
"input_ids"]) + [0] * difference
if self.return_token_type_ids:
data["token_type_ids"] = (
data["token_type_ids"] +
[tokenizer_params['pad_token_type_id']] * difference)
if self.return_special_tokens_mask:
data["special_tokens_mask"] = data[
"special_tokens_mask"] + [1] * difference
data["input_ids"] = data["input_ids"] + [
tokenizer_params['pad_token_id']
] * difference
if not self.infer_mode:
data["labels"] = data[
"labels"] + [self.pad_token_label_id] * difference
data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
elif tokenizer_params['padding_side'] == 'left':
if self.return_attention_mask:
data["attention_mask"] = [0] * difference + [
1
] * len(data["input_ids"])
if self.return_token_type_ids:
data["token_type_ids"] = (
[tokenizer_params['pad_token_type_id']] * difference +
data["token_type_ids"])
if self.return_special_tokens_mask:
data["special_tokens_mask"] = [
1
] * difference + data["special_tokens_mask"]
data["input_ids"] = [tokenizer_params['pad_token_id']
] * difference + data["input_ids"]
if not self.infer_mode:
data["labels"] = [self.pad_token_label_id
] * difference + data["labels"]
data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
else:
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data["input_ids"])
for key in data:
if key in [
'input_ids', 'labels', 'token_type_ids', 'bbox',
'attention_mask'
]:
if self.infer_mode:
if key != 'labels':
length = min(len(data[key]), self.max_seq_len)
data[key] = data[key][:length]
else:
continue
data[key] = np.array(data[key], dtype='int64')
return data

View File

@@ -1,67 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class VQAReTokenRelation(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
"""
build relations
"""
entities = data['entities']
relations = data['relations']
id2label = data.pop('id2label')
empty_entity = data.pop('empty_entity')
entity_id_to_index_map = data.pop('entity_id_to_index_map')
relations = list(set(relations))
relations = [
rel for rel in relations
if rel[0] not in empty_entity and rel[1] not in empty_entity
]
kv_relations = []
for rel in relations:
pair = [id2label[rel[0]], id2label[rel[1]]]
if pair == ["question", "answer"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[0]],
"tail": entity_id_to_index_map[rel[1]]
})
elif pair == ["answer", "question"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[1]],
"tail": entity_id_to_index_map[rel[0]]
})
else:
continue
relations = sorted(
[{
"head": rel["head"],
"tail": rel["tail"],
"start_index": self.get_relation_span(rel, entities)[0],
"end_index": self.get_relation_span(rel, entities)[1],
} for rel in kv_relations],
key=lambda x: x["head"], )
data['relations'] = relations
return data
def get_relation_span(self, rel, entities):
bound = []
for entity_index in [rel["head"], rel["tail"]]:
bound.append(entities[entity_index]["start"])
bound.append(entities[entity_index]["end"])
return min(bound), max(bound)

View File

@@ -1,118 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from paddle.io import Dataset
import lmdb
import cv2
from .imaug import transform, create_operators
class LMDBDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(LMDBDataSet, self).__init__()
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
logger.info("Initialize indexs of datasets:%s" % data_dir)
self.data_idx_order_list = self.dataset_traversal()
if self.do_shuffle:
np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {}
dataset_idx = 0
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
if not dirnames:
env = lmdb.open(
dirpath,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
txn = env.begin(write=False)
num_samples = int(txn.get('num-samples'.encode()))
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
"txn":txn, "num_samples":num_samples}
dataset_idx += 1
return lmdb_sets
def dataset_traversal(self):
lmdb_num = len(self.lmdb_sets)
total_sample_num = 0
for lno in range(lmdb_num):
total_sample_num += self.lmdb_sets[lno]['num_samples']
data_idx_order_list = np.zeros((total_sample_num, 2))
beg_idx = 0
for lno in range(lmdb_num):
tmp_sample_num = self.lmdb_sets[lno]['num_samples']
end_idx = beg_idx + tmp_sample_num
data_idx_order_list[beg_idx:end_idx, 0] = lno
data_idx_order_list[beg_idx:end_idx, 1] \
= list(range(tmp_sample_num))
data_idx_order_list[beg_idx:end_idx, 1] += 1
beg_idx = beg_idx + tmp_sample_num
return data_idx_order_list
def get_img_data(self, value):
"""get_img_data"""
if not value:
return None
imgdata = np.frombuffer(value, dtype='uint8')
if imgdata is None:
return None
imgori = cv2.imdecode(imgdata, 1)
if imgori is None:
return None
return imgori
def get_lmdb_sample_info(self, txn, index):
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key)
if label is None:
return None
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
return imgbuf, label
def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
file_idx)
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img, label = sample_info
data = {'image': img, 'label': label}
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return self.data_idx_order_list.shape[0]

View File

@@ -1,106 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from paddle.io import Dataset
from .imaug import transform, create_operators
import random
class PGDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PGDataSet, self).__init__()
self.logger = logger
self.seed = seed
self.mode = mode
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
img_id = 0
try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
if self.mode.lower() == 'eval':
try:
img_id = int(data_line.split(".")[0][7:])
except:
img_id = 0
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
self.data_idx_order_list[idx], e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)

View File

@@ -1,114 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import random
from paddle.io import Dataset
import json
from .imaug import transform, create_operators
class PubTabDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PubTabDataSet, self).__init__()
self.logger = logger
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path')
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path)
with open(label_file_path, "rb") as f:
self.data_lines = f.readlines()
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def __getitem__(self, idx):
try:
data_line = self.data_lines[idx]
data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
select_flag = True
if self.do_hard_select:
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
if select_flag:
cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name)
data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)

View File

@@ -1,151 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import json
import random
import traceback
from paddle.io import Dataset
from .imaug import transform, create_operators
class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__()
self.logger = logger
self.mode = mode.lower()
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
2)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def shuffle_data_random(self):
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def _try_parse_filename_list(self, file_name):
# multiple images -> one gt label
if len(file_name) > 0 and file_name[0] == "[":
try:
info = json.loads(file_name)
file_name = random.choice(info)
except:
pass
return file_name
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
if hasattr(op, 'ext_data_num'):
ext_data_num = getattr(op, 'ext_data_num')
break
load_data_ops = self.ops[:self.ext_op_transform_idx]
ext_data = []
while len(ext_data) < ext_data_num:
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
))]
data_line = self.data_lines[file_idx]
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
continue
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data = transform(data, load_data_ops)
if data is None:
continue
if 'polys' in data.keys():
if data['polys'].shape[1] != 4:
continue
ext_data.append(data)
return ext_data
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops)
except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, traceback.format_exc()))
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs
def __len__(self):
return len(self.data_idx_order_list)

View File

@@ -1,71 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
import paddle.nn as nn
# basic_loss
from .basic_loss import LossFromOutput
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
from .det_pse_loss import PSELoss
from .det_fce_loss import FCELoss
# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
from .rec_nrtr_loss import NRTRLoss
from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
# cls loss
from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
from .kie_sdmgr_loss import SDMGRLoss
# basic loss function
from .basic_loss import DistanceLoss
# combined loss function
from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def build_loss(config):
support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
support_dict))
module_class = eval(module_name)(**config)
return module_class

View File

@@ -1,52 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is refer from: https://github.com/viig99/LS-ACELoss
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
class ACELoss(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
self.loss_func = nn.CrossEntropyLoss(
weight=None,
ignore_index=0,
reduction='none',
soft_label=True,
axis=-1)
def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
B, N = predicts.shape[:2]
div = paddle.to_tensor([N]).astype('float32')
predicts = nn.functional.softmax(predicts, axis=-1)
aggregation_preds = paddle.sum(predicts, axis=1)
aggregation_preds = paddle.divide(aggregation_preds, div)
length = batch[2].astype("float32")
batch = batch[3].astype("float32")
batch[:, 0] = paddle.subtract(div, length)
batch = paddle.divide(batch, div)
loss = self.loss_func(aggregation_preds, batch)
return {"loss_ace": loss}

View File

@@ -1,155 +0,0 @@
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss
class CELoss(nn.Layer):
def __init__(self, epsilon=None):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def forward(self, x, label):
loss_dict = {}
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
else:
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
return loss
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js":
loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None:
return loss
else:
loss = paddle.sum(loss, axis=[1, 2])
return loss
class DMLLoss(nn.Layer):
"""
DMLLoss
"""
def __init__(self, act=None, use_log=False):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
if act == "softmax":
self.act = nn.Softmax(axis=-1)
elif act == "sigmoid":
self.act = nn.Sigmoid()
else:
self.act = None
self.use_log = use_log
self.jskl_loss = KLJSLoss(mode="js")
def _kldiv(self, x, target):
eps = 1.0e-10
loss = target * (paddle.log(target + eps) - x)
# batch mean loss
loss = paddle.sum(loss) / loss.shape[0]
return loss
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1) + 1e-10
out2 = self.act(out2) + 1e-10
if self.use_log:
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (
self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
else:
# for detection distillation log is not needed
loss = self.jskl_loss(out1, out2)
return loss
class DistanceLoss(nn.Layer):
"""
DistanceLoss:
mode: loss mode
"""
def __init__(self, mode="l2", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)
elif mode == "l2":
self.loss_func = nn.MSELoss(**kargs)
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)
def forward(self, x, y):
return self.loss_func(x, y)
class LossFromOutput(nn.Layer):
def __init__(self, key='loss', reduction='none'):
super().__init__()
self.key = key
self.reduction = reduction
def forward(self, predicts, batch):
loss = predicts[self.key]
if self.reduction == 'mean':
loss = paddle.mean(loss)
elif self.reduction == 'sum':
loss = paddle.sum(loss)
return {'loss': loss}

View File

@@ -1,88 +0,0 @@
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# This code is refer from: https://github.com/KaiyangZhou/pytorch-center-loss
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CenterLoss(nn.Layer):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""
def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
super().__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype("float64")
if center_file_path is not None:
assert os.path.exists(
center_file_path
), f"center path({center_file_path}) must exist when it is not None."
with open(center_file_path, 'rb') as f:
char_dict = pickle.load(f)
for key in char_dict.keys():
self.centers[key] = paddle.to_tensor(char_dict[key])
def __call__(self, predicts, batch):
assert isinstance(predicts, (list, tuple))
features, predicts = predicts
feats_reshape = paddle.reshape(
features, [-1, features.shape[-1]]).astype("float64")
label = paddle.argmax(predicts, axis=2)
label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
batch_size = feats_reshape.shape[0]
#calc l2 distance between feats and centers
square_feat = paddle.sum(paddle.square(feats_reshape),
axis=1,
keepdim=True)
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
square_center = paddle.sum(paddle.square(self.centers),
axis=1,
keepdim=True)
square_center = paddle.expand(
square_center, [self.num_classes, batch_size]).astype("float64")
square_center = paddle.transpose(square_center, [1, 0])
distmat = paddle.add(square_feat, square_center)
feat_dot_center = paddle.matmul(feats_reshape,
paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * feat_dot_center
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
label = paddle.expand(
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
label).astype("float64")
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'loss_center': loss}

View File

@@ -1,30 +0,0 @@
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
class ClsLoss(nn.Layer):
def __init__(self, **kwargs):
super(ClsLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
def forward(self, predicts, batch):
label = batch[1].astype("int64")
loss = self.loss_func(input=predicts, label=label)
return {'loss': loss}

View File

@@ -1,69 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .rec_ctc_loss import CTCLoss
from .center_loss import CenterLoss
from .ace_loss import ACELoss
from .rec_sar_loss import SARLoss
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationSARLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
class CombinedLoss(nn.Layer):
"""
CombinedLoss:
a combionation of loss function
"""
def __init__(self, loss_config_list=None):
super().__init__()
self.loss_func = []
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
for config in loss_config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
def forward(self, input, batch, **kargs):
loss_dict = {}
loss_all = 0.
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx]
loss = {key: loss[key] * weight for key in loss}
if "loss" in loss:
loss_all += loss["loss"]
else:
loss_all += paddle.add_n(list(loss.values()))
loss_dict.update(loss)
loss_dict["loss"] = loss_all
return loss_dict

View File

@@ -1,153 +0,0 @@
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import nn
import paddle.nn.functional as F
class BalanceLoss(nn.Layer):
def __init__(self,
balance_loss=True,
main_loss_type='DiceLoss',
negative_ratio=3,
return_origin=False,
eps=1e-6,
**kwargs):
"""
The BalanceLoss for Differentiable Binarization text detection
args:
balance_loss (bool): whether balance loss or not, default is True
main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
negative_ratio (int|float): float, default is 3.
return_origin (bool): whether return unbalanced loss or not, default is False.
eps (float): default is 1e-6.
"""
super(BalanceLoss, self).__init__()
self.balance_loss = balance_loss
self.main_loss_type = main_loss_type
self.negative_ratio = negative_ratio
self.return_origin = return_origin
self.eps = eps
if self.main_loss_type == "CrossEntropy":
self.loss = nn.CrossEntropyLoss()
elif self.main_loss_type == "Euclidean":
self.loss = nn.MSELoss()
elif self.main_loss_type == "DiceLoss":
self.loss = DiceLoss(self.eps)
elif self.main_loss_type == "BCELoss":
self.loss = BCELoss(reduction='none')
elif self.main_loss_type == "MaskL1Loss":
self.loss = MaskL1Loss(self.eps)
else:
loss_type = [
'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
]
raise Exception(
"main_loss_type in BalanceLoss() can only be one of {}".format(
loss_type))
def forward(self, pred, gt, mask=None):
"""
The BalanceLoss for Differentiable Binarization text detection
args:
pred (variable): predicted feature maps.
gt (variable): ground truth feature maps.
mask (variable): masked maps.
return: (variable) balanced loss
"""
positive = gt * mask
negative = (1 - gt) * mask
positive_count = int(positive.sum())
negative_count = int(
min(negative.sum(), positive_count * self.negative_ratio))
loss = self.loss(pred, gt, mask=mask)
if not self.balance_loss:
return loss
positive_loss = positive * loss
negative_loss = negative * loss
negative_loss = paddle.reshape(negative_loss, shape=[-1])
if negative_count > 0:
sort_loss = negative_loss.sort(descending=True)
negative_loss = sort_loss[:negative_count]
# negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
positive_count + negative_count + self.eps)
else:
balance_loss = positive_loss.sum() / (positive_count + self.eps)
if self.return_origin:
return balance_loss, loss
return balance_loss
class DiceLoss(nn.Layer):
def __init__(self, eps=1e-6):
super(DiceLoss, self).__init__()
self.eps = eps
def forward(self, pred, gt, mask, weights=None):
"""
DiceLoss function.
"""
assert pred.shape == gt.shape
assert pred.shape == mask.shape
if weights is not None:
assert weights.shape == mask.shape
mask = weights * mask
intersection = paddle.sum(pred * gt * mask)
union = paddle.sum(pred * mask) + paddle.sum(gt * mask) + self.eps
loss = 1 - 2.0 * intersection / union
assert loss <= 1
return loss
class MaskL1Loss(nn.Layer):
def __init__(self, eps=1e-6):
super(MaskL1Loss, self).__init__()
self.eps = eps
def forward(self, pred, gt, mask):
"""
Mask L1 Loss
"""
loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
loss = paddle.mean(loss)
return loss
class BCELoss(nn.Layer):
def __init__(self, reduction='mean'):
super(BCELoss, self).__init__()
self.reduction = reduction
def forward(self, input, label, mask=None, weight=None, name=None):
loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
return loss

View File

@@ -1,76 +0,0 @@
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
class DBLoss(nn.Layer):
"""
Differentiable Binarization (DB) Loss Function
args:
param (dict): the super paramter for DB Loss
"""
def __init__(self,
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
beta=10,
ohem_ratio=3,
eps=1e-6,
**kwargs):
super(DBLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.dice_loss = DiceLoss(eps=eps)
self.l1_loss = MaskL1Loss(eps=eps)
self.bce_loss = BalanceLoss(
balance_loss=balance_loss,
main_loss_type=main_loss_type,
negative_ratio=ohem_ratio)
def forward(self, predicts, labels):
predict_maps = predicts['maps']
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
1:]
shrink_maps = predict_maps[:, 0, :, :]
threshold_maps = predict_maps[:, 1, :, :]
binary_maps = predict_maps[:, 2, :, :]
loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
label_shrink_mask)
loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
label_threshold_mask)
loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
label_shrink_mask)
loss_shrink_maps = self.alpha * loss_shrink_maps
loss_threshold_maps = self.beta * loss_threshold_maps
loss_all = loss_shrink_maps + loss_threshold_maps \
+ loss_binary_maps
losses = {'loss': loss_all, \
"loss_shrink_maps": loss_shrink_maps, \
"loss_threshold_maps": loss_threshold_maps, \
"loss_binary_maps": loss_binary_maps}
return losses

View File

@@ -1,63 +0,0 @@
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .det_basic_loss import DiceLoss
class EASTLoss(nn.Layer):
"""
"""
def __init__(self,
eps=1e-6,
**kwargs):
super(EASTLoss, self).__init__()
self.dice_loss = DiceLoss(eps=eps)
def forward(self, predicts, labels):
l_score, l_geo, l_mask = labels[1:]
f_score = predicts['f_score']
f_geo = predicts['f_geo']
dice_loss = self.dice_loss(f_score, l_score, l_mask)
#smoooth_l1_loss
channels = 8
l_geo_split = paddle.split(
l_geo, num_or_sections=channels + 1, axis=1)
f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1)
smooth_l1 = 0
for i in range(0, channels):
geo_diff = l_geo_split[i] - f_geo_split[i]
abs_geo_diff = paddle.abs(geo_diff)
smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score)
smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32')
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
(abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
out_loss = l_geo_split[-1] / channels * in_loss * l_score
smooth_l1 += out_loss
smooth_l1_loss = paddle.mean(smooth_l1 * l_score)
dice_loss = dice_loss * 0.01
total_loss = dice_loss + smooth_l1_loss
losses = {"loss":total_loss, \
"dice_loss":dice_loss,\
"smooth_l1_loss":smooth_l1_loss}
return losses

View File

@@ -1,227 +0,0 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/fce_loss.py
"""
import numpy as np
from paddle import nn
import paddle
import paddle.nn.functional as F
from functools import partial
def multi_apply(func, *args, **kwargs):
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
class FCELoss(nn.Layer):
"""The class for implementing FCENet loss
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
Text Detection
[https://arxiv.org/abs/2104.10442]
Args:
fourier_degree (int) : The maximum Fourier transform degree k.
num_sample (int) : The sampling points number of regression
loss. If it is too small, fcenet tends to be overfitting.
ohem_ratio (float): the negative/positive ratio in OHEM.
"""
def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
super().__init__()
self.fourier_degree = fourier_degree
self.num_sample = num_sample
self.ohem_ratio = ohem_ratio
def forward(self, preds, labels):
assert isinstance(preds, dict)
preds = preds['levels']
p3_maps, p4_maps, p5_maps = labels[1:]
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
'fourier degree not equal in FCEhead and FCEtarget'
# to tensor
gts = [p3_maps, p4_maps, p5_maps]
for idx, maps in enumerate(gts):
gts[idx] = paddle.to_tensor(np.stack(maps))
losses = multi_apply(self.forward_single, preds, gts)
loss_tr = paddle.to_tensor(0.).astype('float32')
loss_tcl = paddle.to_tensor(0.).astype('float32')
loss_reg_x = paddle.to_tensor(0.).astype('float32')
loss_reg_y = paddle.to_tensor(0.).astype('float32')
loss_all = paddle.to_tensor(0.).astype('float32')
for idx, loss in enumerate(losses):
loss_all += sum(loss)
if idx == 0:
loss_tr += sum(loss)
elif idx == 1:
loss_tcl += sum(loss)
elif idx == 2:
loss_reg_x += sum(loss)
else:
loss_reg_y += sum(loss)
results = dict(
loss=loss_all,
loss_text=loss_tr,
loss_center=loss_tcl,
loss_reg_x=loss_reg_x,
loss_reg_y=loss_reg_y, )
return results
def forward_single(self, pred, gt):
cls_pred = paddle.transpose(pred[0], (0, 2, 3, 1))
reg_pred = paddle.transpose(pred[1], (0, 2, 3, 1))
gt = paddle.transpose(gt, (0, 2, 3, 1))
k = 2 * self.fourier_degree + 1
tr_pred = paddle.reshape(cls_pred[:, :, :, :2], (-1, 2))
tcl_pred = paddle.reshape(cls_pred[:, :, :, 2:], (-1, 2))
x_pred = paddle.reshape(reg_pred[:, :, :, 0:k], (-1, k))
y_pred = paddle.reshape(reg_pred[:, :, :, k:2 * k], (-1, k))
tr_mask = gt[:, :, :, :1].reshape([-1])
tcl_mask = gt[:, :, :, 1:2].reshape([-1])
train_mask = gt[:, :, :, 2:3].reshape([-1])
x_map = paddle.reshape(gt[:, :, :, 3:3 + k], (-1, k))
y_map = paddle.reshape(gt[:, :, :, 3 + k:], (-1, k))
tr_train_mask = (train_mask * tr_mask).astype('bool')
tr_train_mask2 = paddle.concat(
[tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1)
# tr loss
loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
# tcl loss
loss_tcl = paddle.to_tensor(0.).astype('float32')
tr_neg_mask = tr_train_mask.logical_not()
tr_neg_mask2 = paddle.concat(
[tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], axis=1)
if tr_train_mask.sum().item() > 0:
loss_tcl_pos = F.cross_entropy(
tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]),
tcl_mask.masked_select(tr_train_mask).astype('int64'))
loss_tcl_neg = F.cross_entropy(
tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]),
tcl_mask.masked_select(tr_neg_mask).astype('int64'))
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
# regression loss
loss_reg_x = paddle.to_tensor(0.).astype('float32')
loss_reg_y = paddle.to_tensor(0.).astype('float32')
if tr_train_mask.sum().item() > 0:
weight = (tr_mask.masked_select(tr_train_mask.astype('bool'))
.astype('float32') + tcl_mask.masked_select(
tr_train_mask.astype('bool')).astype('float32')) / 2
weight = weight.reshape([-1, 1])
ft_x, ft_y = self.fourier2poly(x_map, y_map)
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
dim = ft_x.shape[1]
tr_train_mask3 = paddle.concat(
[tr_train_mask.unsqueeze(1) for i in range(dim)], axis=1)
loss_reg_x = paddle.mean(weight * F.smooth_l1_loss(
ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
reduction='none'))
loss_reg_y = paddle.mean(weight * F.smooth_l1_loss(
ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
reduction='none'))
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
def ohem(self, predict, target, train_mask):
pos = (target * train_mask).astype('bool')
neg = ((1 - target) * train_mask).astype('bool')
pos2 = paddle.concat([pos.unsqueeze(1), pos.unsqueeze(1)], axis=1)
neg2 = paddle.concat([neg.unsqueeze(1), neg.unsqueeze(1)], axis=1)
n_pos = pos.astype('float32').sum()
if n_pos.item() > 0:
loss_pos = F.cross_entropy(
predict.masked_select(pos2).reshape([-1, 2]),
target.masked_select(pos).astype('int64'),
reduction='sum')
loss_neg = F.cross_entropy(
predict.masked_select(neg2).reshape([-1, 2]),
target.masked_select(neg).astype('int64'),
reduction='none')
n_neg = min(
int(neg.astype('float32').sum().item()),
int(self.ohem_ratio * n_pos.astype('float32')))
else:
loss_pos = paddle.to_tensor(0.)
loss_neg = F.cross_entropy(
predict.masked_select(neg2).reshape([-1, 2]),
target.masked_select(neg).astype('int64'),
reduction='none')
n_neg = 100
if len(loss_neg) > n_neg:
loss_neg, _ = paddle.topk(loss_neg, n_neg)
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).astype('float32')
def fourier2poly(self, real_maps, imag_maps):
"""Transform Fourier coefficient maps to polygon maps.
Args:
real_maps (tensor): A map composed of the real parts of the
Fourier coefficients, whose shape is (-1, 2k+1)
imag_maps (tensor):A map composed of the imag parts of the
Fourier coefficients, whose shape is (-1, 2k+1)
Returns
x_maps (tensor): A map composed of the x value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
y_maps (tensor): A map composed of the y value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
"""
k_vect = paddle.arange(
-self.fourier_degree, self.fourier_degree + 1,
dtype='float32').reshape([-1, 1])
i_vect = paddle.arange(
0, self.num_sample, dtype='float32').reshape([1, -1])
transform_matrix = 2 * np.pi / self.num_sample * paddle.matmul(k_vect,
i_vect)
x1 = paddle.einsum('ak, kn-> an', real_maps,
paddle.cos(transform_matrix))
x2 = paddle.einsum('ak, kn-> an', imag_maps,
paddle.sin(transform_matrix))
y1 = paddle.einsum('ak, kn-> an', real_maps,
paddle.sin(transform_matrix))
y2 = paddle.einsum('ak, kn-> an', imag_maps,
paddle.cos(transform_matrix))
x_maps = x1 - x2
y_maps = y1 + y2
return x_maps, y_maps

View File

@@ -1,149 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
"""
import paddle
from paddle import nn
from paddle.nn import functional as F
import numpy as np
from ppocr.utils.iou import iou
class PSELoss(nn.Layer):
def __init__(self,
alpha,
ohem_ratio=3,
kernel_sample_mask='pred',
reduction='sum',
eps=1e-6,
**kwargs):
"""Implement PSE Loss.
"""
super(PSELoss, self).__init__()
assert reduction in ['sum', 'mean', 'none']
self.alpha = alpha
self.ohem_ratio = ohem_ratio
self.kernel_sample_mask = kernel_sample_mask
self.reduction = reduction
self.eps = eps
def forward(self, outputs, labels):
predicts = outputs['maps']
predicts = F.interpolate(predicts, scale_factor=4)
texts = predicts[:, 0, :, :]
kernels = predicts[:, 1:, :, :]
gt_texts, gt_kernels, training_masks = labels[1:]
# text loss
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
loss_text = self.dice_loss(texts, gt_texts, selected_masks)
iou_text = iou((texts > 0).astype('int64'),
gt_texts,
training_masks,
reduce=False)
losses = dict(loss_text=loss_text, iou_text=iou_text)
# kernel loss
loss_kernels = []
if self.kernel_sample_mask == 'gt':
selected_masks = gt_texts * training_masks
elif self.kernel_sample_mask == 'pred':
selected_masks = (
F.sigmoid(texts) > 0.5).astype('float32') * training_masks
for i in range(kernels.shape[1]):
kernel_i = kernels[:, i, :, :]
gt_kernel_i = gt_kernels[:, i, :, :]
loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i,
selected_masks)
loss_kernels.append(loss_kernel_i)
loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1)
iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'),
gt_kernels[:, -1, :, :],
training_masks * gt_texts,
reduce=False)
losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel))
loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels
losses['loss'] = loss
if self.reduction == 'sum':
losses = {x: paddle.sum(v) for x, v in losses.items()}
elif self.reduction == 'mean':
losses = {x: paddle.mean(v) for x, v in losses.items()}
return losses
def dice_loss(self, input, target, mask):
input = F.sigmoid(input)
input = input.reshape([input.shape[0], -1])
target = target.reshape([target.shape[0], -1])
mask = mask.reshape([mask.shape[0], -1])
input = input * mask
target = target * mask
a = paddle.sum(input * target, 1)
b = paddle.sum(input * input, 1) + self.eps
c = paddle.sum(target * target, 1) + self.eps
d = (2 * a) / (b + c)
return 1 - d
def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int(
paddle.sum(
paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5))
.astype('float32')))
if pos_num == 0:
selected_mask = training_mask
selected_mask = selected_mask.reshape(
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
'float32')
return selected_mask
neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32')))
neg_num = int(min(pos_num * ohem_ratio, neg_num))
if neg_num == 0:
selected_mask = training_mask
selected_mask = selected_mask.reshape(
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
'float32')
return selected_mask
neg_score = paddle.masked_select(score, gt_text <= 0.5)
neg_score_sorted = paddle.sort(-neg_score)
threshold = -neg_score_sorted[neg_num - 1]
selected_mask = paddle.logical_and(
paddle.logical_or((score >= threshold), (gt_text > 0.5)),
(training_mask > 0.5))
selected_mask = selected_mask.reshape(
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
'float32')
return selected_mask
def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3):
selected_masks = []
for i in range(scores.shape[0]):
selected_masks.append(
self.ohem_single(scores[i, :, :], gt_texts[i, :, :],
training_masks[i, :, :], ohem_ratio))
selected_masks = paddle.concat(selected_masks, 0).astype('float32')
return selected_masks

View File

@@ -1,121 +0,0 @@
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .det_basic_loss import DiceLoss
import numpy as np
class SASTLoss(nn.Layer):
"""
"""
def __init__(self, eps=1e-6, **kwargs):
super(SASTLoss, self).__init__()
self.dice_loss = DiceLoss(eps=eps)
def forward(self, predicts, labels):
"""
tcl_pos: N x 128 x 3
tcl_mask: N x 128 x 1
tcl_label: N x X list or LoDTensor
"""
f_score = predicts['f_score']
f_border = predicts['f_border']
f_tvo = predicts['f_tvo']
f_tco = predicts['f_tco']
l_score, l_border, l_mask, l_tvo, l_tco = labels[1:]
#score_loss
intersection = paddle.sum(f_score * l_score * l_mask)
union = paddle.sum(f_score * l_mask) + paddle.sum(l_score * l_mask)
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
#border loss
l_border_split, l_border_norm = paddle.split(
l_border, num_or_sections=[4, 1], axis=1)
f_border_split = f_border
border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1])
l_border_norm_split = paddle.expand(
x=l_border_norm, shape=border_ex_shape)
l_border_score = paddle.expand(x=l_score, shape=border_ex_shape)
l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape)
border_diff = l_border_split - f_border_split
abs_border_diff = paddle.abs(border_diff)
border_sign = abs_border_diff < 1.0
border_sign = paddle.cast(border_sign, dtype='float32')
border_sign.stop_gradient = True
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
(abs_border_diff - 0.5) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
#tvo_loss
l_tvo_split, l_tvo_norm = paddle.split(
l_tvo, num_or_sections=[8, 1], axis=1)
f_tvo_split = f_tvo
tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1])
l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape)
l_tvo_score = paddle.expand(x=l_score, shape=tvo_ex_shape)
l_tvo_mask = paddle.expand(x=l_mask, shape=tvo_ex_shape)
#
tvo_geo_diff = l_tvo_split - f_tvo_split
abs_tvo_geo_diff = paddle.abs(tvo_geo_diff)
tvo_sign = abs_tvo_geo_diff < 1.0
tvo_sign = paddle.cast(tvo_sign, dtype='float32')
tvo_sign.stop_gradient = True
tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
(abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
(paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5)
#tco_loss
l_tco_split, l_tco_norm = paddle.split(
l_tco, num_or_sections=[2, 1], axis=1)
f_tco_split = f_tco
tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1])
l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape)
l_tco_score = paddle.expand(x=l_score, shape=tco_ex_shape)
l_tco_mask = paddle.expand(x=l_mask, shape=tco_ex_shape)
tco_geo_diff = l_tco_split - f_tco_split
abs_tco_geo_diff = paddle.abs(tco_geo_diff)
tco_sign = abs_tco_geo_diff < 1.0
tco_sign = paddle.cast(tco_sign, dtype='float32')
tco_sign.stop_gradient = True
tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
(abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
tco_out_loss = l_tco_norm_split * tco_in_loss
tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \
(paddle.sum(l_tco_score * l_tco_mask) + 1e-5)
# total loss
tvo_lw, tco_lw = 1.5, 1.5
score_lw, border_lw = 1.0, 1.0
total_loss = score_loss * score_lw + border_loss * border_lw + \
tvo_loss * tvo_lw + tco_loss * tco_lw
losses = {'loss':total_loss, "score_loss":score_loss,\
"border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
return losses

View File

@@ -1,324 +0,0 @@
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import numpy as np
import cv2
from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
def _sum_loss(loss_dict):
if "loss" in loss_dict.keys():
return loss_dict
else:
loss_dict["loss"] = 0.
for k, value in loss_dict.items():
if k == "loss":
continue
else:
loss_dict["loss"] += value
return loss_dict
class DistillationDMLLoss(DMLLoss):
"""
"""
def __init__(self,
model_name_pairs=[],
act=None,
use_log=False,
key=None,
multi_head=False,
dis_head='ctc',
maps_name=None,
name="dml"):
super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
self.multi_head = multi_head
self.dis_head = dis_head
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name):
if maps_name is None:
return None
elif type(maps_name) == str:
return [maps_name]
elif type(maps_name) == list:
return [maps_name]
else:
return None
def _slice_out(self, outs):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps":
new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps":
new_outs[k] = outs[:, 2, :, :]
else:
continue
return new_outs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if self.maps_name is None:
if self.multi_head:
loss = super().forward(out1[self.dis_head],
out2[self.dis_head])
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2)
for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], self.maps_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
_c], idx)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationCTCLoss(CTCLoss):
def __init__(self,
model_name_list=[],
key=None,
multi_head=False,
name="loss_ctc"):
super().__init__()
self.model_name_list = model_name_list
self.key = key
self.name = name
self.multi_head = multi_head
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
if self.multi_head:
assert 'ctc' in out, 'multi head has multi out'
loss = super().forward(out['ctc'], batch[:2] + batch[3:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, model_name,
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationSARLoss(SARLoss):
def __init__(self,
model_name_list=[],
key=None,
multi_head=False,
name="loss_sar",
**kwargs):
ignore_index = kwargs.get('ignore_index', 92)
super().__init__(ignore_index=ignore_index)
self.model_name_list = model_name_list
self.key = key
self.name = name
self.multi_head = multi_head
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
if self.multi_head:
assert 'sar' in out, 'multi head has multi out'
loss = super().forward(out['sar'], batch[:1] + batch[2:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, model_name,
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationDBLoss(DBLoss):
def __init__(self,
model_name_list=[],
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
beta=10,
ohem_ratio=3,
eps=1e-6,
name="db",
**kwargs):
super().__init__()
self.model_name_list = model_name_list
self.name = name
self.key = None
def forward(self, predicts, batch):
loss_dict = {}
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss.keys():
if key == "loss":
continue
name = "{}_{}_{}".format(self.name, model_name, key)
loss_dict[name] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationDilaDBLoss(DBLoss):
def __init__(self,
model_name_pairs=[],
key=None,
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
beta=10,
ohem_ratio=3,
eps=1e-6,
name="dila_dbloss"):
super().__init__()
self.model_name_pairs = model_name_pairs
self.name = name
self.key = key
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
stu_outs = predicts[pair[0]]
tch_outs = predicts[pair[1]]
if self.key is not None:
stu_preds = stu_outs[self.key]
tch_preds = tch_outs[self.key]
stu_shrink_maps = stu_preds[:, 0, :, :]
stu_binary_maps = stu_preds[:, 2, :, :]
# dilation to teacher prediction
dilation_w = np.array([[1, 1], [1, 1]])
th_shrink_maps = tch_preds[:, 0, :, :]
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
for i in range(th_shrink_maps.shape[0]):
dilate_maps[i] = cv2.dilate(
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
th_shrink_maps = paddle.to_tensor(dilate_maps)
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
1:]
# calculate the shrink map loss
bce_loss = self.alpha * self.bce_loss(
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
label_shrink_mask)
# k = f"{self.name}_{pair[0]}_{pair[1]}"
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
loss_dict[k] = bce_loss + loss_binary_maps
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationDistanceLoss(DistanceLoss):
"""
"""
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
key]
else:
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict

View File

@@ -1,140 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
import paddle
from .det_basic_loss import DiceLoss
from ppocr.utils.e2e_utils.extract_batchsize import pre_process
class PGLoss(nn.Layer):
def __init__(self,
tcl_bs,
max_text_length,
max_text_nums,
pad_num,
eps=1e-6,
**kwargs):
super(PGLoss, self).__init__()
self.tcl_bs = tcl_bs
self.max_text_nums = max_text_nums
self.max_text_length = max_text_length
self.pad_num = pad_num
self.dice_loss = DiceLoss(eps=eps)
def border_loss(self, f_border, l_border, l_score, l_mask):
l_border_split, l_border_norm = paddle.tensor.split(
l_border, num_or_sections=[4, 1], axis=1)
f_border_split = f_border
b, c, h, w = l_border_norm.shape
l_border_norm_split = paddle.expand(
x=l_border_norm, shape=[b, 4 * c, h, w])
b, c, h, w = l_score.shape
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
b, c, h, w = l_mask.shape
l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
border_diff = l_border_split - f_border_split
abs_border_diff = paddle.abs(border_diff)
border_sign = abs_border_diff < 1.0
border_sign = paddle.cast(border_sign, dtype='float32')
border_sign.stop_gradient = True
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
(abs_border_diff - 0.5) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
return border_loss
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
l_direction_split, l_direction_norm = paddle.tensor.split(
l_direction, num_or_sections=[2, 1], axis=1)
f_direction_split = f_direction
b, c, h, w = l_direction_norm.shape
l_direction_norm_split = paddle.expand(
x=l_direction_norm, shape=[b, 2 * c, h, w])
b, c, h, w = l_score.shape
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
b, c, h, w = l_mask.shape
l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
direction_diff = l_direction_split - f_direction_split
abs_direction_diff = paddle.abs(direction_diff)
direction_sign = abs_direction_diff < 1.0
direction_sign = paddle.cast(direction_sign, dtype='float32')
direction_sign.stop_gradient = True
direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
(abs_direction_diff - 0.5) * (1.0 - direction_sign)
direction_out_loss = l_direction_norm_split * direction_in_loss
direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
(paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
return direction_loss
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
f_char = paddle.transpose(f_char, [0, 2, 3, 1])
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
tcl_pos = paddle.cast(tcl_pos, dtype=int)
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
f_tcl_char = paddle.reshape(f_tcl_char,
[-1, 64, 37]) # len(Lexicon_Table)+1
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
b, c, l = tcl_mask.shape
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
tcl_mask_fg.stop_gradient = True
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
-20.0)
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
N, B, _ = f_tcl_char_ld.shape
input_lengths = paddle.to_tensor([N] * B, dtype='int64')
cost = paddle.nn.functional.ctc_loss(
log_probs=f_tcl_char_ld,
labels=tcl_label,
input_lengths=input_lengths,
label_lengths=label_t,
blank=self.pad_num,
reduction='none')
cost = cost.mean()
return cost
def forward(self, predicts, labels):
images, tcl_maps, tcl_label_maps, border_maps \
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
# for all the batch_size
pos_list, pos_mask, label_list, label_t = pre_process(
label_list, pos_list, pos_mask, self.max_text_length,
self.max_text_nums, self.pad_num, self.tcl_bs)
f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
predicts['f_char']
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
border_loss = self.border_loss(f_border, border_maps, tcl_maps,
training_masks)
direction_loss = self.direction_loss(f_direction, direction_maps,
tcl_maps, training_masks)
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
losses = {
'loss': loss_all,
"score_loss": score_loss,
"border_loss": border_loss,
"direction_loss": direction_loss,
"ctc_loss": ctc_loss
}
return losses

View File

@@ -1,115 +0,0 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference from : https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/losses/sdmgr_loss.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
import paddle
class SDMGRLoss(nn.Layer):
def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
super().__init__()
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
self.node_weight = node_weight
self.edge_weight = edge_weight
self.ignore = ignore
def pre_process(self, gts, tag):
gts, tag = gts.numpy(), tag.numpy().tolist()
temp_gts = []
batch = len(tag)
for i in range(batch):
num, recoder_len = tag[i][0], tag[i][1]
temp_gts.append(
paddle.to_tensor(
gts[i, :num, :num + 1], dtype='int64'))
return temp_gts
def accuracy(self, pred, target, topk=1, thresh=None):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class)
target (torch.Tensor): The target of each prediction, shape (N, )
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.
Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
maxk = max(topk)
if pred.shape[0] == 0:
accu = [pred.new_tensor(0.) for i in range(len(topk))]
return accu[0] if return_single else accu
pred_value, pred_label = paddle.topk(pred, maxk, axis=1)
pred_label = pred_label.transpose(
[1, 0]) # transpose to shape (maxk, N)
correct = paddle.equal(pred_label,
(target.reshape([1, -1]).expand_as(pred_label)))
res = []
for k in topk:
correct_k = paddle.sum(correct[:k].reshape([-1]).astype('float32'),
axis=0,
keepdim=True)
res.append(
paddle.multiply(correct_k,
paddle.to_tensor(100.0 / pred.shape[0])))
return res[0] if return_single else res
def forward(self, pred, batch):
node_preds, edge_preds = pred
gts, tag = batch[4], batch[5]
gts = self.pre_process(gts, tag)
node_gts, edge_gts = [], []
for gt in gts:
node_gts.append(gt[:, 0])
edge_gts.append(gt[:, 1:].reshape([-1]))
node_gts = paddle.concat(node_gts)
edge_gts = paddle.concat(edge_gts)
node_valids = paddle.nonzero(node_gts != self.ignore).reshape([-1])
edge_valids = paddle.nonzero(edge_gts != -1).reshape([-1])
loss_node = self.loss_node(node_preds, node_gts)
loss_edge = self.loss_edge(edge_preds, edge_gts)
loss = self.node_weight * loss_node + self.edge_weight * loss_edge
return dict(
loss=loss,
loss_node=loss_node,
loss_edge=loss_edge,
acc_node=self.accuracy(
paddle.gather(node_preds, node_valids),
paddle.gather(node_gts, node_valids)),
acc_edge=self.accuracy(
paddle.gather(edge_preds, edge_valids),
paddle.gather(edge_gts, edge_valids)))

View File

@@ -1,99 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class CosineEmbeddingLoss(nn.Layer):
def __init__(self, margin=0.):
super(CosineEmbeddingLoss, self).__init__()
self.margin = margin
self.epsilon = 1e-12
def forward(self, x1, x2, target):
similarity = paddle.fluid.layers.reduce_sum(
x1 * x2, dim=-1) / (paddle.norm(
x1, axis=-1) * paddle.norm(
x2, axis=-1) + self.epsilon)
one_list = paddle.full_like(target, fill_value=1)
out = paddle.fluid.layers.reduce_mean(
paddle.where(
paddle.equal(target, one_list), 1. - similarity,
paddle.maximum(
paddle.zeros_like(similarity), similarity - self.margin)))
return out
class AsterLoss(nn.Layer):
def __init__(self,
weight=None,
size_average=True,
ignore_index=-100,
sequence_normalize=False,
sample_normalize=True,
**kwargs):
super(AsterLoss, self).__init__()
self.weight = weight
self.size_average = size_average
self.ignore_index = ignore_index
self.sequence_normalize = sequence_normalize
self.sample_normalize = sample_normalize
self.loss_sem = CosineEmbeddingLoss()
self.is_cosin_loss = True
self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
def forward(self, predicts, batch):
targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64')
sem_target = batch[3].astype('float32')
embedding_vectors = predicts['embedding_vectors']
rec_pred = predicts['rec_pred']
if not self.is_cosin_loss:
sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
else:
label_target = paddle.ones([embedding_vectors.shape[0]])
sem_loss = paddle.sum(
self.loss_sem(embedding_vectors, sem_target, label_target))
# rec loss
batch_size, def_max_length = targets.shape[0], targets.shape[1]
mask = paddle.zeros([batch_size, def_max_length])
for i in range(batch_size):
mask[i, :label_lengths[i]] = 1
mask = paddle.cast(mask, "float32")
max_length = max(label_lengths)
assert max_length == rec_pred.shape[1]
targets = targets[:, :max_length]
mask = mask[:, :max_length]
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]])
input = nn.functional.log_softmax(rec_pred, axis=1)
targets = paddle.reshape(targets, [-1, 1])
mask = paddle.reshape(mask, [-1, 1])
output = -paddle.index_sample(input, index=targets) * mask
output = paddle.sum(output)
if self.sequence_normalize:
output = output / paddle.sum(mask)
if self.sample_normalize:
output = output / batch_size
loss = output + sem_loss * 0.1
return {'loss': loss}

View File

@@ -1,39 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class AttentionLoss(nn.Layer):
def __init__(self, **kwargs):
super(AttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
def forward(self, predicts, batch):
targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64')
batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
1], predicts.shape[2]
assert len(targets.shape) == len(list(predicts.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
targets = paddle.reshape(targets, [-1])
return {'loss': paddle.sum(self.loss_func(inputs, targets))}

View File

@@ -1,45 +0,0 @@
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class CTCLoss(nn.Layer):
def __init__(self, use_focal_loss=False, **kwargs):
super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
self.use_focal_loss = use_focal_loss
def forward(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor(
[N] * B, dtype='int64', place=paddle.CPUPlace())
labels = batch[1].astype("int32")
label_lengths = batch[2].astype('int64')
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
if self.use_focal_loss:
weight = paddle.exp(-loss)
weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
weight = paddle.square(weight)
loss = paddle.multiply(loss, weight)
loss = loss.mean()
return {'loss': loss}

View File

@@ -1,70 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .ace_loss import ACELoss
from .center_loss import CenterLoss
from .rec_ctc_loss import CTCLoss
class EnhancedCTCLoss(nn.Layer):
def __init__(self,
use_focal_loss=False,
use_ace_loss=False,
ace_loss_weight=0.1,
use_center_loss=False,
center_loss_weight=0.05,
num_classes=6625,
feat_dim=96,
init_center=False,
center_file_path=None,
**kwargs):
super(EnhancedCTCLoss, self).__init__()
self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss)
self.use_ace_loss = False
if use_ace_loss:
self.use_ace_loss = use_ace_loss
self.ace_loss_func = ACELoss()
self.ace_loss_weight = ace_loss_weight
self.use_center_loss = False
if use_center_loss:
self.use_center_loss = use_center_loss
self.center_loss_func = CenterLoss(
num_classes=num_classes,
feat_dim=feat_dim,
init_center=init_center,
center_file_path=center_file_path)
self.center_loss_weight = center_loss_weight
def __call__(self, predicts, batch):
loss = self.ctc_loss_func(predicts, batch)["loss"]
if self.use_center_loss:
center_loss = self.center_loss_func(
predicts, batch)["loss_center"] * self.center_loss_weight
loss = loss + center_loss
if self.use_ace_loss:
ace_loss = self.ace_loss_func(
predicts, batch)["loss_ace"] * self.ace_loss_weight
loss = loss + ace_loss
return {'enhanced_ctc_loss': loss}

View File

@@ -1,58 +0,0 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
class MultiLoss(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
self.loss_funcs = {}
self.loss_list = kwargs.pop('loss_config_list')
self.weight_1 = kwargs.get('weight_1', 1.0)
self.weight_2 = kwargs.get('weight_2', 1.0)
self.gtc_loss = kwargs.get('gtc_loss', 'sar')
for loss_info in self.loss_list:
for name, param in loss_info.items():
if param is not None:
kwargs.update(param)
loss = eval(name)(**kwargs)
self.loss_funcs[name] = loss
def forward(self, predicts, batch):
self.total_loss = {}
total_loss = 0.0
# batch [image, label_ctc, label_sar, length, valid_ratio]
for name, loss_func in self.loss_funcs.items():
if name == 'CTCLoss':
loss = loss_func(predicts['ctc'],
batch[:2] + batch[3:])['loss'] * self.weight_1
elif name == 'SARLoss':
loss = loss_func(predicts['sar'],
batch[:1] + batch[2:])['loss'] * self.weight_2
else:
raise NotImplementedError(
'{} is not supported in MultiLoss yet'.format(name))
self.total_loss[name] = loss
total_loss += loss
self.total_loss['loss'] = total_loss
return self.total_loss

View File

@@ -1,30 +0,0 @@
import paddle
from paddle import nn
import paddle.nn.functional as F
class NRTRLoss(nn.Layer):
def __init__(self, smoothing=True, **kwargs):
super(NRTRLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
self.smoothing = smoothing
def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
max_len = batch[2].max()
tgt = batch[1][:, 1:2 + max_len]
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype=tgt.dtype))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
loss = self.loss_func(pred, tgt)
return {'loss': loss}

View File

@@ -1,30 +0,0 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
class PRENLoss(nn.Layer):
def __init__(self, **kwargs):
super(PRENLoss, self).__init__()
# note: 0 is padding idx
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
def forward(self, predicts, batch):
loss = self.loss_func(predicts, batch[1].astype('int64'))
return {'loss': loss}

View File

@@ -1,29 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class SARLoss(nn.Layer):
def __init__(self, **kwargs):
super(SARLoss, self).__init__()
ignore_index = kwargs.get('ignore_index', 92) # 6626
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
reduction="mean", ignore_index=ignore_index)
def forward(self, predicts, batch):
predict = predicts[:, :
-1, :] # ignore last index of outputs to be in same seq_len with targets
label = batch[1].astype(
"int64")[:, 1:] # ignore first index of target in loss calculation
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
1], predict.shape[2]
assert len(label.shape) == len(list(predict.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predict, [-1, num_classes])
targets = paddle.reshape(label, [-1])
loss = self.loss_func(inputs, targets)
return {'loss': loss}

View File

@@ -1,47 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class SRNLoss(nn.Layer):
def __init__(self, **kwargs):
super(SRNLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
def forward(self, predicts, batch):
predict = predicts['predict']
word_predict = predicts['word_out']
gsrm_predict = predicts['gsrm_out']
label = batch[1]
casted_label = paddle.cast(x=label, dtype='int64')
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
cost_word = self.loss_func(word_predict, label=casted_label)
cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
cost_vsfd = self.loss_func(predict, label=casted_label)
cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}

View File

@@ -1,109 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle import fluid
class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.use_giou = use_giou
self.giou_weight = giou_weight
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
'''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:return: loss
'''
ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
# overlap
inters = iw * ih
# union
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
# ious
ious = inters / uni
ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
# enclose erea
enclose = ew * eh + eps
giou = ious - (enclose - uni) / enclose
loss = 1 - giou
if reduction == 'mean':
loss = paddle.mean(loss)
elif reduction == 'sum':
loss = paddle.sum(loss)
else:
raise NotImplementedError
return loss
def forward(self, predicts, batch):
structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
if len(batch) == 6:
structure_mask = batch[5].astype("int64")
structure_mask = structure_mask[:, 1:]
structure_mask = paddle.reshape(structure_mask, [-1])
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
structure_targets = paddle.reshape(structure_targets, [-1])
structure_loss = self.loss_func(structure_probs, structure_targets)
if len(batch) == 6:
structure_loss = structure_loss * structure_mask
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
structure_loss = paddle.mean(structure_loss) * self.structure_weight
loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[4].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
if self.use_giou:
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
else:
total_loss = structure_loss + loc_loss
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}

View File

@@ -1,42 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
class VQASerTokenLayoutLMLoss(nn.Layer):
def __init__(self, num_classes):
super().__init__()
self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index
def forward(self, predicts, batch):
labels = batch[1]
attention_mask = batch[4]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape(
[-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels)
else:
loss = self.loss_class(
predicts.reshape([-1, self.num_classes]),
labels.reshape([-1, ]))
return {'loss': loss}

View File

@@ -1,47 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
__all__ = ["build_metric"]
from .det_metric import DetMetric, DetFCEMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
from .table_metric import TableMetric
from .kie_metric import KIEMetric
from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric
def build_metric(config):
support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
'VQAReTokenMetric'
]
config = copy.deepcopy(config)
module_name = config.pop("name")
assert module_name in support_dict, Exception(
"metric only support {}".format(support_dict))
module_class = eval(module_name)(**config)
return module_class

View File

@@ -1,46 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class ClsMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.reset()
def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
all_num = 0
for (pred, pred_conf), (target, _) in zip(preds, labels):
if pred == target:
correct_num += 1
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
return {'acc': correct_num / (all_num + self.eps), }
def get_metric(self):
"""
return metrics {
'acc': 0
}
"""
acc = self.correct_num / (self.all_num + self.eps)
self.reset()
return {'acc': acc}
def reset(self):
self.correct_num = 0
self.all_num = 0

View File

@@ -1,154 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__all__ = ['DetMetric', 'DetFCEMetric']
from .eval_det_iou import DetectionIoUEvaluator
class DetMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.evaluator = DetectionIoUEvaluator()
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
'''
batch: a list produced by dataloaders.
image: np.ndarray of shape (N, C, H, W).
ratio_list: np.ndarray of shape(N,2)
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
preds: a list of dict produced by post process
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
'''
gt_polyons_batch = batch[2]
ignore_tags_batch = batch[3]
for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': '',
'ignore': ignore_tag
} for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
# prepare det
det_info_list = [{
'points': det_polyon,
'text': ''
} for det_polyon in pred['points']]
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result)
def get_metric(self):
"""
return metrics {
'precision': 0,
'recall': 0,
'hmean': 0
}
"""
metrics = self.evaluator.combine_results(self.results)
self.reset()
return metrics
def reset(self):
self.results = [] # clear results
class DetFCEMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.evaluator = DetectionIoUEvaluator()
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
'''
batch: a list produced by dataloaders.
image: np.ndarray of shape (N, C, H, W).
ratio_list: np.ndarray of shape(N,2)
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
preds: a list of dict produced by post process
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
'''
gt_polyons_batch = batch[2]
ignore_tags_batch = batch[3]
for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': '',
'ignore': ignore_tag
} for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
# prepare det
det_info_list = [{
'points': det_polyon,
'text': '',
'score': score
} for det_polyon, score in zip(pred['points'], pred['scores'])]
for score_thr in self.results.keys():
det_info_list_thr = [
det_info for det_info in det_info_list
if det_info['score'] >= score_thr
]
result = self.evaluator.evaluate_image(gt_info_list,
det_info_list_thr)
self.results[score_thr].append(result)
def get_metric(self):
"""
return metrics {'heman':0,
'thr 0.3':'precision: 0 recall: 0 hmean: 0',
'thr 0.4':'precision: 0 recall: 0 hmean: 0',
'thr 0.5':'precision: 0 recall: 0 hmean: 0',
'thr 0.6':'precision: 0 recall: 0 hmean: 0',
'thr 0.7':'precision: 0 recall: 0 hmean: 0',
'thr 0.8':'precision: 0 recall: 0 hmean: 0',
'thr 0.9':'precision: 0 recall: 0 hmean: 0',
}
"""
metrics = {}
hmean = 0
for score_thr in self.results.keys():
metric = self.evaluator.combine_results(self.results[score_thr])
# for key, value in metric.items():
# metrics['{}_{}'.format(key, score_thr)] = value
metric_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format(
metric['precision'], metric['recall'], metric['hmean'])
metrics['thr {}'.format(score_thr)] = metric_str
hmean = max(hmean, metric['hmean'])
metrics['hmean'] = hmean
self.reset()
return metrics
def reset(self):
self.results = {
0.3: [],
0.4: [],
0.5: [],
0.6: [],
0.7: [],
0.8: [],
0.9: []
} # clear results

View File

@@ -1,73 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import copy
from .rec_metric import RecMetric
from .det_metric import DetMetric
from .e2e_metric import E2EMetric
from .cls_metric import ClsMetric
class DistillationMetric(object):
def __init__(self,
key=None,
base_metric_name=None,
main_indicator=None,
**kwargs):
self.main_indicator = main_indicator
self.key = key
self.main_indicator = main_indicator
self.base_metric_name = base_metric_name
self.kwargs = kwargs
self.metrics = None
def _init_metrcis(self, preds):
self.metrics = dict()
mod = importlib.import_module(__name__)
for key in preds:
self.metrics[key] = getattr(mod, self.base_metric_name)(
main_indicator=self.main_indicator, **self.kwargs)
self.metrics[key].reset()
def __call__(self, preds, batch, **kwargs):
assert isinstance(preds, dict)
if self.metrics is None:
self._init_metrcis(preds)
output = dict()
for key in preds:
self.metrics[key].__call__(preds[key], batch, **kwargs)
def get_metric(self):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
output = dict()
for key in self.metrics:
metric = self.metrics[key].get_metric()
# main indicator
if key == self.key:
output.update(metric)
else:
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
def reset(self):
for key in self.metrics:
self.metrics[key].reset()

View File

@@ -1,86 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__all__ = ['E2EMetric']
from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results
from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
class E2EMetric(object):
def __init__(self,
mode,
gt_mat_dir,
character_dict_path,
main_indicator='f_score_e2e',
**kwargs):
self.mode = mode
self.gt_mat_dir = gt_mat_dir
self.label_list = get_dict(character_dict_path)
self.max_index = len(self.label_list)
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
if self.mode == 'A':
gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3][0]
ignore_tags_batch = batch[4]
gt_strs_batch = []
for temp_list in temp_gt_strs_batch:
t = ""
for index in temp_list:
if index < self.max_index:
t += self.label_list[index]
gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip(
[preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': gt_str,
'ignore': ignore_tag
} for gt_polyon, gt_str, ignore_tag in
zip(gt_polyons, gt_strs, ignore_tags)]
# prepare det
e2e_info_list = [{
'points': det_polyon,
'texts': pred_str
} for det_polyon, pred_str in
zip(pred['points'], pred['texts'])]
result = get_socre_A(gt_info_list, e2e_info_list)
self.results.append(result)
else:
img_id = batch[5][0]
e2e_info_list = [{
'points': det_polyon,
'texts': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result)
def get_metric(self):
metrics = combine_results(self.results)
self.reset()
return metrics
def reset(self):
self.results = [] # clear results

View File

@@ -1,225 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import namedtuple
import numpy as np
from shapely.geometry import Polygon
"""
reference from :
https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
"""
class DetectionIoUEvaluator(object):
def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
self.iou_constraint = iou_constraint
self.area_precision_constraint = area_precision_constraint
def evaluate_image(self, gt, pred):
def get_union(pD, pG):
return Polygon(pD).union(Polygon(pG)).area
def get_intersection_over_union(pD, pG):
return get_intersection(pD, pG) / get_union(pD, pG)
def get_intersection(pD, pG):
return Polygon(pD).intersection(Polygon(pG)).area
def compute_ap(confList, matchList, numGtCare):
correct = 0
AP = 0
if len(confList) > 0:
confList = np.array(confList)
matchList = np.array(matchList)
sorted_ind = np.argsort(-confList)
confList = confList[sorted_ind]
matchList = matchList[sorted_ind]
for n in range(len(confList)):
match = matchList[n]
if match:
correct += 1
AP += float(correct) / (n + 1)
if numGtCare > 0:
AP /= numGtCare
return AP
perSampleMetrics = {}
matchedSum = 0
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
numGlobalCareGt = 0
numGlobalCareDet = 0
arrGlobalConfidences = []
arrGlobalMatches = []
recall = 0
precision = 0
hmean = 0
detMatched = 0
iouMat = np.empty([1, 1])
gtPols = []
detPols = []
gtPolPoints = []
detPolPoints = []
# Array of Ground Truth Polygons' keys marked as don't Care
gtDontCarePolsNum = []
# Array of Detected Polygons' matched with a don't Care GT
detDontCarePolsNum = []
pairs = []
detMatchedNums = []
arrSampleConfidences = []
arrSampleMatch = []
evaluationLog = ""
# print(len(gt))
for n in range(len(gt)):
points = gt[n]['points']
# transcription = gt[n]['text']
dontCare = gt[n]['ignore']
# points = Polygon(points)
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
gtPol = points
gtPols.append(gtPol)
gtPolPoints.append(points)
if dontCare:
gtDontCarePolsNum.append(len(gtPols) - 1)
evaluationLog += "GT polygons: " + str(len(gtPols)) + (
" (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
if len(gtDontCarePolsNum) > 0 else "\n")
for n in range(len(pred)):
points = pred[n]['points']
# points = Polygon(points)
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
detPol = points
detPols.append(detPol)
detPolPoints.append(points)
if len(gtDontCarePolsNum) > 0:
for dontCarePol in gtDontCarePolsNum:
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol, detPol)
pdDimensions = Polygon(detPol).area
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
if (precision > self.area_precision_constraint):
detDontCarePolsNum.append(len(detPols) - 1)
break
evaluationLog += "DET polygons: " + str(len(detPols)) + (
" (" + str(len(detDontCarePolsNum)) + " don't care)\n"
if len(detDontCarePolsNum) > 0 else "\n")
if len(gtPols) > 0 and len(detPols) > 0:
# Calculate IoU and precision matrixs
outputShape = [len(gtPols), len(detPols)]
iouMat = np.empty(outputShape)
gtRectMat = np.zeros(len(gtPols), np.int8)
detRectMat = np.zeros(len(detPols), np.int8)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
pG = gtPols[gtNum]
pD = detPols[detNum]
iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
if gtRectMat[gtNum] == 0 and detRectMat[
detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
if iouMat[gtNum, detNum] > self.iou_constraint:
gtRectMat[gtNum] = 1
detRectMat[detNum] = 1
detMatched += 1
pairs.append({'gt': gtNum, 'det': detNum})
detMatchedNums.append(detNum)
evaluationLog += "Match GT #" + \
str(gtNum) + " with Det #" + str(detNum) + "\n"
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
numDetCare = (len(detPols) - len(detDontCarePolsNum))
if numGtCare == 0:
recall = float(1)
precision = float(0) if numDetCare > 0 else float(1)
else:
recall = float(detMatched) / numGtCare
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
hmean = 0 if (precision + recall) == 0 else 2.0 * \
precision * recall / (precision + recall)
matchedSum += detMatched
numGlobalCareGt += numGtCare
numGlobalCareDet += numDetCare
perSampleMetrics = {
'gtCare': numGtCare,
'detCare': numDetCare,
'detMatched': detMatched,
}
return perSampleMetrics
def combine_results(self, results):
numGlobalCareGt = 0
numGlobalCareDet = 0
matchedSum = 0
for result in results:
numGlobalCareGt += result['gtCare']
numGlobalCareDet += result['detCare']
matchedSum += result['detMatched']
methodRecall = 0 if numGlobalCareGt == 0 else float(
matchedSum) / numGlobalCareGt
methodPrecision = 0 if numGlobalCareDet == 0 else float(
matchedSum) / numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / (
methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = {
'precision': methodPrecision,
'recall': methodRecall,
'hmean': methodHmean
}
return methodMetrics
if __name__ == '__main__':
evaluator = DetectionIoUEvaluator()
gts = [[{
'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
'text': 1234,
'ignore': False,
}, {
'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
'text': 5678,
'ignore': False,
}]]
preds = [[{
'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
'text': 123,
'ignore': False,
}]]
results = []
for gt, pred in zip(gts, preds):
results.append(evaluator.evaluate_image(gt, pred))
metrics = evaluator.combine_results(results)
print(metrics)

View File

@@ -1,71 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The code is refer from: https://github.com/open-mmlab/mmocr/blob/main/mmocr/core/evaluation/kie_metric.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
__all__ = ['KIEMetric']
class KIEMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator
self.reset()
self.node = []
self.gt = []
def __call__(self, preds, batch, **kwargs):
nodes, _ = preds
gts, tag = batch[4].squeeze(0), batch[5].tolist()[0]
gts = gts[:tag[0], :1].reshape([-1])
self.node.append(nodes.numpy())
self.gt.append(gts)
# result = self.compute_f1_score(nodes, gts)
# self.results.append(result)
def compute_f1_score(self, preds, gts):
ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]
C = preds.shape[1]
classes = np.array(sorted(set(range(C)) - set(ignores)))
hist = np.bincount(
(gts * C).astype('int64') + preds.argmax(1), minlength=C
**2).reshape([C, C]).astype('float32')
diag = np.diag(hist)
recalls = diag / hist.sum(1).clip(min=1)
precisions = diag / hist.sum(0).clip(min=1)
f1 = 2 * recalls * precisions / (recalls + precisions).clip(min=1e-8)
return f1[classes]
def combine_results(self, results):
node = np.concatenate(self.node, 0)
gts = np.concatenate(self.gt, 0)
results = self.compute_f1_score(node, gts)
data = {'hmean': results.mean()}
return data
def get_metric(self):
metrics = self.combine_results(self.results)
self.reset()
return metrics
def reset(self):
self.results = [] # clear results
self.node = []
self.gt = []

View File

@@ -1,76 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import Levenshtein
import string
class RecMetric(object):
def __init__(self,
main_indicator='acc',
is_filter=False,
ignore_space=True,
**kwargs):
self.main_indicator = main_indicator
self.is_filter = is_filter
self.ignore_space = ignore_space
self.eps = 1e-5
self.reset()
def _normalize_text(self, text):
text = ''.join(
filter(lambda x: x in (string.digits + string.ascii_letters), text))
return text.lower()
def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
all_num = 0
norm_edit_dis = 0.0
for (pred, pred_conf), (target, _) in zip(preds, labels):
if self.ignore_space:
pred = pred.replace(" ", "")
target = target.replace(" ", "")
if self.is_filter:
pred = self._normalize_text(pred)
target = self._normalize_text(target)
norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target), 1)
if pred == target:
correct_num += 1
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
return {
'acc': correct_num / (all_num + self.eps),
'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
}
def get_metric(self):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
def reset(self):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0

View File

@@ -1,51 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
class TableMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.reset()
def __call__(self, pred, batch, *args, **kwargs):
structure_probs = pred['structure_probs'].numpy()
structure_labels = batch[1]
correct_num = 0
all_num = 0
structure_probs = np.argmax(structure_probs, axis=2)
structure_labels = structure_labels[:, 1:]
batch_size = structure_probs.shape[0]
for bno in range(batch_size):
all_num += 1
if (structure_probs[bno] == structure_labels[bno]).all():
correct_num += 1
self.correct_num += correct_num
self.all_num += all_num
return {'acc': correct_num * 1.0 / (all_num + self.eps), }
def get_metric(self):
"""
return metrics {
'acc': 0,
}
"""
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
self.reset()
return {'acc': acc}
def reset(self):
self.correct_num = 0
self.all_num = 0

View File

@@ -1,176 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
__all__ = ['KIEMetric']
class VQAReTokenMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
pred_relations, relations, entities = preds
self.pred_relations_list.extend(pred_relations)
self.relations_list.extend(relations)
self.entities_list.extend(entities)
def get_metric(self):
gt_relations = []
for b in range(len(self.relations_list)):
rel_sent = []
for head, tail in zip(self.relations_list[b]["head"],
self.relations_list[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
self.entities_list[b]["end"][rel["head_id"]])
rel["head_type"] = self.entities_list[b]["label"][rel[
"head_id"]]
rel["tail_id"] = tail
rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]])
rel["tail_type"] = self.entities_list[b]["label"][rel[
"tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = self.re_score(
self.pred_relations_list, gt_relations, mode="boundaries")
metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"hmean": re_metrics["ALL"]["f1"],
}
self.reset()
return metrics
def reset(self):
self.pred_relations_list = []
self.relations_list = []
self.entities_list = []
def re_score(self, pred_relations, gt_relations, mode="strict"):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert mode in ["strict", "boundaries"]
relation_types = [v for v in [0, 1] if not v == 0]
scores = {
rel: {
"tp": 0,
"fp": 0,
"fn": 0
}
for rel in relation_types + ["ALL"]
}
# Count GT relations and Predicted relations
n_sents = len(gt_relations)
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
# Count TP, FP and FN per type
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
for rel_type in relation_types:
# strict mode takes argument types into account
if mode == "strict":
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in pred_sent
if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in gt_sent if rel["type"] == rel_type}
# boundaries mode only takes argument spans into account
elif mode == "boundaries":
pred_rels = {(rel["head"], rel["tail"])
for rel in pred_sent
if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["tail"])
for rel in gt_sent if rel["type"] == rel_type}
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
# Compute per entity Precision / Recall / F1
for rel_type in scores.keys():
if scores[rel_type]["tp"]:
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
scores[rel_type]["fp"] + scores[rel_type]["tp"])
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
scores[rel_type]["fn"] + scores[rel_type]["tp"])
else:
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
scores[rel_type]["f1"] = (
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
(scores[rel_type]["p"] + scores[rel_type]["r"]))
else:
scores[rel_type]["f1"] = 0
# Compute micro F1 Scores
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
scores["ALL"]["p"] = precision
scores["ALL"]["r"] = recall
scores["ALL"]["f1"] = f1
scores["ALL"]["tp"] = tp
scores["ALL"]["fp"] = fp
scores["ALL"]["fn"] = fn
# Compute Macro F1 Scores
scores["ALL"]["Macro_f1"] = np.mean(
[scores[ent_type]["f1"] for ent_type in relation_types])
scores["ALL"]["Macro_p"] = np.mean(
[scores[ent_type]["p"] for ent_type in relation_types])
scores["ALL"]["Macro_r"] = np.mean(
[scores[ent_type]["r"] for ent_type in relation_types])
return scores

View File

@@ -1,47 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
__all__ = ['KIEMetric']
class VQASerTokenMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
preds, labels = preds
self.pred_list.extend(preds)
self.gt_list.extend(labels)
def get_metric(self):
from seqeval.metrics import f1_score, precision_score, recall_score
metrics = {
"precision": precision_score(self.gt_list, self.pred_list),
"recall": recall_score(self.gt_list, self.pred_list),
"hmean": f1_score(self.gt_list, self.pred_list),
}
self.reset()
return metrics
def reset(self):
self.pred_list = []
self.gt_list = []

View File

@@ -1,32 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import importlib
from .base_model import BaseModel
from .distillation_model import DistillationModel
__all__ = ['build_model']
def build_model(config):
config = copy.deepcopy(config)
if not "name" in config:
arch = BaseModel(config)
else:
name = config.pop("name")
mod = importlib.import_module(__name__)
arch = getattr(mod, name)(config)
return arch

View File

@@ -1,100 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
__all__ = ['BaseModel']
class BaseModel(nn.Layer):
def __init__(self, config):
"""
the module for OCR.
args:
config (dict): the super parameters for module.
"""
super(BaseModel, self).__init__()
in_channels = config.get('in_channels', 3)
model_type = config['model_type']
# build transfrom,
# for rec, transfrom can be TPS,None
# for det and cls, transfrom shoule to be None,
# if you make model differently, you can use transfrom in det and cls
if 'Transform' not in config or config['Transform'] is None:
self.use_transform = False
else:
self.use_transform = True
config['Transform']['in_channels'] = in_channels
self.transform = build_transform(config['Transform'])
in_channels = self.transform.out_channels
# build backbone, backbone is need for del, rec and cls
config["Backbone"]['in_channels'] = in_channels
self.backbone = build_backbone(config["Backbone"], model_type)
in_channels = self.backbone.out_channels
# build neck
# for rec, neck can be cnn,rnn or reshape(None)
# for det, neck can be FPN, BIFPN and so on.
# for cls, neck should be none
if 'Neck' not in config or config['Neck'] is None:
self.use_neck = False
else:
self.use_neck = True
config['Neck']['in_channels'] = in_channels
self.neck = build_neck(config['Neck'])
in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls
if 'Head' not in config or config['Head'] is None:
self.use_head = False
else:
self.use_head = True
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None):
y = dict()
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
y["backbone_out"] = x
if self.use_neck:
x = self.neck(x)
y["neck_out"] = x
if self.use_head:
x = self.head(x, targets=data)
# for multi head, save ctc neck out for udml
if isinstance(x, dict) and 'ctc_neck' in x.keys():
y["neck_out"] = x["ctc_neck"]
y["head_out"] = x
elif isinstance(x, dict):
y.update(x)
else:
y["head_out"] = x
if self.return_all_feats:
if self.training:
return y
else:
return {"head_out": y["head_out"]}
else:
return x

View File

@@ -1,60 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
from .base_model import BaseModel
from ppocr.utils.save_load import load_pretrained_params
__all__ = ['DistillationModel']
class DistillationModel(nn.Layer):
def __init__(self, config):
"""
the module for OCR distillation.
args:
config (dict): the super parameters for module.
"""
super().__init__()
self.model_list = []
self.model_name_list = []
for key in config["Models"]:
model_config = config["Models"][key]
freeze_params = False
pretrained = None
if "freeze_params" in model_config:
freeze_params = model_config.pop("freeze_params")
if "pretrained" in model_config:
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
load_pretrained_params(model, pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key)
def forward(self, x, data=None):
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
result_dict[model_name] = self.model_list[idx](x, data)
return result_dict

View File

@@ -1,64 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["build_backbone"]
def build_backbone(config, model_type):
if model_type == "det" or model_type == "table":
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN
from .rec_svtrnet import SVTRNet
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
'SVTRNet'
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
elif model_type == 'kie':
from .kie_unet_sdmgr import Kie_backbone
support_dict = ['Kie_backbone']
elif model_type == "table":
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"]
elif model_type == 'vqa':
from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
support_dict = [
"LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe',
"LayoutXLMForSer", 'LayoutXLMForRe'
]
else:
raise NotImplementedError
module_name = config.pop("name")
assert module_name in support_dict, Exception(
"when model typs is {}, backbone only support {}".format(model_type,
support_dict))
module_class = eval(module_name)(**config)
return module_class

View File

@@ -1,268 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
__all__ = ['MobileNetV3']
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class MobileNetV3(nn.Layer):
def __init__(self,
in_channels=3,
model_name='large',
scale=0.5,
disable_se=False,
**kwargs):
"""
the MobilenetV3 backbone network for detection module.
Args:
params(dict): the super parameters for build network
"""
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hardswish', 2],
[3, 200, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 480, 112, True, 'hardswish', 1],
[3, 672, 112, True, 'hardswish', 1],
[5, 672, 160, True, 'hardswish', 2],
[5, 960, 160, True, 'hardswish', 1],
[5, 960, 160, True, 'hardswish', 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hardswish', 2],
[5, 240, 40, True, 'hardswish', 1],
[5, 240, 40, True, 'hardswish', 1],
[5, 120, 48, True, 'hardswish', 1],
[5, 144, 48, True, 'hardswish', 1],
[5, 288, 96, True, 'hardswish', 2],
[5, 576, 96, True, 'hardswish', 1],
[5, 576, 96, True, 'hardswish', 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError("mode[" + model_name +
"_model] is not implemented!")
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert scale in supported_scale, \
"supported scale are {} but input scale is {}".format(supported_scale, scale)
inplanes = 16
# conv1
self.conv = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
if_act=True,
act='hardswish')
self.stages = []
self.out_channels = []
block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
start_idx = 2 if model_name == 'large' else 0
if s == 2 and i > start_idx:
self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list))
block_list = []
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl))
inplanes = make_divisible(scale * c)
i += 1
block_list.append(
ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act='hardswish'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages):
self.add_sublayer(sublayer=stage, name="stage{}".format(i))
def forward(self, x):
x = self.conv(x)
out_list = []
for stage in self.stages:
x = stage(x)
out_list.append(x)
return out_list
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=False)
self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
if self.act == "relu":
x = F.relu(x)
elif self.act == "hardswish":
x = F.hardswish(x)
else:
print("The activation function({}) is selected incorrectly.".
format(self.act))
exit()
return x
class ResidualUnit(nn.Layer):
def __init__(self,
in_channels,
mid_channels,
out_channels,
kernel_size,
stride,
use_se,
act=None):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=True,
act=act)
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
padding=int((kernel_size - 1) // 2),
groups=mid_channels,
if_act=True,
act=act)
if self.if_se:
self.mid_se = SEModule(mid_channels)
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None)
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = paddle.add(inputs, x)
return x
class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.conv1 = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
padding=0)
self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
return inputs * outputs

View File

@@ -1,351 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform
__all__ = ["ResNet"]
class DeformableConvV2(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
weight_attr=None,
bias_attr=None,
lr_scale=1,
regularizer=None,
skip_quant=False,
dcn_bias_regularizer=L2Decay(0.),
dcn_bias_lr_scale=2.):
super(DeformableConvV2, self).__init__()
self.offset_channel = 2 * kernel_size**2 * groups
self.mask_channel = kernel_size**2 * groups
if bias_attr:
# in FCOS-DCN head, specifically need learning_rate and regularizer
dcn_bias_attr = ParamAttr(
initializer=Constant(value=0),
regularizer=dcn_bias_regularizer,
learning_rate=dcn_bias_lr_scale)
else:
# in ResNet backbone, do not need bias
dcn_bias_attr = False
self.conv_dcn = DeformConv2D(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2 * dilation,
dilation=dilation,
deformable_groups=groups,
weight_attr=weight_attr,
bias_attr=dcn_bias_attr)
if lr_scale == 1 and regularizer is None:
offset_bias_attr = ParamAttr(initializer=Constant(0.))
else:
offset_bias_attr = ParamAttr(
initializer=Constant(0.),
learning_rate=lr_scale,
regularizer=regularizer)
self.conv_offset = nn.Conv2D(
in_channels,
groups * 3 * kernel_size**2,
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
weight_attr=ParamAttr(initializer=Constant(0.0)),
bias_attr=offset_bias_attr)
if skip_quant:
self.conv_offset.skip_quant = True
def forward(self, x):
offset_mask = self.conv_offset(x)
offset, mask = paddle.split(
offset_mask,
num_or_sections=[self.offset_channel, self.mask_channel],
axis=1)
mask = F.sigmoid(mask)
y = self.conv_dcn(x, offset, mask=mask)
return y
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
is_dcn=False):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
if not is_dcn:
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias_attr=False)
else:
self._conv = DeformableConvV2(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=2, #groups,
bias_attr=False)
self._batch_norm = nn.BatchNorm(out_channels, act=act)
def forward(self, inputs):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
is_dcn=False, ):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu')
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
is_dcn=is_dcn)
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False, ):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu')
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None)
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self,
in_channels=3,
layers=50,
dcn_stage=None,
out_indices=None,
**kwargs):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.dcn_stage = dcn_stage if dcn_stage is not None else [
False, False, False, False
]
self.out_indices = out_indices if out_indices is not None else [
0, 1, 2, 3
]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
act='relu')
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu')
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu')
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
self.out_channels = []
if layers >= 50:
for block in range(len(depth)):
block_list = []
shortcut = False
is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
is_dcn=is_dcn))
shortcut = True
block_list.append(bottleneck_block)
if block in self.out_indices:
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
# is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0))
shortcut = True
block_list.append(basic_block)
if block in self.out_indices:
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
y = self.pool2d_max(y)
out = []
for i, block in enumerate(self.stages):
y = block(y)
if i in self.out_indices:
out.append(y)
return out

View File

@@ -1,285 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ["ResNet_SAST"]
class ConvBNLayer(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet_SAST(nn.Layer):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet_SAST, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
# depth = [3, 4, 6, 3]
depth = [3, 4, 6, 3, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
# num_channels = [64, 256, 512,
# 1024] if layers >= 50 else [64, 64, 128, 256]
# num_filters = [64, 128, 256, 512]
num_channels = [64, 256, 512,
1024, 2048] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name="conv1_3")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
self.out_channels = [3, 64]
if layers >= 50:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
out = [inputs]
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
out.append(y)
y = self.pool2d_max(y)
for block in self.stages:
y = block(y)
out.append(y)
return out

View File

@@ -1,265 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ["ResNet"]
class ConvBNLayer(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=stride,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
# depth = [3, 4, 6, 3]
depth = [3, 4, 6, 3, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512, 1024,
2048] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=7,
stride=2,
act='relu',
name="conv1_1")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
self.out_channels = [3, 64]
# num_filters = [64, 128, 256, 512, 512]
if layers >= 50:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
out = [inputs]
y = self.conv1_1(inputs)
out.append(y)
y = self.pool2d_max(y)
for block in self.stages:
y = block(y)
out.append(y)
return out

View File

@@ -1,186 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import numpy as np
import cv2
__all__ = ["Kie_backbone"]
class Encoder(nn.Layer):
def __init__(self, num_channels, num_filters):
super(Encoder, self).__init__()
self.conv1 = nn.Conv2D(
num_channels,
num_filters,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm(num_filters, act='relu')
self.conv2 = nn.Conv2D(
num_filters,
num_filters,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm(num_filters, act='relu')
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
def forward(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
x_pooled = self.pool(x)
return x, x_pooled
class Decoder(nn.Layer):
def __init__(self, num_channels, num_filters):
super(Decoder, self).__init__()
self.conv1 = nn.Conv2D(
num_channels,
num_filters,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm(num_filters, act='relu')
self.conv2 = nn.Conv2D(
num_filters,
num_filters,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm(num_filters, act='relu')
self.conv0 = nn.Conv2D(
num_channels,
num_filters,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.bn0 = nn.BatchNorm(num_filters, act='relu')
def forward(self, inputs_prev, inputs):
x = self.conv0(inputs)
x = self.bn0(x)
x = paddle.nn.functional.interpolate(
x, scale_factor=2, mode='bilinear', align_corners=False)
x = paddle.concat([inputs_prev, x], axis=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
return x
class UNet(nn.Layer):
def __init__(self):
super(UNet, self).__init__()
self.down1 = Encoder(num_channels=3, num_filters=16)
self.down2 = Encoder(num_channels=16, num_filters=32)
self.down3 = Encoder(num_channels=32, num_filters=64)
self.down4 = Encoder(num_channels=64, num_filters=128)
self.down5 = Encoder(num_channels=128, num_filters=256)
self.up1 = Decoder(32, 16)
self.up2 = Decoder(64, 32)
self.up3 = Decoder(128, 64)
self.up4 = Decoder(256, 128)
self.out_channels = 16
def forward(self, inputs):
x1, _ = self.down1(inputs)
_, x2 = self.down2(x1)
_, x3 = self.down3(x2)
_, x4 = self.down4(x3)
_, x5 = self.down5(x4)
x = self.up4(x4, x5)
x = self.up3(x3, x)
x = self.up2(x2, x)
x = self.up1(x1, x)
return x
class Kie_backbone(nn.Layer):
def __init__(self, in_channels, **kwargs):
super(Kie_backbone, self).__init__()
self.out_channels = 16
self.img_feat = UNet()
self.maxpool = nn.MaxPool2D(kernel_size=7)
def bbox2roi(self, bbox_list):
rois_list = []
rois_num = []
for img_id, bboxes in enumerate(bbox_list):
rois_num.append(bboxes.shape[0])
rois_list.append(bboxes)
rois = paddle.concat(rois_list, 0)
rois_num = paddle.to_tensor(rois_num, dtype='int32')
return rois, rois_num
def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
img, relations, texts, gt_bboxes, tag, img_size = img.numpy(
), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy(
).tolist(), img_size.numpy()
temp_relations, temp_texts, temp_gt_bboxes = [], [], []
h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1]))
img = paddle.to_tensor(img[:, :, :h, :w])
batch = len(tag)
for i in range(batch):
num, recoder_len = tag[i][0], tag[i][1]
temp_relations.append(
paddle.to_tensor(
relations[i, :num, :num, :], dtype='float32'))
temp_texts.append(
paddle.to_tensor(
texts[i, :num, :recoder_len], dtype='float32'))
temp_gt_bboxes.append(
paddle.to_tensor(
gt_bboxes[i, :num, ...], dtype='float32'))
return img, temp_relations, temp_texts, temp_gt_bboxes
def forward(self, inputs):
img = inputs[0]
relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
2], inputs[3], inputs[5], inputs[-1]
img, relations, texts, gt_bboxes = self.pre_process(
img, relations, texts, gt_bboxes, tag, img_size)
x = self.img_feat(img)
boxes, rois_num = self.bbox2roi(gt_bboxes)
feats = paddle.fluid.layers.roi_align(
x,
boxes,
spatial_scale=1.0,
pooled_height=7,
pooled_width=7,
rois_num=rois_num)
feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
return [relations, texts, feats]

View File

@@ -1,228 +0,0 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code is refer from:
https://github.com/RuijieJ/pren/blob/main/Nets/EfficientNet.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from collections import namedtuple
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ['EfficientNetb3']
class EffB3Params:
@staticmethod
def get_global_params():
"""
The fllowing are efficientnetb3's arch superparams, but to fit for scene
text recognition task, the resolution(image_size) here is changed
from 300 to 64.
"""
GlobalParams = namedtuple('GlobalParams', [
'drop_connect_rate', 'width_coefficient', 'depth_coefficient',
'depth_divisor', 'image_size'
])
global_params = GlobalParams(
drop_connect_rate=0.3,
width_coefficient=1.2,
depth_coefficient=1.4,
depth_divisor=8,
image_size=64)
return global_params
@staticmethod
def get_block_params():
BlockParams = namedtuple('BlockParams', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'se_ratio', 'stride'
])
block_params = [
BlockParams(3, 1, 32, 16, 1, True, 0.25, 1),
BlockParams(3, 2, 16, 24, 6, True, 0.25, 2),
BlockParams(5, 2, 24, 40, 6, True, 0.25, 2),
BlockParams(3, 3, 40, 80, 6, True, 0.25, 2),
BlockParams(5, 3, 80, 112, 6, True, 0.25, 1),
BlockParams(5, 4, 112, 192, 6, True, 0.25, 2),
BlockParams(3, 1, 192, 320, 6, True, 0.25, 1)
]
return block_params
class EffUtils:
@staticmethod
def round_filters(filters, global_params):
"""Calculate and round number of filters based on depth multiplier."""
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
filters *= multiplier
new_filters = int(filters + divisor / 2) // divisor * divisor
if new_filters < 0.9 * filters:
new_filters += divisor
return int(new_filters)
@staticmethod
def round_repeats(repeats, global_params):
"""Round number of filters based on depth multiplier."""
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
class ConvBlock(nn.Layer):
def __init__(self, block_params):
super(ConvBlock, self).__init__()
self.block_args = block_params
self.has_se = (self.block_args.se_ratio is not None) and \
(0 < self.block_args.se_ratio <= 1)
self.id_skip = block_params.id_skip
# expansion phase
self.input_filters = self.block_args.input_filters
output_filters = \
self.block_args.input_filters * self.block_args.expand_ratio
if self.block_args.expand_ratio != 1:
self.expand_conv = nn.Conv2D(
self.input_filters, output_filters, 1, bias_attr=False)
self.bn0 = nn.BatchNorm(output_filters)
# depthwise conv phase
k = self.block_args.kernel_size
s = self.block_args.stride
self.depthwise_conv = nn.Conv2D(
output_filters,
output_filters,
groups=output_filters,
kernel_size=k,
stride=s,
padding='same',
bias_attr=False)
self.bn1 = nn.BatchNorm(output_filters)
# squeeze and excitation layer, if desired
if self.has_se:
num_squeezed_channels = max(1,
int(self.block_args.input_filters *
self.block_args.se_ratio))
self.se_reduce = nn.Conv2D(output_filters, num_squeezed_channels, 1)
self.se_expand = nn.Conv2D(num_squeezed_channels, output_filters, 1)
# output phase
self.final_oup = self.block_args.output_filters
self.project_conv = nn.Conv2D(
output_filters, self.final_oup, 1, bias_attr=False)
self.bn2 = nn.BatchNorm(self.final_oup)
self.swish = nn.Swish()
def drop_connect(self, inputs, p, training):
if not training:
return inputs
batch_size = inputs.shape[0]
keep_prob = 1 - p
random_tensor = keep_prob
random_tensor += paddle.rand([batch_size, 1, 1, 1], dtype=inputs.dtype)
random_tensor = paddle.to_tensor(random_tensor, place=inputs.place)
binary_tensor = paddle.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def forward(self, inputs, drop_connect_rate=None):
# expansion and depthwise conv
x = inputs
if self.block_args.expand_ratio != 1:
x = self.swish(self.bn0(self.expand_conv(inputs)))
x = self.swish(self.bn1(self.depthwise_conv(x)))
# squeeze and excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self.se_expand(self.swish(self.se_reduce(x_squeezed)))
x = F.sigmoid(x_squeezed) * x
x = self.bn2(self.project_conv(x))
# skip conntection and drop connect
if self.id_skip and self.block_args.stride == 1 and \
self.input_filters == self.final_oup:
if drop_connect_rate:
x = self.drop_connect(
x, p=drop_connect_rate, training=self.training)
x = x + inputs
return x
class EfficientNetb3_PREN(nn.Layer):
def __init__(self, in_channels):
super(EfficientNetb3_PREN, self).__init__()
self.blocks_params = EffB3Params.get_block_params()
self.global_params = EffB3Params.get_global_params()
self.out_channels = []
# stem
stem_channels = EffUtils.round_filters(32, self.global_params)
self.conv_stem = nn.Conv2D(
in_channels, stem_channels, 3, 2, padding='same', bias_attr=False)
self.bn0 = nn.BatchNorm(stem_channels)
self.blocks = []
# to extract three feature maps for fpn based on efficientnetb3 backbone
self.concerned_block_idxes = [7, 17, 25]
concerned_idx = 0
for i, block_params in enumerate(self.blocks_params):
block_params = block_params._replace(
input_filters=EffUtils.round_filters(block_params.input_filters,
self.global_params),
output_filters=EffUtils.round_filters(
block_params.output_filters, self.global_params),
num_repeat=EffUtils.round_repeats(block_params.num_repeat,
self.global_params))
self.blocks.append(
self.add_sublayer("{}-0".format(i), ConvBlock(block_params)))
concerned_idx += 1
if concerned_idx in self.concerned_block_idxes:
self.out_channels.append(block_params.output_filters)
if block_params.num_repeat > 1:
block_params = block_params._replace(
input_filters=block_params.output_filters, stride=1)
for j in range(block_params.num_repeat - 1):
self.blocks.append(
self.add_sublayer('{}-{}'.format(i, j + 1),
ConvBlock(block_params)))
concerned_idx += 1
if concerned_idx in self.concerned_block_idxes:
self.out_channels.append(block_params.output_filters)
self.swish = nn.Swish()
def forward(self, inputs):
outs = []
x = self.swish(self.bn0(self.conv_stem(inputs)))
for idx, block in enumerate(self.blocks):
drop_connect_rate = self.global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
if idx in self.concerned_block_idxes:
outs.append(x)
return outs

View File

@@ -1,528 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/liyunsheng13/micronet/blob/main/backbone/micronet.py
https://github.com/liyunsheng13/micronet/blob/main/backbone/activation.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from ppocr.modeling.backbones.det_mobilenet_v3 import make_divisible
M0_cfgs = [
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r
[2, 1, 8, 3, 2, 2, 0, 4, 8, 2, 2, 2, 0, 1, 1],
[2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 2, 1, 1],
[2, 1, 16, 5, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
[1, 1, 32, 5, 1, 4, 4, 4, 32, 4, 4, 2, 2, 1, 1],
[2, 1, 64, 5, 1, 4, 8, 8, 64, 8, 8, 2, 2, 1, 1],
[1, 1, 96, 3, 1, 4, 8, 8, 96, 8, 8, 2, 2, 1, 2],
[1, 1, 384, 3, 1, 4, 12, 12, 0, 0, 0, 2, 2, 1, 2],
]
M1_cfgs = [
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
[2, 1, 8, 3, 2, 2, 0, 6, 8, 2, 2, 2, 0, 1, 1],
[2, 1, 16, 3, 2, 2, 0, 8, 16, 4, 4, 2, 2, 1, 1],
[2, 1, 16, 5, 2, 2, 0, 16, 16, 4, 4, 2, 2, 1, 1],
[1, 1, 32, 5, 1, 6, 4, 4, 32, 4, 4, 2, 2, 1, 1],
[2, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 1],
[1, 1, 96, 3, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
[1, 1, 576, 3, 1, 6, 12, 12, 0, 0, 0, 2, 2, 1, 2],
]
M2_cfgs = [
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
[2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 0, 1, 1],
[2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
[1, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 2, 2, 1, 1],
[2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 2, 2, 1, 1],
[1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 2, 2, 1, 2],
[1, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 2],
[2, 1, 96, 5, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
[1, 1, 128, 3, 1, 6, 12, 12, 128, 8, 8, 2, 2, 1, 2],
[1, 1, 768, 3, 1, 6, 16, 16, 0, 0, 0, 2, 2, 1, 2],
]
M3_cfgs = [
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
[2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 0, 2, 0, 1],
[2, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 0, 2, 0, 1],
[1, 1, 24, 3, 2, 2, 0, 24, 24, 4, 4, 0, 2, 0, 1],
[2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 0, 2, 0, 1],
[1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 0, 2, 0, 2],
[1, 1, 64, 5, 1, 6, 8, 8, 48, 8, 8, 0, 2, 0, 2],
[1, 1, 80, 5, 1, 6, 8, 8, 80, 8, 8, 0, 2, 0, 2],
[1, 1, 80, 5, 1, 6, 10, 10, 80, 8, 8, 0, 2, 0, 2],
[1, 1, 120, 5, 1, 6, 10, 10, 120, 10, 10, 0, 2, 0, 2],
[1, 1, 120, 5, 1, 6, 12, 12, 120, 10, 10, 0, 2, 0, 2],
[1, 1, 144, 3, 1, 6, 12, 12, 144, 12, 12, 0, 2, 0, 2],
[1, 1, 432, 3, 1, 3, 12, 12, 0, 0, 0, 0, 2, 0, 2],
]
def get_micronet_config(mode):
return eval(mode + '_cfgs')
class MaxGroupPooling(nn.Layer):
def __init__(self, channel_per_group=2):
super(MaxGroupPooling, self).__init__()
self.channel_per_group = channel_per_group
def forward(self, x):
if self.channel_per_group == 1:
return x
# max op
b, c, h, w = x.shape
# reshape
y = paddle.reshape(x, [b, c // self.channel_per_group, -1, h, w])
out = paddle.max(y, axis=2)
return out
class SpatialSepConvSF(nn.Layer):
def __init__(self, inp, oups, kernel_size, stride):
super(SpatialSepConvSF, self).__init__()
oup1, oup2 = oups
self.conv = nn.Sequential(
nn.Conv2D(
inp,
oup1, (kernel_size, 1), (stride, 1), (kernel_size // 2, 0),
bias_attr=False,
groups=1),
nn.BatchNorm2D(oup1),
nn.Conv2D(
oup1,
oup1 * oup2, (1, kernel_size), (1, stride),
(0, kernel_size // 2),
bias_attr=False,
groups=oup1),
nn.BatchNorm2D(oup1 * oup2),
ChannelShuffle(oup1), )
def forward(self, x):
out = self.conv(x)
return out
class ChannelShuffle(nn.Layer):
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
b, c, h, w = x.shape
channels_per_group = c // self.groups
# reshape
x = paddle.reshape(x, [b, self.groups, channels_per_group, h, w])
x = paddle.transpose(x, (0, 2, 1, 3, 4))
out = paddle.reshape(x, [b, -1, h, w])
return out
class StemLayer(nn.Layer):
def __init__(self, inp, oup, stride, groups=(4, 4)):
super(StemLayer, self).__init__()
g1, g2 = groups
self.stem = nn.Sequential(
SpatialSepConvSF(inp, groups, 3, stride),
MaxGroupPooling(2) if g1 * g2 == 2 * oup else nn.ReLU6())
def forward(self, x):
out = self.stem(x)
return out
class DepthSpatialSepConv(nn.Layer):
def __init__(self, inp, expand, kernel_size, stride):
super(DepthSpatialSepConv, self).__init__()
exp1, exp2 = expand
hidden_dim = inp * exp1
oup = inp * exp1 * exp2
self.conv = nn.Sequential(
nn.Conv2D(
inp,
inp * exp1, (kernel_size, 1), (stride, 1),
(kernel_size // 2, 0),
bias_attr=False,
groups=inp),
nn.BatchNorm2D(inp * exp1),
nn.Conv2D(
hidden_dim,
oup, (1, kernel_size),
1, (0, kernel_size // 2),
bias_attr=False,
groups=hidden_dim),
nn.BatchNorm2D(oup))
def forward(self, x):
x = self.conv(x)
return x
class GroupConv(nn.Layer):
def __init__(self, inp, oup, groups=2):
super(GroupConv, self).__init__()
self.inp = inp
self.oup = oup
self.groups = groups
self.conv = nn.Sequential(
nn.Conv2D(
inp, oup, 1, 1, 0, bias_attr=False, groups=self.groups[0]),
nn.BatchNorm2D(oup))
def forward(self, x):
x = self.conv(x)
return x
class DepthConv(nn.Layer):
def __init__(self, inp, oup, kernel_size, stride):
super(DepthConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2D(
inp,
oup,
kernel_size,
stride,
kernel_size // 2,
bias_attr=False,
groups=inp),
nn.BatchNorm2D(oup))
def forward(self, x):
out = self.conv(x)
return out
class DYShiftMax(nn.Layer):
def __init__(self,
inp,
oup,
reduction=4,
act_max=1.0,
act_relu=True,
init_a=[0.0, 0.0],
init_b=[0.0, 0.0],
relu_before_pool=False,
g=None,
expansion=False):
super(DYShiftMax, self).__init__()
self.oup = oup
self.act_max = act_max * 2
self.act_relu = act_relu
self.avg_pool = nn.Sequential(nn.ReLU() if relu_before_pool == True else
nn.Sequential(), nn.AdaptiveAvgPool2D(1))
self.exp = 4 if act_relu else 2
self.init_a = init_a
self.init_b = init_b
# determine squeeze
squeeze = make_divisible(inp // reduction, 4)
if squeeze < 4:
squeeze = 4
self.fc = nn.Sequential(
nn.Linear(inp, squeeze),
nn.ReLU(), nn.Linear(squeeze, oup * self.exp), nn.Hardsigmoid())
if g is None:
g = 1
self.g = g[1]
if self.g != 1 and expansion:
self.g = inp // self.g
self.gc = inp // self.g
index = paddle.to_tensor([range(inp)])
index = paddle.reshape(index, [1, inp, 1, 1])
index = paddle.reshape(index, [1, self.g, self.gc, 1, 1])
indexgs = paddle.split(index, [1, self.g - 1], axis=1)
indexgs = paddle.concat((indexgs[1], indexgs[0]), axis=1)
indexs = paddle.split(indexgs, [1, self.gc - 1], axis=2)
indexs = paddle.concat((indexs[1], indexs[0]), axis=2)
self.index = paddle.reshape(indexs, [inp])
self.expansion = expansion
def forward(self, x):
x_in = x
x_out = x
b, c, _, _ = x_in.shape
y = self.avg_pool(x_in)
y = paddle.reshape(y, [b, c])
y = self.fc(y)
y = paddle.reshape(y, [b, self.oup * self.exp, 1, 1])
y = (y - 0.5) * self.act_max
n2, c2, h2, w2 = x_out.shape
x2 = paddle.to_tensor(x_out.numpy()[:, self.index.numpy(), :, :])
if self.exp == 4:
temp = y.shape
a1, b1, a2, b2 = paddle.split(y, temp[1] // self.oup, axis=1)
a1 = a1 + self.init_a[0]
a2 = a2 + self.init_a[1]
b1 = b1 + self.init_b[0]
b2 = b2 + self.init_b[1]
z1 = x_out * a1 + x2 * b1
z2 = x_out * a2 + x2 * b2
out = paddle.maximum(z1, z2)
elif self.exp == 2:
temp = y.shape
a1, b1 = paddle.split(y, temp[1] // self.oup, axis=1)
a1 = a1 + self.init_a[0]
b1 = b1 + self.init_b[0]
out = x_out * a1 + x2 * b1
return out
class DYMicroBlock(nn.Layer):
def __init__(self,
inp,
oup,
kernel_size=3,
stride=1,
ch_exp=(2, 2),
ch_per_group=4,
groups_1x1=(1, 1),
depthsep=True,
shuffle=False,
activation_cfg=None):
super(DYMicroBlock, self).__init__()
self.identity = stride == 1 and inp == oup
y1, y2, y3 = activation_cfg['dy']
act_reduction = 8 * activation_cfg['ratio']
init_a = activation_cfg['init_a']
init_b = activation_cfg['init_b']
t1 = ch_exp
gs1 = ch_per_group
hidden_fft, g1, g2 = groups_1x1
hidden_dim2 = inp * t1[0] * t1[1]
if gs1[0] == 0:
self.layers = nn.Sequential(
DepthSpatialSepConv(inp, t1, kernel_size, stride),
DYShiftMax(
hidden_dim2,
hidden_dim2,
act_max=2.0,
act_relu=True if y2 == 2 else False,
init_a=init_a,
reduction=act_reduction,
init_b=init_b,
g=gs1,
expansion=False) if y2 > 0 else nn.ReLU6(),
ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
ChannelShuffle(hidden_dim2 // 2)
if shuffle and y2 != 0 else nn.Sequential(),
GroupConv(hidden_dim2, oup, (g1, g2)),
DYShiftMax(
oup,
oup,
act_max=2.0,
act_relu=False,
init_a=[1.0, 0.0],
reduction=act_reduction // 2,
init_b=[0.0, 0.0],
g=(g1, g2),
expansion=False) if y3 > 0 else nn.Sequential(),
ChannelShuffle(g2) if shuffle else nn.Sequential(),
ChannelShuffle(oup // 2)
if shuffle and oup % 2 == 0 and y3 != 0 else nn.Sequential(), )
elif g2 == 0:
self.layers = nn.Sequential(
GroupConv(inp, hidden_dim2, gs1),
DYShiftMax(
hidden_dim2,
hidden_dim2,
act_max=2.0,
act_relu=False,
init_a=[1.0, 0.0],
reduction=act_reduction,
init_b=[0.0, 0.0],
g=gs1,
expansion=False) if y3 > 0 else nn.Sequential(), )
else:
self.layers = nn.Sequential(
GroupConv(inp, hidden_dim2, gs1),
DYShiftMax(
hidden_dim2,
hidden_dim2,
act_max=2.0,
act_relu=True if y1 == 2 else False,
init_a=init_a,
reduction=act_reduction,
init_b=init_b,
g=gs1,
expansion=False) if y1 > 0 else nn.ReLU6(),
ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
DepthSpatialSepConv(hidden_dim2, (1, 1), kernel_size, stride)
if depthsep else
DepthConv(hidden_dim2, hidden_dim2, kernel_size, stride),
nn.Sequential(),
DYShiftMax(
hidden_dim2,
hidden_dim2,
act_max=2.0,
act_relu=True if y2 == 2 else False,
init_a=init_a,
reduction=act_reduction,
init_b=init_b,
g=gs1,
expansion=True) if y2 > 0 else nn.ReLU6(),
ChannelShuffle(hidden_dim2 // 4)
if shuffle and y1 != 0 and y2 != 0 else nn.Sequential()
if y1 == 0 and y2 == 0 else ChannelShuffle(hidden_dim2 // 2),
GroupConv(hidden_dim2, oup, (g1, g2)),
DYShiftMax(
oup,
oup,
act_max=2.0,
act_relu=False,
init_a=[1.0, 0.0],
reduction=act_reduction // 2
if oup < hidden_dim2 else act_reduction,
init_b=[0.0, 0.0],
g=(g1, g2),
expansion=False) if y3 > 0 else nn.Sequential(),
ChannelShuffle(g2) if shuffle else nn.Sequential(),
ChannelShuffle(oup // 2)
if shuffle and y3 != 0 else nn.Sequential(), )
def forward(self, x):
identity = x
out = self.layers(x)
if self.identity:
out = out + identity
return out
class MicroNet(nn.Layer):
"""
the MicroNet backbone network for recognition module.
Args:
mode(str): {'M0', 'M1', 'M2', 'M3'}
Four models are proposed based on four different computational costs (4M, 6M, 12M, 21M MAdds)
Default: 'M3'.
"""
def __init__(self, mode='M3', **kwargs):
super(MicroNet, self).__init__()
self.cfgs = get_micronet_config(mode)
activation_cfg = {}
if mode == 'M0':
input_channel = 4
stem_groups = 2, 2
out_ch = 384
activation_cfg['init_a'] = 1.0, 1.0
activation_cfg['init_b'] = 0.0, 0.0
elif mode == 'M1':
input_channel = 6
stem_groups = 3, 2
out_ch = 576
activation_cfg['init_a'] = 1.0, 1.0
activation_cfg['init_b'] = 0.0, 0.0
elif mode == 'M2':
input_channel = 8
stem_groups = 4, 2
out_ch = 768
activation_cfg['init_a'] = 1.0, 1.0
activation_cfg['init_b'] = 0.0, 0.0
elif mode == 'M3':
input_channel = 12
stem_groups = 4, 3
out_ch = 432
activation_cfg['init_a'] = 1.0, 0.5
activation_cfg['init_b'] = 0.0, 0.5
else:
raise NotImplementedError("mode[" + mode +
"_model] is not implemented!")
layers = [StemLayer(3, input_channel, stride=2, groups=stem_groups)]
for idx, val in enumerate(self.cfgs):
s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r = val
t1 = (c1, c2)
gs1 = (g1, g2)
gs2 = (c3, g3, g4)
activation_cfg['dy'] = [y1, y2, y3]
activation_cfg['ratio'] = r
output_channel = c
layers.append(
DYMicroBlock(
input_channel,
output_channel,
kernel_size=ks,
stride=s,
ch_exp=t1,
ch_per_group=gs1,
groups_1x1=gs2,
depthsep=True,
shuffle=True,
activation_cfg=activation_cfg, ))
input_channel = output_channel
for i in range(1, n):
layers.append(
DYMicroBlock(
input_channel,
output_channel,
kernel_size=ks,
stride=1,
ch_exp=t1,
ch_per_group=gs1,
groups_1x1=gs2,
depthsep=True,
shuffle=True,
activation_cfg=activation_cfg, ))
input_channel = output_channel
self.features = nn.Sequential(*layers)
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(out_ch)
def forward(self, x):
x = self.features(x)
x = self.pool(x)
return x

View File

@@ -1,138 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
from ppocr.modeling.backbones.det_mobilenet_v3 import ResidualUnit, ConvBNLayer, make_divisible
__all__ = ['MobileNetV3']
class MobileNetV3(nn.Layer):
def __init__(self,
in_channels=3,
model_name='small',
scale=0.5,
large_stride=None,
small_stride=None,
disable_se=False,
**kwargs):
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if small_stride is None:
small_stride = [2, 2, 2, 2]
if large_stride is None:
large_stride = [1, 2, 2, 2]
assert isinstance(large_stride, list), "large_stride type must " \
"be list but got {}".format(type(large_stride))
assert isinstance(small_stride, list), "small_stride type must " \
"be list but got {}".format(type(small_stride))
assert len(large_stride) == 4, "large_stride length must be " \
"4 but got {}".format(len(large_stride))
assert len(small_stride) == 4, "small_stride length must be " \
"4 but got {}".format(len(small_stride))
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', large_stride[0]],
[3, 64, 24, False, 'relu', (large_stride[1], 1)],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hardswish', 1],
[3, 200, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 480, 112, True, 'hardswish', 1],
[3, 672, 112, True, 'hardswish', 1],
[5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
[5, 960, 160, True, 'hardswish', 1],
[5, 960, 160, True, 'hardswish', 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
[5, 240, 40, True, 'hardswish', 1],
[5, 240, 40, True, 'hardswish', 1],
[5, 120, 48, True, 'hardswish', 1],
[5, 144, 48, True, 'hardswish', 1],
[5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
[5, 576, 96, True, 'hardswish', 1],
[5, 576, 96, True, 'hardswish', 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError("mode[" + model_name +
"_model] is not implemented!")
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert scale in supported_scale, \
"supported scales are {} but input scale is {}".format(supported_scale, scale)
inplanes = 16
# conv1
self.conv1 = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
if_act=True,
act='hardswish')
i = 0
block_list = []
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl))
inplanes = make_divisible(scale * c)
i += 1
self.blocks = nn.Sequential(*block_list)
self.conv2 = ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act='hardswish')
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze)
def forward(self, x):
x = self.conv1(x)
x = self.blocks(x)
x = self.conv2(x)
x = self.pool(x)
return x

View File

@@ -1,256 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is refer from: https://github.com/PaddlePaddle/PaddleClas/blob/develop/ppcls/arch/backbone/legendary_models/pp_lcnet.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import paddle
from paddle import ParamAttr, reshape, transpose
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import KaimingNormal
from paddle.regularizer import L2Decay
from paddle.nn.functional import hardswish, hardsigmoid
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='hard_swish'):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DepthwiseSeparable(nn.Layer):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
dw_size=3,
padding=1,
use_se=False):
super(DepthwiseSeparable, self).__init__()
self.use_se = use_se
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=dw_size,
stride=stride,
padding=padding,
num_groups=int(num_groups * scale))
if use_se:
self._se = SEModule(int(num_filters1 * scale))
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
if self.use_se:
y = self._se(y)
y = self._pointwise_conv(y)
return y
class MobileNetV1Enhance(nn.Layer):
def __init__(self,
in_channels=3,
scale=0.5,
last_conv_stride=1,
last_pool_type='max',
**kwargs):
super().__init__()
self.scale = scale
self.block_list = []
self.conv1 = ConvBNLayer(
num_channels=3,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
conv2_1 = DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale)
self.block_list.append(conv2_1)
conv2_2 = DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=1,
scale=scale)
self.block_list.append(conv2_2)
conv3_1 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale)
self.block_list.append(conv3_1)
conv3_2 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=(2, 1),
scale=scale)
self.block_list.append(conv3_2)
conv4_1 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale)
self.block_list.append(conv4_1)
conv4_2 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=(2, 1),
scale=scale)
self.block_list.append(conv4_2)
for _ in range(5):
conv5 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
dw_size=5,
padding=2,
scale=scale,
use_se=False)
self.block_list.append(conv5)
conv5_6 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=(2, 1),
dw_size=5,
padding=2,
scale=scale,
use_se=True)
self.block_list.append(conv5_6)
conv6 = DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=last_conv_stride,
dw_size=5,
padding=2,
use_se=True,
scale=scale)
self.block_list.append(conv6)
self.block_list = nn.Sequential(*self.block_list)
if last_pool_type == 'avg':
self.pool = nn.AvgPool2D(kernel_size=2, stride=2, padding=0)
else:
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale)
def forward(self, inputs):
y = self.conv1(inputs)
y = self.block_list(y)
y = self.pool(y)
return y
class SEModule(nn.Layer):
def __init__(self, channel, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(),
bias_attr=ParamAttr())
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(),
bias_attr=ParamAttr())
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = hardsigmoid(outputs)
return paddle.multiply(x=inputs, y=outputs)

View File

@@ -1,48 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
import paddle
class MTB(nn.Layer):
def __init__(self, cnn_num, in_channels):
super(MTB, self).__init__()
self.block = nn.Sequential()
self.out_channels = in_channels
self.cnn_num = cnn_num
if self.cnn_num == 2:
for i in range(self.cnn_num):
self.block.add_sublayer(
'conv_{}'.format(i),
nn.Conv2D(
in_channels=in_channels
if i == 0 else 32 * (2**(i - 1)),
out_channels=32 * (2**i),
kernel_size=3,
stride=2,
padding=1))
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
self.block.add_sublayer('bn_{}'.format(i),
nn.BatchNorm2D(32 * (2**i)))
def forward(self, images):
x = self.block(images)
if self.cnn_num == 2:
# (b, w, h, c)
x = paddle.transpose(x, [0, 3, 2, 1])
x_shape = paddle.shape(x)
x = paddle.reshape(
x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
return x

View File

@@ -1,210 +0,0 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
__all__ = ["ResNet31"]
def conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2D(
in_channel,
out_channel,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, in_channels, channels, stride=1, downsample=False):
super().__init__()
self.conv1 = conv3x3(in_channels, channels, stride)
self.bn1 = nn.BatchNorm2D(channels)
self.relu = nn.ReLU()
self.conv2 = conv3x3(channels, channels)
self.bn2 = nn.BatchNorm2D(channels)
self.downsample = downsample
if downsample:
self.downsample = nn.Sequential(
nn.Conv2D(
in_channels,
channels * self.expansion,
1,
stride,
bias_attr=False),
nn.BatchNorm2D(channels * self.expansion), )
else:
self.downsample = nn.Sequential()
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet31(nn.Layer):
'''
Args:
in_channels (int): Number of channels of input image tensor.
layers (list[int]): List of BasicBlock number for each stage.
channels (list[int]): List of out_channels of Conv2d layer.
out_indices (None | Sequence[int]): Indices of output stages.
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
'''
def __init__(self,
in_channels=3,
layers=[1, 2, 5, 3],
channels=[64, 128, 256, 256, 512, 512, 512],
out_indices=None,
last_stage_pool=False):
super(ResNet31, self).__init__()
assert isinstance(in_channels, int)
assert isinstance(last_stage_pool, bool)
self.out_indices = out_indices
self.last_stage_pool = last_stage_pool
# conv 1 (Conv Conv)
self.conv1_1 = nn.Conv2D(
in_channels, channels[0], kernel_size=3, stride=1, padding=1)
self.bn1_1 = nn.BatchNorm2D(channels[0])
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2D(
channels[0], channels[1], kernel_size=3, stride=1, padding=1)
self.bn1_2 = nn.BatchNorm2D(channels[1])
self.relu1_2 = nn.ReLU()
# conv 2 (Max-pooling, Residual block, Conv)
self.pool2 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
self.conv2 = nn.Conv2D(
channels[2], channels[2], kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(channels[2])
self.relu2 = nn.ReLU()
# conv 3 (Max-pooling, Residual block, Conv)
self.pool3 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
self.conv3 = nn.Conv2D(
channels[3], channels[3], kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2D(channels[3])
self.relu3 = nn.ReLU()
# conv 4 (Max-pooling, Residual block, Conv)
self.pool4 = nn.MaxPool2D(
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
self.conv4 = nn.Conv2D(
channels[4], channels[4], kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2D(channels[4])
self.relu4 = nn.ReLU()
# conv 5 ((Max-pooling), Residual block, Conv)
self.pool5 = None
if self.last_stage_pool:
self.pool5 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
self.conv5 = nn.Conv2D(
channels[5], channels[5], kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2D(channels[5])
self.relu5 = nn.ReLU()
self.out_channels = channels[-1]
def _make_layer(self, input_channels, output_channels, blocks):
layers = []
for _ in range(blocks):
downsample = None
if input_channels != output_channels:
downsample = nn.Sequential(
nn.Conv2D(
input_channels,
output_channels,
kernel_size=1,
stride=1,
bias_attr=False),
nn.BatchNorm2D(output_channels), )
layers.append(
BasicBlock(
input_channels, output_channels, downsample=downsample))
input_channels = output_channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu1_1(x)
x = self.conv1_2(x)
x = self.bn1_2(x)
x = self.relu1_2(x)
outs = []
for i in range(4):
layer_index = i + 2
pool_layer = getattr(self, f'pool{layer_index}')
block_layer = getattr(self, f'block{layer_index}')
conv_layer = getattr(self, f'conv{layer_index}')
bn_layer = getattr(self, f'bn{layer_index}')
relu_layer = getattr(self, f'relu{layer_index}')
if pool_layer is not None:
x = pool_layer(x)
x = block_layer(x)
x = conv_layer(x)
x = bn_layer(x)
x = relu_layer(x)
outs.append(x)
if self.out_indices is not None:
return tuple([outs[i] for i in self.out_indices])
return x

View File

@@ -1,143 +0,0 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/resnet_aster.py
"""
import paddle
import paddle.nn as nn
import sys
import math
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2D(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2D(
in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
# [n_position]
positions = paddle.arange(0, n_position)
# [feat_dim]
dim_range = paddle.arange(0, feat_dim)
dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
# [n_position, feat_dim]
angles = paddle.unsqueeze(
positions, axis=1) / paddle.unsqueeze(
dim_range, axis=0)
angles = paddle.cast(angles, "float32")
angles[:, 0::2] = paddle.sin(angles[:, 0::2])
angles[:, 1::2] = paddle.cos(angles[:, 1::2])
return angles
class AsterBlock(nn.Layer):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(AsterBlock, self).__init__()
self.conv1 = conv1x1(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2D(planes)
self.relu = nn.ReLU()
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2D(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet_ASTER(nn.Layer):
"""For aster or crnn"""
def __init__(self, with_lstm=True, n_group=1, in_channels=3):
super(ResNet_ASTER, self).__init__()
self.with_lstm = with_lstm
self.n_group = n_group
self.layer0 = nn.Sequential(
nn.Conv2D(
in_channels,
32,
kernel_size=(3, 3),
stride=1,
padding=1,
bias_attr=False),
nn.BatchNorm2D(32),
nn.ReLU())
self.inplanes = 32
self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
if with_lstm:
self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
self.out_channels = 2 * 256
else:
self.out_channels = 512
def _make_layer(self, planes, blocks, stride):
downsample = None
if stride != [1, 1] or self.inplanes != planes:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
layers = []
layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
self.inplanes = planes
for _ in range(1, blocks):
layers.append(AsterBlock(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x0 = self.layer0(x)
x1 = self.layer1(x0)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)
cnn_feat = x5.squeeze(2) # [N, c, w]
cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
if self.with_lstm:
rnn_feat, _ = self.rnn(cnn_feat)
return rnn_feat
else:
return cnn_feat

Some files were not shown because too many files have changed in this diff Show More