mirror of
https://github.com/k4yt3x/video2x.git
synced 2026-02-14 00:54:47 +08:00
feat(video2x): added the --listgpus option and GPU ID validation
Signed-off-by: k4yt3x <i@k4yt3x.com>
This commit is contained in:
179
src/video2x.cpp
179
src/video2x.cpp
@@ -33,6 +33,7 @@ extern "C" {
|
||||
}
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <vulkan/vulkan.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#define BOOST_PROGRAM_OPTIONS_WCHAR_T
|
||||
@@ -49,17 +50,22 @@ namespace po = boost::program_options;
|
||||
// Indicate if a newline needs to be printed before the next output
|
||||
std::atomic<bool> newline_required = false;
|
||||
|
||||
// Mutex for synchronizing access to VideoProcessingContext
|
||||
std::mutex proc_ctx_mutex;
|
||||
|
||||
// Structure to hold parsed arguments
|
||||
struct Arguments {
|
||||
StringType loglevel = STR("info");
|
||||
bool noprogress = false;
|
||||
|
||||
// General options
|
||||
std::filesystem::path in_fname;
|
||||
std::filesystem::path out_fname;
|
||||
StringType filter_type;
|
||||
uint32_t gpuid = 0;
|
||||
StringType hwaccel = STR("none");
|
||||
bool nocopystreams = false;
|
||||
bool benchmark = false;
|
||||
StringType loglevel = STR("info");
|
||||
bool noprogress = false;
|
||||
|
||||
// Encoder options
|
||||
StringType codec = STR("libx264");
|
||||
@@ -74,7 +80,6 @@ struct Arguments {
|
||||
int out_height = 0;
|
||||
|
||||
// RealESRGAN options
|
||||
int gpuid = 0;
|
||||
StringType model_name;
|
||||
int scaling_factor = 0;
|
||||
};
|
||||
@@ -161,8 +166,112 @@ enum Libvideo2xLogLevel parse_log_level(const StringType &level_name) {
|
||||
}
|
||||
}
|
||||
|
||||
// Mutex for synchronizing access to VideoProcessingContext
|
||||
std::mutex proc_ctx_mutex;
|
||||
int list_gpus() {
|
||||
// Create a Vulkan instance
|
||||
VkInstance instance;
|
||||
VkInstanceCreateInfo create_info{};
|
||||
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
|
||||
if (vkCreateInstance(&create_info, nullptr, &instance) != VK_SUCCESS) {
|
||||
spdlog::critical("Failed to create Vulkan instance.");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Enumerate physical devices
|
||||
uint32_t device_count = 0;
|
||||
VkResult result = vkEnumeratePhysicalDevices(instance, &device_count, nullptr);
|
||||
if (result != VK_SUCCESS) {
|
||||
spdlog::critical("Failed to enumerate Vulkan physical devices.");
|
||||
vkDestroyInstance(instance, nullptr);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Check if any devices are found
|
||||
if (device_count == 0) {
|
||||
spdlog::critical("No Vulkan physical devices found.");
|
||||
vkDestroyInstance(instance, nullptr);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get physical device properties
|
||||
std::vector<VkPhysicalDevice> physical_devices(device_count);
|
||||
result = vkEnumeratePhysicalDevices(instance, &device_count, physical_devices.data());
|
||||
if (result != VK_SUCCESS) {
|
||||
spdlog::critical("Failed to enumerate Vulkan physical devices.");
|
||||
vkDestroyInstance(instance, nullptr);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// List GPU information
|
||||
for (uint32_t i = 0; i < device_count; i++) {
|
||||
VkPhysicalDevice device = physical_devices[i];
|
||||
VkPhysicalDeviceProperties device_properties;
|
||||
vkGetPhysicalDeviceProperties(device, &device_properties);
|
||||
|
||||
// Print GPU ID and name
|
||||
std::cout << i << ". " << device_properties.deviceName << std::endl;
|
||||
std::cout << "\tType: ";
|
||||
switch (device_properties.deviceType) {
|
||||
case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
|
||||
std::cout << "Integrated GPU";
|
||||
break;
|
||||
case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
|
||||
std::cout << "Discrete GPU";
|
||||
break;
|
||||
case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
|
||||
std::cout << "Virtual GPU";
|
||||
break;
|
||||
case VK_PHYSICAL_DEVICE_TYPE_CPU:
|
||||
std::cout << "CPU";
|
||||
break;
|
||||
default:
|
||||
std::cout << "Unknown";
|
||||
break;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// Print Vulkan API version
|
||||
std::cout << "\tVulkan API Version: " << VK_VERSION_MAJOR(device_properties.apiVersion)
|
||||
<< "." << VK_VERSION_MINOR(device_properties.apiVersion) << "."
|
||||
<< VK_VERSION_PATCH(device_properties.apiVersion) << std::endl;
|
||||
|
||||
// Print driver version
|
||||
std::cout << "\tDriver Version: " << VK_VERSION_MAJOR(device_properties.driverVersion)
|
||||
<< "." << VK_VERSION_MINOR(device_properties.driverVersion) << "."
|
||||
<< VK_VERSION_PATCH(device_properties.driverVersion) << std::endl;
|
||||
}
|
||||
|
||||
// Clean up Vulkan instance
|
||||
vkDestroyInstance(instance, nullptr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int is_valid_gpu_id(uint32_t gpu_id) {
|
||||
// Create a Vulkan instance
|
||||
VkInstance instance;
|
||||
VkInstanceCreateInfo create_info{};
|
||||
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
|
||||
if (vkCreateInstance(&create_info, nullptr, &instance) != VK_SUCCESS) {
|
||||
spdlog::error("Failed to create Vulkan instance.");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Enumerate physical devices
|
||||
uint32_t device_count = 0;
|
||||
VkResult result = vkEnumeratePhysicalDevices(instance, &device_count, nullptr);
|
||||
if (result != VK_SUCCESS) {
|
||||
spdlog::error("Failed to enumerate Vulkan physical devices.");
|
||||
vkDestroyInstance(instance, nullptr);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Clean up Vulkan instance
|
||||
vkDestroyInstance(instance, nullptr);
|
||||
|
||||
if (gpu_id >= device_count) {
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Wrapper function for video processing thread
|
||||
void process_video_thread(
|
||||
@@ -194,6 +303,7 @@ void process_video_thread(
|
||||
out_fname,
|
||||
log_level,
|
||||
arguments->benchmark,
|
||||
arguments->gpuid,
|
||||
hw_device_type,
|
||||
filter_config,
|
||||
encoder_config,
|
||||
@@ -224,11 +334,13 @@ int main(int argc, char **argv) {
|
||||
("version,v", "Print program version")
|
||||
("loglevel", PO_STR_VALUE<StringType>(&arguments.loglevel)->default_value(STR("info"), "info"), "Set log level (trace, debug, info, warn, error, critical, none)")
|
||||
("noprogress", po::bool_switch(&arguments.noprogress), "Do not display the progress bar")
|
||||
("listgpus", "List the available GPUs")
|
||||
|
||||
// General Processing Options
|
||||
("input,i", PO_STR_VALUE<StringType>(), "Input video file path")
|
||||
("output,o", PO_STR_VALUE<StringType>(), "Output video file path")
|
||||
("filter,f", PO_STR_VALUE<StringType>(&arguments.filter_type), "Filter to use: 'libplacebo' or 'realesrgan'")
|
||||
("gpuid,g", po::value<uint32_t>(&arguments.gpuid)->default_value(0), "Vulkan GPU ID (default: 0)")
|
||||
("hwaccel,a", PO_STR_VALUE<StringType>(&arguments.hwaccel)->default_value(STR("none"), "none"), "Hardware acceleration method (default: none)")
|
||||
("nocopystreams", po::bool_switch(&arguments.nocopystreams), "Do not copy audio and subtitle streams")
|
||||
("benchmark", po::bool_switch(&arguments.benchmark), "Discard processed frames and calculate average FPS")
|
||||
@@ -246,7 +358,6 @@ int main(int argc, char **argv) {
|
||||
("height,h", po::value<int>(&arguments.out_height), "Output height")
|
||||
|
||||
// RealESRGAN options
|
||||
("gpuid,g", po::value<int>(&arguments.gpuid)->default_value(0), "Vulkan GPU ID (default: 0)")
|
||||
("model,m", PO_STR_VALUE<StringType>(&arguments.model_name), "Name of the model to use")
|
||||
("scale,r", po::value<int>(&arguments.scaling_factor), "Scaling factor (2, 3, or 4)")
|
||||
;
|
||||
@@ -274,23 +385,27 @@ int main(int argc, char **argv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (vm.count("listgpus")) {
|
||||
return list_gpus();
|
||||
}
|
||||
|
||||
// Assign positional arguments
|
||||
if (vm.count("input")) {
|
||||
arguments.in_fname = std::filesystem::path(vm["input"].as<StringType>());
|
||||
} else {
|
||||
spdlog::error("Error: Input file path is required.");
|
||||
spdlog::critical("Input file path is required.");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (vm.count("output")) {
|
||||
arguments.out_fname = std::filesystem::path(vm["output"].as<StringType>());
|
||||
} else if (!arguments.benchmark) {
|
||||
spdlog::error("Error: Output file path is required.");
|
||||
spdlog::critical("Output file path is required.");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!vm.count("filter")) {
|
||||
spdlog::error("Error: Filter type is required (libplacebo or realesrgan).");
|
||||
spdlog::critical("Filter type is required (libplacebo or realesrgan).");
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -300,18 +415,18 @@ int main(int argc, char **argv) {
|
||||
|
||||
if (vm.count("model")) {
|
||||
if (!is_valid_realesrgan_model(vm["model"].as<StringType>())) {
|
||||
spdlog::error(
|
||||
"Error: Invalid model specified. Must be 'realesrgan-plus', "
|
||||
spdlog::critical(
|
||||
"Invalid model specified. Must be 'realesrgan-plus', "
|
||||
"'realesrgan-plus-anime', or 'realesr-animevideov3'."
|
||||
);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
} catch (const po::error &e) {
|
||||
spdlog::error("Error parsing options: {}", e.what());
|
||||
spdlog::critical("Error parsing options: {}", e.what());
|
||||
return 1;
|
||||
} catch (const std::exception &e) {
|
||||
spdlog::error("Unexpected exception caught while parsing options: {}", e.what());
|
||||
spdlog::critical("Unexpected exception caught while parsing options: {}", e.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -319,45 +434,52 @@ int main(int argc, char **argv) {
|
||||
if (arguments.filter_type == STR("libplacebo")) {
|
||||
if (arguments.shader_path.empty() || arguments.out_width == 0 ||
|
||||
arguments.out_height == 0) {
|
||||
spdlog::error(
|
||||
"Error: For libplacebo, shader name/path (-s), width (-w), "
|
||||
spdlog::critical(
|
||||
"For libplacebo, shader name/path (-s), width (-w), "
|
||||
"and height (-h) are required."
|
||||
);
|
||||
return 1;
|
||||
}
|
||||
} else if (arguments.filter_type == STR("realesrgan")) {
|
||||
if (arguments.scaling_factor == 0 || arguments.model_name.empty()) {
|
||||
spdlog::error("Error: For realesrgan, scaling factor (-r) and model (-m) are required."
|
||||
);
|
||||
spdlog::critical("For realesrgan, scaling factor (-r) and model (-m) are required.");
|
||||
return 1;
|
||||
}
|
||||
if (arguments.scaling_factor != 2 && arguments.scaling_factor != 3 &&
|
||||
arguments.scaling_factor != 4) {
|
||||
spdlog::error("Error: Scaling factor must be 2, 3, or 4.");
|
||||
spdlog::critical("Scaling factor must be 2, 3, or 4.");
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
spdlog::error("Error: Invalid filter type specified. Must be 'libplacebo' or 'realesrgan'."
|
||||
);
|
||||
spdlog::critical("Invalid filter type specified. Must be 'libplacebo' or 'realesrgan'.");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Validate GPU ID
|
||||
int gpu_status = is_valid_gpu_id(arguments.gpuid);
|
||||
if (gpu_status < 0) {
|
||||
spdlog::warn("Unable to validate GPU ID.");
|
||||
} else if (arguments.gpuid > 0 && gpu_status == 0) {
|
||||
spdlog::critical("Invalid GPU ID specified.");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Validate bitrate
|
||||
if (arguments.bitrate < 0) {
|
||||
spdlog::error("Error: Invalid bitrate specified.");
|
||||
spdlog::critical("Invalid bitrate specified.");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Validate CRF
|
||||
if (arguments.crf < 0.0f || arguments.crf > 51.0f) {
|
||||
spdlog::error("Error: CRF must be between 0 and 51.");
|
||||
spdlog::critical("CRF must be between 0 and 51.");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Parse codec to AVCodec
|
||||
const AVCodec *codec = avcodec_find_encoder_by_name(wstring_to_utf8(arguments.codec).c_str());
|
||||
if (!codec) {
|
||||
spdlog::error("Error: Codec '{}' not found.", wstring_to_utf8(arguments.codec));
|
||||
spdlog::critical("Codec '{}' not found.", wstring_to_utf8(arguments.codec));
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -366,7 +488,7 @@ int main(int argc, char **argv) {
|
||||
if (!arguments.pix_fmt.empty()) {
|
||||
pix_fmt = av_get_pix_fmt(wstring_to_utf8(arguments.pix_fmt).c_str());
|
||||
if (pix_fmt == AV_PIX_FMT_NONE) {
|
||||
spdlog::error("Error: Invalid pixel format '{}'.", wstring_to_utf8(arguments.pix_fmt));
|
||||
spdlog::critical("Invalid pixel format '{}'.", wstring_to_utf8(arguments.pix_fmt));
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
@@ -400,6 +522,10 @@ int main(int argc, char **argv) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Print program version and processing information
|
||||
spdlog::info("Video2X version {}", LIBVIDEO2X_VERSION_STRING);
|
||||
spdlog::info("Processing file: {}", arguments.in_fname.u8string());
|
||||
|
||||
#ifdef _WIN32
|
||||
std::wstring shader_path_str = arguments.shader_path.wstring();
|
||||
#else
|
||||
@@ -415,7 +541,6 @@ int main(int argc, char **argv) {
|
||||
filter_config.config.libplacebo.shader_path = shader_path_str.c_str();
|
||||
} else if (arguments.filter_type == STR("realesrgan")) {
|
||||
filter_config.filter_type = FILTER_REALESRGAN;
|
||||
filter_config.config.realesrgan.gpuid = arguments.gpuid;
|
||||
filter_config.config.realesrgan.tta_mode = false;
|
||||
filter_config.config.realesrgan.scaling_factor = arguments.scaling_factor;
|
||||
filter_config.config.realesrgan.model_name = arguments.model_name.c_str();
|
||||
@@ -439,8 +564,8 @@ int main(int argc, char **argv) {
|
||||
if (arguments.hwaccel != STR("none")) {
|
||||
hw_device_type = av_hwdevice_find_type_by_name(wstring_to_utf8(arguments.hwaccel).c_str());
|
||||
if (hw_device_type == AV_HWDEVICE_TYPE_NONE) {
|
||||
spdlog::error(
|
||||
"Error: Invalid hardware device type '{}'.", wstring_to_utf8(arguments.hwaccel)
|
||||
spdlog::critical(
|
||||
"Invalid hardware device type '{}'.", wstring_to_utf8(arguments.hwaccel)
|
||||
);
|
||||
return 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user