diff --git a/backend/main.py b/backend/main.py index 5adb4ec..5cad54d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -8,6 +8,7 @@ import sys sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import config +from backend.tools.common_tools import is_video_or_image, is_image_file from backend.scenedetect import scene_detect from backend.scenedetect.detectors import ContentDetector from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint @@ -183,7 +184,8 @@ class SubtitleDetect: new_unify_values = [] for idx, region in enumerate(current_regions): - last_standard_region = unify_value_map[last_key][idx] if idx < len(unify_value_map[last_key]) else None + last_standard_region = unify_value_map[last_key][idx] if idx < len( + unify_value_map[last_key]) else None # 如果当前的区间与前一个键的对应区间相似,我们统一它们 if last_standard_region and self.are_similar(region, last_standard_region): @@ -483,7 +485,7 @@ class SubtitleRemover: self.gui_mode = gui_mode # 判断是否为图片 self.is_picture = False - if str(vd_path).endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): + if is_image_file(str(vd_path)): self.sub_area = None self.is_picture = True # 视频路径 @@ -496,8 +498,10 @@ class SubtitleRemover: # 视频帧率 self.fps = self.video_cap.get(cv2.CAP_PROP_FPS) # 视频尺寸 - self.size = (int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) - self.mask_size = (int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))) + self.size = ( + int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) + self.mask_size = ( + int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))) self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 创建字幕检测对象 @@ -671,7 +675,8 @@ class SubtitleRemover: if self.sub_area is not None: ymin, ymax, xmin, xmax = self.sub_area else: - print('[Info] No subtitle area has been set. Video will be processed in full screen. As a result, the final outcome might be suboptimal.') + print( + '[Info] No subtitle area has been set. Video will be processed in full screen. As a result, the final outcome might be suboptimal.') ymin, ymax, xmin, xmax = 0, self.frame_height, 0, self.frame_width mask_area_coordinates = [(xmin, xmax, ymin, ymax)] mask = create_mask(self.mask_size, mask_area_coordinates) @@ -875,9 +880,13 @@ class SubtitleRemover: if __name__ == '__main__': multiprocessing.set_start_method("spawn") # 1. 提示用户输入视频路径 - video_path = input(f"Please input video file path: ").strip() + video_path = input(f"Please input video or image file path: ").strip() + # 判断视频路径是不是一个目录,是目录的化,批量处理改目录下的所有视频文件 # 2. 按以下顺序传入字幕区域 # sub_area = (ymin, ymax, xmin, xmax) # 3. 新建字幕提取对象 - sd = SubtitleRemover(video_path, sub_area=None) - sd.run() + if is_video_or_image(video_path): + sd = SubtitleRemover(video_path, sub_area=None) + sd.run() + else: + print(f'Invalid video path: {video_path}') diff --git a/backend/tools/common_tools.py b/backend/tools/common_tools.py new file mode 100644 index 0000000..54372d5 --- /dev/null +++ b/backend/tools/common_tools.py @@ -0,0 +1,32 @@ +import os + +video_extensions = { + '.mp4', '.m4a', '.m4v', '.f4v', '.f4a', '.m4b', '.m4r', '.f4b', '.mov', + '.3gp', '.3gp2', '.3g2', '.3gpp', '.3gpp2', '.ogg', '.oga', '.ogv', '.ogx', + '.wmv', '.wma', '.asf', '.webm', '.flv', '.avi', '.gifv', '.mkv', '.rm', + '.rmvb', '.vob', '.dvd', '.mpg', '.mpeg', '.mp2', '.mpe', '.mpv', '.mpg', + '.mpeg', '.m2v', '.svi', '.3gp', '.mxf', '.roq', '.nsv', '.flv', '.f4v', + '.f4p', '.f4a', '.f4b' +} + +image_extensions = { + '.jpg', '.jpeg', '.jpe', '.jif', '.jfif', '.jfi', '.png', '.gif', + '.webp', '.tiff', '.tif', '.psd', '.raw', '.arw', '.cr2', '.nrw', + '.k25', '.bmp', '.dib', '.heif', '.heic', '.ind', '.indd', '.indt', + '.jp2', '.j2k', '.jpf', '.jpx', '.jpm', '.mj2', '.svg', '.svgz', + '.ai', '.eps', '.ico' +} + + +def is_video_file(filename): + return os.path.splitext(filename)[-1].lower() in video_extensions + + +def is_image_file(filename): + return os.path.splitext(filename)[-1].lower() in image_extensions + + +def is_video_or_image(filename): + file_extension = os.path.splitext(filename)[-1].lower() + # 检查扩展名是否在定义的视频或图片文件后缀集合中 + return file_extension in video_extensions or file_extension in image_extensions