feat(video2x): added the --listgpus option and GPU ID validation

Signed-off-by: k4yt3x <i@k4yt3x.com>
This commit is contained in:
k4yt3x
2024-11-04 00:00:00 +00:00
parent 850e0fde9c
commit bcbe33d5dc
5 changed files with 191 additions and 51 deletions

View File

@@ -33,6 +33,7 @@ extern "C" {
}
#include <spdlog/spdlog.h>
#include <vulkan/vulkan.h>
#ifdef _WIN32
#define BOOST_PROGRAM_OPTIONS_WCHAR_T
@@ -49,17 +50,22 @@ namespace po = boost::program_options;
// Indicate if a newline needs to be printed before the next output
std::atomic<bool> newline_required = false;
// Mutex for synchronizing access to VideoProcessingContext
std::mutex proc_ctx_mutex;
// Structure to hold parsed arguments
struct Arguments {
StringType loglevel = STR("info");
bool noprogress = false;
// General options
std::filesystem::path in_fname;
std::filesystem::path out_fname;
StringType filter_type;
uint32_t gpuid = 0;
StringType hwaccel = STR("none");
bool nocopystreams = false;
bool benchmark = false;
StringType loglevel = STR("info");
bool noprogress = false;
// Encoder options
StringType codec = STR("libx264");
@@ -74,7 +80,6 @@ struct Arguments {
int out_height = 0;
// RealESRGAN options
int gpuid = 0;
StringType model_name;
int scaling_factor = 0;
};
@@ -161,8 +166,112 @@ enum Libvideo2xLogLevel parse_log_level(const StringType &level_name) {
}
}
// Mutex for synchronizing access to VideoProcessingContext
std::mutex proc_ctx_mutex;
int list_gpus() {
// Create a Vulkan instance
VkInstance instance;
VkInstanceCreateInfo create_info{};
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
if (vkCreateInstance(&create_info, nullptr, &instance) != VK_SUCCESS) {
spdlog::critical("Failed to create Vulkan instance.");
return -1;
}
// Enumerate physical devices
uint32_t device_count = 0;
VkResult result = vkEnumeratePhysicalDevices(instance, &device_count, nullptr);
if (result != VK_SUCCESS) {
spdlog::critical("Failed to enumerate Vulkan physical devices.");
vkDestroyInstance(instance, nullptr);
return -1;
}
// Check if any devices are found
if (device_count == 0) {
spdlog::critical("No Vulkan physical devices found.");
vkDestroyInstance(instance, nullptr);
return -1;
}
// Get physical device properties
std::vector<VkPhysicalDevice> physical_devices(device_count);
result = vkEnumeratePhysicalDevices(instance, &device_count, physical_devices.data());
if (result != VK_SUCCESS) {
spdlog::critical("Failed to enumerate Vulkan physical devices.");
vkDestroyInstance(instance, nullptr);
return -1;
}
// List GPU information
for (uint32_t i = 0; i < device_count; i++) {
VkPhysicalDevice device = physical_devices[i];
VkPhysicalDeviceProperties device_properties;
vkGetPhysicalDeviceProperties(device, &device_properties);
// Print GPU ID and name
std::cout << i << ". " << device_properties.deviceName << std::endl;
std::cout << "\tType: ";
switch (device_properties.deviceType) {
case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
std::cout << "Integrated GPU";
break;
case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
std::cout << "Discrete GPU";
break;
case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
std::cout << "Virtual GPU";
break;
case VK_PHYSICAL_DEVICE_TYPE_CPU:
std::cout << "CPU";
break;
default:
std::cout << "Unknown";
break;
}
std::cout << std::endl;
// Print Vulkan API version
std::cout << "\tVulkan API Version: " << VK_VERSION_MAJOR(device_properties.apiVersion)
<< "." << VK_VERSION_MINOR(device_properties.apiVersion) << "."
<< VK_VERSION_PATCH(device_properties.apiVersion) << std::endl;
// Print driver version
std::cout << "\tDriver Version: " << VK_VERSION_MAJOR(device_properties.driverVersion)
<< "." << VK_VERSION_MINOR(device_properties.driverVersion) << "."
<< VK_VERSION_PATCH(device_properties.driverVersion) << std::endl;
}
// Clean up Vulkan instance
vkDestroyInstance(instance, nullptr);
return 0;
}
int is_valid_gpu_id(uint32_t gpu_id) {
// Create a Vulkan instance
VkInstance instance;
VkInstanceCreateInfo create_info{};
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
if (vkCreateInstance(&create_info, nullptr, &instance) != VK_SUCCESS) {
spdlog::error("Failed to create Vulkan instance.");
return -1;
}
// Enumerate physical devices
uint32_t device_count = 0;
VkResult result = vkEnumeratePhysicalDevices(instance, &device_count, nullptr);
if (result != VK_SUCCESS) {
spdlog::error("Failed to enumerate Vulkan physical devices.");
vkDestroyInstance(instance, nullptr);
return -1;
}
// Clean up Vulkan instance
vkDestroyInstance(instance, nullptr);
if (gpu_id >= device_count) {
return 0;
}
return 1;
}
// Wrapper function for video processing thread
void process_video_thread(
@@ -194,6 +303,7 @@ void process_video_thread(
out_fname,
log_level,
arguments->benchmark,
arguments->gpuid,
hw_device_type,
filter_config,
encoder_config,
@@ -224,11 +334,13 @@ int main(int argc, char **argv) {
("version,v", "Print program version")
("loglevel", PO_STR_VALUE<StringType>(&arguments.loglevel)->default_value(STR("info"), "info"), "Set log level (trace, debug, info, warn, error, critical, none)")
("noprogress", po::bool_switch(&arguments.noprogress), "Do not display the progress bar")
("listgpus", "List the available GPUs")
// General Processing Options
("input,i", PO_STR_VALUE<StringType>(), "Input video file path")
("output,o", PO_STR_VALUE<StringType>(), "Output video file path")
("filter,f", PO_STR_VALUE<StringType>(&arguments.filter_type), "Filter to use: 'libplacebo' or 'realesrgan'")
("gpuid,g", po::value<uint32_t>(&arguments.gpuid)->default_value(0), "Vulkan GPU ID (default: 0)")
("hwaccel,a", PO_STR_VALUE<StringType>(&arguments.hwaccel)->default_value(STR("none"), "none"), "Hardware acceleration method (default: none)")
("nocopystreams", po::bool_switch(&arguments.nocopystreams), "Do not copy audio and subtitle streams")
("benchmark", po::bool_switch(&arguments.benchmark), "Discard processed frames and calculate average FPS")
@@ -246,7 +358,6 @@ int main(int argc, char **argv) {
("height,h", po::value<int>(&arguments.out_height), "Output height")
// RealESRGAN options
("gpuid,g", po::value<int>(&arguments.gpuid)->default_value(0), "Vulkan GPU ID (default: 0)")
("model,m", PO_STR_VALUE<StringType>(&arguments.model_name), "Name of the model to use")
("scale,r", po::value<int>(&arguments.scaling_factor), "Scaling factor (2, 3, or 4)")
;
@@ -274,23 +385,27 @@ int main(int argc, char **argv) {
return 0;
}
if (vm.count("listgpus")) {
return list_gpus();
}
// Assign positional arguments
if (vm.count("input")) {
arguments.in_fname = std::filesystem::path(vm["input"].as<StringType>());
} else {
spdlog::error("Error: Input file path is required.");
spdlog::critical("Input file path is required.");
return 1;
}
if (vm.count("output")) {
arguments.out_fname = std::filesystem::path(vm["output"].as<StringType>());
} else if (!arguments.benchmark) {
spdlog::error("Error: Output file path is required.");
spdlog::critical("Output file path is required.");
return 1;
}
if (!vm.count("filter")) {
spdlog::error("Error: Filter type is required (libplacebo or realesrgan).");
spdlog::critical("Filter type is required (libplacebo or realesrgan).");
return 1;
}
@@ -300,18 +415,18 @@ int main(int argc, char **argv) {
if (vm.count("model")) {
if (!is_valid_realesrgan_model(vm["model"].as<StringType>())) {
spdlog::error(
"Error: Invalid model specified. Must be 'realesrgan-plus', "
spdlog::critical(
"Invalid model specified. Must be 'realesrgan-plus', "
"'realesrgan-plus-anime', or 'realesr-animevideov3'."
);
return 1;
}
}
} catch (const po::error &e) {
spdlog::error("Error parsing options: {}", e.what());
spdlog::critical("Error parsing options: {}", e.what());
return 1;
} catch (const std::exception &e) {
spdlog::error("Unexpected exception caught while parsing options: {}", e.what());
spdlog::critical("Unexpected exception caught while parsing options: {}", e.what());
return 1;
}
@@ -319,45 +434,52 @@ int main(int argc, char **argv) {
if (arguments.filter_type == STR("libplacebo")) {
if (arguments.shader_path.empty() || arguments.out_width == 0 ||
arguments.out_height == 0) {
spdlog::error(
"Error: For libplacebo, shader name/path (-s), width (-w), "
spdlog::critical(
"For libplacebo, shader name/path (-s), width (-w), "
"and height (-h) are required."
);
return 1;
}
} else if (arguments.filter_type == STR("realesrgan")) {
if (arguments.scaling_factor == 0 || arguments.model_name.empty()) {
spdlog::error("Error: For realesrgan, scaling factor (-r) and model (-m) are required."
);
spdlog::critical("For realesrgan, scaling factor (-r) and model (-m) are required.");
return 1;
}
if (arguments.scaling_factor != 2 && arguments.scaling_factor != 3 &&
arguments.scaling_factor != 4) {
spdlog::error("Error: Scaling factor must be 2, 3, or 4.");
spdlog::critical("Scaling factor must be 2, 3, or 4.");
return 1;
}
} else {
spdlog::error("Error: Invalid filter type specified. Must be 'libplacebo' or 'realesrgan'."
);
spdlog::critical("Invalid filter type specified. Must be 'libplacebo' or 'realesrgan'.");
return 1;
}
// Validate GPU ID
int gpu_status = is_valid_gpu_id(arguments.gpuid);
if (gpu_status < 0) {
spdlog::warn("Unable to validate GPU ID.");
} else if (arguments.gpuid > 0 && gpu_status == 0) {
spdlog::critical("Invalid GPU ID specified.");
return 1;
}
// Validate bitrate
if (arguments.bitrate < 0) {
spdlog::error("Error: Invalid bitrate specified.");
spdlog::critical("Invalid bitrate specified.");
return 1;
}
// Validate CRF
if (arguments.crf < 0.0f || arguments.crf > 51.0f) {
spdlog::error("Error: CRF must be between 0 and 51.");
spdlog::critical("CRF must be between 0 and 51.");
return 1;
}
// Parse codec to AVCodec
const AVCodec *codec = avcodec_find_encoder_by_name(wstring_to_utf8(arguments.codec).c_str());
if (!codec) {
spdlog::error("Error: Codec '{}' not found.", wstring_to_utf8(arguments.codec));
spdlog::critical("Codec '{}' not found.", wstring_to_utf8(arguments.codec));
return 1;
}
@@ -366,7 +488,7 @@ int main(int argc, char **argv) {
if (!arguments.pix_fmt.empty()) {
pix_fmt = av_get_pix_fmt(wstring_to_utf8(arguments.pix_fmt).c_str());
if (pix_fmt == AV_PIX_FMT_NONE) {
spdlog::error("Error: Invalid pixel format '{}'.", wstring_to_utf8(arguments.pix_fmt));
spdlog::critical("Invalid pixel format '{}'.", wstring_to_utf8(arguments.pix_fmt));
return 1;
}
}
@@ -400,6 +522,10 @@ int main(int argc, char **argv) {
break;
}
// Print program version and processing information
spdlog::info("Video2X version {}", LIBVIDEO2X_VERSION_STRING);
spdlog::info("Processing file: {}", arguments.in_fname.u8string());
#ifdef _WIN32
std::wstring shader_path_str = arguments.shader_path.wstring();
#else
@@ -415,7 +541,6 @@ int main(int argc, char **argv) {
filter_config.config.libplacebo.shader_path = shader_path_str.c_str();
} else if (arguments.filter_type == STR("realesrgan")) {
filter_config.filter_type = FILTER_REALESRGAN;
filter_config.config.realesrgan.gpuid = arguments.gpuid;
filter_config.config.realesrgan.tta_mode = false;
filter_config.config.realesrgan.scaling_factor = arguments.scaling_factor;
filter_config.config.realesrgan.model_name = arguments.model_name.c_str();
@@ -439,8 +564,8 @@ int main(int argc, char **argv) {
if (arguments.hwaccel != STR("none")) {
hw_device_type = av_hwdevice_find_type_by_name(wstring_to_utf8(arguments.hwaccel).c_str());
if (hw_device_type == AV_HWDEVICE_TYPE_NONE) {
spdlog::error(
"Error: Invalid hardware device type '{}'.", wstring_to_utf8(arguments.hwaccel)
spdlog::critical(
"Invalid hardware device type '{}'.", wstring_to_utf8(arguments.hwaccel)
);
return 1;
}