mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-09 08:14:40 +08:00
Compare commits
85 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a577efd3e1 | ||
|
|
2bb2d54314 | ||
|
|
7049a24883 | ||
|
|
7e0ed62b9e | ||
|
|
7a5384ad95 | ||
|
|
b40ab7d921 | ||
|
|
7304905a84 | ||
|
|
150923b409 | ||
|
|
746db4bced | ||
|
|
30e7913981 | ||
|
|
6fc6d35584 | ||
|
|
87d8d9d3d7 | ||
|
|
3770ccdcfd | ||
|
|
29873c33ea | ||
|
|
196a1a8e7b | ||
|
|
53baf28326 | ||
|
|
991294c00f | ||
|
|
ea7e01e3aa | ||
|
|
acdb150aa2 | ||
|
|
285bfbafa7 | ||
|
|
97b4159d38 | ||
|
|
bb80445cf4 | ||
|
|
7e8d0b818b | ||
|
|
77758d258b | ||
|
|
c60234f4ec | ||
|
|
38ff91fad7 | ||
|
|
9f9cded1ff | ||
|
|
54027ceeb0 | ||
|
|
1dc9036dee | ||
|
|
09dbfa47f2 | ||
|
|
aa83db0f98 | ||
|
|
f86c8c9fe8 | ||
|
|
3dc8f3bfe0 | ||
|
|
b0ca454473 | ||
|
|
9f7fd5b341 | ||
|
|
535fdecef4 | ||
|
|
019f7f4517 | ||
|
|
8c5ea2e19d | ||
|
|
330cf54e1a | ||
|
|
7019572f7b | ||
|
|
ee53840adb | ||
|
|
96d744b3a7 | ||
|
|
32c47873ab | ||
|
|
99770a32b9 | ||
|
|
0f71d732e1 | ||
|
|
f3a982710d | ||
|
|
e07849ef87 | ||
|
|
96099ea2d4 | ||
|
|
f4c22dd420 | ||
|
|
a3452832ff | ||
|
|
45e80bc9b0 | ||
|
|
4a09342987 | ||
|
|
caf4cb27f4 | ||
|
|
c927476c0f | ||
|
|
61aa3d8f88 | ||
|
|
67fdacdd8b | ||
|
|
3d21963995 | ||
|
|
a3dd7b797d | ||
|
|
6b353455a0 | ||
|
|
d6736d9206 | ||
|
|
4abc3409ac | ||
|
|
2d1eb11fd6 | ||
|
|
9a65c17a50 | ||
|
|
3ce8d7409b | ||
|
|
f9dd30fddf | ||
|
|
fda9024084 | ||
|
|
19141ff5c9 | ||
|
|
97b54f6d9e | ||
|
|
584e574795 | ||
|
|
dad37eba7d | ||
|
|
063a896cb9 | ||
|
|
63d8378f36 | ||
|
|
4cbfa9ebf0 | ||
|
|
8a8088be1f | ||
|
|
757cc5bf77 | ||
|
|
e536d6af86 | ||
|
|
311701d3e6 | ||
|
|
a7e62db98a | ||
|
|
945aeb9bc8 | ||
|
|
6ea7482344 | ||
|
|
ba396d9569 | ||
|
|
22b021d9ae | ||
|
|
49ae0029f5 | ||
|
|
f89c109636 | ||
|
|
055a08403f |
14
.condarc
14
.condarc
@@ -2,13 +2,9 @@ channels:
|
||||
- defaults
|
||||
show_channel_urls: true
|
||||
default_channels:
|
||||
- http://mirrors.aliyun.com/anaconda/pkgs/main
|
||||
- http://mirrors.aliyun.com/anaconda/pkgs/r
|
||||
- http://mirrors.aliyun.com/anaconda/pkgs/msys2
|
||||
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
|
||||
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
|
||||
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2
|
||||
custom_channels:
|
||||
conda-forge: http://mirrors.aliyun.com/anaconda/cloud
|
||||
msys2: http://mirrors.aliyun.com/anaconda/cloud
|
||||
bioconda: http://mirrors.aliyun.com/anaconda/cloud
|
||||
menpo: http://mirrors.aliyun.com/anaconda/cloud
|
||||
pytorch: http://mirrors.aliyun.com/anaconda/cloud
|
||||
simpleitk: http://mirrors.aliyun.com/anaconda/cloud
|
||||
conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
|
||||
pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
|
||||
|
||||
97
.github/workflows/build-docker.yml
vendored
Normal file
97
.github/workflows/build-docker.yml
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
name: Docker Build and Push
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- '**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check-secrets:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
has_secrets: ${{ steps.check.outputs.has_secrets }}
|
||||
steps:
|
||||
- id: check
|
||||
run: |
|
||||
if [[ -n "${{ secrets.DOCKERHUB_USERNAME }}" && -n "${{ secrets.DOCKERHUB_TOKEN }}" ]]; then
|
||||
echo "has_secrets=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "has_secrets=false" >> $GITHUB_OUTPUT
|
||||
echo "未设置 Docker Hub 凭据,将跳过整个 Action"
|
||||
fi
|
||||
|
||||
build-and-push:
|
||||
needs: check-secrets
|
||||
if: needs.check-secrets.outputs.has_secrets == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- type: cuda
|
||||
version: "11.8"
|
||||
- type: cuda
|
||||
version: "12.6"
|
||||
- type: cuda
|
||||
version: "12.8"
|
||||
- type: directml
|
||||
version: "latest"
|
||||
|
||||
steps:
|
||||
|
||||
- name: Show system
|
||||
run: |
|
||||
echo -e "Total CPU cores\t: $(nproc)"
|
||||
cat /proc/cpuinfo | grep 'model name'
|
||||
ulimit -a
|
||||
|
||||
- name: Free disk space
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android
|
||||
df -h
|
||||
|
||||
- name: 检出代码
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: 读取 VERSION
|
||||
id: version
|
||||
run: |
|
||||
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
||||
shell: bash
|
||||
|
||||
- name: 设置 Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: 登录到 Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: 提取元数据
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/video-subtitle-remover
|
||||
tags: |
|
||||
type=raw,value=${{ env.VERSION }}-${{ matrix.type }}${{ matrix.type == 'cuda' && matrix.version || '' }}
|
||||
|
||||
- name: 构建并推送
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
build-args: |
|
||||
${{ matrix.type == 'cuda' && format('CUDA_VERSION={0}', matrix.version) || '' }}
|
||||
${{ matrix.type == 'directml' && 'USE_DIRECTML=1' || '' }}
|
||||
|
||||
- name: Docker Hub Description
|
||||
uses: peter-evans/dockerhub-description@v4
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
repository: ${{ secrets.DOCKERHUB_USERNAME }}/video-subtitle-remover
|
||||
94
.github/workflows/build-windows-cuda-11.8.yml
vendored
Normal file
94
.github/workflows/build-windows-cuda-11.8.yml
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
name: Build Windows CUDA 11.8
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- '**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
ssh:
|
||||
description: 'SSH connection to Actions'
|
||||
required: false
|
||||
default: false
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: 读取 VERSION
|
||||
id: version
|
||||
run: |
|
||||
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
||||
shell: bash
|
||||
# - name: 检查 tag 是否已存在
|
||||
# run: |
|
||||
# TAG_NAME="${VERSION}"
|
||||
# if git ls-remote --tags origin | grep -q "refs/tags/$TAG_NAME"; then
|
||||
# echo "Tag $TAG_NAME 已存在,发布中止"
|
||||
# exit 1
|
||||
# fi
|
||||
# shell: bash
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- run: pip install paddlepaddle==3.0.0
|
||||
- run: pip install -r requirements.txt
|
||||
- run: pip freeze > requirements.txt
|
||||
- run: pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu118
|
||||
- run: pip install QPT==1.0b8 setuptools
|
||||
- name: 获取 site-packages 路径
|
||||
shell: bash
|
||||
run: |
|
||||
SITE_PACKAGES=$(python -c "import site, os; print(os.path.join(site.getsitepackages()[0], 'Lib', 'site-packages'))")
|
||||
SITE_PACKAGES_UNIX=$(cygpath -u "$SITE_PACKAGES")
|
||||
echo "site-packages路径: $SITE_PACKAGES"
|
||||
echo "site-packages UNIX路径: $SITE_PACKAGES_UNIX"
|
||||
echo "SITE_PACKAGES_UNIX=$SITE_PACKAGES_UNIX" >> $GITHUB_ENV
|
||||
echo "SITE_PACKAGES=$SITE_PACKAGES" >> $GITHUB_ENV
|
||||
- name: 修复QPT内部错误
|
||||
run: sed -i '98c\ try:\n dep = pkg.requires()\n except TypeError:\n continue' ${SITE_PACKAGES_UNIX}/qpt/kernel/qpackage.py
|
||||
shell: bash
|
||||
- name: Start SSH via tmate
|
||||
if: (github.event.inputs.ssh == 'true' && github.event.inputs.ssh != 'false') || contains(github.event.action, 'ssh')
|
||||
uses: mxschmitt/action-tmate@v3
|
||||
- run: |
|
||||
python backend/tools/makedist.py --cuda 11.8 && \
|
||||
mv ../vsr_out ./vsr_out && \
|
||||
cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
|
||||
env:
|
||||
QPT_Action: "True"
|
||||
shell: bash
|
||||
- name: 上传 Debug 文件夹到 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8-debug
|
||||
path: vsr_out/Debug/
|
||||
- name: 上传 Release 文件夹到 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8-release
|
||||
path: vsr_out/Release/
|
||||
- name: 打包 Release 文件夹
|
||||
run: |
|
||||
cd vsr_out/Release
|
||||
7z a -t7z -mx=9 -m0=LZMA2 -ms=on -mfb=64 -md=32m -mmt=on -v2000m vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z * && \
|
||||
# 检测是否只有一个分卷
|
||||
if [ -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z.001 ] && [ ! -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z.002 ]; then \
|
||||
mv vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z.001 vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z; fi
|
||||
shell: bash
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
prerelease: true
|
||||
tag_name: ${{ env.VERSION }}
|
||||
target_commitish: ${{ github.sha }}
|
||||
name: 硬字幕去除器 v${{ env.VERSION }}
|
||||
files: |
|
||||
vsr_out/Release/vsr-v${{ env.VERSION }}-windows-nvidia-cuda-11.8.7z*
|
||||
94
.github/workflows/build-windows-cuda-12.6.yml
vendored
Normal file
94
.github/workflows/build-windows-cuda-12.6.yml
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
name: Build Windows CUDA 12.6
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- '**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
ssh:
|
||||
description: 'SSH connection to Actions'
|
||||
required: false
|
||||
default: false
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: 读取 VERSION
|
||||
id: version
|
||||
run: |
|
||||
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
||||
shell: bash
|
||||
# - name: 检查 tag 是否已存在
|
||||
# run: |
|
||||
# TAG_NAME="${VERSION}"
|
||||
# if git ls-remote --tags origin | grep -q "refs/tags/$TAG_NAME"; then
|
||||
# echo "Tag $TAG_NAME 已存在,发布中止"
|
||||
# exit 1
|
||||
# fi
|
||||
# shell: bash
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- run: pip install paddlepaddle==3.0.0
|
||||
- run: pip install -r requirements.txt
|
||||
- run: pip freeze > requirements.txt
|
||||
- run: pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
- run: pip install QPT==1.0b8 setuptools
|
||||
- name: 获取 site-packages 路径
|
||||
shell: bash
|
||||
run: |
|
||||
SITE_PACKAGES=$(python -c "import site, os; print(os.path.join(site.getsitepackages()[0], 'Lib', 'site-packages'))")
|
||||
SITE_PACKAGES_UNIX=$(cygpath -u "$SITE_PACKAGES")
|
||||
echo "site-packages路径: $SITE_PACKAGES"
|
||||
echo "site-packages UNIX路径: $SITE_PACKAGES_UNIX"
|
||||
echo "SITE_PACKAGES_UNIX=$SITE_PACKAGES_UNIX" >> $GITHUB_ENV
|
||||
echo "SITE_PACKAGES=$SITE_PACKAGES" >> $GITHUB_ENV
|
||||
- name: 修复QPT内部错误
|
||||
run: sed -i '98c\ try:\n dep = pkg.requires()\n except TypeError:\n continue' ${SITE_PACKAGES_UNIX}/qpt/kernel/qpackage.py
|
||||
shell: bash
|
||||
- name: Start SSH via tmate
|
||||
if: (github.event.inputs.ssh == 'true' && github.event.inputs.ssh != 'false') || contains(github.event.action, 'ssh')
|
||||
uses: mxschmitt/action-tmate@v3
|
||||
- run: |
|
||||
python backend/tools/makedist.py --cuda 12.6 && \
|
||||
mv ../vsr_out ./vsr_out && \
|
||||
cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
|
||||
env:
|
||||
QPT_Action: "True"
|
||||
shell: bash
|
||||
- name: 上传 Debug 文件夹到 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6-debug
|
||||
path: vsr_out/Debug/
|
||||
- name: 上传 Release 文件夹到 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6-release
|
||||
path: vsr_out/Release/
|
||||
- name: 打包 Release 文件夹
|
||||
run: |
|
||||
cd vsr_out/Release
|
||||
7z a -t7z -mx=9 -m0=LZMA2 -ms=on -mfb=64 -md=32m -mmt=on -v2000m vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z * && \
|
||||
# 检测是否只有一个分卷
|
||||
if [ -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z.001 ] && [ ! -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z.002 ]; then \
|
||||
mv vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z.001 vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z; fi
|
||||
shell: bash
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
prerelease: true
|
||||
tag_name: ${{ env.VERSION }}
|
||||
target_commitish: ${{ github.sha }}
|
||||
name: 硬字幕去除器 v${{ env.VERSION }}
|
||||
files: |
|
||||
vsr_out/Release/vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.6.7z*
|
||||
94
.github/workflows/build-windows-cuda-12.8.yml
vendored
Normal file
94
.github/workflows/build-windows-cuda-12.8.yml
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
name: Build Windows CUDA 12.8
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- '**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
ssh:
|
||||
description: 'SSH connection to Actions'
|
||||
required: false
|
||||
default: false
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: 读取 VERSION
|
||||
id: version
|
||||
run: |
|
||||
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
||||
shell: bash
|
||||
# - name: 检查 tag 是否已存在
|
||||
# run: |
|
||||
# TAG_NAME="${VERSION}"
|
||||
# if git ls-remote --tags origin | grep -q "refs/tags/$TAG_NAME"; then
|
||||
# echo "Tag $TAG_NAME 已存在,发布中止"
|
||||
# exit 1
|
||||
# fi
|
||||
# shell: bash
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- run: pip install paddlepaddle==3.0.0
|
||||
- run: pip install -r requirements.txt
|
||||
- run: pip freeze > requirements.txt
|
||||
- run: pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu128
|
||||
- run: pip install QPT==1.0b8 setuptools
|
||||
- name: 获取 site-packages 路径
|
||||
shell: bash
|
||||
run: |
|
||||
SITE_PACKAGES=$(python -c "import site, os; print(os.path.join(site.getsitepackages()[0], 'Lib', 'site-packages'))")
|
||||
SITE_PACKAGES_UNIX=$(cygpath -u "$SITE_PACKAGES")
|
||||
echo "site-packages路径: $SITE_PACKAGES"
|
||||
echo "site-packages UNIX路径: $SITE_PACKAGES_UNIX"
|
||||
echo "SITE_PACKAGES_UNIX=$SITE_PACKAGES_UNIX" >> $GITHUB_ENV
|
||||
echo "SITE_PACKAGES=$SITE_PACKAGES" >> $GITHUB_ENV
|
||||
- name: 修复QPT内部错误
|
||||
run: sed -i '98c\ try:\n dep = pkg.requires()\n except TypeError:\n continue' ${SITE_PACKAGES_UNIX}/qpt/kernel/qpackage.py
|
||||
shell: bash
|
||||
- name: Start SSH via tmate
|
||||
if: (github.event.inputs.ssh == 'true' && github.event.inputs.ssh != 'false') || contains(github.event.action, 'ssh')
|
||||
uses: mxschmitt/action-tmate@v3
|
||||
- run: |
|
||||
python backend/tools/makedist.py --cuda 12.8 && \
|
||||
mv ../vsr_out ./vsr_out && \
|
||||
cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
|
||||
env:
|
||||
QPT_Action: "True"
|
||||
shell: bash
|
||||
- name: 上传 Debug 文件夹到 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8-debug
|
||||
path: vsr_out/Debug/
|
||||
- name: 上传 Release 文件夹到 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8-release
|
||||
path: vsr_out/Release/
|
||||
- name: 打包 Release 文件夹
|
||||
run: |
|
||||
cd vsr_out/Release
|
||||
7z a -t7z -mx=9 -m0=LZMA2 -ms=on -mfb=64 -md=32m -mmt=on -v2000m vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z * && \
|
||||
# 检测是否只有一个分卷
|
||||
if [ -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z.001 ] && [ ! -f vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z.002 ]; then \
|
||||
mv vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z.001 vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z; fi
|
||||
shell: bash
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
prerelease: true
|
||||
tag_name: ${{ env.VERSION }}
|
||||
target_commitish: ${{ github.sha }}
|
||||
name: 硬字幕去除器 v${{ env.VERSION }}
|
||||
files: |
|
||||
vsr_out/Release/vsr-v${{ env.VERSION }}-windows-nvidia-cuda-12.8.7z*
|
||||
94
.github/workflows/build-windows-directml.yml
vendored
Normal file
94
.github/workflows/build-windows-directml.yml
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
name: Build Windows DirectML
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- '**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
ssh:
|
||||
description: 'SSH connection to Actions'
|
||||
required: false
|
||||
default: false
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: 读取 VERSION
|
||||
id: version
|
||||
run: |
|
||||
VERSION=$(sed -n 's/^VERSION = "\(.*\)"/\1/p' backend/config.py)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
||||
shell: bash
|
||||
# - name: 检查 tag 是否已存在
|
||||
# run: |
|
||||
# TAG_NAME="${VERSION}"
|
||||
# if git ls-remote --tags origin | grep -q "refs/tags/$TAG_NAME"; then
|
||||
# echo "Tag $TAG_NAME 已存在,发布中止"
|
||||
# exit 1
|
||||
# fi
|
||||
# shell: bash
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- run: pip install paddlepaddle==3.0.0
|
||||
- run: pip install -r requirements.txt
|
||||
- run: pip freeze > requirements.txt
|
||||
- run: pip install torch_directml==0.2.5.dev240914
|
||||
- run: pip install QPT==1.0b8 setuptools
|
||||
- name: 获取 site-packages 路径
|
||||
shell: bash
|
||||
run: |
|
||||
SITE_PACKAGES=$(python -c "import site, os; print(os.path.join(site.getsitepackages()[0], 'Lib', 'site-packages'))")
|
||||
SITE_PACKAGES_UNIX=$(cygpath -u "$SITE_PACKAGES")
|
||||
echo "site-packages路径: $SITE_PACKAGES"
|
||||
echo "site-packages UNIX路径: $SITE_PACKAGES_UNIX"
|
||||
echo "SITE_PACKAGES_UNIX=$SITE_PACKAGES_UNIX" >> $GITHUB_ENV
|
||||
echo "SITE_PACKAGES=$SITE_PACKAGES" >> $GITHUB_ENV
|
||||
- name: 修复QPT内部错误
|
||||
run: sed -i '98c\ try:\n dep = pkg.requires()\n except TypeError:\n continue' ${SITE_PACKAGES_UNIX}/qpt/kernel/qpackage.py
|
||||
shell: bash
|
||||
- name: Start SSH via tmate
|
||||
if: (github.event.inputs.ssh == 'true' && github.event.inputs.ssh != 'false') || contains(github.event.action, 'ssh')
|
||||
uses: mxschmitt/action-tmate@v3
|
||||
- run: |
|
||||
python backend/tools/makedist.py --directml && \
|
||||
mv ../vsr_out ./vsr_out && \
|
||||
cp ./vsr_out/Debug/Debug-进入虚拟环境.cmd ./vsr_out/Release/
|
||||
env:
|
||||
QPT_Action: "True"
|
||||
shell: bash
|
||||
- name: 上传 Debug 文件夹到 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: vsr-v${{ env.VERSION }}-windows-directml-debug
|
||||
path: vsr_out/Debug/
|
||||
- name: 上传 Release 文件夹到 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: vsr-v${{ env.VERSION }}-windows-directml-release
|
||||
path: vsr_out/Release/
|
||||
- name: 打包 Release 文件夹
|
||||
run: |
|
||||
cd vsr_out/Release
|
||||
7z a -t7z -mx=9 -m0=LZMA2 -ms=on -mfb=64 -md=32m -mmt=on -v2000m vsr-v${{ env.VERSION }}-windows-directml.7z * && \
|
||||
# 检测是否只有一个分卷
|
||||
if [ -f vsr-v${{ env.VERSION }}-windows-directml.7z.001 ] && [ ! -f vsr-v${{ env.VERSION }}-windows-directml.7z.002 ]; then \
|
||||
mv vsr-v${{ env.VERSION }}-windows-directml.7z.001 vsr-v${{ env.VERSION }}-windows-directml.7z; fi
|
||||
shell: bash
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
prerelease: true
|
||||
tag_name: ${{ env.VERSION }}
|
||||
target_commitish: ${{ github.sha }}
|
||||
name: 硬字幕去除器 v${{ env.VERSION }}
|
||||
files: |
|
||||
vsr_out/Release/vsr-v${{ env.VERSION }}-windows-directml.7z*
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -372,3 +372,5 @@ test*_no_sub*.mp4
|
||||
/backend/models/video/ProPainter.pth
|
||||
/backend/models/big-lama/big-lama.pt
|
||||
/test/debug/
|
||||
/backend/tools/train/release_model/
|
||||
model.onnx
|
||||
273
README.md
273
README.md
@@ -3,7 +3,7 @@
|
||||
## 项目简介
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
Video-subtitle-remover (VSR) 是一款基于AI技术,将视频中的硬字幕去除的软件。
|
||||
@@ -12,13 +12,14 @@ Video-subtitle-remover (VSR) 是一款基于AI技术,将视频中的硬字幕
|
||||
- 通过超强AI算法模型,对去除字幕文本的区域进行填充(非相邻像素填充与马赛克去除)
|
||||
- 支持自定义字幕位置,仅去除定义位置中的字幕(传入位置)
|
||||
- 支持全视频自动去除所有文本(不传入位置)
|
||||
- 支持多选图片批量去除水印文本
|
||||
|
||||
<p style="text-align:center;"><img src="https://github.com/YaoFANGUK/video-subtitle-remover/raw/main/design/demo.png" alt="demo.png"/></p>
|
||||
|
||||
**使用说明:**
|
||||
|
||||
- 有使用问题请加群讨论,QQ群:806152575
|
||||
- 直接下载压缩包解压运行,如果不能运行再按照下面的教程,尝试源码安装conda环境运行
|
||||
- 有使用问题请加群讨论,QQ群:210150985(已满)、806152575(已满)、816881808(已满)、295894827
|
||||
- 直接下载压缩包解压运行,如果不能运行再按照下面的教程,尝试源码安装conda环境运行
|
||||
|
||||
**下载地址:**
|
||||
|
||||
@@ -26,9 +27,36 @@ Windows GPU版本v1.1.0(GPU):
|
||||
|
||||
- 百度网盘: <a href="https://pan.baidu.com/s/1zR6CjRztmOGBbOkqK8R1Ng?pwd=vsr1">vsr_windows_gpu_v1.1.0.zip</a> 提取码:**vsr1**
|
||||
|
||||
- Google Drive: <a href="https://drive.google.com/drive/folders/1NRgLNoHHOmdO4GxLhkPbHsYfMOB_3Elr?usp=sharing">vsr_windows_gpu_v1.1.0.zip</a>
|
||||
- Google Drive: <a href="https://drive.google.com/drive/folders/1NRgLNoHHOmdO4GxLhkPbHsYfMOB_3Elr?usp=sharing">vsr_windows_gpu_v1.1.0.zip</a>
|
||||
|
||||
> 仅供具有Nvidia显卡的用户使用(AMD的显卡不行)
|
||||
**预构建包对比说明**:
|
||||
| 预构建包名 | Python | Paddle | Torch | 环境 | 支持的计算能力范围|
|
||||
|---------------|------------|--------------|--------------|-----------------------------|----------|
|
||||
| `vsr-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows 非Nvidia显卡 | 通用 |
|
||||
| `vsr-windows-nvidia-cuda-11.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 – 8.9 |
|
||||
| `vsr-windows-nvidia-cuda-12.6.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 – 8.9 |
|
||||
| `vsr-windows-nvidia-cuda-12.8.7z` | 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 – 9.0+ |
|
||||
|
||||
> NVIDIA官方提供了各GPU型号的计算能力列表,您可以参考链接: [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) 查看你的GPU适合哪个CUDA版本
|
||||
|
||||
**Docker版本:**
|
||||
```shell
|
||||
# Nvidia 10 20 30系显卡
|
||||
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda11.8
|
||||
|
||||
# Nvidia 40系显卡
|
||||
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.6
|
||||
|
||||
# Nvidia 50系显卡
|
||||
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.8
|
||||
|
||||
# AMD / Intel 独显 集显
|
||||
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-directml
|
||||
|
||||
# 演示视频, 输入
|
||||
/vsr/test/test.mp4
|
||||
docker cp vsr:/vsr/test/test_no_sub.mp4 ./
|
||||
```
|
||||
|
||||
## 演示
|
||||
|
||||
@@ -42,114 +70,98 @@ Windows GPU版本v1.1.0(GPU):
|
||||
|
||||
## 源码使用说明
|
||||
|
||||
> **无Nvidia显卡请勿使用本项目**,最低配置:
|
||||
>
|
||||
> **GPU**:GTX 1060或以上显卡
|
||||
>
|
||||
> CPU: 支持AVX指令集
|
||||
|
||||
#### 1. 下载安装Miniconda
|
||||
#### 1. 安装 Python
|
||||
|
||||
- Windows: <a href="https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Windows-x86_64.exe">Miniconda3-py38_4.11.0-Windows-x86_64.exe</a>
|
||||
请确保您已经安装了 Python 3.12+。
|
||||
|
||||
- Linux: <a href="https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Linux-x86_64.sh">Miniconda3-py38_4.11.0-Linux-x86_64.sh</a>
|
||||
- Windows 用户可以前往 [Python 官网](https://www.python.org/downloads/windows/) 下载并安装 Python。
|
||||
- MacOS 用户可以使用 Homebrew 安装:
|
||||
```shell
|
||||
brew install python@3.12
|
||||
```
|
||||
- Linux 用户可以使用包管理器安装,例如 Ubuntu/Debian:
|
||||
```shell
|
||||
sudo apt update && sudo apt install python3.12 python3.12-venv python3.12-dev
|
||||
```
|
||||
|
||||
#### 2. 创建并激活虚机环境
|
||||
#### 2. 安装依赖文件
|
||||
|
||||
(1)切换到源码所在目录:
|
||||
请使用虚拟环境来管理项目依赖,避免与系统环境冲突。
|
||||
|
||||
(1)创建虚拟环境并激活
|
||||
```shell
|
||||
python -m venv videoEnv
|
||||
```
|
||||
|
||||
- Windows:
|
||||
```shell
|
||||
videoEnv\\Scripts\\activate
|
||||
```
|
||||
- MacOS/Linux:
|
||||
```shell
|
||||
source videoEnv/bin/activate
|
||||
```
|
||||
|
||||
#### 3. 创建并激活项目目录
|
||||
|
||||
切换到源码所在目录:
|
||||
```shell
|
||||
cd <源码所在目录>
|
||||
```
|
||||
> 例如:如果你的源代码放在D盘的tools文件下,并且源代码的文件夹名为video-subtitle-remover,就输入 ```cd D:/tools/video-subtitle-remover-main```
|
||||
> 例如:如果您的源代码放在 D 盘的 tools 文件夹下,并且源代码的文件夹名为 video-subtitle-remover,则输入:
|
||||
> ```shell
|
||||
> cd D:/tools/video-subtitle-remover-main
|
||||
> ```
|
||||
|
||||
(2)创建激活conda环境
|
||||
```shell
|
||||
conda create -n videoEnv python=3.8
|
||||
```
|
||||
#### 4. 安装合适的运行环境
|
||||
|
||||
```shell
|
||||
conda activate videoEnv
|
||||
```
|
||||
本项目支持 CUDA(NVIDIA显卡加速)和 DirectML(AMD、Intel等GPU/APU加速)两种运行模式。
|
||||
|
||||
#### 3. 安装依赖文件
|
||||
##### (1) CUDA(NVIDIA 显卡用户)
|
||||
|
||||
请确保你已经安装 python 3.8+,使用conda创建项目虚拟环境并激活环境 (建议创建虚拟环境运行,以免后续出现问题)
|
||||
> 请确保您的 NVIDIA 显卡驱动支持所选 CUDA 版本。
|
||||
|
||||
- 安装CUDA和cuDNN
|
||||
- 推荐 CUDA 11.8,对应 cuDNN 8.6.0。
|
||||
|
||||
<details>
|
||||
<summary>Linux用户</summary>
|
||||
<h5>(1) 下载CUDA 11.7</h5>
|
||||
<pre><code>wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run</code></pre>
|
||||
<h5>(2) 安装CUDA 11.7</h5>
|
||||
<pre><code>sudo sh cuda_11.7.0_515.43.04_linux.run</code></pre>
|
||||
<p>1. 输入accept</p>
|
||||
<img src="https://i.328888.xyz/2023/03/31/iwVoeH.png" width="500" alt="">
|
||||
<p>2. 选中CUDA Toolkit 11.7(如果你没有安装nvidia驱动则选中Driver,如果你已经安装了nvidia驱动请不要选中driver),之后选中install,回车</p>
|
||||
<img src="https://i.328888.xyz/2023/03/31/iwVThJ.png" width="500" alt="">
|
||||
<p>3. 添加环境变量</p>
|
||||
<p>在 ~/.bashrc 加入以下内容</p>
|
||||
<pre><code># CUDA
|
||||
export PATH=/usr/local/cuda-11.7/bin${PATH:+:${PATH}}
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}</code></pre>
|
||||
<p>使其生效</p>
|
||||
<pre><code>source ~/.bashrc</code></pre>
|
||||
<h5>(3) 下载cuDNN 8.4.1</h5>
|
||||
<p>国内:<a href="https://pan.baidu.com/s/1Gd_pSVzWfX1G7zCuqz6YYA">cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz</a> 提取码:57mg</p>
|
||||
<p>国外:<a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz">cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz</a></p>
|
||||
<h5>(4) 安装cuDNN 8.4.1</h5>
|
||||
<pre><code> tar -xf cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz
|
||||
mv cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive cuda
|
||||
sudo cp ./cuda/include/* /usr/local/cuda-11.7/include/
|
||||
sudo cp ./cuda/lib/* /usr/local/cuda-11.7/lib64/
|
||||
sudo chmod a+r /usr/local/cuda-11.7/lib64/*
|
||||
sudo chmod a+r /usr/local/cuda-11.7/include/*</code></pre>
|
||||
</details>
|
||||
- 安装 CUDA:
|
||||
- Windows:[CUDA 11.8 下载](https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe)
|
||||
- Linux:
|
||||
```shell
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
|
||||
sudo sh cuda_11.8.0_520.61.05_linux.run
|
||||
```
|
||||
- MacOS 不支持 CUDA。
|
||||
|
||||
<details>
|
||||
<summary>Windows用户</summary>
|
||||
<h5>(1) 下载CUDA 11.7</h5>
|
||||
<a href="https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_516.01_windows.exe">cuda_11.7.0_516.01_windows.exe</a>
|
||||
<h5>(2) 安装CUDA 11.7</h5>
|
||||
<h5>(3) 下载cuDNN 8.2.4</h5>
|
||||
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x64-v8.2.4.15.zip">cudnn-windows-x64-v8.2.4.15.zip</a></p>
|
||||
<h5>(4) 安装cuDNN 8.2.4</h5>
|
||||
<p>
|
||||
将cuDNN解压后的cuda文件夹中的bin, include, lib目录下的文件复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\对应目录下
|
||||
</p>
|
||||
</details>
|
||||
- 安装 cuDNN(CUDA 11.8 对应 cuDNN 8.6.0):
|
||||
- [Windows cuDNN 8.6.0 下载](https://developer.download.nvidia.cn/compute/redist/cudnn/v8.6.0/local_installers/11.8/cudnn-windows-x86_64-8.6.0.163_cuda11-archive.zip)
|
||||
- [Linux cuDNN 8.6.0 下载](https://developer.download.nvidia.cn/compute/redist/cudnn/v8.6.0/local_installers/11.8/cudnn-linux-x86_64-8.6.0.163_cuda11-archive.tar.xz)
|
||||
- 安装方法请参考 NVIDIA 官方文档。
|
||||
|
||||
|
||||
- 安装GPU版本Paddlepaddle:
|
||||
|
||||
- windows:
|
||||
|
||||
```shell
|
||||
python -m pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html
|
||||
```
|
||||
|
||||
- Linux:
|
||||
|
||||
```shell
|
||||
python -m pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
|
||||
```
|
||||
|
||||
- 安装GPU版本Pytorch:
|
||||
|
||||
```shell
|
||||
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
|
||||
- 安装 PaddlePaddle GPU 版本(CUDA 11.8):
|
||||
```shell
|
||||
pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
|
||||
```
|
||||
或者使用
|
||||
```shell
|
||||
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117
|
||||
- 安装 Torch GPU 版本(CUDA 11.8):
|
||||
```shell
|
||||
pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
- 安装其他依赖:
|
||||
|
||||
- 安装其他依赖
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
##### (2) DirectML(AMD、Intel等GPU/APU加速卡用户)
|
||||
|
||||
- 适用于 Windows 设备的 AMD/NVIDIA/Intel GPU。
|
||||
- 安装 ONNX Runtime DirectML 版本:
|
||||
```shell
|
||||
pip install paddlepaddle==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
|
||||
pip install -r requirements.txt
|
||||
pip install torch_directml==0.2.5.dev240914
|
||||
```
|
||||
|
||||
|
||||
#### 4. 运行程序
|
||||
|
||||
@@ -166,29 +178,80 @@ python ./backend/main.py
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
1. CondaHTTPError
|
||||
1. 提取速度慢怎么办
|
||||
|
||||
修改backend/config.py中的参数,可以大幅度提高去除速度
|
||||
```python
|
||||
MODE = InpaintMode.STTN # 设置为STTN算法
|
||||
STTN_SKIP_DETECTION = True # 跳过字幕检测,跳过后可能会导致要去除的字幕遗漏或者误伤不需要去除字幕的视频帧
|
||||
```
|
||||
|
||||
2. 视频去除效果不好怎么办
|
||||
|
||||
修改backend/config.py中的参数,尝试不同的去除算法,算法介绍
|
||||
|
||||
> - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
|
||||
> - InpaintMode.LAMA 算法:对于图片效果最好,对动画类视频效果好,速度一般,不可以跳过字幕检测
|
||||
> - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
|
||||
|
||||
- 使用STTN算法
|
||||
|
||||
```python
|
||||
MODE = InpaintMode.STTN # 设置为STTN算法
|
||||
# 相邻帧数, 调大会增加显存占用,效果变好
|
||||
STTN_NEIGHBOR_STRIDE = 10
|
||||
# 参考帧长度, 调大会增加显存占用,效果变好
|
||||
STTN_REFERENCE_LENGTH = 10
|
||||
# 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好
|
||||
# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
|
||||
STTN_MAX_LOAD_NUM = 30
|
||||
```
|
||||
- 使用LAMA算法
|
||||
```python
|
||||
MODE = InpaintMode.LAMA # 设置为STTN算法
|
||||
LAMA_SUPER_FAST = False # 保证效果
|
||||
```
|
||||
|
||||
> 如果对模型去字幕的效果不满意,可以查看design文件夹里面的训练方法,利用backend/tools/train里面的代码进行训练,然后将训练的模型替换旧模型即可
|
||||
|
||||
3. CondaHTTPError
|
||||
|
||||
将项目中的.condarc放在用户目录下(C:/Users/<你的用户名>),如果用户目录已经存在该文件则覆盖
|
||||
|
||||
解决方案:https://zhuanlan.zhihu.com/p/260034241
|
||||
|
||||
2. 7z文件解压错误
|
||||
4. 7z文件解压错误
|
||||
|
||||
解决方案:升级7-zip解压程序到最新版本
|
||||
|
||||
3. 4090使用cuda 11.7跑不起来
|
||||
|
||||
解决方案:改用cuda 11.8
|
||||
|
||||
## 赞助
|
||||
<img src="https://i.imgur.com/EMCP5Lv.jpeg" width="600">
|
||||
|
||||
| 捐赠者 | 累计捐赠金额 | 赞助席位 |
|
||||
| --- | --- | --- |
|
||||
| 坤V | 400.00 RMB | 金牌赞助席位 |
|
||||
| 陈凯 | 50.00 RMB | 银牌赞助席位 |
|
||||
| Tshuang | 20.00 RMB | 银牌赞助席位 |
|
||||
| 很奇异| 15.00 RMB | 银牌赞助席位 |
|
||||
| 何斐| 10.00 RMB | 银牌赞助席位 |
|
||||
| 长缨在手| 6.00 RMB | 银牌赞助席位 |
|
||||
| Leo| 1.00 RMB | 铜牌赞助席位 |
|
||||
<img src="https://github.com/YaoFANGUK/video-subtitle-extractor/raw/main/design/sponsor.png" width="600">
|
||||
|
||||
| 捐赠者 | 累计捐赠金额 | 赞助席位 |
|
||||
|---------------------------|------------| --- |
|
||||
| 坤V | 400.00 RMB | 金牌赞助席位 |
|
||||
| Jenkit | 200.00 RMB | 金牌赞助席位 |
|
||||
| 子车松兰 | 188.00 RMB | 金牌赞助席位 |
|
||||
| 落花未逝 | 100.00 RMB | 金牌赞助席位 |
|
||||
| 张音乐 | 100.00 RMB | 金牌赞助席位 |
|
||||
| 麦格 | 100.00 RMB | 金牌赞助席位 |
|
||||
| 无痕 | 100.00 RMB | 金牌赞助席位 |
|
||||
| wr | 100.00 RMB | 金牌赞助席位 |
|
||||
| 陈 | 100.00 RMB | 金牌赞助席位 |
|
||||
| lyons | 100.00 RMB | 金牌赞助席位 |
|
||||
| TalkLuv | 50.00 RMB | 银牌赞助席位 |
|
||||
| 陈凯 | 50.00 RMB | 银牌赞助席位 |
|
||||
| Freeman | 30.00 RMB | 银牌赞助席位 |
|
||||
| Tshuang | 20.00 RMB | 银牌赞助席位 |
|
||||
| 很奇异 | 15.00 RMB | 银牌赞助席位 |
|
||||
| 郭鑫 | 12.00 RMB | 银牌赞助席位 |
|
||||
| 生活不止眼前的苟且 | 10.00 RMB | 铜牌赞助席位 |
|
||||
| 何斐 | 10.00 RMB | 铜牌赞助席位 |
|
||||
| 老猫 | 8.80 RMB | 铜牌赞助席位 |
|
||||
| 伍六七 | 7.77 RMB | 铜牌赞助席位 |
|
||||
| 长缨在手 | 6.00 RMB | 铜牌赞助席位 |
|
||||
| 无忌 | 6.00 RMB | 铜牌赞助席位 |
|
||||
| Stephen | 2.00 RMB | 铜牌赞助席位 |
|
||||
| Leo | 1.00 RMB | 铜牌赞助席位 |
|
||||
|
||||
229
README_en.md
229
README_en.md
@@ -3,7 +3,7 @@
|
||||
## Project Introduction
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subtitles from videos. It mainly implements the following functionalities:
|
||||
@@ -12,6 +12,7 @@ Video-subtitle-remover (VSR) is an AI-based software that removes hardcoded subt
|
||||
- Fills in the removed subtitle text area using a powerful AI algorithm model (non-adjacent pixel filling and mosaic removal).
|
||||
- Supports custom subtitle positions by only removing subtitles in the defined location (input position).
|
||||
- Supports automatic removal of all text throughout the entire video (without inputting a position).
|
||||
- Supports multi-selection of images for batch removal of watermark text.
|
||||
|
||||
<p style="text-align:center;"><img src="https://github.com/YaoFANGUK/video-subtitle-remover/raw/main/design/demo.png" alt="demo.png"/></p>
|
||||
|
||||
@@ -25,7 +26,36 @@ Windows GPU Version v1.1.0 (GPU):
|
||||
|
||||
- Google Drive: <a href="https://drive.google.com/drive/folders/1NRgLNoHHOmdO4GxLhkPbHsYfMOB_3Elr?usp=sharing">vsr_windows_gpu_v1.1.0.zip</a>
|
||||
|
||||
> For use only by users with Nvidia graphics cards (AMD graphics cards are not supported).
|
||||
|
||||
**Pre-built Package Comparison**:
|
||||
|
||||
| Pre-built Package Name | Python | Paddle | Torch | Environment | Supported Compute Capability Range |
|
||||
|----------------------------------|--------|--------|--------|-----------------------------------|------------------------------------|
|
||||
| `vse-windows-directml.7z` | 3.12 | 3.0.0 | 2.4.1 | Windows without Nvidia GPU | Universal |
|
||||
| `vse-windows-nvidia-cuda-11.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 11.8 | 3.5 – 8.9 |
|
||||
| `vse-windows-nvidia-cuda-12.6.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.6 | 5.0 – 8.9 |
|
||||
| `vse-windows-nvidia-cuda-12.8.7z`| 3.12 | 3.0.0 | 2.7.0 | CUDA 12.8 | 5.0 – 9.0+ |
|
||||
|
||||
> NVIDIA provides a list of supported compute capabilities for each GPU model. You can refer to the following link: [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) to check which CUDA version is compatible with your GPU.
|
||||
|
||||
**Docker Versions:**
|
||||
```shell
|
||||
# Nvidia 10, 20, 30 Series Graphics Cards
|
||||
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda11.8
|
||||
|
||||
# Nvidia 40 Series Graphics Cards
|
||||
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.6
|
||||
|
||||
# Nvidia 50 Series Graphics Cards
|
||||
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-cuda12.8
|
||||
|
||||
# AMD / Intel Dedicated or Integrated Graphics
|
||||
docker run -it --name vsr --gpus all eritpchy/video-subtitle-remover:1.1.1-directml
|
||||
|
||||
# Demo video, input
|
||||
/vsr/test/test.mp4
|
||||
docker cp vsr:/vsr/test/test_no_sub.mp4 ./
|
||||
```
|
||||
|
||||
## Demonstration
|
||||
|
||||
@@ -39,117 +69,98 @@ Windows GPU Version v1.1.0 (GPU):
|
||||
|
||||
## Source Code Usage Instructions
|
||||
|
||||
> **Do not use this project without an Nvidia graphics card**. The minimum requirements are:
|
||||
>
|
||||
> **GPU**: GTX 1060 or higher graphics card
|
||||
>
|
||||
> CPU: Supports AVX instruction set
|
||||
#### 1. Install Python
|
||||
|
||||
#### 1. Download and install Miniconda
|
||||
Please ensure that you have installed Python 3.12+.
|
||||
|
||||
- Windows: <a href="https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Windows-x86_64.exe">Miniconda3-py38_4.11.0-Windows-x86_64.exe</a>
|
||||
- Windows users can go to the [Python official website](https://www.python.org/downloads/windows/) to download and install Python.
|
||||
- MacOS users can install using Homebrew:
|
||||
```shell
|
||||
brew install python@3.12
|
||||
```
|
||||
- Linux users can install via the package manager, such as on Ubuntu/Debian:
|
||||
```shell
|
||||
sudo apt update && sudo apt install python3.12 python3.12-venv python3.12-dev
|
||||
```
|
||||
|
||||
- Linux: <a href="https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Linux-x86_64.sh">Miniconda3-py38_4.11.0-Linux-x86_64.sh</a>
|
||||
#### 2. Install Dependencies
|
||||
|
||||
#### 2. Create and activate a virtual environment
|
||||
It is recommended to use a virtual environment to manage project dependencies to avoid conflicts with the system environment.
|
||||
|
||||
(1) Switch to the source code directory:
|
||||
(1) Create and activate the virtual environment:
|
||||
```shell
|
||||
python -m venv videoEnv
|
||||
```
|
||||
|
||||
- Windows:
|
||||
```shell
|
||||
videoEnv\\Scripts\\activate
|
||||
```
|
||||
- MacOS/Linux:
|
||||
```shell
|
||||
source videoEnv/bin/activate
|
||||
```
|
||||
|
||||
#### 3. Create and Activate Project Directory
|
||||
|
||||
Change to the directory where your source code is located:
|
||||
```shell
|
||||
cd <source_code_directory>
|
||||
```
|
||||
> For example, if your source code is in the `tools` folder on the D drive and the folder name is `video-subtitle-remover`, use:
|
||||
> ```shell
|
||||
> cd D:/tools/video-subtitle-remover-main
|
||||
> ```
|
||||
|
||||
> For example, if your source code is in the `tools` folder on drive D, and the source code folder name is `video-subtitle-remover`, enter `cd D:/tools/video-subtitle-remover-main`.
|
||||
#### 4. Install the Appropriate Runtime Environment
|
||||
|
||||
(2) Create and activate the conda environment:
|
||||
This project supports two runtime modes: CUDA (NVIDIA GPU acceleration) and DirectML (AMD, Intel, and other GPUs/APUs).
|
||||
|
||||
```shell
|
||||
conda create -n videoEnv python=3.8
|
||||
```
|
||||
##### (1) CUDA (For NVIDIA GPU users)
|
||||
|
||||
```shell
|
||||
conda activate videoEnv
|
||||
```
|
||||
> Make sure your NVIDIA GPU driver supports the selected CUDA version.
|
||||
|
||||
#### 3. Install dependencies
|
||||
|
||||
Please make sure you have already installed Python 3.8+, use conda to create a project virtual environment and activate the environment (it is recommended to create a virtual environment to run to avoid subsequent problems).
|
||||
|
||||
- Install **CUDA** and **cuDNN**
|
||||
|
||||
<details>
|
||||
<summary>Linux</summary>
|
||||
<h5>(1) Download CUDA 11.7</h5>
|
||||
<pre><code>wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run</code></pre>
|
||||
<h5>(2) Install CUDA 11.7</h5>
|
||||
<pre><code>sudo sh cuda_11.7.0_515.43.04_linux.run</code></pre>
|
||||
<p>1. Input accept</p>
|
||||
<img src="https://i.328888.xyz/2023/03/31/iwVoeH.png" width="500" alt="">
|
||||
<p>2. make sure CUDA Toolkit 11.7 is chosen (If you have already installed driver, do not select Driver)</p>
|
||||
<img src="https://i.328888.xyz/2023/03/31/iwVThJ.png" width="500" alt="">
|
||||
<p>3. Add environment variables</p>
|
||||
<p>add the following content in <strong>~/.bashrc</strong></p>
|
||||
<pre><code># CUDA
|
||||
export PATH=/usr/local/cuda-11.7/bin${PATH:+:${PATH}}
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}</code></pre>
|
||||
<p>Make sure it works</p>
|
||||
<pre><code>source ~/.bashrc</code></pre>
|
||||
<h5>(3) Download cuDNN 8.4.1</h5>
|
||||
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz">cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz</a></p>
|
||||
<h5>(4) Install cuDNN 8.4.1</h5>
|
||||
<pre><code> tar -xf cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz
|
||||
mv cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive cuda
|
||||
sudo cp ./cuda/include/* /usr/local/cuda-11.7/include/
|
||||
sudo cp ./cuda/lib/* /usr/local/cuda-11.7/lib64/
|
||||
sudo chmod a+r /usr/local/cuda-11.7/lib64/*
|
||||
sudo chmod a+r /usr/local/cuda-11.7/include/*</code></pre>
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Windows</summary>
|
||||
<h5>(1) Download CUDA 11.7</h5>
|
||||
<a href="https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_516.01_windows.exe">cuda_11.7.0_516.01_windows.exe</a>
|
||||
<h5>(2) Install CUDA 11.7</h5>
|
||||
<h5>(3) Download cuDNN 8.2.4</h5>
|
||||
<p><a href="https://github.com/YaoFANGUK/video-subtitle-extractor/releases/download/1.0.0/cudnn-windows-x64-v8.2.4.15.zip">cudnn-windows-x64-v8.2.4.15.zip</a></p>
|
||||
<h5>(4) Install cuDNN 8.2.4</h5>
|
||||
<p>
|
||||
unzip "cudnn-windows-x64-v8.2.4.15.zip", then move all files in "bin, include, lib" in cuda
|
||||
directory to C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\
|
||||
</p>
|
||||
</details>
|
||||
|
||||
|
||||
- Install GPU version of Paddlepaddle:
|
||||
- windows:
|
||||
|
||||
```shell
|
||||
python -m pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html
|
||||
```
|
||||
- Recommended CUDA 11.8, corresponding to cuDNN 8.6.0.
|
||||
|
||||
- Install CUDA:
|
||||
- Windows: [Download CUDA 11.8](https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe)
|
||||
- Linux:
|
||||
```shell
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
|
||||
sudo sh cuda_11.8.0_520.61.05_linux.run
|
||||
```
|
||||
- CUDA is not supported on MacOS.
|
||||
|
||||
```shell
|
||||
python -m pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
|
||||
```
|
||||
- Install cuDNN (CUDA 11.8 corresponds to cuDNN 8.6.0):
|
||||
- [Windows cuDNN 8.6.0 Download](https://developer.download.nvidia.cn/compute/redist/cudnn/v8.6.0/local_installers/11.8/cudnn-windows-x86_64-8.6.0.163_cuda11-archive.zip)
|
||||
- [Linux cuDNN 8.6.0 Download](https://developer.download.nvidia.cn/compute/redist/cudnn/v8.6.0/local_installers/11.8/cudnn-linux-x86_64-8.6.0.163_cuda11-archive.tar.xz)
|
||||
- Follow the installation guide in the NVIDIA official documentation.
|
||||
|
||||
- Install GPU version of Pytorch:
|
||||
|
||||
```shell
|
||||
conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
|
||||
- Install PaddlePaddle GPU version (CUDA 11.8):
|
||||
```shell
|
||||
pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
|
||||
```
|
||||
or use
|
||||
|
||||
```shell
|
||||
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu117
|
||||
|
||||
- Install Torch GPU version (CUDA 11.8):
|
||||
```shell
|
||||
pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
- Install other dependencies:
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
##### (2) DirectML (For AMD, Intel, and other GPU/APU users)
|
||||
|
||||
- Suitable for Windows devices with AMD/NVIDIA/Intel GPUs.
|
||||
- Install ONNX Runtime DirectML version:
|
||||
```shell
|
||||
pip install paddlepaddle==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
|
||||
pip install -r requirements.txt
|
||||
pip install -r requirements_directml.txt
|
||||
```
|
||||
|
||||
|
||||
#### 4. Run the program
|
||||
|
||||
@@ -166,12 +177,52 @@ python ./backend/main.py
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
1. CondaHTTPError
|
||||
|
||||
1. How to deal with slow removal speed
|
||||
|
||||
You can greatly increase the removal speed by modifying the parameters in backend/config.py:
|
||||
|
||||
```python
|
||||
MODE = InpaintMode.STTN # Set to STTN algorithm
|
||||
STTN_SKIP_DETECTION = True # Skip subtitle detection
|
||||
```
|
||||
|
||||
2. What to do if the video removal results are not satisfactory
|
||||
|
||||
Modify the values in backend/config.py and try different removal algorithms. Here is an introduction to the algorithms:
|
||||
|
||||
> - **InpaintMode.STTN** algorithm: Good for live-action videos and fast in speed, capable of skipping subtitle detection
|
||||
> - **InpaintMode.LAMA** algorithm: Best for images and effective for animated videos, moderate speed, unable to skip subtitle detection
|
||||
> - **InpaintMode.PROPAINTER** algorithm: Consumes a significant amount of VRAM, slower in speed, works better for videos with very intense movement
|
||||
|
||||
- Using the STTN algorithm
|
||||
|
||||
```python
|
||||
MODE = InpaintMode.STTN # Set to STTN algorithm
|
||||
# Number of neighboring frames, increasing this will increase memory usage and improve the result
|
||||
STTN_NEIGHBOR_STRIDE = 10
|
||||
# Length of reference frames, increasing this will increase memory usage and improve the result
|
||||
STTN_REFERENCE_LENGTH = 10
|
||||
# Set the maximum number of frames processed simultaneously by the STTN algorithm, a larger value leads to slower processing but better results
|
||||
# Ensure that STTN_MAX_LOAD_NUM is greater than STTN_NEIGHBOR_STRIDE and STTN_REFERENCE_LENGTH
|
||||
STTN_MAX_LOAD_NUM = 30
|
||||
```
|
||||
- Using the LAMA algorithm
|
||||
|
||||
```python
|
||||
MODE = InpaintMode.LAMA # Set to LAMA algorithm
|
||||
LAMA_SUPER_FAST = False # Ensure quality
|
||||
```
|
||||
|
||||
|
||||
3. CondaHTTPError
|
||||
|
||||
Place the .condarc file from the project in the user directory (C:/Users/<your_username>). If the file already exists in the user directory, overwrite it.
|
||||
|
||||
Solution: https://zhuanlan.zhihu.com/p/260034241
|
||||
|
||||
2. 7z file extraction error
|
||||
4. 7z file extraction error
|
||||
|
||||
Solution: Upgrade the 7-zip extraction program to the latest version.
|
||||
|
||||
|
||||
|
||||
@@ -7,12 +7,20 @@ import logging
|
||||
import platform
|
||||
import stat
|
||||
from fsplit.filesplit import Filesplit
|
||||
import paddle
|
||||
import onnxruntime as ort
|
||||
|
||||
# 项目版本号
|
||||
VERSION = "1.1.1"
|
||||
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
|
||||
paddle.disable_signal_handler()
|
||||
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
|
||||
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
try:
|
||||
import torch_directml
|
||||
device = torch_directml.device(torch_directml.default_device())
|
||||
USE_DML = True
|
||||
except:
|
||||
USE_DML = False
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
|
||||
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
|
||||
@@ -50,6 +58,27 @@ if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'
|
||||
# 将ffmpeg添加可执行权限
|
||||
os.chmod(FFMPEG_PATH, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||
|
||||
# 是否使用ONNX(DirectML/AMD/Intel)
|
||||
ONNX_PROVIDERS = []
|
||||
available_providers = ort.get_available_providers()
|
||||
for provider in available_providers:
|
||||
if provider in [
|
||||
"CPUExecutionProvider"
|
||||
]:
|
||||
continue
|
||||
if provider not in [
|
||||
"DmlExecutionProvider", # DirectML,适用于 Windows GPU
|
||||
"ROCMExecutionProvider", # AMD ROCm
|
||||
"MIGraphXExecutionProvider", # AMD MIGraphX
|
||||
"VitisAIExecutionProvider", # AMD VitisAI,适用于 RyzenAI & Windows, 实测和DirectML性能似乎差不多
|
||||
"OpenVINOExecutionProvider", # Intel GPU
|
||||
"MetalExecutionProvider", # Apple macOS
|
||||
"CoreMLExecutionProvider", # Apple macOS
|
||||
"CUDAExecutionProvider", # Nvidia GPU
|
||||
]:
|
||||
continue
|
||||
ONNX_PROVIDERS.append(provider)
|
||||
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
|
||||
|
||||
|
||||
@@ -64,12 +93,17 @@ class InpaintMode(Enum):
|
||||
|
||||
|
||||
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
||||
# 是否使用h264编码,如果需要安卓手机分享生成的视频,请打开该选项
|
||||
USE_H264 = True
|
||||
|
||||
# ×××××××××× 通用设置 start ××××××××××
|
||||
"""
|
||||
MODE可选算法类型
|
||||
- InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
|
||||
- InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测
|
||||
- InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
|
||||
"""
|
||||
# 【设置inpaint算法】
|
||||
# - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
|
||||
# - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以字幕检测
|
||||
# - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
|
||||
MODE = InpaintMode.STTN
|
||||
# 【设置像素点偏差】
|
||||
# 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
|
||||
@@ -85,17 +119,33 @@ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
|
||||
|
||||
# ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
|
||||
# 以下参数仅适用STTN算法时,才生效
|
||||
# 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧
|
||||
"""
|
||||
1. STTN_SKIP_DETECTION
|
||||
含义:是否使用跳过检测
|
||||
效果:设置为True跳过字幕检测,会省去很大时间,但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了
|
||||
|
||||
2. STTN_NEIGHBOR_STRIDE
|
||||
含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域,STTN_NEIGHBOR_STRIDE=5,那么算法会使用第45帧、第40帧等作为参照。
|
||||
效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。
|
||||
|
||||
3. STTN_REFERENCE_LENGTH
|
||||
含义:参数帧数量,STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息
|
||||
效果:调大会增加显存占用,处理效果变好,但是处理速度变慢
|
||||
|
||||
4. STTN_MAX_LOAD_NUM
|
||||
含义:STTN算法每次最多加载的视频帧数量
|
||||
效果:设置越大速度越慢,但效果越好
|
||||
注意:要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
|
||||
"""
|
||||
STTN_SKIP_DETECTION = True
|
||||
# 相邻帧数, 调大会增加显存占用,效果变好
|
||||
STTN_NEIGHBOR_STRIDE = 10
|
||||
# 参考帧长度, 调大会增加显存占用,效果变好
|
||||
# 参考帧步长
|
||||
STTN_NEIGHBOR_STRIDE = 5
|
||||
# 参考帧长度(数量)
|
||||
STTN_REFERENCE_LENGTH = 10
|
||||
# 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好
|
||||
# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
|
||||
STTN_MAX_LOAD_NUM = 30
|
||||
if STTN_MAX_LOAD_NUM < max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH):
|
||||
STTN_MAX_LOAD_NUM = max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH)
|
||||
# 设置STTN算法最大同时处理的帧数量
|
||||
STTN_MAX_LOAD_NUM = 50
|
||||
if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE:
|
||||
STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE
|
||||
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
|
||||
|
||||
# ×××××××××× InpaintMode.PROPAINTER算法设置 start ××××××××××
|
||||
|
||||
@@ -27,7 +27,7 @@ class STTNInpaint:
|
||||
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
|
||||
self.model = InpaintGenerator().to(self.device)
|
||||
# 2. 载入预训练模型的权重,转载模型的状态字典
|
||||
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG'])
|
||||
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location='cpu')['netG'])
|
||||
# 3. # 将模型设置为评估模式
|
||||
self.model.eval()
|
||||
# 模型输入用的宽和高
|
||||
@@ -93,6 +93,7 @@ class STTNInpaint:
|
||||
@staticmethod
|
||||
def read_mask(path):
|
||||
img = cv2.imread(path, 0)
|
||||
# 转为binary mask
|
||||
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
|
||||
img = img[:, :, None]
|
||||
return img
|
||||
@@ -200,6 +201,24 @@ class STTNInpaint:
|
||||
to_H -= h
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
|
||||
@staticmethod
|
||||
def get_inpaint_area_by_selection(input_sub_area, mask):
|
||||
print('use selection area for inpainting')
|
||||
height, width = mask.shape[:2]
|
||||
ymin, ymax, _, _ = input_sub_area
|
||||
interval_size = 135
|
||||
# 存储结果的列表
|
||||
inpaint_area = []
|
||||
# 计算并存储标准区间
|
||||
for i in range(ymin, ymax, interval_size):
|
||||
inpaint_area.append((i, i + interval_size))
|
||||
# 检查最后一个区间是否达到了最大值
|
||||
if inpaint_area[-1][1] != ymax:
|
||||
# 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值
|
||||
if inpaint_area[-1][1] + interval_size <= height:
|
||||
inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size))
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
|
||||
|
||||
class STTNVideoInpaint:
|
||||
|
||||
@@ -234,70 +253,106 @@ class STTNVideoInpaint:
|
||||
self.clip_gap = clip_gap
|
||||
|
||||
def __call__(self, input_mask=None, input_sub_remover=None, tbar=None):
|
||||
# 读取视频帧信息
|
||||
reader, frame_info = self.read_frame_info_from_video()
|
||||
if input_sub_remover is not None:
|
||||
writer = input_sub_remover.video_writer
|
||||
else:
|
||||
# 创建视频写入对象,用于输出修复后的视频
|
||||
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
|
||||
# 计算需要迭代修复视频的次数
|
||||
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
|
||||
# 计算分割高度,用于确定修复区域的大小
|
||||
split_h = int(frame_info['W_ori'] * 3 / 16)
|
||||
if input_mask is None:
|
||||
# 读取掩码
|
||||
mask = self.sttn_inpaint.read_mask(self.mask_path)
|
||||
else:
|
||||
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
|
||||
mask = mask[:, :, None]
|
||||
# 得到修复区域位置
|
||||
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
|
||||
# 遍历每一次的迭代次数
|
||||
for i in range(rec_time):
|
||||
start_f = i * self.clip_gap # 起始帧位置
|
||||
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
|
||||
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
|
||||
frames_hr = [] # 高分辨率帧列表
|
||||
frames = {} # 帧字典,用于存储裁剪后的图像
|
||||
comps = {} # 组合字典,用于存储修复后的图像
|
||||
# 初始化帧字典
|
||||
for k in range(len(inpaint_area)):
|
||||
frames[k] = []
|
||||
# 读取和修复高分辨率帧
|
||||
for j in range(start_f, end_f):
|
||||
success, image = reader.read()
|
||||
frames_hr.append(image)
|
||||
reader = None
|
||||
writer = None
|
||||
try:
|
||||
# 读取视频帧信息
|
||||
reader, frame_info = self.read_frame_info_from_video()
|
||||
if input_sub_remover is not None:
|
||||
writer = input_sub_remover.video_writer
|
||||
else:
|
||||
# 创建视频写入对象,用于输出修复后的视频
|
||||
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
|
||||
|
||||
# 计算需要迭代修复视频的次数
|
||||
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
|
||||
# 计算分割高度,用于确定修复区域的大小
|
||||
split_h = int(frame_info['W_ori'] * 3 / 16)
|
||||
|
||||
if input_mask is None:
|
||||
# 读取掩码
|
||||
mask = self.sttn_inpaint.read_mask(self.mask_path)
|
||||
else:
|
||||
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
|
||||
mask = mask[:, :, None]
|
||||
|
||||
# 得到修复区域位置
|
||||
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
|
||||
|
||||
# 遍历每一次的迭代次数
|
||||
for i in range(rec_time):
|
||||
start_f = i * self.clip_gap # 起始帧位置
|
||||
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
|
||||
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
|
||||
|
||||
frames_hr = [] # 高分辨率帧列表
|
||||
frames = {} # 帧字典,用于存储裁剪后的图像
|
||||
comps = {} # 组合字典,用于存储修复后的图像
|
||||
|
||||
# 初始化帧字典
|
||||
for k in range(len(inpaint_area)):
|
||||
# 裁剪、缩放并添加到帧字典
|
||||
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height))
|
||||
frames[k].append(image_resize)
|
||||
# 对每个修复区域运行修复
|
||||
for k in range(len(inpaint_area)):
|
||||
comps[k] = self.sttn_inpaint.inpaint(frames[k])
|
||||
# 如果有要修复的区域
|
||||
if inpaint_area is not []:
|
||||
for j in range(end_f - start_f):
|
||||
if input_sub_remover is not None and input_sub_remover.gui_mode:
|
||||
original_frame = copy.deepcopy(frames_hr[j])
|
||||
else:
|
||||
original_frame = None
|
||||
frame = frames_hr[j]
|
||||
frames[k] = []
|
||||
|
||||
# 读取和修复高分辨率帧
|
||||
valid_frames_count = 0
|
||||
for j in range(start_f, end_f):
|
||||
success, image = reader.read()
|
||||
if not success:
|
||||
print(f"Warning: Failed to read frame {j}.")
|
||||
break
|
||||
|
||||
frames_hr.append(image)
|
||||
valid_frames_count += 1
|
||||
|
||||
for k in range(len(inpaint_area)):
|
||||
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
|
||||
comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h))
|
||||
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
writer.write(frame)
|
||||
if input_sub_remover is not None and input_sub_remover.gui_mode:
|
||||
if tbar is not None:
|
||||
input_sub_remover.update_progress(tbar, increment=1)
|
||||
if original_frame is not None:
|
||||
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
|
||||
# 释放视频写入对象
|
||||
writer.release()
|
||||
# 裁剪、缩放并添加到帧字典
|
||||
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height))
|
||||
frames[k].append(image_resize)
|
||||
|
||||
# 如果没有读取到有效帧,则跳过当前迭代
|
||||
if valid_frames_count == 0:
|
||||
print(f"Warning: No valid frames found in range {start_f+1}-{end_f}. Skipping this segment.")
|
||||
continue
|
||||
|
||||
# 对每个修复区域运行修复
|
||||
for k in range(len(inpaint_area)):
|
||||
if len(frames[k]) > 0: # 确保有帧可以处理
|
||||
comps[k] = self.sttn_inpaint.inpaint(frames[k])
|
||||
else:
|
||||
comps[k] = []
|
||||
|
||||
# 如果有要修复的区域
|
||||
if inpaint_area and valid_frames_count > 0:
|
||||
for j in range(valid_frames_count):
|
||||
if input_sub_remover is not None and input_sub_remover.gui_mode:
|
||||
original_frame = copy.deepcopy(frames_hr[j])
|
||||
else:
|
||||
original_frame = None
|
||||
|
||||
frame = frames_hr[j]
|
||||
|
||||
for k in range(len(inpaint_area)):
|
||||
if j < len(comps[k]): # 确保索引有效
|
||||
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
|
||||
comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h))
|
||||
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
|
||||
writer.write(frame)
|
||||
|
||||
if input_sub_remover is not None:
|
||||
if tbar is not None:
|
||||
input_sub_remover.update_progress(tbar, increment=1)
|
||||
if original_frame is not None and input_sub_remover.gui_mode:
|
||||
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
|
||||
except Exception as e:
|
||||
print(f"Error during video processing: {str(e)}")
|
||||
# 不抛出异常,允许程序继续执行
|
||||
finally:
|
||||
if writer:
|
||||
writer.release()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -53,8 +53,16 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None
|
||||
return logger
|
||||
|
||||
|
||||
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
||||
torch.__version__)[0][:3])] >= [1, 12, 0]
|
||||
def get_version_numbers(version_str):
|
||||
# 匹配主要版本号(支持 2.8.0 或 2.8.0.dev20250422+cu128 等格式)
|
||||
pattern = r"^(\d+)\.(\d+)\.(\d+)"
|
||||
match = re.match(pattern, version_str)
|
||||
if match:
|
||||
return [int(x) for x in match.groups()]
|
||||
return [0, 0, 0] # 如果无法匹配,返回默认值
|
||||
|
||||
# 使用示例
|
||||
IS_HIGH_VERSION = get_version_numbers(torch.__version__) >= [1, 12, 0]
|
||||
|
||||
|
||||
def gpu_is_available():
|
||||
|
||||
163
backend/main.py
163
backend/main.py
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import shutil
|
||||
import subprocess
|
||||
import os
|
||||
@@ -5,9 +6,12 @@ from pathlib import Path
|
||||
import threading
|
||||
import cv2
|
||||
import sys
|
||||
from functools import cached_property
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import config
|
||||
from backend.tools.common_tools import is_video_or_image, is_image_file
|
||||
from backend.scenedetect import scene_detect
|
||||
from backend.scenedetect.detectors import ContentDetector
|
||||
from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint
|
||||
@@ -17,13 +21,10 @@ from backend.tools.inpaint_tools import create_mask, batch_generator
|
||||
import importlib
|
||||
import platform
|
||||
import tempfile
|
||||
import torch
|
||||
import multiprocessing
|
||||
from shapely.geometry import Polygon
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from tools.infer import utility
|
||||
from tools.infer.predict_det import TextDetector
|
||||
|
||||
|
||||
class SubtitleDetect:
|
||||
@@ -32,14 +33,23 @@ class SubtitleDetect:
|
||||
"""
|
||||
|
||||
def __init__(self, video_path, sub_area=None):
|
||||
self.video_path = video_path
|
||||
self.sub_area = sub_area
|
||||
|
||||
@cached_property
|
||||
def text_detector(self):
|
||||
import paddle
|
||||
paddle.disable_signal_handler()
|
||||
from paddleocr.tools.infer import utility
|
||||
from paddleocr.tools.infer.predict_det import TextDetector
|
||||
# 获取参数对象
|
||||
importlib.reload(config)
|
||||
args = utility.parse_args()
|
||||
args.det_algorithm = 'DB'
|
||||
args.det_model_dir = config.DET_MODEL_PATH
|
||||
self.text_detector = TextDetector(args)
|
||||
self.video_path = video_path
|
||||
self.sub_area = sub_area
|
||||
args.det_model_dir = self.convertToOnnxModelIfNeeded(config.DET_MODEL_PATH)
|
||||
args.use_onnx=len(config.ONNX_PROVIDERS) > 0
|
||||
args.onnx_providers=config.ONNX_PROVIDERS
|
||||
return TextDetector(args)
|
||||
|
||||
def detect_subtitle(self, img):
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
@@ -119,6 +129,52 @@ class SubtitleDetect:
|
||||
new_subtitle_frame_no_box_dict[key] = subtitle_frame_no_box_dict[key]
|
||||
return new_subtitle_frame_no_box_dict
|
||||
|
||||
def convertToOnnxModelIfNeeded(self, model_dir, model_filename="inference.pdmodel", params_filename="inference.pdiparams", opset_version=14):
|
||||
"""Converts a Paddle model to ONNX if ONNX providers are available and the model does not already exist."""
|
||||
|
||||
if not config.ONNX_PROVIDERS:
|
||||
return model_dir
|
||||
|
||||
onnx_model_path = os.path.join(model_dir, "model.onnx")
|
||||
|
||||
if os.path.exists(onnx_model_path):
|
||||
print(f"ONNX model already exists: {onnx_model_path}. Skipping conversion.")
|
||||
return onnx_model_path
|
||||
|
||||
print(f"Converting Paddle model {model_dir} to ONNX...")
|
||||
model_file = os.path.join(model_dir, model_filename)
|
||||
params_file = os.path.join(model_dir, params_filename) if params_filename else ""
|
||||
|
||||
try:
|
||||
import paddle2onnx
|
||||
# Ensure the target directory exists
|
||||
os.makedirs(os.path.dirname(onnx_model_path), exist_ok=True)
|
||||
|
||||
# Convert and save the model
|
||||
onnx_model = paddle2onnx.export(
|
||||
model_filename=model_file,
|
||||
params_filename=params_file,
|
||||
save_file=onnx_model_path,
|
||||
opset_version=opset_version,
|
||||
auto_upgrade_opset=True,
|
||||
verbose=True,
|
||||
enable_onnx_checker=True,
|
||||
enable_experimental_op=True,
|
||||
enable_optimize=True,
|
||||
custom_op_info={},
|
||||
deploy_backend="onnxruntime",
|
||||
calibration_file="calibration.cache",
|
||||
external_file=os.path.join(model_dir, "external_data"),
|
||||
export_fp16_model=False,
|
||||
)
|
||||
|
||||
print(f"Conversion successful. ONNX model saved to: {onnx_model_path}")
|
||||
return onnx_model_path
|
||||
except Exception as e:
|
||||
print(f"Error during conversion: {e}")
|
||||
return model_dir
|
||||
|
||||
|
||||
@staticmethod
|
||||
def split_range_by_scene(intervals, points):
|
||||
# 确保离散值列表是有序的
|
||||
@@ -257,7 +313,43 @@ class SubtitleDetect:
|
||||
return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
|
||||
|
||||
@staticmethod
|
||||
def expand_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
|
||||
def expand_and_merge_intervals(intervals, expand_size=config.STTN_NEIGHBOR_STRIDE*config.STTN_REFERENCE_LENGTH, max_length=config.STTN_MAX_LOAD_NUM):
|
||||
# 初始化输出区间列表
|
||||
expanded_intervals = []
|
||||
|
||||
# 对每个原始区间进行扩展
|
||||
for interval in intervals:
|
||||
start, end = interval
|
||||
|
||||
# 扩展至至少 'expand_size' 个单位,但不超过 'max_length' 个单位
|
||||
expansion_amount = max(expand_size - (end - start + 1), 0)
|
||||
|
||||
# 在保证包含原区间的前提下尽可能平分前后扩展量
|
||||
expand_start = max(start - expansion_amount // 2, 1) # 确保起始点不小于1
|
||||
expand_end = end + expansion_amount // 2
|
||||
|
||||
# 如果扩展后的区间超出了最大长度,进行调整
|
||||
if (expand_end - expand_start + 1) > max_length:
|
||||
expand_end = expand_start + max_length - 1
|
||||
|
||||
# 对于单点的处理,需额外保证有至少 'expand_size' 长度
|
||||
if start == end:
|
||||
if expand_end - expand_start + 1 < expand_size:
|
||||
expand_end = expand_start + expand_size - 1
|
||||
|
||||
# 检查与前一个区间是否有重叠并进行相应的合并
|
||||
if expanded_intervals and expand_start <= expanded_intervals[-1][1]:
|
||||
previous_start, previous_end = expanded_intervals.pop()
|
||||
expand_start = previous_start
|
||||
expand_end = max(expand_end, previous_end)
|
||||
|
||||
# 添加扩展后的区间至结果列表
|
||||
expanded_intervals.append((expand_start, expand_end))
|
||||
|
||||
return expanded_intervals
|
||||
|
||||
@staticmethod
|
||||
def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
|
||||
"""
|
||||
合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH
|
||||
"""
|
||||
@@ -290,12 +382,10 @@ class SubtitleDetect:
|
||||
for start, end in expanded[1:]:
|
||||
last_start, last_end = merged[-1]
|
||||
# 检查是否重叠
|
||||
if start <= last_end and (
|
||||
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
# 需要合并
|
||||
merged[-1] = (last_start, max(last_end, end)) # 合并区间
|
||||
elif start == last_end + 1 and (
|
||||
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
# 相邻区间也需要合并的场景
|
||||
merged[-1] = (last_start, end)
|
||||
else:
|
||||
@@ -483,7 +573,7 @@ class SubtitleRemover:
|
||||
self.gui_mode = gui_mode
|
||||
# 判断是否为图片
|
||||
self.is_picture = False
|
||||
if str(vd_path).endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
||||
if is_image_file(str(vd_path)):
|
||||
self.sub_area = None
|
||||
self.is_picture = True
|
||||
# 视频路径
|
||||
@@ -505,8 +595,7 @@ class SubtitleRemover:
|
||||
# 创建视频临时对象,windows下delete=True会有permission denied的报错
|
||||
self.video_temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
||||
# 创建视频写对象
|
||||
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps,
|
||||
self.size)
|
||||
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
|
||||
self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
|
||||
self.video_inpaint = None
|
||||
self.lama_inpaint = None
|
||||
@@ -518,6 +607,14 @@ class SubtitleRemover:
|
||||
self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
|
||||
if torch.cuda.is_available():
|
||||
print('use GPU for acceleration')
|
||||
if config.USE_DML:
|
||||
print('use DirectML for acceleration')
|
||||
if config.MODE != config.InpaintMode.STTN:
|
||||
print('Warning: DirectML acceleration is only available for STTN model. Falling back to CPU for other models.')
|
||||
for provider in config.ONNX_PROVIDERS:
|
||||
print(f"Detected execution provider: {provider}")
|
||||
|
||||
|
||||
# 总处理进度
|
||||
self.progress_total = 0
|
||||
self.progress_remover = 0
|
||||
@@ -647,16 +744,14 @@ class SubtitleRemover:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
inpainted_frame = self.lama_inpaint(frame, single_mask)
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(
|
||||
f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||||
inner_index += 1
|
||||
self.update_progress(tbar, increment=1)
|
||||
elif len(batch) > 1:
|
||||
inpainted_frames = self.video_inpaint.inpaint(batch, mask)
|
||||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(
|
||||
f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
|
||||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
|
||||
inner_index += 1
|
||||
if self.gui_mode:
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
@@ -670,12 +765,13 @@ class SubtitleRemover:
|
||||
print('[Processing] start removing subtitles...')
|
||||
if self.sub_area is not None:
|
||||
ymin, ymax, xmin, xmax = self.sub_area
|
||||
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
|
||||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||||
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
|
||||
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
|
||||
else:
|
||||
print('please set subtitle area first')
|
||||
print('[Info] No subtitle area has been set. Video will be processed in full screen. As a result, the final outcome might be suboptimal.')
|
||||
ymin, ymax, xmin, xmax = 0, self.frame_height, 0, self.frame_width
|
||||
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
|
||||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||||
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
|
||||
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
|
||||
|
||||
def sttn_mode(self, tbar):
|
||||
# 是否跳过字幕帧寻找
|
||||
@@ -688,7 +784,7 @@ class SubtitleRemover:
|
||||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||||
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
||||
print(continuous_frame_no_list)
|
||||
continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list)
|
||||
continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list)
|
||||
print(continuous_frame_no_list)
|
||||
start_end_map = dict()
|
||||
for interval in continuous_frame_no_list:
|
||||
@@ -847,7 +943,7 @@ class SubtitleRemover:
|
||||
audio_merge_command = [config.FFMPEG_PATH,
|
||||
"-y", "-i", self.video_temp_file.name,
|
||||
"-i", temp.name,
|
||||
"-vcodec", "copy",
|
||||
"-vcodec", "libx264" if config.USE_H264 else "copy",
|
||||
"-acodec", "copy",
|
||||
"-loglevel", "error", self.video_out_name]
|
||||
try:
|
||||
@@ -859,7 +955,10 @@ class SubtitleRemover:
|
||||
try:
|
||||
os.remove(temp.name)
|
||||
except Exception:
|
||||
print(f'failed to delete temp file {temp.name}')
|
||||
if platform.system() in ['Windows']:
|
||||
pass
|
||||
else:
|
||||
print(f'failed to delete temp file {temp.name}')
|
||||
self.is_successful_merged = True
|
||||
finally:
|
||||
temp.close()
|
||||
@@ -874,9 +973,13 @@ class SubtitleRemover:
|
||||
if __name__ == '__main__':
|
||||
multiprocessing.set_start_method("spawn")
|
||||
# 1. 提示用户输入视频路径
|
||||
video_path = input(f"Please input video file path: ").strip()
|
||||
video_path = input(f"Please input video or image file path: ").strip()
|
||||
# 判断视频路径是不是一个目录,是目录的化,批量处理改目录下的所有视频文件
|
||||
# 2. 按以下顺序传入字幕区域
|
||||
# sub_area = (ymin, ymax, xmin, xmax)
|
||||
# 3. 新建字幕提取对象
|
||||
sd = SubtitleRemover(video_path, sub_area=None)
|
||||
sd.run()
|
||||
if is_video_or_image(video_path):
|
||||
sd = SubtitleRemover(video_path, sub_area=None)
|
||||
sd.run()
|
||||
else:
|
||||
print(f'Invalid video path: {video_path}')
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=Warning)
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
@@ -1,109 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import skimage
|
||||
import paddle
|
||||
import signal
|
||||
import random
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import copy
|
||||
from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
|
||||
import paddle.distributed as dist
|
||||
|
||||
from ppocr.data.imaug import transform, create_operators
|
||||
from ppocr.data.simple_dataset import SimpleDataSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||
from ppocr.data.pgnet_dataset import PGDataSet
|
||||
from ppocr.data.pubtab_dataset import PubTabDataSet
|
||||
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||
|
||||
|
||||
def term_mp(sig_num, frame):
|
||||
""" kill all child processes
|
||||
"""
|
||||
pid = os.getpid()
|
||||
pgid = os.getpgid(os.getpid())
|
||||
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
|
||||
|
||||
def build_dataloader(config, mode, device, logger, seed=None):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
support_dict = [
|
||||
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'
|
||||
]
|
||||
module_name = config[mode]['dataset']['name']
|
||||
assert module_name in support_dict, Exception(
|
||||
'DataSet only support {}'.format(support_dict))
|
||||
assert mode in ['Train', 'Eval', 'Test'
|
||||
], "Mode should be Train, Eval or Test."
|
||||
|
||||
dataset = eval(module_name)(config, mode, logger, seed)
|
||||
loader_config = config[mode]['loader']
|
||||
batch_size = loader_config['batch_size_per_card']
|
||||
drop_last = loader_config['drop_last']
|
||||
shuffle = loader_config['shuffle']
|
||||
num_workers = loader_config['num_workers']
|
||||
if 'use_shared_memory' in loader_config.keys():
|
||||
use_shared_memory = loader_config['use_shared_memory']
|
||||
else:
|
||||
use_shared_memory = True
|
||||
|
||||
if mode == "Train":
|
||||
# Distribute data to multiple cards
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last)
|
||||
else:
|
||||
# Distribute data to single card
|
||||
batch_sampler = BatchSampler(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last)
|
||||
|
||||
if 'collate_fn' in loader_config:
|
||||
from . import collate_fn
|
||||
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
|
||||
else:
|
||||
collate_fn = None
|
||||
data_loader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
places=device,
|
||||
num_workers=num_workers,
|
||||
return_list=True,
|
||||
use_shared_memory=use_shared_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
# support exit using ctrl+c
|
||||
signal.signal(signal.SIGINT, term_mp)
|
||||
signal.signal(signal.SIGTERM, term_mp)
|
||||
|
||||
return data_loader
|
||||
@@ -1,72 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import numbers
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class DictCollator(object):
|
||||
"""
|
||||
data batch
|
||||
"""
|
||||
|
||||
def __call__(self, batch):
|
||||
# todo:support batch operators
|
||||
data_dict = defaultdict(list)
|
||||
to_tensor_keys = []
|
||||
for sample in batch:
|
||||
for k, v in sample.items():
|
||||
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
||||
if k not in to_tensor_keys:
|
||||
to_tensor_keys.append(k)
|
||||
data_dict[k].append(v)
|
||||
for k in to_tensor_keys:
|
||||
data_dict[k] = paddle.to_tensor(data_dict[k])
|
||||
return data_dict
|
||||
|
||||
|
||||
class ListCollator(object):
|
||||
"""
|
||||
data batch
|
||||
"""
|
||||
|
||||
def __call__(self, batch):
|
||||
# todo:support batch operators
|
||||
data_dict = defaultdict(list)
|
||||
to_tensor_idxs = []
|
||||
for sample in batch:
|
||||
for idx, v in enumerate(sample):
|
||||
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
||||
if idx not in to_tensor_idxs:
|
||||
to_tensor_idxs.append(idx)
|
||||
data_dict[idx].append(v)
|
||||
for idx in to_tensor_idxs:
|
||||
data_dict[idx] = paddle.to_tensor(data_dict[idx])
|
||||
return list(data_dict.values())
|
||||
|
||||
|
||||
class SSLRotateCollate(object):
|
||||
"""
|
||||
bach: [
|
||||
[(4*3xH*W), (4,)]
|
||||
[(4*3xH*W), (4,)]
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
def __call__(self, batch):
|
||||
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
|
||||
return output
|
||||
@@ -1,26 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from paddle.vision.transforms import ColorJitter as pp_ColorJitter
|
||||
|
||||
__all__ = ['ColorJitter']
|
||||
|
||||
class ColorJitter(object):
|
||||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs):
|
||||
self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
|
||||
|
||||
def __call__(self, data):
|
||||
image = data['image']
|
||||
image = self.aug(image)
|
||||
data['image'] = image
|
||||
return data
|
||||
@@ -1,74 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from .iaa_augment import IaaAugment
|
||||
from .make_border_map import MakeBorderMap
|
||||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
||||
from .make_pse_gt import MakePseGt
|
||||
|
||||
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .ColorJitter import ColorJitter
|
||||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
||||
from .east_process import *
|
||||
from .sast_process import *
|
||||
from .pg_process import *
|
||||
from .gen_table_mask import *
|
||||
|
||||
from .vqa import *
|
||||
|
||||
from .fce_aug import *
|
||||
from .fce_targets import FCENetTargets
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
""" transform """
|
||||
if ops is None:
|
||||
ops = []
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def create_operators(op_param_list, global_config=None):
|
||||
"""
|
||||
create operators based on the config
|
||||
|
||||
Args:
|
||||
params(list): a dict list, used to create some operators
|
||||
"""
|
||||
assert isinstance(op_param_list, list), ('operator config should be a list')
|
||||
ops = []
|
||||
for operator in op_param_list:
|
||||
assert isinstance(operator,
|
||||
dict) and len(operator) == 1, "yaml format error"
|
||||
op_name = list(operator)[0]
|
||||
param = {} if operator[op_name] is None else operator[op_name]
|
||||
if global_config is not None:
|
||||
param.update(global_config)
|
||||
op = eval(op_name)(**param)
|
||||
ops.append(op)
|
||||
return ops
|
||||
@@ -1,170 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import cv2
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from ppocr.data.imaug.iaa_augment import IaaAugment
|
||||
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
|
||||
from tools.infer.utility import get_rotate_crop_image
|
||||
|
||||
|
||||
class CopyPaste(object):
|
||||
def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
|
||||
self.ext_data_num = 1
|
||||
self.objects_paste_ratio = objects_paste_ratio
|
||||
self.limit_paste = limit_paste
|
||||
augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
|
||||
self.aug = IaaAugment(augmenter_args)
|
||||
|
||||
def __call__(self, data):
|
||||
point_num = data['polys'].shape[1]
|
||||
src_img = data['image']
|
||||
src_polys = data['polys'].tolist()
|
||||
src_ignores = data['ignore_tags'].tolist()
|
||||
ext_data = data['ext_data'][0]
|
||||
ext_image = ext_data['image']
|
||||
ext_polys = ext_data['polys']
|
||||
ext_ignores = ext_data['ignore_tags']
|
||||
|
||||
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
|
||||
select_num = max(
|
||||
1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
|
||||
|
||||
random.shuffle(indexs)
|
||||
select_idxs = indexs[:select_num]
|
||||
select_polys = ext_polys[select_idxs]
|
||||
select_ignores = ext_ignores[select_idxs]
|
||||
|
||||
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
||||
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
|
||||
src_img = Image.fromarray(src_img).convert('RGBA')
|
||||
for poly, tag in zip(select_polys, select_ignores):
|
||||
box_img = get_rotate_crop_image(ext_image, poly)
|
||||
|
||||
src_img, box = self.paste_img(src_img, box_img, src_polys)
|
||||
if box is not None:
|
||||
box = box.tolist()
|
||||
for _ in range(len(box), point_num):
|
||||
box.append(box[-1])
|
||||
src_polys.append(box)
|
||||
src_ignores.append(tag)
|
||||
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
|
||||
h, w = src_img.shape[:2]
|
||||
src_polys = np.array(src_polys)
|
||||
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
|
||||
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
|
||||
data['image'] = src_img
|
||||
data['polys'] = src_polys
|
||||
data['ignore_tags'] = np.array(src_ignores)
|
||||
return data
|
||||
|
||||
def paste_img(self, src_img, box_img, src_polys):
|
||||
box_img_pil = Image.fromarray(box_img).convert('RGBA')
|
||||
src_w, src_h = src_img.size
|
||||
box_w, box_h = box_img_pil.size
|
||||
|
||||
angle = np.random.randint(0, 360)
|
||||
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
|
||||
box = rotate_bbox(box_img, box, angle)[0]
|
||||
box_img_pil = box_img_pil.rotate(angle, expand=1)
|
||||
box_w, box_h = box_img_pil.width, box_img_pil.height
|
||||
if src_w - box_w < 0 or src_h - box_h < 0:
|
||||
return src_img, None
|
||||
|
||||
paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
|
||||
src_h - box_h)
|
||||
if paste_x is None:
|
||||
return src_img, None
|
||||
box[:, 0] += paste_x
|
||||
box[:, 1] += paste_y
|
||||
r, g, b, A = box_img_pil.split()
|
||||
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
|
||||
|
||||
return src_img, box
|
||||
|
||||
def select_coord(self, src_polys, box, endx, endy):
|
||||
if self.limit_paste:
|
||||
xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
|
||||
), box[:, 0].max(), box[:, 1].max()
|
||||
for _ in range(50):
|
||||
paste_x = random.randint(0, endx)
|
||||
paste_y = random.randint(0, endy)
|
||||
xmin1 = xmin + paste_x
|
||||
xmax1 = xmax + paste_x
|
||||
ymin1 = ymin + paste_y
|
||||
ymax1 = ymax + paste_y
|
||||
|
||||
num_poly_in_rect = 0
|
||||
for poly in src_polys:
|
||||
if not is_poly_outside_rect(poly, xmin1, ymin1,
|
||||
xmax1 - xmin1, ymax1 - ymin1):
|
||||
num_poly_in_rect += 1
|
||||
break
|
||||
if num_poly_in_rect == 0:
|
||||
return paste_x, paste_y
|
||||
return None, None
|
||||
else:
|
||||
paste_x = random.randint(0, endx)
|
||||
paste_y = random.randint(0, endy)
|
||||
return paste_x, paste_y
|
||||
|
||||
|
||||
def get_union(pD, pG):
|
||||
return Polygon(pD).union(Polygon(pG)).area
|
||||
|
||||
|
||||
def get_intersection_over_union(pD, pG):
|
||||
return get_intersection(pD, pG) / get_union(pD, pG)
|
||||
|
||||
|
||||
def get_intersection(pD, pG):
|
||||
return Polygon(pD).intersection(Polygon(pG)).area
|
||||
|
||||
|
||||
def rotate_bbox(img, text_polys, angle, scale=1):
|
||||
"""
|
||||
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
|
||||
Args:
|
||||
img: np.ndarray
|
||||
text_polys: np.ndarray N*4*2
|
||||
angle: int
|
||||
scale: int
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
w = img.shape[1]
|
||||
h = img.shape[0]
|
||||
|
||||
rangle = np.deg2rad(angle)
|
||||
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
|
||||
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
|
||||
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
|
||||
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
|
||||
rot_mat[0, 2] += rot_move[0]
|
||||
rot_mat[1, 2] += rot_move[1]
|
||||
|
||||
# ---------------------- rotate box ----------------------
|
||||
rot_text_polys = list()
|
||||
for bbox in text_polys:
|
||||
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
|
||||
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
|
||||
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
|
||||
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
|
||||
rot_text_polys.append([point1, point2, point3, point4])
|
||||
return np.array(rot_text_polys, dtype=np.float32)
|
||||
@@ -1,436 +0,0 @@
|
||||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
"""
|
||||
This code is refered from:
|
||||
https://github.com/songdejia/EAST/blob/master/data_utils.py
|
||||
"""
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
__all__ = ['EASTProcessTrain']
|
||||
|
||||
|
||||
class EASTProcessTrain(object):
|
||||
def __init__(self,
|
||||
image_shape=[512, 512],
|
||||
background_ratio=0.125,
|
||||
min_crop_side_ratio=0.1,
|
||||
min_text_size=10,
|
||||
**kwargs):
|
||||
self.input_size = image_shape[1]
|
||||
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
|
||||
self.background_ratio = background_ratio
|
||||
self.min_crop_side_ratio = min_crop_side_ratio
|
||||
self.min_text_size = min_text_size
|
||||
|
||||
def preprocess(self, im):
|
||||
input_size = self.input_size
|
||||
im_shape = im.shape
|
||||
im_size_min = np.min(im_shape[0:2])
|
||||
im_size_max = np.max(im_shape[0:2])
|
||||
im_scale = float(input_size) / float(im_size_max)
|
||||
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
# im = im[:, :, ::-1].astype(np.float32)
|
||||
im = im / 255
|
||||
im -= img_mean
|
||||
im /= img_std
|
||||
new_h, new_w, _ = im.shape
|
||||
im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
|
||||
im_padded[:new_h, :new_w, :] = im
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
im_padded = im_padded[np.newaxis, :]
|
||||
return im_padded, im_scale
|
||||
|
||||
def rotate_im_poly(self, im, text_polys):
|
||||
"""
|
||||
rotate image with 90 / 180 / 270 degre
|
||||
"""
|
||||
im_w, im_h = im.shape[1], im.shape[0]
|
||||
dst_im = im.copy()
|
||||
dst_polys = []
|
||||
rand_degree_ratio = np.random.rand()
|
||||
rand_degree_cnt = 1
|
||||
if 0.333 < rand_degree_ratio < 0.666:
|
||||
rand_degree_cnt = 2
|
||||
elif rand_degree_ratio > 0.666:
|
||||
rand_degree_cnt = 3
|
||||
for i in range(rand_degree_cnt):
|
||||
dst_im = np.rot90(dst_im)
|
||||
rot_degree = -90 * rand_degree_cnt
|
||||
rot_angle = rot_degree * math.pi / 180.0
|
||||
n_poly = text_polys.shape[0]
|
||||
cx, cy = 0.5 * im_w, 0.5 * im_h
|
||||
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
||||
for i in range(n_poly):
|
||||
wordBB = text_polys[i]
|
||||
poly = []
|
||||
for j in range(4):
|
||||
sx, sy = wordBB[j][0], wordBB[j][1]
|
||||
dx = math.cos(rot_angle) * (sx - cx)\
|
||||
- math.sin(rot_angle) * (sy - cy) + ncx
|
||||
dy = math.sin(rot_angle) * (sx - cx)\
|
||||
+ math.cos(rot_angle) * (sy - cy) + ncy
|
||||
poly.append([dx, dy])
|
||||
dst_polys.append(poly)
|
||||
dst_polys = np.array(dst_polys, dtype=np.float32)
|
||||
return dst_im, dst_polys
|
||||
|
||||
def polygon_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, img_height, img_width):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
:param polys:
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
h, w = img_height, img_width
|
||||
if polys.shape[0] == 0:
|
||||
return polys
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||
|
||||
validated_polys = []
|
||||
validated_tags = []
|
||||
for poly, tag in zip(polys, tags):
|
||||
p_area = self.polygon_area(poly)
|
||||
#invalid poly
|
||||
if abs(p_area) < 1:
|
||||
continue
|
||||
if p_area > 0:
|
||||
#'poly in wrong direction'
|
||||
if not tag:
|
||||
tag = True #reversed cases should be ignore
|
||||
poly = poly[(0, 3, 2, 1), :]
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
return np.array(validated_polys), np.array(validated_tags)
|
||||
|
||||
def draw_img_polys(self, img, polys):
|
||||
if len(img.shape) == 4:
|
||||
img = np.squeeze(img, axis=0)
|
||||
if img.shape[0] == 3:
|
||||
img = img.transpose((1, 2, 0))
|
||||
img[:, :, 2] += 123.68
|
||||
img[:, :, 1] += 116.78
|
||||
img[:, :, 0] += 103.94
|
||||
cv2.imwrite("tmp.jpg", img)
|
||||
img = cv2.imread("tmp.jpg")
|
||||
for box in polys:
|
||||
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
|
||||
import random
|
||||
ino = random.randint(0, 100)
|
||||
cv2.imwrite("tmp_%d.jpg" % ino, img)
|
||||
return
|
||||
|
||||
def shrink_poly(self, poly, r):
|
||||
"""
|
||||
fit a poly inside the origin poly, maybe bugs here...
|
||||
used for generate the score map
|
||||
:param poly: the text poly
|
||||
:param r: r in the paper
|
||||
:return: the shrinked poly
|
||||
"""
|
||||
# shrink ratio
|
||||
R = 0.3
|
||||
# find the longer pair
|
||||
dist0 = np.linalg.norm(poly[0] - poly[1])
|
||||
dist1 = np.linalg.norm(poly[2] - poly[3])
|
||||
dist2 = np.linalg.norm(poly[0] - poly[3])
|
||||
dist3 = np.linalg.norm(poly[1] - poly[2])
|
||||
if dist0 + dist1 > dist2 + dist3:
|
||||
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
|
||||
## p0, p1
|
||||
theta = np.arctan2((poly[1][1] - poly[0][1]),
|
||||
(poly[1][0] - poly[0][0]))
|
||||
poly[0][0] += R * r[0] * np.cos(theta)
|
||||
poly[0][1] += R * r[0] * np.sin(theta)
|
||||
poly[1][0] -= R * r[1] * np.cos(theta)
|
||||
poly[1][1] -= R * r[1] * np.sin(theta)
|
||||
## p2, p3
|
||||
theta = np.arctan2((poly[2][1] - poly[3][1]),
|
||||
(poly[2][0] - poly[3][0]))
|
||||
poly[3][0] += R * r[3] * np.cos(theta)
|
||||
poly[3][1] += R * r[3] * np.sin(theta)
|
||||
poly[2][0] -= R * r[2] * np.cos(theta)
|
||||
poly[2][1] -= R * r[2] * np.sin(theta)
|
||||
## p0, p3
|
||||
theta = np.arctan2((poly[3][0] - poly[0][0]),
|
||||
(poly[3][1] - poly[0][1]))
|
||||
poly[0][0] += R * r[0] * np.sin(theta)
|
||||
poly[0][1] += R * r[0] * np.cos(theta)
|
||||
poly[3][0] -= R * r[3] * np.sin(theta)
|
||||
poly[3][1] -= R * r[3] * np.cos(theta)
|
||||
## p1, p2
|
||||
theta = np.arctan2((poly[2][0] - poly[1][0]),
|
||||
(poly[2][1] - poly[1][1]))
|
||||
poly[1][0] += R * r[1] * np.sin(theta)
|
||||
poly[1][1] += R * r[1] * np.cos(theta)
|
||||
poly[2][0] -= R * r[2] * np.sin(theta)
|
||||
poly[2][1] -= R * r[2] * np.cos(theta)
|
||||
else:
|
||||
## p0, p3
|
||||
# print poly
|
||||
theta = np.arctan2((poly[3][0] - poly[0][0]),
|
||||
(poly[3][1] - poly[0][1]))
|
||||
poly[0][0] += R * r[0] * np.sin(theta)
|
||||
poly[0][1] += R * r[0] * np.cos(theta)
|
||||
poly[3][0] -= R * r[3] * np.sin(theta)
|
||||
poly[3][1] -= R * r[3] * np.cos(theta)
|
||||
## p1, p2
|
||||
theta = np.arctan2((poly[2][0] - poly[1][0]),
|
||||
(poly[2][1] - poly[1][1]))
|
||||
poly[1][0] += R * r[1] * np.sin(theta)
|
||||
poly[1][1] += R * r[1] * np.cos(theta)
|
||||
poly[2][0] -= R * r[2] * np.sin(theta)
|
||||
poly[2][1] -= R * r[2] * np.cos(theta)
|
||||
## p0, p1
|
||||
theta = np.arctan2((poly[1][1] - poly[0][1]),
|
||||
(poly[1][0] - poly[0][0]))
|
||||
poly[0][0] += R * r[0] * np.cos(theta)
|
||||
poly[0][1] += R * r[0] * np.sin(theta)
|
||||
poly[1][0] -= R * r[1] * np.cos(theta)
|
||||
poly[1][1] -= R * r[1] * np.sin(theta)
|
||||
## p2, p3
|
||||
theta = np.arctan2((poly[2][1] - poly[3][1]),
|
||||
(poly[2][0] - poly[3][0]))
|
||||
poly[3][0] += R * r[3] * np.cos(theta)
|
||||
poly[3][1] += R * r[3] * np.sin(theta)
|
||||
poly[2][0] -= R * r[2] * np.cos(theta)
|
||||
poly[2][1] -= R * r[2] * np.sin(theta)
|
||||
return poly
|
||||
|
||||
def generate_quad(self, im_size, polys, tags):
|
||||
"""
|
||||
Generate quadrangle.
|
||||
"""
|
||||
h, w = im_size
|
||||
poly_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
score_map = np.zeros((h, w), dtype=np.uint8)
|
||||
# (x1, y1, ..., x4, y4, short_edge_norm)
|
||||
geo_map = np.zeros((h, w, 9), dtype=np.float32)
|
||||
# mask used during traning, to ignore some hard areas
|
||||
training_mask = np.ones((h, w), dtype=np.uint8)
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
r = [None, None, None, None]
|
||||
for i in range(4):
|
||||
dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
|
||||
dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
|
||||
r[i] = min(dist1, dist2)
|
||||
# score map
|
||||
shrinked_poly = self.shrink_poly(
|
||||
poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
|
||||
cv2.fillPoly(score_map, shrinked_poly, 1)
|
||||
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
|
||||
# if the poly is too small, then ignore it during training
|
||||
poly_h = min(
|
||||
np.linalg.norm(poly[0] - poly[3]),
|
||||
np.linalg.norm(poly[1] - poly[2]))
|
||||
poly_w = min(
|
||||
np.linalg.norm(poly[0] - poly[1]),
|
||||
np.linalg.norm(poly[2] - poly[3]))
|
||||
if min(poly_h, poly_w) < self.min_text_size:
|
||||
cv2.fillPoly(training_mask,
|
||||
poly.astype(np.int32)[np.newaxis, :, :], 0)
|
||||
|
||||
if tag:
|
||||
cv2.fillPoly(training_mask,
|
||||
poly.astype(np.int32)[np.newaxis, :, :], 0)
|
||||
|
||||
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
|
||||
# geo map.
|
||||
y_in_poly = xy_in_poly[:, 0]
|
||||
x_in_poly = xy_in_poly[:, 1]
|
||||
poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
|
||||
poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
|
||||
for pno in range(4):
|
||||
geo_channel_beg = pno * 2
|
||||
geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
|
||||
x_in_poly - poly[pno, 0]
|
||||
geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
|
||||
y_in_poly - poly[pno, 1]
|
||||
geo_map[y_in_poly, x_in_poly, 8] = \
|
||||
1.0 / max(min(poly_h, poly_w), 1.0)
|
||||
return score_map, geo_map, training_mask
|
||||
|
||||
def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
:param polys:
|
||||
:param tags:
|
||||
:param crop_background:
|
||||
:param max_tries:
|
||||
:return:
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
pad_h = h // 10
|
||||
pad_w = w // 10
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return im, polys, tags
|
||||
|
||||
for i in range(max_tries):
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
if xmax - xmin < self.min_crop_side_ratio * w or \
|
||||
ymax - ymin < self.min_crop_side_ratio * h:
|
||||
# area too small
|
||||
continue
|
||||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin)\
|
||||
& (polys[:, :, 0] <= xmax)\
|
||||
& (polys[:, :, 1] >= ymin)\
|
||||
& (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(
|
||||
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||
polys = []
|
||||
tags = []
|
||||
return im, polys, tags
|
||||
else:
|
||||
continue
|
||||
|
||||
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
polys[:, :, 0] -= xmin
|
||||
polys[:, :, 1] -= ymin
|
||||
return im, polys, tags
|
||||
return im, polys, tags
|
||||
|
||||
def crop_background_infor(self, im, text_polys, text_tags):
|
||||
im, text_polys, text_tags = self.crop_area(
|
||||
im, text_polys, text_tags, crop_background=True)
|
||||
|
||||
if len(text_polys) > 0:
|
||||
return None
|
||||
# pad and resize image
|
||||
input_size = self.input_size
|
||||
im, ratio = self.preprocess(im)
|
||||
score_map = np.zeros((input_size, input_size), dtype=np.float32)
|
||||
geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
|
||||
training_mask = np.ones((input_size, input_size), dtype=np.float32)
|
||||
return im, score_map, geo_map, training_mask
|
||||
|
||||
def crop_foreground_infor(self, im, text_polys, text_tags):
|
||||
im, text_polys, text_tags = self.crop_area(
|
||||
im, text_polys, text_tags, crop_background=False)
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
#continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
# pad and resize image
|
||||
input_size = self.input_size
|
||||
im, ratio = self.preprocess(im)
|
||||
text_polys[:, :, 0] *= ratio
|
||||
text_polys[:, :, 1] *= ratio
|
||||
_, _, new_h, new_w = im.shape
|
||||
# print(im.shape)
|
||||
# self.draw_img_polys(im, text_polys)
|
||||
score_map, geo_map, training_mask = self.generate_quad(
|
||||
(new_h, new_w), text_polys, text_tags)
|
||||
return im, score_map, geo_map, training_mask
|
||||
|
||||
def __call__(self, data):
|
||||
im = data['image']
|
||||
text_polys = data['polys']
|
||||
text_tags = data['ignore_tags']
|
||||
if im is None:
|
||||
return None
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
#add rotate cases
|
||||
if np.random.rand() < 0.5:
|
||||
im, text_polys = self.rotate_im_poly(im, text_polys)
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags = self.check_and_validate_polys(text_polys,
|
||||
text_tags, h, w)
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
# random scale this image
|
||||
rd_scale = np.random.choice(self.random_scale)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
if np.random.rand() < self.background_ratio:
|
||||
outs = self.crop_background_infor(im, text_polys, text_tags)
|
||||
else:
|
||||
outs = self.crop_foreground_infor(im, text_polys, text_tags)
|
||||
|
||||
if outs is None:
|
||||
return None
|
||||
im, score_map, geo_map, training_mask = outs
|
||||
score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
|
||||
geo_map = np.swapaxes(geo_map, 1, 2)
|
||||
geo_map = np.swapaxes(geo_map, 1, 0)
|
||||
geo_map = geo_map[:, ::4, ::4].astype(np.float32)
|
||||
training_mask = training_mask[np.newaxis, ::4, ::4]
|
||||
training_mask = training_mask.astype(np.float32)
|
||||
|
||||
data['image'] = im[0]
|
||||
data['score_map'] = score_map
|
||||
data['geo_map'] = geo_map
|
||||
data['training_mask'] = training_mask
|
||||
return data
|
||||
@@ -1,564 +0,0 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/transforms.py
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
import cv2
|
||||
from shapely.geometry import Polygon
|
||||
import math
|
||||
from ppocr.utils.poly_nms import poly_intersection
|
||||
|
||||
|
||||
class RandomScaling:
|
||||
def __init__(self, size=800, scale=(3. / 4, 5. / 2), **kwargs):
|
||||
"""Random scale the image while keeping aspect.
|
||||
|
||||
Args:
|
||||
size (int) : Base size before scaling.
|
||||
scale (tuple(float)) : The range of scaling.
|
||||
"""
|
||||
assert isinstance(size, int)
|
||||
assert isinstance(scale, float) or isinstance(scale, tuple)
|
||||
self.size = size
|
||||
self.scale = scale if isinstance(scale, tuple) \
|
||||
else (1 - scale, 1 + scale)
|
||||
|
||||
def __call__(self, data):
|
||||
image = data['image']
|
||||
text_polys = data['polys']
|
||||
h, w, _ = image.shape
|
||||
|
||||
aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
|
||||
scales = self.size * 1.0 / max(h, w) * aspect_ratio
|
||||
scales = np.array([scales, scales])
|
||||
out_size = (int(h * scales[1]), int(w * scales[0]))
|
||||
image = cv2.resize(image, out_size[::-1])
|
||||
|
||||
data['image'] = image
|
||||
text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
|
||||
text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
|
||||
data['polys'] = text_polys
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class RandomCropFlip:
|
||||
def __init__(self,
|
||||
pad_ratio=0.1,
|
||||
crop_ratio=0.5,
|
||||
iter_num=1,
|
||||
min_area_ratio=0.2,
|
||||
**kwargs):
|
||||
"""Random crop and flip a patch of the image.
|
||||
|
||||
Args:
|
||||
crop_ratio (float): The ratio of cropping.
|
||||
iter_num (int): Number of operations.
|
||||
min_area_ratio (float): Minimal area ratio between cropped patch
|
||||
and original image.
|
||||
"""
|
||||
assert isinstance(crop_ratio, float)
|
||||
assert isinstance(iter_num, int)
|
||||
assert isinstance(min_area_ratio, float)
|
||||
|
||||
self.pad_ratio = pad_ratio
|
||||
self.epsilon = 1e-2
|
||||
self.crop_ratio = crop_ratio
|
||||
self.iter_num = iter_num
|
||||
self.min_area_ratio = min_area_ratio
|
||||
|
||||
def __call__(self, results):
|
||||
for i in range(self.iter_num):
|
||||
results = self.random_crop_flip(results)
|
||||
|
||||
return results
|
||||
|
||||
def random_crop_flip(self, results):
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
ignore_tags = results['ignore_tags']
|
||||
if len(polygons) == 0:
|
||||
return results
|
||||
|
||||
if np.random.random() >= self.crop_ratio:
|
||||
return results
|
||||
|
||||
h, w, _ = image.shape
|
||||
area = h * w
|
||||
pad_h = int(h * self.pad_ratio)
|
||||
pad_w = int(w * self.pad_ratio)
|
||||
h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h,
|
||||
pad_w)
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return results
|
||||
|
||||
attempt = 0
|
||||
while attempt < 50:
|
||||
attempt += 1
|
||||
polys_keep = []
|
||||
polys_new = []
|
||||
ignore_tags_keep = []
|
||||
ignore_tags_new = []
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
|
||||
# area too small
|
||||
continue
|
||||
|
||||
pts = np.stack([[xmin, xmax, xmax, xmin],
|
||||
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
|
||||
pp = Polygon(pts)
|
||||
fail_flag = False
|
||||
for polygon, ignore_tag in zip(polygons, ignore_tags):
|
||||
ppi = Polygon(polygon.reshape(-1, 2))
|
||||
ppiou, _ = poly_intersection(ppi, pp, buffer=0)
|
||||
if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
|
||||
np.abs(ppiou) > self.epsilon:
|
||||
fail_flag = True
|
||||
break
|
||||
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
|
||||
polys_new.append(polygon)
|
||||
ignore_tags_new.append(ignore_tag)
|
||||
else:
|
||||
polys_keep.append(polygon)
|
||||
ignore_tags_keep.append(ignore_tag)
|
||||
|
||||
if fail_flag:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
cropped = image[ymin:ymax, xmin:xmax, :]
|
||||
select_type = np.random.randint(3)
|
||||
if select_type == 0:
|
||||
img = np.ascontiguousarray(cropped[:, ::-1])
|
||||
elif select_type == 1:
|
||||
img = np.ascontiguousarray(cropped[::-1, :])
|
||||
else:
|
||||
img = np.ascontiguousarray(cropped[::-1, ::-1])
|
||||
image[ymin:ymax, xmin:xmax, :] = img
|
||||
results['img'] = image
|
||||
|
||||
if len(polys_new) != 0:
|
||||
height, width, _ = cropped.shape
|
||||
if select_type == 0:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon.reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
polys_new[idx] = poly
|
||||
elif select_type == 1:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon.reshape(-1, 2)
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = poly
|
||||
else:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon.reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = poly
|
||||
polygons = polys_keep + polys_new
|
||||
ignore_tags = ignore_tags_keep + ignore_tags_new
|
||||
results['polys'] = np.array(polygons)
|
||||
results['ignore_tags'] = ignore_tags
|
||||
|
||||
return results
|
||||
|
||||
def generate_crop_target(self, image, all_polys, pad_h, pad_w):
|
||||
"""Generate crop target and make sure not to crop the polygon
|
||||
instances.
|
||||
|
||||
Args:
|
||||
image (ndarray): The image waited to be crop.
|
||||
all_polys (list[list[ndarray]]): All polygons including ground
|
||||
truth polygons and ground truth ignored polygons.
|
||||
pad_h (int): Padding length of height.
|
||||
pad_w (int): Padding length of width.
|
||||
Returns:
|
||||
h_axis (ndarray): Vertical cropping range.
|
||||
w_axis (ndarray): Horizontal cropping range.
|
||||
"""
|
||||
h, w, _ = image.shape
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
|
||||
text_polys = []
|
||||
for polygon in all_polys:
|
||||
rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
|
||||
box = cv2.boxPoints(rect)
|
||||
box = np.int0(box)
|
||||
text_polys.append([box[0], box[1], box[2], box[3]])
|
||||
|
||||
polys = np.array(text_polys, dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
return h_axis, w_axis
|
||||
|
||||
|
||||
class RandomCropPolyInstances:
|
||||
"""Randomly crop images and make sure to contain at least one intact
|
||||
instance."""
|
||||
|
||||
def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
|
||||
super().__init__()
|
||||
self.crop_ratio = crop_ratio
|
||||
self.min_side_ratio = min_side_ratio
|
||||
|
||||
def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
|
||||
|
||||
assert isinstance(min_len, int)
|
||||
assert len(valid_array) > min_len
|
||||
|
||||
start_array = valid_array.copy()
|
||||
max_start = min(len(start_array) - min_len, max_start)
|
||||
start_array[max_start:] = 0
|
||||
start_array[0] = 1
|
||||
diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
|
||||
region_starts = np.where(diff_array < 0)[0]
|
||||
region_ends = np.where(diff_array > 0)[0]
|
||||
region_ind = np.random.randint(0, len(region_starts))
|
||||
start = np.random.randint(region_starts[region_ind],
|
||||
region_ends[region_ind])
|
||||
|
||||
end_array = valid_array.copy()
|
||||
min_end = max(start + min_len, min_end)
|
||||
end_array[:min_end] = 0
|
||||
end_array[-1] = 1
|
||||
diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
|
||||
region_starts = np.where(diff_array < 0)[0]
|
||||
region_ends = np.where(diff_array > 0)[0]
|
||||
region_ind = np.random.randint(0, len(region_starts))
|
||||
end = np.random.randint(region_starts[region_ind],
|
||||
region_ends[region_ind])
|
||||
return start, end
|
||||
|
||||
def sample_crop_box(self, img_size, results):
|
||||
"""Generate crop box and make sure not to crop the polygon instances.
|
||||
|
||||
Args:
|
||||
img_size (tuple(int)): The image size (h, w).
|
||||
results (dict): The results dict.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
h, w = img_size[:2]
|
||||
|
||||
key_masks = results['polys']
|
||||
|
||||
x_valid_array = np.ones(w, dtype=np.int32)
|
||||
y_valid_array = np.ones(h, dtype=np.int32)
|
||||
|
||||
selected_mask = key_masks[np.random.randint(0, len(key_masks))]
|
||||
selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32)
|
||||
max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
|
||||
min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
|
||||
max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
|
||||
min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
|
||||
|
||||
for mask in key_masks:
|
||||
mask = mask.reshape((-1, 2)).astype(np.int32)
|
||||
clip_x = np.clip(mask[:, 0], 0, w - 1)
|
||||
clip_y = np.clip(mask[:, 1], 0, h - 1)
|
||||
min_x, max_x = np.min(clip_x), np.max(clip_x)
|
||||
min_y, max_y = np.min(clip_y), np.max(clip_y)
|
||||
|
||||
x_valid_array[min_x - 2:max_x + 3] = 0
|
||||
y_valid_array[min_y - 2:max_y + 3] = 0
|
||||
|
||||
min_w = int(w * self.min_side_ratio)
|
||||
min_h = int(h * self.min_side_ratio)
|
||||
|
||||
x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start,
|
||||
min_x_end)
|
||||
y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start,
|
||||
min_y_end)
|
||||
|
||||
return np.array([x1, y1, x2, y2])
|
||||
|
||||
def crop_img(self, img, bbox):
|
||||
assert img.ndim == 3
|
||||
h, w, _ = img.shape
|
||||
assert 0 <= bbox[1] < bbox[3] <= h
|
||||
assert 0 <= bbox[0] < bbox[2] <= w
|
||||
return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
|
||||
|
||||
def __call__(self, results):
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
ignore_tags = results['ignore_tags']
|
||||
if len(polygons) < 1:
|
||||
return results
|
||||
|
||||
if np.random.random_sample() < self.crop_ratio:
|
||||
|
||||
crop_box = self.sample_crop_box(image.shape, results)
|
||||
img = self.crop_img(image, crop_box)
|
||||
results['image'] = img
|
||||
# crop and filter masks
|
||||
x1, y1, x2, y2 = crop_box
|
||||
w = max(x2 - x1, 1)
|
||||
h = max(y2 - y1, 1)
|
||||
polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1
|
||||
polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1
|
||||
|
||||
valid_masks_list = []
|
||||
valid_tags_list = []
|
||||
for ind, polygon in enumerate(polygons):
|
||||
if (polygon[:, ::2] > -4).all() and (
|
||||
polygon[:, ::2] < w + 4).all() and (
|
||||
polygon[:, 1::2] > -4).all() and (
|
||||
polygon[:, 1::2] < h + 4).all():
|
||||
polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
|
||||
polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
|
||||
valid_masks_list.append(polygon)
|
||||
valid_tags_list.append(ignore_tags[ind])
|
||||
|
||||
results['polys'] = np.array(valid_masks_list)
|
||||
results['ignore_tags'] = valid_tags_list
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
class RandomRotatePolyInstances:
|
||||
def __init__(self,
|
||||
rotate_ratio=0.5,
|
||||
max_angle=10,
|
||||
pad_with_fixed_color=False,
|
||||
pad_value=(0, 0, 0),
|
||||
**kwargs):
|
||||
"""Randomly rotate images and polygon masks.
|
||||
|
||||
Args:
|
||||
rotate_ratio (float): The ratio of samples to operate rotation.
|
||||
max_angle (int): The maximum rotation angle.
|
||||
pad_with_fixed_color (bool): The flag for whether to pad rotated
|
||||
image with fixed value. If set to False, the rotated image will
|
||||
be padded onto cropped image.
|
||||
pad_value (tuple(int)): The color value for padding rotated image.
|
||||
"""
|
||||
self.rotate_ratio = rotate_ratio
|
||||
self.max_angle = max_angle
|
||||
self.pad_with_fixed_color = pad_with_fixed_color
|
||||
self.pad_value = pad_value
|
||||
|
||||
def rotate(self, center, points, theta, center_shift=(0, 0)):
|
||||
# rotate points.
|
||||
(center_x, center_y) = center
|
||||
center_y = -center_y
|
||||
x, y = points[:, ::2], points[:, 1::2]
|
||||
y = -y
|
||||
|
||||
theta = theta / 180 * math.pi
|
||||
cos = math.cos(theta)
|
||||
sin = math.sin(theta)
|
||||
|
||||
x = (x - center_x)
|
||||
y = (y - center_y)
|
||||
|
||||
_x = center_x + x * cos - y * sin + center_shift[0]
|
||||
_y = -(center_y + x * sin + y * cos) + center_shift[1]
|
||||
|
||||
points[:, ::2], points[:, 1::2] = _x, _y
|
||||
return points
|
||||
|
||||
def cal_canvas_size(self, ori_size, degree):
|
||||
assert isinstance(ori_size, tuple)
|
||||
angle = degree * math.pi / 180.0
|
||||
h, w = ori_size[:2]
|
||||
|
||||
cos = math.cos(angle)
|
||||
sin = math.sin(angle)
|
||||
canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
|
||||
canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
|
||||
|
||||
canvas_size = (canvas_h, canvas_w)
|
||||
return canvas_size
|
||||
|
||||
def sample_angle(self, max_angle):
|
||||
angle = np.random.random_sample() * 2 * max_angle - max_angle
|
||||
return angle
|
||||
|
||||
def rotate_img(self, img, angle, canvas_size):
|
||||
h, w = img.shape[:2]
|
||||
rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
||||
rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
|
||||
rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
|
||||
|
||||
if self.pad_with_fixed_color:
|
||||
target_img = cv2.warpAffine(
|
||||
img,
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
flags=cv2.INTER_NEAREST,
|
||||
borderValue=self.pad_value)
|
||||
else:
|
||||
mask = np.zeros_like(img)
|
||||
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
|
||||
np.random.randint(0, w * 7 // 8))
|
||||
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
|
||||
img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
|
||||
|
||||
mask = cv2.warpAffine(
|
||||
mask,
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
borderValue=[1, 1, 1])
|
||||
target_img = cv2.warpAffine(
|
||||
img,
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
borderValue=[0, 0, 0])
|
||||
target_img = target_img + img_cut * mask
|
||||
|
||||
return target_img
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.random_sample() < self.rotate_ratio:
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
h, w = image.shape[:2]
|
||||
|
||||
angle = self.sample_angle(self.max_angle)
|
||||
canvas_size = self.cal_canvas_size((h, w), angle)
|
||||
center_shift = (int((canvas_size[1] - w) / 2), int(
|
||||
(canvas_size[0] - h) / 2))
|
||||
image = self.rotate_img(image, angle, canvas_size)
|
||||
results['image'] = image
|
||||
# rotate polygons
|
||||
rotated_masks = []
|
||||
for mask in polygons:
|
||||
rotated_mask = self.rotate((w / 2, h / 2), mask, angle,
|
||||
center_shift)
|
||||
rotated_masks.append(rotated_mask)
|
||||
results['polys'] = np.array(rotated_masks)
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
class SquareResizePad:
|
||||
def __init__(self,
|
||||
target_size,
|
||||
pad_ratio=0.6,
|
||||
pad_with_fixed_color=False,
|
||||
pad_value=(0, 0, 0),
|
||||
**kwargs):
|
||||
"""Resize or pad images to be square shape.
|
||||
|
||||
Args:
|
||||
target_size (int): The target size of square shaped image.
|
||||
pad_with_fixed_color (bool): The flag for whether to pad rotated
|
||||
image with fixed value. If set to False, the rescales image will
|
||||
be padded onto cropped image.
|
||||
pad_value (tuple(int)): The color value for padding rotated image.
|
||||
"""
|
||||
assert isinstance(target_size, int)
|
||||
assert isinstance(pad_ratio, float)
|
||||
assert isinstance(pad_with_fixed_color, bool)
|
||||
assert isinstance(pad_value, tuple)
|
||||
|
||||
self.target_size = target_size
|
||||
self.pad_ratio = pad_ratio
|
||||
self.pad_with_fixed_color = pad_with_fixed_color
|
||||
self.pad_value = pad_value
|
||||
|
||||
def resize_img(self, img, keep_ratio=True):
|
||||
h, w, _ = img.shape
|
||||
if keep_ratio:
|
||||
t_h = self.target_size if h >= w else int(h * self.target_size / w)
|
||||
t_w = self.target_size if h <= w else int(w * self.target_size / h)
|
||||
else:
|
||||
t_h = t_w = self.target_size
|
||||
img = cv2.resize(img, (t_w, t_h))
|
||||
return img, (t_h, t_w)
|
||||
|
||||
def square_pad(self, img):
|
||||
h, w = img.shape[:2]
|
||||
if h == w:
|
||||
return img, (0, 0)
|
||||
pad_size = max(h, w)
|
||||
if self.pad_with_fixed_color:
|
||||
expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
|
||||
expand_img[:] = self.pad_value
|
||||
else:
|
||||
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
|
||||
np.random.randint(0, w * 7 // 8))
|
||||
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
|
||||
expand_img = cv2.resize(img_cut, (pad_size, pad_size))
|
||||
if h > w:
|
||||
y0, x0 = 0, (h - w) // 2
|
||||
else:
|
||||
y0, x0 = (w - h) // 2, 0
|
||||
expand_img[y0:y0 + h, x0:x0 + w] = img
|
||||
offset = (x0, y0)
|
||||
|
||||
return expand_img, offset
|
||||
|
||||
def square_pad_mask(self, points, offset):
|
||||
x0, y0 = offset
|
||||
pad_points = points.copy()
|
||||
pad_points[::2] = pad_points[::2] + x0
|
||||
pad_points[1::2] = pad_points[1::2] + y0
|
||||
return pad_points
|
||||
|
||||
def __call__(self, results):
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if np.random.random_sample() < self.pad_ratio:
|
||||
image, out_size = self.resize_img(image, keep_ratio=True)
|
||||
image, offset = self.square_pad(image)
|
||||
else:
|
||||
image, out_size = self.resize_img(image, keep_ratio=False)
|
||||
offset = (0, 0)
|
||||
results['image'] = image
|
||||
try:
|
||||
polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[
|
||||
1] / w + offset[0]
|
||||
polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[
|
||||
0] / h + offset[1]
|
||||
except:
|
||||
pass
|
||||
results['polys'] = polygons
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
@@ -1,658 +0,0 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.fft import fft
|
||||
from numpy.linalg import norm
|
||||
import sys
|
||||
|
||||
|
||||
class FCENetTargets:
|
||||
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
|
||||
for Arbitrary-Shaped Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2104.10442]
|
||||
|
||||
Args:
|
||||
fourier_degree (int): The maximum Fourier transform degree k.
|
||||
resample_step (float): The step size for resampling the text center
|
||||
line (TCL). It's better not to exceed half of the minimum width.
|
||||
center_region_shrink_ratio (float): The shrink ratio of text center
|
||||
region.
|
||||
level_size_divisors (tuple(int)): The downsample ratio on each level.
|
||||
level_proportion_range (tuple(tuple(int))): The range of text sizes
|
||||
assigned to each level.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
fourier_degree=5,
|
||||
resample_step=4.0,
|
||||
center_region_shrink_ratio=0.3,
|
||||
level_size_divisors=(8, 16, 32),
|
||||
level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
|
||||
orientation_thr=2.0,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
assert isinstance(level_size_divisors, tuple)
|
||||
assert isinstance(level_proportion_range, tuple)
|
||||
assert len(level_size_divisors) == len(level_proportion_range)
|
||||
self.fourier_degree = fourier_degree
|
||||
self.resample_step = resample_step
|
||||
self.center_region_shrink_ratio = center_region_shrink_ratio
|
||||
self.level_size_divisors = level_size_divisors
|
||||
self.level_proportion_range = level_proportion_range
|
||||
|
||||
self.orientation_thr = orientation_thr
|
||||
|
||||
def vector_angle(self, vec1, vec2):
|
||||
if vec1.ndim > 1:
|
||||
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
|
||||
else:
|
||||
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
|
||||
if vec2.ndim > 1:
|
||||
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
|
||||
else:
|
||||
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
|
||||
return np.arccos(
|
||||
np.clip(
|
||||
np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
|
||||
|
||||
def resample_line(self, line, n):
|
||||
"""Resample n points on a line.
|
||||
|
||||
Args:
|
||||
line (ndarray): The points composing a line.
|
||||
n (int): The resampled points number.
|
||||
|
||||
Returns:
|
||||
resampled_line (ndarray): The points composing the resampled line.
|
||||
"""
|
||||
|
||||
assert line.ndim == 2
|
||||
assert line.shape[0] >= 2
|
||||
assert line.shape[1] == 2
|
||||
assert isinstance(n, int)
|
||||
assert n > 0
|
||||
|
||||
length_list = [
|
||||
norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
|
||||
]
|
||||
total_length = sum(length_list)
|
||||
length_cumsum = np.cumsum([0.0] + length_list)
|
||||
delta_length = total_length / (float(n) + 1e-8)
|
||||
|
||||
current_edge_ind = 0
|
||||
resampled_line = [line[0]]
|
||||
|
||||
for i in range(1, n):
|
||||
current_line_len = i * delta_length
|
||||
|
||||
while current_line_len >= length_cumsum[current_edge_ind + 1]:
|
||||
current_edge_ind += 1
|
||||
current_edge_end_shift = current_line_len - length_cumsum[
|
||||
current_edge_ind]
|
||||
end_shift_ratio = current_edge_end_shift / length_list[
|
||||
current_edge_ind]
|
||||
current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
|
||||
- line[current_edge_ind]
|
||||
) * end_shift_ratio
|
||||
resampled_line.append(current_point)
|
||||
|
||||
resampled_line.append(line[-1])
|
||||
resampled_line = np.array(resampled_line)
|
||||
|
||||
return resampled_line
|
||||
|
||||
def reorder_poly_edge(self, points):
|
||||
"""Get the respective points composing head edge, tail edge, top
|
||||
sideline and bottom sideline.
|
||||
|
||||
Args:
|
||||
points (ndarray): The points composing a text polygon.
|
||||
|
||||
Returns:
|
||||
head_edge (ndarray): The two points composing the head edge of text
|
||||
polygon.
|
||||
tail_edge (ndarray): The two points composing the tail edge of text
|
||||
polygon.
|
||||
top_sideline (ndarray): The points composing top curved sideline of
|
||||
text polygon.
|
||||
bot_sideline (ndarray): The points composing bottom curved sideline
|
||||
of text polygon.
|
||||
"""
|
||||
|
||||
assert points.ndim == 2
|
||||
assert points.shape[0] >= 4
|
||||
assert points.shape[1] == 2
|
||||
|
||||
head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
|
||||
head_edge, tail_edge = points[head_inds], points[tail_inds]
|
||||
|
||||
pad_points = np.vstack([points, points])
|
||||
if tail_inds[1] < 1:
|
||||
tail_inds[1] = len(points)
|
||||
sideline1 = pad_points[head_inds[1]:tail_inds[1]]
|
||||
sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
|
||||
sideline_mean_shift = np.mean(
|
||||
sideline1, axis=0) - np.mean(
|
||||
sideline2, axis=0)
|
||||
|
||||
if sideline_mean_shift[1] > 0:
|
||||
top_sideline, bot_sideline = sideline2, sideline1
|
||||
else:
|
||||
top_sideline, bot_sideline = sideline1, sideline2
|
||||
|
||||
return head_edge, tail_edge, top_sideline, bot_sideline
|
||||
|
||||
def find_head_tail(self, points, orientation_thr):
|
||||
"""Find the head edge and tail edge of a text polygon.
|
||||
|
||||
Args:
|
||||
points (ndarray): The points composing a text polygon.
|
||||
orientation_thr (float): The threshold for distinguishing between
|
||||
head edge and tail edge among the horizontal and vertical edges
|
||||
of a quadrangle.
|
||||
|
||||
Returns:
|
||||
head_inds (list): The indexes of two points composing head edge.
|
||||
tail_inds (list): The indexes of two points composing tail edge.
|
||||
"""
|
||||
|
||||
assert points.ndim == 2
|
||||
assert points.shape[0] >= 4
|
||||
assert points.shape[1] == 2
|
||||
assert isinstance(orientation_thr, float)
|
||||
|
||||
if len(points) > 4:
|
||||
pad_points = np.vstack([points, points[0]])
|
||||
edge_vec = pad_points[1:] - pad_points[:-1]
|
||||
|
||||
theta_sum = []
|
||||
adjacent_vec_theta = []
|
||||
for i, edge_vec1 in enumerate(edge_vec):
|
||||
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
|
||||
adjacent_edge_vec = edge_vec[adjacent_ind]
|
||||
temp_theta_sum = np.sum(
|
||||
self.vector_angle(edge_vec1, adjacent_edge_vec))
|
||||
temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
|
||||
adjacent_edge_vec[1])
|
||||
theta_sum.append(temp_theta_sum)
|
||||
adjacent_vec_theta.append(temp_adjacent_theta)
|
||||
theta_sum_score = np.array(theta_sum) / np.pi
|
||||
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
|
||||
poly_center = np.mean(points, axis=0)
|
||||
edge_dist = np.maximum(
|
||||
norm(
|
||||
pad_points[1:] - poly_center, axis=-1),
|
||||
norm(
|
||||
pad_points[:-1] - poly_center, axis=-1))
|
||||
dist_score = edge_dist / np.max(edge_dist)
|
||||
position_score = np.zeros(len(edge_vec))
|
||||
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
|
||||
score += 0.35 * dist_score
|
||||
if len(points) % 2 == 0:
|
||||
position_score[(len(score) // 2 - 1)] += 1
|
||||
position_score[-1] += 1
|
||||
score += 0.1 * position_score
|
||||
pad_score = np.concatenate([score, score])
|
||||
score_matrix = np.zeros((len(score), len(score) - 3))
|
||||
x = np.arange(len(score) - 3) / float(len(score) - 4)
|
||||
gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
|
||||
(x - 0.5) / 0.5, 2.) / 2)
|
||||
gaussian = gaussian / np.max(gaussian)
|
||||
for i in range(len(score)):
|
||||
score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
|
||||
score) - 1)] * gaussian * 0.3
|
||||
|
||||
head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
|
||||
score_matrix.shape)
|
||||
tail_start = (head_start + tail_increment + 2) % len(points)
|
||||
head_end = (head_start + 1) % len(points)
|
||||
tail_end = (tail_start + 1) % len(points)
|
||||
|
||||
if head_end > tail_end:
|
||||
head_start, tail_start = tail_start, head_start
|
||||
head_end, tail_end = tail_end, head_end
|
||||
head_inds = [head_start, head_end]
|
||||
tail_inds = [tail_start, tail_end]
|
||||
else:
|
||||
if self.vector_slope(points[1] - points[0]) + self.vector_slope(
|
||||
points[3] - points[2]) < self.vector_slope(points[
|
||||
2] - points[1]) + self.vector_slope(points[0] - points[
|
||||
3]):
|
||||
horizontal_edge_inds = [[0, 1], [2, 3]]
|
||||
vertical_edge_inds = [[3, 0], [1, 2]]
|
||||
else:
|
||||
horizontal_edge_inds = [[3, 0], [1, 2]]
|
||||
vertical_edge_inds = [[0, 1], [2, 3]]
|
||||
|
||||
vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
|
||||
vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
|
||||
0]] - points[vertical_edge_inds[1][1]])
|
||||
horizontal_len_sum = norm(points[horizontal_edge_inds[0][
|
||||
0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
|
||||
horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
|
||||
[1]])
|
||||
|
||||
if vertical_len_sum > horizontal_len_sum * orientation_thr:
|
||||
head_inds = horizontal_edge_inds[0]
|
||||
tail_inds = horizontal_edge_inds[1]
|
||||
else:
|
||||
head_inds = vertical_edge_inds[0]
|
||||
tail_inds = vertical_edge_inds[1]
|
||||
|
||||
return head_inds, tail_inds
|
||||
|
||||
def resample_sidelines(self, sideline1, sideline2, resample_step):
|
||||
"""Resample two sidelines to be of the same points number according to
|
||||
step size.
|
||||
|
||||
Args:
|
||||
sideline1 (ndarray): The points composing a sideline of a text
|
||||
polygon.
|
||||
sideline2 (ndarray): The points composing another sideline of a
|
||||
text polygon.
|
||||
resample_step (float): The resampled step size.
|
||||
|
||||
Returns:
|
||||
resampled_line1 (ndarray): The resampled line 1.
|
||||
resampled_line2 (ndarray): The resampled line 2.
|
||||
"""
|
||||
|
||||
assert sideline1.ndim == sideline2.ndim == 2
|
||||
assert sideline1.shape[1] == sideline2.shape[1] == 2
|
||||
assert sideline1.shape[0] >= 2
|
||||
assert sideline2.shape[0] >= 2
|
||||
assert isinstance(resample_step, float)
|
||||
|
||||
length1 = sum([
|
||||
norm(sideline1[i + 1] - sideline1[i])
|
||||
for i in range(len(sideline1) - 1)
|
||||
])
|
||||
length2 = sum([
|
||||
norm(sideline2[i + 1] - sideline2[i])
|
||||
for i in range(len(sideline2) - 1)
|
||||
])
|
||||
|
||||
total_length = (length1 + length2) / 2
|
||||
resample_point_num = max(int(float(total_length) / resample_step), 1)
|
||||
|
||||
resampled_line1 = self.resample_line(sideline1, resample_point_num)
|
||||
resampled_line2 = self.resample_line(sideline2, resample_point_num)
|
||||
|
||||
return resampled_line1, resampled_line2
|
||||
|
||||
def generate_center_region_mask(self, img_size, text_polys):
|
||||
"""Generate text center region mask.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size of (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
# assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
|
||||
center_region_mask = np.zeros((h, w), np.uint8)
|
||||
|
||||
center_region_boxes = []
|
||||
for poly in text_polys:
|
||||
# assert len(poly) == 1
|
||||
polygon_points = poly.reshape(-1, 2)
|
||||
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
||||
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
||||
top_line, bot_line, self.resample_step)
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
center_line = (resampled_top_line + resampled_bot_line) / 2
|
||||
|
||||
line_head_shrink_len = norm(resampled_top_line[0] -
|
||||
resampled_bot_line[0]) / 4.0
|
||||
line_tail_shrink_len = norm(resampled_top_line[-1] -
|
||||
resampled_bot_line[-1]) / 4.0
|
||||
head_shrink_num = int(line_head_shrink_len // self.resample_step)
|
||||
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
|
||||
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
|
||||
center_line = center_line[head_shrink_num:len(center_line) -
|
||||
tail_shrink_num]
|
||||
resampled_top_line = resampled_top_line[head_shrink_num:len(
|
||||
resampled_top_line) - tail_shrink_num]
|
||||
resampled_bot_line = resampled_bot_line[head_shrink_num:len(
|
||||
resampled_bot_line) - tail_shrink_num]
|
||||
|
||||
for i in range(0, len(center_line) - 1):
|
||||
tl = center_line[i] + (resampled_top_line[i] - center_line[i]
|
||||
) * self.center_region_shrink_ratio
|
||||
tr = center_line[i + 1] + (resampled_top_line[i + 1] -
|
||||
center_line[i + 1]
|
||||
) * self.center_region_shrink_ratio
|
||||
br = center_line[i + 1] + (resampled_bot_line[i + 1] -
|
||||
center_line[i + 1]
|
||||
) * self.center_region_shrink_ratio
|
||||
bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
|
||||
) * self.center_region_shrink_ratio
|
||||
current_center_box = np.vstack([tl, tr, br,
|
||||
bl]).astype(np.int32)
|
||||
center_region_boxes.append(current_center_box)
|
||||
|
||||
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
|
||||
return center_region_mask
|
||||
|
||||
def resample_polygon(self, polygon, n=400):
|
||||
"""Resample one polygon with n points on its boundary.
|
||||
|
||||
Args:
|
||||
polygon (list[float]): The input polygon.
|
||||
n (int): The number of resampled points.
|
||||
Returns:
|
||||
resampled_polygon (list[float]): The resampled polygon.
|
||||
"""
|
||||
length = []
|
||||
|
||||
for i in range(len(polygon)):
|
||||
p1 = polygon[i]
|
||||
if i == len(polygon) - 1:
|
||||
p2 = polygon[0]
|
||||
else:
|
||||
p2 = polygon[i + 1]
|
||||
length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
|
||||
|
||||
total_length = sum(length)
|
||||
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
|
||||
n_on_each_line = n_on_each_line.astype(np.int32)
|
||||
new_polygon = []
|
||||
|
||||
for i in range(len(polygon)):
|
||||
num = n_on_each_line[i]
|
||||
p1 = polygon[i]
|
||||
if i == len(polygon) - 1:
|
||||
p2 = polygon[0]
|
||||
else:
|
||||
p2 = polygon[i + 1]
|
||||
|
||||
if num == 0:
|
||||
continue
|
||||
|
||||
dxdy = (p2 - p1) / num
|
||||
for j in range(num):
|
||||
point = p1 + dxdy * j
|
||||
new_polygon.append(point)
|
||||
|
||||
return np.array(new_polygon)
|
||||
|
||||
def normalize_polygon(self, polygon):
|
||||
"""Normalize one polygon so that its start point is at right most.
|
||||
|
||||
Args:
|
||||
polygon (list[float]): The origin polygon.
|
||||
Returns:
|
||||
new_polygon (lost[float]): The polygon with start point at right.
|
||||
"""
|
||||
temp_polygon = polygon - polygon.mean(axis=0)
|
||||
x = np.abs(temp_polygon[:, 0])
|
||||
y = temp_polygon[:, 1]
|
||||
index_x = np.argsort(x)
|
||||
index_y = np.argmin(y[index_x[:8]])
|
||||
index = index_x[index_y]
|
||||
new_polygon = np.concatenate([polygon[index:], polygon[:index]])
|
||||
return new_polygon
|
||||
|
||||
def poly2fourier(self, polygon, fourier_degree):
|
||||
"""Perform Fourier transformation to generate Fourier coefficients ck
|
||||
from polygon.
|
||||
|
||||
Args:
|
||||
polygon (ndarray): An input polygon.
|
||||
fourier_degree (int): The maximum Fourier degree K.
|
||||
Returns:
|
||||
c (ndarray(complex)): Fourier coefficients.
|
||||
"""
|
||||
points = polygon[:, 0] + polygon[:, 1] * 1j
|
||||
c_fft = fft(points) / len(points)
|
||||
c = np.hstack((c_fft[-fourier_degree:], c_fft[:fourier_degree + 1]))
|
||||
return c
|
||||
|
||||
def clockwise(self, c, fourier_degree):
|
||||
"""Make sure the polygon reconstructed from Fourier coefficients c in
|
||||
the clockwise direction.
|
||||
|
||||
Args:
|
||||
polygon (list[float]): The origin polygon.
|
||||
Returns:
|
||||
new_polygon (lost[float]): The polygon in clockwise point order.
|
||||
"""
|
||||
if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
|
||||
return c
|
||||
elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
|
||||
return c[::-1]
|
||||
else:
|
||||
if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
|
||||
return c
|
||||
else:
|
||||
return c[::-1]
|
||||
|
||||
def cal_fourier_signature(self, polygon, fourier_degree):
|
||||
"""Calculate Fourier signature from input polygon.
|
||||
|
||||
Args:
|
||||
polygon (ndarray): The input polygon.
|
||||
fourier_degree (int): The maximum Fourier degree K.
|
||||
Returns:
|
||||
fourier_signature (ndarray): An array shaped (2k+1, 2) containing
|
||||
real part and image part of 2k+1 Fourier coefficients.
|
||||
"""
|
||||
resampled_polygon = self.resample_polygon(polygon)
|
||||
resampled_polygon = self.normalize_polygon(resampled_polygon)
|
||||
|
||||
fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree)
|
||||
fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
|
||||
|
||||
real_part = np.real(fourier_coeff).reshape((-1, 1))
|
||||
image_part = np.imag(fourier_coeff).reshape((-1, 1))
|
||||
fourier_signature = np.hstack([real_part, image_part])
|
||||
|
||||
return fourier_signature
|
||||
|
||||
def generate_fourier_maps(self, img_size, text_polys):
|
||||
"""Generate Fourier coefficient maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size of (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
fourier_real_map (ndarray): The Fourier coefficient real part maps.
|
||||
fourier_image_map (ndarray): The Fourier coefficient image part
|
||||
maps.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
|
||||
h, w = img_size
|
||||
k = self.fourier_degree
|
||||
real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
||||
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
||||
|
||||
for poly in text_polys:
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
polygon = np.array(poly).reshape((1, -1, 2))
|
||||
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
|
||||
fourier_coeff = self.cal_fourier_signature(polygon[0], k)
|
||||
for i in range(-k, k + 1):
|
||||
if i != 0:
|
||||
real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
|
||||
1 - mask) * real_map[i + k, :, :]
|
||||
imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
|
||||
1 - mask) * imag_map[i + k, :, :]
|
||||
else:
|
||||
yx = np.argwhere(mask > 0.5)
|
||||
k_ind = np.ones((len(yx)), dtype=np.int64) * k
|
||||
y, x = yx[:, 0], yx[:, 1]
|
||||
real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
|
||||
imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
|
||||
|
||||
return real_map, imag_map
|
||||
|
||||
def generate_text_region_mask(self, img_size, text_polys):
|
||||
"""Generate text center region mask and geometry attribute maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
text_region_mask (ndarray): The text region mask.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
|
||||
h, w = img_size
|
||||
text_region_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
for poly in text_polys:
|
||||
polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
|
||||
cv2.fillPoly(text_region_mask, polygon, 1)
|
||||
|
||||
return text_region_mask
|
||||
|
||||
def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
|
||||
"""Generate effective mask by setting the ineffective regions to 0 and
|
||||
effective regions to 1.
|
||||
|
||||
Args:
|
||||
mask_size (tuple): The mask size.
|
||||
polygons_ignore (list[[ndarray]]: The list of ignored text
|
||||
polygons.
|
||||
|
||||
Returns:
|
||||
mask (ndarray): The effective mask of (height, width).
|
||||
"""
|
||||
|
||||
mask = np.ones(mask_size, dtype=np.uint8)
|
||||
|
||||
for poly in polygons_ignore:
|
||||
instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2)
|
||||
cv2.fillPoly(mask, instance, 0)
|
||||
|
||||
return mask
|
||||
|
||||
def generate_level_targets(self, img_size, text_polys, ignore_polys):
|
||||
"""Generate ground truth target on each level.
|
||||
|
||||
Args:
|
||||
img_size (list[int]): Shape of input image.
|
||||
text_polys (list[list[ndarray]]): A list of ground truth polygons.
|
||||
ignore_polys (list[list[ndarray]]): A list of ignored polygons.
|
||||
Returns:
|
||||
level_maps (list(ndarray)): A list of ground target on each level.
|
||||
"""
|
||||
h, w = img_size
|
||||
lv_size_divs = self.level_size_divisors
|
||||
lv_proportion_range = self.level_proportion_range
|
||||
lv_text_polys = [[] for i in range(len(lv_size_divs))]
|
||||
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
|
||||
level_maps = []
|
||||
for poly in text_polys:
|
||||
polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
|
||||
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
||||
proportion = max(box_h, box_w) / (h + 1e-8)
|
||||
|
||||
for ind, proportion_range in enumerate(lv_proportion_range):
|
||||
if proportion_range[0] < proportion < proportion_range[1]:
|
||||
lv_text_polys[ind].append(poly / lv_size_divs[ind])
|
||||
|
||||
for ignore_poly in ignore_polys:
|
||||
polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
|
||||
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
||||
proportion = max(box_h, box_w) / (h + 1e-8)
|
||||
|
||||
for ind, proportion_range in enumerate(lv_proportion_range):
|
||||
if proportion_range[0] < proportion < proportion_range[1]:
|
||||
lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind])
|
||||
|
||||
for ind, size_divisor in enumerate(lv_size_divs):
|
||||
current_level_maps = []
|
||||
level_img_size = (h // size_divisor, w // size_divisor)
|
||||
|
||||
text_region = self.generate_text_region_mask(
|
||||
level_img_size, lv_text_polys[ind])[None]
|
||||
current_level_maps.append(text_region)
|
||||
|
||||
center_region = self.generate_center_region_mask(
|
||||
level_img_size, lv_text_polys[ind])[None]
|
||||
current_level_maps.append(center_region)
|
||||
|
||||
effective_mask = self.generate_effective_mask(
|
||||
level_img_size, lv_ignore_polys[ind])[None]
|
||||
current_level_maps.append(effective_mask)
|
||||
|
||||
fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
|
||||
level_img_size, lv_text_polys[ind])
|
||||
current_level_maps.append(fourier_real_map)
|
||||
current_level_maps.append(fourier_image_maps)
|
||||
|
||||
level_maps.append(np.concatenate(current_level_maps))
|
||||
|
||||
return level_maps
|
||||
|
||||
def generate_targets(self, results):
|
||||
"""Generate the ground truth targets for FCENet.
|
||||
|
||||
Args:
|
||||
results (dict): The input result dictionary.
|
||||
|
||||
Returns:
|
||||
results (dict): The output result dictionary.
|
||||
"""
|
||||
|
||||
assert isinstance(results, dict)
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
ignore_tags = results['ignore_tags']
|
||||
h, w, _ = image.shape
|
||||
|
||||
polygon_masks = []
|
||||
polygon_masks_ignore = []
|
||||
for tag, polygon in zip(ignore_tags, polygons):
|
||||
if tag is True:
|
||||
polygon_masks_ignore.append(polygon)
|
||||
else:
|
||||
polygon_masks.append(polygon)
|
||||
|
||||
level_maps = self.generate_level_targets((h, w), polygon_masks,
|
||||
polygon_masks_ignore)
|
||||
|
||||
mapping = {
|
||||
'p3_maps': level_maps[0],
|
||||
'p4_maps': level_maps[1],
|
||||
'p5_maps': level_maps[2]
|
||||
}
|
||||
for key, value in mapping.items():
|
||||
results[key] = value
|
||||
|
||||
return results
|
||||
|
||||
def __call__(self, results):
|
||||
results = self.generate_targets(results)
|
||||
return results
|
||||
@@ -1,244 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GenTableMask(object):
|
||||
""" gen table mask """
|
||||
|
||||
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
|
||||
self.shrink_h_max = 5
|
||||
self.shrink_w_max = 5
|
||||
self.mask_type = mask_type
|
||||
|
||||
def projection(self, erosion, h, w, spilt_threshold=0):
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
return box_list, projection_map
|
||||
|
||||
def projection_cx(self, box_img):
|
||||
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
|
||||
h, w = box_gray_img.shape
|
||||
# 灰度图片进行二值化处理
|
||||
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
|
||||
# 纵向腐蚀
|
||||
if h < w:
|
||||
kernel = np.ones((2, 1), np.uint8)
|
||||
erode = cv2.erode(thresh1, kernel, iterations=1)
|
||||
else:
|
||||
erode = thresh1
|
||||
# 水平膨胀
|
||||
kernel = np.ones((1, 5), np.uint8)
|
||||
erosion = cv2.dilate(erode, kernel, iterations=1)
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
spilt_threshold = 0
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
split_bbox_list = []
|
||||
if len(box_list) > 1:
|
||||
for i, (h_start, h_end) in enumerate(box_list):
|
||||
if i == 0:
|
||||
h_start = 0
|
||||
if i == len(box_list):
|
||||
h_end = h
|
||||
word_img = erosion[h_start:h_end + 1, :]
|
||||
word_h, word_w = word_img.shape
|
||||
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
|
||||
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
|
||||
if h_start > 0:
|
||||
h_start -= 1
|
||||
h_end += 1
|
||||
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
|
||||
split_bbox_list.append([w_start, h_start, w_end, h_end])
|
||||
else:
|
||||
split_bbox_list.append([0, 0, w, h])
|
||||
return split_bbox_list
|
||||
|
||||
def shrink_bbox(self, bbox):
|
||||
left, top, right, bottom = bbox
|
||||
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
|
||||
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
|
||||
left_new = left + sh_w
|
||||
right_new = right - sh_w
|
||||
top_new = top + sh_h
|
||||
bottom_new = bottom - sh_h
|
||||
if left_new >= right_new:
|
||||
left_new = left
|
||||
right_new = right
|
||||
if top_new >= bottom_new:
|
||||
top_new = top
|
||||
bottom_new = bottom
|
||||
return [left_new, top_new, right_new, bottom_new]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
cells = data['cells']
|
||||
height, width = img.shape[0:2]
|
||||
if self.mask_type == 1:
|
||||
mask_img = np.zeros((height, width), dtype=np.float32)
|
||||
else:
|
||||
mask_img = np.zeros((height, width, 3), dtype=np.float32)
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
left, top, right, bottom = bbox
|
||||
box_img = img[top:bottom, left:right, :].copy()
|
||||
split_bbox_list = self.projection_cx(box_img)
|
||||
for sno in range(len(split_bbox_list)):
|
||||
split_bbox_list[sno][0] += left
|
||||
split_bbox_list[sno][1] += top
|
||||
split_bbox_list[sno][2] += left
|
||||
split_bbox_list[sno][3] += top
|
||||
|
||||
for sno in range(len(split_bbox_list)):
|
||||
left, top, right, bottom = split_bbox_list[sno]
|
||||
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
|
||||
if self.mask_type == 1:
|
||||
mask_img[top:bottom, left:right] = 1.0
|
||||
data['mask_img'] = mask_img
|
||||
else:
|
||||
mask_img[top:bottom, left:right, :] = (255, 255, 255)
|
||||
data['image'] = mask_img
|
||||
return data
|
||||
|
||||
class ResizeTableImage(object):
|
||||
def __init__(self, max_len, **kwargs):
|
||||
super(ResizeTableImage, self).__init__()
|
||||
self.max_len = max_len
|
||||
|
||||
def get_img_bbox(self, cells):
|
||||
bbox_list = []
|
||||
if len(cells) == 0:
|
||||
return bbox_list
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
bbox_list.append(bbox)
|
||||
return bbox_list
|
||||
|
||||
def resize_img_table(self, img, bbox_list, max_len):
|
||||
height, width = img.shape[0:2]
|
||||
ratio = max_len / (max(height, width) * 1.0)
|
||||
resize_h = int(height * ratio)
|
||||
resize_w = int(width * ratio)
|
||||
img_new = cv2.resize(img, (resize_w, resize_h))
|
||||
bbox_list_new = []
|
||||
for bno in range(len(bbox_list)):
|
||||
left, top, right, bottom = bbox_list[bno].copy()
|
||||
left = int(left * ratio)
|
||||
top = int(top * ratio)
|
||||
right = int(right * ratio)
|
||||
bottom = int(bottom * ratio)
|
||||
bbox_list_new.append([left, top, right, bottom])
|
||||
return img_new, bbox_list_new
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if 'cells' not in data:
|
||||
cells = []
|
||||
else:
|
||||
cells = data['cells']
|
||||
bbox_list = self.get_img_bbox(cells)
|
||||
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
|
||||
data['image'] = img_new
|
||||
cell_num = len(cells)
|
||||
bno = 0
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in data['cells'][cno]:
|
||||
data['cells'][cno]['bbox'] = bbox_list_new[bno]
|
||||
bno += 1
|
||||
data['max_len'] = self.max_len
|
||||
return data
|
||||
|
||||
class PaddingTableImage(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(PaddingTableImage, self).__init__()
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
max_len = data['max_len']
|
||||
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
|
||||
height, width = img.shape[0:2]
|
||||
padding_img[0:height, 0:width, :] = img.copy()
|
||||
data['image'] = padding_img
|
||||
return data
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
import imgaug
|
||||
import imgaug.augmenters as iaa
|
||||
|
||||
|
||||
class AugmenterBuilder(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def build(self, args, root=True):
|
||||
if args is None or len(args) == 0:
|
||||
return None
|
||||
elif isinstance(args, list):
|
||||
if root:
|
||||
sequence = [self.build(value, root=False) for value in args]
|
||||
return iaa.Sequential(sequence)
|
||||
else:
|
||||
return getattr(iaa, args[0])(
|
||||
*[self.to_tuple_if_list(a) for a in args[1:]])
|
||||
elif isinstance(args, dict):
|
||||
cls = getattr(iaa, args['type'])
|
||||
return cls(**{
|
||||
k: self.to_tuple_if_list(v)
|
||||
for k, v in args['args'].items()
|
||||
})
|
||||
else:
|
||||
raise RuntimeError('unknown augmenter arg: ' + str(args))
|
||||
|
||||
def to_tuple_if_list(self, obj):
|
||||
if isinstance(obj, list):
|
||||
return tuple(obj)
|
||||
return obj
|
||||
|
||||
|
||||
class IaaAugment():
|
||||
def __init__(self, augmenter_args=None, **kwargs):
|
||||
if augmenter_args is None:
|
||||
augmenter_args = [{
|
||||
'type': 'Fliplr',
|
||||
'args': {
|
||||
'p': 0.5
|
||||
}
|
||||
}, {
|
||||
'type': 'Affine',
|
||||
'args': {
|
||||
'rotate': [-10, 10]
|
||||
}
|
||||
}, {
|
||||
'type': 'Resize',
|
||||
'args': {
|
||||
'size': [0.5, 3]
|
||||
}
|
||||
}]
|
||||
self.augmenter = AugmenterBuilder().build(augmenter_args)
|
||||
|
||||
def __call__(self, data):
|
||||
image = data['image']
|
||||
shape = image.shape
|
||||
|
||||
if self.augmenter:
|
||||
aug = self.augmenter.to_deterministic()
|
||||
data['image'] = aug.augment_image(image)
|
||||
data = self.may_augment_annotation(aug, data, shape)
|
||||
return data
|
||||
|
||||
def may_augment_annotation(self, aug, data, shape):
|
||||
if aug is None:
|
||||
return data
|
||||
|
||||
line_polys = []
|
||||
for poly in data['polys']:
|
||||
new_poly = self.may_augment_poly(aug, shape, poly)
|
||||
line_polys.append(new_poly)
|
||||
data['polys'] = np.array(line_polys)
|
||||
return data
|
||||
|
||||
def may_augment_poly(self, aug, img_shape, poly):
|
||||
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
|
||||
keypoints = aug.augment_keypoints(
|
||||
[imgaug.KeypointsOnImage(
|
||||
keypoints, shape=img_shape)])[0].keypoints
|
||||
poly = [(p.x, p.y) for p in keypoints]
|
||||
return poly
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,173 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
np.seterr(divide='ignore', invalid='ignore')
|
||||
import pyclipper
|
||||
from shapely.geometry import Polygon
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
__all__ = ['MakeBorderMap']
|
||||
|
||||
|
||||
class MakeBorderMap(object):
|
||||
def __init__(self,
|
||||
shrink_ratio=0.4,
|
||||
thresh_min=0.3,
|
||||
thresh_max=0.7,
|
||||
**kwargs):
|
||||
self.shrink_ratio = shrink_ratio
|
||||
self.thresh_min = thresh_min
|
||||
self.thresh_max = thresh_max
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
img = data['image']
|
||||
text_polys = data['polys']
|
||||
ignore_tags = data['ignore_tags']
|
||||
|
||||
canvas = np.zeros(img.shape[:2], dtype=np.float32)
|
||||
mask = np.zeros(img.shape[:2], dtype=np.float32)
|
||||
|
||||
for i in range(len(text_polys)):
|
||||
if ignore_tags[i]:
|
||||
continue
|
||||
self.draw_border_map(text_polys[i], canvas, mask=mask)
|
||||
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
|
||||
|
||||
data['threshold_map'] = canvas
|
||||
data['threshold_mask'] = mask
|
||||
return data
|
||||
|
||||
def draw_border_map(self, polygon, canvas, mask):
|
||||
polygon = np.array(polygon)
|
||||
assert polygon.ndim == 2
|
||||
assert polygon.shape[1] == 2
|
||||
|
||||
polygon_shape = Polygon(polygon)
|
||||
if polygon_shape.area <= 0:
|
||||
return
|
||||
distance = polygon_shape.area * (
|
||||
1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
|
||||
subject = [tuple(l) for l in polygon]
|
||||
padding = pyclipper.PyclipperOffset()
|
||||
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
|
||||
padded_polygon = np.array(padding.Execute(distance)[0])
|
||||
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
||||
|
||||
xmin = padded_polygon[:, 0].min()
|
||||
xmax = padded_polygon[:, 0].max()
|
||||
ymin = padded_polygon[:, 1].min()
|
||||
ymax = padded_polygon[:, 1].max()
|
||||
width = xmax - xmin + 1
|
||||
height = ymax - ymin + 1
|
||||
|
||||
polygon[:, 0] = polygon[:, 0] - xmin
|
||||
polygon[:, 1] = polygon[:, 1] - ymin
|
||||
|
||||
xs = np.broadcast_to(
|
||||
np.linspace(
|
||||
0, width - 1, num=width).reshape(1, width), (height, width))
|
||||
ys = np.broadcast_to(
|
||||
np.linspace(
|
||||
0, height - 1, num=height).reshape(height, 1), (height, width))
|
||||
|
||||
distance_map = np.zeros(
|
||||
(polygon.shape[0], height, width), dtype=np.float32)
|
||||
for i in range(polygon.shape[0]):
|
||||
j = (i + 1) % polygon.shape[0]
|
||||
absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
|
||||
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
|
||||
distance_map = distance_map.min(axis=0)
|
||||
|
||||
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
|
||||
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
|
||||
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
|
||||
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
|
||||
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
|
||||
1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
|
||||
xmin_valid - xmin:xmax_valid - xmax + width],
|
||||
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
|
||||
|
||||
def _distance(self, xs, ys, point_1, point_2):
|
||||
'''
|
||||
compute the distance from point to a line
|
||||
ys: coordinates in the first axis
|
||||
xs: coordinates in the second axis
|
||||
point_1, point_2: (x, y), the end of the line
|
||||
'''
|
||||
height, width = xs.shape[:2]
|
||||
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
|
||||
1])
|
||||
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
|
||||
1])
|
||||
square_distance = np.square(point_1[0] - point_2[0]) + np.square(
|
||||
point_1[1] - point_2[1])
|
||||
|
||||
cosin = (square_distance - square_distance_1 - square_distance_2) / (
|
||||
2 * np.sqrt(square_distance_1 * square_distance_2))
|
||||
square_sin = 1 - np.square(cosin)
|
||||
square_sin = np.nan_to_num(square_sin)
|
||||
result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
|
||||
square_distance)
|
||||
|
||||
result[cosin <
|
||||
0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
|
||||
< 0]
|
||||
# self.extend_line(point_1, point_2, result)
|
||||
return result
|
||||
|
||||
def extend_line(self, point_1, point_2, result, shrink_ratio):
|
||||
ex_point_1 = (int(
|
||||
round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
|
||||
int(
|
||||
round(point_1[1] + (point_1[1] - point_2[1]) * (
|
||||
1 + shrink_ratio))))
|
||||
cv2.line(
|
||||
result,
|
||||
tuple(ex_point_1),
|
||||
tuple(point_1),
|
||||
4096.0,
|
||||
1,
|
||||
lineType=cv2.LINE_AA,
|
||||
shift=0)
|
||||
ex_point_2 = (int(
|
||||
round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
|
||||
int(
|
||||
round(point_2[1] + (point_2[1] - point_1[1]) * (
|
||||
1 + shrink_ratio))))
|
||||
cv2.line(
|
||||
result,
|
||||
tuple(ex_point_2),
|
||||
tuple(point_2),
|
||||
4096.0,
|
||||
1,
|
||||
lineType=cv2.LINE_AA,
|
||||
shift=0)
|
||||
return ex_point_1, ex_point_2
|
||||
@@ -1,106 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pyclipper
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
__all__ = ['MakePseGt']
|
||||
|
||||
|
||||
class MakePseGt(object):
|
||||
def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
|
||||
self.kernel_num = kernel_num
|
||||
self.min_shrink_ratio = min_shrink_ratio
|
||||
self.size = size
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
image = data['image']
|
||||
text_polys = data['polys']
|
||||
ignore_tags = data['ignore_tags']
|
||||
|
||||
h, w, _ = image.shape
|
||||
short_edge = min(h, w)
|
||||
if short_edge < self.size:
|
||||
# keep short_size >= self.size
|
||||
scale = self.size / short_edge
|
||||
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
||||
text_polys *= scale
|
||||
|
||||
gt_kernels = []
|
||||
for i in range(1, self.kernel_num + 1):
|
||||
# s1->sn, from big to small
|
||||
rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1
|
||||
) * i
|
||||
text_kernel, ignore_tags = self.generate_kernel(
|
||||
image.shape[0:2], rate, text_polys, ignore_tags)
|
||||
gt_kernels.append(text_kernel)
|
||||
|
||||
training_mask = np.ones(image.shape[0:2], dtype='uint8')
|
||||
for i in range(text_polys.shape[0]):
|
||||
if ignore_tags[i]:
|
||||
cv2.fillPoly(training_mask,
|
||||
text_polys[i].astype(np.int32)[np.newaxis, :, :],
|
||||
0)
|
||||
|
||||
gt_kernels = np.array(gt_kernels)
|
||||
gt_kernels[gt_kernels > 0] = 1
|
||||
|
||||
data['image'] = image
|
||||
data['polys'] = text_polys
|
||||
data['gt_kernels'] = gt_kernels[0:]
|
||||
data['gt_text'] = gt_kernels[0]
|
||||
data['mask'] = training_mask.astype('float32')
|
||||
return data
|
||||
|
||||
def generate_kernel(self,
|
||||
img_size,
|
||||
shrink_ratio,
|
||||
text_polys,
|
||||
ignore_tags=None):
|
||||
"""
|
||||
Refer to part of the code:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
|
||||
"""
|
||||
|
||||
h, w = img_size
|
||||
text_kernel = np.zeros((h, w), dtype=np.float32)
|
||||
for i, poly in enumerate(text_polys):
|
||||
polygon = Polygon(poly)
|
||||
distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (
|
||||
polygon.length + 1e-6)
|
||||
subject = [tuple(l) for l in poly]
|
||||
pco = pyclipper.PyclipperOffset()
|
||||
pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
shrinked = np.array(pco.Execute(-distance))
|
||||
|
||||
if len(shrinked) == 0 or shrinked.size == 0:
|
||||
if ignore_tags is not None:
|
||||
ignore_tags[i] = True
|
||||
continue
|
||||
try:
|
||||
shrinked = np.array(shrinked[0]).reshape(-1, 2)
|
||||
except:
|
||||
if ignore_tags is not None:
|
||||
ignore_tags[i] = True
|
||||
continue
|
||||
cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
|
||||
return text_kernel, ignore_tags
|
||||
@@ -1,123 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
__all__ = ['MakeShrinkMap']
|
||||
|
||||
|
||||
class MakeShrinkMap(object):
|
||||
r'''
|
||||
Making binary mask from detection data with ICDAR format.
|
||||
Typically following the process of class `MakeICDARData`.
|
||||
'''
|
||||
|
||||
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
|
||||
self.min_text_size = min_text_size
|
||||
self.shrink_ratio = shrink_ratio
|
||||
|
||||
def __call__(self, data):
|
||||
image = data['image']
|
||||
text_polys = data['polys']
|
||||
ignore_tags = data['ignore_tags']
|
||||
|
||||
h, w = image.shape[:2]
|
||||
text_polys, ignore_tags = self.validate_polygons(text_polys,
|
||||
ignore_tags, h, w)
|
||||
gt = np.zeros((h, w), dtype=np.float32)
|
||||
mask = np.ones((h, w), dtype=np.float32)
|
||||
for i in range(len(text_polys)):
|
||||
polygon = text_polys[i]
|
||||
height = max(polygon[:, 1]) - min(polygon[:, 1])
|
||||
width = max(polygon[:, 0]) - min(polygon[:, 0])
|
||||
if ignore_tags[i] or min(height, width) < self.min_text_size:
|
||||
cv2.fillPoly(mask,
|
||||
polygon.astype(np.int32)[np.newaxis, :, :], 0)
|
||||
ignore_tags[i] = True
|
||||
else:
|
||||
polygon_shape = Polygon(polygon)
|
||||
subject = [tuple(l) for l in polygon]
|
||||
padding = pyclipper.PyclipperOffset()
|
||||
padding.AddPath(subject, pyclipper.JT_ROUND,
|
||||
pyclipper.ET_CLOSEDPOLYGON)
|
||||
shrinked = []
|
||||
|
||||
# Increase the shrink ratio every time we get multiple polygon returned back
|
||||
possible_ratios = np.arange(self.shrink_ratio, 1,
|
||||
self.shrink_ratio)
|
||||
np.append(possible_ratios, 1)
|
||||
# print(possible_ratios)
|
||||
for ratio in possible_ratios:
|
||||
# print(f"Change shrink ratio to {ratio}")
|
||||
distance = polygon_shape.area * (
|
||||
1 - np.power(ratio, 2)) / polygon_shape.length
|
||||
shrinked = padding.Execute(-distance)
|
||||
if len(shrinked) == 1:
|
||||
break
|
||||
|
||||
if shrinked == []:
|
||||
cv2.fillPoly(mask,
|
||||
polygon.astype(np.int32)[np.newaxis, :, :], 0)
|
||||
ignore_tags[i] = True
|
||||
continue
|
||||
|
||||
for each_shirnk in shrinked:
|
||||
shirnk = np.array(each_shirnk).reshape(-1, 2)
|
||||
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
|
||||
|
||||
data['shrink_map'] = gt
|
||||
data['shrink_mask'] = mask
|
||||
return data
|
||||
|
||||
def validate_polygons(self, polygons, ignore_tags, h, w):
|
||||
'''
|
||||
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
|
||||
'''
|
||||
if len(polygons) == 0:
|
||||
return polygons, ignore_tags
|
||||
assert len(polygons) == len(ignore_tags)
|
||||
for polygon in polygons:
|
||||
polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
|
||||
polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
|
||||
|
||||
for i in range(len(polygons)):
|
||||
area = self.polygon_area(polygons[i])
|
||||
if abs(area) < 1:
|
||||
ignore_tags[i] = True
|
||||
if area > 0:
|
||||
polygons[i] = polygons[i][::-1, :]
|
||||
return polygons, ignore_tags
|
||||
|
||||
def polygon_area(self, polygon):
|
||||
"""
|
||||
compute polygon area
|
||||
"""
|
||||
area = 0
|
||||
q = polygon[-1]
|
||||
for p in polygon:
|
||||
area += p[0] * q[1] - p[1] * q[0]
|
||||
q = p
|
||||
return area / 2.0
|
||||
@@ -1,468 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
def __init__(self,
|
||||
img_mode='RGB',
|
||||
channel_first=False,
|
||||
ignore_orientation=False,
|
||||
**kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
self.ignore_orientation = ignore_orientation
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if six.PY2:
|
||||
assert type(img) is str and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert type(img) is bytes and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
if self.ignore_orientation:
|
||||
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
|
||||
cv2.IMREAD_COLOR)
|
||||
else:
|
||||
img = cv2.imdecode(img, 1)
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class NRTRDecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if six.PY2:
|
||||
assert type(img) is str and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert type(img) is bytes and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
|
||||
img = cv2.imdecode(img, 1)
|
||||
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
||||
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data['image'] = (
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
class ToCHWImage(object):
|
||||
""" convert hwc image to chw image
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
data['image'] = img.transpose((2, 0, 1))
|
||||
return data
|
||||
|
||||
|
||||
class Fasttext(object):
|
||||
def __init__(self, path="None", **kwargs):
|
||||
import fasttext
|
||||
self.fast_model = fasttext.load_model(path)
|
||||
|
||||
def __call__(self, data):
|
||||
label = data['label']
|
||||
fast_label = self.fast_model[label]
|
||||
data['fast_label'] = fast_label
|
||||
return data
|
||||
|
||||
|
||||
class KeepKeys(object):
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
||||
def __call__(self, data):
|
||||
data_list = []
|
||||
for key in self.keep_keys:
|
||||
data_list.append(data[key])
|
||||
return data_list
|
||||
|
||||
|
||||
class Pad(object):
|
||||
def __init__(self, size=None, size_div=32, **kwargs):
|
||||
if size is not None and not isinstance(size, (int, list, tuple)):
|
||||
raise TypeError("Type of target_size is invalid. Now is {}".format(
|
||||
type(size)))
|
||||
if isinstance(size, int):
|
||||
size = [size, size]
|
||||
self.size = size
|
||||
self.size_div = size_div
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
img = data['image']
|
||||
img_h, img_w = img.shape[0], img.shape[1]
|
||||
if self.size:
|
||||
resize_h2, resize_w2 = self.size
|
||||
assert (
|
||||
img_h < resize_h2 and img_w < resize_w2
|
||||
), '(h, w) of target size should be greater than (img_h, img_w)'
|
||||
else:
|
||||
resize_h2 = max(
|
||||
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
|
||||
self.size_div)
|
||||
resize_w2 = max(
|
||||
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
|
||||
self.size_div)
|
||||
img = cv2.copyMakeBorder(
|
||||
img,
|
||||
0,
|
||||
resize_h2 - img_h,
|
||||
0,
|
||||
resize_w2 - img_w,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class Resize(object):
|
||||
def __init__(self, size=(640, 640), **kwargs):
|
||||
self.size = size
|
||||
|
||||
def resize_image(self, img):
|
||||
resize_h, resize_w = self.size
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if 'polys' in data:
|
||||
text_polys = data['polys']
|
||||
|
||||
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
|
||||
if 'polys' in data:
|
||||
new_boxes = []
|
||||
for box in text_polys:
|
||||
new_box = []
|
||||
for cord in box:
|
||||
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
|
||||
new_boxes.append(new_box)
|
||||
data['polys'] = np.array(new_boxes, dtype=np.float32)
|
||||
data['image'] = img_resize
|
||||
return data
|
||||
|
||||
|
||||
class DetResizeForTest(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(DetResizeForTest, self).__init__()
|
||||
self.resize_type = 0
|
||||
if 'image_shape' in kwargs:
|
||||
self.image_shape = kwargs['image_shape']
|
||||
self.resize_type = 1
|
||||
elif 'limit_side_len' in kwargs:
|
||||
self.limit_side_len = kwargs['limit_side_len']
|
||||
self.limit_type = kwargs.get('limit_type', 'min')
|
||||
elif 'resize_long' in kwargs:
|
||||
self.resize_type = 2
|
||||
self.resize_long = kwargs.get('resize_long', 960)
|
||||
else:
|
||||
self.limit_side_len = 736
|
||||
self.limit_type = 'min'
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
|
||||
if self.resize_type == 0:
|
||||
# img, shape = self.resize_image_type0(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
||||
elif self.resize_type == 2:
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
||||
else:
|
||||
# img, shape = self.resize_image_type1(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
||||
data['image'] = img
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def resize_image_type1(self, img):
|
||||
resize_h, resize_w = self.image_shape
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
# return img, np.array([ori_h, ori_w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type0(self, img):
|
||||
"""
|
||||
resize image to a size multiple of 32 which is required by the network
|
||||
args:
|
||||
img(array): array with shape [h, w, c]
|
||||
return(tuple):
|
||||
img, (ratio_h, ratio_w)
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, c = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == 'max':
|
||||
if max(h, w) > limit_side_len:
|
||||
if h > w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'min':
|
||||
if min(h, w) < limit_side_len:
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'resize_long':
|
||||
ratio = float(limit_side_len) / max(h, w)
|
||||
else:
|
||||
raise Exception('not support limit type, image ')
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
||||
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
||||
|
||||
try:
|
||||
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
||||
return None, (None, None)
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
except:
|
||||
print(img.shape, resize_w, resize_h)
|
||||
sys.exit(0)
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
h, w, _ = img.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.resize_long) / resize_h
|
||||
else:
|
||||
ratio = float(self.resize_long) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
class E2EResizeForTest(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(E2EResizeForTest, self).__init__()
|
||||
self.max_side_len = kwargs['max_side_len']
|
||||
self.valid_set = kwargs['valid_set']
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
if self.valid_set == 'totaltext':
|
||||
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
|
||||
img, max_side_len=self.max_side_len)
|
||||
else:
|
||||
im_resized, (ratio_h, ratio_w) = self.resize_image(
|
||||
img, max_side_len=self.max_side_len)
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def resize_image_for_totaltext(self, im, max_side_len=512):
|
||||
|
||||
h, w, _ = im.shape
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
ratio = 1.25
|
||||
if h * ratio > max_side_len:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def resize_image(self, im, max_side_len=512):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
class KieResize(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(KieResize, self).__init__()
|
||||
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
|
||||
'img_scale'][1]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
points = data['points']
|
||||
src_h, src_w, _ = img.shape
|
||||
im_resized, scale_factor, [ratio_h, ratio_w
|
||||
], [new_h, new_w] = self.resize_image(img)
|
||||
resize_points = self.resize_boxes(img, points, scale_factor)
|
||||
data['ori_image'] = img
|
||||
data['ori_boxes'] = points
|
||||
data['points'] = resize_points
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([new_h, new_w])
|
||||
return data
|
||||
|
||||
def resize_image(self, img):
|
||||
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
|
||||
scale = [512, 1024]
|
||||
h, w = img.shape[:2]
|
||||
max_long_edge = max(scale)
|
||||
max_short_edge = min(scale)
|
||||
scale_factor = min(max_long_edge / max(h, w),
|
||||
max_short_edge / min(h, w))
|
||||
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
|
||||
scale_factor) + 0.5)
|
||||
max_stride = 32
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(img, (resize_w, resize_h))
|
||||
new_h, new_w = im.shape[:2]
|
||||
w_scale = new_w / w
|
||||
h_scale = new_h / h
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
norm_img[:new_h, :new_w, :] = im
|
||||
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
|
||||
|
||||
def resize_boxes(self, im, points, scale_factor):
|
||||
points = points * scale_factor
|
||||
img_shape = im.shape[:2]
|
||||
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
|
||||
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
|
||||
return points
|
||||
@@ -1,906 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['PGProcessTrain']
|
||||
|
||||
|
||||
class PGProcessTrain(object):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
max_text_length,
|
||||
max_text_nums,
|
||||
tcl_len,
|
||||
batch_size=14,
|
||||
min_crop_size=24,
|
||||
min_text_size=4,
|
||||
max_text_size=512,
|
||||
**kwargs):
|
||||
self.tcl_len = tcl_len
|
||||
self.max_text_length = max_text_length
|
||||
self.max_text_nums = max_text_nums
|
||||
self.batch_size = batch_size
|
||||
self.min_crop_size = min_crop_size
|
||||
self.min_text_size = min_text_size
|
||||
self.max_text_size = max_text_size
|
||||
self.Lexicon_Table = self.get_dict(character_dict_path)
|
||||
self.pad_num = len(self.Lexicon_Table)
|
||||
self.img_id = 0
|
||||
|
||||
def get_dict(self, character_dict_path):
|
||||
character_str = ""
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
character_str += line
|
||||
dict_character = list(character_str)
|
||||
return dict_character
|
||||
|
||||
def quad_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def gen_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, im_size):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
:param polys:
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
(h, w) = im_size
|
||||
if polys.shape[0] == 0:
|
||||
return polys, np.array([]), np.array([])
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||
|
||||
validated_polys = []
|
||||
validated_tags = []
|
||||
hv_tags = []
|
||||
for poly, tag in zip(polys, tags):
|
||||
quad = self.gen_quad_from_poly(poly)
|
||||
p_area = self.quad_area(quad)
|
||||
if abs(p_area) < 1:
|
||||
print('invalid poly')
|
||||
continue
|
||||
if p_area > 0:
|
||||
if tag == False:
|
||||
print('poly in wrong direction')
|
||||
tag = True # reversed cases should be ignore
|
||||
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
|
||||
1), :]
|
||||
quad = quad[(0, 3, 2, 1), :]
|
||||
|
||||
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
|
||||
quad[2])
|
||||
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
|
||||
quad[2])
|
||||
hv_tag = 1
|
||||
|
||||
if len_w * 2.0 < len_h:
|
||||
hv_tag = 0
|
||||
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
hv_tags.append(hv_tag)
|
||||
return np.array(validated_polys), np.array(validated_tags), np.array(
|
||||
hv_tags)
|
||||
|
||||
def crop_area(self,
|
||||
im,
|
||||
polys,
|
||||
tags,
|
||||
hv_tags,
|
||||
txts,
|
||||
crop_background=False,
|
||||
max_tries=25):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
:param polys: [b,4,2]
|
||||
:param tags:
|
||||
:param crop_background:
|
||||
:param max_tries: 50 -> 25
|
||||
:return:
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
pad_h = h // 10
|
||||
pad_w = w // 10
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return im, polys, tags, hv_tags, txts
|
||||
for i in range(max_tries):
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
if xmax - xmin < self.min_crop_size or \
|
||||
ymax - ymin < self.min_crop_size:
|
||||
continue
|
||||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
||||
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(
|
||||
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
return im[ymin: ymax + 1, xmin: xmax + 1, :], \
|
||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
||||
else:
|
||||
continue
|
||||
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
hv_tags = hv_tags[selected_polys]
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
polys[:, :, 0] -= xmin
|
||||
polys[:, :, 1] -= ymin
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
def fit_and_gather_tcl_points_v2(self,
|
||||
min_area_quad,
|
||||
poly,
|
||||
max_h,
|
||||
max_w,
|
||||
fixed_point_num=64,
|
||||
img_id=0,
|
||||
reference_height=3):
|
||||
"""
|
||||
Find the center point of poly as key_points, then fit and gather.
|
||||
"""
|
||||
key_point_xys = []
|
||||
point_num = poly.shape[0]
|
||||
for idx in range(point_num // 2):
|
||||
center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
|
||||
key_point_xys.append(center_point)
|
||||
|
||||
tmp_image = np.zeros(
|
||||
shape=(
|
||||
max_h,
|
||||
max_w, ), dtype='float32')
|
||||
cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
|
||||
False, 1.0)
|
||||
ys, xs = np.where(tmp_image > 0)
|
||||
xy_text = np.array(list(zip(xs, ys)), dtype='float32')
|
||||
|
||||
left_center_pt = (
|
||||
(min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
|
||||
right_center_pt = (
|
||||
(min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
|
||||
proj_unit_vec = (right_center_pt - left_center_pt) / (
|
||||
np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
||||
proj_unit_vec_tile = np.tile(proj_unit_vec,
|
||||
(xy_text.shape[0], 1)) # (n, 2)
|
||||
left_center_pt_tile = np.tile(left_center_pt,
|
||||
(xy_text.shape[0], 1)) # (n, 2)
|
||||
xy_text_to_left_center = xy_text - left_center_pt_tile
|
||||
proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
|
||||
xy_text = xy_text[np.argsort(proj_value)]
|
||||
|
||||
# convert to np and keep the num of point not greater then fixed_point_num
|
||||
pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
|
||||
point_num = len(pos_info)
|
||||
if point_num > fixed_point_num:
|
||||
keep_ids = [
|
||||
int((point_num * 1.0 / fixed_point_num) * x)
|
||||
for x in range(fixed_point_num)
|
||||
]
|
||||
pos_info = pos_info[keep_ids, :]
|
||||
|
||||
keep = int(min(len(pos_info), fixed_point_num))
|
||||
if np.random.rand() < 0.2 and reference_height >= 3:
|
||||
dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
|
||||
random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
|
||||
[keep, 1])
|
||||
pos_info += random_float
|
||||
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
|
||||
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
|
||||
|
||||
# padding to fixed length
|
||||
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
|
||||
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
|
||||
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
|
||||
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
|
||||
pos_m[:keep] = 1.0
|
||||
return pos_l, pos_m
|
||||
|
||||
def generate_direction_map(self, poly_quads, n_char, direction_map):
|
||||
"""
|
||||
"""
|
||||
width_list = []
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
width_list.append(quad_w)
|
||||
height_list.append(quad_h)
|
||||
norm_width = max(sum(width_list) / n_char, 1.0)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
k = 1
|
||||
for quad in poly_quads:
|
||||
direct_vector_full = (
|
||||
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||
direct_vector = direct_vector_full / (
|
||||
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||
direction_label = tuple(
|
||||
map(float,
|
||||
[direct_vector[0], direct_vector[1], 1.0 / average_height]))
|
||||
cv2.fillPoly(direction_map,
|
||||
quad.round().astype(np.int32)[np.newaxis, :, :],
|
||||
direction_label)
|
||||
k += 1
|
||||
return direction_map
|
||||
|
||||
def calculate_average_height(self, poly_quads):
|
||||
"""
|
||||
"""
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
height_list.append(quad_h)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
return average_height
|
||||
|
||||
def generate_tcl_ctc_label(self,
|
||||
h,
|
||||
w,
|
||||
polys,
|
||||
tags,
|
||||
text_strs,
|
||||
ds_ratio,
|
||||
tcl_ratio=0.3,
|
||||
shrink_ratio_of_width=0.15):
|
||||
"""
|
||||
Generate polygon.
|
||||
"""
|
||||
score_map_big = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
|
||||
score_map = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
score_label_map = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
||||
training_mask = np.ones(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
|
||||
[1, 1, 3]).astype(np.float32)
|
||||
|
||||
label_idx = 0
|
||||
score_label_map_text_label_list = []
|
||||
pos_list, pos_mask, label_list = [], [], []
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
||||
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
||||
continue
|
||||
|
||||
if tag:
|
||||
cv2.fillPoly(training_mask,
|
||||
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||
else:
|
||||
text_label = text_strs[poly_idx]
|
||||
text_label = self.prepare_text_label(text_label,
|
||||
self.Lexicon_Table)
|
||||
|
||||
text_label_index_list = [[self.Lexicon_Table.index(c_)]
|
||||
for c_ in text_label
|
||||
if c_ in self.Lexicon_Table]
|
||||
if len(text_label_index_list) < 1:
|
||||
continue
|
||||
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
tcl_quads = self.poly2quads(tcl_poly)
|
||||
poly_quads = self.poly2quads(poly)
|
||||
|
||||
stcl_quads, quad_index = self.shrink_poly_along_width(
|
||||
tcl_quads,
|
||||
shrink_ratio_of_width=shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0 / tcl_ratio)
|
||||
|
||||
cv2.fillPoly(score_map,
|
||||
np.round(stcl_quads).astype(np.int32), 1.0)
|
||||
cv2.fillPoly(score_map_big,
|
||||
np.round(stcl_quads / ds_ratio).astype(np.int32),
|
||||
1.0)
|
||||
|
||||
for idx, quad in enumerate(stcl_quads):
|
||||
quad_mask = np.zeros((h, w), dtype=np.float32)
|
||||
quad_mask = cv2.fillPoly(
|
||||
quad_mask,
|
||||
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
|
||||
quad_mask, tbo_map)
|
||||
|
||||
# score label map and score_label_map_text_label_list for refine
|
||||
if label_idx == 0:
|
||||
text_pos_list_ = [[len(self.Lexicon_Table)], ]
|
||||
score_label_map_text_label_list.append(text_pos_list_)
|
||||
|
||||
label_idx += 1
|
||||
cv2.fillPoly(score_label_map,
|
||||
np.round(poly_quads).astype(np.int32), label_idx)
|
||||
score_label_map_text_label_list.append(text_label_index_list)
|
||||
|
||||
# direction info, fix-me
|
||||
n_char = len(text_label_index_list)
|
||||
direction_map = self.generate_direction_map(poly_quads, n_char,
|
||||
direction_map)
|
||||
|
||||
# pos info
|
||||
average_shrink_height = self.calculate_average_height(
|
||||
stcl_quads)
|
||||
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
||||
min_area_quad,
|
||||
poly,
|
||||
max_h=h,
|
||||
max_w=w,
|
||||
fixed_point_num=64,
|
||||
img_id=self.img_id,
|
||||
reference_height=average_shrink_height)
|
||||
|
||||
label_l = text_label_index_list
|
||||
if len(text_label_index_list) < 2:
|
||||
continue
|
||||
|
||||
pos_list.append(pos_l)
|
||||
pos_mask.append(pos_m)
|
||||
label_list.append(label_l)
|
||||
|
||||
# use big score_map for smooth tcl lines
|
||||
score_map_big_resized = cv2.resize(
|
||||
score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
|
||||
score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
|
||||
|
||||
return score_map, score_label_map, tbo_map, direction_map, training_mask, \
|
||||
pos_list, pos_mask, label_list, score_label_map_text_label_list
|
||||
|
||||
def adjust_point(self, poly):
|
||||
"""
|
||||
adjust point order.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
if point_num == 4:
|
||||
len_1 = np.linalg.norm(poly[0] - poly[1])
|
||||
len_2 = np.linalg.norm(poly[1] - poly[2])
|
||||
len_3 = np.linalg.norm(poly[2] - poly[3])
|
||||
len_4 = np.linalg.norm(poly[3] - poly[0])
|
||||
|
||||
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
||||
poly = poly[[1, 2, 3, 0], :]
|
||||
|
||||
elif point_num > 4:
|
||||
vector_1 = poly[0] - poly[1]
|
||||
vector_2 = poly[1] - poly[2]
|
||||
cos_theta = np.dot(vector_1, vector_2) / (
|
||||
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||
theta = np.arccos(np.round(cos_theta, decimals=4))
|
||||
|
||||
if abs(theta) > (70 / 180 * math.pi):
|
||||
index = list(range(1, point_num)) + [0]
|
||||
poly = poly[np.array(index), :]
|
||||
return poly
|
||||
|
||||
def gen_min_area_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if point_num == 4:
|
||||
min_area_quad = poly
|
||||
center_point = np.sum(poly, axis=0) / 4
|
||||
else:
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad, center_point
|
||||
|
||||
def shrink_quad_along_width(self,
|
||||
quad,
|
||||
begin_width_ratio=0.,
|
||||
end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def shrink_poly_along_width(self,
|
||||
quads,
|
||||
shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0):
|
||||
"""
|
||||
shrink poly with given length.
|
||||
"""
|
||||
upper_edge_list = []
|
||||
|
||||
def get_cut_info(edge_len_list, cut_len):
|
||||
for idx, edge_len in enumerate(edge_len_list):
|
||||
cut_len -= edge_len
|
||||
if cut_len <= 0.000001:
|
||||
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
||||
return idx, ratio
|
||||
|
||||
for quad in quads:
|
||||
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
||||
upper_edge_list.append(upper_edge_len)
|
||||
|
||||
# length of left edge and right edge.
|
||||
left_length = np.linalg.norm(quads[0][0] - quads[0][
|
||||
3]) * expand_height_ratio
|
||||
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
|
||||
2]) * expand_height_ratio
|
||||
|
||||
shrink_length = min(left_length, right_length,
|
||||
sum(upper_edge_list)) * shrink_ratio_of_width
|
||||
# shrinking length
|
||||
upper_len_left = shrink_length
|
||||
upper_len_right = sum(upper_edge_list) - shrink_length
|
||||
|
||||
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
||||
left_quad = self.shrink_quad_along_width(
|
||||
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
||||
right_quad = self.shrink_quad_along_width(
|
||||
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||
|
||||
out_quad_list = []
|
||||
if left_idx == right_idx:
|
||||
out_quad_list.append(
|
||||
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||
else:
|
||||
out_quad_list.append(left_quad)
|
||||
for idx in range(left_idx + 1, right_idx):
|
||||
out_quad_list.append(quads[idx])
|
||||
out_quad_list.append(right_quad)
|
||||
|
||||
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
||||
|
||||
def prepare_text_label(self, label_str, Lexicon_Table):
|
||||
"""
|
||||
Prepare text lablel by given Lexicon_Table.
|
||||
"""
|
||||
if len(Lexicon_Table) == 36:
|
||||
return label_str.lower()
|
||||
else:
|
||||
return label_str
|
||||
|
||||
def vector_angle(self, A, B):
|
||||
"""
|
||||
Calculate the angle between vector AB and x-axis positive direction.
|
||||
"""
|
||||
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
||||
return np.arctan2(*AB)
|
||||
|
||||
def theta_line_cross_point(self, theta, point):
|
||||
"""
|
||||
Calculate the line through given point and angle in ax + by + c =0 form.
|
||||
"""
|
||||
x, y = point
|
||||
cos = np.cos(theta)
|
||||
sin = np.sin(theta)
|
||||
return [sin, -cos, cos * y - sin * x]
|
||||
|
||||
def line_cross_two_point(self, A, B):
|
||||
"""
|
||||
Calculate the line through given point A and B in ax + by + c =0 form.
|
||||
"""
|
||||
angle = self.vector_angle(A, B)
|
||||
return self.theta_line_cross_point(angle, A)
|
||||
|
||||
def average_angle(self, poly):
|
||||
"""
|
||||
Calculate the average angle between left and right edge in given poly.
|
||||
"""
|
||||
p0, p1, p2, p3 = poly
|
||||
angle30 = self.vector_angle(p3, p0)
|
||||
angle21 = self.vector_angle(p2, p1)
|
||||
return (angle30 + angle21) / 2
|
||||
|
||||
def line_cross_point(self, line1, line2):
|
||||
"""
|
||||
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
||||
"""
|
||||
a1, b1, c1 = line1
|
||||
a2, b2, c2 = line2
|
||||
d = a1 * b2 - a2 * b1
|
||||
|
||||
if d == 0:
|
||||
print('Cross point does not exist')
|
||||
return np.array([0, 0], dtype=np.float32)
|
||||
else:
|
||||
x = (b1 * c2 - b2 * c1) / d
|
||||
y = (a2 * c1 - a1 * c2) / d
|
||||
|
||||
return np.array([x, y], dtype=np.float32)
|
||||
|
||||
def quad2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point. (4, 2)
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
||||
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
||||
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
||||
|
||||
def poly2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
tcl_poly = np.zeros_like(poly)
|
||||
point_num = poly.shape[0]
|
||||
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
|
||||
) * ratio_pair
|
||||
tcl_poly[idx] = point_pair[0]
|
||||
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
||||
return tcl_poly
|
||||
|
||||
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
||||
"""
|
||||
Generate tbo_map for give quad.
|
||||
"""
|
||||
# upper and lower line function: ax + by + c = 0;
|
||||
up_line = self.line_cross_two_point(quad[0], quad[1])
|
||||
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
||||
|
||||
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[1] - quad[2]))
|
||||
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3]))
|
||||
|
||||
# average angle of left and right line.
|
||||
angle = self.average_angle(quad)
|
||||
|
||||
xy_in_poly = np.argwhere(tcl_mask == 1)
|
||||
for y, x in xy_in_poly:
|
||||
point = (x, y)
|
||||
line = self.theta_line_cross_point(angle, point)
|
||||
cross_point_upper = self.line_cross_point(up_line, line)
|
||||
cross_point_lower = self.line_cross_point(lower_line, line)
|
||||
##FIX, offset reverse
|
||||
upper_offset_x, upper_offset_y = cross_point_upper - point
|
||||
lower_offset_x, lower_offset_y = cross_point_lower - point
|
||||
tbo_map[y, x, 0] = upper_offset_y
|
||||
tbo_map[y, x, 1] = upper_offset_x
|
||||
tbo_map[y, x, 2] = lower_offset_y
|
||||
tbo_map[y, x, 3] = lower_offset_x
|
||||
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
||||
return tbo_map
|
||||
|
||||
def poly2quads(self, poly):
|
||||
"""
|
||||
Split poly into quads.
|
||||
"""
|
||||
quad_list = []
|
||||
point_num = poly.shape[0]
|
||||
|
||||
# point pair
|
||||
point_pair_list = []
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
quad_num = point_num // 2 - 1
|
||||
for idx in range(quad_num):
|
||||
# reshape and adjust to clock-wise
|
||||
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
|
||||
).reshape(4, 2)[[0, 2, 3, 1]])
|
||||
|
||||
return np.array(quad_list)
|
||||
|
||||
def rotate_im_poly(self, im, text_polys):
|
||||
"""
|
||||
rotate image with 90 / 180 / 270 degre
|
||||
"""
|
||||
im_w, im_h = im.shape[1], im.shape[0]
|
||||
dst_im = im.copy()
|
||||
dst_polys = []
|
||||
rand_degree_ratio = np.random.rand()
|
||||
rand_degree_cnt = 1
|
||||
if rand_degree_ratio > 0.5:
|
||||
rand_degree_cnt = 3
|
||||
for i in range(rand_degree_cnt):
|
||||
dst_im = np.rot90(dst_im)
|
||||
rot_degree = -90 * rand_degree_cnt
|
||||
rot_angle = rot_degree * math.pi / 180.0
|
||||
n_poly = text_polys.shape[0]
|
||||
cx, cy = 0.5 * im_w, 0.5 * im_h
|
||||
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
||||
for i in range(n_poly):
|
||||
wordBB = text_polys[i]
|
||||
poly = []
|
||||
for j in range(4): # 16->4
|
||||
sx, sy = wordBB[j][0], wordBB[j][1]
|
||||
dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
|
||||
sy - cy) + ncx
|
||||
dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
|
||||
sy - cy) + ncy
|
||||
poly.append([dx, dy])
|
||||
dst_polys.append(poly)
|
||||
return dst_im, np.array(dst_polys, dtype=np.float32)
|
||||
|
||||
def __call__(self, data):
|
||||
input_size = 512
|
||||
im = data['image']
|
||||
text_polys = data['polys']
|
||||
text_tags = data['ignore_tags']
|
||||
text_strs = data['texts']
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
||||
text_polys, text_tags, (h, w))
|
||||
if text_polys.shape[0] <= 0:
|
||||
return None
|
||||
# set aspect ratio and keep area fix
|
||||
asp_scales = np.arange(1.0, 1.55, 0.1)
|
||||
asp_scale = np.random.choice(asp_scales)
|
||||
if np.random.rand() < 0.5:
|
||||
asp_scale = 1.0 / asp_scale
|
||||
asp_scale = math.sqrt(asp_scale)
|
||||
|
||||
asp_wx = asp_scale
|
||||
asp_hy = 1.0 / asp_scale
|
||||
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
||||
text_polys[:, :, 0] *= asp_wx
|
||||
text_polys[:, :, 1] *= asp_hy
|
||||
|
||||
h, w, _ = im.shape
|
||||
if max(h, w) > 2048:
|
||||
rd_scale = 2048.0 / max(h, w)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
h, w, _ = im.shape
|
||||
if min(h, w) < 16:
|
||||
return None
|
||||
|
||||
# no background
|
||||
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
|
||||
im,
|
||||
text_polys,
|
||||
text_tags,
|
||||
hv_tags,
|
||||
text_strs,
|
||||
crop_background=False)
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
# # continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
new_h, new_w, _ = im.shape
|
||||
if (new_h is None) or (new_w is None):
|
||||
return None
|
||||
# resize image
|
||||
std_ratio = float(input_size) / max(new_w, new_h)
|
||||
rand_scales = np.array(
|
||||
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
rz_scale = std_ratio * np.random.choice(rand_scales)
|
||||
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
||||
text_polys[:, :, 0] *= rz_scale
|
||||
text_polys[:, :, 1] *= rz_scale
|
||||
|
||||
# add gaussian blur
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
ks = np.random.permutation(5)[0] + 1
|
||||
ks = int(ks / 2) * 2 + 1
|
||||
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||
# add brighter
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 + np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
# add darker
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 - np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
|
||||
# Padding the im to [input_size, input_size]
|
||||
new_h, new_w, _ = im.shape
|
||||
if min(new_w, new_h) < input_size * 0.5:
|
||||
return None
|
||||
im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
|
||||
im_padded[:, :, 2] = 0.485 * 255
|
||||
im_padded[:, :, 1] = 0.456 * 255
|
||||
im_padded[:, :, 0] = 0.406 * 255
|
||||
|
||||
# Random the start position
|
||||
del_h = input_size - new_h
|
||||
del_w = input_size - new_w
|
||||
sh, sw = 0, 0
|
||||
if del_h > 1:
|
||||
sh = int(np.random.rand() * del_h)
|
||||
if del_w > 1:
|
||||
sw = int(np.random.rand() * del_w)
|
||||
|
||||
# Padding
|
||||
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
|
||||
text_polys[:, :, 0] += sw
|
||||
text_polys[:, :, 1] += sh
|
||||
|
||||
score_map, score_label_map, border_map, direction_map, training_mask, \
|
||||
pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
|
||||
input_size,
|
||||
text_polys,
|
||||
text_tags,
|
||||
text_strs, 0.25)
|
||||
if len(label_list) <= 0: # eliminate negative samples
|
||||
return None
|
||||
pos_list_temp = np.zeros([64, 3])
|
||||
pos_mask_temp = np.zeros([64, 1])
|
||||
label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
|
||||
|
||||
for i, label in enumerate(label_list):
|
||||
n = len(label)
|
||||
if n > self.max_text_length:
|
||||
label_list[i] = label[:self.max_text_length]
|
||||
continue
|
||||
while n < self.max_text_length:
|
||||
label.append([self.pad_num])
|
||||
n += 1
|
||||
|
||||
for i in range(len(label_list)):
|
||||
label_list[i] = np.array(label_list[i])
|
||||
|
||||
if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
|
||||
return None
|
||||
for __ in range(self.max_text_nums - len(pos_list), 0, -1):
|
||||
pos_list.append(pos_list_temp)
|
||||
pos_mask.append(pos_mask_temp)
|
||||
label_list.append(label_list_temp)
|
||||
|
||||
if self.img_id == self.batch_size - 1:
|
||||
self.img_id = 0
|
||||
else:
|
||||
self.img_id += 1
|
||||
|
||||
im_padded[:, :, 2] -= 0.485 * 255
|
||||
im_padded[:, :, 1] -= 0.456 * 255
|
||||
im_padded[:, :, 0] -= 0.406 * 255
|
||||
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
images = im_padded[::-1, :, :]
|
||||
tcl_maps = score_map[np.newaxis, :, :]
|
||||
tcl_label_maps = score_label_map[np.newaxis, :, :]
|
||||
border_maps = border_map.transpose((2, 0, 1))
|
||||
direction_maps = direction_map.transpose((2, 0, 1))
|
||||
training_masks = training_mask[np.newaxis, :, :]
|
||||
pos_list = np.array(pos_list)
|
||||
pos_mask = np.array(pos_mask)
|
||||
label_list = np.array(label_list)
|
||||
data['images'] = images
|
||||
data['tcl_maps'] = tcl_maps
|
||||
data['tcl_label_maps'] = tcl_label_maps
|
||||
data['border_maps'] = border_maps
|
||||
data['direction_maps'] = direction_maps
|
||||
data['training_masks'] = training_masks
|
||||
data['label_list'] = label_list
|
||||
data['pos_list'] = pos_list
|
||||
data['pos_mask'] = pos_mask
|
||||
return data
|
||||
@@ -1,143 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
import six
|
||||
|
||||
|
||||
class RawRandAugment(object):
|
||||
def __init__(self,
|
||||
num_layers=2,
|
||||
magnitude=5,
|
||||
fillcolor=(128, 128, 128),
|
||||
**kwargs):
|
||||
self.num_layers = num_layers
|
||||
self.magnitude = magnitude
|
||||
self.max_level = 10
|
||||
|
||||
abso_level = self.magnitude / self.max_level
|
||||
self.level_map = {
|
||||
"shearX": 0.3 * abso_level,
|
||||
"shearY": 0.3 * abso_level,
|
||||
"translateX": 150.0 / 331 * abso_level,
|
||||
"translateY": 150.0 / 331 * abso_level,
|
||||
"rotate": 30 * abso_level,
|
||||
"color": 0.9 * abso_level,
|
||||
"posterize": int(4.0 * abso_level),
|
||||
"solarize": 256.0 * abso_level,
|
||||
"contrast": 0.9 * abso_level,
|
||||
"sharpness": 0.9 * abso_level,
|
||||
"brightness": 0.9 * abso_level,
|
||||
"autocontrast": 0,
|
||||
"equalize": 0,
|
||||
"invert": 0
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/
|
||||
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot,
|
||||
Image.new("RGBA", rot.size, (128, ) * 4),
|
||||
rot).convert(img.mode)
|
||||
|
||||
rnd_ch_op = random.choice
|
||||
|
||||
self.func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"posterize": lambda img, magnitude:
|
||||
ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude:
|
||||
ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude:
|
||||
ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"sharpness": lambda img, magnitude:
|
||||
ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"brightness": lambda img, magnitude:
|
||||
ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude:
|
||||
ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
def __call__(self, img):
|
||||
avaiable_op_names = list(self.level_map.keys())
|
||||
for layer_num in range(self.num_layers):
|
||||
op_name = np.random.choice(avaiable_op_names)
|
||||
img = self.func[op_name](img, self.level_map[op_name])
|
||||
return img
|
||||
|
||||
|
||||
class RandAugment(RawRandAugment):
|
||||
""" RandAugment wrapper to auto fit different img types """
|
||||
|
||||
def __init__(self, prob=0.5, *args, **kwargs):
|
||||
self.prob = prob
|
||||
if six.PY2:
|
||||
super(RandAugment, self).__init__(*args, **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, data):
|
||||
if np.random.rand() > self.prob:
|
||||
return data
|
||||
img = data['image']
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if six.PY2:
|
||||
img = super(RandAugment, self).__call__(img)
|
||||
else:
|
||||
img = super().__call__(img)
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
data['image'] = img
|
||||
return data
|
||||
@@ -1,234 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import random
|
||||
|
||||
|
||||
def is_poly_in_rect(poly, x, y, w, h):
|
||||
poly = np.array(poly)
|
||||
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
|
||||
return False
|
||||
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_poly_outside_rect(poly, x, y, w, h):
|
||||
poly = np.array(poly)
|
||||
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
|
||||
return True
|
||||
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def split_regions(axis):
|
||||
regions = []
|
||||
min_axis = 0
|
||||
for i in range(1, axis.shape[0]):
|
||||
if axis[i] != axis[i - 1] + 1:
|
||||
region = axis[min_axis:i]
|
||||
min_axis = i
|
||||
regions.append(region)
|
||||
return regions
|
||||
|
||||
|
||||
def random_select(axis, max_size):
|
||||
xx = np.random.choice(axis, size=2)
|
||||
xmin = np.min(xx)
|
||||
xmax = np.max(xx)
|
||||
xmin = np.clip(xmin, 0, max_size - 1)
|
||||
xmax = np.clip(xmax, 0, max_size - 1)
|
||||
return xmin, xmax
|
||||
|
||||
|
||||
def region_wise_random_select(regions, max_size):
|
||||
selected_index = list(np.random.choice(len(regions), 2))
|
||||
selected_values = []
|
||||
for index in selected_index:
|
||||
axis = regions[index]
|
||||
xx = int(np.random.choice(axis, size=1))
|
||||
selected_values.append(xx)
|
||||
xmin = min(selected_values)
|
||||
xmax = max(selected_values)
|
||||
return xmin, xmax
|
||||
|
||||
|
||||
def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
|
||||
h, w, _ = im.shape
|
||||
h_array = np.zeros(h, dtype=np.int32)
|
||||
w_array = np.zeros(w, dtype=np.int32)
|
||||
for points in text_polys:
|
||||
points = np.round(points, decimals=0).astype(np.int32)
|
||||
minx = np.min(points[:, 0])
|
||||
maxx = np.max(points[:, 0])
|
||||
w_array[minx:maxx] = 1
|
||||
miny = np.min(points[:, 1])
|
||||
maxy = np.max(points[:, 1])
|
||||
h_array[miny:maxy] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return 0, 0, w, h
|
||||
|
||||
h_regions = split_regions(h_axis)
|
||||
w_regions = split_regions(w_axis)
|
||||
|
||||
for i in range(max_tries):
|
||||
if len(w_regions) > 1:
|
||||
xmin, xmax = region_wise_random_select(w_regions, w)
|
||||
else:
|
||||
xmin, xmax = random_select(w_axis, w)
|
||||
if len(h_regions) > 1:
|
||||
ymin, ymax = region_wise_random_select(h_regions, h)
|
||||
else:
|
||||
ymin, ymax = random_select(h_axis, h)
|
||||
|
||||
if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
|
||||
# area too small
|
||||
continue
|
||||
num_poly_in_rect = 0
|
||||
for poly in text_polys:
|
||||
if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
|
||||
ymax - ymin):
|
||||
num_poly_in_rect += 1
|
||||
break
|
||||
|
||||
if num_poly_in_rect > 0:
|
||||
return xmin, ymin, xmax - xmin, ymax - ymin
|
||||
|
||||
return 0, 0, w, h
|
||||
|
||||
|
||||
class EastRandomCropData(object):
|
||||
def __init__(self,
|
||||
size=(640, 640),
|
||||
max_tries=10,
|
||||
min_crop_side_ratio=0.1,
|
||||
keep_ratio=True,
|
||||
**kwargs):
|
||||
self.size = size
|
||||
self.max_tries = max_tries
|
||||
self.min_crop_side_ratio = min_crop_side_ratio
|
||||
self.keep_ratio = keep_ratio
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
text_polys = data['polys']
|
||||
ignore_tags = data['ignore_tags']
|
||||
texts = data['texts']
|
||||
all_care_polys = [
|
||||
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
|
||||
]
|
||||
# 计算crop区域
|
||||
crop_x, crop_y, crop_w, crop_h = crop_area(
|
||||
img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
|
||||
# crop 图片 保持比例填充
|
||||
scale_w = self.size[0] / crop_w
|
||||
scale_h = self.size[1] / crop_h
|
||||
scale = min(scale_w, scale_h)
|
||||
h = int(crop_h * scale)
|
||||
w = int(crop_w * scale)
|
||||
if self.keep_ratio:
|
||||
padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
|
||||
img.dtype)
|
||||
padimg[:h, :w] = cv2.resize(
|
||||
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
|
||||
img = padimg
|
||||
else:
|
||||
img = cv2.resize(
|
||||
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
|
||||
tuple(self.size))
|
||||
# crop 文本框
|
||||
text_polys_crop = []
|
||||
ignore_tags_crop = []
|
||||
texts_crop = []
|
||||
for poly, text, tag in zip(text_polys, texts, ignore_tags):
|
||||
poly = ((poly - (crop_x, crop_y)) * scale).tolist()
|
||||
if not is_poly_outside_rect(poly, 0, 0, w, h):
|
||||
text_polys_crop.append(poly)
|
||||
ignore_tags_crop.append(tag)
|
||||
texts_crop.append(text)
|
||||
data['image'] = img
|
||||
data['polys'] = np.array(text_polys_crop)
|
||||
data['ignore_tags'] = ignore_tags_crop
|
||||
data['texts'] = texts_crop
|
||||
return data
|
||||
|
||||
|
||||
class RandomCropImgMask(object):
|
||||
def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
|
||||
self.size = size
|
||||
self.main_key = main_key
|
||||
self.crop_keys = crop_keys
|
||||
self.p = p
|
||||
|
||||
def __call__(self, data):
|
||||
image = data['image']
|
||||
|
||||
h, w = image.shape[0:2]
|
||||
th, tw = self.size
|
||||
if w == tw and h == th:
|
||||
return data
|
||||
|
||||
mask = data[self.main_key]
|
||||
if np.max(mask) > 0 and random.random() > self.p:
|
||||
# make sure to crop the text region
|
||||
tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
|
||||
tl[tl < 0] = 0
|
||||
br = np.max(np.where(mask > 0), axis=1) - (th, tw)
|
||||
br[br < 0] = 0
|
||||
|
||||
br[0] = min(br[0], h - th)
|
||||
br[1] = min(br[1], w - tw)
|
||||
|
||||
i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
|
||||
j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
|
||||
else:
|
||||
i = random.randint(0, h - th) if h - th > 0 else 0
|
||||
j = random.randint(0, w - tw) if w - tw > 0 else 0
|
||||
|
||||
# return i, j, th, tw
|
||||
for k in data:
|
||||
if k in self.crop_keys:
|
||||
if len(data[k].shape) == 3:
|
||||
if np.argmin(data[k].shape) == 0:
|
||||
img = data[k][:, i:i + th, j:j + tw]
|
||||
if img.shape[1] != img.shape[2]:
|
||||
a = 1
|
||||
elif np.argmin(data[k].shape) == 2:
|
||||
img = data[k][i:i + th, j:j + tw, :]
|
||||
if img.shape[1] != img.shape[0]:
|
||||
a = 1
|
||||
else:
|
||||
img = data[k]
|
||||
else:
|
||||
img = data[k][i:i + th, j:j + tw]
|
||||
if img.shape[0] != img.shape[1]:
|
||||
a = 1
|
||||
data[k] = img
|
||||
return data
|
||||
@@ -1,601 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import random
|
||||
import copy
|
||||
from PIL import Image
|
||||
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
||||
|
||||
|
||||
class RecAug(object):
|
||||
def __init__(self, use_tia=True, aug_prob=0.4, **kwargs):
|
||||
self.use_tia = use_tia
|
||||
self.aug_prob = aug_prob
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = warp(img, 10, self.use_tia, self.aug_prob)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class RecConAug(object):
|
||||
def __init__(self,
|
||||
prob=0.5,
|
||||
image_shape=(32, 320, 3),
|
||||
max_text_length=25,
|
||||
ext_data_num=1,
|
||||
**kwargs):
|
||||
self.ext_data_num = ext_data_num
|
||||
self.prob = prob
|
||||
self.max_text_length = max_text_length
|
||||
self.image_shape = image_shape
|
||||
self.max_wh_ratio = self.image_shape[1] / self.image_shape[0]
|
||||
|
||||
def merge_ext_data(self, data, ext_data):
|
||||
ori_w = round(data['image'].shape[1] / data['image'].shape[0] *
|
||||
self.image_shape[0])
|
||||
ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] *
|
||||
self.image_shape[0])
|
||||
data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0]))
|
||||
ext_data['image'] = cv2.resize(ext_data['image'],
|
||||
(ext_w, self.image_shape[0]))
|
||||
data['image'] = np.concatenate(
|
||||
[data['image'], ext_data['image']], axis=1)
|
||||
data["label"] += ext_data["label"]
|
||||
return data
|
||||
|
||||
def __call__(self, data):
|
||||
rnd_num = random.random()
|
||||
if rnd_num > self.prob:
|
||||
return data
|
||||
for idx, ext_data in enumerate(data["ext_data"]):
|
||||
if len(data["label"]) + len(ext_data[
|
||||
"label"]) > self.max_text_length:
|
||||
break
|
||||
concat_ratio = data['image'].shape[1] / data['image'].shape[
|
||||
0] + ext_data['image'].shape[1] / ext_data['image'].shape[0]
|
||||
if concat_ratio > self.max_wh_ratio:
|
||||
break
|
||||
data = self.merge_ext_data(data, ext_data)
|
||||
data.pop("ext_data")
|
||||
return data
|
||||
|
||||
|
||||
class ClsResizeImg(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img, _ = resize_norm_img(img, self.image_shape)
|
||||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
|
||||
class NRTRRecResizeImg(object):
|
||||
def __init__(self, image_shape, resize_type, padding=False, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.resize_type = resize_type
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
image_shape = self.image_shape
|
||||
if self.padding:
|
||||
imgC, imgH, imgW = image_shape
|
||||
# todo: change to 0 and modified image shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
norm_img = np.expand_dims(resized_image, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
resized_image = norm_img.astype(np.float32) / 128. - 1.
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
data['image'] = padding_im
|
||||
return data
|
||||
if self.resize_type == 'PIL':
|
||||
image_pil = Image.fromarray(np.uint8(img))
|
||||
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
|
||||
img = np.array(img)
|
||||
if self.resize_type == 'OpenCV':
|
||||
img = cv2.resize(img, self.image_shape)
|
||||
norm_img = np.expand_dims(img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||
return data
|
||||
|
||||
|
||||
class RecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
infer_mode=False,
|
||||
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
|
||||
padding=True,
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.infer_mode = infer_mode
|
||||
self.character_dict_path = character_dict_path
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if self.infer_mode and self.character_dict_path is not None:
|
||||
norm_img, valid_ratio = resize_norm_img_chinese(img,
|
||||
self.image_shape)
|
||||
else:
|
||||
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
|
||||
self.padding)
|
||||
data['image'] = norm_img
|
||||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
class SRNRecResizeImg(object):
|
||||
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.num_heads = num_heads
|
||||
self.max_text_length = max_text_length
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img = resize_norm_img_srn(img, self.image_shape)
|
||||
data['image'] = norm_img
|
||||
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
|
||||
|
||||
data['encoder_word_pos'] = encoder_word_pos
|
||||
data['gsrm_word_pos'] = gsrm_word_pos
|
||||
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
|
||||
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
|
||||
return data
|
||||
|
||||
|
||||
class SARRecResizeImg(object):
|
||||
def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.width_downsample_ratio = width_downsample_ratio
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
|
||||
img, self.image_shape, self.width_downsample_ratio)
|
||||
data['image'] = norm_img
|
||||
data['resized_shape'] = resize_shape
|
||||
data['pad_shape'] = pad_shape
|
||||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
class PRENResizeImg(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
"""
|
||||
Accroding to original paper's realization, it's a hard resize method here.
|
||||
So maybe you should optimize it to fit for your task better.
|
||||
"""
|
||||
self.dst_h, self.dst_w = image_shape
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
resized_img = cv2.resize(
|
||||
img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR)
|
||||
resized_img = resized_img.transpose((2, 0, 1)) / 255
|
||||
resized_img -= 0.5
|
||||
resized_img /= 0.5
|
||||
data['image'] = resized_img.astype(np.float32)
|
||||
return data
|
||||
|
||||
|
||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
|
||||
def resize_norm_img(img, image_shape, padding=True):
|
||||
imgC, imgH, imgW = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
if not padding:
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_w = imgW
|
||||
else:
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
valid_ratio = min(1.0, float(resized_w / imgW))
|
||||
return padding_im, valid_ratio
|
||||
|
||||
|
||||
def resize_norm_img_chinese(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
# todo: change to 0 and modified image shape
|
||||
max_wh_ratio = imgW * 1.0 / imgH
|
||||
h, w = img.shape[0], img.shape[1]
|
||||
ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, ratio)
|
||||
imgW = int(imgH * max_wh_ratio)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
valid_ratio = min(1.0, float(resized_w / imgW))
|
||||
return padding_im, valid_ratio
|
||||
|
||||
|
||||
def resize_norm_img_srn(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0:img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
|
||||
def srn_other_inputs(image_shape, num_heads, max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||
(feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||
(max_text_length, 1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||
[1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
|
||||
[num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||
[1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
|
||||
[num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
return [
|
||||
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
|
||||
|
||||
def flag():
|
||||
"""
|
||||
flag
|
||||
"""
|
||||
return 1 if random.random() > 0.5000001 else -1
|
||||
|
||||
|
||||
def cvtColor(img):
|
||||
"""
|
||||
cvtColor
|
||||
"""
|
||||
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
delta = 0.001 * random.random() * flag()
|
||||
hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
|
||||
new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
||||
return new_img
|
||||
|
||||
|
||||
def blur(img):
|
||||
"""
|
||||
blur
|
||||
"""
|
||||
h, w, _ = img.shape
|
||||
if h > 10 and w > 10:
|
||||
return cv2.GaussianBlur(img, (5, 5), 1)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def jitter(img):
|
||||
"""
|
||||
jitter
|
||||
"""
|
||||
w, h, _ = img.shape
|
||||
if h > 10 and w > 10:
|
||||
thres = min(w, h)
|
||||
s = int(random.random() * thres * 0.01)
|
||||
src_img = img.copy()
|
||||
for i in range(s):
|
||||
img[i:, i:, :] = src_img[:w - i, :h - i, :]
|
||||
return img
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def add_gasuss_noise(image, mean=0, var=0.1):
|
||||
"""
|
||||
Gasuss noise
|
||||
"""
|
||||
|
||||
noise = np.random.normal(mean, var**0.5, image.shape)
|
||||
out = image + 0.5 * noise
|
||||
out = np.clip(out, 0, 255)
|
||||
out = np.uint8(out)
|
||||
return out
|
||||
|
||||
|
||||
def get_crop(image):
|
||||
"""
|
||||
random crop
|
||||
"""
|
||||
h, w, _ = image.shape
|
||||
top_min = 1
|
||||
top_max = 8
|
||||
top_crop = int(random.randint(top_min, top_max))
|
||||
top_crop = min(top_crop, h - 1)
|
||||
crop_img = image.copy()
|
||||
ratio = random.randint(0, 1)
|
||||
if ratio:
|
||||
crop_img = crop_img[top_crop:h, :, :]
|
||||
else:
|
||||
crop_img = crop_img[0:h - top_crop, :, :]
|
||||
return crop_img
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Config
|
||||
"""
|
||||
|
||||
def __init__(self, use_tia):
|
||||
self.anglex = random.random() * 30
|
||||
self.angley = random.random() * 15
|
||||
self.anglez = random.random() * 10
|
||||
self.fov = 42
|
||||
self.r = 0
|
||||
self.shearx = random.random() * 0.3
|
||||
self.sheary = random.random() * 0.05
|
||||
self.borderMode = cv2.BORDER_REPLICATE
|
||||
self.use_tia = use_tia
|
||||
|
||||
def make(self, w, h, ang):
|
||||
"""
|
||||
make
|
||||
"""
|
||||
self.anglex = random.random() * 5 * flag()
|
||||
self.angley = random.random() * 5 * flag()
|
||||
self.anglez = -1 * random.random() * int(ang) * flag()
|
||||
self.fov = 42
|
||||
self.r = 0
|
||||
self.shearx = 0
|
||||
self.sheary = 0
|
||||
self.borderMode = cv2.BORDER_REPLICATE
|
||||
self.w = w
|
||||
self.h = h
|
||||
|
||||
self.perspective = self.use_tia
|
||||
self.stretch = self.use_tia
|
||||
self.distort = self.use_tia
|
||||
|
||||
self.crop = True
|
||||
self.affine = False
|
||||
self.reverse = True
|
||||
self.noise = True
|
||||
self.jitter = True
|
||||
self.blur = True
|
||||
self.color = True
|
||||
|
||||
|
||||
def rad(x):
|
||||
"""
|
||||
rad
|
||||
"""
|
||||
return x * np.pi / 180
|
||||
|
||||
|
||||
def get_warpR(config):
|
||||
"""
|
||||
get_warpR
|
||||
"""
|
||||
anglex, angley, anglez, fov, w, h, r = \
|
||||
config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
|
||||
if w > 69 and w < 112:
|
||||
anglex = anglex * 1.5
|
||||
|
||||
z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
|
||||
# Homogeneous coordinate transformation matrix
|
||||
rx = np.array([[1, 0, 0, 0],
|
||||
[0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
|
||||
0,
|
||||
-np.sin(rad(anglex)),
|
||||
np.cos(rad(anglex)),
|
||||
0,
|
||||
], [0, 0, 0, 1]], np.float32)
|
||||
ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
|
||||
[0, 1, 0, 0], [
|
||||
-np.sin(rad(angley)),
|
||||
0,
|
||||
np.cos(rad(angley)),
|
||||
0,
|
||||
], [0, 0, 0, 1]], np.float32)
|
||||
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
|
||||
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
|
||||
[0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
|
||||
r = rx.dot(ry).dot(rz)
|
||||
# generate 4 points
|
||||
pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
|
||||
p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
|
||||
p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
|
||||
p3 = np.array([0, h, 0, 0], np.float32) - pcenter
|
||||
p4 = np.array([w, h, 0, 0], np.float32) - pcenter
|
||||
dst1 = r.dot(p1)
|
||||
dst2 = r.dot(p2)
|
||||
dst3 = r.dot(p3)
|
||||
dst4 = r.dot(p4)
|
||||
list_dst = np.array([dst1, dst2, dst3, dst4])
|
||||
org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
|
||||
dst = np.zeros((4, 2), np.float32)
|
||||
# Project onto the image plane
|
||||
dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0]
|
||||
dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1]
|
||||
|
||||
warpR = cv2.getPerspectiveTransform(org, dst)
|
||||
|
||||
dst1, dst2, dst3, dst4 = dst
|
||||
r1 = int(min(dst1[1], dst2[1]))
|
||||
r2 = int(max(dst3[1], dst4[1]))
|
||||
c1 = int(min(dst1[0], dst3[0]))
|
||||
c2 = int(max(dst2[0], dst4[0]))
|
||||
|
||||
try:
|
||||
ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1))
|
||||
|
||||
dx = -c1
|
||||
dy = -r1
|
||||
T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
|
||||
ret = T1.dot(warpR)
|
||||
except:
|
||||
ratio = 1.0
|
||||
T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
|
||||
ret = T1
|
||||
return ret, (-r1, -c1), ratio, dst
|
||||
|
||||
|
||||
def get_warpAffine(config):
|
||||
"""
|
||||
get_warpAffine
|
||||
"""
|
||||
anglez = config.anglez
|
||||
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
|
||||
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
|
||||
return rz
|
||||
|
||||
|
||||
def warp(img, ang, use_tia=True, prob=0.4):
|
||||
"""
|
||||
warp
|
||||
"""
|
||||
h, w, _ = img.shape
|
||||
config = Config(use_tia=use_tia)
|
||||
config.make(w, h, ang)
|
||||
new_img = img
|
||||
|
||||
if config.distort:
|
||||
img_height, img_width = img.shape[0:2]
|
||||
if random.random() <= prob and img_height >= 20 and img_width >= 20:
|
||||
new_img = tia_distort(new_img, random.randint(3, 6))
|
||||
|
||||
if config.stretch:
|
||||
img_height, img_width = img.shape[0:2]
|
||||
if random.random() <= prob and img_height >= 20 and img_width >= 20:
|
||||
new_img = tia_stretch(new_img, random.randint(3, 6))
|
||||
|
||||
if config.perspective:
|
||||
if random.random() <= prob:
|
||||
new_img = tia_perspective(new_img)
|
||||
|
||||
if config.crop:
|
||||
img_height, img_width = img.shape[0:2]
|
||||
if random.random() <= prob and img_height >= 20 and img_width >= 20:
|
||||
new_img = get_crop(new_img)
|
||||
|
||||
if config.blur:
|
||||
if random.random() <= prob:
|
||||
new_img = blur(new_img)
|
||||
if config.color:
|
||||
if random.random() <= prob:
|
||||
new_img = cvtColor(new_img)
|
||||
if config.jitter:
|
||||
new_img = jitter(new_img)
|
||||
if config.noise:
|
||||
if random.random() <= prob:
|
||||
new_img = add_gasuss_noise(new_img)
|
||||
if config.reverse:
|
||||
if random.random() <= prob:
|
||||
new_img = 255 - new_img
|
||||
return new_img
|
||||
@@ -1,777 +0,0 @@
|
||||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
"""
|
||||
This part code is refered from:
|
||||
https://github.com/songdejia/EAST/blob/master/data_utils.py
|
||||
"""
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
__all__ = ['SASTProcessTrain']
|
||||
|
||||
|
||||
class SASTProcessTrain(object):
|
||||
def __init__(self,
|
||||
image_shape=[512, 512],
|
||||
min_crop_size=24,
|
||||
min_crop_side_ratio=0.3,
|
||||
min_text_size=10,
|
||||
max_text_size=512,
|
||||
**kwargs):
|
||||
self.input_size = image_shape[1]
|
||||
self.min_crop_size = min_crop_size
|
||||
self.min_crop_side_ratio = min_crop_side_ratio
|
||||
self.min_text_size = min_text_size
|
||||
self.max_text_size = max_text_size
|
||||
|
||||
def quad_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def gen_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if True:
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
:param polys:
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
(h, w) = xxx_todo_changeme
|
||||
if polys.shape[0] == 0:
|
||||
return polys, np.array([]), np.array([])
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||
|
||||
validated_polys = []
|
||||
validated_tags = []
|
||||
hv_tags = []
|
||||
for poly, tag in zip(polys, tags):
|
||||
quad = self.gen_quad_from_poly(poly)
|
||||
p_area = self.quad_area(quad)
|
||||
if abs(p_area) < 1:
|
||||
print('invalid poly')
|
||||
continue
|
||||
if p_area > 0:
|
||||
if tag == False:
|
||||
print('poly in wrong direction')
|
||||
tag = True # reversed cases should be ignore
|
||||
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
|
||||
1), :]
|
||||
quad = quad[(0, 3, 2, 1), :]
|
||||
|
||||
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
|
||||
quad[2])
|
||||
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
|
||||
quad[2])
|
||||
hv_tag = 1
|
||||
|
||||
if len_w * 2.0 < len_h:
|
||||
hv_tag = 0
|
||||
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
hv_tags.append(hv_tag)
|
||||
return np.array(validated_polys), np.array(validated_tags), np.array(
|
||||
hv_tags)
|
||||
|
||||
def crop_area(self,
|
||||
im,
|
||||
polys,
|
||||
tags,
|
||||
hv_tags,
|
||||
crop_background=False,
|
||||
max_tries=25):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
:param polys:
|
||||
:param tags:
|
||||
:param crop_background:
|
||||
:param max_tries: 50 -> 25
|
||||
:return:
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
pad_h = h // 10
|
||||
pad_w = w // 10
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return im, polys, tags, hv_tags
|
||||
for i in range(max_tries):
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
# if xmax - xmin < ARGS.min_crop_side_ratio * w or \
|
||||
# ymax - ymin < ARGS.min_crop_side_ratio * h:
|
||||
if xmax - xmin < self.min_crop_size or \
|
||||
ymax - ymin < self.min_crop_size:
|
||||
# area too small
|
||||
continue
|
||||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
||||
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(
|
||||
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
return im[ymin : ymax + 1, xmin : xmax + 1, :], \
|
||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys]
|
||||
else:
|
||||
continue
|
||||
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
hv_tags = hv_tags[selected_polys]
|
||||
polys[:, :, 0] -= xmin
|
||||
polys[:, :, 1] -= ymin
|
||||
return im, polys, tags, hv_tags
|
||||
|
||||
return im, polys, tags, hv_tags
|
||||
|
||||
def generate_direction_map(self, poly_quads, direction_map):
|
||||
"""
|
||||
"""
|
||||
width_list = []
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
width_list.append(quad_w)
|
||||
height_list.append(quad_h)
|
||||
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
|
||||
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
|
||||
|
||||
for quad in poly_quads:
|
||||
direct_vector_full = (
|
||||
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||
direct_vector = direct_vector_full / (
|
||||
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||
direction_label = tuple(
|
||||
map(float, [
|
||||
direct_vector[0], direct_vector[1], 1.0 / (average_height +
|
||||
1e-6)
|
||||
]))
|
||||
cv2.fillPoly(direction_map,
|
||||
quad.round().astype(np.int32)[np.newaxis, :, :],
|
||||
direction_label)
|
||||
return direction_map
|
||||
|
||||
def calculate_average_height(self, poly_quads):
|
||||
"""
|
||||
"""
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
height_list.append(quad_h)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
return average_height
|
||||
|
||||
def generate_tcl_label(self,
|
||||
hw,
|
||||
polys,
|
||||
tags,
|
||||
ds_ratio,
|
||||
tcl_ratio=0.3,
|
||||
shrink_ratio_of_width=0.15):
|
||||
"""
|
||||
Generate polygon.
|
||||
"""
|
||||
h, w = hw
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
|
||||
score_map = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
||||
training_mask = np.ones(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
|
||||
[1, 1, 3]).astype(np.float32)
|
||||
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
||||
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
||||
continue
|
||||
|
||||
if tag:
|
||||
# continue
|
||||
cv2.fillPoly(training_mask,
|
||||
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||
else:
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
tcl_quads = self.poly2quads(tcl_poly)
|
||||
poly_quads = self.poly2quads(poly)
|
||||
# stcl map
|
||||
stcl_quads, quad_index = self.shrink_poly_along_width(
|
||||
tcl_quads,
|
||||
shrink_ratio_of_width=shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0 / tcl_ratio)
|
||||
# generate tcl map
|
||||
cv2.fillPoly(score_map,
|
||||
np.round(stcl_quads).astype(np.int32), 1.0)
|
||||
|
||||
# generate tbo map
|
||||
for idx, quad in enumerate(stcl_quads):
|
||||
quad_mask = np.zeros((h, w), dtype=np.float32)
|
||||
quad_mask = cv2.fillPoly(
|
||||
quad_mask,
|
||||
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
|
||||
quad_mask, tbo_map)
|
||||
return score_map, tbo_map, training_mask
|
||||
|
||||
def generate_tvo_and_tco(self,
|
||||
hw,
|
||||
polys,
|
||||
tags,
|
||||
tcl_ratio=0.3,
|
||||
ds_ratio=0.25):
|
||||
"""
|
||||
Generate tcl map, tvo map and tbo map.
|
||||
"""
|
||||
h, w = hw
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
poly_mask = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
tvo_map = np.ones((9, h, w), dtype=np.float32)
|
||||
tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
|
||||
tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
|
||||
poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
|
||||
|
||||
# tco map
|
||||
tco_map = np.ones((3, h, w), dtype=np.float32)
|
||||
tco_map[0] = np.tile(np.arange(0, w), (h, 1))
|
||||
tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
|
||||
poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
|
||||
|
||||
poly_short_edge_map = np.ones((h, w), dtype=np.float32)
|
||||
|
||||
for poly, poly_tag in zip(polys, tags):
|
||||
|
||||
if poly_tag == True:
|
||||
continue
|
||||
|
||||
# adjust point order for vertical poly
|
||||
poly = self.adjust_point(poly)
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
# generate tcl map and text, 128 * 128
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
|
||||
# generate poly_tv_xy_map
|
||||
for idx in range(4):
|
||||
cv2.fillPoly(
|
||||
poly_tv_xy_map[2 * idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 0], 0), w)))
|
||||
cv2.fillPoly(
|
||||
poly_tv_xy_map[2 * idx + 1],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 1], 0), h)))
|
||||
|
||||
# generate poly_tc_xy_map
|
||||
for idx in range(2):
|
||||
cv2.fillPoly(
|
||||
poly_tc_xy_map[idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(center_point[idx]))
|
||||
|
||||
# generate poly_short_edge_map
|
||||
cv2.fillPoly(
|
||||
poly_short_edge_map,
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
|
||||
|
||||
# generate poly_mask and training_mask
|
||||
cv2.fillPoly(poly_mask,
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
1)
|
||||
|
||||
tvo_map *= poly_mask
|
||||
tvo_map[:8] -= poly_tv_xy_map
|
||||
tvo_map[-1] /= poly_short_edge_map
|
||||
tvo_map = tvo_map.transpose((1, 2, 0))
|
||||
|
||||
tco_map *= poly_mask
|
||||
tco_map[:2] -= poly_tc_xy_map
|
||||
tco_map[-1] /= poly_short_edge_map
|
||||
tco_map = tco_map.transpose((1, 2, 0))
|
||||
|
||||
return tvo_map, tco_map
|
||||
|
||||
def adjust_point(self, poly):
|
||||
"""
|
||||
adjust point order.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
if point_num == 4:
|
||||
len_1 = np.linalg.norm(poly[0] - poly[1])
|
||||
len_2 = np.linalg.norm(poly[1] - poly[2])
|
||||
len_3 = np.linalg.norm(poly[2] - poly[3])
|
||||
len_4 = np.linalg.norm(poly[3] - poly[0])
|
||||
|
||||
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
||||
poly = poly[[1, 2, 3, 0], :]
|
||||
|
||||
elif point_num > 4:
|
||||
vector_1 = poly[0] - poly[1]
|
||||
vector_2 = poly[1] - poly[2]
|
||||
cos_theta = np.dot(vector_1, vector_2) / (
|
||||
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||
theta = np.arccos(np.round(cos_theta, decimals=4))
|
||||
|
||||
if abs(theta) > (70 / 180 * math.pi):
|
||||
index = list(range(1, point_num)) + [0]
|
||||
poly = poly[np.array(index), :]
|
||||
return poly
|
||||
|
||||
def gen_min_area_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if point_num == 4:
|
||||
min_area_quad = poly
|
||||
center_point = np.sum(poly, axis=0) / 4
|
||||
else:
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad, center_point
|
||||
|
||||
def shrink_quad_along_width(self,
|
||||
quad,
|
||||
begin_width_ratio=0.,
|
||||
end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def shrink_poly_along_width(self,
|
||||
quads,
|
||||
shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0):
|
||||
"""
|
||||
shrink poly with given length.
|
||||
"""
|
||||
upper_edge_list = []
|
||||
|
||||
def get_cut_info(edge_len_list, cut_len):
|
||||
for idx, edge_len in enumerate(edge_len_list):
|
||||
cut_len -= edge_len
|
||||
if cut_len <= 0.000001:
|
||||
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
||||
return idx, ratio
|
||||
|
||||
for quad in quads:
|
||||
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
||||
upper_edge_list.append(upper_edge_len)
|
||||
|
||||
# length of left edge and right edge.
|
||||
left_length = np.linalg.norm(quads[0][0] - quads[0][
|
||||
3]) * expand_height_ratio
|
||||
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
|
||||
2]) * expand_height_ratio
|
||||
|
||||
shrink_length = min(left_length, right_length,
|
||||
sum(upper_edge_list)) * shrink_ratio_of_width
|
||||
# shrinking length
|
||||
upper_len_left = shrink_length
|
||||
upper_len_right = sum(upper_edge_list) - shrink_length
|
||||
|
||||
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
||||
left_quad = self.shrink_quad_along_width(
|
||||
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
||||
right_quad = self.shrink_quad_along_width(
|
||||
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||
|
||||
out_quad_list = []
|
||||
if left_idx == right_idx:
|
||||
out_quad_list.append(
|
||||
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||
else:
|
||||
out_quad_list.append(left_quad)
|
||||
for idx in range(left_idx + 1, right_idx):
|
||||
out_quad_list.append(quads[idx])
|
||||
out_quad_list.append(right_quad)
|
||||
|
||||
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
||||
|
||||
def vector_angle(self, A, B):
|
||||
"""
|
||||
Calculate the angle between vector AB and x-axis positive direction.
|
||||
"""
|
||||
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
||||
return np.arctan2(*AB)
|
||||
|
||||
def theta_line_cross_point(self, theta, point):
|
||||
"""
|
||||
Calculate the line through given point and angle in ax + by + c =0 form.
|
||||
"""
|
||||
x, y = point
|
||||
cos = np.cos(theta)
|
||||
sin = np.sin(theta)
|
||||
return [sin, -cos, cos * y - sin * x]
|
||||
|
||||
def line_cross_two_point(self, A, B):
|
||||
"""
|
||||
Calculate the line through given point A and B in ax + by + c =0 form.
|
||||
"""
|
||||
angle = self.vector_angle(A, B)
|
||||
return self.theta_line_cross_point(angle, A)
|
||||
|
||||
def average_angle(self, poly):
|
||||
"""
|
||||
Calculate the average angle between left and right edge in given poly.
|
||||
"""
|
||||
p0, p1, p2, p3 = poly
|
||||
angle30 = self.vector_angle(p3, p0)
|
||||
angle21 = self.vector_angle(p2, p1)
|
||||
return (angle30 + angle21) / 2
|
||||
|
||||
def line_cross_point(self, line1, line2):
|
||||
"""
|
||||
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
||||
"""
|
||||
a1, b1, c1 = line1
|
||||
a2, b2, c2 = line2
|
||||
d = a1 * b2 - a2 * b1
|
||||
|
||||
if d == 0:
|
||||
#print("line1", line1)
|
||||
#print("line2", line2)
|
||||
print('Cross point does not exist')
|
||||
return np.array([0, 0], dtype=np.float32)
|
||||
else:
|
||||
x = (b1 * c2 - b2 * c1) / d
|
||||
y = (a2 * c1 - a1 * c2) / d
|
||||
|
||||
return np.array([x, y], dtype=np.float32)
|
||||
|
||||
def quad2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point. (4, 2)
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
||||
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
||||
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
||||
|
||||
def poly2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
tcl_poly = np.zeros_like(poly)
|
||||
point_num = poly.shape[0]
|
||||
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
|
||||
) * ratio_pair
|
||||
tcl_poly[idx] = point_pair[0]
|
||||
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
||||
return tcl_poly
|
||||
|
||||
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
||||
"""
|
||||
Generate tbo_map for give quad.
|
||||
"""
|
||||
# upper and lower line function: ax + by + c = 0;
|
||||
up_line = self.line_cross_two_point(quad[0], quad[1])
|
||||
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
||||
|
||||
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[1] - quad[2]))
|
||||
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3]))
|
||||
|
||||
# average angle of left and right line.
|
||||
angle = self.average_angle(quad)
|
||||
|
||||
xy_in_poly = np.argwhere(tcl_mask == 1)
|
||||
for y, x in xy_in_poly:
|
||||
point = (x, y)
|
||||
line = self.theta_line_cross_point(angle, point)
|
||||
cross_point_upper = self.line_cross_point(up_line, line)
|
||||
cross_point_lower = self.line_cross_point(lower_line, line)
|
||||
##FIX, offset reverse
|
||||
upper_offset_x, upper_offset_y = cross_point_upper - point
|
||||
lower_offset_x, lower_offset_y = cross_point_lower - point
|
||||
tbo_map[y, x, 0] = upper_offset_y
|
||||
tbo_map[y, x, 1] = upper_offset_x
|
||||
tbo_map[y, x, 2] = lower_offset_y
|
||||
tbo_map[y, x, 3] = lower_offset_x
|
||||
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
||||
return tbo_map
|
||||
|
||||
def poly2quads(self, poly):
|
||||
"""
|
||||
Split poly into quads.
|
||||
"""
|
||||
quad_list = []
|
||||
point_num = poly.shape[0]
|
||||
|
||||
# point pair
|
||||
point_pair_list = []
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
quad_num = point_num // 2 - 1
|
||||
for idx in range(quad_num):
|
||||
# reshape and adjust to clock-wise
|
||||
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
|
||||
).reshape(4, 2)[[0, 2, 3, 1]])
|
||||
|
||||
return np.array(quad_list)
|
||||
|
||||
def __call__(self, data):
|
||||
im = data['image']
|
||||
text_polys = data['polys']
|
||||
text_tags = data['ignore_tags']
|
||||
if im is None:
|
||||
return None
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
||||
text_polys, text_tags, (h, w))
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
#set aspect ratio and keep area fix
|
||||
asp_scales = np.arange(1.0, 1.55, 0.1)
|
||||
asp_scale = np.random.choice(asp_scales)
|
||||
|
||||
if np.random.rand() < 0.5:
|
||||
asp_scale = 1.0 / asp_scale
|
||||
asp_scale = math.sqrt(asp_scale)
|
||||
|
||||
asp_wx = asp_scale
|
||||
asp_hy = 1.0 / asp_scale
|
||||
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
||||
text_polys[:, :, 0] *= asp_wx
|
||||
text_polys[:, :, 1] *= asp_hy
|
||||
|
||||
h, w, _ = im.shape
|
||||
if max(h, w) > 2048:
|
||||
rd_scale = 2048.0 / max(h, w)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
h, w, _ = im.shape
|
||||
if min(h, w) < 16:
|
||||
return None
|
||||
|
||||
#no background
|
||||
im, text_polys, text_tags, hv_tags = self.crop_area(im, \
|
||||
text_polys, text_tags, hv_tags, crop_background=False)
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
#continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
new_h, new_w, _ = im.shape
|
||||
if (new_h is None) or (new_w is None):
|
||||
return None
|
||||
#resize image
|
||||
std_ratio = float(self.input_size) / max(new_w, new_h)
|
||||
rand_scales = np.array(
|
||||
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
rz_scale = std_ratio * np.random.choice(rand_scales)
|
||||
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
||||
text_polys[:, :, 0] *= rz_scale
|
||||
text_polys[:, :, 1] *= rz_scale
|
||||
|
||||
#add gaussian blur
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
ks = np.random.permutation(5)[0] + 1
|
||||
ks = int(ks / 2) * 2 + 1
|
||||
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||
#add brighter
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 + np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
#add darker
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 - np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
|
||||
# Padding the im to [input_size, input_size]
|
||||
new_h, new_w, _ = im.shape
|
||||
if min(new_w, new_h) < self.input_size * 0.5:
|
||||
return None
|
||||
|
||||
im_padded = np.ones(
|
||||
(self.input_size, self.input_size, 3), dtype=np.float32)
|
||||
im_padded[:, :, 2] = 0.485 * 255
|
||||
im_padded[:, :, 1] = 0.456 * 255
|
||||
im_padded[:, :, 0] = 0.406 * 255
|
||||
|
||||
# Random the start position
|
||||
del_h = self.input_size - new_h
|
||||
del_w = self.input_size - new_w
|
||||
sh, sw = 0, 0
|
||||
if del_h > 1:
|
||||
sh = int(np.random.rand() * del_h)
|
||||
if del_w > 1:
|
||||
sw = int(np.random.rand() * del_w)
|
||||
|
||||
# Padding
|
||||
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
|
||||
text_polys[:, :, 0] += sw
|
||||
text_polys[:, :, 1] += sh
|
||||
|
||||
score_map, border_map, training_mask = self.generate_tcl_label(
|
||||
(self.input_size, self.input_size), text_polys, text_tags, 0.25)
|
||||
|
||||
# SAST head
|
||||
tvo_map, tco_map = self.generate_tvo_and_tco(
|
||||
(self.input_size, self.input_size),
|
||||
text_polys,
|
||||
text_tags,
|
||||
tcl_ratio=0.3,
|
||||
ds_ratio=0.25)
|
||||
# print("test--------tvo_map shape:", tvo_map.shape)
|
||||
|
||||
im_padded[:, :, 2] -= 0.485 * 255
|
||||
im_padded[:, :, 1] -= 0.456 * 255
|
||||
im_padded[:, :, 0] -= 0.406 * 255
|
||||
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
|
||||
data['image'] = im_padded[::-1, :, :]
|
||||
data['score_map'] = score_map[np.newaxis, :, :]
|
||||
data['border_map'] = border_map.transpose((2, 0, 1))
|
||||
data['training_mask'] = training_mask[np.newaxis, :, :]
|
||||
data['tvo_map'] = tvo_map.transpose((2, 0, 1))
|
||||
data['tco_map'] = tco_map.transpose((2, 0, 1))
|
||||
return data
|
||||
@@ -1,60 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
from .rec_img_aug import resize_norm_img
|
||||
|
||||
|
||||
class SSLRotateResize(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
padding=False,
|
||||
select_all=True,
|
||||
mode="train",
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.padding = padding
|
||||
self.select_all = select_all
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, data):
|
||||
img = data["image"]
|
||||
|
||||
data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
||||
data["image_r180"] = cv2.rotate(data["image_r90"],
|
||||
cv2.ROTATE_90_CLOCKWISE)
|
||||
data["image_r270"] = cv2.rotate(data["image_r180"],
|
||||
cv2.ROTATE_90_CLOCKWISE)
|
||||
|
||||
images = []
|
||||
for key in ["image", "image_r90", "image_r180", "image_r270"]:
|
||||
images.append(
|
||||
resize_norm_img(
|
||||
data.pop(key),
|
||||
image_shape=self.image_shape,
|
||||
padding=self.padding)[0])
|
||||
data["image"] = np.stack(images, axis=0)
|
||||
data["label"] = np.array(list(range(4)))
|
||||
if not self.select_all:
|
||||
data["image"] = data["image"][0::2] # just choose 0 and 180
|
||||
data["label"] = data["label"][0:2] # label needs to be continuous
|
||||
if self.mode == "test":
|
||||
data["image"] = data["image"][0]
|
||||
data["label"] = data["label"][0]
|
||||
return data
|
||||
@@ -1,17 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .augment import tia_perspective, tia_distort, tia_stretch
|
||||
|
||||
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
|
||||
@@ -1,120 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from .warp_mls import WarpMLS
|
||||
|
||||
|
||||
def tia_distort(src, segment=4):
|
||||
img_h, img_w = src.shape[:2]
|
||||
|
||||
cut = img_w // segment
|
||||
thresh = cut // 3
|
||||
|
||||
src_pts = list()
|
||||
dst_pts = list()
|
||||
|
||||
src_pts.append([0, 0])
|
||||
src_pts.append([img_w, 0])
|
||||
src_pts.append([img_w, img_h])
|
||||
src_pts.append([0, img_h])
|
||||
|
||||
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
|
||||
dst_pts.append(
|
||||
[img_w - np.random.randint(thresh), np.random.randint(thresh)])
|
||||
dst_pts.append(
|
||||
[img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
|
||||
dst_pts.append(
|
||||
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
|
||||
|
||||
half_thresh = thresh * 0.5
|
||||
|
||||
for cut_idx in np.arange(1, segment, 1):
|
||||
src_pts.append([cut * cut_idx, 0])
|
||||
src_pts.append([cut * cut_idx, img_h])
|
||||
dst_pts.append([
|
||||
cut * cut_idx + np.random.randint(thresh) - half_thresh,
|
||||
np.random.randint(thresh) - half_thresh
|
||||
])
|
||||
dst_pts.append([
|
||||
cut * cut_idx + np.random.randint(thresh) - half_thresh,
|
||||
img_h + np.random.randint(thresh) - half_thresh
|
||||
])
|
||||
|
||||
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
|
||||
dst = trans.generate()
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
def tia_stretch(src, segment=4):
|
||||
img_h, img_w = src.shape[:2]
|
||||
|
||||
cut = img_w // segment
|
||||
thresh = cut * 4 // 5
|
||||
|
||||
src_pts = list()
|
||||
dst_pts = list()
|
||||
|
||||
src_pts.append([0, 0])
|
||||
src_pts.append([img_w, 0])
|
||||
src_pts.append([img_w, img_h])
|
||||
src_pts.append([0, img_h])
|
||||
|
||||
dst_pts.append([0, 0])
|
||||
dst_pts.append([img_w, 0])
|
||||
dst_pts.append([img_w, img_h])
|
||||
dst_pts.append([0, img_h])
|
||||
|
||||
half_thresh = thresh * 0.5
|
||||
|
||||
for cut_idx in np.arange(1, segment, 1):
|
||||
move = np.random.randint(thresh) - half_thresh
|
||||
src_pts.append([cut * cut_idx, 0])
|
||||
src_pts.append([cut * cut_idx, img_h])
|
||||
dst_pts.append([cut * cut_idx + move, 0])
|
||||
dst_pts.append([cut * cut_idx + move, img_h])
|
||||
|
||||
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
|
||||
dst = trans.generate()
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
def tia_perspective(src):
|
||||
img_h, img_w = src.shape[:2]
|
||||
|
||||
thresh = img_h // 2
|
||||
|
||||
src_pts = list()
|
||||
dst_pts = list()
|
||||
|
||||
src_pts.append([0, 0])
|
||||
src_pts.append([img_w, 0])
|
||||
src_pts.append([img_w, img_h])
|
||||
src_pts.append([0, img_h])
|
||||
|
||||
dst_pts.append([0, np.random.randint(thresh)])
|
||||
dst_pts.append([img_w, np.random.randint(thresh)])
|
||||
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
|
||||
dst_pts.append([0, img_h - np.random.randint(thresh)])
|
||||
|
||||
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
|
||||
dst = trans.generate()
|
||||
|
||||
return dst
|
||||
@@ -1,168 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class WarpMLS:
|
||||
def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
|
||||
self.src = src
|
||||
self.src_pts = src_pts
|
||||
self.dst_pts = dst_pts
|
||||
self.pt_count = len(self.dst_pts)
|
||||
self.dst_w = dst_w
|
||||
self.dst_h = dst_h
|
||||
self.trans_ratio = trans_ratio
|
||||
self.grid_size = 100
|
||||
self.rdx = np.zeros((self.dst_h, self.dst_w))
|
||||
self.rdy = np.zeros((self.dst_h, self.dst_w))
|
||||
|
||||
@staticmethod
|
||||
def __bilinear_interp(x, y, v11, v12, v21, v22):
|
||||
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
|
||||
(1 - y) + v22 * y) * x
|
||||
|
||||
def generate(self):
|
||||
self.calc_delta()
|
||||
return self.gen_img()
|
||||
|
||||
def calc_delta(self):
|
||||
w = np.zeros(self.pt_count, dtype=np.float32)
|
||||
|
||||
if self.pt_count < 2:
|
||||
return
|
||||
|
||||
i = 0
|
||||
while 1:
|
||||
if self.dst_w <= i < self.dst_w + self.grid_size - 1:
|
||||
i = self.dst_w - 1
|
||||
elif i >= self.dst_w:
|
||||
break
|
||||
|
||||
j = 0
|
||||
while 1:
|
||||
if self.dst_h <= j < self.dst_h + self.grid_size - 1:
|
||||
j = self.dst_h - 1
|
||||
elif j >= self.dst_h:
|
||||
break
|
||||
|
||||
sw = 0
|
||||
swp = np.zeros(2, dtype=np.float32)
|
||||
swq = np.zeros(2, dtype=np.float32)
|
||||
new_pt = np.zeros(2, dtype=np.float32)
|
||||
cur_pt = np.array([i, j], dtype=np.float32)
|
||||
|
||||
k = 0
|
||||
for k in range(self.pt_count):
|
||||
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
|
||||
break
|
||||
|
||||
w[k] = 1. / (
|
||||
(i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
|
||||
(j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
|
||||
|
||||
sw += w[k]
|
||||
swp = swp + w[k] * np.array(self.dst_pts[k])
|
||||
swq = swq + w[k] * np.array(self.src_pts[k])
|
||||
|
||||
if k == self.pt_count - 1:
|
||||
pstar = 1 / sw * swp
|
||||
qstar = 1 / sw * swq
|
||||
|
||||
miu_s = 0
|
||||
for k in range(self.pt_count):
|
||||
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
|
||||
continue
|
||||
pt_i = self.dst_pts[k] - pstar
|
||||
miu_s += w[k] * np.sum(pt_i * pt_i)
|
||||
|
||||
cur_pt -= pstar
|
||||
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
|
||||
|
||||
for k in range(self.pt_count):
|
||||
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
|
||||
continue
|
||||
|
||||
pt_i = self.dst_pts[k] - pstar
|
||||
pt_j = np.array([-pt_i[1], pt_i[0]])
|
||||
|
||||
tmp_pt = np.zeros(2, dtype=np.float32)
|
||||
tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
|
||||
np.sum(pt_j * cur_pt) * self.src_pts[k][1]
|
||||
tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
|
||||
np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
|
||||
tmp_pt *= (w[k] / miu_s)
|
||||
new_pt += tmp_pt
|
||||
|
||||
new_pt += qstar
|
||||
else:
|
||||
new_pt = self.src_pts[k]
|
||||
|
||||
self.rdx[j, i] = new_pt[0] - i
|
||||
self.rdy[j, i] = new_pt[1] - j
|
||||
|
||||
j += self.grid_size
|
||||
i += self.grid_size
|
||||
|
||||
def gen_img(self):
|
||||
src_h, src_w = self.src.shape[:2]
|
||||
dst = np.zeros_like(self.src, dtype=np.float32)
|
||||
|
||||
for i in np.arange(0, self.dst_h, self.grid_size):
|
||||
for j in np.arange(0, self.dst_w, self.grid_size):
|
||||
ni = i + self.grid_size
|
||||
nj = j + self.grid_size
|
||||
w = h = self.grid_size
|
||||
if ni >= self.dst_h:
|
||||
ni = self.dst_h - 1
|
||||
h = ni - i + 1
|
||||
if nj >= self.dst_w:
|
||||
nj = self.dst_w - 1
|
||||
w = nj - j + 1
|
||||
|
||||
di = np.reshape(np.arange(h), (-1, 1))
|
||||
dj = np.reshape(np.arange(w), (1, -1))
|
||||
delta_x = self.__bilinear_interp(
|
||||
di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
|
||||
self.rdx[ni, j], self.rdx[ni, nj])
|
||||
delta_y = self.__bilinear_interp(
|
||||
di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
|
||||
self.rdy[ni, j], self.rdy[ni, nj])
|
||||
nx = j + dj + delta_x * self.trans_ratio
|
||||
ny = i + di + delta_y * self.trans_ratio
|
||||
nx = np.clip(nx, 0, src_w - 1)
|
||||
ny = np.clip(ny, 0, src_h - 1)
|
||||
nxi = np.array(np.floor(nx), dtype=np.int32)
|
||||
nyi = np.array(np.floor(ny), dtype=np.int32)
|
||||
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
|
||||
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
|
||||
|
||||
if len(self.src.shape) == 3:
|
||||
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
|
||||
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
|
||||
else:
|
||||
x = ny - nyi
|
||||
y = nx - nxi
|
||||
dst[i:i + h, j:j + w] = self.__bilinear_interp(
|
||||
x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
|
||||
self.src[nyi1, nxi], self.src[nyi1, nxi1])
|
||||
|
||||
dst = np.clip(dst, 0, 255)
|
||||
dst = np.array(dst, dtype=np.uint8)
|
||||
|
||||
return dst
|
||||
@@ -1,19 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
|
||||
|
||||
__all__ = [
|
||||
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
|
||||
]
|
||||
@@ -1,17 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
|
||||
from .vqa_token_pad import VQATokenPad
|
||||
from .vqa_token_relation import VQAReTokenRelation
|
||||
@@ -1,122 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class VQASerTokenChunk(object):
|
||||
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
encoded_inputs_all = []
|
||||
seq_len = len(data['input_ids'])
|
||||
for index in range(0, seq_len, self.max_seq_len):
|
||||
chunk_beg = index
|
||||
chunk_end = min(index + self.max_seq_len, seq_len)
|
||||
encoded_inputs_example = {}
|
||||
for key in data:
|
||||
if key in [
|
||||
'label', 'input_ids', 'labels', 'token_type_ids',
|
||||
'bbox', 'attention_mask'
|
||||
]:
|
||||
if self.infer_mode and key == 'labels':
|
||||
encoded_inputs_example[key] = data[key]
|
||||
else:
|
||||
encoded_inputs_example[key] = data[key][chunk_beg:
|
||||
chunk_end]
|
||||
else:
|
||||
encoded_inputs_example[key] = data[key]
|
||||
|
||||
encoded_inputs_all.append(encoded_inputs_example)
|
||||
if len(encoded_inputs_all) == 0:
|
||||
return None
|
||||
return encoded_inputs_all[0]
|
||||
|
||||
|
||||
class VQAReTokenChunk(object):
|
||||
def __init__(self,
|
||||
max_seq_len=512,
|
||||
entities_labels=None,
|
||||
infer_mode=False,
|
||||
**kwargs):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.entities_labels = {
|
||||
'HEADER': 0,
|
||||
'QUESTION': 1,
|
||||
'ANSWER': 2
|
||||
} if entities_labels is None else entities_labels
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
# prepare data
|
||||
entities = data.pop('entities')
|
||||
relations = data.pop('relations')
|
||||
encoded_inputs_all = []
|
||||
for index in range(0, len(data["input_ids"]), self.max_seq_len):
|
||||
item = {}
|
||||
for key in data:
|
||||
if key in [
|
||||
'label', 'input_ids', 'labels', 'token_type_ids',
|
||||
'bbox', 'attention_mask'
|
||||
]:
|
||||
if self.infer_mode and key == 'labels':
|
||||
item[key] = data[key]
|
||||
else:
|
||||
item[key] = data[key][index:index + self.max_seq_len]
|
||||
else:
|
||||
item[key] = data[key]
|
||||
# select entity in current chunk
|
||||
entities_in_this_span = []
|
||||
global_to_local_map = {} #
|
||||
for entity_id, entity in enumerate(entities):
|
||||
if (index <= entity["start"] < index + self.max_seq_len and
|
||||
index <= entity["end"] < index + self.max_seq_len):
|
||||
entity["start"] = entity["start"] - index
|
||||
entity["end"] = entity["end"] - index
|
||||
global_to_local_map[entity_id] = len(entities_in_this_span)
|
||||
entities_in_this_span.append(entity)
|
||||
|
||||
# select relations in current chunk
|
||||
relations_in_this_span = []
|
||||
for relation in relations:
|
||||
if (index <= relation["start_index"] < index + self.max_seq_len
|
||||
and index <= relation["end_index"] <
|
||||
index + self.max_seq_len):
|
||||
relations_in_this_span.append({
|
||||
"head": global_to_local_map[relation["head"]],
|
||||
"tail": global_to_local_map[relation["tail"]],
|
||||
"start_index": relation["start_index"] - index,
|
||||
"end_index": relation["end_index"] - index,
|
||||
})
|
||||
item.update({
|
||||
"entities": self.reformat(entities_in_this_span),
|
||||
"relations": self.reformat(relations_in_this_span),
|
||||
})
|
||||
if len(item['entities']) > 0:
|
||||
item['entities']['label'] = [
|
||||
self.entities_labels[x] for x in item['entities']['label']
|
||||
]
|
||||
encoded_inputs_all.append(item)
|
||||
if len(encoded_inputs_all) == 0:
|
||||
return None
|
||||
return encoded_inputs_all[0]
|
||||
|
||||
def reformat(self, data):
|
||||
new_data = defaultdict(list)
|
||||
for item in data:
|
||||
for k, v in item.items():
|
||||
new_data[k].append(v)
|
||||
return new_data
|
||||
@@ -1,104 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VQATokenPad(object):
|
||||
def __init__(self,
|
||||
max_seq_len=512,
|
||||
pad_to_max_seq_len=True,
|
||||
return_attention_mask=True,
|
||||
return_token_type_ids=True,
|
||||
truncation_strategy="longest_first",
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False,
|
||||
infer_mode=False,
|
||||
**kwargs):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.pad_to_max_seq_len = max_seq_len
|
||||
self.return_attention_mask = return_attention_mask
|
||||
self.return_token_type_ids = return_token_type_ids
|
||||
self.truncation_strategy = truncation_strategy
|
||||
self.return_overflowing_tokens = return_overflowing_tokens
|
||||
self.return_special_tokens_mask = return_special_tokens_mask
|
||||
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
needs_to_be_padded = self.pad_to_max_seq_len and len(data[
|
||||
"input_ids"]) < self.max_seq_len
|
||||
|
||||
if needs_to_be_padded:
|
||||
if 'tokenizer_params' in data:
|
||||
tokenizer_params = data.pop('tokenizer_params')
|
||||
else:
|
||||
tokenizer_params = dict(
|
||||
padding_side='right', pad_token_type_id=0, pad_token_id=1)
|
||||
|
||||
difference = self.max_seq_len - len(data["input_ids"])
|
||||
if tokenizer_params['padding_side'] == 'right':
|
||||
if self.return_attention_mask:
|
||||
data["attention_mask"] = [1] * len(data[
|
||||
"input_ids"]) + [0] * difference
|
||||
if self.return_token_type_ids:
|
||||
data["token_type_ids"] = (
|
||||
data["token_type_ids"] +
|
||||
[tokenizer_params['pad_token_type_id']] * difference)
|
||||
if self.return_special_tokens_mask:
|
||||
data["special_tokens_mask"] = data[
|
||||
"special_tokens_mask"] + [1] * difference
|
||||
data["input_ids"] = data["input_ids"] + [
|
||||
tokenizer_params['pad_token_id']
|
||||
] * difference
|
||||
if not self.infer_mode:
|
||||
data["labels"] = data[
|
||||
"labels"] + [self.pad_token_label_id] * difference
|
||||
data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
|
||||
elif tokenizer_params['padding_side'] == 'left':
|
||||
if self.return_attention_mask:
|
||||
data["attention_mask"] = [0] * difference + [
|
||||
1
|
||||
] * len(data["input_ids"])
|
||||
if self.return_token_type_ids:
|
||||
data["token_type_ids"] = (
|
||||
[tokenizer_params['pad_token_type_id']] * difference +
|
||||
data["token_type_ids"])
|
||||
if self.return_special_tokens_mask:
|
||||
data["special_tokens_mask"] = [
|
||||
1
|
||||
] * difference + data["special_tokens_mask"]
|
||||
data["input_ids"] = [tokenizer_params['pad_token_id']
|
||||
] * difference + data["input_ids"]
|
||||
if not self.infer_mode:
|
||||
data["labels"] = [self.pad_token_label_id
|
||||
] * difference + data["labels"]
|
||||
data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
|
||||
else:
|
||||
if self.return_attention_mask:
|
||||
data["attention_mask"] = [1] * len(data["input_ids"])
|
||||
|
||||
for key in data:
|
||||
if key in [
|
||||
'input_ids', 'labels', 'token_type_ids', 'bbox',
|
||||
'attention_mask'
|
||||
]:
|
||||
if self.infer_mode:
|
||||
if key != 'labels':
|
||||
length = min(len(data[key]), self.max_seq_len)
|
||||
data[key] = data[key][:length]
|
||||
else:
|
||||
continue
|
||||
data[key] = np.array(data[key], dtype='int64')
|
||||
return data
|
||||
@@ -1,67 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class VQAReTokenRelation(object):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
build relations
|
||||
"""
|
||||
entities = data['entities']
|
||||
relations = data['relations']
|
||||
id2label = data.pop('id2label')
|
||||
empty_entity = data.pop('empty_entity')
|
||||
entity_id_to_index_map = data.pop('entity_id_to_index_map')
|
||||
|
||||
relations = list(set(relations))
|
||||
relations = [
|
||||
rel for rel in relations
|
||||
if rel[0] not in empty_entity and rel[1] not in empty_entity
|
||||
]
|
||||
kv_relations = []
|
||||
for rel in relations:
|
||||
pair = [id2label[rel[0]], id2label[rel[1]]]
|
||||
if pair == ["question", "answer"]:
|
||||
kv_relations.append({
|
||||
"head": entity_id_to_index_map[rel[0]],
|
||||
"tail": entity_id_to_index_map[rel[1]]
|
||||
})
|
||||
elif pair == ["answer", "question"]:
|
||||
kv_relations.append({
|
||||
"head": entity_id_to_index_map[rel[1]],
|
||||
"tail": entity_id_to_index_map[rel[0]]
|
||||
})
|
||||
else:
|
||||
continue
|
||||
relations = sorted(
|
||||
[{
|
||||
"head": rel["head"],
|
||||
"tail": rel["tail"],
|
||||
"start_index": self.get_relation_span(rel, entities)[0],
|
||||
"end_index": self.get_relation_span(rel, entities)[1],
|
||||
} for rel in kv_relations],
|
||||
key=lambda x: x["head"], )
|
||||
|
||||
data['relations'] = relations
|
||||
return data
|
||||
|
||||
def get_relation_span(self, rel, entities):
|
||||
bound = []
|
||||
for entity_index in [rel["head"], rel["tail"]]:
|
||||
bound.append(entities[entity_index]["start"])
|
||||
bound.append(entities[entity_index]["end"])
|
||||
return min(bound), max(bound)
|
||||
@@ -1,118 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import os
|
||||
from paddle.io import Dataset
|
||||
import lmdb
|
||||
import cv2
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
|
||||
|
||||
class LMDBDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(LMDBDataSet, self).__init__()
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
batch_size = loader_config['batch_size_per_card']
|
||||
data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
|
||||
logger.info("Initialize indexs of datasets:%s" % data_dir)
|
||||
self.data_idx_order_list = self.dataset_traversal()
|
||||
if self.do_shuffle:
|
||||
np.random.shuffle(self.data_idx_order_list)
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||
self.need_reset = True in [x < 1 for x in ratio_list]
|
||||
|
||||
def load_hierarchical_lmdb_dataset(self, data_dir):
|
||||
lmdb_sets = {}
|
||||
dataset_idx = 0
|
||||
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
|
||||
if not dirnames:
|
||||
env = lmdb.open(
|
||||
dirpath,
|
||||
max_readers=32,
|
||||
readonly=True,
|
||||
lock=False,
|
||||
readahead=False,
|
||||
meminit=False)
|
||||
txn = env.begin(write=False)
|
||||
num_samples = int(txn.get('num-samples'.encode()))
|
||||
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
|
||||
"txn":txn, "num_samples":num_samples}
|
||||
dataset_idx += 1
|
||||
return lmdb_sets
|
||||
|
||||
def dataset_traversal(self):
|
||||
lmdb_num = len(self.lmdb_sets)
|
||||
total_sample_num = 0
|
||||
for lno in range(lmdb_num):
|
||||
total_sample_num += self.lmdb_sets[lno]['num_samples']
|
||||
data_idx_order_list = np.zeros((total_sample_num, 2))
|
||||
beg_idx = 0
|
||||
for lno in range(lmdb_num):
|
||||
tmp_sample_num = self.lmdb_sets[lno]['num_samples']
|
||||
end_idx = beg_idx + tmp_sample_num
|
||||
data_idx_order_list[beg_idx:end_idx, 0] = lno
|
||||
data_idx_order_list[beg_idx:end_idx, 1] \
|
||||
= list(range(tmp_sample_num))
|
||||
data_idx_order_list[beg_idx:end_idx, 1] += 1
|
||||
beg_idx = beg_idx + tmp_sample_num
|
||||
return data_idx_order_list
|
||||
|
||||
def get_img_data(self, value):
|
||||
"""get_img_data"""
|
||||
if not value:
|
||||
return None
|
||||
imgdata = np.frombuffer(value, dtype='uint8')
|
||||
if imgdata is None:
|
||||
return None
|
||||
imgori = cv2.imdecode(imgdata, 1)
|
||||
if imgori is None:
|
||||
return None
|
||||
return imgori
|
||||
|
||||
def get_lmdb_sample_info(self, txn, index):
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key)
|
||||
if label is None:
|
||||
return None
|
||||
label = label.decode('utf-8')
|
||||
img_key = 'image-%09d'.encode() % index
|
||||
imgbuf = txn.get(img_key)
|
||||
return imgbuf, label
|
||||
|
||||
def __getitem__(self, idx):
|
||||
lmdb_idx, file_idx = self.data_idx_order_list[idx]
|
||||
lmdb_idx = int(lmdb_idx)
|
||||
file_idx = int(file_idx)
|
||||
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
|
||||
file_idx)
|
||||
if sample_info is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
img, label = sample_info
|
||||
data = {'image': img, 'label': label}
|
||||
outs = transform(data, self.ops)
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return self.data_idx_order_list.shape[0]
|
||||
@@ -1,106 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import os
|
||||
from paddle.io import Dataset
|
||||
from .imaug import transform, create_operators
|
||||
import random
|
||||
|
||||
|
||||
class PGDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(PGDataSet, self).__init__()
|
||||
|
||||
self.logger = logger
|
||||
self.seed = seed
|
||||
self.mode = mode
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
||||
self.delimiter = dataset_config.get('delimiter', '\t')
|
||||
label_file_list = dataset_config.pop('label_file_list')
|
||||
data_source_num = len(label_file_list)
|
||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||
if isinstance(ratio_list, (float, int)):
|
||||
ratio_list = [float(ratio_list)] * int(data_source_num)
|
||||
assert len(
|
||||
ratio_list
|
||||
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
self.shuffle_data_random()
|
||||
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
self.need_reset = True in [x < 1 for x in ratio_list]
|
||||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def get_image_info_list(self, file_list, ratio_list):
|
||||
if isinstance(file_list, str):
|
||||
file_list = [file_list]
|
||||
data_lines = []
|
||||
for idx, file in enumerate(file_list):
|
||||
with open(file, "rb") as f:
|
||||
lines = f.readlines()
|
||||
if self.mode == "train" or ratio_list[idx] < 1.0:
|
||||
random.seed(self.seed)
|
||||
lines = random.sample(lines,
|
||||
round(len(lines) * ratio_list[idx]))
|
||||
data_lines.extend(lines)
|
||||
return data_lines
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
img_id = 0
|
||||
try:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
if self.mode.lower() == 'eval':
|
||||
try:
|
||||
img_id = int(data_line.split(".")[0][7:])
|
||||
except:
|
||||
img_id = 0
|
||||
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
outs = transform(data, self.ops)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
self.data_idx_order_list[idx], e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
||||
@@ -1,114 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from paddle.io import Dataset
|
||||
import json
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
|
||||
|
||||
class PubTabDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(PubTabDataSet, self).__init__()
|
||||
self.logger = logger
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
||||
label_file_path = dataset_config.pop('label_file_path')
|
||||
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
self.do_hard_select = False
|
||||
if 'hard_select' in loader_config:
|
||||
self.do_hard_select = loader_config['hard_select']
|
||||
self.hard_prob = loader_config['hard_prob']
|
||||
if self.do_hard_select:
|
||||
self.img_select_prob = self.load_hard_select_prob()
|
||||
self.table_select_type = None
|
||||
if 'table_select_type' in loader_config:
|
||||
self.table_select_type = loader_config['table_select_type']
|
||||
self.table_select_prob = loader_config['table_select_prob']
|
||||
|
||||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_path)
|
||||
with open(label_file_path, "rb") as f:
|
||||
self.data_lines = f.readlines()
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
self.shuffle_data_random()
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||
self.need_reset = True in [x < 1 for x in ratio_list]
|
||||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
data_line = self.data_lines[idx]
|
||||
data_line = data_line.decode('utf-8').strip("\n")
|
||||
info = json.loads(data_line)
|
||||
file_name = info['filename']
|
||||
select_flag = True
|
||||
if self.do_hard_select:
|
||||
prob = self.img_select_prob[file_name]
|
||||
if prob < random.uniform(0, 1):
|
||||
select_flag = False
|
||||
|
||||
if self.table_select_type:
|
||||
structure = info['html']['structure']['tokens'].copy()
|
||||
structure_str = ''.join(structure)
|
||||
table_type = "simple"
|
||||
if 'colspan' in structure_str or 'rowspan' in structure_str:
|
||||
table_type = "complex"
|
||||
if table_type == "complex":
|
||||
if self.table_select_prob < random.uniform(0, 1):
|
||||
select_flag = False
|
||||
|
||||
if select_flag:
|
||||
cells = info['html']['cells'].copy()
|
||||
structure = info['html']['structure'].copy()
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {
|
||||
'img_path': img_path,
|
||||
'cells': cells,
|
||||
'structure': structure
|
||||
}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
outs = transform(data, self.ops)
|
||||
else:
|
||||
outs = None
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
||||
@@ -1,151 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import traceback
|
||||
from paddle.io import Dataset
|
||||
from .imaug import transform, create_operators
|
||||
|
||||
|
||||
class SimpleDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(SimpleDataSet, self).__init__()
|
||||
self.logger = logger
|
||||
self.mode = mode.lower()
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
||||
self.delimiter = dataset_config.get('delimiter', '\t')
|
||||
label_file_list = dataset_config.pop('label_file_list')
|
||||
data_source_num = len(label_file_list)
|
||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||
if isinstance(ratio_list, (float, int)):
|
||||
ratio_list = [float(ratio_list)] * int(data_source_num)
|
||||
|
||||
assert len(
|
||||
ratio_list
|
||||
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if self.mode == "train" and self.do_shuffle:
|
||||
self.shuffle_data_random()
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
|
||||
2)
|
||||
self.need_reset = True in [x < 1 for x in ratio_list]
|
||||
|
||||
def get_image_info_list(self, file_list, ratio_list):
|
||||
if isinstance(file_list, str):
|
||||
file_list = [file_list]
|
||||
data_lines = []
|
||||
for idx, file in enumerate(file_list):
|
||||
with open(file, "rb") as f:
|
||||
lines = f.readlines()
|
||||
if self.mode == "train" or ratio_list[idx] < 1.0:
|
||||
random.seed(self.seed)
|
||||
lines = random.sample(lines,
|
||||
round(len(lines) * ratio_list[idx]))
|
||||
data_lines.extend(lines)
|
||||
return data_lines
|
||||
|
||||
def shuffle_data_random(self):
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def _try_parse_filename_list(self, file_name):
|
||||
# multiple images -> one gt label
|
||||
if len(file_name) > 0 and file_name[0] == "[":
|
||||
try:
|
||||
info = json.loads(file_name)
|
||||
file_name = random.choice(info)
|
||||
except:
|
||||
pass
|
||||
return file_name
|
||||
|
||||
def get_ext_data(self):
|
||||
ext_data_num = 0
|
||||
for op in self.ops:
|
||||
if hasattr(op, 'ext_data_num'):
|
||||
ext_data_num = getattr(op, 'ext_data_num')
|
||||
break
|
||||
load_data_ops = self.ops[:self.ext_op_transform_idx]
|
||||
ext_data = []
|
||||
|
||||
while len(ext_data) < ext_data_num:
|
||||
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
|
||||
))]
|
||||
data_line = self.data_lines[file_idx]
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
file_name = self._try_parse_filename_list(file_name)
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
if not os.path.exists(img_path):
|
||||
continue
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data = transform(data, load_data_ops)
|
||||
|
||||
if data is None:
|
||||
continue
|
||||
if 'polys' in data.keys():
|
||||
if data['polys'].shape[1] != 4:
|
||||
continue
|
||||
ext_data.append(data)
|
||||
return ext_data
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
try:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
file_name = self._try_parse_filename_list(file_name)
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data['ext_data'] = self.get_ext_data()
|
||||
outs = transform(data, self.ops)
|
||||
except:
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, traceback.format_exc()))
|
||||
outs = None
|
||||
if outs is None:
|
||||
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||
rnd_idx = np.random.randint(self.__len__(
|
||||
)) if self.mode == "train" else (idx + 1) % self.__len__()
|
||||
return self.__getitem__(rnd_idx)
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
||||
@@ -1,71 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
# basic_loss
|
||||
from .basic_loss import LossFromOutput
|
||||
|
||||
# det loss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_east_loss import EASTLoss
|
||||
from .det_sast_loss import SASTLoss
|
||||
from .det_pse_loss import PSELoss
|
||||
from .det_fce_loss import FCELoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
from .rec_nrtr_loss import NRTRLoss
|
||||
from .rec_sar_loss import SARLoss
|
||||
from .rec_aster_loss import AsterLoss
|
||||
from .rec_pren_loss import PRENLoss
|
||||
from .rec_multi_loss import MultiLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
# e2e loss
|
||||
from .e2e_pg_loss import PGLoss
|
||||
from .kie_sdmgr_loss import SDMGRLoss
|
||||
|
||||
# basic loss function
|
||||
from .basic_loss import DistanceLoss
|
||||
|
||||
# combined loss function
|
||||
from .combined_loss import CombinedLoss
|
||||
|
||||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
|
||||
# vqa token loss
|
||||
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
|
||||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||
support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
@@ -1,52 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This code is refer from: https://github.com/viig99/LS-ACELoss
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
|
||||
class ACELoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(
|
||||
weight=None,
|
||||
ignore_index=0,
|
||||
reduction='none',
|
||||
soft_label=True,
|
||||
axis=-1)
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
if isinstance(predicts, (list, tuple)):
|
||||
predicts = predicts[-1]
|
||||
|
||||
B, N = predicts.shape[:2]
|
||||
div = paddle.to_tensor([N]).astype('float32')
|
||||
|
||||
predicts = nn.functional.softmax(predicts, axis=-1)
|
||||
aggregation_preds = paddle.sum(predicts, axis=1)
|
||||
aggregation_preds = paddle.divide(aggregation_preds, div)
|
||||
|
||||
length = batch[2].astype("float32")
|
||||
batch = batch[3].astype("float32")
|
||||
batch[:, 0] = paddle.subtract(div, length)
|
||||
batch = paddle.divide(batch, div)
|
||||
|
||||
loss = self.loss_func(aggregation_preds, batch)
|
||||
return {"loss_ace": loss}
|
||||
@@ -1,155 +0,0 @@
|
||||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from paddle.nn import L1Loss
|
||||
from paddle.nn import MSELoss as L2Loss
|
||||
from paddle.nn import SmoothL1Loss
|
||||
|
||||
|
||||
class CELoss(nn.Layer):
|
||||
def __init__(self, epsilon=None):
|
||||
super().__init__()
|
||||
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
||||
epsilon = None
|
||||
self.epsilon = epsilon
|
||||
|
||||
def _labelsmoothing(self, target, class_num):
|
||||
if target.shape[-1] != class_num:
|
||||
one_hot_target = F.one_hot(target, class_num)
|
||||
else:
|
||||
one_hot_target = target
|
||||
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
|
||||
return soft_target
|
||||
|
||||
def forward(self, x, label):
|
||||
loss_dict = {}
|
||||
if self.epsilon is not None:
|
||||
class_num = x.shape[-1]
|
||||
label = self._labelsmoothing(label, class_num)
|
||||
x = -F.log_softmax(x, axis=-1)
|
||||
loss = paddle.sum(x * label, axis=-1)
|
||||
else:
|
||||
if label.shape[-1] == x.shape[-1]:
|
||||
label = F.softmax(label, axis=-1)
|
||||
soft_label = True
|
||||
else:
|
||||
soft_label = False
|
||||
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
|
||||
return loss
|
||||
|
||||
|
||||
class KLJSLoss(object):
|
||||
def __init__(self, mode='kl'):
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'
|
||||
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
|
||||
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
||||
|
||||
if self.mode.lower() == "js":
|
||||
loss += paddle.multiply(
|
||||
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
if reduction == "mean":
|
||||
loss = paddle.mean(loss, axis=[1, 2])
|
||||
elif reduction == "none" or reduction is None:
|
||||
return loss
|
||||
else:
|
||||
loss = paddle.sum(loss, axis=[1, 2])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class DMLLoss(nn.Layer):
|
||||
"""
|
||||
DMLLoss
|
||||
"""
|
||||
|
||||
def __init__(self, act=None, use_log=False):
|
||||
super().__init__()
|
||||
if act is not None:
|
||||
assert act in ["softmax", "sigmoid"]
|
||||
if act == "softmax":
|
||||
self.act = nn.Softmax(axis=-1)
|
||||
elif act == "sigmoid":
|
||||
self.act = nn.Sigmoid()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
self.use_log = use_log
|
||||
self.jskl_loss = KLJSLoss(mode="js")
|
||||
|
||||
def _kldiv(self, x, target):
|
||||
eps = 1.0e-10
|
||||
loss = target * (paddle.log(target + eps) - x)
|
||||
# batch mean loss
|
||||
loss = paddle.sum(loss) / loss.shape[0]
|
||||
return loss
|
||||
|
||||
def forward(self, out1, out2):
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1) + 1e-10
|
||||
out2 = self.act(out2) + 1e-10
|
||||
if self.use_log:
|
||||
# for recognition distillation, log is needed for feature map
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (
|
||||
self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
|
||||
else:
|
||||
# for detection distillation log is not needed
|
||||
loss = self.jskl_loss(out1, out2)
|
||||
return loss
|
||||
|
||||
|
||||
class DistanceLoss(nn.Layer):
|
||||
"""
|
||||
DistanceLoss:
|
||||
mode: loss mode
|
||||
"""
|
||||
|
||||
def __init__(self, mode="l2", **kargs):
|
||||
super().__init__()
|
||||
assert mode in ["l1", "l2", "smooth_l1"]
|
||||
if mode == "l1":
|
||||
self.loss_func = nn.L1Loss(**kargs)
|
||||
elif mode == "l2":
|
||||
self.loss_func = nn.MSELoss(**kargs)
|
||||
elif mode == "smooth_l1":
|
||||
self.loss_func = nn.SmoothL1Loss(**kargs)
|
||||
|
||||
def forward(self, x, y):
|
||||
return self.loss_func(x, y)
|
||||
|
||||
|
||||
class LossFromOutput(nn.Layer):
|
||||
def __init__(self, key='loss', reduction='none'):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss = predicts[self.key]
|
||||
if self.reduction == 'mean':
|
||||
loss = paddle.mean(loss)
|
||||
elif self.reduction == 'sum':
|
||||
loss = paddle.sum(loss)
|
||||
return {'loss': loss}
|
||||
@@ -1,88 +0,0 @@
|
||||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
# This code is refer from: https://github.com/KaiyangZhou/pytorch-center-loss
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class CenterLoss(nn.Layer):
|
||||
"""
|
||||
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.feat_dim = feat_dim
|
||||
self.centers = paddle.randn(
|
||||
shape=[self.num_classes, self.feat_dim]).astype("float64")
|
||||
|
||||
if center_file_path is not None:
|
||||
assert os.path.exists(
|
||||
center_file_path
|
||||
), f"center path({center_file_path}) must exist when it is not None."
|
||||
with open(center_file_path, 'rb') as f:
|
||||
char_dict = pickle.load(f)
|
||||
for key in char_dict.keys():
|
||||
self.centers[key] = paddle.to_tensor(char_dict[key])
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
assert isinstance(predicts, (list, tuple))
|
||||
features, predicts = predicts
|
||||
|
||||
feats_reshape = paddle.reshape(
|
||||
features, [-1, features.shape[-1]]).astype("float64")
|
||||
label = paddle.argmax(predicts, axis=2)
|
||||
label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
|
||||
|
||||
batch_size = feats_reshape.shape[0]
|
||||
|
||||
#calc l2 distance between feats and centers
|
||||
square_feat = paddle.sum(paddle.square(feats_reshape),
|
||||
axis=1,
|
||||
keepdim=True)
|
||||
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
|
||||
|
||||
square_center = paddle.sum(paddle.square(self.centers),
|
||||
axis=1,
|
||||
keepdim=True)
|
||||
square_center = paddle.expand(
|
||||
square_center, [self.num_classes, batch_size]).astype("float64")
|
||||
square_center = paddle.transpose(square_center, [1, 0])
|
||||
|
||||
distmat = paddle.add(square_feat, square_center)
|
||||
feat_dot_center = paddle.matmul(feats_reshape,
|
||||
paddle.transpose(self.centers, [1, 0]))
|
||||
distmat = distmat - 2.0 * feat_dot_center
|
||||
|
||||
#generate the mask
|
||||
classes = paddle.arange(self.num_classes).astype("int64")
|
||||
label = paddle.expand(
|
||||
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
|
||||
mask = paddle.equal(
|
||||
paddle.expand(classes, [batch_size, self.num_classes]),
|
||||
label).astype("float64")
|
||||
dist = paddle.multiply(distmat, mask)
|
||||
|
||||
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
|
||||
return {'loss_center': loss}
|
||||
@@ -1,30 +0,0 @@
|
||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class ClsLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(ClsLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
label = batch[1].astype("int64")
|
||||
loss = self.loss_func(input=predicts, label=label)
|
||||
return {'loss': loss}
|
||||
@@ -1,69 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .center_loss import CenterLoss
|
||||
from .ace_loss import ACELoss
|
||||
from .rec_sar_loss import SARLoss
|
||||
|
||||
from .distillation_loss import DistillationCTCLoss
|
||||
from .distillation_loss import DistillationSARLoss
|
||||
from .distillation_loss import DistillationDMLLoss
|
||||
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
"""
|
||||
CombinedLoss:
|
||||
a combionation of loss function
|
||||
"""
|
||||
|
||||
def __init__(self, loss_config_list=None):
|
||||
super().__init__()
|
||||
self.loss_func = []
|
||||
self.loss_weight = []
|
||||
assert isinstance(loss_config_list, list), (
|
||||
'operator config should be a list')
|
||||
for config in loss_config_list:
|
||||
assert isinstance(config,
|
||||
dict) and len(config) == 1, "yaml format error"
|
||||
name = list(config)[0]
|
||||
param = config[name]
|
||||
assert "weight" in param, "weight must be in param, but param just contains {}".format(
|
||||
param.keys())
|
||||
self.loss_weight.append(param.pop("weight"))
|
||||
self.loss_func.append(eval(name)(**param))
|
||||
|
||||
def forward(self, input, batch, **kargs):
|
||||
loss_dict = {}
|
||||
loss_all = 0.
|
||||
for idx, loss_func in enumerate(self.loss_func):
|
||||
loss = loss_func(input, batch, **kargs)
|
||||
if isinstance(loss, paddle.Tensor):
|
||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||
|
||||
weight = self.loss_weight[idx]
|
||||
|
||||
loss = {key: loss[key] * weight for key in loss}
|
||||
|
||||
if "loss" in loss:
|
||||
loss_all += loss["loss"]
|
||||
else:
|
||||
loss_all += paddle.add_n(list(loss.values()))
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = loss_all
|
||||
return loss_dict
|
||||
@@ -1,153 +0,0 @@
|
||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class BalanceLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
negative_ratio=3,
|
||||
return_origin=False,
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
"""
|
||||
The BalanceLoss for Differentiable Binarization text detection
|
||||
args:
|
||||
balance_loss (bool): whether balance loss or not, default is True
|
||||
main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
|
||||
'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
|
||||
negative_ratio (int|float): float, default is 3.
|
||||
return_origin (bool): whether return unbalanced loss or not, default is False.
|
||||
eps (float): default is 1e-6.
|
||||
"""
|
||||
super(BalanceLoss, self).__init__()
|
||||
self.balance_loss = balance_loss
|
||||
self.main_loss_type = main_loss_type
|
||||
self.negative_ratio = negative_ratio
|
||||
self.return_origin = return_origin
|
||||
self.eps = eps
|
||||
|
||||
if self.main_loss_type == "CrossEntropy":
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
elif self.main_loss_type == "Euclidean":
|
||||
self.loss = nn.MSELoss()
|
||||
elif self.main_loss_type == "DiceLoss":
|
||||
self.loss = DiceLoss(self.eps)
|
||||
elif self.main_loss_type == "BCELoss":
|
||||
self.loss = BCELoss(reduction='none')
|
||||
elif self.main_loss_type == "MaskL1Loss":
|
||||
self.loss = MaskL1Loss(self.eps)
|
||||
else:
|
||||
loss_type = [
|
||||
'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
|
||||
]
|
||||
raise Exception(
|
||||
"main_loss_type in BalanceLoss() can only be one of {}".format(
|
||||
loss_type))
|
||||
|
||||
def forward(self, pred, gt, mask=None):
|
||||
"""
|
||||
The BalanceLoss for Differentiable Binarization text detection
|
||||
args:
|
||||
pred (variable): predicted feature maps.
|
||||
gt (variable): ground truth feature maps.
|
||||
mask (variable): masked maps.
|
||||
return: (variable) balanced loss
|
||||
"""
|
||||
positive = gt * mask
|
||||
negative = (1 - gt) * mask
|
||||
|
||||
positive_count = int(positive.sum())
|
||||
negative_count = int(
|
||||
min(negative.sum(), positive_count * self.negative_ratio))
|
||||
loss = self.loss(pred, gt, mask=mask)
|
||||
|
||||
if not self.balance_loss:
|
||||
return loss
|
||||
|
||||
positive_loss = positive * loss
|
||||
negative_loss = negative * loss
|
||||
negative_loss = paddle.reshape(negative_loss, shape=[-1])
|
||||
if negative_count > 0:
|
||||
sort_loss = negative_loss.sort(descending=True)
|
||||
negative_loss = sort_loss[:negative_count]
|
||||
# negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
|
||||
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
|
||||
positive_count + negative_count + self.eps)
|
||||
else:
|
||||
balance_loss = positive_loss.sum() / (positive_count + self.eps)
|
||||
if self.return_origin:
|
||||
return balance_loss, loss
|
||||
|
||||
return balance_loss
|
||||
|
||||
|
||||
class DiceLoss(nn.Layer):
|
||||
def __init__(self, eps=1e-6):
|
||||
super(DiceLoss, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, pred, gt, mask, weights=None):
|
||||
"""
|
||||
DiceLoss function.
|
||||
"""
|
||||
|
||||
assert pred.shape == gt.shape
|
||||
assert pred.shape == mask.shape
|
||||
if weights is not None:
|
||||
assert weights.shape == mask.shape
|
||||
mask = weights * mask
|
||||
intersection = paddle.sum(pred * gt * mask)
|
||||
|
||||
union = paddle.sum(pred * mask) + paddle.sum(gt * mask) + self.eps
|
||||
loss = 1 - 2.0 * intersection / union
|
||||
assert loss <= 1
|
||||
return loss
|
||||
|
||||
|
||||
class MaskL1Loss(nn.Layer):
|
||||
def __init__(self, eps=1e-6):
|
||||
super(MaskL1Loss, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, pred, gt, mask):
|
||||
"""
|
||||
Mask L1 Loss
|
||||
"""
|
||||
loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
|
||||
loss = paddle.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class BCELoss(nn.Layer):
|
||||
def __init__(self, reduction='mean'):
|
||||
super(BCELoss, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, input, label, mask=None, weight=None, name=None):
|
||||
loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
|
||||
return loss
|
||||
@@ -1,76 +0,0 @@
|
||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
|
||||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||
|
||||
|
||||
class DBLoss(nn.Layer):
|
||||
"""
|
||||
Differentiable Binarization (DB) Loss Function
|
||||
args:
|
||||
param (dict): the super paramter for DB Loss
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
super(DBLoss, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.dice_loss = DiceLoss(eps=eps)
|
||||
self.l1_loss = MaskL1Loss(eps=eps)
|
||||
self.bce_loss = BalanceLoss(
|
||||
balance_loss=balance_loss,
|
||||
main_loss_type=main_loss_type,
|
||||
negative_ratio=ohem_ratio)
|
||||
|
||||
def forward(self, predicts, labels):
|
||||
predict_maps = predicts['maps']
|
||||
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
|
||||
1:]
|
||||
shrink_maps = predict_maps[:, 0, :, :]
|
||||
threshold_maps = predict_maps[:, 1, :, :]
|
||||
binary_maps = predict_maps[:, 2, :, :]
|
||||
|
||||
loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
|
||||
label_shrink_mask)
|
||||
loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
|
||||
label_threshold_mask)
|
||||
loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
|
||||
label_shrink_mask)
|
||||
loss_shrink_maps = self.alpha * loss_shrink_maps
|
||||
loss_threshold_maps = self.beta * loss_threshold_maps
|
||||
|
||||
loss_all = loss_shrink_maps + loss_threshold_maps \
|
||||
+ loss_binary_maps
|
||||
losses = {'loss': loss_all, \
|
||||
"loss_shrink_maps": loss_shrink_maps, \
|
||||
"loss_threshold_maps": loss_threshold_maps, \
|
||||
"loss_binary_maps": loss_binary_maps}
|
||||
return losses
|
||||
@@ -1,63 +0,0 @@
|
||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from .det_basic_loss import DiceLoss
|
||||
|
||||
|
||||
class EASTLoss(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
super(EASTLoss, self).__init__()
|
||||
self.dice_loss = DiceLoss(eps=eps)
|
||||
|
||||
def forward(self, predicts, labels):
|
||||
l_score, l_geo, l_mask = labels[1:]
|
||||
f_score = predicts['f_score']
|
||||
f_geo = predicts['f_geo']
|
||||
|
||||
dice_loss = self.dice_loss(f_score, l_score, l_mask)
|
||||
|
||||
#smoooth_l1_loss
|
||||
channels = 8
|
||||
l_geo_split = paddle.split(
|
||||
l_geo, num_or_sections=channels + 1, axis=1)
|
||||
f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1)
|
||||
smooth_l1 = 0
|
||||
for i in range(0, channels):
|
||||
geo_diff = l_geo_split[i] - f_geo_split[i]
|
||||
abs_geo_diff = paddle.abs(geo_diff)
|
||||
smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score)
|
||||
smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32')
|
||||
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
|
||||
(abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
|
||||
out_loss = l_geo_split[-1] / channels * in_loss * l_score
|
||||
smooth_l1 += out_loss
|
||||
smooth_l1_loss = paddle.mean(smooth_l1 * l_score)
|
||||
|
||||
dice_loss = dice_loss * 0.01
|
||||
total_loss = dice_loss + smooth_l1_loss
|
||||
losses = {"loss":total_loss, \
|
||||
"dice_loss":dice_loss,\
|
||||
"smooth_l1_loss":smooth_l1_loss}
|
||||
return losses
|
||||
@@ -1,227 +0,0 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/fce_loss.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
|
||||
def multi_apply(func, *args, **kwargs):
|
||||
pfunc = partial(func, **kwargs) if kwargs else func
|
||||
map_results = map(pfunc, *args)
|
||||
return tuple(map(list, zip(*map_results)))
|
||||
|
||||
|
||||
class FCELoss(nn.Layer):
|
||||
"""The class for implementing FCENet loss
|
||||
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
|
||||
Text Detection
|
||||
|
||||
[https://arxiv.org/abs/2104.10442]
|
||||
|
||||
Args:
|
||||
fourier_degree (int) : The maximum Fourier transform degree k.
|
||||
num_sample (int) : The sampling points number of regression
|
||||
loss. If it is too small, fcenet tends to be overfitting.
|
||||
ohem_ratio (float): the negative/positive ratio in OHEM.
|
||||
"""
|
||||
|
||||
def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
|
||||
super().__init__()
|
||||
self.fourier_degree = fourier_degree
|
||||
self.num_sample = num_sample
|
||||
self.ohem_ratio = ohem_ratio
|
||||
|
||||
def forward(self, preds, labels):
|
||||
assert isinstance(preds, dict)
|
||||
preds = preds['levels']
|
||||
|
||||
p3_maps, p4_maps, p5_maps = labels[1:]
|
||||
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
|
||||
'fourier degree not equal in FCEhead and FCEtarget'
|
||||
|
||||
# to tensor
|
||||
gts = [p3_maps, p4_maps, p5_maps]
|
||||
for idx, maps in enumerate(gts):
|
||||
gts[idx] = paddle.to_tensor(np.stack(maps))
|
||||
|
||||
losses = multi_apply(self.forward_single, preds, gts)
|
||||
|
||||
loss_tr = paddle.to_tensor(0.).astype('float32')
|
||||
loss_tcl = paddle.to_tensor(0.).astype('float32')
|
||||
loss_reg_x = paddle.to_tensor(0.).astype('float32')
|
||||
loss_reg_y = paddle.to_tensor(0.).astype('float32')
|
||||
loss_all = paddle.to_tensor(0.).astype('float32')
|
||||
|
||||
for idx, loss in enumerate(losses):
|
||||
loss_all += sum(loss)
|
||||
if idx == 0:
|
||||
loss_tr += sum(loss)
|
||||
elif idx == 1:
|
||||
loss_tcl += sum(loss)
|
||||
elif idx == 2:
|
||||
loss_reg_x += sum(loss)
|
||||
else:
|
||||
loss_reg_y += sum(loss)
|
||||
|
||||
results = dict(
|
||||
loss=loss_all,
|
||||
loss_text=loss_tr,
|
||||
loss_center=loss_tcl,
|
||||
loss_reg_x=loss_reg_x,
|
||||
loss_reg_y=loss_reg_y, )
|
||||
return results
|
||||
|
||||
def forward_single(self, pred, gt):
|
||||
cls_pred = paddle.transpose(pred[0], (0, 2, 3, 1))
|
||||
reg_pred = paddle.transpose(pred[1], (0, 2, 3, 1))
|
||||
gt = paddle.transpose(gt, (0, 2, 3, 1))
|
||||
|
||||
k = 2 * self.fourier_degree + 1
|
||||
tr_pred = paddle.reshape(cls_pred[:, :, :, :2], (-1, 2))
|
||||
tcl_pred = paddle.reshape(cls_pred[:, :, :, 2:], (-1, 2))
|
||||
x_pred = paddle.reshape(reg_pred[:, :, :, 0:k], (-1, k))
|
||||
y_pred = paddle.reshape(reg_pred[:, :, :, k:2 * k], (-1, k))
|
||||
|
||||
tr_mask = gt[:, :, :, :1].reshape([-1])
|
||||
tcl_mask = gt[:, :, :, 1:2].reshape([-1])
|
||||
train_mask = gt[:, :, :, 2:3].reshape([-1])
|
||||
x_map = paddle.reshape(gt[:, :, :, 3:3 + k], (-1, k))
|
||||
y_map = paddle.reshape(gt[:, :, :, 3 + k:], (-1, k))
|
||||
|
||||
tr_train_mask = (train_mask * tr_mask).astype('bool')
|
||||
tr_train_mask2 = paddle.concat(
|
||||
[tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1)
|
||||
# tr loss
|
||||
loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
|
||||
# tcl loss
|
||||
loss_tcl = paddle.to_tensor(0.).astype('float32')
|
||||
tr_neg_mask = tr_train_mask.logical_not()
|
||||
tr_neg_mask2 = paddle.concat(
|
||||
[tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], axis=1)
|
||||
if tr_train_mask.sum().item() > 0:
|
||||
loss_tcl_pos = F.cross_entropy(
|
||||
tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]),
|
||||
tcl_mask.masked_select(tr_train_mask).astype('int64'))
|
||||
loss_tcl_neg = F.cross_entropy(
|
||||
tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]),
|
||||
tcl_mask.masked_select(tr_neg_mask).astype('int64'))
|
||||
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
|
||||
|
||||
# regression loss
|
||||
loss_reg_x = paddle.to_tensor(0.).astype('float32')
|
||||
loss_reg_y = paddle.to_tensor(0.).astype('float32')
|
||||
if tr_train_mask.sum().item() > 0:
|
||||
weight = (tr_mask.masked_select(tr_train_mask.astype('bool'))
|
||||
.astype('float32') + tcl_mask.masked_select(
|
||||
tr_train_mask.astype('bool')).astype('float32')) / 2
|
||||
weight = weight.reshape([-1, 1])
|
||||
|
||||
ft_x, ft_y = self.fourier2poly(x_map, y_map)
|
||||
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
|
||||
|
||||
dim = ft_x.shape[1]
|
||||
|
||||
tr_train_mask3 = paddle.concat(
|
||||
[tr_train_mask.unsqueeze(1) for i in range(dim)], axis=1)
|
||||
|
||||
loss_reg_x = paddle.mean(weight * F.smooth_l1_loss(
|
||||
ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
|
||||
ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
|
||||
reduction='none'))
|
||||
loss_reg_y = paddle.mean(weight * F.smooth_l1_loss(
|
||||
ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
|
||||
ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
|
||||
reduction='none'))
|
||||
|
||||
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
|
||||
|
||||
def ohem(self, predict, target, train_mask):
|
||||
|
||||
pos = (target * train_mask).astype('bool')
|
||||
neg = ((1 - target) * train_mask).astype('bool')
|
||||
|
||||
pos2 = paddle.concat([pos.unsqueeze(1), pos.unsqueeze(1)], axis=1)
|
||||
neg2 = paddle.concat([neg.unsqueeze(1), neg.unsqueeze(1)], axis=1)
|
||||
|
||||
n_pos = pos.astype('float32').sum()
|
||||
|
||||
if n_pos.item() > 0:
|
||||
loss_pos = F.cross_entropy(
|
||||
predict.masked_select(pos2).reshape([-1, 2]),
|
||||
target.masked_select(pos).astype('int64'),
|
||||
reduction='sum')
|
||||
loss_neg = F.cross_entropy(
|
||||
predict.masked_select(neg2).reshape([-1, 2]),
|
||||
target.masked_select(neg).astype('int64'),
|
||||
reduction='none')
|
||||
n_neg = min(
|
||||
int(neg.astype('float32').sum().item()),
|
||||
int(self.ohem_ratio * n_pos.astype('float32')))
|
||||
else:
|
||||
loss_pos = paddle.to_tensor(0.)
|
||||
loss_neg = F.cross_entropy(
|
||||
predict.masked_select(neg2).reshape([-1, 2]),
|
||||
target.masked_select(neg).astype('int64'),
|
||||
reduction='none')
|
||||
n_neg = 100
|
||||
if len(loss_neg) > n_neg:
|
||||
loss_neg, _ = paddle.topk(loss_neg, n_neg)
|
||||
|
||||
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).astype('float32')
|
||||
|
||||
def fourier2poly(self, real_maps, imag_maps):
|
||||
"""Transform Fourier coefficient maps to polygon maps.
|
||||
|
||||
Args:
|
||||
real_maps (tensor): A map composed of the real parts of the
|
||||
Fourier coefficients, whose shape is (-1, 2k+1)
|
||||
imag_maps (tensor):A map composed of the imag parts of the
|
||||
Fourier coefficients, whose shape is (-1, 2k+1)
|
||||
|
||||
Returns
|
||||
x_maps (tensor): A map composed of the x value of the polygon
|
||||
represented by n sample points (xn, yn), whose shape is (-1, n)
|
||||
y_maps (tensor): A map composed of the y value of the polygon
|
||||
represented by n sample points (xn, yn), whose shape is (-1, n)
|
||||
"""
|
||||
|
||||
k_vect = paddle.arange(
|
||||
-self.fourier_degree, self.fourier_degree + 1,
|
||||
dtype='float32').reshape([-1, 1])
|
||||
i_vect = paddle.arange(
|
||||
0, self.num_sample, dtype='float32').reshape([1, -1])
|
||||
|
||||
transform_matrix = 2 * np.pi / self.num_sample * paddle.matmul(k_vect,
|
||||
i_vect)
|
||||
|
||||
x1 = paddle.einsum('ak, kn-> an', real_maps,
|
||||
paddle.cos(transform_matrix))
|
||||
x2 = paddle.einsum('ak, kn-> an', imag_maps,
|
||||
paddle.sin(transform_matrix))
|
||||
y1 = paddle.einsum('ak, kn-> an', real_maps,
|
||||
paddle.sin(transform_matrix))
|
||||
y2 = paddle.einsum('ak, kn-> an', imag_maps,
|
||||
paddle.cos(transform_matrix))
|
||||
|
||||
x_maps = x1 - x2
|
||||
y_maps = y1 + y2
|
||||
|
||||
return x_maps, y_maps
|
||||
@@ -1,149 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
from ppocr.utils.iou import iou
|
||||
|
||||
|
||||
class PSELoss(nn.Layer):
|
||||
def __init__(self,
|
||||
alpha,
|
||||
ohem_ratio=3,
|
||||
kernel_sample_mask='pred',
|
||||
reduction='sum',
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
"""Implement PSE Loss.
|
||||
"""
|
||||
super(PSELoss, self).__init__()
|
||||
assert reduction in ['sum', 'mean', 'none']
|
||||
self.alpha = alpha
|
||||
self.ohem_ratio = ohem_ratio
|
||||
self.kernel_sample_mask = kernel_sample_mask
|
||||
self.reduction = reduction
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, outputs, labels):
|
||||
predicts = outputs['maps']
|
||||
predicts = F.interpolate(predicts, scale_factor=4)
|
||||
|
||||
texts = predicts[:, 0, :, :]
|
||||
kernels = predicts[:, 1:, :, :]
|
||||
gt_texts, gt_kernels, training_masks = labels[1:]
|
||||
|
||||
# text loss
|
||||
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
|
||||
|
||||
loss_text = self.dice_loss(texts, gt_texts, selected_masks)
|
||||
iou_text = iou((texts > 0).astype('int64'),
|
||||
gt_texts,
|
||||
training_masks,
|
||||
reduce=False)
|
||||
losses = dict(loss_text=loss_text, iou_text=iou_text)
|
||||
|
||||
# kernel loss
|
||||
loss_kernels = []
|
||||
if self.kernel_sample_mask == 'gt':
|
||||
selected_masks = gt_texts * training_masks
|
||||
elif self.kernel_sample_mask == 'pred':
|
||||
selected_masks = (
|
||||
F.sigmoid(texts) > 0.5).astype('float32') * training_masks
|
||||
|
||||
for i in range(kernels.shape[1]):
|
||||
kernel_i = kernels[:, i, :, :]
|
||||
gt_kernel_i = gt_kernels[:, i, :, :]
|
||||
loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i,
|
||||
selected_masks)
|
||||
loss_kernels.append(loss_kernel_i)
|
||||
loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1)
|
||||
iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'),
|
||||
gt_kernels[:, -1, :, :],
|
||||
training_masks * gt_texts,
|
||||
reduce=False)
|
||||
losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel))
|
||||
loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels
|
||||
losses['loss'] = loss
|
||||
if self.reduction == 'sum':
|
||||
losses = {x: paddle.sum(v) for x, v in losses.items()}
|
||||
elif self.reduction == 'mean':
|
||||
losses = {x: paddle.mean(v) for x, v in losses.items()}
|
||||
return losses
|
||||
|
||||
def dice_loss(self, input, target, mask):
|
||||
input = F.sigmoid(input)
|
||||
|
||||
input = input.reshape([input.shape[0], -1])
|
||||
target = target.reshape([target.shape[0], -1])
|
||||
mask = mask.reshape([mask.shape[0], -1])
|
||||
|
||||
input = input * mask
|
||||
target = target * mask
|
||||
|
||||
a = paddle.sum(input * target, 1)
|
||||
b = paddle.sum(input * input, 1) + self.eps
|
||||
c = paddle.sum(target * target, 1) + self.eps
|
||||
d = (2 * a) / (b + c)
|
||||
return 1 - d
|
||||
|
||||
def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
|
||||
pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int(
|
||||
paddle.sum(
|
||||
paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5))
|
||||
.astype('float32')))
|
||||
|
||||
if pos_num == 0:
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.reshape(
|
||||
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
|
||||
'float32')
|
||||
return selected_mask
|
||||
|
||||
neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32')))
|
||||
neg_num = int(min(pos_num * ohem_ratio, neg_num))
|
||||
|
||||
if neg_num == 0:
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.reshape(
|
||||
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
|
||||
'float32')
|
||||
return selected_mask
|
||||
|
||||
neg_score = paddle.masked_select(score, gt_text <= 0.5)
|
||||
neg_score_sorted = paddle.sort(-neg_score)
|
||||
threshold = -neg_score_sorted[neg_num - 1]
|
||||
|
||||
selected_mask = paddle.logical_and(
|
||||
paddle.logical_or((score >= threshold), (gt_text > 0.5)),
|
||||
(training_mask > 0.5))
|
||||
selected_mask = selected_mask.reshape(
|
||||
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
|
||||
'float32')
|
||||
return selected_mask
|
||||
|
||||
def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3):
|
||||
selected_masks = []
|
||||
for i in range(scores.shape[0]):
|
||||
selected_masks.append(
|
||||
self.ohem_single(scores[i, :, :], gt_texts[i, :, :],
|
||||
training_masks[i, :, :], ohem_ratio))
|
||||
|
||||
selected_masks = paddle.concat(selected_masks, 0).astype('float32')
|
||||
return selected_masks
|
||||
@@ -1,121 +0,0 @@
|
||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from .det_basic_loss import DiceLoss
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SASTLoss(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, eps=1e-6, **kwargs):
|
||||
super(SASTLoss, self).__init__()
|
||||
self.dice_loss = DiceLoss(eps=eps)
|
||||
|
||||
def forward(self, predicts, labels):
|
||||
"""
|
||||
tcl_pos: N x 128 x 3
|
||||
tcl_mask: N x 128 x 1
|
||||
tcl_label: N x X list or LoDTensor
|
||||
"""
|
||||
|
||||
f_score = predicts['f_score']
|
||||
f_border = predicts['f_border']
|
||||
f_tvo = predicts['f_tvo']
|
||||
f_tco = predicts['f_tco']
|
||||
|
||||
l_score, l_border, l_mask, l_tvo, l_tco = labels[1:]
|
||||
|
||||
#score_loss
|
||||
intersection = paddle.sum(f_score * l_score * l_mask)
|
||||
union = paddle.sum(f_score * l_mask) + paddle.sum(l_score * l_mask)
|
||||
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
|
||||
|
||||
#border loss
|
||||
l_border_split, l_border_norm = paddle.split(
|
||||
l_border, num_or_sections=[4, 1], axis=1)
|
||||
f_border_split = f_border
|
||||
border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1])
|
||||
l_border_norm_split = paddle.expand(
|
||||
x=l_border_norm, shape=border_ex_shape)
|
||||
l_border_score = paddle.expand(x=l_score, shape=border_ex_shape)
|
||||
l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape)
|
||||
|
||||
border_diff = l_border_split - f_border_split
|
||||
abs_border_diff = paddle.abs(border_diff)
|
||||
border_sign = abs_border_diff < 1.0
|
||||
border_sign = paddle.cast(border_sign, dtype='float32')
|
||||
border_sign.stop_gradient = True
|
||||
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
||||
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
||||
border_out_loss = l_border_norm_split * border_in_loss
|
||||
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
|
||||
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
|
||||
|
||||
#tvo_loss
|
||||
l_tvo_split, l_tvo_norm = paddle.split(
|
||||
l_tvo, num_or_sections=[8, 1], axis=1)
|
||||
f_tvo_split = f_tvo
|
||||
tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1])
|
||||
l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape)
|
||||
l_tvo_score = paddle.expand(x=l_score, shape=tvo_ex_shape)
|
||||
l_tvo_mask = paddle.expand(x=l_mask, shape=tvo_ex_shape)
|
||||
#
|
||||
tvo_geo_diff = l_tvo_split - f_tvo_split
|
||||
abs_tvo_geo_diff = paddle.abs(tvo_geo_diff)
|
||||
tvo_sign = abs_tvo_geo_diff < 1.0
|
||||
tvo_sign = paddle.cast(tvo_sign, dtype='float32')
|
||||
tvo_sign.stop_gradient = True
|
||||
tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
|
||||
(abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
|
||||
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
|
||||
tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
|
||||
(paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5)
|
||||
|
||||
#tco_loss
|
||||
l_tco_split, l_tco_norm = paddle.split(
|
||||
l_tco, num_or_sections=[2, 1], axis=1)
|
||||
f_tco_split = f_tco
|
||||
tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1])
|
||||
l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape)
|
||||
l_tco_score = paddle.expand(x=l_score, shape=tco_ex_shape)
|
||||
l_tco_mask = paddle.expand(x=l_mask, shape=tco_ex_shape)
|
||||
|
||||
tco_geo_diff = l_tco_split - f_tco_split
|
||||
abs_tco_geo_diff = paddle.abs(tco_geo_diff)
|
||||
tco_sign = abs_tco_geo_diff < 1.0
|
||||
tco_sign = paddle.cast(tco_sign, dtype='float32')
|
||||
tco_sign.stop_gradient = True
|
||||
tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
|
||||
(abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
|
||||
tco_out_loss = l_tco_norm_split * tco_in_loss
|
||||
tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \
|
||||
(paddle.sum(l_tco_score * l_tco_mask) + 1e-5)
|
||||
|
||||
# total loss
|
||||
tvo_lw, tco_lw = 1.5, 1.5
|
||||
score_lw, border_lw = 1.0, 1.0
|
||||
total_loss = score_loss * score_lw + border_loss * border_lw + \
|
||||
tvo_loss * tvo_lw + tco_loss * tco_lw
|
||||
|
||||
losses = {'loss':total_loss, "score_loss":score_loss,\
|
||||
"border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
|
||||
return losses
|
||||
@@ -1,324 +0,0 @@
|
||||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_sar_loss import SARLoss
|
||||
from .basic_loss import DMLLoss
|
||||
from .basic_loss import DistanceLoss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||
|
||||
|
||||
def _sum_loss(loss_dict):
|
||||
if "loss" in loss_dict.keys():
|
||||
return loss_dict
|
||||
else:
|
||||
loss_dict["loss"] = 0.
|
||||
for k, value in loss_dict.items():
|
||||
if k == "loss":
|
||||
continue
|
||||
else:
|
||||
loss_dict["loss"] += value
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDMLLoss(DMLLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
act=None,
|
||||
use_log=False,
|
||||
key=None,
|
||||
multi_head=False,
|
||||
dis_head='ctc',
|
||||
maps_name=None,
|
||||
name="dml"):
|
||||
super().__init__(act=act, use_log=use_log)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.multi_head = multi_head
|
||||
self.dis_head = dis_head
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.name = name
|
||||
self.maps_name = self._check_maps_name(maps_name)
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(
|
||||
model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
||||
def _check_maps_name(self, maps_name):
|
||||
if maps_name is None:
|
||||
return None
|
||||
elif type(maps_name) == str:
|
||||
return [maps_name]
|
||||
elif type(maps_name) == list:
|
||||
return [maps_name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _slice_out(self, outs):
|
||||
new_outs = {}
|
||||
for k in self.maps_name:
|
||||
if k == "thrink_maps":
|
||||
new_outs[k] = outs[:, 0, :, :]
|
||||
elif k == "threshold_maps":
|
||||
new_outs[k] = outs[:, 1, :, :]
|
||||
elif k == "binary_maps":
|
||||
new_outs[k] = outs[:, 2, :, :]
|
||||
else:
|
||||
continue
|
||||
return new_outs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
|
||||
if self.maps_name is None:
|
||||
if self.multi_head:
|
||||
loss = super().forward(out1[self.dis_head],
|
||||
out2[self.dis_head])
|
||||
else:
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
else:
|
||||
outs1 = self._slice_out(out1)
|
||||
outs2 = self._slice_out(out2)
|
||||
for _c, k in enumerate(outs1.keys()):
|
||||
loss = super().forward(outs1[k], outs2[k])
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], self.maps_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
||||
_c], idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationCTCLoss(CTCLoss):
|
||||
def __init__(self,
|
||||
model_name_list=[],
|
||||
key=None,
|
||||
multi_head=False,
|
||||
name="loss_ctc"):
|
||||
super().__init__()
|
||||
self.model_name_list = model_name_list
|
||||
self.key = key
|
||||
self.name = name
|
||||
self.multi_head = multi_head
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
if self.multi_head:
|
||||
assert 'ctc' in out, 'multi head has multi out'
|
||||
loss = super().forward(out['ctc'], batch[:2] + batch[3:])
|
||||
else:
|
||||
loss = super().forward(out, batch)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, model_name,
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationSARLoss(SARLoss):
|
||||
def __init__(self,
|
||||
model_name_list=[],
|
||||
key=None,
|
||||
multi_head=False,
|
||||
name="loss_sar",
|
||||
**kwargs):
|
||||
ignore_index = kwargs.get('ignore_index', 92)
|
||||
super().__init__(ignore_index=ignore_index)
|
||||
self.model_name_list = model_name_list
|
||||
self.key = key
|
||||
self.name = name
|
||||
self.multi_head = multi_head
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
if self.multi_head:
|
||||
assert 'sar' in out, 'multi head has multi out'
|
||||
loss = super().forward(out['sar'], batch[:1] + batch[2:])
|
||||
else:
|
||||
loss = super().forward(out, batch)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, model_name,
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_list=[],
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="db",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.model_name_list = model_name_list
|
||||
self.name = name
|
||||
self.key = None
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = {}
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
loss = super().forward(out, batch)
|
||||
|
||||
if isinstance(loss, dict):
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
continue
|
||||
name = "{}_{}_{}".format(self.name, model_name, key)
|
||||
loss_dict[name] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDilaDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="dila_dbloss"):
|
||||
super().__init__()
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
self.key = key
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
stu_outs = predicts[pair[0]]
|
||||
tch_outs = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
stu_preds = stu_outs[self.key]
|
||||
tch_preds = tch_outs[self.key]
|
||||
|
||||
stu_shrink_maps = stu_preds[:, 0, :, :]
|
||||
stu_binary_maps = stu_preds[:, 2, :, :]
|
||||
|
||||
# dilation to teacher prediction
|
||||
dilation_w = np.array([[1, 1], [1, 1]])
|
||||
th_shrink_maps = tch_preds[:, 0, :, :]
|
||||
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
|
||||
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
|
||||
for i in range(th_shrink_maps.shape[0]):
|
||||
dilate_maps[i] = cv2.dilate(
|
||||
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
|
||||
th_shrink_maps = paddle.to_tensor(dilate_maps)
|
||||
|
||||
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
|
||||
1:]
|
||||
|
||||
# calculate the shrink map loss
|
||||
bce_loss = self.alpha * self.bce_loss(
|
||||
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
|
||||
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
|
||||
label_shrink_mask)
|
||||
|
||||
# k = f"{self.name}_{pair[0]}_{pair[1]}"
|
||||
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
|
||||
loss_dict[k] = bce_loss + loss_binary_maps
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDistanceLoss(DistanceLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mode="l2",
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
name="loss_distance",
|
||||
**kargs):
|
||||
super().__init__(mode=mode, **kargs)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name + "_l2"
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
|
||||
key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
|
||||
idx)] = loss
|
||||
return loss_dict
|
||||
@@ -1,140 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
import paddle
|
||||
|
||||
from .det_basic_loss import DiceLoss
|
||||
from ppocr.utils.e2e_utils.extract_batchsize import pre_process
|
||||
|
||||
|
||||
class PGLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
tcl_bs,
|
||||
max_text_length,
|
||||
max_text_nums,
|
||||
pad_num,
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
super(PGLoss, self).__init__()
|
||||
self.tcl_bs = tcl_bs
|
||||
self.max_text_nums = max_text_nums
|
||||
self.max_text_length = max_text_length
|
||||
self.pad_num = pad_num
|
||||
self.dice_loss = DiceLoss(eps=eps)
|
||||
|
||||
def border_loss(self, f_border, l_border, l_score, l_mask):
|
||||
l_border_split, l_border_norm = paddle.tensor.split(
|
||||
l_border, num_or_sections=[4, 1], axis=1)
|
||||
f_border_split = f_border
|
||||
b, c, h, w = l_border_norm.shape
|
||||
l_border_norm_split = paddle.expand(
|
||||
x=l_border_norm, shape=[b, 4 * c, h, w])
|
||||
b, c, h, w = l_score.shape
|
||||
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
|
||||
b, c, h, w = l_mask.shape
|
||||
l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
|
||||
border_diff = l_border_split - f_border_split
|
||||
abs_border_diff = paddle.abs(border_diff)
|
||||
border_sign = abs_border_diff < 1.0
|
||||
border_sign = paddle.cast(border_sign, dtype='float32')
|
||||
border_sign.stop_gradient = True
|
||||
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
||||
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
||||
border_out_loss = l_border_norm_split * border_in_loss
|
||||
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
|
||||
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
|
||||
return border_loss
|
||||
|
||||
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
|
||||
l_direction_split, l_direction_norm = paddle.tensor.split(
|
||||
l_direction, num_or_sections=[2, 1], axis=1)
|
||||
f_direction_split = f_direction
|
||||
b, c, h, w = l_direction_norm.shape
|
||||
l_direction_norm_split = paddle.expand(
|
||||
x=l_direction_norm, shape=[b, 2 * c, h, w])
|
||||
b, c, h, w = l_score.shape
|
||||
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
|
||||
b, c, h, w = l_mask.shape
|
||||
l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
|
||||
direction_diff = l_direction_split - f_direction_split
|
||||
abs_direction_diff = paddle.abs(direction_diff)
|
||||
direction_sign = abs_direction_diff < 1.0
|
||||
direction_sign = paddle.cast(direction_sign, dtype='float32')
|
||||
direction_sign.stop_gradient = True
|
||||
direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
|
||||
(abs_direction_diff - 0.5) * (1.0 - direction_sign)
|
||||
direction_out_loss = l_direction_norm_split * direction_in_loss
|
||||
direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
|
||||
(paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
|
||||
return direction_loss
|
||||
|
||||
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
|
||||
f_char = paddle.transpose(f_char, [0, 2, 3, 1])
|
||||
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
|
||||
tcl_pos = paddle.cast(tcl_pos, dtype=int)
|
||||
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
|
||||
f_tcl_char = paddle.reshape(f_tcl_char,
|
||||
[-1, 64, 37]) # len(Lexicon_Table)+1
|
||||
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
|
||||
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
|
||||
b, c, l = tcl_mask.shape
|
||||
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
|
||||
tcl_mask_fg.stop_gradient = True
|
||||
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
|
||||
-20.0)
|
||||
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
|
||||
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
|
||||
N, B, _ = f_tcl_char_ld.shape
|
||||
input_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
||||
cost = paddle.nn.functional.ctc_loss(
|
||||
log_probs=f_tcl_char_ld,
|
||||
labels=tcl_label,
|
||||
input_lengths=input_lengths,
|
||||
label_lengths=label_t,
|
||||
blank=self.pad_num,
|
||||
reduction='none')
|
||||
cost = cost.mean()
|
||||
return cost
|
||||
|
||||
def forward(self, predicts, labels):
|
||||
images, tcl_maps, tcl_label_maps, border_maps \
|
||||
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
|
||||
# for all the batch_size
|
||||
pos_list, pos_mask, label_list, label_t = pre_process(
|
||||
label_list, pos_list, pos_mask, self.max_text_length,
|
||||
self.max_text_nums, self.pad_num, self.tcl_bs)
|
||||
|
||||
f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
|
||||
predicts['f_char']
|
||||
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
|
||||
border_loss = self.border_loss(f_border, border_maps, tcl_maps,
|
||||
training_masks)
|
||||
direction_loss = self.direction_loss(f_direction, direction_maps,
|
||||
tcl_maps, training_masks)
|
||||
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
|
||||
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
|
||||
|
||||
losses = {
|
||||
'loss': loss_all,
|
||||
"score_loss": score_loss,
|
||||
"border_loss": border_loss,
|
||||
"direction_loss": direction_loss,
|
||||
"ctc_loss": ctc_loss
|
||||
}
|
||||
return losses
|
||||
@@ -1,115 +0,0 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# reference from : https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/losses/sdmgr_loss.py
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
import paddle
|
||||
|
||||
|
||||
class SDMGRLoss(nn.Layer):
|
||||
def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
|
||||
super().__init__()
|
||||
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
|
||||
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.node_weight = node_weight
|
||||
self.edge_weight = edge_weight
|
||||
self.ignore = ignore
|
||||
|
||||
def pre_process(self, gts, tag):
|
||||
gts, tag = gts.numpy(), tag.numpy().tolist()
|
||||
temp_gts = []
|
||||
batch = len(tag)
|
||||
for i in range(batch):
|
||||
num, recoder_len = tag[i][0], tag[i][1]
|
||||
temp_gts.append(
|
||||
paddle.to_tensor(
|
||||
gts[i, :num, :num + 1], dtype='int64'))
|
||||
return temp_gts
|
||||
|
||||
def accuracy(self, pred, target, topk=1, thresh=None):
|
||||
"""Calculate accuracy according to the prediction and target.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The model prediction, shape (N, num_class)
|
||||
target (torch.Tensor): The target of each prediction, shape (N, )
|
||||
topk (int | tuple[int], optional): If the predictions in ``topk``
|
||||
matches the target, the predictions will be regarded as
|
||||
correct ones. Defaults to 1.
|
||||
thresh (float, optional): If not None, predictions with scores under
|
||||
this threshold are considered incorrect. Default to None.
|
||||
|
||||
Returns:
|
||||
float | tuple[float]: If the input ``topk`` is a single integer,
|
||||
the function will return a single float as accuracy. If
|
||||
``topk`` is a tuple containing multiple integers, the
|
||||
function will return a tuple containing accuracies of
|
||||
each ``topk`` number.
|
||||
"""
|
||||
assert isinstance(topk, (int, tuple))
|
||||
if isinstance(topk, int):
|
||||
topk = (topk, )
|
||||
return_single = True
|
||||
else:
|
||||
return_single = False
|
||||
|
||||
maxk = max(topk)
|
||||
if pred.shape[0] == 0:
|
||||
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
||||
return accu[0] if return_single else accu
|
||||
pred_value, pred_label = paddle.topk(pred, maxk, axis=1)
|
||||
pred_label = pred_label.transpose(
|
||||
[1, 0]) # transpose to shape (maxk, N)
|
||||
correct = paddle.equal(pred_label,
|
||||
(target.reshape([1, -1]).expand_as(pred_label)))
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = paddle.sum(correct[:k].reshape([-1]).astype('float32'),
|
||||
axis=0,
|
||||
keepdim=True)
|
||||
res.append(
|
||||
paddle.multiply(correct_k,
|
||||
paddle.to_tensor(100.0 / pred.shape[0])))
|
||||
return res[0] if return_single else res
|
||||
|
||||
def forward(self, pred, batch):
|
||||
node_preds, edge_preds = pred
|
||||
gts, tag = batch[4], batch[5]
|
||||
gts = self.pre_process(gts, tag)
|
||||
node_gts, edge_gts = [], []
|
||||
for gt in gts:
|
||||
node_gts.append(gt[:, 0])
|
||||
edge_gts.append(gt[:, 1:].reshape([-1]))
|
||||
node_gts = paddle.concat(node_gts)
|
||||
edge_gts = paddle.concat(edge_gts)
|
||||
|
||||
node_valids = paddle.nonzero(node_gts != self.ignore).reshape([-1])
|
||||
edge_valids = paddle.nonzero(edge_gts != -1).reshape([-1])
|
||||
loss_node = self.loss_node(node_preds, node_gts)
|
||||
loss_edge = self.loss_edge(edge_preds, edge_gts)
|
||||
loss = self.node_weight * loss_node + self.edge_weight * loss_edge
|
||||
return dict(
|
||||
loss=loss,
|
||||
loss_node=loss_node,
|
||||
loss_edge=loss_edge,
|
||||
acc_node=self.accuracy(
|
||||
paddle.gather(node_preds, node_valids),
|
||||
paddle.gather(node_gts, node_valids)),
|
||||
acc_edge=self.accuracy(
|
||||
paddle.gather(edge_preds, edge_valids),
|
||||
paddle.gather(edge_gts, edge_valids)))
|
||||
@@ -1,99 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class CosineEmbeddingLoss(nn.Layer):
|
||||
def __init__(self, margin=0.):
|
||||
super(CosineEmbeddingLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.epsilon = 1e-12
|
||||
|
||||
def forward(self, x1, x2, target):
|
||||
similarity = paddle.fluid.layers.reduce_sum(
|
||||
x1 * x2, dim=-1) / (paddle.norm(
|
||||
x1, axis=-1) * paddle.norm(
|
||||
x2, axis=-1) + self.epsilon)
|
||||
one_list = paddle.full_like(target, fill_value=1)
|
||||
out = paddle.fluid.layers.reduce_mean(
|
||||
paddle.where(
|
||||
paddle.equal(target, one_list), 1. - similarity,
|
||||
paddle.maximum(
|
||||
paddle.zeros_like(similarity), similarity - self.margin)))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class AsterLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
weight=None,
|
||||
size_average=True,
|
||||
ignore_index=-100,
|
||||
sequence_normalize=False,
|
||||
sample_normalize=True,
|
||||
**kwargs):
|
||||
super(AsterLoss, self).__init__()
|
||||
self.weight = weight
|
||||
self.size_average = size_average
|
||||
self.ignore_index = ignore_index
|
||||
self.sequence_normalize = sequence_normalize
|
||||
self.sample_normalize = sample_normalize
|
||||
self.loss_sem = CosineEmbeddingLoss()
|
||||
self.is_cosin_loss = True
|
||||
self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
targets = batch[1].astype("int64")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
sem_target = batch[3].astype('float32')
|
||||
embedding_vectors = predicts['embedding_vectors']
|
||||
rec_pred = predicts['rec_pred']
|
||||
|
||||
if not self.is_cosin_loss:
|
||||
sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
|
||||
else:
|
||||
label_target = paddle.ones([embedding_vectors.shape[0]])
|
||||
sem_loss = paddle.sum(
|
||||
self.loss_sem(embedding_vectors, sem_target, label_target))
|
||||
|
||||
# rec loss
|
||||
batch_size, def_max_length = targets.shape[0], targets.shape[1]
|
||||
|
||||
mask = paddle.zeros([batch_size, def_max_length])
|
||||
for i in range(batch_size):
|
||||
mask[i, :label_lengths[i]] = 1
|
||||
mask = paddle.cast(mask, "float32")
|
||||
max_length = max(label_lengths)
|
||||
assert max_length == rec_pred.shape[1]
|
||||
targets = targets[:, :max_length]
|
||||
mask = mask[:, :max_length]
|
||||
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]])
|
||||
input = nn.functional.log_softmax(rec_pred, axis=1)
|
||||
targets = paddle.reshape(targets, [-1, 1])
|
||||
mask = paddle.reshape(mask, [-1, 1])
|
||||
output = -paddle.index_sample(input, index=targets) * mask
|
||||
output = paddle.sum(output)
|
||||
if self.sequence_normalize:
|
||||
output = output / paddle.sum(mask)
|
||||
if self.sample_normalize:
|
||||
output = output / batch_size
|
||||
|
||||
loss = output + sem_loss * 0.1
|
||||
return {'loss': loss}
|
||||
@@ -1,39 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class AttentionLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(AttentionLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
targets = batch[1].astype("int64")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
|
||||
1], predicts.shape[2]
|
||||
assert len(targets.shape) == len(list(predicts.shape)) - 1, \
|
||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||
|
||||
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
|
||||
targets = paddle.reshape(targets, [-1])
|
||||
|
||||
return {'loss': paddle.sum(self.loss_func(inputs, targets))}
|
||||
@@ -1,45 +0,0 @@
|
||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class CTCLoss(nn.Layer):
|
||||
def __init__(self, use_focal_loss=False, **kwargs):
|
||||
super(CTCLoss, self).__init__()
|
||||
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
|
||||
self.use_focal_loss = use_focal_loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
if isinstance(predicts, (list, tuple)):
|
||||
predicts = predicts[-1]
|
||||
predicts = predicts.transpose((1, 0, 2))
|
||||
N, B, _ = predicts.shape
|
||||
preds_lengths = paddle.to_tensor(
|
||||
[N] * B, dtype='int64', place=paddle.CPUPlace())
|
||||
labels = batch[1].astype("int32")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
|
||||
if self.use_focal_loss:
|
||||
weight = paddle.exp(-loss)
|
||||
weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
|
||||
weight = paddle.square(weight)
|
||||
loss = paddle.multiply(loss, weight)
|
||||
loss = loss.mean()
|
||||
return {'loss': loss}
|
||||
@@ -1,70 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from .ace_loss import ACELoss
|
||||
from .center_loss import CenterLoss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
|
||||
|
||||
class EnhancedCTCLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
use_focal_loss=False,
|
||||
use_ace_loss=False,
|
||||
ace_loss_weight=0.1,
|
||||
use_center_loss=False,
|
||||
center_loss_weight=0.05,
|
||||
num_classes=6625,
|
||||
feat_dim=96,
|
||||
init_center=False,
|
||||
center_file_path=None,
|
||||
**kwargs):
|
||||
super(EnhancedCTCLoss, self).__init__()
|
||||
self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss)
|
||||
|
||||
self.use_ace_loss = False
|
||||
if use_ace_loss:
|
||||
self.use_ace_loss = use_ace_loss
|
||||
self.ace_loss_func = ACELoss()
|
||||
self.ace_loss_weight = ace_loss_weight
|
||||
|
||||
self.use_center_loss = False
|
||||
if use_center_loss:
|
||||
self.use_center_loss = use_center_loss
|
||||
self.center_loss_func = CenterLoss(
|
||||
num_classes=num_classes,
|
||||
feat_dim=feat_dim,
|
||||
init_center=init_center,
|
||||
center_file_path=center_file_path)
|
||||
self.center_loss_weight = center_loss_weight
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
loss = self.ctc_loss_func(predicts, batch)["loss"]
|
||||
|
||||
if self.use_center_loss:
|
||||
center_loss = self.center_loss_func(
|
||||
predicts, batch)["loss_center"] * self.center_loss_weight
|
||||
loss = loss + center_loss
|
||||
|
||||
if self.use_ace_loss:
|
||||
ace_loss = self.ace_loss_func(
|
||||
predicts, batch)["loss_ace"] * self.ace_loss_weight
|
||||
loss = loss + ace_loss
|
||||
|
||||
return {'enhanced_ctc_loss': loss}
|
||||
@@ -1,58 +0,0 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_sar_loss import SARLoss
|
||||
|
||||
|
||||
class MultiLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.loss_funcs = {}
|
||||
self.loss_list = kwargs.pop('loss_config_list')
|
||||
self.weight_1 = kwargs.get('weight_1', 1.0)
|
||||
self.weight_2 = kwargs.get('weight_2', 1.0)
|
||||
self.gtc_loss = kwargs.get('gtc_loss', 'sar')
|
||||
for loss_info in self.loss_list:
|
||||
for name, param in loss_info.items():
|
||||
if param is not None:
|
||||
kwargs.update(param)
|
||||
loss = eval(name)(**kwargs)
|
||||
self.loss_funcs[name] = loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
self.total_loss = {}
|
||||
total_loss = 0.0
|
||||
# batch [image, label_ctc, label_sar, length, valid_ratio]
|
||||
for name, loss_func in self.loss_funcs.items():
|
||||
if name == 'CTCLoss':
|
||||
loss = loss_func(predicts['ctc'],
|
||||
batch[:2] + batch[3:])['loss'] * self.weight_1
|
||||
elif name == 'SARLoss':
|
||||
loss = loss_func(predicts['sar'],
|
||||
batch[:1] + batch[2:])['loss'] * self.weight_2
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'{} is not supported in MultiLoss yet'.format(name))
|
||||
self.total_loss[name] = loss
|
||||
total_loss += loss
|
||||
self.total_loss['loss'] = total_loss
|
||||
return self.total_loss
|
||||
@@ -1,30 +0,0 @@
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class NRTRLoss(nn.Layer):
|
||||
def __init__(self, smoothing=True, **kwargs):
|
||||
super(NRTRLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
|
||||
self.smoothing = smoothing
|
||||
|
||||
def forward(self, pred, batch):
|
||||
pred = pred.reshape([-1, pred.shape[2]])
|
||||
max_len = batch[2].max()
|
||||
tgt = batch[1][:, 1:2 + max_len]
|
||||
tgt = tgt.reshape([-1])
|
||||
if self.smoothing:
|
||||
eps = 0.1
|
||||
n_class = pred.shape[1]
|
||||
one_hot = F.one_hot(tgt, pred.shape[1])
|
||||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||
log_prb = F.log_softmax(pred, axis=1)
|
||||
non_pad_mask = paddle.not_equal(
|
||||
tgt, paddle.zeros(
|
||||
tgt.shape, dtype=tgt.dtype))
|
||||
loss = -(one_hot * log_prb).sum(axis=1)
|
||||
loss = loss.masked_select(non_pad_mask).mean()
|
||||
else:
|
||||
loss = self.loss_func(pred, tgt)
|
||||
return {'loss': loss}
|
||||
@@ -1,30 +0,0 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class PRENLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(PRENLoss, self).__init__()
|
||||
# note: 0 is padding idx
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss = self.loss_func(predicts, batch[1].astype('int64'))
|
||||
return {'loss': loss}
|
||||
@@ -1,29 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class SARLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(SARLoss, self).__init__()
|
||||
ignore_index = kwargs.get('ignore_index', 92) # 6626
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
|
||||
reduction="mean", ignore_index=ignore_index)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
predict = predicts[:, :
|
||||
-1, :] # ignore last index of outputs to be in same seq_len with targets
|
||||
label = batch[1].astype(
|
||||
"int64")[:, 1:] # ignore first index of target in loss calculation
|
||||
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
|
||||
1], predict.shape[2]
|
||||
assert len(label.shape) == len(list(predict.shape)) - 1, \
|
||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||
|
||||
inputs = paddle.reshape(predict, [-1, num_classes])
|
||||
targets = paddle.reshape(label, [-1])
|
||||
loss = self.loss_func(inputs, targets)
|
||||
return {'loss': loss}
|
||||
@@ -1,47 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class SRNLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(SRNLoss, self).__init__()
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
predict = predicts['predict']
|
||||
word_predict = predicts['word_out']
|
||||
gsrm_predict = predicts['gsrm_out']
|
||||
label = batch[1]
|
||||
|
||||
casted_label = paddle.cast(x=label, dtype='int64')
|
||||
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
|
||||
|
||||
cost_word = self.loss_func(word_predict, label=casted_label)
|
||||
cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
|
||||
cost_vsfd = self.loss_func(predict, label=casted_label)
|
||||
|
||||
cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
|
||||
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
|
||||
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
|
||||
|
||||
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
|
||||
|
||||
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
|
||||
@@ -1,109 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import fluid
|
||||
|
||||
class TableAttentionLoss(nn.Layer):
|
||||
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
|
||||
super(TableAttentionLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
self.structure_weight = structure_weight
|
||||
self.loc_weight = loc_weight
|
||||
self.use_giou = use_giou
|
||||
self.giou_weight = giou_weight
|
||||
|
||||
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
|
||||
'''
|
||||
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:return: loss
|
||||
'''
|
||||
ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
|
||||
iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
|
||||
ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
|
||||
iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
|
||||
|
||||
iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
|
||||
ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
|
||||
|
||||
# overlap
|
||||
inters = iw * ih
|
||||
|
||||
# union
|
||||
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
|
||||
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
|
||||
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
|
||||
|
||||
# ious
|
||||
ious = inters / uni
|
||||
|
||||
ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
|
||||
ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
|
||||
ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
|
||||
ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
|
||||
ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
|
||||
eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
|
||||
|
||||
# enclose erea
|
||||
enclose = ew * eh + eps
|
||||
giou = ious - (enclose - uni) / enclose
|
||||
|
||||
loss = 1 - giou
|
||||
|
||||
if reduction == 'mean':
|
||||
loss = paddle.mean(loss)
|
||||
elif reduction == 'sum':
|
||||
loss = paddle.sum(loss)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
structure_probs = predicts['structure_probs']
|
||||
structure_targets = batch[1].astype("int64")
|
||||
structure_targets = structure_targets[:, 1:]
|
||||
if len(batch) == 6:
|
||||
structure_mask = batch[5].astype("int64")
|
||||
structure_mask = structure_mask[:, 1:]
|
||||
structure_mask = paddle.reshape(structure_mask, [-1])
|
||||
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
|
||||
structure_targets = paddle.reshape(structure_targets, [-1])
|
||||
structure_loss = self.loss_func(structure_probs, structure_targets)
|
||||
|
||||
if len(batch) == 6:
|
||||
structure_loss = structure_loss * structure_mask
|
||||
|
||||
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
|
||||
structure_loss = paddle.mean(structure_loss) * self.structure_weight
|
||||
|
||||
loc_preds = predicts['loc_preds']
|
||||
loc_targets = batch[2].astype("float32")
|
||||
loc_targets_mask = batch[4].astype("float32")
|
||||
loc_targets = loc_targets[:, 1:, :]
|
||||
loc_targets_mask = loc_targets_mask[:, 1:, :]
|
||||
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
|
||||
if self.use_giou:
|
||||
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
|
||||
total_loss = structure_loss + loc_loss + loc_loss_giou
|
||||
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
|
||||
else:
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
|
||||
@@ -1,42 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class VQASerTokenLayoutLMLoss(nn.Layer):
|
||||
def __init__(self, num_classes):
|
||||
super().__init__()
|
||||
self.loss_class = nn.CrossEntropyLoss()
|
||||
self.num_classes = num_classes
|
||||
self.ignore_index = self.loss_class.ignore_index
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
labels = batch[1]
|
||||
attention_mask = batch[4]
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.reshape([-1, ]) == 1
|
||||
active_outputs = predicts.reshape(
|
||||
[-1, self.num_classes])[active_loss]
|
||||
active_labels = labels.reshape([-1, ])[active_loss]
|
||||
loss = self.loss_class(active_outputs, active_labels)
|
||||
else:
|
||||
loss = self.loss_class(
|
||||
predicts.reshape([-1, self.num_classes]),
|
||||
labels.reshape([-1, ]))
|
||||
return {'loss': loss}
|
||||
@@ -1,47 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import copy
|
||||
|
||||
__all__ = ["build_metric"]
|
||||
|
||||
from .det_metric import DetMetric, DetFCEMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .distillation_metric import DistillationMetric
|
||||
from .table_metric import TableMetric
|
||||
from .kie_metric import KIEMetric
|
||||
from .vqa_token_ser_metric import VQASerTokenMetric
|
||||
from .vqa_token_re_metric import VQAReTokenMetric
|
||||
|
||||
|
||||
def build_metric(config):
|
||||
support_dict = [
|
||||
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
||||
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
|
||||
'VQAReTokenMetric'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop("name")
|
||||
assert module_name in support_dict, Exception(
|
||||
"metric only support {}".format(support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
@@ -1,46 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class ClsMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.eps = 1e-5
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred_label, *args, **kwargs):
|
||||
preds, labels = pred_label
|
||||
correct_num = 0
|
||||
all_num = 0
|
||||
for (pred, pred_conf), (target, _) in zip(preds, labels):
|
||||
if pred == target:
|
||||
correct_num += 1
|
||||
all_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
return {'acc': correct_num / (all_num + self.eps), }
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0
|
||||
}
|
||||
"""
|
||||
acc = self.correct_num / (self.all_num + self.eps)
|
||||
self.reset()
|
||||
return {'acc': acc}
|
||||
|
||||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
@@ -1,154 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ = ['DetMetric', 'DetFCEMetric']
|
||||
|
||||
from .eval_det_iou import DetectionIoUEvaluator
|
||||
|
||||
|
||||
class DetMetric(object):
|
||||
def __init__(self, main_indicator='hmean', **kwargs):
|
||||
self.evaluator = DetectionIoUEvaluator()
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
'''
|
||||
batch: a list produced by dataloaders.
|
||||
image: np.ndarray of shape (N, C, H, W).
|
||||
ratio_list: np.ndarray of shape(N,2)
|
||||
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
|
||||
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
|
||||
preds: a list of dict produced by post process
|
||||
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
|
||||
'''
|
||||
gt_polyons_batch = batch[2]
|
||||
ignore_tags_batch = batch[3]
|
||||
for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
|
||||
ignore_tags_batch):
|
||||
# prepare gt
|
||||
gt_info_list = [{
|
||||
'points': gt_polyon,
|
||||
'text': '',
|
||||
'ignore': ignore_tag
|
||||
} for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
|
||||
# prepare det
|
||||
det_info_list = [{
|
||||
'points': det_polyon,
|
||||
'text': ''
|
||||
} for det_polyon in pred['points']]
|
||||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'precision': 0,
|
||||
'recall': 0,
|
||||
'hmean': 0
|
||||
}
|
||||
"""
|
||||
|
||||
metrics = self.evaluator.combine_results(self.results)
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self.results = [] # clear results
|
||||
|
||||
|
||||
class DetFCEMetric(object):
|
||||
def __init__(self, main_indicator='hmean', **kwargs):
|
||||
self.evaluator = DetectionIoUEvaluator()
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
'''
|
||||
batch: a list produced by dataloaders.
|
||||
image: np.ndarray of shape (N, C, H, W).
|
||||
ratio_list: np.ndarray of shape(N,2)
|
||||
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
|
||||
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
|
||||
preds: a list of dict produced by post process
|
||||
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
|
||||
'''
|
||||
gt_polyons_batch = batch[2]
|
||||
ignore_tags_batch = batch[3]
|
||||
|
||||
for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
|
||||
ignore_tags_batch):
|
||||
# prepare gt
|
||||
gt_info_list = [{
|
||||
'points': gt_polyon,
|
||||
'text': '',
|
||||
'ignore': ignore_tag
|
||||
} for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
|
||||
# prepare det
|
||||
det_info_list = [{
|
||||
'points': det_polyon,
|
||||
'text': '',
|
||||
'score': score
|
||||
} for det_polyon, score in zip(pred['points'], pred['scores'])]
|
||||
|
||||
for score_thr in self.results.keys():
|
||||
det_info_list_thr = [
|
||||
det_info for det_info in det_info_list
|
||||
if det_info['score'] >= score_thr
|
||||
]
|
||||
result = self.evaluator.evaluate_image(gt_info_list,
|
||||
det_info_list_thr)
|
||||
self.results[score_thr].append(result)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {'heman':0,
|
||||
'thr 0.3':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.4':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.5':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.6':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.7':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.8':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.9':'precision: 0 recall: 0 hmean: 0',
|
||||
}
|
||||
"""
|
||||
metrics = {}
|
||||
hmean = 0
|
||||
for score_thr in self.results.keys():
|
||||
metric = self.evaluator.combine_results(self.results[score_thr])
|
||||
# for key, value in metric.items():
|
||||
# metrics['{}_{}'.format(key, score_thr)] = value
|
||||
metric_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format(
|
||||
metric['precision'], metric['recall'], metric['hmean'])
|
||||
metrics['thr {}'.format(score_thr)] = metric_str
|
||||
hmean = max(hmean, metric['hmean'])
|
||||
metrics['hmean'] = hmean
|
||||
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self.results = {
|
||||
0.3: [],
|
||||
0.4: [],
|
||||
0.5: [],
|
||||
0.6: [],
|
||||
0.7: [],
|
||||
0.8: [],
|
||||
0.9: []
|
||||
} # clear results
|
||||
@@ -1,73 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import copy
|
||||
|
||||
from .rec_metric import RecMetric
|
||||
from .det_metric import DetMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .cls_metric import ClsMetric
|
||||
|
||||
|
||||
class DistillationMetric(object):
|
||||
def __init__(self,
|
||||
key=None,
|
||||
base_metric_name=None,
|
||||
main_indicator=None,
|
||||
**kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.key = key
|
||||
self.main_indicator = main_indicator
|
||||
self.base_metric_name = base_metric_name
|
||||
self.kwargs = kwargs
|
||||
self.metrics = None
|
||||
|
||||
def _init_metrcis(self, preds):
|
||||
self.metrics = dict()
|
||||
mod = importlib.import_module(__name__)
|
||||
for key in preds:
|
||||
self.metrics[key] = getattr(mod, self.base_metric_name)(
|
||||
main_indicator=self.main_indicator, **self.kwargs)
|
||||
self.metrics[key].reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
assert isinstance(preds, dict)
|
||||
if self.metrics is None:
|
||||
self._init_metrcis(preds)
|
||||
output = dict()
|
||||
for key in preds:
|
||||
self.metrics[key].__call__(preds[key], batch, **kwargs)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
'norm_edit_dis': 0,
|
||||
}
|
||||
"""
|
||||
output = dict()
|
||||
for key in self.metrics:
|
||||
metric = self.metrics[key].get_metric()
|
||||
# main indicator
|
||||
if key == self.key:
|
||||
output.update(metric)
|
||||
else:
|
||||
for sub_key in metric:
|
||||
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
||||
return output
|
||||
|
||||
def reset(self):
|
||||
for key in self.metrics:
|
||||
self.metrics[key].reset()
|
||||
@@ -1,86 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ = ['E2EMetric']
|
||||
|
||||
from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results
|
||||
from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
|
||||
|
||||
|
||||
class E2EMetric(object):
|
||||
def __init__(self,
|
||||
mode,
|
||||
gt_mat_dir,
|
||||
character_dict_path,
|
||||
main_indicator='f_score_e2e',
|
||||
**kwargs):
|
||||
self.mode = mode
|
||||
self.gt_mat_dir = gt_mat_dir
|
||||
self.label_list = get_dict(character_dict_path)
|
||||
self.max_index = len(self.label_list)
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
if self.mode == 'A':
|
||||
gt_polyons_batch = batch[2]
|
||||
temp_gt_strs_batch = batch[3][0]
|
||||
ignore_tags_batch = batch[4]
|
||||
gt_strs_batch = []
|
||||
|
||||
for temp_list in temp_gt_strs_batch:
|
||||
t = ""
|
||||
for index in temp_list:
|
||||
if index < self.max_index:
|
||||
t += self.label_list[index]
|
||||
gt_strs_batch.append(t)
|
||||
|
||||
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
||||
[preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
|
||||
# prepare gt
|
||||
gt_info_list = [{
|
||||
'points': gt_polyon,
|
||||
'text': gt_str,
|
||||
'ignore': ignore_tag
|
||||
} for gt_polyon, gt_str, ignore_tag in
|
||||
zip(gt_polyons, gt_strs, ignore_tags)]
|
||||
# prepare det
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'texts': pred_str
|
||||
} for det_polyon, pred_str in
|
||||
zip(pred['points'], pred['texts'])]
|
||||
|
||||
result = get_socre_A(gt_info_list, e2e_info_list)
|
||||
self.results.append(result)
|
||||
else:
|
||||
img_id = batch[5][0]
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'texts': pred_str
|
||||
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
|
||||
result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
def get_metric(self):
|
||||
metrics = combine_results(self.results)
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self.results = [] # clear results
|
||||
@@ -1,225 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
"""
|
||||
reference from :
|
||||
https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
|
||||
"""
|
||||
|
||||
|
||||
class DetectionIoUEvaluator(object):
|
||||
def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
|
||||
self.iou_constraint = iou_constraint
|
||||
self.area_precision_constraint = area_precision_constraint
|
||||
|
||||
def evaluate_image(self, gt, pred):
|
||||
def get_union(pD, pG):
|
||||
return Polygon(pD).union(Polygon(pG)).area
|
||||
|
||||
def get_intersection_over_union(pD, pG):
|
||||
return get_intersection(pD, pG) / get_union(pD, pG)
|
||||
|
||||
def get_intersection(pD, pG):
|
||||
return Polygon(pD).intersection(Polygon(pG)).area
|
||||
|
||||
def compute_ap(confList, matchList, numGtCare):
|
||||
correct = 0
|
||||
AP = 0
|
||||
if len(confList) > 0:
|
||||
confList = np.array(confList)
|
||||
matchList = np.array(matchList)
|
||||
sorted_ind = np.argsort(-confList)
|
||||
confList = confList[sorted_ind]
|
||||
matchList = matchList[sorted_ind]
|
||||
for n in range(len(confList)):
|
||||
match = matchList[n]
|
||||
if match:
|
||||
correct += 1
|
||||
AP += float(correct) / (n + 1)
|
||||
|
||||
if numGtCare > 0:
|
||||
AP /= numGtCare
|
||||
|
||||
return AP
|
||||
|
||||
perSampleMetrics = {}
|
||||
|
||||
matchedSum = 0
|
||||
|
||||
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
||||
|
||||
numGlobalCareGt = 0
|
||||
numGlobalCareDet = 0
|
||||
|
||||
arrGlobalConfidences = []
|
||||
arrGlobalMatches = []
|
||||
|
||||
recall = 0
|
||||
precision = 0
|
||||
hmean = 0
|
||||
|
||||
detMatched = 0
|
||||
|
||||
iouMat = np.empty([1, 1])
|
||||
|
||||
gtPols = []
|
||||
detPols = []
|
||||
|
||||
gtPolPoints = []
|
||||
detPolPoints = []
|
||||
|
||||
# Array of Ground Truth Polygons' keys marked as don't Care
|
||||
gtDontCarePolsNum = []
|
||||
# Array of Detected Polygons' matched with a don't Care GT
|
||||
detDontCarePolsNum = []
|
||||
|
||||
pairs = []
|
||||
detMatchedNums = []
|
||||
|
||||
arrSampleConfidences = []
|
||||
arrSampleMatch = []
|
||||
|
||||
evaluationLog = ""
|
||||
|
||||
# print(len(gt))
|
||||
for n in range(len(gt)):
|
||||
points = gt[n]['points']
|
||||
# transcription = gt[n]['text']
|
||||
dontCare = gt[n]['ignore']
|
||||
# points = Polygon(points)
|
||||
# points = points.buffer(0)
|
||||
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
||||
continue
|
||||
|
||||
gtPol = points
|
||||
gtPols.append(gtPol)
|
||||
gtPolPoints.append(points)
|
||||
if dontCare:
|
||||
gtDontCarePolsNum.append(len(gtPols) - 1)
|
||||
|
||||
evaluationLog += "GT polygons: " + str(len(gtPols)) + (
|
||||
" (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
|
||||
if len(gtDontCarePolsNum) > 0 else "\n")
|
||||
|
||||
for n in range(len(pred)):
|
||||
points = pred[n]['points']
|
||||
# points = Polygon(points)
|
||||
# points = points.buffer(0)
|
||||
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
||||
continue
|
||||
|
||||
detPol = points
|
||||
detPols.append(detPol)
|
||||
detPolPoints.append(points)
|
||||
if len(gtDontCarePolsNum) > 0:
|
||||
for dontCarePol in gtDontCarePolsNum:
|
||||
dontCarePol = gtPols[dontCarePol]
|
||||
intersected_area = get_intersection(dontCarePol, detPol)
|
||||
pdDimensions = Polygon(detPol).area
|
||||
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
|
||||
if (precision > self.area_precision_constraint):
|
||||
detDontCarePolsNum.append(len(detPols) - 1)
|
||||
break
|
||||
|
||||
evaluationLog += "DET polygons: " + str(len(detPols)) + (
|
||||
" (" + str(len(detDontCarePolsNum)) + " don't care)\n"
|
||||
if len(detDontCarePolsNum) > 0 else "\n")
|
||||
|
||||
if len(gtPols) > 0 and len(detPols) > 0:
|
||||
# Calculate IoU and precision matrixs
|
||||
outputShape = [len(gtPols), len(detPols)]
|
||||
iouMat = np.empty(outputShape)
|
||||
gtRectMat = np.zeros(len(gtPols), np.int8)
|
||||
detRectMat = np.zeros(len(detPols), np.int8)
|
||||
for gtNum in range(len(gtPols)):
|
||||
for detNum in range(len(detPols)):
|
||||
pG = gtPols[gtNum]
|
||||
pD = detPols[detNum]
|
||||
iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
|
||||
|
||||
for gtNum in range(len(gtPols)):
|
||||
for detNum in range(len(detPols)):
|
||||
if gtRectMat[gtNum] == 0 and detRectMat[
|
||||
detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
|
||||
if iouMat[gtNum, detNum] > self.iou_constraint:
|
||||
gtRectMat[gtNum] = 1
|
||||
detRectMat[detNum] = 1
|
||||
detMatched += 1
|
||||
pairs.append({'gt': gtNum, 'det': detNum})
|
||||
detMatchedNums.append(detNum)
|
||||
evaluationLog += "Match GT #" + \
|
||||
str(gtNum) + " with Det #" + str(detNum) + "\n"
|
||||
|
||||
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
||||
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
||||
if numGtCare == 0:
|
||||
recall = float(1)
|
||||
precision = float(0) if numDetCare > 0 else float(1)
|
||||
else:
|
||||
recall = float(detMatched) / numGtCare
|
||||
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
|
||||
|
||||
hmean = 0 if (precision + recall) == 0 else 2.0 * \
|
||||
precision * recall / (precision + recall)
|
||||
|
||||
matchedSum += detMatched
|
||||
numGlobalCareGt += numGtCare
|
||||
numGlobalCareDet += numDetCare
|
||||
|
||||
perSampleMetrics = {
|
||||
'gtCare': numGtCare,
|
||||
'detCare': numDetCare,
|
||||
'detMatched': detMatched,
|
||||
}
|
||||
return perSampleMetrics
|
||||
|
||||
def combine_results(self, results):
|
||||
numGlobalCareGt = 0
|
||||
numGlobalCareDet = 0
|
||||
matchedSum = 0
|
||||
for result in results:
|
||||
numGlobalCareGt += result['gtCare']
|
||||
numGlobalCareDet += result['detCare']
|
||||
matchedSum += result['detMatched']
|
||||
|
||||
methodRecall = 0 if numGlobalCareGt == 0 else float(
|
||||
matchedSum) / numGlobalCareGt
|
||||
methodPrecision = 0 if numGlobalCareDet == 0 else float(
|
||||
matchedSum) / numGlobalCareDet
|
||||
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
|
||||
methodRecall * methodPrecision / (
|
||||
methodRecall + methodPrecision)
|
||||
# print(methodRecall, methodPrecision, methodHmean)
|
||||
# sys.exit(-1)
|
||||
methodMetrics = {
|
||||
'precision': methodPrecision,
|
||||
'recall': methodRecall,
|
||||
'hmean': methodHmean
|
||||
}
|
||||
|
||||
return methodMetrics
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
evaluator = DetectionIoUEvaluator()
|
||||
gts = [[{
|
||||
'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
|
||||
'text': 1234,
|
||||
'ignore': False,
|
||||
}, {
|
||||
'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
|
||||
'text': 5678,
|
||||
'ignore': False,
|
||||
}]]
|
||||
preds = [[{
|
||||
'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
|
||||
'text': 123,
|
||||
'ignore': False,
|
||||
}]]
|
||||
results = []
|
||||
for gt, pred in zip(gts, preds):
|
||||
results.append(evaluator.evaluate_image(gt, pred))
|
||||
metrics = evaluator.combine_results(results)
|
||||
print(metrics)
|
||||
@@ -1,71 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# The code is refer from: https://github.com/open-mmlab/mmocr/blob/main/mmocr/core/evaluation/kie_metric.py
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
__all__ = ['KIEMetric']
|
||||
|
||||
|
||||
class KIEMetric(object):
|
||||
def __init__(self, main_indicator='hmean', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
self.node = []
|
||||
self.gt = []
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
nodes, _ = preds
|
||||
gts, tag = batch[4].squeeze(0), batch[5].tolist()[0]
|
||||
gts = gts[:tag[0], :1].reshape([-1])
|
||||
self.node.append(nodes.numpy())
|
||||
self.gt.append(gts)
|
||||
# result = self.compute_f1_score(nodes, gts)
|
||||
# self.results.append(result)
|
||||
|
||||
def compute_f1_score(self, preds, gts):
|
||||
ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]
|
||||
C = preds.shape[1]
|
||||
classes = np.array(sorted(set(range(C)) - set(ignores)))
|
||||
hist = np.bincount(
|
||||
(gts * C).astype('int64') + preds.argmax(1), minlength=C
|
||||
**2).reshape([C, C]).astype('float32')
|
||||
diag = np.diag(hist)
|
||||
recalls = diag / hist.sum(1).clip(min=1)
|
||||
precisions = diag / hist.sum(0).clip(min=1)
|
||||
f1 = 2 * recalls * precisions / (recalls + precisions).clip(min=1e-8)
|
||||
return f1[classes]
|
||||
|
||||
def combine_results(self, results):
|
||||
node = np.concatenate(self.node, 0)
|
||||
gts = np.concatenate(self.gt, 0)
|
||||
results = self.compute_f1_score(node, gts)
|
||||
data = {'hmean': results.mean()}
|
||||
return data
|
||||
|
||||
def get_metric(self):
|
||||
|
||||
metrics = self.combine_results(self.results)
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self.results = [] # clear results
|
||||
self.node = []
|
||||
self.gt = []
|
||||
@@ -1,76 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import Levenshtein
|
||||
import string
|
||||
|
||||
|
||||
class RecMetric(object):
|
||||
def __init__(self,
|
||||
main_indicator='acc',
|
||||
is_filter=False,
|
||||
ignore_space=True,
|
||||
**kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.is_filter = is_filter
|
||||
self.ignore_space = ignore_space
|
||||
self.eps = 1e-5
|
||||
self.reset()
|
||||
|
||||
def _normalize_text(self, text):
|
||||
text = ''.join(
|
||||
filter(lambda x: x in (string.digits + string.ascii_letters), text))
|
||||
return text.lower()
|
||||
|
||||
def __call__(self, pred_label, *args, **kwargs):
|
||||
preds, labels = pred_label
|
||||
correct_num = 0
|
||||
all_num = 0
|
||||
norm_edit_dis = 0.0
|
||||
for (pred, pred_conf), (target, _) in zip(preds, labels):
|
||||
if self.ignore_space:
|
||||
pred = pred.replace(" ", "")
|
||||
target = target.replace(" ", "")
|
||||
if self.is_filter:
|
||||
pred = self._normalize_text(pred)
|
||||
target = self._normalize_text(target)
|
||||
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
||||
len(pred), len(target), 1)
|
||||
if pred == target:
|
||||
correct_num += 1
|
||||
all_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
self.norm_edit_dis += norm_edit_dis
|
||||
return {
|
||||
'acc': correct_num / (all_num + self.eps),
|
||||
'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
|
||||
}
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
'norm_edit_dis': 0,
|
||||
}
|
||||
"""
|
||||
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
|
||||
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
|
||||
self.reset()
|
||||
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
||||
|
||||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
self.norm_edit_dis = 0
|
||||
@@ -1,51 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TableMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.eps = 1e-5
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred, batch, *args, **kwargs):
|
||||
structure_probs = pred['structure_probs'].numpy()
|
||||
structure_labels = batch[1]
|
||||
correct_num = 0
|
||||
all_num = 0
|
||||
structure_probs = np.argmax(structure_probs, axis=2)
|
||||
structure_labels = structure_labels[:, 1:]
|
||||
batch_size = structure_probs.shape[0]
|
||||
for bno in range(batch_size):
|
||||
all_num += 1
|
||||
if (structure_probs[bno] == structure_labels[bno]).all():
|
||||
correct_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
return {'acc': correct_num * 1.0 / (all_num + self.eps), }
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
}
|
||||
"""
|
||||
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
|
||||
self.reset()
|
||||
return {'acc': acc}
|
||||
|
||||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
@@ -1,176 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
__all__ = ['KIEMetric']
|
||||
|
||||
|
||||
class VQAReTokenMetric(object):
|
||||
def __init__(self, main_indicator='hmean', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
pred_relations, relations, entities = preds
|
||||
self.pred_relations_list.extend(pred_relations)
|
||||
self.relations_list.extend(relations)
|
||||
self.entities_list.extend(entities)
|
||||
|
||||
def get_metric(self):
|
||||
gt_relations = []
|
||||
for b in range(len(self.relations_list)):
|
||||
rel_sent = []
|
||||
for head, tail in zip(self.relations_list[b]["head"],
|
||||
self.relations_list[b]["tail"]):
|
||||
rel = {}
|
||||
rel["head_id"] = head
|
||||
rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
|
||||
self.entities_list[b]["end"][rel["head_id"]])
|
||||
rel["head_type"] = self.entities_list[b]["label"][rel[
|
||||
"head_id"]]
|
||||
|
||||
rel["tail_id"] = tail
|
||||
rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
|
||||
self.entities_list[b]["end"][rel["tail_id"]])
|
||||
rel["tail_type"] = self.entities_list[b]["label"][rel[
|
||||
"tail_id"]]
|
||||
|
||||
rel["type"] = 1
|
||||
rel_sent.append(rel)
|
||||
gt_relations.append(rel_sent)
|
||||
re_metrics = self.re_score(
|
||||
self.pred_relations_list, gt_relations, mode="boundaries")
|
||||
metrics = {
|
||||
"precision": re_metrics["ALL"]["p"],
|
||||
"recall": re_metrics["ALL"]["r"],
|
||||
"hmean": re_metrics["ALL"]["f1"],
|
||||
}
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self.pred_relations_list = []
|
||||
self.relations_list = []
|
||||
self.entities_list = []
|
||||
|
||||
def re_score(self, pred_relations, gt_relations, mode="strict"):
|
||||
"""Evaluate RE predictions
|
||||
|
||||
Args:
|
||||
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
|
||||
gt_relations (list) : list of list of ground truth relations
|
||||
|
||||
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
|
||||
"tail": (start_idx (inclusive), end_idx (exclusive)),
|
||||
"head_type": ent_type,
|
||||
"tail_type": ent_type,
|
||||
"type": rel_type}
|
||||
|
||||
vocab (Vocab) : dataset vocabulary
|
||||
mode (str) : in 'strict' or 'boundaries'"""
|
||||
|
||||
assert mode in ["strict", "boundaries"]
|
||||
|
||||
relation_types = [v for v in [0, 1] if not v == 0]
|
||||
scores = {
|
||||
rel: {
|
||||
"tp": 0,
|
||||
"fp": 0,
|
||||
"fn": 0
|
||||
}
|
||||
for rel in relation_types + ["ALL"]
|
||||
}
|
||||
|
||||
# Count GT relations and Predicted relations
|
||||
n_sents = len(gt_relations)
|
||||
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
|
||||
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
|
||||
|
||||
# Count TP, FP and FN per type
|
||||
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
|
||||
for rel_type in relation_types:
|
||||
# strict mode takes argument types into account
|
||||
if mode == "strict":
|
||||
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
|
||||
rel["tail_type"])
|
||||
for rel in pred_sent
|
||||
if rel["type"] == rel_type}
|
||||
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
|
||||
rel["tail_type"])
|
||||
for rel in gt_sent if rel["type"] == rel_type}
|
||||
|
||||
# boundaries mode only takes argument spans into account
|
||||
elif mode == "boundaries":
|
||||
pred_rels = {(rel["head"], rel["tail"])
|
||||
for rel in pred_sent
|
||||
if rel["type"] == rel_type}
|
||||
gt_rels = {(rel["head"], rel["tail"])
|
||||
for rel in gt_sent if rel["type"] == rel_type}
|
||||
|
||||
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
|
||||
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
|
||||
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
|
||||
|
||||
# Compute per entity Precision / Recall / F1
|
||||
for rel_type in scores.keys():
|
||||
if scores[rel_type]["tp"]:
|
||||
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
|
||||
scores[rel_type]["fp"] + scores[rel_type]["tp"])
|
||||
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
|
||||
scores[rel_type]["fn"] + scores[rel_type]["tp"])
|
||||
else:
|
||||
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
|
||||
|
||||
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
|
||||
scores[rel_type]["f1"] = (
|
||||
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
|
||||
(scores[rel_type]["p"] + scores[rel_type]["r"]))
|
||||
else:
|
||||
scores[rel_type]["f1"] = 0
|
||||
|
||||
# Compute micro F1 Scores
|
||||
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
|
||||
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
|
||||
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
|
||||
|
||||
if tp:
|
||||
precision = tp / (tp + fp)
|
||||
recall = tp / (tp + fn)
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
|
||||
else:
|
||||
precision, recall, f1 = 0, 0, 0
|
||||
|
||||
scores["ALL"]["p"] = precision
|
||||
scores["ALL"]["r"] = recall
|
||||
scores["ALL"]["f1"] = f1
|
||||
scores["ALL"]["tp"] = tp
|
||||
scores["ALL"]["fp"] = fp
|
||||
scores["ALL"]["fn"] = fn
|
||||
|
||||
# Compute Macro F1 Scores
|
||||
scores["ALL"]["Macro_f1"] = np.mean(
|
||||
[scores[ent_type]["f1"] for ent_type in relation_types])
|
||||
scores["ALL"]["Macro_p"] = np.mean(
|
||||
[scores[ent_type]["p"] for ent_type in relation_types])
|
||||
scores["ALL"]["Macro_r"] = np.mean(
|
||||
[scores[ent_type]["r"] for ent_type in relation_types])
|
||||
|
||||
return scores
|
||||
@@ -1,47 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
__all__ = ['KIEMetric']
|
||||
|
||||
|
||||
class VQASerTokenMetric(object):
|
||||
def __init__(self, main_indicator='hmean', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
preds, labels = preds
|
||||
self.pred_list.extend(preds)
|
||||
self.gt_list.extend(labels)
|
||||
|
||||
def get_metric(self):
|
||||
from seqeval.metrics import f1_score, precision_score, recall_score
|
||||
metrics = {
|
||||
"precision": precision_score(self.gt_list, self.pred_list),
|
||||
"recall": recall_score(self.gt_list, self.pred_list),
|
||||
"hmean": f1_score(self.gt_list, self.pred_list),
|
||||
}
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self.pred_list = []
|
||||
self.gt_list = []
|
||||
@@ -1,32 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import importlib
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .distillation_model import DistillationModel
|
||||
|
||||
__all__ = ['build_model']
|
||||
|
||||
|
||||
def build_model(config):
|
||||
config = copy.deepcopy(config)
|
||||
if not "name" in config:
|
||||
arch = BaseModel(config)
|
||||
else:
|
||||
name = config.pop("name")
|
||||
mod = importlib.import_module(__name__)
|
||||
arch = getattr(mod, name)(config)
|
||||
return arch
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from paddle import nn
|
||||
from ppocr.modeling.transforms import build_transform
|
||||
from ppocr.modeling.backbones import build_backbone
|
||||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
|
||||
__all__ = ['BaseModel']
|
||||
|
||||
|
||||
class BaseModel(nn.Layer):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
the module for OCR.
|
||||
args:
|
||||
config (dict): the super parameters for module.
|
||||
"""
|
||||
super(BaseModel, self).__init__()
|
||||
in_channels = config.get('in_channels', 3)
|
||||
model_type = config['model_type']
|
||||
# build transfrom,
|
||||
# for rec, transfrom can be TPS,None
|
||||
# for det and cls, transfrom shoule to be None,
|
||||
# if you make model differently, you can use transfrom in det and cls
|
||||
if 'Transform' not in config or config['Transform'] is None:
|
||||
self.use_transform = False
|
||||
else:
|
||||
self.use_transform = True
|
||||
config['Transform']['in_channels'] = in_channels
|
||||
self.transform = build_transform(config['Transform'])
|
||||
in_channels = self.transform.out_channels
|
||||
|
||||
# build backbone, backbone is need for del, rec and cls
|
||||
config["Backbone"]['in_channels'] = in_channels
|
||||
self.backbone = build_backbone(config["Backbone"], model_type)
|
||||
in_channels = self.backbone.out_channels
|
||||
|
||||
# build neck
|
||||
# for rec, neck can be cnn,rnn or reshape(None)
|
||||
# for det, neck can be FPN, BIFPN and so on.
|
||||
# for cls, neck should be none
|
||||
if 'Neck' not in config or config['Neck'] is None:
|
||||
self.use_neck = False
|
||||
else:
|
||||
self.use_neck = True
|
||||
config['Neck']['in_channels'] = in_channels
|
||||
self.neck = build_neck(config['Neck'])
|
||||
in_channels = self.neck.out_channels
|
||||
|
||||
# # build head, head is need for det, rec and cls
|
||||
if 'Head' not in config or config['Head'] is None:
|
||||
self.use_head = False
|
||||
else:
|
||||
self.use_head = True
|
||||
config["Head"]['in_channels'] = in_channels
|
||||
self.head = build_head(config["Head"])
|
||||
|
||||
self.return_all_feats = config.get("return_all_feats", False)
|
||||
|
||||
def forward(self, x, data=None):
|
||||
y = dict()
|
||||
if self.use_transform:
|
||||
x = self.transform(x)
|
||||
x = self.backbone(x)
|
||||
y["backbone_out"] = x
|
||||
if self.use_neck:
|
||||
x = self.neck(x)
|
||||
y["neck_out"] = x
|
||||
if self.use_head:
|
||||
x = self.head(x, targets=data)
|
||||
# for multi head, save ctc neck out for udml
|
||||
if isinstance(x, dict) and 'ctc_neck' in x.keys():
|
||||
y["neck_out"] = x["ctc_neck"]
|
||||
y["head_out"] = x
|
||||
elif isinstance(x, dict):
|
||||
y.update(x)
|
||||
else:
|
||||
y["head_out"] = x
|
||||
if self.return_all_feats:
|
||||
if self.training:
|
||||
return y
|
||||
else:
|
||||
return {"head_out": y["head_out"]}
|
||||
else:
|
||||
return x
|
||||
@@ -1,60 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
from ppocr.modeling.transforms import build_transform
|
||||
from ppocr.modeling.backbones import build_backbone
|
||||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
from .base_model import BaseModel
|
||||
from ppocr.utils.save_load import load_pretrained_params
|
||||
|
||||
__all__ = ['DistillationModel']
|
||||
|
||||
|
||||
class DistillationModel(nn.Layer):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
the module for OCR distillation.
|
||||
args:
|
||||
config (dict): the super parameters for module.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model_list = []
|
||||
self.model_name_list = []
|
||||
for key in config["Models"]:
|
||||
model_config = config["Models"][key]
|
||||
freeze_params = False
|
||||
pretrained = None
|
||||
if "freeze_params" in model_config:
|
||||
freeze_params = model_config.pop("freeze_params")
|
||||
if "pretrained" in model_config:
|
||||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
load_pretrained_params(model, pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
self.model_list.append(self.add_sublayer(key, model))
|
||||
self.model_name_list.append(key)
|
||||
|
||||
def forward(self, x, data=None):
|
||||
result_dict = dict()
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
result_dict[model_name] = self.model_list[idx](x, data)
|
||||
return result_dict
|
||||
@@ -1,64 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ["build_backbone"]
|
||||
|
||||
|
||||
def build_backbone(config, model_type):
|
||||
if model_type == "det" or model_type == "table":
|
||||
from .det_mobilenet_v3 import MobileNetV3
|
||||
from .det_resnet_vd import ResNet
|
||||
from .det_resnet_vd_sast import ResNet_SAST
|
||||
support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
|
||||
elif model_type == "rec" or model_type == "cls":
|
||||
from .rec_mobilenet_v3 import MobileNetV3
|
||||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_nrtr_mtb import MTB
|
||||
from .rec_resnet_31 import ResNet31
|
||||
from .rec_resnet_aster import ResNet_ASTER
|
||||
from .rec_micronet import MicroNet
|
||||
from .rec_efficientb3_pren import EfficientNetb3_PREN
|
||||
from .rec_svtrnet import SVTRNet
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
|
||||
'SVTRNet'
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ['ResNet']
|
||||
elif model_type == 'kie':
|
||||
from .kie_unet_sdmgr import Kie_backbone
|
||||
support_dict = ['Kie_backbone']
|
||||
elif model_type == "table":
|
||||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
elif model_type == 'vqa':
|
||||
from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
|
||||
support_dict = [
|
||||
"LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe',
|
||||
"LayoutXLMForSer", 'LayoutXLMForRe'
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
module_name = config.pop("name")
|
||||
assert module_name in support_dict, Exception(
|
||||
"when model typs is {}, backbone only support {}".format(model_type,
|
||||
support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
@@ -1,268 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
__all__ = ['MobileNetV3']
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class MobileNetV3(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
model_name='large',
|
||||
scale=0.5,
|
||||
disable_se=False,
|
||||
**kwargs):
|
||||
"""
|
||||
the MobilenetV3 backbone network for detection module.
|
||||
Args:
|
||||
params(dict): the super parameters for build network
|
||||
"""
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
self.disable_se = disable_se
|
||||
|
||||
if model_name == "large":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, False, 'relu', 1],
|
||||
[3, 64, 24, False, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 1],
|
||||
[5, 72, 40, True, 'relu', 2],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[3, 240, 80, False, 'hardswish', 2],
|
||||
[3, 200, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 480, 112, True, 'hardswish', 1],
|
||||
[3, 672, 112, True, 'hardswish', 1],
|
||||
[5, 672, 160, True, 'hardswish', 2],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 960
|
||||
elif model_name == "small":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, True, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 2],
|
||||
[3, 88, 24, False, 'relu', 1],
|
||||
[5, 96, 40, True, 'hardswish', 2],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 120, 48, True, 'hardswish', 1],
|
||||
[5, 144, 48, True, 'hardswish', 1],
|
||||
[5, 288, 96, True, 'hardswish', 2],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 576
|
||||
else:
|
||||
raise NotImplementedError("mode[" + model_name +
|
||||
"_model] is not implemented!")
|
||||
|
||||
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
|
||||
assert scale in supported_scale, \
|
||||
"supported scale are {} but input scale is {}".format(supported_scale, scale)
|
||||
inplanes = 16
|
||||
# conv1
|
||||
self.conv = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=make_divisible(inplanes * scale),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish')
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
block_list = []
|
||||
i = 0
|
||||
inplanes = make_divisible(inplanes * scale)
|
||||
for (k, exp, c, se, nl, s) in cfg:
|
||||
se = se and not self.disable_se
|
||||
start_idx = 2 if model_name == 'large' else 0
|
||||
if s == 2 and i > start_idx:
|
||||
self.out_channels.append(inplanes)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
block_list = []
|
||||
block_list.append(
|
||||
ResidualUnit(
|
||||
in_channels=inplanes,
|
||||
mid_channels=make_divisible(scale * exp),
|
||||
out_channels=make_divisible(scale * c),
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl))
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
block_list.append(
|
||||
ConvBNLayer(
|
||||
in_channels=inplanes,
|
||||
out_channels=make_divisible(scale * cls_ch_squeeze),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish'))
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||
for i, stage in enumerate(self.stages):
|
||||
self.add_sublayer(sublayer=stage, name="stage{}".format(i))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
out_list = []
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
out_list.append(x)
|
||||
return out_list
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.if_act:
|
||||
if self.act == "relu":
|
||||
x = F.relu(x)
|
||||
elif self.act == "hardswish":
|
||||
x = F.hardswish(x)
|
||||
else:
|
||||
print("The activation function({}) is selected incorrectly.".
|
||||
format(self.act))
|
||||
exit()
|
||||
return x
|
||||
|
||||
|
||||
class ResidualUnit(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
mid_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
use_se,
|
||||
act=None):
|
||||
super(ResidualUnit, self).__init__()
|
||||
self.if_shortcut = stride == 1 and in_channels == out_channels
|
||||
self.if_se = use_se
|
||||
|
||||
self.expand_conv = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=True,
|
||||
act=act)
|
||||
self.bottleneck_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=int((kernel_size - 1) // 2),
|
||||
groups=mid_channels,
|
||||
if_act=True,
|
||||
act=act)
|
||||
if self.if_se:
|
||||
self.mid_se = SEModule(mid_channels)
|
||||
self.linear_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.expand_conv(inputs)
|
||||
x = self.bottleneck_conv(x)
|
||||
if self.if_se:
|
||||
x = self.mid_se(x)
|
||||
x = self.linear_conv(x)
|
||||
if self.if_shortcut:
|
||||
x = paddle.add(inputs, x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, in_channels, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=in_channels // reduction,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
|
||||
return inputs * outputs
|
||||
@@ -1,351 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from paddle.vision.ops import DeformConv2D
|
||||
from paddle.regularizer import L2Decay
|
||||
from paddle.nn.initializer import Normal, Constant, XavierUniform
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class DeformableConvV2(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
weight_attr=None,
|
||||
bias_attr=None,
|
||||
lr_scale=1,
|
||||
regularizer=None,
|
||||
skip_quant=False,
|
||||
dcn_bias_regularizer=L2Decay(0.),
|
||||
dcn_bias_lr_scale=2.):
|
||||
super(DeformableConvV2, self).__init__()
|
||||
self.offset_channel = 2 * kernel_size**2 * groups
|
||||
self.mask_channel = kernel_size**2 * groups
|
||||
|
||||
if bias_attr:
|
||||
# in FCOS-DCN head, specifically need learning_rate and regularizer
|
||||
dcn_bias_attr = ParamAttr(
|
||||
initializer=Constant(value=0),
|
||||
regularizer=dcn_bias_regularizer,
|
||||
learning_rate=dcn_bias_lr_scale)
|
||||
else:
|
||||
# in ResNet backbone, do not need bias
|
||||
dcn_bias_attr = False
|
||||
self.conv_dcn = DeformConv2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2 * dilation,
|
||||
dilation=dilation,
|
||||
deformable_groups=groups,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=dcn_bias_attr)
|
||||
|
||||
if lr_scale == 1 and regularizer is None:
|
||||
offset_bias_attr = ParamAttr(initializer=Constant(0.))
|
||||
else:
|
||||
offset_bias_attr = ParamAttr(
|
||||
initializer=Constant(0.),
|
||||
learning_rate=lr_scale,
|
||||
regularizer=regularizer)
|
||||
self.conv_offset = nn.Conv2D(
|
||||
in_channels,
|
||||
groups * 3 * kernel_size**2,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
weight_attr=ParamAttr(initializer=Constant(0.0)),
|
||||
bias_attr=offset_bias_attr)
|
||||
if skip_quant:
|
||||
self.conv_offset.skip_quant = True
|
||||
|
||||
def forward(self, x):
|
||||
offset_mask = self.conv_offset(x)
|
||||
offset, mask = paddle.split(
|
||||
offset_mask,
|
||||
num_or_sections=[self.offset_channel, self.mask_channel],
|
||||
axis=1)
|
||||
mask = F.sigmoid(mask)
|
||||
y = self.conv_dcn(x, offset, mask=mask)
|
||||
return y
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
is_dcn=False):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
if not is_dcn:
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
bias_attr=False)
|
||||
else:
|
||||
self._conv = DeformableConvV2(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=2, #groups,
|
||||
bias_attr=False)
|
||||
self._batch_norm = nn.BatchNorm(out_channels, act=act)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.is_vd_mode:
|
||||
inputs = self._pool2d_avg(inputs)
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
is_dcn=False, ):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu')
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
is_dcn=is_dcn)
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None)
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True)
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False, ):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu')
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None)
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True)
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
layers=50,
|
||||
dcn_stage=None,
|
||||
out_indices=None,
|
||||
**kwargs):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
depth = [3, 4, 6, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
elif layers == 200:
|
||||
depth = [3, 12, 48, 3]
|
||||
num_channels = [64, 256, 512,
|
||||
1024] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
self.dcn_stage = dcn_stage if dcn_stage is not None else [
|
||||
False, False, False, False
|
||||
]
|
||||
self.out_indices = out_indices if out_indices is not None else [
|
||||
0, 1, 2, 3
|
||||
]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act='relu')
|
||||
self.conv1_2 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu')
|
||||
self.conv1_3 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu')
|
||||
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
is_dcn = self.dcn_stage[block]
|
||||
for i in range(depth[block]):
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
is_dcn=is_dcn))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
if block in self.out_indices:
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
# is_dcn = self.dcn_stage[block]
|
||||
for i in range(depth[block]):
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BasicBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
if block in self.out_indices:
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1_1(inputs)
|
||||
y = self.conv1_2(y)
|
||||
y = self.conv1_3(y)
|
||||
y = self.pool2d_max(y)
|
||||
out = []
|
||||
for i, block in enumerate(self.stages):
|
||||
y = block(y)
|
||||
if i in self.out_indices:
|
||||
out.append(y)
|
||||
return out
|
||||
@@ -1,285 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ["ResNet_SAST"]
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None, ):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.is_vd_mode:
|
||||
inputs = self._pool2d_avg(inputs)
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ResNet_SAST(nn.Layer):
|
||||
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||
super(ResNet_SAST, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
# depth = [3, 4, 6, 3]
|
||||
depth = [3, 4, 6, 3, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
elif layers == 200:
|
||||
depth = [3, 12, 48, 3]
|
||||
# num_channels = [64, 256, 512,
|
||||
# 1024] if layers >= 50 else [64, 64, 128, 256]
|
||||
# num_filters = [64, 128, 256, 512]
|
||||
num_channels = [64, 256, 512,
|
||||
1024, 2048] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512, 512]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1_1")
|
||||
self.conv1_2 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_2")
|
||||
self.conv1_3 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_3")
|
||||
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = [3, 64]
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BasicBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
out = [inputs]
|
||||
y = self.conv1_1(inputs)
|
||||
y = self.conv1_2(y)
|
||||
y = self.conv1_3(y)
|
||||
out.append(y)
|
||||
y = self.pool2d_max(y)
|
||||
for block in self.stages:
|
||||
y = block(y)
|
||||
out.append(y)
|
||||
return out
|
||||
@@ -1,265 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None, ):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
# depth = [3, 4, 6, 3]
|
||||
depth = [3, 4, 6, 3, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
elif layers == 200:
|
||||
depth = [3, 12, 48, 3]
|
||||
num_channels = [64, 256, 512, 1024,
|
||||
2048] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512, 512]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1_1")
|
||||
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = [3, 64]
|
||||
# num_filters = [64, 128, 256, 512, 512]
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BasicBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
out = [inputs]
|
||||
y = self.conv1_1(inputs)
|
||||
out.append(y)
|
||||
y = self.pool2d_max(y)
|
||||
for block in self.stages:
|
||||
y = block(y)
|
||||
out.append(y)
|
||||
return out
|
||||
@@ -1,186 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
__all__ = ["Kie_backbone"]
|
||||
|
||||
|
||||
class Encoder(nn.Layer):
|
||||
def __init__(self, num_channels, num_filters):
|
||||
super(Encoder, self).__init__()
|
||||
self.conv1 = nn.Conv2D(
|
||||
num_channels,
|
||||
num_filters,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
self.conv2 = nn.Conv2D(
|
||||
num_filters,
|
||||
num_filters,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn2 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.conv1(inputs)
|
||||
x = self.bn1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x_pooled = self.pool(x)
|
||||
return x, x_pooled
|
||||
|
||||
|
||||
class Decoder(nn.Layer):
|
||||
def __init__(self, num_channels, num_filters):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(
|
||||
num_channels,
|
||||
num_filters,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
self.conv2 = nn.Conv2D(
|
||||
num_filters,
|
||||
num_filters,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn2 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
self.conv0 = nn.Conv2D(
|
||||
num_channels,
|
||||
num_filters,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False)
|
||||
self.bn0 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
def forward(self, inputs_prev, inputs):
|
||||
x = self.conv0(inputs)
|
||||
x = self.bn0(x)
|
||||
x = paddle.nn.functional.interpolate(
|
||||
x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
x = paddle.concat([inputs_prev, x], axis=1)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
return x
|
||||
|
||||
|
||||
class UNet(nn.Layer):
|
||||
def __init__(self):
|
||||
super(UNet, self).__init__()
|
||||
self.down1 = Encoder(num_channels=3, num_filters=16)
|
||||
self.down2 = Encoder(num_channels=16, num_filters=32)
|
||||
self.down3 = Encoder(num_channels=32, num_filters=64)
|
||||
self.down4 = Encoder(num_channels=64, num_filters=128)
|
||||
self.down5 = Encoder(num_channels=128, num_filters=256)
|
||||
|
||||
self.up1 = Decoder(32, 16)
|
||||
self.up2 = Decoder(64, 32)
|
||||
self.up3 = Decoder(128, 64)
|
||||
self.up4 = Decoder(256, 128)
|
||||
self.out_channels = 16
|
||||
|
||||
def forward(self, inputs):
|
||||
x1, _ = self.down1(inputs)
|
||||
_, x2 = self.down2(x1)
|
||||
_, x3 = self.down3(x2)
|
||||
_, x4 = self.down4(x3)
|
||||
_, x5 = self.down5(x4)
|
||||
|
||||
x = self.up4(x4, x5)
|
||||
x = self.up3(x3, x)
|
||||
x = self.up2(x2, x)
|
||||
x = self.up1(x1, x)
|
||||
return x
|
||||
|
||||
|
||||
class Kie_backbone(nn.Layer):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(Kie_backbone, self).__init__()
|
||||
self.out_channels = 16
|
||||
self.img_feat = UNet()
|
||||
self.maxpool = nn.MaxPool2D(kernel_size=7)
|
||||
|
||||
def bbox2roi(self, bbox_list):
|
||||
rois_list = []
|
||||
rois_num = []
|
||||
for img_id, bboxes in enumerate(bbox_list):
|
||||
rois_num.append(bboxes.shape[0])
|
||||
rois_list.append(bboxes)
|
||||
rois = paddle.concat(rois_list, 0)
|
||||
rois_num = paddle.to_tensor(rois_num, dtype='int32')
|
||||
return rois, rois_num
|
||||
|
||||
def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
|
||||
img, relations, texts, gt_bboxes, tag, img_size = img.numpy(
|
||||
), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy(
|
||||
).tolist(), img_size.numpy()
|
||||
temp_relations, temp_texts, temp_gt_bboxes = [], [], []
|
||||
h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1]))
|
||||
img = paddle.to_tensor(img[:, :, :h, :w])
|
||||
batch = len(tag)
|
||||
for i in range(batch):
|
||||
num, recoder_len = tag[i][0], tag[i][1]
|
||||
temp_relations.append(
|
||||
paddle.to_tensor(
|
||||
relations[i, :num, :num, :], dtype='float32'))
|
||||
temp_texts.append(
|
||||
paddle.to_tensor(
|
||||
texts[i, :num, :recoder_len], dtype='float32'))
|
||||
temp_gt_bboxes.append(
|
||||
paddle.to_tensor(
|
||||
gt_bboxes[i, :num, ...], dtype='float32'))
|
||||
return img, temp_relations, temp_texts, temp_gt_bboxes
|
||||
|
||||
def forward(self, inputs):
|
||||
img = inputs[0]
|
||||
relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
|
||||
2], inputs[3], inputs[5], inputs[-1]
|
||||
img, relations, texts, gt_bboxes = self.pre_process(
|
||||
img, relations, texts, gt_bboxes, tag, img_size)
|
||||
x = self.img_feat(img)
|
||||
boxes, rois_num = self.bbox2roi(gt_bboxes)
|
||||
feats = paddle.fluid.layers.roi_align(
|
||||
x,
|
||||
boxes,
|
||||
spatial_scale=1.0,
|
||||
pooled_height=7,
|
||||
pooled_width=7,
|
||||
rois_num=rois_num)
|
||||
feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
|
||||
return [relations, texts, feats]
|
||||
@@ -1,228 +0,0 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Code is refer from:
|
||||
https://github.com/RuijieJ/pren/blob/main/Nets/EfficientNet.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
from collections import namedtuple
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ['EfficientNetb3']
|
||||
|
||||
|
||||
class EffB3Params:
|
||||
@staticmethod
|
||||
def get_global_params():
|
||||
"""
|
||||
The fllowing are efficientnetb3's arch superparams, but to fit for scene
|
||||
text recognition task, the resolution(image_size) here is changed
|
||||
from 300 to 64.
|
||||
"""
|
||||
GlobalParams = namedtuple('GlobalParams', [
|
||||
'drop_connect_rate', 'width_coefficient', 'depth_coefficient',
|
||||
'depth_divisor', 'image_size'
|
||||
])
|
||||
global_params = GlobalParams(
|
||||
drop_connect_rate=0.3,
|
||||
width_coefficient=1.2,
|
||||
depth_coefficient=1.4,
|
||||
depth_divisor=8,
|
||||
image_size=64)
|
||||
return global_params
|
||||
|
||||
@staticmethod
|
||||
def get_block_params():
|
||||
BlockParams = namedtuple('BlockParams', [
|
||||
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
|
||||
'expand_ratio', 'id_skip', 'se_ratio', 'stride'
|
||||
])
|
||||
block_params = [
|
||||
BlockParams(3, 1, 32, 16, 1, True, 0.25, 1),
|
||||
BlockParams(3, 2, 16, 24, 6, True, 0.25, 2),
|
||||
BlockParams(5, 2, 24, 40, 6, True, 0.25, 2),
|
||||
BlockParams(3, 3, 40, 80, 6, True, 0.25, 2),
|
||||
BlockParams(5, 3, 80, 112, 6, True, 0.25, 1),
|
||||
BlockParams(5, 4, 112, 192, 6, True, 0.25, 2),
|
||||
BlockParams(3, 1, 192, 320, 6, True, 0.25, 1)
|
||||
]
|
||||
return block_params
|
||||
|
||||
|
||||
class EffUtils:
|
||||
@staticmethod
|
||||
def round_filters(filters, global_params):
|
||||
"""Calculate and round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.width_coefficient
|
||||
if not multiplier:
|
||||
return filters
|
||||
divisor = global_params.depth_divisor
|
||||
filters *= multiplier
|
||||
new_filters = int(filters + divisor / 2) // divisor * divisor
|
||||
if new_filters < 0.9 * filters:
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
|
||||
@staticmethod
|
||||
def round_repeats(repeats, global_params):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.depth_coefficient
|
||||
if not multiplier:
|
||||
return repeats
|
||||
return int(math.ceil(multiplier * repeats))
|
||||
|
||||
|
||||
class ConvBlock(nn.Layer):
|
||||
def __init__(self, block_params):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.block_args = block_params
|
||||
self.has_se = (self.block_args.se_ratio is not None) and \
|
||||
(0 < self.block_args.se_ratio <= 1)
|
||||
self.id_skip = block_params.id_skip
|
||||
|
||||
# expansion phase
|
||||
self.input_filters = self.block_args.input_filters
|
||||
output_filters = \
|
||||
self.block_args.input_filters * self.block_args.expand_ratio
|
||||
if self.block_args.expand_ratio != 1:
|
||||
self.expand_conv = nn.Conv2D(
|
||||
self.input_filters, output_filters, 1, bias_attr=False)
|
||||
self.bn0 = nn.BatchNorm(output_filters)
|
||||
|
||||
# depthwise conv phase
|
||||
k = self.block_args.kernel_size
|
||||
s = self.block_args.stride
|
||||
self.depthwise_conv = nn.Conv2D(
|
||||
output_filters,
|
||||
output_filters,
|
||||
groups=output_filters,
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
padding='same',
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm(output_filters)
|
||||
|
||||
# squeeze and excitation layer, if desired
|
||||
if self.has_se:
|
||||
num_squeezed_channels = max(1,
|
||||
int(self.block_args.input_filters *
|
||||
self.block_args.se_ratio))
|
||||
self.se_reduce = nn.Conv2D(output_filters, num_squeezed_channels, 1)
|
||||
self.se_expand = nn.Conv2D(num_squeezed_channels, output_filters, 1)
|
||||
|
||||
# output phase
|
||||
self.final_oup = self.block_args.output_filters
|
||||
self.project_conv = nn.Conv2D(
|
||||
output_filters, self.final_oup, 1, bias_attr=False)
|
||||
self.bn2 = nn.BatchNorm(self.final_oup)
|
||||
self.swish = nn.Swish()
|
||||
|
||||
def drop_connect(self, inputs, p, training):
|
||||
if not training:
|
||||
return inputs
|
||||
|
||||
batch_size = inputs.shape[0]
|
||||
keep_prob = 1 - p
|
||||
random_tensor = keep_prob
|
||||
random_tensor += paddle.rand([batch_size, 1, 1, 1], dtype=inputs.dtype)
|
||||
random_tensor = paddle.to_tensor(random_tensor, place=inputs.place)
|
||||
binary_tensor = paddle.floor(random_tensor)
|
||||
output = inputs / keep_prob * binary_tensor
|
||||
return output
|
||||
|
||||
def forward(self, inputs, drop_connect_rate=None):
|
||||
# expansion and depthwise conv
|
||||
x = inputs
|
||||
if self.block_args.expand_ratio != 1:
|
||||
x = self.swish(self.bn0(self.expand_conv(inputs)))
|
||||
x = self.swish(self.bn1(self.depthwise_conv(x)))
|
||||
|
||||
# squeeze and excitation
|
||||
if self.has_se:
|
||||
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
||||
x_squeezed = self.se_expand(self.swish(self.se_reduce(x_squeezed)))
|
||||
x = F.sigmoid(x_squeezed) * x
|
||||
x = self.bn2(self.project_conv(x))
|
||||
|
||||
# skip conntection and drop connect
|
||||
if self.id_skip and self.block_args.stride == 1 and \
|
||||
self.input_filters == self.final_oup:
|
||||
if drop_connect_rate:
|
||||
x = self.drop_connect(
|
||||
x, p=drop_connect_rate, training=self.training)
|
||||
x = x + inputs
|
||||
return x
|
||||
|
||||
|
||||
class EfficientNetb3_PREN(nn.Layer):
|
||||
def __init__(self, in_channels):
|
||||
super(EfficientNetb3_PREN, self).__init__()
|
||||
self.blocks_params = EffB3Params.get_block_params()
|
||||
self.global_params = EffB3Params.get_global_params()
|
||||
self.out_channels = []
|
||||
# stem
|
||||
stem_channels = EffUtils.round_filters(32, self.global_params)
|
||||
self.conv_stem = nn.Conv2D(
|
||||
in_channels, stem_channels, 3, 2, padding='same', bias_attr=False)
|
||||
self.bn0 = nn.BatchNorm(stem_channels)
|
||||
|
||||
self.blocks = []
|
||||
# to extract three feature maps for fpn based on efficientnetb3 backbone
|
||||
self.concerned_block_idxes = [7, 17, 25]
|
||||
concerned_idx = 0
|
||||
for i, block_params in enumerate(self.blocks_params):
|
||||
block_params = block_params._replace(
|
||||
input_filters=EffUtils.round_filters(block_params.input_filters,
|
||||
self.global_params),
|
||||
output_filters=EffUtils.round_filters(
|
||||
block_params.output_filters, self.global_params),
|
||||
num_repeat=EffUtils.round_repeats(block_params.num_repeat,
|
||||
self.global_params))
|
||||
self.blocks.append(
|
||||
self.add_sublayer("{}-0".format(i), ConvBlock(block_params)))
|
||||
concerned_idx += 1
|
||||
if concerned_idx in self.concerned_block_idxes:
|
||||
self.out_channels.append(block_params.output_filters)
|
||||
if block_params.num_repeat > 1:
|
||||
block_params = block_params._replace(
|
||||
input_filters=block_params.output_filters, stride=1)
|
||||
for j in range(block_params.num_repeat - 1):
|
||||
self.blocks.append(
|
||||
self.add_sublayer('{}-{}'.format(i, j + 1),
|
||||
ConvBlock(block_params)))
|
||||
concerned_idx += 1
|
||||
if concerned_idx in self.concerned_block_idxes:
|
||||
self.out_channels.append(block_params.output_filters)
|
||||
|
||||
self.swish = nn.Swish()
|
||||
|
||||
def forward(self, inputs):
|
||||
outs = []
|
||||
|
||||
x = self.swish(self.bn0(self.conv_stem(inputs)))
|
||||
for idx, block in enumerate(self.blocks):
|
||||
drop_connect_rate = self.global_params.drop_connect_rate
|
||||
if drop_connect_rate:
|
||||
drop_connect_rate *= float(idx) / len(self.blocks)
|
||||
x = block(x, drop_connect_rate=drop_connect_rate)
|
||||
if idx in self.concerned_block_idxes:
|
||||
outs.append(x)
|
||||
return outs
|
||||
@@ -1,528 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/liyunsheng13/micronet/blob/main/backbone/micronet.py
|
||||
https://github.com/liyunsheng13/micronet/blob/main/backbone/activation.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from ppocr.modeling.backbones.det_mobilenet_v3 import make_divisible
|
||||
|
||||
M0_cfgs = [
|
||||
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r
|
||||
[2, 1, 8, 3, 2, 2, 0, 4, 8, 2, 2, 2, 0, 1, 1],
|
||||
[2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 2, 1, 1],
|
||||
[2, 1, 16, 5, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
|
||||
[1, 1, 32, 5, 1, 4, 4, 4, 32, 4, 4, 2, 2, 1, 1],
|
||||
[2, 1, 64, 5, 1, 4, 8, 8, 64, 8, 8, 2, 2, 1, 1],
|
||||
[1, 1, 96, 3, 1, 4, 8, 8, 96, 8, 8, 2, 2, 1, 2],
|
||||
[1, 1, 384, 3, 1, 4, 12, 12, 0, 0, 0, 2, 2, 1, 2],
|
||||
]
|
||||
M1_cfgs = [
|
||||
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
|
||||
[2, 1, 8, 3, 2, 2, 0, 6, 8, 2, 2, 2, 0, 1, 1],
|
||||
[2, 1, 16, 3, 2, 2, 0, 8, 16, 4, 4, 2, 2, 1, 1],
|
||||
[2, 1, 16, 5, 2, 2, 0, 16, 16, 4, 4, 2, 2, 1, 1],
|
||||
[1, 1, 32, 5, 1, 6, 4, 4, 32, 4, 4, 2, 2, 1, 1],
|
||||
[2, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 1],
|
||||
[1, 1, 96, 3, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
|
||||
[1, 1, 576, 3, 1, 6, 12, 12, 0, 0, 0, 2, 2, 1, 2],
|
||||
]
|
||||
M2_cfgs = [
|
||||
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
|
||||
[2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 0, 1, 1],
|
||||
[2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
|
||||
[1, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 2, 2, 1, 1],
|
||||
[2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 2, 2, 1, 1],
|
||||
[1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 2, 2, 1, 2],
|
||||
[1, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 2],
|
||||
[2, 1, 96, 5, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
|
||||
[1, 1, 128, 3, 1, 6, 12, 12, 128, 8, 8, 2, 2, 1, 2],
|
||||
[1, 1, 768, 3, 1, 6, 16, 16, 0, 0, 0, 2, 2, 1, 2],
|
||||
]
|
||||
M3_cfgs = [
|
||||
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
|
||||
[2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 0, 2, 0, 1],
|
||||
[2, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 0, 2, 0, 1],
|
||||
[1, 1, 24, 3, 2, 2, 0, 24, 24, 4, 4, 0, 2, 0, 1],
|
||||
[2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 0, 2, 0, 1],
|
||||
[1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 0, 2, 0, 2],
|
||||
[1, 1, 64, 5, 1, 6, 8, 8, 48, 8, 8, 0, 2, 0, 2],
|
||||
[1, 1, 80, 5, 1, 6, 8, 8, 80, 8, 8, 0, 2, 0, 2],
|
||||
[1, 1, 80, 5, 1, 6, 10, 10, 80, 8, 8, 0, 2, 0, 2],
|
||||
[1, 1, 120, 5, 1, 6, 10, 10, 120, 10, 10, 0, 2, 0, 2],
|
||||
[1, 1, 120, 5, 1, 6, 12, 12, 120, 10, 10, 0, 2, 0, 2],
|
||||
[1, 1, 144, 3, 1, 6, 12, 12, 144, 12, 12, 0, 2, 0, 2],
|
||||
[1, 1, 432, 3, 1, 3, 12, 12, 0, 0, 0, 0, 2, 0, 2],
|
||||
]
|
||||
|
||||
|
||||
def get_micronet_config(mode):
|
||||
return eval(mode + '_cfgs')
|
||||
|
||||
|
||||
class MaxGroupPooling(nn.Layer):
|
||||
def __init__(self, channel_per_group=2):
|
||||
super(MaxGroupPooling, self).__init__()
|
||||
self.channel_per_group = channel_per_group
|
||||
|
||||
def forward(self, x):
|
||||
if self.channel_per_group == 1:
|
||||
return x
|
||||
# max op
|
||||
b, c, h, w = x.shape
|
||||
|
||||
# reshape
|
||||
y = paddle.reshape(x, [b, c // self.channel_per_group, -1, h, w])
|
||||
out = paddle.max(y, axis=2)
|
||||
return out
|
||||
|
||||
|
||||
class SpatialSepConvSF(nn.Layer):
|
||||
def __init__(self, inp, oups, kernel_size, stride):
|
||||
super(SpatialSepConvSF, self).__init__()
|
||||
|
||||
oup1, oup2 = oups
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
inp,
|
||||
oup1, (kernel_size, 1), (stride, 1), (kernel_size // 2, 0),
|
||||
bias_attr=False,
|
||||
groups=1),
|
||||
nn.BatchNorm2D(oup1),
|
||||
nn.Conv2D(
|
||||
oup1,
|
||||
oup1 * oup2, (1, kernel_size), (1, stride),
|
||||
(0, kernel_size // 2),
|
||||
bias_attr=False,
|
||||
groups=oup1),
|
||||
nn.BatchNorm2D(oup1 * oup2),
|
||||
ChannelShuffle(oup1), )
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ChannelShuffle(nn.Layer):
|
||||
def __init__(self, groups):
|
||||
super(ChannelShuffle, self).__init__()
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
channels_per_group = c // self.groups
|
||||
|
||||
# reshape
|
||||
x = paddle.reshape(x, [b, self.groups, channels_per_group, h, w])
|
||||
|
||||
x = paddle.transpose(x, (0, 2, 1, 3, 4))
|
||||
out = paddle.reshape(x, [b, -1, h, w])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class StemLayer(nn.Layer):
|
||||
def __init__(self, inp, oup, stride, groups=(4, 4)):
|
||||
super(StemLayer, self).__init__()
|
||||
|
||||
g1, g2 = groups
|
||||
self.stem = nn.Sequential(
|
||||
SpatialSepConvSF(inp, groups, 3, stride),
|
||||
MaxGroupPooling(2) if g1 * g2 == 2 * oup else nn.ReLU6())
|
||||
|
||||
def forward(self, x):
|
||||
out = self.stem(x)
|
||||
return out
|
||||
|
||||
|
||||
class DepthSpatialSepConv(nn.Layer):
|
||||
def __init__(self, inp, expand, kernel_size, stride):
|
||||
super(DepthSpatialSepConv, self).__init__()
|
||||
|
||||
exp1, exp2 = expand
|
||||
|
||||
hidden_dim = inp * exp1
|
||||
oup = inp * exp1 * exp2
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
inp,
|
||||
inp * exp1, (kernel_size, 1), (stride, 1),
|
||||
(kernel_size // 2, 0),
|
||||
bias_attr=False,
|
||||
groups=inp),
|
||||
nn.BatchNorm2D(inp * exp1),
|
||||
nn.Conv2D(
|
||||
hidden_dim,
|
||||
oup, (1, kernel_size),
|
||||
1, (0, kernel_size // 2),
|
||||
bias_attr=False,
|
||||
groups=hidden_dim),
|
||||
nn.BatchNorm2D(oup))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class GroupConv(nn.Layer):
|
||||
def __init__(self, inp, oup, groups=2):
|
||||
super(GroupConv, self).__init__()
|
||||
self.inp = inp
|
||||
self.oup = oup
|
||||
self.groups = groups
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
inp, oup, 1, 1, 0, bias_attr=False, groups=self.groups[0]),
|
||||
nn.BatchNorm2D(oup))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class DepthConv(nn.Layer):
|
||||
def __init__(self, inp, oup, kernel_size, stride):
|
||||
super(DepthConv, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
inp,
|
||||
oup,
|
||||
kernel_size,
|
||||
stride,
|
||||
kernel_size // 2,
|
||||
bias_attr=False,
|
||||
groups=inp),
|
||||
nn.BatchNorm2D(oup))
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class DYShiftMax(nn.Layer):
|
||||
def __init__(self,
|
||||
inp,
|
||||
oup,
|
||||
reduction=4,
|
||||
act_max=1.0,
|
||||
act_relu=True,
|
||||
init_a=[0.0, 0.0],
|
||||
init_b=[0.0, 0.0],
|
||||
relu_before_pool=False,
|
||||
g=None,
|
||||
expansion=False):
|
||||
super(DYShiftMax, self).__init__()
|
||||
self.oup = oup
|
||||
self.act_max = act_max * 2
|
||||
self.act_relu = act_relu
|
||||
self.avg_pool = nn.Sequential(nn.ReLU() if relu_before_pool == True else
|
||||
nn.Sequential(), nn.AdaptiveAvgPool2D(1))
|
||||
|
||||
self.exp = 4 if act_relu else 2
|
||||
self.init_a = init_a
|
||||
self.init_b = init_b
|
||||
|
||||
# determine squeeze
|
||||
squeeze = make_divisible(inp // reduction, 4)
|
||||
if squeeze < 4:
|
||||
squeeze = 4
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(inp, squeeze),
|
||||
nn.ReLU(), nn.Linear(squeeze, oup * self.exp), nn.Hardsigmoid())
|
||||
|
||||
if g is None:
|
||||
g = 1
|
||||
self.g = g[1]
|
||||
if self.g != 1 and expansion:
|
||||
self.g = inp // self.g
|
||||
|
||||
self.gc = inp // self.g
|
||||
index = paddle.to_tensor([range(inp)])
|
||||
index = paddle.reshape(index, [1, inp, 1, 1])
|
||||
index = paddle.reshape(index, [1, self.g, self.gc, 1, 1])
|
||||
indexgs = paddle.split(index, [1, self.g - 1], axis=1)
|
||||
indexgs = paddle.concat((indexgs[1], indexgs[0]), axis=1)
|
||||
indexs = paddle.split(indexgs, [1, self.gc - 1], axis=2)
|
||||
indexs = paddle.concat((indexs[1], indexs[0]), axis=2)
|
||||
self.index = paddle.reshape(indexs, [inp])
|
||||
self.expansion = expansion
|
||||
|
||||
def forward(self, x):
|
||||
x_in = x
|
||||
x_out = x
|
||||
|
||||
b, c, _, _ = x_in.shape
|
||||
y = self.avg_pool(x_in)
|
||||
y = paddle.reshape(y, [b, c])
|
||||
y = self.fc(y)
|
||||
y = paddle.reshape(y, [b, self.oup * self.exp, 1, 1])
|
||||
y = (y - 0.5) * self.act_max
|
||||
|
||||
n2, c2, h2, w2 = x_out.shape
|
||||
x2 = paddle.to_tensor(x_out.numpy()[:, self.index.numpy(), :, :])
|
||||
|
||||
if self.exp == 4:
|
||||
temp = y.shape
|
||||
a1, b1, a2, b2 = paddle.split(y, temp[1] // self.oup, axis=1)
|
||||
|
||||
a1 = a1 + self.init_a[0]
|
||||
a2 = a2 + self.init_a[1]
|
||||
|
||||
b1 = b1 + self.init_b[0]
|
||||
b2 = b2 + self.init_b[1]
|
||||
|
||||
z1 = x_out * a1 + x2 * b1
|
||||
z2 = x_out * a2 + x2 * b2
|
||||
|
||||
out = paddle.maximum(z1, z2)
|
||||
|
||||
elif self.exp == 2:
|
||||
temp = y.shape
|
||||
a1, b1 = paddle.split(y, temp[1] // self.oup, axis=1)
|
||||
a1 = a1 + self.init_a[0]
|
||||
b1 = b1 + self.init_b[0]
|
||||
out = x_out * a1 + x2 * b1
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DYMicroBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
inp,
|
||||
oup,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
ch_exp=(2, 2),
|
||||
ch_per_group=4,
|
||||
groups_1x1=(1, 1),
|
||||
depthsep=True,
|
||||
shuffle=False,
|
||||
activation_cfg=None):
|
||||
super(DYMicroBlock, self).__init__()
|
||||
|
||||
self.identity = stride == 1 and inp == oup
|
||||
|
||||
y1, y2, y3 = activation_cfg['dy']
|
||||
act_reduction = 8 * activation_cfg['ratio']
|
||||
init_a = activation_cfg['init_a']
|
||||
init_b = activation_cfg['init_b']
|
||||
|
||||
t1 = ch_exp
|
||||
gs1 = ch_per_group
|
||||
hidden_fft, g1, g2 = groups_1x1
|
||||
hidden_dim2 = inp * t1[0] * t1[1]
|
||||
|
||||
if gs1[0] == 0:
|
||||
self.layers = nn.Sequential(
|
||||
DepthSpatialSepConv(inp, t1, kernel_size, stride),
|
||||
DYShiftMax(
|
||||
hidden_dim2,
|
||||
hidden_dim2,
|
||||
act_max=2.0,
|
||||
act_relu=True if y2 == 2 else False,
|
||||
init_a=init_a,
|
||||
reduction=act_reduction,
|
||||
init_b=init_b,
|
||||
g=gs1,
|
||||
expansion=False) if y2 > 0 else nn.ReLU6(),
|
||||
ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
|
||||
ChannelShuffle(hidden_dim2 // 2)
|
||||
if shuffle and y2 != 0 else nn.Sequential(),
|
||||
GroupConv(hidden_dim2, oup, (g1, g2)),
|
||||
DYShiftMax(
|
||||
oup,
|
||||
oup,
|
||||
act_max=2.0,
|
||||
act_relu=False,
|
||||
init_a=[1.0, 0.0],
|
||||
reduction=act_reduction // 2,
|
||||
init_b=[0.0, 0.0],
|
||||
g=(g1, g2),
|
||||
expansion=False) if y3 > 0 else nn.Sequential(),
|
||||
ChannelShuffle(g2) if shuffle else nn.Sequential(),
|
||||
ChannelShuffle(oup // 2)
|
||||
if shuffle and oup % 2 == 0 and y3 != 0 else nn.Sequential(), )
|
||||
elif g2 == 0:
|
||||
self.layers = nn.Sequential(
|
||||
GroupConv(inp, hidden_dim2, gs1),
|
||||
DYShiftMax(
|
||||
hidden_dim2,
|
||||
hidden_dim2,
|
||||
act_max=2.0,
|
||||
act_relu=False,
|
||||
init_a=[1.0, 0.0],
|
||||
reduction=act_reduction,
|
||||
init_b=[0.0, 0.0],
|
||||
g=gs1,
|
||||
expansion=False) if y3 > 0 else nn.Sequential(), )
|
||||
else:
|
||||
self.layers = nn.Sequential(
|
||||
GroupConv(inp, hidden_dim2, gs1),
|
||||
DYShiftMax(
|
||||
hidden_dim2,
|
||||
hidden_dim2,
|
||||
act_max=2.0,
|
||||
act_relu=True if y1 == 2 else False,
|
||||
init_a=init_a,
|
||||
reduction=act_reduction,
|
||||
init_b=init_b,
|
||||
g=gs1,
|
||||
expansion=False) if y1 > 0 else nn.ReLU6(),
|
||||
ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
|
||||
DepthSpatialSepConv(hidden_dim2, (1, 1), kernel_size, stride)
|
||||
if depthsep else
|
||||
DepthConv(hidden_dim2, hidden_dim2, kernel_size, stride),
|
||||
nn.Sequential(),
|
||||
DYShiftMax(
|
||||
hidden_dim2,
|
||||
hidden_dim2,
|
||||
act_max=2.0,
|
||||
act_relu=True if y2 == 2 else False,
|
||||
init_a=init_a,
|
||||
reduction=act_reduction,
|
||||
init_b=init_b,
|
||||
g=gs1,
|
||||
expansion=True) if y2 > 0 else nn.ReLU6(),
|
||||
ChannelShuffle(hidden_dim2 // 4)
|
||||
if shuffle and y1 != 0 and y2 != 0 else nn.Sequential()
|
||||
if y1 == 0 and y2 == 0 else ChannelShuffle(hidden_dim2 // 2),
|
||||
GroupConv(hidden_dim2, oup, (g1, g2)),
|
||||
DYShiftMax(
|
||||
oup,
|
||||
oup,
|
||||
act_max=2.0,
|
||||
act_relu=False,
|
||||
init_a=[1.0, 0.0],
|
||||
reduction=act_reduction // 2
|
||||
if oup < hidden_dim2 else act_reduction,
|
||||
init_b=[0.0, 0.0],
|
||||
g=(g1, g2),
|
||||
expansion=False) if y3 > 0 else nn.Sequential(),
|
||||
ChannelShuffle(g2) if shuffle else nn.Sequential(),
|
||||
ChannelShuffle(oup // 2)
|
||||
if shuffle and y3 != 0 else nn.Sequential(), )
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = self.layers(x)
|
||||
|
||||
if self.identity:
|
||||
out = out + identity
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MicroNet(nn.Layer):
|
||||
"""
|
||||
the MicroNet backbone network for recognition module.
|
||||
Args:
|
||||
mode(str): {'M0', 'M1', 'M2', 'M3'}
|
||||
Four models are proposed based on four different computational costs (4M, 6M, 12M, 21M MAdds)
|
||||
Default: 'M3'.
|
||||
"""
|
||||
|
||||
def __init__(self, mode='M3', **kwargs):
|
||||
super(MicroNet, self).__init__()
|
||||
|
||||
self.cfgs = get_micronet_config(mode)
|
||||
|
||||
activation_cfg = {}
|
||||
if mode == 'M0':
|
||||
input_channel = 4
|
||||
stem_groups = 2, 2
|
||||
out_ch = 384
|
||||
activation_cfg['init_a'] = 1.0, 1.0
|
||||
activation_cfg['init_b'] = 0.0, 0.0
|
||||
elif mode == 'M1':
|
||||
input_channel = 6
|
||||
stem_groups = 3, 2
|
||||
out_ch = 576
|
||||
activation_cfg['init_a'] = 1.0, 1.0
|
||||
activation_cfg['init_b'] = 0.0, 0.0
|
||||
elif mode == 'M2':
|
||||
input_channel = 8
|
||||
stem_groups = 4, 2
|
||||
out_ch = 768
|
||||
activation_cfg['init_a'] = 1.0, 1.0
|
||||
activation_cfg['init_b'] = 0.0, 0.0
|
||||
elif mode == 'M3':
|
||||
input_channel = 12
|
||||
stem_groups = 4, 3
|
||||
out_ch = 432
|
||||
activation_cfg['init_a'] = 1.0, 0.5
|
||||
activation_cfg['init_b'] = 0.0, 0.5
|
||||
else:
|
||||
raise NotImplementedError("mode[" + mode +
|
||||
"_model] is not implemented!")
|
||||
|
||||
layers = [StemLayer(3, input_channel, stride=2, groups=stem_groups)]
|
||||
|
||||
for idx, val in enumerate(self.cfgs):
|
||||
s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r = val
|
||||
|
||||
t1 = (c1, c2)
|
||||
gs1 = (g1, g2)
|
||||
gs2 = (c3, g3, g4)
|
||||
activation_cfg['dy'] = [y1, y2, y3]
|
||||
activation_cfg['ratio'] = r
|
||||
|
||||
output_channel = c
|
||||
layers.append(
|
||||
DYMicroBlock(
|
||||
input_channel,
|
||||
output_channel,
|
||||
kernel_size=ks,
|
||||
stride=s,
|
||||
ch_exp=t1,
|
||||
ch_per_group=gs1,
|
||||
groups_1x1=gs2,
|
||||
depthsep=True,
|
||||
shuffle=True,
|
||||
activation_cfg=activation_cfg, ))
|
||||
input_channel = output_channel
|
||||
for i in range(1, n):
|
||||
layers.append(
|
||||
DYMicroBlock(
|
||||
input_channel,
|
||||
output_channel,
|
||||
kernel_size=ks,
|
||||
stride=1,
|
||||
ch_exp=t1,
|
||||
ch_per_group=gs1,
|
||||
groups_1x1=gs2,
|
||||
depthsep=True,
|
||||
shuffle=True,
|
||||
activation_cfg=activation_cfg, ))
|
||||
input_channel = output_channel
|
||||
self.features = nn.Sequential(*layers)
|
||||
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
|
||||
self.out_channels = make_divisible(out_ch)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
@@ -1,138 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import nn
|
||||
|
||||
from ppocr.modeling.backbones.det_mobilenet_v3 import ResidualUnit, ConvBNLayer, make_divisible
|
||||
|
||||
__all__ = ['MobileNetV3']
|
||||
|
||||
|
||||
class MobileNetV3(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
model_name='small',
|
||||
scale=0.5,
|
||||
large_stride=None,
|
||||
small_stride=None,
|
||||
disable_se=False,
|
||||
**kwargs):
|
||||
super(MobileNetV3, self).__init__()
|
||||
self.disable_se = disable_se
|
||||
if small_stride is None:
|
||||
small_stride = [2, 2, 2, 2]
|
||||
if large_stride is None:
|
||||
large_stride = [1, 2, 2, 2]
|
||||
|
||||
assert isinstance(large_stride, list), "large_stride type must " \
|
||||
"be list but got {}".format(type(large_stride))
|
||||
assert isinstance(small_stride, list), "small_stride type must " \
|
||||
"be list but got {}".format(type(small_stride))
|
||||
assert len(large_stride) == 4, "large_stride length must be " \
|
||||
"4 but got {}".format(len(large_stride))
|
||||
assert len(small_stride) == 4, "small_stride length must be " \
|
||||
"4 but got {}".format(len(small_stride))
|
||||
|
||||
if model_name == "large":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, False, 'relu', large_stride[0]],
|
||||
[3, 64, 24, False, 'relu', (large_stride[1], 1)],
|
||||
[3, 72, 24, False, 'relu', 1],
|
||||
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[3, 240, 80, False, 'hardswish', 1],
|
||||
[3, 200, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 480, 112, True, 'hardswish', 1],
|
||||
[3, 672, 112, True, 'hardswish', 1],
|
||||
[5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 960
|
||||
elif model_name == "small":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
|
||||
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
|
||||
[3, 88, 24, False, 'relu', 1],
|
||||
[5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 120, 48, True, 'hardswish', 1],
|
||||
[5, 144, 48, True, 'hardswish', 1],
|
||||
[5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 576
|
||||
else:
|
||||
raise NotImplementedError("mode[" + model_name +
|
||||
"_model] is not implemented!")
|
||||
|
||||
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
|
||||
assert scale in supported_scale, \
|
||||
"supported scales are {} but input scale is {}".format(supported_scale, scale)
|
||||
|
||||
inplanes = 16
|
||||
# conv1
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=make_divisible(inplanes * scale),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish')
|
||||
i = 0
|
||||
block_list = []
|
||||
inplanes = make_divisible(inplanes * scale)
|
||||
for (k, exp, c, se, nl, s) in cfg:
|
||||
se = se and not self.disable_se
|
||||
block_list.append(
|
||||
ResidualUnit(
|
||||
in_channels=inplanes,
|
||||
mid_channels=make_divisible(scale * exp),
|
||||
out_channels=make_divisible(scale * c),
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl))
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
self.blocks = nn.Sequential(*block_list)
|
||||
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=inplanes,
|
||||
out_channels=make_divisible(scale * cls_ch_squeeze),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish')
|
||||
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = make_divisible(scale * cls_ch_squeeze)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.blocks(x)
|
||||
x = self.conv2(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
@@ -1,256 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This code is refer from: https://github.com/PaddlePaddle/PaddleClas/blob/develop/ppcls/arch/backbone/legendary_models/pp_lcnet.py
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr, reshape, transpose
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
from paddle.regularizer import L2Decay
|
||||
from paddle.nn.functional import hardswish, hardsigmoid
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
padding,
|
||||
channels=None,
|
||||
num_groups=1,
|
||||
act='hard_swish'):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
self._batch_norm = BatchNorm(
|
||||
num_filters,
|
||||
act=act,
|
||||
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters1,
|
||||
num_filters2,
|
||||
num_groups,
|
||||
stride,
|
||||
scale,
|
||||
dw_size=3,
|
||||
padding=1,
|
||||
use_se=False):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self._depthwise_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=int(num_filters1 * scale),
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
num_groups=int(num_groups * scale))
|
||||
if use_se:
|
||||
self._se = SEModule(int(num_filters1 * scale))
|
||||
self._pointwise_conv = ConvBNLayer(
|
||||
num_channels=int(num_filters1 * scale),
|
||||
filter_size=1,
|
||||
num_filters=int(num_filters2 * scale),
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._depthwise_conv(inputs)
|
||||
if self.use_se:
|
||||
y = self._se(y)
|
||||
y = self._pointwise_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class MobileNetV1Enhance(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
scale=0.5,
|
||||
last_conv_stride=1,
|
||||
last_pool_type='max',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.block_list = []
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=3,
|
||||
filter_size=3,
|
||||
channels=3,
|
||||
num_filters=int(32 * scale),
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
conv2_1 = DepthwiseSeparable(
|
||||
num_channels=int(32 * scale),
|
||||
num_filters1=32,
|
||||
num_filters2=64,
|
||||
num_groups=32,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_1)
|
||||
|
||||
conv2_2 = DepthwiseSeparable(
|
||||
num_channels=int(64 * scale),
|
||||
num_filters1=64,
|
||||
num_filters2=128,
|
||||
num_groups=64,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_2)
|
||||
|
||||
conv3_1 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=128,
|
||||
num_groups=128,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_1)
|
||||
|
||||
conv3_2 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=256,
|
||||
num_groups=128,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_2)
|
||||
|
||||
conv4_1 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=256,
|
||||
num_groups=256,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_1)
|
||||
|
||||
conv4_2 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=512,
|
||||
num_groups=256,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_2)
|
||||
|
||||
for _ in range(5):
|
||||
conv5 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=512,
|
||||
num_groups=512,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=False)
|
||||
self.block_list.append(conv5)
|
||||
|
||||
conv5_6 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=1024,
|
||||
num_groups=512,
|
||||
stride=(2, 1),
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=True)
|
||||
self.block_list.append(conv5_6)
|
||||
|
||||
conv6 = DepthwiseSeparable(
|
||||
num_channels=int(1024 * scale),
|
||||
num_filters1=1024,
|
||||
num_filters2=1024,
|
||||
num_groups=1024,
|
||||
stride=last_conv_stride,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
use_se=True,
|
||||
scale=scale)
|
||||
self.block_list.append(conv6)
|
||||
|
||||
self.block_list = nn.Sequential(*self.block_list)
|
||||
if last_pool_type == 'avg':
|
||||
self.pool = nn.AvgPool2D(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = int(1024 * scale)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1(inputs)
|
||||
y = self.block_list(y)
|
||||
y = self.pool(y)
|
||||
return y
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = hardsigmoid(outputs)
|
||||
return paddle.multiply(x=inputs, y=outputs)
|
||||
@@ -1,48 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import nn
|
||||
import paddle
|
||||
|
||||
|
||||
class MTB(nn.Layer):
|
||||
def __init__(self, cnn_num, in_channels):
|
||||
super(MTB, self).__init__()
|
||||
self.block = nn.Sequential()
|
||||
self.out_channels = in_channels
|
||||
self.cnn_num = cnn_num
|
||||
if self.cnn_num == 2:
|
||||
for i in range(self.cnn_num):
|
||||
self.block.add_sublayer(
|
||||
'conv_{}'.format(i),
|
||||
nn.Conv2D(
|
||||
in_channels=in_channels
|
||||
if i == 0 else 32 * (2**(i - 1)),
|
||||
out_channels=32 * (2**i),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1))
|
||||
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
|
||||
self.block.add_sublayer('bn_{}'.format(i),
|
||||
nn.BatchNorm2D(32 * (2**i)))
|
||||
|
||||
def forward(self, images):
|
||||
x = self.block(images)
|
||||
if self.cnn_num == 2:
|
||||
# (b, w, h, c)
|
||||
x = paddle.transpose(x, [0, 3, 2, 1])
|
||||
x_shape = paddle.shape(x)
|
||||
x = paddle.reshape(
|
||||
x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
|
||||
return x
|
||||
@@ -1,210 +0,0 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["ResNet31"]
|
||||
|
||||
|
||||
def conv3x3(in_channel, out_channel, stride=1):
|
||||
return nn.Conv2D(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, channels, stride=1, downsample=False):
|
||||
super().__init__()
|
||||
self.conv1 = conv3x3(in_channels, channels, stride)
|
||||
self.bn1 = nn.BatchNorm2D(channels)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = conv3x3(channels, channels)
|
||||
self.bn2 = nn.BatchNorm2D(channels)
|
||||
self.downsample = downsample
|
||||
if downsample:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
in_channels,
|
||||
channels * self.expansion,
|
||||
1,
|
||||
stride,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(channels * self.expansion), )
|
||||
else:
|
||||
self.downsample = nn.Sequential()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet31(nn.Layer):
|
||||
'''
|
||||
Args:
|
||||
in_channels (int): Number of channels of input image tensor.
|
||||
layers (list[int]): List of BasicBlock number for each stage.
|
||||
channels (list[int]): List of out_channels of Conv2d layer.
|
||||
out_indices (None | Sequence[int]): Indices of output stages.
|
||||
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
layers=[1, 2, 5, 3],
|
||||
channels=[64, 128, 256, 256, 512, 512, 512],
|
||||
out_indices=None,
|
||||
last_stage_pool=False):
|
||||
super(ResNet31, self).__init__()
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(last_stage_pool, bool)
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.last_stage_pool = last_stage_pool
|
||||
|
||||
# conv 1 (Conv Conv)
|
||||
self.conv1_1 = nn.Conv2D(
|
||||
in_channels, channels[0], kernel_size=3, stride=1, padding=1)
|
||||
self.bn1_1 = nn.BatchNorm2D(channels[0])
|
||||
self.relu1_1 = nn.ReLU()
|
||||
|
||||
self.conv1_2 = nn.Conv2D(
|
||||
channels[0], channels[1], kernel_size=3, stride=1, padding=1)
|
||||
self.bn1_2 = nn.BatchNorm2D(channels[1])
|
||||
self.relu1_2 = nn.ReLU()
|
||||
|
||||
# conv 2 (Max-pooling, Residual block, Conv)
|
||||
self.pool2 = nn.MaxPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
|
||||
self.conv2 = nn.Conv2D(
|
||||
channels[2], channels[2], kernel_size=3, stride=1, padding=1)
|
||||
self.bn2 = nn.BatchNorm2D(channels[2])
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
# conv 3 (Max-pooling, Residual block, Conv)
|
||||
self.pool3 = nn.MaxPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
|
||||
self.conv3 = nn.Conv2D(
|
||||
channels[3], channels[3], kernel_size=3, stride=1, padding=1)
|
||||
self.bn3 = nn.BatchNorm2D(channels[3])
|
||||
self.relu3 = nn.ReLU()
|
||||
|
||||
# conv 4 (Max-pooling, Residual block, Conv)
|
||||
self.pool4 = nn.MaxPool2D(
|
||||
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
|
||||
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
|
||||
self.conv4 = nn.Conv2D(
|
||||
channels[4], channels[4], kernel_size=3, stride=1, padding=1)
|
||||
self.bn4 = nn.BatchNorm2D(channels[4])
|
||||
self.relu4 = nn.ReLU()
|
||||
|
||||
# conv 5 ((Max-pooling), Residual block, Conv)
|
||||
self.pool5 = None
|
||||
if self.last_stage_pool:
|
||||
self.pool5 = nn.MaxPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
|
||||
self.conv5 = nn.Conv2D(
|
||||
channels[5], channels[5], kernel_size=3, stride=1, padding=1)
|
||||
self.bn5 = nn.BatchNorm2D(channels[5])
|
||||
self.relu5 = nn.ReLU()
|
||||
|
||||
self.out_channels = channels[-1]
|
||||
|
||||
def _make_layer(self, input_channels, output_channels, blocks):
|
||||
layers = []
|
||||
for _ in range(blocks):
|
||||
downsample = None
|
||||
if input_channels != output_channels:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
input_channels,
|
||||
output_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(output_channels), )
|
||||
|
||||
layers.append(
|
||||
BasicBlock(
|
||||
input_channels, output_channels, downsample=downsample))
|
||||
input_channels = output_channels
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1_1(x)
|
||||
x = self.bn1_1(x)
|
||||
x = self.relu1_1(x)
|
||||
|
||||
x = self.conv1_2(x)
|
||||
x = self.bn1_2(x)
|
||||
x = self.relu1_2(x)
|
||||
|
||||
outs = []
|
||||
for i in range(4):
|
||||
layer_index = i + 2
|
||||
pool_layer = getattr(self, f'pool{layer_index}')
|
||||
block_layer = getattr(self, f'block{layer_index}')
|
||||
conv_layer = getattr(self, f'conv{layer_index}')
|
||||
bn_layer = getattr(self, f'bn{layer_index}')
|
||||
relu_layer = getattr(self, f'relu{layer_index}')
|
||||
|
||||
if pool_layer is not None:
|
||||
x = pool_layer(x)
|
||||
x = block_layer(x)
|
||||
x = conv_layer(x)
|
||||
x = bn_layer(x)
|
||||
x = relu_layer(x)
|
||||
|
||||
outs.append(x)
|
||||
|
||||
if self.out_indices is not None:
|
||||
return tuple([outs[i] for i in self.out_indices])
|
||||
|
||||
return x
|
||||
@@ -1,143 +0,0 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/resnet_aster.py
|
||||
"""
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
import sys
|
||||
import math
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2D(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2D(
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
|
||||
|
||||
|
||||
def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
|
||||
# [n_position]
|
||||
positions = paddle.arange(0, n_position)
|
||||
# [feat_dim]
|
||||
dim_range = paddle.arange(0, feat_dim)
|
||||
dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
|
||||
# [n_position, feat_dim]
|
||||
angles = paddle.unsqueeze(
|
||||
positions, axis=1) / paddle.unsqueeze(
|
||||
dim_range, axis=0)
|
||||
angles = paddle.cast(angles, "float32")
|
||||
angles[:, 0::2] = paddle.sin(angles[:, 0::2])
|
||||
angles[:, 1::2] = paddle.cos(angles[:, 1::2])
|
||||
return angles
|
||||
|
||||
|
||||
class AsterBlock(nn.Layer):
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(AsterBlock, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2D(planes)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2D(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_ASTER(nn.Layer):
|
||||
"""For aster or crnn"""
|
||||
|
||||
def __init__(self, with_lstm=True, n_group=1, in_channels=3):
|
||||
super(ResNet_ASTER, self).__init__()
|
||||
self.with_lstm = with_lstm
|
||||
self.n_group = n_group
|
||||
|
||||
self.layer0 = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
in_channels,
|
||||
32,
|
||||
kernel_size=(3, 3),
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(32),
|
||||
nn.ReLU())
|
||||
|
||||
self.inplanes = 32
|
||||
self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
|
||||
self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
|
||||
self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
|
||||
self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
|
||||
self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
|
||||
|
||||
if with_lstm:
|
||||
self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
|
||||
self.out_channels = 2 * 256
|
||||
else:
|
||||
self.out_channels = 512
|
||||
|
||||
def _make_layer(self, planes, blocks, stride):
|
||||
downsample = None
|
||||
if stride != [1, 1] or self.inplanes != planes:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
|
||||
|
||||
layers = []
|
||||
layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
layers.append(AsterBlock(self.inplanes, planes))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.layer0(x)
|
||||
x1 = self.layer1(x0)
|
||||
x2 = self.layer2(x1)
|
||||
x3 = self.layer3(x2)
|
||||
x4 = self.layer4(x3)
|
||||
x5 = self.layer5(x4)
|
||||
|
||||
cnn_feat = x5.squeeze(2) # [N, c, w]
|
||||
cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
|
||||
if self.with_lstm:
|
||||
rnn_feat, _ = self.rnn(cnn_feat)
|
||||
return rnn_feat
|
||||
else:
|
||||
return cnn_feat
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user