This commit is contained in:
YaoFANGUK
2023-10-25 16:38:16 +08:00
commit 2b9360c299
602 changed files with 152490 additions and 0 deletions

View File

@@ -0,0 +1,61 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
__all__ = ['build_post_process']
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess',
'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode'
]
if config['name'] == 'PSEPostProcess':
from .pse_postprocess import PSEPostProcess
support_dict.append('PSEPostProcess')
config = copy.deepcopy(config)
module_name = config.pop('name')
if module_name == "None":
return
if global_config is not None:
config.update(global_config)
assert module_name in support_dict, Exception(
'post process only support {}'.format(support_dict))
module_class = eval(module_name)(**config)
return module_class

View File

@@ -0,0 +1,42 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class ClsPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, label_list=None, key=None, **kwargs):
super(ClsPostProcess, self).__init__()
self.label_list = label_list
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
if self.key is not None:
preds = preds[self.key]
label_list = self.label_list
if label_list is None:
label_list = {idx: idx for idx in range(preds.shape[-1])}
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
pred_idxs = preds.argmax(axis=1)
decode_out = [(label_list[idx], preds[i, idx])
for i, idx in enumerate(pred_idxs)]
if label is None:
return decode_out
label = [(label_list[idx], 1.0) for idx in label]
return decode_out, label

View File

@@ -0,0 +1,220 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refered from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
import paddle
from shapely.geometry import Polygon
import pyclipper
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
def unclip(self, box):
unclip_ratio = self.unclip_ratio
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
else:
mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
boxes_batch.append({'points': boxes})
return boxes_batch
class DistillationDBPostProcess(object):
def __init__(self,
model_name=["student"],
key=None,
thresh=0.3,
box_thresh=0.6,
max_candidates=1000,
unclip_ratio=1.5,
use_dilation=False,
score_mode="fast",
**kwargs):
self.model_name = model_name
self.key = key
self.post_process = DBPostProcess(
thresh=thresh,
box_thresh=box_thresh,
max_candidates=max_candidates,
unclip_ratio=unclip_ratio,
use_dilation=use_dilation,
score_mode=score_mode)
def __call__(self, predicts, shape_list):
results = {}
for k in self.model_name:
results[k] = self.post_process(predicts[k], shape_list=shape_list)
return results

View File

@@ -0,0 +1,143 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from .locality_aware_nms import nms_locality
import cv2
import paddle
import os
import sys
class EASTPostProcess(object):
"""
The post process for EAST.
"""
def __init__(self,
score_thresh=0.8,
cover_thresh=0.1,
nms_thresh=0.2,
**kwargs):
self.score_thresh = score_thresh
self.cover_thresh = cover_thresh
self.nms_thresh = nms_thresh
def restore_rectangle_quad(self, origin, geometry):
"""
Restore rectangle from quadrangle.
"""
# quad
origin_concat = np.concatenate(
(origin, origin, origin, origin), axis=1) # (n, 8)
pred_quads = origin_concat - geometry
pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
return pred_quads
def detect(self,
score_map,
geo_map,
score_thresh=0.8,
cover_thresh=0.1,
nms_thresh=0.2):
"""
restore text boxes from score map and geo map
"""
score_map = score_map[0]
geo_map = np.swapaxes(geo_map, 1, 0)
geo_map = np.swapaxes(geo_map, 1, 2)
# filter the score map
xy_text = np.argwhere(score_map > score_thresh)
if len(xy_text) == 0:
return []
# sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 0])]
#restore quad proposals
text_box_restored = self.restore_rectangle_quad(
xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
try:
import lanms
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
except:
print(
'you should install lanms by pip3 install lanms-nova to speed up nms_locality'
)
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
if boxes.shape[0] == 0:
return []
# Here we filter some low score boxes by the average score map,
# this is different from the orginal paper.
for i, box in enumerate(boxes):
mask = np.zeros_like(score_map, dtype=np.uint8)
cv2.fillPoly(mask, box[:8].reshape(
(-1, 4, 2)).astype(np.int32) // 4, 1)
boxes[i, 8] = cv2.mean(score_map, mask)[0]
boxes = boxes[boxes[:, 8] > cover_thresh]
return boxes
def sort_poly(self, p):
"""
Sort polygons.
"""
min_axis = np.argmin(np.sum(p, axis=1))
p = p[[min_axis, (min_axis + 1) % 4,\
(min_axis + 2) % 4, (min_axis + 3) % 4]]
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
return p
else:
return p[[0, 3, 2, 1]]
def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score']
geo_list = outs_dict['f_geo']
if isinstance(score_list, paddle.Tensor):
score_list = score_list.numpy()
geo_list = geo_list.numpy()
img_num = len(shape_list)
dt_boxes_list = []
for ino in range(img_num):
score = score_list[ino]
geo = geo_list[ino]
boxes = self.detect(
score_map=score,
geo_map=geo,
score_thresh=self.score_thresh,
cover_thresh=self.cover_thresh,
nms_thresh=self.nms_thresh)
boxes_norm = []
if len(boxes) > 0:
h, w = score.shape[1:]
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
boxes = boxes[:, :8].reshape((-1, 4, 2))
boxes[:, :, 0] /= ratio_w
boxes[:, :, 1] /= ratio_h
for i_box, box in enumerate(boxes):
box = self.sort_poly(box.astype(np.int32))
if np.linalg.norm(box[0] - box[1]) < 5 \
or np.linalg.norm(box[3] - box[0]) < 5:
continue
boxes_norm.append(box)
dt_boxes_list.append({'points': np.array(boxes_norm)})
return dt_boxes_list

View File

@@ -0,0 +1,241 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/v0.3.0/mmocr/models/textdet/postprocess/wrapper.py
"""
import cv2
import paddle
import numpy as np
from numpy.fft import ifft
from ppocr.utils.poly_nms import poly_nms, valid_boundary
def fill_hole(input_mask):
h, w = input_mask.shape
canvas = np.zeros((h + 2, w + 2), np.uint8)
canvas[1:h + 1, 1:w + 1] = input_mask.copy()
mask = np.zeros((h + 4, w + 4), np.uint8)
cv2.floodFill(canvas, mask, (0, 0), 1)
canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
return ~canvas | input_mask
def fourier2poly(fourier_coeff, num_reconstr_points=50):
""" Inverse Fourier transform
Args:
fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
with n and k being candidates number and Fourier degree
respectively.
num_reconstr_points (int): Number of reconstructed polygon points.
Returns:
Polygons (ndarray): The reconstructed polygons shaped (n, n')
"""
a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype='complex')
k = (len(fourier_coeff[0]) - 1) // 2
a[:, 0:k + 1] = fourier_coeff[:, k:]
a[:, -k:] = fourier_coeff[:, :k]
poly_complex = ifft(a) * num_reconstr_points
polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2))
polygon[:, :, 0] = poly_complex.real
polygon[:, :, 1] = poly_complex.imag
return polygon.astype('int32').reshape((len(fourier_coeff), -1))
class FCEPostProcess(object):
"""
The post process for FCENet.
"""
def __init__(self,
scales,
fourier_degree=5,
num_reconstr_points=50,
decoding_type='fcenet',
score_thr=0.3,
nms_thr=0.1,
alpha=1.0,
beta=1.0,
box_type='poly',
**kwargs):
self.scales = scales
self.fourier_degree = fourier_degree
self.num_reconstr_points = num_reconstr_points
self.decoding_type = decoding_type
self.score_thr = score_thr
self.nms_thr = nms_thr
self.alpha = alpha
self.beta = beta
self.box_type = box_type
def __call__(self, preds, shape_list):
score_maps = []
for key, value in preds.items():
if isinstance(value, paddle.Tensor):
value = value.numpy()
cls_res = value[:, :4, :, :]
reg_res = value[:, 4:, :, :]
score_maps.append([cls_res, reg_res])
return self.get_boundary(score_maps, shape_list)
def resize_boundary(self, boundaries, scale_factor):
"""Rescale boundaries via scale_factor.
Args:
boundaries (list[list[float]]): The boundary list. Each boundary
with size 2k+1 with k>=4.
scale_factor(ndarray): The scale factor of size (4,).
Returns:
boundaries (list[list[float]]): The scaled boundaries.
"""
boxes = []
scores = []
for b in boundaries:
sz = len(b)
valid_boundary(b, True)
scores.append(b[-1])
b = (np.array(b[:sz - 1]) *
(np.tile(scale_factor[:2], int(
(sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
boxes.append(np.array(b).reshape([-1, 2]))
return np.array(boxes, dtype=np.float32), scores
def get_boundary(self, score_maps, shape_list):
assert len(score_maps) == len(self.scales)
boundaries = []
for idx, score_map in enumerate(score_maps):
scale = self.scales[idx]
boundaries = boundaries + self._get_boundary_single(score_map,
scale)
# nms
boundaries = poly_nms(boundaries, self.nms_thr)
boundaries, scores = self.resize_boundary(
boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
boxes_batch = [dict(points=boundaries, scores=scores)]
return boxes_batch
def _get_boundary_single(self, score_map, scale):
assert len(score_map) == 2
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
return self.fcenet_decode(
preds=score_map,
fourier_degree=self.fourier_degree,
num_reconstr_points=self.num_reconstr_points,
scale=scale,
alpha=self.alpha,
beta=self.beta,
box_type=self.box_type,
score_thr=self.score_thr,
nms_thr=self.nms_thr)
def fcenet_decode(self,
preds,
fourier_degree,
num_reconstr_points,
scale,
alpha=1.0,
beta=2.0,
box_type='poly',
score_thr=0.3,
nms_thr=0.1):
"""Decoding predictions of FCENet to instances.
Args:
preds (list(Tensor)): The head output tensors.
fourier_degree (int): The maximum Fourier transform degree k.
num_reconstr_points (int): The points number of the polygon
reconstructed from predicted Fourier coefficients.
scale (int): The down-sample scale of the prediction.
alpha (float) : The parameter to calculate final scores. Score_{final}
= (Score_{text region} ^ alpha)
* (Score_{text center region}^ beta)
beta (float) : The parameter to calculate final score.
box_type (str): Boundary encoding type 'poly' or 'quad'.
score_thr (float) : The threshold used to filter out the final
candidates.
nms_thr (float) : The threshold of nms.
Returns:
boundaries (list[list[float]]): The instance boundary and confidence
list.
"""
assert isinstance(preds, list)
assert len(preds) == 2
assert box_type in ['poly', 'quad']
cls_pred = preds[0][0]
tr_pred = cls_pred[0:2]
tcl_pred = cls_pred[2:]
reg_pred = preds[1][0].transpose([1, 2, 0])
x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
tr_pred_mask = (score_pred) > score_thr
tr_mask = fill_hole(tr_pred_mask)
tr_contours, _ = cv2.findContours(
tr_mask.astype(np.uint8), cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE) # opencv4
mask = np.zeros_like(tr_mask)
boundaries = []
for cont in tr_contours:
deal_map = mask.copy().astype(np.int8)
cv2.drawContours(deal_map, [cont], -1, 1, -1)
score_map = score_pred * deal_map
score_mask = score_map > 0
xy_text = np.argwhere(score_mask)
dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
x, y = x_pred[score_mask], y_pred[score_mask]
c = x + y * 1j
c[:, fourier_degree] = c[:, fourier_degree] + dxy
c *= scale
polygons = fourier2poly(c, num_reconstr_points)
score = score_map[score_mask].reshape(-1, 1)
polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr)
boundaries = boundaries + polygons
boundaries = poly_nms(boundaries, nms_thr)
if box_type == 'quad':
new_boundaries = []
for boundary in boundaries:
poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
score = boundary[-1]
points = cv2.boxPoints(cv2.minAreaRect(poly))
points = np.int0(points)
new_boundaries.append(points.reshape(-1).tolist() + [score])
boundaries = new_boundaries
return boundaries

View File

@@ -0,0 +1,200 @@
"""
Locality aware nms.
This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
"""
import numpy as np
from shapely.geometry import Polygon
def intersection(g, p):
"""
Intersection.
"""
g = Polygon(g[:8].reshape((4, 2)))
p = Polygon(p[:8].reshape((4, 2)))
g = g.buffer(0)
p = p.buffer(0)
if not g.is_valid or not p.is_valid:
return 0
inter = Polygon(g).intersection(Polygon(p)).area
union = g.area + p.area - inter
if union == 0:
return 0
else:
return inter / union
def intersection_iog(g, p):
"""
Intersection_iog.
"""
g = Polygon(g[:8].reshape((4, 2)))
p = Polygon(p[:8].reshape((4, 2)))
if not g.is_valid or not p.is_valid:
return 0
inter = Polygon(g).intersection(Polygon(p)).area
#union = g.area + p.area - inter
union = p.area
if union == 0:
print("p_area is very small")
return 0
else:
return inter / union
def weighted_merge(g, p):
"""
Weighted merge.
"""
g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
g[8] = (g[8] + p[8])
return g
def standard_nms(S, thres):
"""
Standard nms.
"""
order = np.argsort(S[:, 8])[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return S[keep]
def standard_nms_inds(S, thres):
"""
Standard nms, retun inds.
"""
order = np.argsort(S[:, 8])[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
def nms(S, thres):
"""
nms.
"""
order = np.argsort(S[:, 8])[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
"""
soft_nms
:para boxes_in, N x 9 (coords + score)
:para threshould, eliminate cases min score(0.001)
:para Nt_thres, iou_threshi
:para sigma, gaussian weght
:method, linear or gaussian
"""
boxes = boxes_in.copy()
N = boxes.shape[0]
if N is None or N < 1:
return np.array([])
pos, maxpos = 0, 0
weight = 0.0
inds = np.arange(N)
tbox, sbox = boxes[0].copy(), boxes[0].copy()
for i in range(N):
maxscore = boxes[i, 8]
maxpos = i
tbox = boxes[i].copy()
ti = inds[i]
pos = i + 1
#get max box
while pos < N:
if maxscore < boxes[pos, 8]:
maxscore = boxes[pos, 8]
maxpos = pos
pos = pos + 1
#add max box as a detection
boxes[i, :] = boxes[maxpos, :]
inds[i] = inds[maxpos]
#swap
boxes[maxpos, :] = tbox
inds[maxpos] = ti
tbox = boxes[i].copy()
pos = i + 1
#NMS iteration
while pos < N:
sbox = boxes[pos].copy()
ts_iou_val = intersection(tbox, sbox)
if ts_iou_val > 0:
if method == 1:
if ts_iou_val > Nt_thres:
weight = 1 - ts_iou_val
else:
weight = 1
elif method == 2:
weight = np.exp(-1.0 * ts_iou_val**2 / sigma)
else:
if ts_iou_val > Nt_thres:
weight = 0
else:
weight = 1
boxes[pos, 8] = weight * boxes[pos, 8]
#if box score falls below thresold, discard the box by
#swaping last box update N
if boxes[pos, 8] < threshold:
boxes[pos, :] = boxes[N - 1, :]
inds[pos] = inds[N - 1]
N = N - 1
pos = pos - 1
pos = pos + 1
return boxes[:N]
def nms_locality(polys, thres=0.3):
"""
locality aware nms of EAST
:param polys: a N*9 numpy array. first 8 coordinates, then prob
:return: boxes after nms
"""
S = []
p = None
for g in polys:
if p is not None and intersection(g, p) > thres:
p = weighted_merge(g, p)
else:
if p is not None:
S.append(p)
p = g
if p is not None:
S.append(p)
if len(S) == 0:
return np.array([])
return standard_nms(np.array(S), thres)
if __name__ == '__main__':
# 343,350,448,135,474,143,369,359
print(
Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]]))
.area)

View File

@@ -0,0 +1,52 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess
class PGPostProcess(object):
"""
The post process for PGNet.
"""
def __init__(self, character_dict_path, valid_set, score_thresh, mode,
**kwargs):
self.character_dict_path = character_dict_path
self.valid_set = valid_set
self.score_thresh = score_thresh
self.mode = mode
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
post = PGNet_PostProcess(self.character_dict_path, self.valid_set,
self.score_thresh, outs_dict, shape_list)
if self.mode == 'fast':
data = post.pg_postprocess_fast()
else:
data = post.pg_postprocess_slow()
return data

View File

@@ -0,0 +1,15 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .pse_postprocess import PSEPostProcess

View File

@@ -0,0 +1,6 @@
## 编译
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/post_processing/pse
```python
python3 setup.py build_ext --inplace
```

View File

@@ -0,0 +1,29 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import subprocess
python_path = sys.executable
ori_path = os.getcwd()
os.chdir('ppocr/postprocess/pse_postprocess/pse')
if subprocess.call(
'{} setup.py build_ext --inplace'.format(python_path), shell=True) != 0:
raise RuntimeError(
'Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+'.
format(os.path.dirname(os.path.realpath(__file__))))
os.chdir(ori_path)
from .pse import pse

View File

@@ -0,0 +1,70 @@
import numpy as np
import cv2
cimport numpy as np
cimport cython
cimport libcpp
cimport libcpp.pair
cimport libcpp.queue
from libcpp.pair cimport *
from libcpp.queue cimport *
@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
np.ndarray[np.int32_t, ndim=2] label,
int kernel_num,
int label_num,
float min_area=0):
cdef np.ndarray[np.int32_t, ndim=2] pred
pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
for label_idx in range(1, label_num):
if np.sum(label == label_idx) < min_area:
label[label == label_idx] = 0
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
cdef np.int16_t* dx = [-1, 1, 0, 0]
cdef np.int16_t* dy = [0, 0, -1, 1]
cdef np.int16_t tmpx, tmpy
points = np.array(np.where(label > 0)).transpose((1, 0))
for point_idx in range(points.shape[0]):
tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
pred[tmpx, tmpy] = label[tmpx, tmpy]
cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
cdef int cur_label
for kernel_idx in range(kernel_num - 1, -1, -1):
while not que.empty():
cur = que.front()
que.pop()
cur_label = pred[cur.first, cur.second]
is_edge = True
for j in range(4):
tmpx = cur.first + dx[j]
tmpy = cur.second + dy[j]
if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
continue
if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
continue
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
pred[tmpx, tmpy] = cur_label
is_edge = False
if is_edge:
nxt_que.push(cur)
que, nxt_que = nxt_que, que
return pred
def pse(kernels, min_area):
kernel_num = kernels.shape[0]
label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
return _pse(kernels[:-1], label, kernel_num, label_num, min_area)

View File

@@ -0,0 +1,14 @@
from distutils.core import setup, Extension
from Cython.Build import cythonize
import numpy
setup(ext_modules=cythonize(Extension(
'pse',
sources=['pse.pyx'],
language='c++',
include_dirs=[numpy.get_include()],
library_dirs=[],
libraries=[],
extra_compile_args=['-O3'],
extra_link_args=[]
)))

View File

@@ -0,0 +1,118 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
import paddle
from paddle.nn import functional as F
from ppocr.postprocess.pse_postprocess.pse import pse
class PSEPostProcess(object):
"""
The post process for PSE.
"""
def __init__(self,
thresh=0.5,
box_thresh=0.85,
min_area=16,
box_type='quad',
scale=4,
**kwargs):
assert box_type in ['quad', 'poly'], 'Only quad and poly is supported'
self.thresh = thresh
self.box_thresh = box_thresh
self.min_area = min_area
self.box_type = box_type
self.scale = scale
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if not isinstance(pred, paddle.Tensor):
pred = paddle.to_tensor(pred)
pred = F.interpolate(
pred, scale_factor=4 // self.scale, mode='bilinear')
score = F.sigmoid(pred[:, 0, :, :])
kernels = (pred > self.thresh).astype('float32')
text_mask = kernels[:, 0, :, :]
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
score = score.numpy()
kernels = kernels.numpy().astype(np.uint8)
boxes_batch = []
for batch_index in range(pred.shape[0]):
boxes, scores = self.boxes_from_bitmap(score[batch_index],
kernels[batch_index],
shape_list[batch_index])
boxes_batch.append({'points': boxes, 'scores': scores})
return boxes_batch
def boxes_from_bitmap(self, score, kernels, shape):
label = pse(kernels, self.min_area)
return self.generate_box(score, label, shape)
def generate_box(self, score, label, shape):
src_h, src_w, ratio_h, ratio_w = shape
label_num = np.max(label) + 1
boxes = []
scores = []
for i in range(1, label_num):
ind = label == i
points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
if points.shape[0] < self.min_area:
label[ind] = 0
continue
score_i = np.mean(score[ind])
if score_i < self.box_thresh:
label[ind] = 0
continue
if self.box_type == 'quad':
rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect)
elif self.box_type == 'poly':
box_height = np.max(points[:, 1]) + 10
box_width = np.max(points[:, 0]) + 10
mask = np.zeros((box_height, box_width), np.uint8)
mask[points[:, 1], points[:, 0]] = 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE)
bbox = np.squeeze(contours[0], 1)
else:
raise NotImplementedError
bbox[:, 0] = np.clip(np.round(bbox[:, 0] / ratio_w), 0, src_w)
bbox[:, 1] = np.clip(np.round(bbox[:, 1] / ratio_h), 0, src_h)
boxes.append(bbox)
scores.append(score_i)
return boxes, scores

View File

@@ -0,0 +1,754 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle.nn import functional as F
import re
class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[
batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[text_id]
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
class DistillationCTCLabelDecode(CTCLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
use_space_char=False,
model_name=["student"],
key=None,
multi_head=False,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
self.multi_head = multi_head
def __call__(self, preds, label=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
if self.multi_head and isinstance(pred, dict):
pred = pred['ctc']
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class NRTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]
if isinstance(preds_id, paddle.Tensor):
preds_id = preds_id.numpy()
if isinstance(preds_prob, paddle.Tensor):
preds_prob = preds_prob.numpy()
if preds_id[0][0] == 2:
preds_idx = preds_id[:, 1:]
preds_prob = preds_prob[:, 1:]
else:
preds_idx = preds_id
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
else:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == 3: # end
break
try:
char_list.append(self.character[int(text_index[batch_idx][
idx])])
except:
continue
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text.lower(), np.mean(conf_list).tolist()))
return result_list
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(AttnLabelDecode, self).__init__(character_dict_path,
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
[beg_idx, end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SEEDLabelDecode, self).__init__(character_dict_path,
use_space_char)
def add_special_char(self, dict_character):
self.padding_str = "padding"
self.end_str = "eos"
self.unknown = "unknown"
dict_character = dict_character + [
self.end_str, self.padding_str, self.unknown
]
return dict_character
def get_ignored_tokens(self):
end_idx = self.get_beg_end_flag_idx("eos")
return [end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "sos":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "eos":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
[end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx = preds["rec_pred"]
if isinstance(preds_idx, paddle.Tensor):
preds_idx = preds_idx.numpy()
if "rec_pred_scores" in preds:
preds_idx = preds["rec_pred"]
preds_prob = preds["rec_pred_scores"]
else:
preds_idx = preds["rec_pred"].argmax(axis=2)
preds_prob = preds["rec_pred"].max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path,
use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs):
pred = preds['predict']
char_num = len(self.character_str) + 2
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = np.reshape(pred, [-1, char_num])
preds_idx = np.argmax(pred, axis=1)
preds_prob = np.max(pred, axis=1)
preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
text = self.decode(preds_idx, preds_prob)
if label is None:
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
return text
label = self.decode(label)
return text, label
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class TableLabelDecode(object):
""" """
def __init__(self, character_dict_path, **kwargs):
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
"\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {
'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SARLabelDecode, self).__init__(character_dict_path,
use_space_char)
self.rm_symbol = kwargs.get('rm_symbol', False)
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(self.end_idx):
if text_prob is None and idx == 0:
continue
else:
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
if self.rm_symbol:
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
text = text.lower()
text = comp.sub('', text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
def get_ignored_tokens(self):
return [self.padding_idx]
class DistillationSARLabelDecode(SARLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
use_space_char=False,
model_name=["student"],
key=None,
multi_head=False,
**kwargs):
super(DistillationSARLabelDecode, self).__init__(character_dict_path,
use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
self.multi_head = multi_head
def __call__(self, preds, label=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
if self.multi_head and isinstance(pred, dict):
pred = pred['sar']
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class PRENLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(PRENLabelDecode, self).__init__(character_dict_path,
use_space_char)
def add_special_char(self, dict_character):
padding_str = '<PAD>' # 0
end_str = '<EOS>' # 1
unknown_str = '<UNK>' # 2
dict_character = [padding_str, end_str, unknown_str] + dict_character
self.padding_idx = 0
self.end_idx = 1
self.unknown_idx = 2
return dict_character
def decode(self, text_index, text_prob=None):
""" convert text-index into text-label. """
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == self.end_idx:
break
if text_index[batch_idx][idx] in \
[self.padding_idx, self.unknown_idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
if len(text) > 0:
result_list.append((text, np.mean(conf_list).tolist()))
else:
# here confidence of empty recog result is 1
result_list.append(('', 1))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob)
if label is None:
return text
label = self.decode(label)
return text, label

View File

@@ -0,0 +1,355 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import numpy as np
from .locality_aware_nms import nms_locality
import paddle
import cv2
import time
class SASTPostProcess(object):
"""
The post process for SAST.
"""
def __init__(self,
score_thresh=0.5,
nms_thresh=0.2,
sample_pts_num=2,
shrink_ratio_of_width=0.3,
expand_scale=1.0,
tcl_map_thresh=0.5,
**kwargs):
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.sample_pts_num = sample_pts_num
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def point_pair2poly(self, point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
# constract poly
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
"""Restore quad."""
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
xy_text = xy_text[:, ::-1] # (n, 2)
# Sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 1])]
scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
scores = scores[:, np.newaxis]
# Restore
point_num = int(tvo_map.shape[-1] / 2)
assert point_num == 4
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
quads = xy_text_tile - tvo_map
return scores, quads, xy_text
def quad_area(self, quad):
"""
compute area of a quad.
"""
edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
return np.sum(edge) / 2.
def nms(self, dets):
if self.is_python35:
import lanms
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
else:
dets = nms_locality(dets, self.nms_thresh)
return dets
def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
"""
Cluster pixels in tcl_map based on quads.
"""
instance_count = quads.shape[0] + 1 # contain background
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
if instance_count == 1:
return instance_count, instance_label_map
# predict text center
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
n = xy_text.shape[0]
xy_text = xy_text[:, ::-1] # (n, 2)
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
pred_tc = xy_text - tco
# get gt text center
m = quads.shape[0]
gt_tc = np.mean(quads, axis=1) # (m, 2)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
(1, m, 1)) # (n, m, 2)
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
return instance_count, instance_label_map
def estimate_sample_pts_num(self, quad, xy_text):
"""
Estimate sample points number.
"""
eh = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
dense_sample_pts_num = max(2, int(ew))
dense_xy_center_line = xy_text[np.linspace(
0,
xy_text.shape[0] - 1,
dense_sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[
1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(
np.linalg.norm(
dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num
def detect_sast(self,
tcl_map,
tvo_map,
tbo_map,
tco_map,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=0.3,
tcl_map_thresh=0.5,
offset_expand=1.0,
out_strid=4.0):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets)
if dets.shape[0] == 0:
return []
quads = dets[:, :-1].reshape(-1, 4, 2)
# Compute quad area
quad_areas = []
for quad in quads:
quad_areas.append(-self.quad_area(quad))
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count, instance_label_map = self.cluster_by_quads_tco(
tcl_map, tcl_map_thresh, quads, tco_map)
# restore single poly with tcl instance.
poly_list = []
for instance_idx in range(1, instance_count):
xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
quad = quads[instance_idx - 1]
q_area = quad_areas[instance_idx - 1]
if q_area < 5:
continue
#
len1 = float(np.linalg.norm(quad[0] - quad[1]))
len2 = float(np.linalg.norm(quad[1] - quad[2]))
min_len = min(len1, len2)
if min_len < 3:
continue
# filter small CC
if xy_text.shape[0] <= 0:
continue
# filter low confidence instance
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
# sort xy_text
left_center_pt = np.array(
[[(quad[0, 0] + quad[-1, 0]) / 2.0,
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
right_center_pt = np.array(
[[(quad[1, 0] + quad[2, 0]) / 2.0,
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / \
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
xy_text = xy_text[np.argsort(proj_value)]
# Sample pts in tcl map
if self.sample_pts_num == 0:
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else:
sample_pts_num = self.sample_pts_num
xy_center_line = xy_text[np.linspace(
0,
xy_text.shape[0] - 1,
sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
point_pair_list = []
for x, y in xy_center_line:
# get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
# original point
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list)
detected_poly = self.expand_poly_along_width(detected_poly,
shrink_ratio_of_width)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly)
return poly_list
def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score']
border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo']
tco_list = outs_dict['f_tco']
if isinstance(score_list, paddle.Tensor):
score_list = score_list.numpy()
border_list = border_list.numpy()
tvo_list = tvo_list.numpy()
tco_list = tco_list.numpy()
img_num = len(shape_list)
poly_lists = []
for ino in range(img_num):
p_score = score_list[ino].transpose((1, 2, 0))
p_border = border_list[ino].transpose((1, 2, 0))
p_tvo = tvo_list[ino].transpose((1, 2, 0))
p_tco = tco_list[ino].transpose((1, 2, 0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
poly_list = self.detect_sast(
p_score,
p_tvo,
p_border,
p_tco,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh,
offset_expand=self.expand_scale)
poly_lists.append({'points': np.array(poly_list)})
return poly_lists

View File

@@ -0,0 +1,51 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class VQAReTokenLayoutLMPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, **kwargs):
super(VQAReTokenLayoutLMPostProcess, self).__init__()
def __call__(self, preds, label=None, *args, **kwargs):
if label is not None:
return self._metric(preds, label)
else:
return self._infer(preds, *args, **kwargs)
def _metric(self, preds, label):
return preds['pred_relations'], label[6], label[5]
def _infer(self, preds, *args, **kwargs):
ser_results = kwargs['ser_results']
entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
pred_relations = preds['pred_relations']
# merge relations and ocr info
results = []
for pred_relation, ser_result, entity_idx_dict in zip(
pred_relations, ser_results, entity_idx_dict_batch):
result = []
used_tail_id = []
for relation in pred_relation:
if relation['tail_id'] in used_tail_id:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
result.append((ocr_info_head, ocr_info_tail))
results.append(result)
return results

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from ppocr.utils.utility import load_vqa_bio_label_maps
class VQASerTokenLayoutLMPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, class_path, **kwargs):
super(VQASerTokenLayoutLMPostProcess, self).__init__()
label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
self.label2id_map_for_draw = dict()
for key in label2id_map:
if key.startswith("I-"):
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
self.label2id_map_for_draw[key] = label2id_map[key]
self.id2label_map_for_show = dict()
for key in self.label2id_map_for_draw:
val = self.label2id_map_for_draw[key]
if key == "O":
self.id2label_map_for_show[val] = key
if key.startswith("B-") or key.startswith("I-"):
self.id2label_map_for_show[val] = key[2:]
else:
self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if batch is not None:
return self._metric(preds, batch[1])
else:
return self._infer(preds, **kwargs)
def _metric(self, preds, label):
pred_idxs = preds.argmax(axis=2)
decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
for i in range(pred_idxs.shape[0]):
for j in range(pred_idxs.shape[1]):
if label[i, j] != -100:
label_decode_out_list[i].append(self.id2label_map[label[i,
j]])
decode_out_list[i].append(self.id2label_map[pred_idxs[i,
j]])
return decode_out_list, label_decode_out_list
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
results = []
for pred, attention_mask, segment_offset_id, ocr_info in zip(
preds, attention_masks, segment_offset_ids, ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = pred[start_id:end_id]
curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
results.append(ocr_info)
return results