mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-27 22:24:42 +08:00
init
This commit is contained in:
67
backend/ppocr/data/imaug/vqa/token/vqa_token_relation.py
Normal file
67
backend/ppocr/data/imaug/vqa/token/vqa_token_relation.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class VQAReTokenRelation(object):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
build relations
|
||||
"""
|
||||
entities = data['entities']
|
||||
relations = data['relations']
|
||||
id2label = data.pop('id2label')
|
||||
empty_entity = data.pop('empty_entity')
|
||||
entity_id_to_index_map = data.pop('entity_id_to_index_map')
|
||||
|
||||
relations = list(set(relations))
|
||||
relations = [
|
||||
rel for rel in relations
|
||||
if rel[0] not in empty_entity and rel[1] not in empty_entity
|
||||
]
|
||||
kv_relations = []
|
||||
for rel in relations:
|
||||
pair = [id2label[rel[0]], id2label[rel[1]]]
|
||||
if pair == ["question", "answer"]:
|
||||
kv_relations.append({
|
||||
"head": entity_id_to_index_map[rel[0]],
|
||||
"tail": entity_id_to_index_map[rel[1]]
|
||||
})
|
||||
elif pair == ["answer", "question"]:
|
||||
kv_relations.append({
|
||||
"head": entity_id_to_index_map[rel[1]],
|
||||
"tail": entity_id_to_index_map[rel[0]]
|
||||
})
|
||||
else:
|
||||
continue
|
||||
relations = sorted(
|
||||
[{
|
||||
"head": rel["head"],
|
||||
"tail": rel["tail"],
|
||||
"start_index": self.get_relation_span(rel, entities)[0],
|
||||
"end_index": self.get_relation_span(rel, entities)[1],
|
||||
} for rel in kv_relations],
|
||||
key=lambda x: x["head"], )
|
||||
|
||||
data['relations'] = relations
|
||||
return data
|
||||
|
||||
def get_relation_span(self, rel, entities):
|
||||
bound = []
|
||||
for entity_index in [rel["head"], rel["tail"]]:
|
||||
bound.append(entities[entity_index]["start"])
|
||||
bound.append(entities[entity_index]["end"])
|
||||
return min(bound), max(bound)
|
||||
Reference in New Issue
Block a user