mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-24 03:14:41 +08:00
init
This commit is contained in:
6
backend/ppocr/postprocess/pse_postprocess/pse/README.md
Normal file
6
backend/ppocr/postprocess/pse_postprocess/pse/README.md
Normal file
@@ -0,0 +1,6 @@
|
||||
## 编译
|
||||
This code is refer from:
|
||||
https://github.com/whai362/PSENet/blob/python3/models/post_processing/pse
|
||||
```python
|
||||
python3 setup.py build_ext --inplace
|
||||
```
|
||||
29
backend/ppocr/postprocess/pse_postprocess/pse/__init__.py
Normal file
29
backend/ppocr/postprocess/pse_postprocess/pse/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
python_path = sys.executable
|
||||
|
||||
ori_path = os.getcwd()
|
||||
os.chdir('ppocr/postprocess/pse_postprocess/pse')
|
||||
if subprocess.call(
|
||||
'{} setup.py build_ext --inplace'.format(python_path), shell=True) != 0:
|
||||
raise RuntimeError(
|
||||
'Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+'.
|
||||
format(os.path.dirname(os.path.realpath(__file__))))
|
||||
os.chdir(ori_path)
|
||||
|
||||
from .pse import pse
|
||||
70
backend/ppocr/postprocess/pse_postprocess/pse/pse.pyx
Normal file
70
backend/ppocr/postprocess/pse_postprocess/pse/pse.pyx
Normal file
@@ -0,0 +1,70 @@
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
cimport numpy as np
|
||||
cimport cython
|
||||
cimport libcpp
|
||||
cimport libcpp.pair
|
||||
cimport libcpp.queue
|
||||
from libcpp.pair cimport *
|
||||
from libcpp.queue cimport *
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
|
||||
np.ndarray[np.int32_t, ndim=2] label,
|
||||
int kernel_num,
|
||||
int label_num,
|
||||
float min_area=0):
|
||||
cdef np.ndarray[np.int32_t, ndim=2] pred
|
||||
pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
|
||||
|
||||
for label_idx in range(1, label_num):
|
||||
if np.sum(label == label_idx) < min_area:
|
||||
label[label == label_idx] = 0
|
||||
|
||||
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
|
||||
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
|
||||
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
|
||||
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
|
||||
cdef np.int16_t* dx = [-1, 1, 0, 0]
|
||||
cdef np.int16_t* dy = [0, 0, -1, 1]
|
||||
cdef np.int16_t tmpx, tmpy
|
||||
|
||||
points = np.array(np.where(label > 0)).transpose((1, 0))
|
||||
for point_idx in range(points.shape[0]):
|
||||
tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
|
||||
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
|
||||
pred[tmpx, tmpy] = label[tmpx, tmpy]
|
||||
|
||||
cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
|
||||
cdef int cur_label
|
||||
for kernel_idx in range(kernel_num - 1, -1, -1):
|
||||
while not que.empty():
|
||||
cur = que.front()
|
||||
que.pop()
|
||||
cur_label = pred[cur.first, cur.second]
|
||||
|
||||
is_edge = True
|
||||
for j in range(4):
|
||||
tmpx = cur.first + dx[j]
|
||||
tmpy = cur.second + dy[j]
|
||||
if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
|
||||
continue
|
||||
if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
|
||||
continue
|
||||
|
||||
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
|
||||
pred[tmpx, tmpy] = cur_label
|
||||
is_edge = False
|
||||
if is_edge:
|
||||
nxt_que.push(cur)
|
||||
|
||||
que, nxt_que = nxt_que, que
|
||||
|
||||
return pred
|
||||
|
||||
def pse(kernels, min_area):
|
||||
kernel_num = kernels.shape[0]
|
||||
label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
|
||||
return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
|
||||
14
backend/ppocr/postprocess/pse_postprocess/pse/setup.py
Normal file
14
backend/ppocr/postprocess/pse_postprocess/pse/setup.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from distutils.core import setup, Extension
|
||||
from Cython.Build import cythonize
|
||||
import numpy
|
||||
|
||||
setup(ext_modules=cythonize(Extension(
|
||||
'pse',
|
||||
sources=['pse.pyx'],
|
||||
language='c++',
|
||||
include_dirs=[numpy.get_include()],
|
||||
library_dirs=[],
|
||||
libraries=[],
|
||||
extra_compile_args=['-O3'],
|
||||
extra_link_args=[]
|
||||
)))
|
||||
Reference in New Issue
Block a user