mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-22 01:38:19 +08:00
init
This commit is contained in:
131
backend/ppocr/utils/utility.py
Executable file
131
backend/ppocr/utils/utility.py
Executable file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import imghdr
|
||||
import cv2
|
||||
import random
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
|
||||
def print_dict(d, logger, delimiter=0):
|
||||
"""
|
||||
Recursively visualize a dict and
|
||||
indenting acrrording by the relationship of keys.
|
||||
"""
|
||||
for k, v in sorted(d.items()):
|
||||
if isinstance(v, dict):
|
||||
logger.info("{}{} : ".format(delimiter * " ", str(k)))
|
||||
print_dict(v, logger, delimiter + 4)
|
||||
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
|
||||
logger.info("{}{} : ".format(delimiter * " ", str(k)))
|
||||
for value in v:
|
||||
print_dict(value, logger, delimiter + 4)
|
||||
else:
|
||||
logger.info("{}{} : {}".format(delimiter * " ", k, v))
|
||||
|
||||
|
||||
def get_check_global_params(mode):
|
||||
check_params = ['use_gpu', 'max_text_length', 'image_shape', \
|
||||
'image_shape', 'character_type', 'loss_type']
|
||||
if mode == "train_eval":
|
||||
check_params = check_params + [ \
|
||||
'train_batch_size_per_card', 'test_batch_size_per_card']
|
||||
elif mode == "test":
|
||||
check_params = check_params + ['test_batch_size_per_card']
|
||||
return check_params
|
||||
|
||||
|
||||
def _check_image_file(path):
|
||||
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
|
||||
return any([path.lower().endswith(e) for e in img_end])
|
||||
|
||||
|
||||
def get_image_file_list(img_file):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
||||
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
|
||||
if os.path.isfile(img_file) and _check_image_file(img_file):
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
file_path = os.path.join(img_file, single_file)
|
||||
if os.path.isfile(file_path) and _check_image_file(file_path):
|
||||
imgs_lists.append(file_path)
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
imgs_lists = sorted(imgs_lists)
|
||||
return imgs_lists
|
||||
|
||||
|
||||
def check_and_read_gif(img_path):
|
||||
if os.path.basename(img_path)[-3:] in ['gif', 'GIF']:
|
||||
gif = cv2.VideoCapture(img_path)
|
||||
ret, frame = gif.read()
|
||||
if not ret:
|
||||
logger = logging.getLogger('ppocr')
|
||||
logger.info("Cannot read {}. This gif image maybe corrupted.")
|
||||
return None, False
|
||||
if len(frame.shape) == 2 or frame.shape[-1] == 1:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
||||
imgvalue = frame[:, :, ::-1]
|
||||
return imgvalue, True
|
||||
return None, False
|
||||
|
||||
|
||||
def load_vqa_bio_label_maps(label_map_path):
|
||||
with open(label_map_path, "r", encoding='utf-8') as fin:
|
||||
lines = fin.readlines()
|
||||
lines = [line.strip() for line in lines]
|
||||
if "O" not in lines:
|
||||
lines.insert(0, "O")
|
||||
labels = []
|
||||
for line in lines:
|
||||
if line == "O":
|
||||
labels.append("O")
|
||||
else:
|
||||
labels.append("B-" + line)
|
||||
labels.append("I-" + line)
|
||||
label2id_map = {label: idx for idx, label in enumerate(labels)}
|
||||
id2label_map = {idx: label for idx, label in enumerate(labels)}
|
||||
return label2id_map, id2label_map
|
||||
|
||||
|
||||
def set_seed(seed=1024):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
paddle.seed(seed)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""reset"""
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
"""update"""
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
Reference in New Issue
Block a user