feat(rife): add support for frame interpolation and RIFE (#1244)

* feat: add RIFE files and processor/interpolator abstractions
* feat: add `rife` as processor option
* feat: add frame interpolation math except first frame
* feat: complete motion interpolation and add scene detection
* feat: improve Vulkan device validation
* fix: fix casting issues and variable names
* refactor: improve error-checking; add abstractions and factories
* refactor: improve readability of the frames processor
* docs: update changelog

Signed-off-by: k4yt3x <i@k4yt3x.com>
This commit is contained in:
K4YT3X
2024-12-01 09:55:56 +00:00
committed by GitHub
parent 2fc89e3883
commit 627f3d84a4
84 changed files with 4914 additions and 615 deletions

View File

@@ -2,12 +2,12 @@
#include <chrono>
#include <csignal>
#include <cstdarg>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <filesystem>
#include <iostream>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
@@ -62,9 +62,9 @@ struct Arguments {
// General options
std::filesystem::path in_fname;
std::filesystem::path out_fname;
StringType filter_type;
uint32_t gpu_id = 0;
StringType processor_type;
StringType hwaccel = STR("none");
uint32_t vk_device_index = 0;
bool no_copy_streams = false;
bool benchmark = false;
@@ -85,14 +85,22 @@ struct Arguments {
int delay = 0;
std::vector<std::pair<StringType, StringType>> extra_options;
// libplacebo options
std::filesystem::path shader_path;
// General processing options
int width = 0;
int height = 0;
int scaling_factor = 0;
int frm_rate_mul = 2;
float scn_det_thresh = 0.0f;
// libplacebo options
std::filesystem::path libplacebo_shader_path;
// RealESRGAN options
StringType model_name;
int scaling_factor = 0;
StringType realesrgan_model_name = STR("realesr-animevideov3");
// RIFE options
StringType rife_model_name = STR("rife-v4.6");
bool rife_uhd_mode = false;
};
// Set UNIX terminal input to non-blocking mode
@@ -156,6 +164,23 @@ bool is_valid_realesrgan_model(const StringType &model) {
return valid_realesrgan_models.count(model) > 0;
}
bool is_valid_rife_model(const StringType &model) {
static const std::unordered_set<StringType> valid_realesrgan_models = {
STR("rife"),
STR("rife-HD"),
STR("rife-UHD"),
STR("rife-anime"),
STR("rife-v2"),
STR("rife-v2.3"),
STR("rife-v2.4"),
STR("rife-v3.0"),
STR("rife-v3.1"),
STR("rife-v4"),
STR("rife-v4.6"),
};
return valid_realesrgan_models.count(model) > 0;
}
enum Libvideo2xLogLevel parse_log_level(const StringType &level_name) {
if (level_name == STR("trace")) {
return LIBVIDEO2X_LOG_LEVEL_TRACE;
@@ -177,48 +202,54 @@ enum Libvideo2xLogLevel parse_log_level(const StringType &level_name) {
}
}
int list_gpus() {
int enumerate_vulkan_devices(VkInstance *instance, std::vector<VkPhysicalDevice> &devices) {
// Create a Vulkan instance
VkInstance instance;
VkInstanceCreateInfo create_info{};
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
if (vkCreateInstance(&create_info, nullptr, &instance) != VK_SUCCESS) {
spdlog::critical("Failed to create Vulkan instance.");
VkResult result = vkCreateInstance(&create_info, nullptr, instance);
if (result != VK_SUCCESS) {
spdlog::error("Failed to create Vulkan instance.");
return -1;
}
// Enumerate physical devices
uint32_t device_count = 0;
VkResult result = vkEnumeratePhysicalDevices(instance, &device_count, nullptr);
result = vkEnumeratePhysicalDevices(*instance, &device_count, nullptr);
if (result != VK_SUCCESS || device_count == 0) {
spdlog::error("Failed to enumerate Vulkan physical devices or no devices available.");
vkDestroyInstance(*instance, nullptr);
return -1;
}
devices.resize(device_count);
result = vkEnumeratePhysicalDevices(*instance, &device_count, devices.data());
if (result != VK_SUCCESS) {
spdlog::critical("Failed to enumerate Vulkan physical devices.");
vkDestroyInstance(instance, nullptr);
spdlog::error("Failed to retrieve Vulkan physical devices.");
vkDestroyInstance(*instance, nullptr);
return -1;
}
// Check if any devices are found
if (device_count == 0) {
spdlog::critical("No Vulkan physical devices found.");
vkDestroyInstance(instance, nullptr);
return -1;
return 0;
}
int list_vulkan_devices() {
VkInstance instance;
std::vector<VkPhysicalDevice> physical_devices;
int result = enumerate_vulkan_devices(&instance, physical_devices);
if (result != 0) {
return result;
}
// Get physical device properties
std::vector<VkPhysicalDevice> physical_devices(device_count);
result = vkEnumeratePhysicalDevices(instance, &device_count, physical_devices.data());
if (result != VK_SUCCESS) {
spdlog::critical("Failed to enumerate Vulkan physical devices.");
vkDestroyInstance(instance, nullptr);
return -1;
}
uint32_t device_count = static_cast<uint32_t>(physical_devices.size());
// List GPU information
// List Vulkan device information
for (uint32_t i = 0; i < device_count; i++) {
VkPhysicalDevice device = physical_devices[i];
VkPhysicalDeviceProperties device_properties;
vkGetPhysicalDeviceProperties(device, &device_properties);
// Print GPU ID and name
// Print Vulkan device ID and name
std::cout << i << ". " << device_properties.deviceName << std::endl;
std::cout << "\tType: ";
switch (device_properties.deviceType) {
@@ -256,32 +287,34 @@ int list_gpus() {
return 0;
}
int is_valid_gpu_id(uint32_t gpu_id) {
// Create a Vulkan instance
VkInstance instance;
VkInstanceCreateInfo create_info{};
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
if (vkCreateInstance(&create_info, nullptr, &instance) != VK_SUCCESS) {
spdlog::error("Failed to create Vulkan instance.");
int get_vulkan_device_prop(uint32_t vk_device_index, VkPhysicalDeviceProperties *dev_props) {
if (dev_props == nullptr) {
spdlog::error("Invalid device properties pointer.");
return -1;
}
// Enumerate physical devices
uint32_t device_count = 0;
VkResult result = vkEnumeratePhysicalDevices(instance, &device_count, nullptr);
if (result != VK_SUCCESS) {
spdlog::error("Failed to enumerate Vulkan physical devices.");
vkDestroyInstance(instance, nullptr);
return -1;
VkInstance instance;
std::vector<VkPhysicalDevice> devices;
int result = enumerate_vulkan_devices(&instance, devices);
if (result != 0) {
return result;
}
uint32_t device_count = static_cast<uint32_t>(devices.size());
// Check if the Vulkan device ID is valid
if (vk_device_index >= device_count) {
vkDestroyInstance(instance, nullptr);
return -2;
}
// Get device properties for the specified Vulkan device ID
vkGetPhysicalDeviceProperties(devices[vk_device_index], dev_props);
// Clean up Vulkan instance
vkDestroyInstance(instance, nullptr);
if (gpu_id >= device_count) {
return 0;
}
return 1;
return 0;
}
// Wrapper function for video processing thread
@@ -289,7 +322,7 @@ void process_video_thread(
Arguments *arguments,
int *proc_ret,
AVHWDeviceType hw_device_type,
FilterConfig *filter_config,
ProcessorConfig *filter_config,
EncoderConfig *encoder_config,
VideoProcessingContext *proc_ctx
) {
@@ -314,7 +347,7 @@ void process_video_thread(
out_fname,
log_level,
arguments->benchmark,
arguments->gpu_id,
arguments->vk_device_index,
hw_device_type,
filter_config,
encoder_config,
@@ -355,17 +388,17 @@ int main(int argc, char **argv) {
"info"), "Set verbosity level (trace, debug, info, warn, error, critical, none)")
("no-progress", po::bool_switch(&arguments.no_progress),
"Do not display the progress bar")
("list-gpus,l", "List the available GPUs")
("list-devices,l", "List the available Vulkan devices (GPUs)")
// General Processing Options
("input,i", PO_STR_VALUE<StringType>(), "Input video file path")
("output,o", PO_STR_VALUE<StringType>(), "Output video file path")
("filter,f", PO_STR_VALUE<StringType>(&arguments.filter_type),
"Filter to use: 'libplacebo' or 'realesrgan'")
("gpu,g", po::value<uint32_t>(&arguments.gpu_id)->default_value(0),
"GPU ID (Vulkan device index)")
("processor,p", PO_STR_VALUE<StringType>(&arguments.processor_type),
"Processor to use: 'libplacebo', 'realesrgan', or 'rife'")
("hwaccel,a", PO_STR_VALUE<StringType>(&arguments.hwaccel)->default_value(STR("none"),
"none"), "Hardware acceleration method (mostly for decoding)")
("device,d", po::value<uint32_t>(&arguments.vk_device_index)->default_value(0),
"Vulkan device index (GPU ID)")
("benchmark,b", po::bool_switch(&arguments.benchmark),
"Discard processed frames and calculate average FPS; "
"useful for detecting encoder bottlenecks")
@@ -407,27 +440,56 @@ int main(int argc, char **argv) {
"Additional AVOption(s) for the encoder (format: -e key=value)")
;
po::options_description libplacebo_opts("libplacebo options");
libplacebo_opts.add_options()
("shader,s", PO_STR_VALUE<StringType>(), "Name/path of the GLSL shader file to use")
po::options_description upscale_opts("Upscaling options");
upscale_opts.add_options()
("width,w", po::value<int>(&arguments.width), "Output width")
("height,h", po::value<int>(&arguments.height), "Output height")
("scaling-factor,s", po::value<int>(&arguments.scaling_factor), "Scaling factor")
;
po::options_description interp_opts("Frame interpolation options");
interp_opts.add_options()
("frame-rate-mul,m",
po::value<int>(&arguments.frm_rate_mul)->default_value(2),
"Frame rate multiplier")
("scene-thresh,t", po::value<float>(&arguments.scn_det_thresh)->default_value(10.0f),
"Scene detection threshold")
;
po::options_description libplacebo_opts("libplacebo options");
libplacebo_opts.add_options()
("libplacebo-shader", PO_STR_VALUE<StringType>(),
"Name/path of the GLSL shader file to use")
;
// RealESRGAN options
po::options_description realesrgan_opts("RealESRGAN options");
realesrgan_opts.add_options()
("model,m", PO_STR_VALUE<StringType>(&arguments.model_name), "Name of the model to use")
("scale,r", po::value<int>(&arguments.scaling_factor), "Scaling factor (2, 3, or 4)")
("realesrgan-model", PO_STR_VALUE<StringType>(&arguments.realesrgan_model_name),
"Name of the RealESRGAN model to use")
;
// RIFE options
po::options_description rife_opts("RIFE options");
rife_opts.add_options()
("rife-model", PO_STR_VALUE<StringType>(&arguments.rife_model_name),
"Name of the RIFE model to use")
("rife-uhd", po::bool_switch(&arguments.rife_uhd_mode),
"Enable Ultra HD mode")
;
// clang-format on
// Combine all options
all_opts.add(encoder_opts).add(libplacebo_opts).add(realesrgan_opts);
all_opts.add(encoder_opts)
.add(upscale_opts)
.add(interp_opts)
.add(libplacebo_opts)
.add(realesrgan_opts)
.add(rife_opts);
// Positional arguments
po::positional_options_description p;
p.add("input", 1).add("output", 1).add("filter", 1);
p.add("input", 1).add("output", 1).add("processor", 1);
#ifdef _WIN32
po::variables_map vm;
@@ -460,13 +522,19 @@ int main(int argc, char **argv) {
return 0;
}
if (vm.count("list-gpus")) {
return list_gpus();
if (vm.count("list-devices")) {
return list_vulkan_devices();
}
// Print program banner
spdlog::info("Video2X version {}", LIBVIDEO2X_VERSION_STRING);
// spdlog::info("Copyright (C) 2018-2024 K4YT3X and contributors.");
// spdlog::info("Licensed under GNU AGPL version 3.");
// Assign positional arguments
if (vm.count("input")) {
arguments.in_fname = std::filesystem::path(vm["input"].as<StringType>());
spdlog::info("Processing file: {}", arguments.in_fname.u8string());
} else {
spdlog::critical("Input file path is required.");
return 1;
@@ -479,12 +547,12 @@ int main(int argc, char **argv) {
return 1;
}
if (!vm.count("filter")) {
spdlog::critical("Filter type is required (libplacebo or realesrgan).");
if (!vm.count("processor")) {
spdlog::critical("Processor type is required (libplacebo, realesrgan, or rife).");
return 1;
}
// Parse avoptions
// Parse extra AVOptions
if (vm.count("extra-encoder-option")) {
for (const auto &opt : vm["extra-encoder-option"].as<std::vector<StringType>>()) {
size_t eq_pos = opt.find('=');
@@ -499,16 +567,21 @@ int main(int argc, char **argv) {
}
}
if (vm.count("shader")) {
arguments.shader_path = std::filesystem::path(vm["shader"].as<StringType>());
if (vm.count("libplacebo-shader")) {
arguments.libplacebo_shader_path =
std::filesystem::path(vm["libplacebo-shader"].as<StringType>());
}
if (vm.count("model")) {
if (!is_valid_realesrgan_model(vm["model"].as<StringType>())) {
spdlog::critical(
"Invalid model specified. Must be 'realesrgan-plus', "
"'realesrgan-plus-anime', or 'realesr-animevideov3'."
);
if (vm.count("libplacebo-model")) {
if (!is_valid_realesrgan_model(vm["realesrgan-model"].as<StringType>())) {
spdlog::critical("Invalid model specified.");
return 1;
}
}
if (vm.count("rife-model")) {
if (!is_valid_rife_model(vm["rife-model"].as<StringType>())) {
spdlog::critical("Invalid RIFE model specified.");
return 1;
}
}
@@ -521,36 +594,59 @@ int main(int argc, char **argv) {
}
// Additional validations
if (arguments.filter_type == STR("libplacebo")) {
if (arguments.shader_path.empty() || arguments.width == 0 || arguments.height == 0) {
spdlog::critical(
"For libplacebo, shader name/path (-s), width (-w), "
"and height (-h) are required."
);
return 1;
}
} else if (arguments.filter_type == STR("realesrgan")) {
if (arguments.scaling_factor == 0 || arguments.model_name.empty()) {
spdlog::critical("For realesrgan, scaling factor (-r) and model (-m) are required.");
if (arguments.width < 0 || arguments.height < 0) {
spdlog::critical("Invalid output resolution specified.");
return 1;
}
if (arguments.scaling_factor < 0) {
spdlog::critical("Invalid scaling factor specified.");
return 1;
}
if (arguments.frm_rate_mul <= 1) {
spdlog::critical("Invalid frame rate multiplier specified.");
return 1;
}
if (arguments.scn_det_thresh < 0.0f || arguments.scn_det_thresh > 100.0f) {
spdlog::critical("Invalid scene detection threshold specified.");
return 1;
}
if (arguments.processor_type == STR("libplacebo")) {
if (arguments.libplacebo_shader_path.empty() || arguments.width == 0 ||
arguments.height == 0) {
spdlog::critical("Shader name/path, width, and height are required for libplacebo.");
return 1;
}
} else if (arguments.processor_type == STR("realesrgan")) {
if (arguments.scaling_factor != 2 && arguments.scaling_factor != 3 &&
arguments.scaling_factor != 4) {
spdlog::critical("Scaling factor must be 2, 3, or 4.");
spdlog::critical("Scaling factor must be 2, 3, or 4 for RealESRGAN.");
return 1;
}
} else {
spdlog::critical("Invalid filter type specified. Must be 'libplacebo' or 'realesrgan'.");
} else if (arguments.processor_type != STR("rife")) {
spdlog::critical(
"Invalid processor specified. Must be 'libplacebo', 'realesrgan', or 'rife'."
);
return 1;
}
// Validate GPU ID
int gpu_status = is_valid_gpu_id(arguments.gpu_id);
if (gpu_status < 0) {
spdlog::warn("Unable to validate GPU ID.");
} else if (arguments.gpu_id > 0 && gpu_status == 0) {
spdlog::critical("Invalid GPU ID specified.");
return 1;
VkPhysicalDeviceProperties dev_props;
int get_vulkan_dev_ret = get_vulkan_device_prop(arguments.vk_device_index, &dev_props);
if (get_vulkan_dev_ret != 0) {
if (get_vulkan_dev_ret == -2) {
spdlog::critical("Invalid Vulkan device ID specified.");
return 1;
} else {
spdlog::warn("Unable to validate Vulkan device ID.");
return 1;
}
} else {
// Warn if the selected device is a CPU
spdlog::info("Using Vulkan device: {} ({})", dev_props.deviceName, dev_props.deviceID);
if (dev_props.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU) {
spdlog::warn("The selected Vulkan device is a CPU device.");
}
}
// Validate bitrate
@@ -605,36 +701,59 @@ int main(int argc, char **argv) {
break;
}
// Print program version and processing information
spdlog::info("Video2X version {}", LIBVIDEO2X_VERSION_STRING);
spdlog::info("Processing file: {}", arguments.in_fname.u8string());
#ifdef _WIN32
std::wstring shader_path_str = arguments.shader_path.wstring();
std::wstring shader_path_str = arguments.libplacebo_shader_path.wstring();
#else
std::string shader_path_str = arguments.shader_path.string();
std::string shader_path_str = arguments.libplacebo_shader_path.string();
#endif
// Setup filter configurations based on the parsed arguments
FilterConfig filter_config;
if (arguments.filter_type == STR("libplacebo")) {
filter_config.filter_type = FILTER_LIBPLACEBO;
filter_config.config.libplacebo.out_width = arguments.width;
filter_config.config.libplacebo.out_height = arguments.height;
filter_config.config.libplacebo.shader_path = shader_path_str.c_str();
} else if (arguments.filter_type == STR("realesrgan")) {
filter_config.filter_type = FILTER_REALESRGAN;
filter_config.config.realesrgan.tta_mode = false;
filter_config.config.realesrgan.scaling_factor = arguments.scaling_factor;
filter_config.config.realesrgan.model_name = arguments.model_name.c_str();
ProcessorConfig processor_config;
processor_config.width = arguments.width;
processor_config.height = arguments.height;
processor_config.scaling_factor = arguments.scaling_factor;
processor_config.frm_rate_mul = arguments.frm_rate_mul;
processor_config.scn_det_thresh = arguments.scn_det_thresh;
if (arguments.processor_type == STR("libplacebo")) {
processor_config.processor_type = PROCESSOR_LIBPLACEBO;
processor_config.config.libplacebo.shader_path = shader_path_str.c_str();
} else if (arguments.processor_type == STR("realesrgan")) {
processor_config.processor_type = PROCESSOR_REALESRGAN;
processor_config.config.realesrgan.tta_mode = false;
processor_config.config.realesrgan.model_name = arguments.realesrgan_model_name.c_str();
} else if (arguments.processor_type == STR("rife")) {
processor_config.processor_type = PROCESSOR_RIFE;
processor_config.config.rife.tta_mode = false;
processor_config.config.rife.tta_temporal_mode = false;
processor_config.config.rife.uhd_mode = arguments.rife_uhd_mode;
processor_config.config.rife.num_threads = 0;
processor_config.config.rife.model_name = arguments.rife_model_name.c_str();
bool rife_v2 = false;
bool rife_v4 = false;
if (arguments.rife_model_name.find(STR("rife-v2")) != StringType::npos) {
rife_v2 = true;
} else if (arguments.rife_model_name.find(STR("rife-v3")) != StringType::npos) {
rife_v2 = true;
} else if (arguments.rife_model_name.find(STR("rife-v4")) != StringType::npos) {
rife_v4 = true;
} else if (arguments.rife_model_name.find(STR("rife")) == StringType::npos) {
spdlog::critical("Unknown RIFE model generation.");
return 1;
}
processor_config.config.rife.rife_v2 = rife_v2;
processor_config.config.rife.rife_v4 = rife_v4;
}
// Setup encoder configuration
EncoderConfig encoder_config;
encoder_config.codec = codec->id;
encoder_config.copy_streams = !arguments.no_copy_streams;
encoder_config.width = arguments.width;
encoder_config.height = arguments.height;
encoder_config.width = 0;
encoder_config.height = 0;
encoder_config.pix_fmt = pix_fmt;
encoder_config.bit_rate = arguments.bit_rate;
encoder_config.rc_buffer_size = arguments.rc_buffer_size;
@@ -713,7 +832,7 @@ int main(int argc, char **argv) {
&arguments,
&proc_ret,
hw_device_type,
&filter_config,
&processor_config,
&encoder_config,
&proc_ctx
);