mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-03-06 18:57:36 +08:00
init
This commit is contained in:
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import paddle
|
||||
from ppocr.utils.utility import load_vqa_bio_label_maps
|
||||
|
||||
|
||||
class VQASerTokenLayoutLMPostProcess(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, class_path, **kwargs):
|
||||
super(VQASerTokenLayoutLMPostProcess, self).__init__()
|
||||
label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
|
||||
|
||||
self.label2id_map_for_draw = dict()
|
||||
for key in label2id_map:
|
||||
if key.startswith("I-"):
|
||||
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
|
||||
else:
|
||||
self.label2id_map_for_draw[key] = label2id_map[key]
|
||||
|
||||
self.id2label_map_for_show = dict()
|
||||
for key in self.label2id_map_for_draw:
|
||||
val = self.label2id_map_for_draw[key]
|
||||
if key == "O":
|
||||
self.id2label_map_for_show[val] = key
|
||||
if key.startswith("B-") or key.startswith("I-"):
|
||||
self.id2label_map_for_show[val] = key[2:]
|
||||
else:
|
||||
self.id2label_map_for_show[val] = key
|
||||
|
||||
def __call__(self, preds, batch=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
|
||||
if batch is not None:
|
||||
return self._metric(preds, batch[1])
|
||||
else:
|
||||
return self._infer(preds, **kwargs)
|
||||
|
||||
def _metric(self, preds, label):
|
||||
pred_idxs = preds.argmax(axis=2)
|
||||
decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
|
||||
label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
|
||||
|
||||
for i in range(pred_idxs.shape[0]):
|
||||
for j in range(pred_idxs.shape[1]):
|
||||
if label[i, j] != -100:
|
||||
label_decode_out_list[i].append(self.id2label_map[label[i,
|
||||
j]])
|
||||
decode_out_list[i].append(self.id2label_map[pred_idxs[i,
|
||||
j]])
|
||||
return decode_out_list, label_decode_out_list
|
||||
|
||||
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
|
||||
results = []
|
||||
|
||||
for pred, attention_mask, segment_offset_id, ocr_info in zip(
|
||||
preds, attention_masks, segment_offset_ids, ocr_infos):
|
||||
pred = np.argmax(pred, axis=1)
|
||||
pred = [self.id2label_map[idx] for idx in pred]
|
||||
|
||||
for idx in range(len(segment_offset_id)):
|
||||
if idx == 0:
|
||||
start_id = 0
|
||||
else:
|
||||
start_id = segment_offset_id[idx - 1]
|
||||
|
||||
end_id = segment_offset_id[idx]
|
||||
|
||||
curr_pred = pred[start_id:end_id]
|
||||
curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
|
||||
|
||||
if len(curr_pred) <= 0:
|
||||
pred_id = 0
|
||||
else:
|
||||
counts = np.bincount(curr_pred)
|
||||
pred_id = np.argmax(counts)
|
||||
ocr_info[idx]["pred_id"] = int(pred_id)
|
||||
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
|
||||
results.append(ocr_info)
|
||||
return results
|
||||
Reference in New Issue
Block a user