moved main to __main__; changed _run to use the new encoder and processor

This commit is contained in:
k4yt3x
2022-05-01 08:44:33 +00:00
parent a041a60d87
commit 9fc0aa787e

View File

@@ -27,7 +27,7 @@ __ __ _ _ ___ __ __
Name: Video2X Name: Video2X
Creator: K4YT3X Creator: K4YT3X
Date Created: February 24, 2018 Date Created: February 24, 2018
Last Modified: April 5, 2022 Last Modified: April 30, 2022
Editor: BrianPetkovsek Editor: BrianPetkovsek
Last Modified: June 17, 2019 Last Modified: June 17, 2019
@@ -39,20 +39,18 @@ Editor: 28598519a
Last Modified: March 23, 2020 Last Modified: March 23, 2020
""" """
import argparse
import ctypes import ctypes
import math import math
import multiprocessing
import os
import pathlib
import signal import signal
import sys import sys
import time import time
from enum import Enum
from multiprocessing import Manager, Pool, Queue, Value
from pathlib import Path
import ffmpeg import ffmpeg
from cv2 import cv2 from cv2 import cv2
from loguru import logger from loguru import logger
from rich import print as rich_print
from rich.console import Console from rich.console import Console
from rich.file_proxy import FileProxy from rich.file_proxy import FileProxy
from rich.progress import ( from rich.progress import (
@@ -65,41 +63,23 @@ from rich.progress import (
) )
from rich.text import Text from rich.text import Text
from video2x.processor import Processor
from . import __version__ from . import __version__
from .decoder import VideoDecoder from .decoder import VideoDecoder, VideoDecoderThread
from .encoder import VideoEncoder from .encoder import VideoEncoder
from .interpolator import Interpolator from .interpolator import Interpolator
from .upscaler import Upscaler from .upscaler import UpscalerProcessor
# for desktop environments only # for desktop environments only
# if pynput can be loaded, enable global pause hotkey support # if pynput can be loaded, enable global pause hotkey support
try: try:
import pynput from pynput.keyboard import HotKey, Listener
except ImportError: except ImportError:
ENABLE_HOTKEY = False ENABLE_HOTKEY = False
else: else:
ENABLE_HOTKEY = True ENABLE_HOTKEY = True
LEGAL_INFO = f"""Video2X\t\t{__version__}
Author:\t\tK4YT3X
License:\tGNU AGPL v3
Github Page:\thttps://github.com/k4yt3x/video2x
Contact:\ti@k4yt3x.com"""
# algorithms available for upscaling tasks
UPSCALING_ALGORITHMS = [
"waifu2x",
"srmd",
"realsr",
"realcugan",
]
# algorithms available for frame interpolation tasks
INTERPOLATION_ALGORITHMS = ["rife"]
# progress bar labels for different modes
MODE_LABELS = {"upscale": "Upscaling", "interpolate": "Interpolating"}
# format string for Loguru loggers # format string for Loguru loggers
LOGURU_FORMAT = ( LOGURU_FORMAT = (
"<green>{time:HH:mm:ss.SSSSSS!UTC}</green> | " "<green>{time:HH:mm:ss.SSSSSS!UTC}</green> | "
@@ -119,6 +99,11 @@ class ProcessingSpeedColumn(ProgressColumn):
) )
class ProcessingMode(Enum):
UPSCALE = {"label": "Upscaling", "processor": UpscalerProcessor}
INTERPOLATE = {"label": "Interpolating", "processor": Interpolator}
class Video2X: class Video2X:
""" """
Video2X class Video2X class
@@ -132,11 +117,11 @@ class Video2X:
self.version = __version__ self.version = __version__
@staticmethod @staticmethod
def _get_video_info(path: pathlib.Path) -> tuple: def _get_video_info(path: Path) -> tuple:
""" """
get video file information with FFmpeg get video file information with FFmpeg
:param path pathlib.Path: video file path :param path Path: video file path
:raises RuntimeError: raised when video stream isn't found :raises RuntimeError: raised when video stream isn't found
""" """
# probe video file info # probe video file info
@@ -160,34 +145,17 @@ class Video2X:
return video_info["width"], video_info["height"], total_frames, frame_rate return video_info["width"], video_info["height"], total_frames, frame_rate
def _toggle_pause(self, _signal_number: int = -1, _frame=None):
# print console messages and update the progress bar's status
if self.pause.value is False:
self.progress.update(self.task, description=self.description + " (paused)")
self.progress.stop_task(self.task)
logger.warning("Processing paused, press Ctrl+Alt+V again to resume")
elif self.pause.value is True:
self.progress.update(self.task, description=self.description)
logger.warning("Resuming processing")
self.progress.start_task(self.task)
# invert the value of the pause flag
with self.pause.get_lock():
self.pause.value = not self.pause.value
def _run( def _run(
self, self,
input_path: pathlib.Path, input_path: Path,
width: int, width: int,
height: int, height: int,
total_frames: int, total_frames: int,
frame_rate: float, frame_rate: float,
output_path: pathlib.Path, output_path: Path,
output_width: int, output_width: int,
output_height: int, output_height: int,
Processor: object, mode: ProcessingMode,
mode: str,
processes: int, processes: int,
processing_settings: tuple, processing_settings: tuple,
) -> None: ) -> None:
@@ -207,51 +175,40 @@ class Video2X:
logger.remove() logger.remove()
logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT) logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT)
# initialize values # TODO: add docs
self.processor_processes = [] tasks_queue = Queue(maxsize=processes * 10)
self.processing_queue = multiprocessing.Queue(maxsize=processes * 10) processed_frames = Manager().dict()
processed_frames = multiprocessing.Manager().list([None] * total_frames) pause_flag = Value(ctypes.c_bool, False)
self.processed = multiprocessing.Value("I", 0)
self.pause = multiprocessing.Value(ctypes.c_bool, False)
# set up and start decoder thread # set up and start decoder thread
logger.info("Starting video decoder") logger.info("Starting video decoder")
self.decoder = VideoDecoder( decoder = VideoDecoder(
input_path, input_path,
width, width,
height, height,
frame_rate, frame_rate,
self.processing_queue,
processing_settings,
self.pause,
) )
self.decoder.start() decoder_thread = VideoDecoderThread(tasks_queue, decoder, processing_settings)
decoder_thread.start()
# set up and start encoder thread # set up and start encoder thread
logger.info("Starting video encoder") logger.info("Starting video encoder")
self.encoder = VideoEncoder( encoder = VideoEncoder(
input_path, input_path,
frame_rate * 2 if mode == "interpolate" else frame_rate, frame_rate * 2 if mode == "interpolate" else frame_rate,
output_path, output_path,
output_width, output_width,
output_height, output_height,
total_frames,
processed_frames,
self.processed,
self.pause,
) )
self.encoder.start()
# create processor processes # create a pool of processor processes to process the queue
for process_name in range(processes): processor: Processor = mode.value["processor"](
process = Processor(self.processing_queue, processed_frames, self.pause) tasks_queue, processed_frames, pause_flag
process.name = str(process_name) )
process.daemon = True processor_pool = Pool(processes, processor.process)
process.start()
self.processor_processes.append(process)
# create progress bar # create progress bar
self.progress = Progress( progress = Progress(
"[progress.description]{task.description}", "[progress.description]{task.description}",
BarColumn(complete_style="blue", finished_style="green"), BarColumn(complete_style="blue", finished_style="green"),
"[progress.percentage]{task.percentage:>3.0f}%", "[progress.percentage]{task.percentage:>3.0f}%",
@@ -264,23 +221,42 @@ class Video2X:
speed_estimate_period=300.0, speed_estimate_period=300.0,
disable=True, disable=True,
) )
task = progress.add_task(f"[cyan]{mode.value['label']}", total=total_frames)
self.description = f"[cyan]{MODE_LABELS.get(mode, 'Unknown')}" def _toggle_pause(_signal_number: int = -1, _frame=None):
self.task = self.progress.add_task(self.description, total=total_frames)
# allow the closure to modify external immutable flag
nonlocal pause_flag
# print console messages and update the progress bar's status
if pause_flag.value is False:
progress.update(
task, description=f"[cyan]{mode.value['label']} (paused)"
)
progress.stop_task(task)
logger.warning("Processing paused, press Ctrl+Alt+V again to resume")
# the lock is already acquired
elif pause_flag.value is True:
progress.update(task, description=f"[cyan]{mode.value['label']}")
logger.warning("Resuming processing")
progress.start_task(task)
# invert the flag
with pause_flag.get_lock():
pause_flag.value = not pause_flag.value
# allow sending SIGUSR1 to pause/resume processing # allow sending SIGUSR1 to pause/resume processing
signal.signal(signal.SIGUSR1, self._toggle_pause) signal.signal(signal.SIGUSR1, _toggle_pause)
# enable global pause hotkey if it's supported # enable global pause hotkey if it's supported
if ENABLE_HOTKEY is True: if ENABLE_HOTKEY is True:
# create global pause hotkey # create global pause hotkey
pause_hotkey = pynput.keyboard.HotKey( pause_hotkey = HotKey(HotKey.parse("<ctrl>+<alt>+v"), _toggle_pause)
pynput.keyboard.HotKey.parse("<ctrl>+<alt>+v"), self._toggle_pause
)
# create global keyboard input listener # create global keyboard input listener
keyboard_listener = pynput.keyboard.Listener( keyboard_listener = Listener(
on_press=( on_press=(
lambda key: pause_hotkey.press(keyboard_listener.canonical(key)) lambda key: pause_hotkey.press(keyboard_listener.canonical(key))
), ),
@@ -293,51 +269,52 @@ class Video2X:
keyboard_listener.start() keyboard_listener.start()
# a temporary variable that stores the exception # a temporary variable that stores the exception
exception = [] exceptions = []
try: try:
# wait for jobs in queue to deplete # let the context manager automatically stop the progress bar
while self.processed.value < total_frames - 1: with progress:
time.sleep(1)
# check processor health frame_index = 0
for process in self.processor_processes: while frame_index < total_frames:
if not process.is_alive():
raise Exception("process died unexpectedly")
# check decoder health current_frame = processed_frames.get(frame_index)
if not self.decoder.is_alive() and self.decoder.exception is not None:
raise Exception("decoder died unexpectedly")
# check encoder health if pause_flag.value is True or current_frame is None:
if not self.encoder.is_alive() and self.encoder.exception is not None: time.sleep(0.1)
raise Exception("encoder died unexpectedly") continue
# show progress bar when upscale starts # show the progress bar after the processing starts
if self.progress.disable is True and self.processed.value > 0: # reduces speed estimation inaccuracies and print overlaps
self.progress.disable = False if frame_index == 0:
self.progress.start() progress.disable = False
progress.start()
# update progress if current_frame is True:
if self.pause.value is False: encoder.write(processed_frames.get(frame_index - 1))
self.progress.update(self.task, completed=self.processed.value)
self.progress.update(self.task, completed=total_frames) else:
self.progress.stop() encoder.write(current_frame)
logger.info("Processing has completed")
if frame_index > 0:
del processed_frames[frame_index - 1]
progress.update(task, completed=frame_index + 1)
frame_index += 1
# if SIGTERM is received or ^C is pressed # if SIGTERM is received or ^C is pressed
except (SystemExit, KeyboardInterrupt) as error: except (SystemExit, KeyboardInterrupt) as error:
self.progress.stop()
logger.warning("Exit signal received, exiting gracefully") logger.warning("Exit signal received, exiting gracefully")
logger.warning("Press ^C again to force terminate") logger.warning("Press ^C again to force terminate")
exception.append(error) exceptions.append(error)
except Exception as error: except Exception as error:
self.progress.stop()
logger.exception(error) logger.exception(error)
exception.append(error) exceptions.append(error)
else:
logger.info("Processing has completed")
finally: finally:
@@ -346,31 +323,30 @@ class Video2X:
keyboard_listener.stop() keyboard_listener.stop()
keyboard_listener.join() keyboard_listener.join()
# stop progress display # if errors have occurred, kill the FFmpeg processes
self.progress.stop() if len(exceptions) > 0:
decoder.kill()
encoder.kill()
# stop processor processes # stop the decoder
logger.info("Stopping processor processes") decoder_thread.stop()
for process in self.processor_processes: decoder_thread.join()
process.terminate()
# wait for processes to finish # stop the encoder
for process in self.processor_processes: encoder.join()
process.join()
# stop encoder and decoder logger.critical("ENCODER")
logger.info("Stopping decoder and encoder threads")
self.decoder.stop()
self.encoder.stop()
self.decoder.join()
self.encoder.join()
# mark processing queue as closed # clear queue and signal processors to exit
self.processing_queue.close() # multiprocessing.Queue has no Queue.queue.clear
while tasks_queue.empty() is not True:
tasks_queue.get()
for _ in range(processes):
tasks_queue.put(None)
# raise the error if there is any # close and join the process pool
if len(exception) > 0: processor_pool.close()
raise exception[0] processor_pool.join()
# restore original STDOUT and STDERR # restore original STDOUT and STDERR
sys.stdout = original_stdout sys.stdout = original_stdout
@@ -380,10 +356,14 @@ class Video2X:
logger.remove() logger.remove()
logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT) logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT)
# raise the first collected error
if len(exceptions) > 0:
raise exceptions[0]
def upscale( def upscale(
self, self,
input_path: pathlib.Path, input_path: Path,
output_path: pathlib.Path, output_path: Path,
output_width: int, output_width: int,
output_height: int, output_height: int,
noise: int, noise: int,
@@ -416,22 +396,21 @@ class Video2X:
output_path, output_path,
output_width, output_width,
output_height, output_height,
Upscaler, ProcessingMode.UPSCALE,
"upscale",
processes, processes,
( (
output_width, output_width,
output_height, output_height,
algorithm,
noise, noise,
threshold, threshold,
algorithm,
), ),
) )
def interpolate( def interpolate(
self, self,
input_path: pathlib.Path, input_path: Path,
output_path: pathlib.Path, output_path: Path,
processes: int, processes: int,
threshold: float, threshold: float,
algorithm: str, algorithm: str,
@@ -453,192 +432,7 @@ class Video2X:
output_path, output_path,
width, width,
height, height,
Interpolator, ProcessingMode.INTERPOLATE,
"interpolate",
processes, processes,
(threshold, algorithm), (threshold, algorithm),
) )
def parse_arguments() -> argparse.Namespace:
"""
parse command line arguments
:rtype argparse.Namespace: command parsing results
"""
parser = argparse.ArgumentParser(
prog="video2x",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--version", help="show version information and exit", action="store_true"
)
parser.add_argument(
"-i",
"--input",
type=pathlib.Path,
help="input file/directory path",
required=True,
)
parser.add_argument(
"-o",
"--output",
type=pathlib.Path,
help="output file/directory path",
required=True,
)
parser.add_argument(
"-p", "--processes", type=int, help="number of processes to launch", default=1
)
parser.add_argument(
"-l",
"--loglevel",
choices=["trace", "debug", "info", "success", "warning", "error", "critical"],
default="info",
)
# upscaler arguments
action = parser.add_subparsers(
help="action to perform", dest="action", required=True
)
upscale = action.add_parser(
"upscale",
help="upscale a file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
add_help=False,
)
upscale.add_argument(
"--help", action="help", help="show this help message and exit"
)
upscale.add_argument("-w", "--width", type=int, help="output width")
upscale.add_argument("-h", "--height", type=int, help="output height")
upscale.add_argument("-n", "--noise", type=int, help="denoise level", default=3)
upscale.add_argument(
"-a",
"--algorithm",
choices=UPSCALING_ALGORITHMS,
help="algorithm to use for upscaling",
default=UPSCALING_ALGORITHMS[0],
)
upscale.add_argument(
"-t",
"--threshold",
type=float,
help=(
"skip if the percent difference between two adjacent frames is below this"
" value; set to 0 to process all frames"
),
default=0,
)
# interpolator arguments
interpolate = action.add_parser(
"interpolate",
help="interpolate frames for file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
add_help=False,
)
interpolate.add_argument(
"--help", action="help", help="show this help message and exit"
)
interpolate.add_argument(
"-a",
"--algorithm",
choices=UPSCALING_ALGORITHMS,
help="algorithm to use for upscaling",
default=INTERPOLATION_ALGORITHMS[0],
)
interpolate.add_argument(
"-t",
"--threshold",
type=float,
help=(
"skip if the percent difference between two adjacent frames exceeds this"
" value; set to 100 to interpolate all frames"
),
default=10,
)
return parser.parse_args()
def main() -> int:
"""
command line entrypoint for direct CLI invocation
:rtype int: 0 if completed successfully, else other int
"""
try:
# display version and lawful informaition
if "--version" in sys.argv:
rich_print(LEGAL_INFO)
return 0
# parse command line arguments
args = parse_arguments()
# check input/output file paths
if not args.input.exists():
logger.critical(f"Cannot find input file: {args.input}")
return 1
if not args.input.is_file():
logger.critical("Input path is not a file")
return 1
if not args.output.parent.exists():
logger.critical(f"Output directory does not exist: {args.output.parent}")
return 1
# set logger level
if os.environ.get("LOGURU_LEVEL") is None:
os.environ["LOGURU_LEVEL"] = args.loglevel.upper()
# remove default handler
logger.remove()
# add new sink with custom handler
logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT)
# print package version and copyright notice
logger.opt(colors=True).info(f"<magenta>Video2X {__version__}</magenta>")
logger.opt(colors=True).info(
"<magenta>Copyright (C) 2018-2022 K4YT3X and contributors.</magenta>"
)
# initialize video2x object
video2x = Video2X()
if args.action == "upscale":
video2x.upscale(
args.input,
args.output,
args.width,
args.height,
args.noise,
args.processes,
args.threshold,
args.algorithm,
)
elif args.action == "interpolate":
video2x.interpolate(
args.input,
args.output,
args.processes,
args.threshold,
args.algorithm,
)
# don't print the traceback for manual terminations
except KeyboardInterrupt:
return 2
except Exception as error:
logger.exception(error)
return 1
# if no exceptions were produced
else:
logger.success("Processing completed successfully")
return 0