mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-03-02 00:22:17 +08:00
init
This commit is contained in:
172
backend/ppocr/modeling/backbones/vqa_layoutlm.py
Normal file
172
backend/ppocr/modeling/backbones/vqa_layoutlm.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from paddle import nn
|
||||
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
|
||||
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
|
||||
from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
|
||||
|
||||
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
|
||||
|
||||
pretrained_model_dict = {
|
||||
LayoutXLMModel: 'layoutxlm-base-uncased',
|
||||
LayoutLMModel: 'layoutlm-base-uncased',
|
||||
LayoutLMv2Model: 'layoutlmv2-base-uncased'
|
||||
}
|
||||
|
||||
|
||||
class NLPBaseModel(nn.Layer):
|
||||
def __init__(self,
|
||||
base_model_class,
|
||||
model_class,
|
||||
type='ser',
|
||||
pretrained=True,
|
||||
checkpoints=None,
|
||||
**kwargs):
|
||||
super(NLPBaseModel, self).__init__()
|
||||
if checkpoints is not None:
|
||||
self.model = model_class.from_pretrained(checkpoints)
|
||||
else:
|
||||
pretrained_model_name = pretrained_model_dict[base_model_class]
|
||||
if pretrained:
|
||||
base_model = base_model_class.from_pretrained(
|
||||
pretrained_model_name)
|
||||
else:
|
||||
base_model = base_model_class(
|
||||
**base_model_class.pretrained_init_configuration[
|
||||
pretrained_model_name])
|
||||
if type == 'ser':
|
||||
self.model = model_class(
|
||||
base_model, num_classes=kwargs['num_classes'], dropout=None)
|
||||
else:
|
||||
self.model = model_class(base_model, dropout=None)
|
||||
self.out_channels = 1
|
||||
|
||||
|
||||
class LayoutLMForSer(NLPBaseModel):
|
||||
def __init__(self, num_classes, pretrained=True, checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutLMForSer, self).__init__(
|
||||
LayoutLMModel,
|
||||
LayoutLMForTokenClassification,
|
||||
'ser',
|
||||
pretrained,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[2],
|
||||
attention_mask=x[4],
|
||||
token_type_ids=x[5],
|
||||
position_ids=None,
|
||||
output_hidden_states=False)
|
||||
return x
|
||||
|
||||
|
||||
class LayoutLMv2ForSer(NLPBaseModel):
|
||||
def __init__(self, num_classes, pretrained=True, checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutLMv2ForSer, self).__init__(
|
||||
LayoutLMv2Model,
|
||||
LayoutLMv2ForTokenClassification,
|
||||
'ser',
|
||||
pretrained,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[2],
|
||||
image=x[3],
|
||||
attention_mask=x[4],
|
||||
token_type_ids=x[5],
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
labels=None)
|
||||
return x[0]
|
||||
|
||||
|
||||
class LayoutXLMForSer(NLPBaseModel):
|
||||
def __init__(self, num_classes, pretrained=True, checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutXLMForSer, self).__init__(
|
||||
LayoutXLMModel,
|
||||
LayoutXLMForTokenClassification,
|
||||
'ser',
|
||||
pretrained,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[2],
|
||||
image=x[3],
|
||||
attention_mask=x[4],
|
||||
token_type_ids=x[5],
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
labels=None)
|
||||
return x[0]
|
||||
|
||||
|
||||
class LayoutLMv2ForRe(NLPBaseModel):
|
||||
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
|
||||
super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model,
|
||||
LayoutLMv2ForRelationExtraction,
|
||||
're', pretrained, checkpoints)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[1],
|
||||
labels=None,
|
||||
image=x[2],
|
||||
attention_mask=x[3],
|
||||
token_type_ids=x[4],
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
entities=x[5],
|
||||
relations=x[6])
|
||||
return x
|
||||
|
||||
|
||||
class LayoutXLMForRe(NLPBaseModel):
|
||||
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
|
||||
super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
|
||||
LayoutXLMForRelationExtraction,
|
||||
're', pretrained, checkpoints)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[1],
|
||||
labels=None,
|
||||
image=x[2],
|
||||
attention_mask=x[3],
|
||||
token_type_ids=x[4],
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
entities=x[5],
|
||||
relations=x[6])
|
||||
return x
|
||||
Reference in New Issue
Block a user