Initial commit as Claude developed it
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 22s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m44s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m12s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m13s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 5m30s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Docker Build Test (push) Failing after 13s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Successful in 31s
CI/CD Pipeline / Create Release Package (push) Has been skipped

This commit is contained in:
2025-06-22 12:50:10 +02:00
parent 270c540556
commit d6dc083a5a
38 changed files with 10197 additions and 6 deletions

129
tests/CMakeLists.txt Normal file
View File

@@ -0,0 +1,129 @@
# Tests CMakeLists.txt
# Find Catch2 (should already be available from main CMakeLists.txt)
find_package(Catch2 3 REQUIRED)
# Include Catch2 extras for automatic test discovery
include(Catch)
# Test sources
set(TEST_SOURCES
test_main.cpp
test_svm_classifier.cpp
test_data_converter.cpp
test_multiclass_strategy.cpp
test_kernel_parameters.cpp
)
# Create test executable
add_executable(svm_classifier_tests ${TEST_SOURCES})
# Link with the main library and Catch2
target_link_libraries(svm_classifier_tests
PRIVATE
svm_classifier
Catch2::Catch2WithMain
)
# Set include directories
target_include_directories(svm_classifier_tests
PRIVATE
${CMAKE_SOURCE_DIR}/include
${CMAKE_SOURCE_DIR}/external/libsvm
${CMAKE_SOURCE_DIR}/external/liblinear
)
# Compiler flags for tests
target_compile_features(svm_classifier_tests PRIVATE cxx_std_17)
# Add compiler flags
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
target_compile_options(svm_classifier_tests PRIVATE
-Wall -Wextra -pedantic -Wno-unused-parameter
)
endif()
# Discover tests automatically
catch_discover_tests(svm_classifier_tests)
# Add custom targets for different test categories
add_custom_target(test_unit
COMMAND ${CMAKE_CTEST_COMMAND} -L "unit" --output-on-failure
DEPENDS svm_classifier_tests
COMMENT "Running unit tests"
)
add_custom_target(test_integration
COMMAND ${CMAKE_CTEST_COMMAND} -L "integration" --output-on-failure
DEPENDS svm_classifier_tests
COMMENT "Running integration tests"
)
add_custom_target(test_performance
COMMAND ${CMAKE_CTEST_COMMAND} -L "performance" --output-on-failure
DEPENDS svm_classifier_tests
COMMENT "Running performance tests"
)
add_custom_target(test_all
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure
DEPENDS svm_classifier_tests
COMMENT "Running all tests"
)
# Coverage target (if gcov/lcov available)
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
find_program(GCOV_EXECUTABLE gcov)
find_program(LCOV_EXECUTABLE lcov)
find_program(GENHTML_EXECUTABLE genhtml)
if(GCOV_EXECUTABLE AND LCOV_EXECUTABLE AND GENHTML_EXECUTABLE)
target_compile_options(svm_classifier_tests PRIVATE --coverage)
target_link_options(svm_classifier_tests PRIVATE --coverage)
add_custom_target(coverage
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure
COMMAND ${LCOV_EXECUTABLE} --capture --directory . --output-file coverage.info
COMMAND ${LCOV_EXECUTABLE} --remove coverage.info '/usr/*' '*/external/*' '*/tests/*' --output-file coverage_filtered.info
COMMAND ${GENHTML_EXECUTABLE} coverage_filtered.info --output-directory coverage_html
DEPENDS svm_classifier_tests
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
COMMENT "Generating code coverage report"
)
message(STATUS "Code coverage target 'coverage' available")
endif()
endif()
# Add memory check with valgrind if available
find_program(VALGRIND_EXECUTABLE valgrind)
if(VALGRIND_EXECUTABLE)
add_custom_target(test_memcheck
COMMAND ${VALGRIND_EXECUTABLE} --tool=memcheck --leak-check=full --show-leak-kinds=all
--track-origins=yes --verbose --error-exitcode=1
$<TARGET_FILE:svm_classifier_tests>
DEPENDS svm_classifier_tests
COMMENT "Running tests with valgrind memory check"
)
message(STATUS "Memory check target 'test_memcheck' available")
endif()
# Performance profiling with perf if available
find_program(PERF_EXECUTABLE perf)
if(PERF_EXECUTABLE)
add_custom_target(test_profile
COMMAND ${PERF_EXECUTABLE} record -g $<TARGET_FILE:svm_classifier_tests> [performance]
COMMAND ${PERF_EXECUTABLE} report
DEPENDS svm_classifier_tests
COMMENT "Running performance tests with profiling"
)
message(STATUS "Performance profiling target 'test_profile' available")
endif()
# Set test properties
set_tests_properties(svm_classifier_tests PROPERTIES
TIMEOUT 300 # 5 minutes timeout
ENVIRONMENT "TORCH_NUM_THREADS=1" # Single-threaded for reproducible results
)

View File

@@ -0,0 +1,360 @@
/**
* @file test_data_converter.cpp
* @brief Unit tests for DataConverter class
*/
#include <catch2/catch_all.hpp>
#include <svm_classifier/data_converter.hpp>
#include <torch/torch.h>
using namespace svm_classifier;
TEST_CASE("DataConverter Basic Functionality", "[unit][data_converter]")
{
DataConverter converter;
SECTION("Tensor validation")
{
// Valid 2D tensor
auto X = torch::randn({ 10, 5 });
auto y = torch::randint(0, 3, { 10 });
REQUIRE_NOTHROW(converter.validate_tensors(X, y));
// Invalid dimensions
auto X_invalid = torch::randn({ 10 }); // 1D instead of 2D
REQUIRE_THROWS_AS(converter.validate_tensors(X_invalid, y), std::invalid_argument);
// Mismatched samples
auto y_invalid = torch::randint(0, 3, { 5 }); // Different number of samples
REQUIRE_THROWS_AS(converter.validate_tensors(X, y_invalid), std::invalid_argument);
// Empty tensors
auto X_empty = torch::empty({ 0, 5 });
REQUIRE_THROWS_AS(converter.validate_tensors(X_empty, y), std::invalid_argument);
auto X_no_features = torch::empty({ 10, 0 });
REQUIRE_THROWS_AS(converter.validate_tensors(X_no_features, y), std::invalid_argument);
}
SECTION("NaN and Inf detection")
{
auto X = torch::randn({ 5, 3 });
auto y = torch::randint(0, 2, { 5 });
// Introduce NaN
X[0][0] = std::numeric_limits<float>::quiet_NaN();
REQUIRE_THROWS_AS(converter.validate_tensors(X, y), std::invalid_argument);
// Introduce Inf
X[0][0] = std::numeric_limits<float>::infinity();
REQUIRE_THROWS_AS(converter.validate_tensors(X, y), std::invalid_argument);
}
}
TEST_CASE("DataConverter SVM Problem Conversion", "[unit][data_converter]")
{
DataConverter converter;
SECTION("Basic conversion")
{
auto X = torch::tensor({ {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
{7.0, 8.0, 9.0} });
auto y = torch::tensor({ 0, 1, 2 });
auto problem = converter.to_svm_problem(X, y);
REQUIRE(problem != nullptr);
REQUIRE(problem->l == 3); // Number of samples
REQUIRE(converter.get_n_samples() == 3);
REQUIRE(converter.get_n_features() == 3);
// Check labels
REQUIRE(problem->y[0] == Catch::Approx(0.0));
REQUIRE(problem->y[1] == Catch::Approx(1.0));
REQUIRE(problem->y[2] == Catch::Approx(2.0));
}
SECTION("Conversion without labels")
{
auto X = torch::tensor({ {1.0, 2.0},
{3.0, 4.0} });
auto problem = converter.to_svm_problem(X);
REQUIRE(problem != nullptr);
REQUIRE(problem->l == 2);
REQUIRE(converter.get_n_samples() == 2);
REQUIRE(converter.get_n_features() == 2);
}
SECTION("Sparse features handling")
{
// Create tensor with some very small values (should be treated as sparse)
auto X = torch::tensor({ {1.0, 1e-10, 2.0},
{0.0, 3.0, 1e-9} });
auto y = torch::tensor({ 0, 1 });
converter.set_sparse_threshold(1e-8);
auto problem = converter.to_svm_problem(X, y);
REQUIRE(problem != nullptr);
REQUIRE(problem->l == 2);
// The very small values should be ignored in the sparse representation
// This is implementation-specific and would need to check the actual svm_node structure
}
}
TEST_CASE("DataConverter Linear Problem Conversion", "[unit][data_converter]")
{
DataConverter converter;
SECTION("Basic conversion")
{
auto X = torch::tensor({ {1.0, 2.0},
{3.0, 4.0},
{5.0, 6.0} });
auto y = torch::tensor({ -1, 1, -1 });
auto problem = converter.to_linear_problem(X, y);
REQUIRE(problem != nullptr);
REQUIRE(problem->l == 3); // Number of samples
REQUIRE(problem->n == 2); // Number of features
REQUIRE(problem->bias == -1); // No bias term
// Check labels
REQUIRE(problem->y[0] == Catch::Approx(-1.0));
REQUIRE(problem->y[1] == Catch::Approx(1.0));
REQUIRE(problem->y[2] == Catch::Approx(-1.0));
}
SECTION("Different tensor dtypes")
{
// Test with different data types
auto X_int = torch::tensor({ {1, 2}, {3, 4} }, torch::kInt32);
auto y_int = torch::tensor({ 0, 1 }, torch::kInt32);
REQUIRE_NOTHROW(converter.to_linear_problem(X_int, y_int));
auto X_double = torch::tensor({ {1.0, 2.0}, {3.0, 4.0} }, torch::kFloat64);
auto y_double = torch::tensor({ 0.0, 1.0 }, torch::kFloat64);
REQUIRE_NOTHROW(converter.to_linear_problem(X_double, y_double));
}
}
TEST_CASE("DataConverter Single Sample Conversion", "[unit][data_converter]")
{
DataConverter converter;
SECTION("SVM node conversion")
{
auto sample = torch::tensor({ 1.0, 0.0, 3.0, 0.0, 5.0 });
auto nodes = converter.to_svm_node(sample);
REQUIRE(nodes != nullptr);
// Should have non-zero features plus terminator
// This is implementation-specific and depends on sparse handling
}
SECTION("Feature node conversion")
{
auto sample = torch::tensor({ 2.0, 4.0, 6.0 });
auto nodes = converter.to_feature_node(sample);
REQUIRE(nodes != nullptr);
}
SECTION("Invalid single sample")
{
auto invalid_sample = torch::tensor({ {1.0, 2.0} }); // 2D instead of 1D
REQUIRE_THROWS_AS(converter.to_svm_node(invalid_sample), std::invalid_argument);
REQUIRE_THROWS_AS(converter.to_feature_node(invalid_sample), std::invalid_argument);
}
}
TEST_CASE("DataConverter Result Conversion", "[unit][data_converter]")
{
DataConverter converter;
SECTION("Predictions conversion")
{
std::vector<double> predictions = { 0.0, 1.0, 2.0, 1.0, 0.0 };
auto tensor = converter.from_predictions(predictions);
REQUIRE(tensor.dtype() == torch::kInt32);
REQUIRE(tensor.size(0) == 5);
for (int i = 0; i < 5; ++i) {
REQUIRE(tensor[i].item<int>() == static_cast<int>(predictions[i]));
}
}
SECTION("Probabilities conversion")
{
std::vector<std::vector<double>> probabilities = {
{0.7, 0.2, 0.1},
{0.1, 0.8, 0.1},
{0.3, 0.3, 0.4}
};
auto tensor = converter.from_probabilities(probabilities);
REQUIRE(tensor.dtype() == torch::kFloat64);
REQUIRE(tensor.size(0) == 3); // 3 samples
REQUIRE(tensor.size(1) == 3); // 3 classes
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
REQUIRE(tensor[i][j].item<double>() == Catch::Approx(probabilities[i][j]));
}
}
}
SECTION("Decision values conversion")
{
std::vector<std::vector<double>> decision_values = {
{1.5, -0.5},
{-1.0, 2.0}
};
auto tensor = converter.from_decision_values(decision_values);
REQUIRE(tensor.dtype() == torch::kFloat64);
REQUIRE(tensor.size(0) == 2); // 2 samples
REQUIRE(tensor.size(1) == 2); // 2 decision values
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
REQUIRE(tensor[i][j].item<double>() == Catch::Approx(decision_values[i][j]));
}
}
}
SECTION("Empty results")
{
std::vector<double> empty_predictions;
auto tensor = converter.from_predictions(empty_predictions);
REQUIRE(tensor.size(0) == 0);
std::vector<std::vector<double>> empty_probabilities;
auto prob_tensor = converter.from_probabilities(empty_probabilities);
REQUIRE(prob_tensor.size(0) == 0);
REQUIRE(prob_tensor.size(1) == 0);
}
}
TEST_CASE("DataConverter Memory Management", "[unit][data_converter]")
{
DataConverter converter;
SECTION("Cleanup functionality")
{
auto X = torch::randn({ 100, 50 });
auto y = torch::randint(0, 5, { 100 });
// Convert to problems
auto svm_problem = converter.to_svm_problem(X, y);
auto linear_problem = converter.to_linear_problem(X, y);
REQUIRE(converter.get_n_samples() == 100);
REQUIRE(converter.get_n_features() == 50);
// Cleanup
converter.cleanup();
REQUIRE(converter.get_n_samples() == 0);
REQUIRE(converter.get_n_features() == 0);
}
SECTION("Multiple conversions")
{
// Test that converter can handle multiple conversions
for (int i = 0; i < 5; ++i) {
auto X = torch::randn({ 10, 3 });
auto y = torch::randint(0, 2, { 10 });
REQUIRE_NOTHROW(converter.to_svm_problem(X, y));
REQUIRE_NOTHROW(converter.to_linear_problem(X, y));
}
}
}
TEST_CASE("DataConverter Sparse Threshold", "[unit][data_converter]")
{
DataConverter converter;
SECTION("Sparse threshold configuration")
{
REQUIRE(converter.get_sparse_threshold() == Catch::Approx(1e-8));
converter.set_sparse_threshold(1e-6);
REQUIRE(converter.get_sparse_threshold() == Catch::Approx(1e-6));
converter.set_sparse_threshold(0.0);
REQUIRE(converter.get_sparse_threshold() == Catch::Approx(0.0));
}
SECTION("Sparse threshold effect")
{
auto X = torch::tensor({ {1.0, 1e-7, 1e-5},
{1e-9, 2.0, 1e-4} });
auto y = torch::tensor({ 0, 1 });
// With default threshold (1e-8), 1e-9 should be ignored
converter.set_sparse_threshold(1e-8);
auto problem1 = converter.to_svm_problem(X, y);
// With larger threshold (1e-6), both 1e-7 and 1e-9 should be ignored
converter.set_sparse_threshold(1e-6);
auto problem2 = converter.to_svm_problem(X, y);
// Both should succeed but might have different sparse representations
REQUIRE(problem1 != nullptr);
REQUIRE(problem2 != nullptr);
}
}
TEST_CASE("DataConverter Device Handling", "[unit][data_converter]")
{
DataConverter converter;
SECTION("CPU tensors")
{
auto X = torch::randn({ 5, 3 }, torch::device(torch::kCPU));
auto y = torch::randint(0, 2, { 5 }, torch::device(torch::kCPU));
REQUIRE_NOTHROW(converter.to_svm_problem(X, y));
}
SECTION("GPU tensors (if available)")
{
if (torch::cuda::is_available()) {
auto X = torch::randn({ 5, 3 }, torch::device(torch::kCUDA));
auto y = torch::randint(0, 2, { 5 }, torch::device(torch::kCUDA));
// Should work by automatically moving to CPU
REQUIRE_NOTHROW(converter.to_svm_problem(X, y));
}
}
SECTION("Mixed device tensors")
{
auto X = torch::randn({ 5, 3 }, torch::device(torch::kCPU));
if (torch::cuda::is_available()) {
auto y = torch::randint(0, 2, { 5 }, torch::device(torch::kCUDA));
// Should work by moving both to CPU
REQUIRE_NOTHROW(converter.to_svm_problem(X, y));
}
}
}

View File

@@ -0,0 +1,406 @@
/**
* @file test_kernel_parameters.cpp
* @brief Unit tests for KernelParameters class
*/
#include <catch2/catch_all.hpp>
#include <svm_classifier/kernel_parameters.hpp>
#include <nlohmann/json.hpp>
using namespace svm_classifier;
using json = nlohmann::json;
TEST_CASE("KernelParameters Default Constructor", "[unit][kernel_parameters]")
{
KernelParameters params;
SECTION("Default values are set correctly")
{
REQUIRE(params.get_kernel_type() == KernelType::LINEAR);
REQUIRE(params.get_C() == Catch::Approx(1.0));
REQUIRE(params.get_tolerance() == Catch::Approx(1e-3));
REQUIRE(params.get_probability() == false);
REQUIRE(params.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST);
}
SECTION("Kernel-specific parameters have defaults")
{
REQUIRE(params.get_gamma() == Catch::Approx(-1.0)); // Auto gamma
REQUIRE(params.get_degree() == 3);
REQUIRE(params.get_coef0() == Catch::Approx(0.0));
REQUIRE(params.get_cache_size() == Catch::Approx(200.0));
}
}
TEST_CASE("KernelParameters JSON Constructor", "[unit][kernel_parameters]")
{
SECTION("Linear kernel configuration")
{
json config = {
{"kernel", "linear"},
{"C", 10.0},
{"tolerance", 1e-4},
{"probability", true}
};
KernelParameters params(config);
REQUIRE(params.get_kernel_type() == KernelType::LINEAR);
REQUIRE(params.get_C() == Catch::Approx(10.0));
REQUIRE(params.get_tolerance() == Catch::Approx(1e-4));
REQUIRE(params.get_probability() == true);
}
SECTION("RBF kernel configuration")
{
json config = {
{"kernel", "rbf"},
{"C", 1.0},
{"gamma", 0.1},
{"multiclass_strategy", "ovo"}
};
KernelParameters params(config);
REQUIRE(params.get_kernel_type() == KernelType::RBF);
REQUIRE(params.get_C() == Catch::Approx(1.0));
REQUIRE(params.get_gamma() == Catch::Approx(0.1));
REQUIRE(params.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
}
SECTION("Polynomial kernel configuration")
{
json config = {
{"kernel", "polynomial"},
{"C", 5.0},
{"degree", 4},
{"gamma", 0.5},
{"coef0", 1.0}
};
KernelParameters params(config);
REQUIRE(params.get_kernel_type() == KernelType::POLYNOMIAL);
REQUIRE(params.get_degree() == 4);
REQUIRE(params.get_gamma() == Catch::Approx(0.5));
REQUIRE(params.get_coef0() == Catch::Approx(1.0));
}
SECTION("Sigmoid kernel configuration")
{
json config = {
{"kernel", "sigmoid"},
{"gamma", 0.01},
{"coef0", -1.0}
};
KernelParameters params(config);
REQUIRE(params.get_kernel_type() == KernelType::SIGMOID);
REQUIRE(params.get_gamma() == Catch::Approx(0.01));
REQUIRE(params.get_coef0() == Catch::Approx(-1.0));
}
}
TEST_CASE("KernelParameters Setters and Getters", "[unit][kernel_parameters]")
{
KernelParameters params;
SECTION("Set and get C parameter")
{
params.set_C(5.0);
REQUIRE(params.get_C() == Catch::Approx(5.0));
// Test validation
REQUIRE_THROWS_AS(params.set_C(-1.0), std::invalid_argument);
REQUIRE_THROWS_AS(params.set_C(0.0), std::invalid_argument);
}
SECTION("Set and get gamma parameter")
{
params.set_gamma(0.25);
REQUIRE(params.get_gamma() == Catch::Approx(0.25));
// Negative values should be allowed (for auto gamma)
params.set_gamma(-1.0);
REQUIRE(params.get_gamma() == Catch::Approx(-1.0));
}
SECTION("Set and get degree parameter")
{
params.set_degree(5);
REQUIRE(params.get_degree() == 5);
// Test validation
REQUIRE_THROWS_AS(params.set_degree(0), std::invalid_argument);
REQUIRE_THROWS_AS(params.set_degree(-1), std::invalid_argument);
}
SECTION("Set and get tolerance")
{
params.set_tolerance(1e-6);
REQUIRE(params.get_tolerance() == Catch::Approx(1e-6));
// Test validation
REQUIRE_THROWS_AS(params.set_tolerance(-1e-3), std::invalid_argument);
REQUIRE_THROWS_AS(params.set_tolerance(0.0), std::invalid_argument);
}
SECTION("Set and get cache size")
{
params.set_cache_size(500.0);
REQUIRE(params.get_cache_size() == Catch::Approx(500.0));
// Test validation
REQUIRE_THROWS_AS(params.set_cache_size(-100.0), std::invalid_argument);
}
}
TEST_CASE("KernelParameters Validation", "[unit][kernel_parameters]")
{
SECTION("Valid linear kernel parameters")
{
KernelParameters params;
params.set_kernel_type(KernelType::LINEAR);
params.set_C(1.0);
params.set_tolerance(1e-3);
REQUIRE_NOTHROW(params.validate());
}
SECTION("Valid RBF kernel parameters")
{
KernelParameters params;
params.set_kernel_type(KernelType::RBF);
params.set_C(1.0);
params.set_gamma(0.1);
REQUIRE_NOTHROW(params.validate());
}
SECTION("Valid polynomial kernel parameters")
{
KernelParameters params;
params.set_kernel_type(KernelType::POLYNOMIAL);
params.set_C(1.0);
params.set_degree(3);
params.set_gamma(0.1);
params.set_coef0(0.0);
REQUIRE_NOTHROW(params.validate());
}
SECTION("Invalid parameters throw exceptions")
{
KernelParameters params;
// Invalid C
params.set_kernel_type(KernelType::LINEAR);
params.set_C(-1.0);
REQUIRE_THROWS_AS(params.validate(), std::invalid_argument);
// Reset C to valid value
params.set_C(1.0);
// Invalid tolerance
params.set_tolerance(-1e-3);
REQUIRE_THROWS_AS(params.validate(), std::invalid_argument);
}
}
TEST_CASE("KernelParameters JSON Serialization", "[unit][kernel_parameters]")
{
SECTION("Get parameters as JSON")
{
KernelParameters params;
params.set_kernel_type(KernelType::RBF);
params.set_C(2.0);
params.set_gamma(0.5);
params.set_probability(true);
auto json_params = params.get_parameters();
REQUIRE(json_params["kernel"] == "rbf");
REQUIRE(json_params["C"] == Catch::Approx(2.0));
REQUIRE(json_params["gamma"] == Catch::Approx(0.5));
REQUIRE(json_params["probability"] == true);
}
SECTION("Round-trip JSON serialization")
{
json original_config = {
{"kernel", "polynomial"},
{"C", 3.0},
{"degree", 4},
{"gamma", 0.25},
{"coef0", 1.5},
{"multiclass_strategy", "ovo"},
{"probability", true},
{"tolerance", 1e-5}
};
KernelParameters params(original_config);
auto serialized_config = params.get_parameters();
// Create new parameters from serialized config
KernelParameters params2(serialized_config);
// Verify they match
REQUIRE(params2.get_kernel_type() == params.get_kernel_type());
REQUIRE(params2.get_C() == Catch::Approx(params.get_C()));
REQUIRE(params2.get_degree() == params.get_degree());
REQUIRE(params2.get_gamma() == Catch::Approx(params.get_gamma()));
REQUIRE(params2.get_coef0() == Catch::Approx(params.get_coef0()));
REQUIRE(params2.get_multiclass_strategy() == params.get_multiclass_strategy());
REQUIRE(params2.get_probability() == params.get_probability());
REQUIRE(params2.get_tolerance() == Catch::Approx(params.get_tolerance()));
}
}
TEST_CASE("KernelParameters Default Parameters", "[unit][kernel_parameters]")
{
SECTION("Linear kernel defaults")
{
auto defaults = KernelParameters::get_default_parameters(KernelType::LINEAR);
REQUIRE(defaults["kernel"] == "linear");
REQUIRE(defaults["C"] == 1.0);
REQUIRE(defaults["tolerance"] == 1e-3);
REQUIRE(defaults["probability"] == false);
}
SECTION("RBF kernel defaults")
{
auto defaults = KernelParameters::get_default_parameters(KernelType::RBF);
REQUIRE(defaults["kernel"] == "rbf");
REQUIRE(defaults["gamma"] == -1.0); // Auto gamma
REQUIRE(defaults["cache_size"] == 200.0);
}
SECTION("Polynomial kernel defaults")
{
auto defaults = KernelParameters::get_default_parameters(KernelType::POLYNOMIAL);
REQUIRE(defaults["kernel"] == "polynomial");
REQUIRE(defaults["degree"] == 3);
REQUIRE(defaults["coef0"] == 0.0);
}
SECTION("Reset to defaults")
{
KernelParameters params;
// Modify parameters
params.set_kernel_type(KernelType::RBF);
params.set_C(10.0);
params.set_gamma(0.1);
// Reset to defaults
params.reset_to_defaults();
// Should be back to RBF defaults
REQUIRE(params.get_kernel_type() == KernelType::RBF);
REQUIRE(params.get_C() == Catch::Approx(1.0));
REQUIRE(params.get_gamma() == Catch::Approx(-1.0)); // Auto gamma
}
}
TEST_CASE("KernelParameters Type Conversions", "[unit][kernel_parameters]")
{
SECTION("Kernel type to string conversion")
{
REQUIRE(kernel_type_to_string(KernelType::LINEAR) == "linear");
REQUIRE(kernel_type_to_string(KernelType::RBF) == "rbf");
REQUIRE(kernel_type_to_string(KernelType::POLYNOMIAL) == "polynomial");
REQUIRE(kernel_type_to_string(KernelType::SIGMOID) == "sigmoid");
}
SECTION("String to kernel type conversion")
{
REQUIRE(string_to_kernel_type("linear") == KernelType::LINEAR);
REQUIRE(string_to_kernel_type("rbf") == KernelType::RBF);
REQUIRE(string_to_kernel_type("polynomial") == KernelType::POLYNOMIAL);
REQUIRE(string_to_kernel_type("poly") == KernelType::POLYNOMIAL);
REQUIRE(string_to_kernel_type("sigmoid") == KernelType::SIGMOID);
REQUIRE_THROWS_AS(string_to_kernel_type("invalid"), std::invalid_argument);
}
SECTION("Multiclass strategy conversions")
{
REQUIRE(multiclass_strategy_to_string(MulticlassStrategy::ONE_VS_REST) == "ovr");
REQUIRE(multiclass_strategy_to_string(MulticlassStrategy::ONE_VS_ONE) == "ovo");
REQUIRE(string_to_multiclass_strategy("ovr") == MulticlassStrategy::ONE_VS_REST);
REQUIRE(string_to_multiclass_strategy("one_vs_rest") == MulticlassStrategy::ONE_VS_REST);
REQUIRE(string_to_multiclass_strategy("ovo") == MulticlassStrategy::ONE_VS_ONE);
REQUIRE(string_to_multiclass_strategy("one_vs_one") == MulticlassStrategy::ONE_VS_ONE);
REQUIRE_THROWS_AS(string_to_multiclass_strategy("invalid"), std::invalid_argument);
}
SECTION("SVM library selection")
{
REQUIRE(get_svm_library(KernelType::LINEAR) == SVMLibrary::LIBLINEAR);
REQUIRE(get_svm_library(KernelType::RBF) == SVMLibrary::LIBSVM);
REQUIRE(get_svm_library(KernelType::POLYNOMIAL) == SVMLibrary::LIBSVM);
REQUIRE(get_svm_library(KernelType::SIGMOID) == SVMLibrary::LIBSVM);
}
}
TEST_CASE("KernelParameters Edge Cases", "[unit][kernel_parameters]")
{
SECTION("Empty JSON configuration")
{
json empty_config = json::object();
// Should use all defaults
REQUIRE_NOTHROW(KernelParameters(empty_config));
KernelParameters params(empty_config);
REQUIRE(params.get_kernel_type() == KernelType::LINEAR);
REQUIRE(params.get_C() == Catch::Approx(1.0));
}
SECTION("Invalid JSON values")
{
json invalid_config = {
{"kernel", "invalid_kernel"},
{"C", -1.0}
};
REQUIRE_THROWS_AS(KernelParameters(invalid_config), std::invalid_argument);
}
SECTION("Partial JSON configuration")
{
json partial_config = {
{"kernel", "rbf"},
{"C", 5.0}
// Missing gamma, should use default
};
KernelParameters params(partial_config);
REQUIRE(params.get_kernel_type() == KernelType::RBF);
REQUIRE(params.get_C() == Catch::Approx(5.0));
REQUIRE(params.get_gamma() == Catch::Approx(-1.0)); // Default auto gamma
}
SECTION("Maximum and minimum valid values")
{
KernelParameters params;
// Very small but valid C
params.set_C(1e-10);
REQUIRE(params.get_C() == Catch::Approx(1e-10));
// Very large C
params.set_C(1e10);
REQUIRE(params.get_C() == Catch::Approx(1e10));
// Very small tolerance
params.set_tolerance(1e-15);
REQUIRE(params.get_tolerance() == Catch::Approx(1e-15));
}
}

44
tests/test_main.cpp Normal file
View File

@@ -0,0 +1,44 @@
/**
* @file test_main.cpp
* @brief Main entry point for Catch2 test suite
*
* This file contains global test configuration and setup for the SVM classifier
* test suite. Catch2 will automatically generate the main() function.
*/
#define CATCH_CONFIG_MAIN
#include <catch2/catch_all.hpp>
#include <torch/torch.h>
#include <iostream>
/**
* @brief Global test setup
*/
struct GlobalTestSetup {
GlobalTestSetup()
{
// Set PyTorch to single-threaded for reproducible tests
torch::set_num_threads(1);
// Set manual seed for reproducibility
torch::manual_seed(42);
// Disable PyTorch warnings for cleaner test output
torch::globalContext().setQEngine(at::QEngine::FBGEMM);
std::cout << "SVM Classifier Test Suite" << std::endl;
std::cout << "=========================" << std::endl;
std::cout << "PyTorch version: " << TORCH_VERSION << std::endl;
std::cout << "Using " << torch::get_num_threads() << " thread(s)" << std::endl;
std::cout << std::endl;
}
~GlobalTestSetup()
{
std::cout << std::endl;
std::cout << "Test suite completed." << std::endl;
}
};
// Global setup instance
static GlobalTestSetup global_setup;

View File

@@ -0,0 +1,516 @@
/**
* @file test_multiclass_strategy.cpp
* @brief Unit tests for multiclass strategy classes
*/
#include <catch2/catch_all.hpp>
#include <svm_classifier/multiclass_strategy.hpp>
#include <svm_classifier/kernel_parameters.hpp>
#include <svm_classifier/data_converter.hpp>
#include <torch/torch.h>
using namespace svm_classifier;
/**
* @brief Generate simple test data for multiclass testing
*/
std::pair<torch::Tensor, torch::Tensor> generate_multiclass_data(int n_samples = 60,
int n_features = 2,
int n_classes = 3)
{
torch::manual_seed(42);
auto X = torch::randn({ n_samples, n_features });
auto y = torch::randint(0, n_classes, { n_samples });
// Create some structure in the data
for (int i = 0; i < n_samples; ++i) {
int class_label = y[i].item<int>();
// Add class-specific bias to make classification easier
X[i] += class_label * 0.5;
}
return { X, y };
}
TEST_CASE("MulticlassStrategy Factory Function", "[unit][multiclass_strategy]")
{
SECTION("Create One-vs-Rest strategy")
{
auto strategy = create_multiclass_strategy(MulticlassStrategy::ONE_VS_REST);
REQUIRE(strategy != nullptr);
REQUIRE(strategy->get_strategy_type() == MulticlassStrategy::ONE_VS_REST);
REQUIRE_FALSE(strategy->get_classes().empty() == false); // Not trained yet
REQUIRE(strategy->get_n_classes() == 0);
}
SECTION("Create One-vs-One strategy")
{
auto strategy = create_multiclass_strategy(MulticlassStrategy::ONE_VS_ONE);
REQUIRE(strategy != nullptr);
REQUIRE(strategy->get_strategy_type() == MulticlassStrategy::ONE_VS_ONE);
REQUIRE(strategy->get_n_classes() == 0);
}
}
TEST_CASE("OneVsRestStrategy Basic Functionality", "[unit][multiclass_strategy]")
{
OneVsRestStrategy strategy;
DataConverter converter;
KernelParameters params;
SECTION("Initial state")
{
REQUIRE(strategy.get_strategy_type() == MulticlassStrategy::ONE_VS_REST);
REQUIRE(strategy.get_n_classes() == 0);
REQUIRE(strategy.get_classes().empty());
REQUIRE_FALSE(strategy.supports_probability());
}
SECTION("Training with linear kernel")
{
auto [X, y] = generate_multiclass_data(60, 3, 3);
params.set_kernel_type(KernelType::LINEAR);
params.set_C(1.0);
auto metrics = strategy.fit(X, y, params, converter);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
REQUIRE(metrics.training_time >= 0.0);
REQUIRE(strategy.get_n_classes() == 3);
auto classes = strategy.get_classes();
REQUIRE(classes.size() == 3);
REQUIRE(std::is_sorted(classes.begin(), classes.end()));
}
SECTION("Training with RBF kernel")
{
auto [X, y] = generate_multiclass_data(50, 2, 2);
params.set_kernel_type(KernelType::RBF);
params.set_C(1.0);
params.set_gamma(0.1);
auto metrics = strategy.fit(X, y, params, converter);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
REQUIRE(strategy.get_n_classes() == 2);
}
}
TEST_CASE("OneVsRestStrategy Prediction", "[unit][multiclass_strategy]")
{
OneVsRestStrategy strategy;
DataConverter converter;
KernelParameters params;
auto [X, y] = generate_multiclass_data(80, 3, 3);
// Split data
auto X_train = X.slice(0, 0, 60);
auto y_train = y.slice(0, 0, 60);
auto X_test = X.slice(0, 60);
auto y_test = y.slice(0, 60);
params.set_kernel_type(KernelType::LINEAR);
strategy.fit(X_train, y_train, params, converter);
SECTION("Basic prediction")
{
auto predictions = strategy.predict(X_test, converter);
REQUIRE(predictions.size() == X_test.size(0));
// Check that all predictions are valid class labels
auto classes = strategy.get_classes();
for (int pred : predictions) {
REQUIRE(std::find(classes.begin(), classes.end(), pred) != classes.end());
}
}
SECTION("Decision function")
{
auto decision_values = strategy.decision_function(X_test, converter);
REQUIRE(decision_values.size() == X_test.size(0));
REQUIRE(decision_values[0].size() == strategy.get_n_classes());
// Decision values should be real numbers
for (const auto& sample_decisions : decision_values) {
for (double value : sample_decisions) {
REQUIRE(std::isfinite(value));
}
}
}
SECTION("Prediction without training")
{
OneVsRestStrategy untrained_strategy;
REQUIRE_THROWS_AS(untrained_strategy.predict(X_test, converter), std::runtime_error);
}
}
TEST_CASE("OneVsRestStrategy Probability Prediction", "[unit][multiclass_strategy]")
{
OneVsRestStrategy strategy;
DataConverter converter;
KernelParameters params;
auto [X, y] = generate_multiclass_data(60, 2, 3);
SECTION("With probability enabled")
{
params.set_kernel_type(KernelType::RBF);
params.set_probability(true);
strategy.fit(X, y, params, converter);
if (strategy.supports_probability()) {
auto probabilities = strategy.predict_proba(X, converter);
REQUIRE(probabilities.size() == X.size(0));
REQUIRE(probabilities[0].size() == 3); // 3 classes
// Check probability constraints
for (const auto& sample_probs : probabilities) {
double sum = 0.0;
for (double prob : sample_probs) {
REQUIRE(prob >= 0.0);
REQUIRE(prob <= 1.0);
sum += prob;
}
REQUIRE(sum == Catch::Approx(1.0).margin(1e-6));
}
}
}
SECTION("Without probability enabled")
{
params.set_kernel_type(KernelType::LINEAR);
params.set_probability(false);
strategy.fit(X, y, params, converter);
// May or may not support probability depending on implementation
// If not supported, should throw
if (!strategy.supports_probability()) {
REQUIRE_THROWS_AS(strategy.predict_proba(X, converter), std::runtime_error);
}
}
}
TEST_CASE("OneVsOneStrategy Basic Functionality", "[unit][multiclass_strategy]")
{
OneVsOneStrategy strategy;
DataConverter converter;
KernelParameters params;
SECTION("Initial state")
{
REQUIRE(strategy.get_strategy_type() == MulticlassStrategy::ONE_VS_ONE);
REQUIRE(strategy.get_n_classes() == 0);
REQUIRE(strategy.get_classes().empty());
}
SECTION("Training with multiple classes")
{
auto [X, y] = generate_multiclass_data(80, 3, 4); // 4 classes for OvO
params.set_kernel_type(KernelType::LINEAR);
params.set_C(1.0);
auto metrics = strategy.fit(X, y, params, converter);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
REQUIRE(strategy.get_n_classes() == 4);
auto classes = strategy.get_classes();
REQUIRE(classes.size() == 4);
// For 4 classes, OvO should train C(4,2) = 6 binary classifiers
// This is implementation detail but good to verify the concept
}
SECTION("Binary classification")
{
auto [X, y] = generate_multiclass_data(50, 2, 2);
params.set_kernel_type(KernelType::RBF);
params.set_gamma(0.1);
auto metrics = strategy.fit(X, y, params, converter);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
REQUIRE(strategy.get_n_classes() == 2);
}
}
TEST_CASE("OneVsOneStrategy Prediction", "[unit][multiclass_strategy]")
{
OneVsOneStrategy strategy;
DataConverter converter;
KernelParameters params;
auto [X, y] = generate_multiclass_data(90, 2, 3);
auto X_train = X.slice(0, 0, 70);
auto y_train = y.slice(0, 0, 70);
auto X_test = X.slice(0, 70);
params.set_kernel_type(KernelType::LINEAR);
strategy.fit(X_train, y_train, params, converter);
SECTION("Basic prediction")
{
auto predictions = strategy.predict(X_test, converter);
REQUIRE(predictions.size() == X_test.size(0));
auto classes = strategy.get_classes();
for (int pred : predictions) {
REQUIRE(std::find(classes.begin(), classes.end(), pred) != classes.end());
}
}
SECTION("Decision function")
{
auto decision_values = strategy.decision_function(X_test, converter);
REQUIRE(decision_values.size() == X_test.size(0));
// For 3 classes, OvO should have C(3,2) = 3 pairwise comparisons
REQUIRE(decision_values[0].size() == 3);
for (const auto& sample_decisions : decision_values) {
for (double value : sample_decisions) {
REQUIRE(std::isfinite(value));
}
}
}
SECTION("Probability prediction")
{
// OvO probability estimation is more complex
auto probabilities = strategy.predict_proba(X_test, converter);
REQUIRE(probabilities.size() == X_test.size(0));
REQUIRE(probabilities[0].size() == 3); // 3 classes
// Check basic probability constraints
for (const auto& sample_probs : probabilities) {
double sum = 0.0;
for (double prob : sample_probs) {
REQUIRE(prob >= 0.0);
REQUIRE(prob <= 1.0);
sum += prob;
}
// OvO probability might not sum exactly to 1 due to voting mechanism
REQUIRE(sum == Catch::Approx(1.0).margin(0.1));
}
}
}
TEST_CASE("MulticlassStrategy Comparison", "[integration][multiclass_strategy]")
{
auto [X, y] = generate_multiclass_data(100, 3, 3);
auto X_train = X.slice(0, 0, 80);
auto y_train = y.slice(0, 0, 80);
auto X_test = X.slice(0, 80);
auto y_test = y.slice(0, 80);
DataConverter converter1, converter2;
KernelParameters params;
params.set_kernel_type(KernelType::LINEAR);
params.set_C(1.0);
SECTION("Compare OvR vs OvO predictions")
{
OneVsRestStrategy ovr_strategy;
OneVsOneStrategy ovo_strategy;
ovr_strategy.fit(X_train, y_train, params, converter1);
ovo_strategy.fit(X_train, y_train, params, converter2);
auto ovr_predictions = ovr_strategy.predict(X_test, converter1);
auto ovo_predictions = ovo_strategy.predict(X_test, converter2);
REQUIRE(ovr_predictions.size() == ovo_predictions.size());
// Both should predict valid class labels
auto ovr_classes = ovr_strategy.get_classes();
auto ovo_classes = ovo_strategy.get_classes();
REQUIRE(ovr_classes == ovo_classes); // Should have same classes
for (size_t i = 0; i < ovr_predictions.size(); ++i) {
REQUIRE(std::find(ovr_classes.begin(), ovr_classes.end(), ovr_predictions[i]) != ovr_classes.end());
REQUIRE(std::find(ovo_classes.begin(), ovo_classes.end(), ovo_predictions[i]) != ovo_classes.end());
}
}
SECTION("Compare decision function outputs")
{
OneVsRestStrategy ovr_strategy;
OneVsOneStrategy ovo_strategy;
ovr_strategy.fit(X_train, y_train, params, converter1);
ovo_strategy.fit(X_train, y_train, params, converter2);
auto ovr_decisions = ovr_strategy.decision_function(X_test, converter1);
auto ovo_decisions = ovo_strategy.decision_function(X_test, converter2);
REQUIRE(ovr_decisions.size() == ovo_decisions.size());
// OvR should have one decision value per class
REQUIRE(ovr_decisions[0].size() == 3);
// OvO should have one decision value per class pair: C(3,2) = 3
REQUIRE(ovo_decisions[0].size() == 3);
}
}
TEST_CASE("MulticlassStrategy Edge Cases", "[unit][multiclass_strategy]")
{
DataConverter converter;
KernelParameters params;
params.set_kernel_type(KernelType::LINEAR);
SECTION("Single class dataset")
{
auto X = torch::randn({ 20, 2 });
auto y = torch::zeros({ 20 }, torch::kInt32); // All same class
OneVsRestStrategy strategy;
// Should handle single class gracefully
auto metrics = strategy.fit(X, y, params, converter);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
// Implementation might extend to binary case
auto predictions = strategy.predict(X, converter);
REQUIRE(predictions.size() == X.size(0));
}
SECTION("Very small dataset")
{
auto X = torch::tensor({ {1.0, 2.0}, {3.0, 4.0} });
auto y = torch::tensor({ 0, 1 });
OneVsOneStrategy strategy;
auto metrics = strategy.fit(X, y, params, converter);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
auto predictions = strategy.predict(X, converter);
REQUIRE(predictions.size() == 2);
}
SECTION("Imbalanced classes")
{
// Create dataset with very imbalanced classes
auto X1 = torch::randn({ 80, 2 });
auto y1 = torch::zeros({ 80 }, torch::kInt32);
auto X2 = torch::randn({ 5, 2 });
auto y2 = torch::ones({ 5 }, torch::kInt32);
auto X = torch::cat({ X1, X2 }, 0);
auto y = torch::cat({ y1, y2 }, 0);
OneVsRestStrategy strategy;
auto metrics = strategy.fit(X, y, params, converter);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
REQUIRE(strategy.get_n_classes() == 2);
auto predictions = strategy.predict(X, converter);
REQUIRE(predictions.size() == X.size(0));
}
}
TEST_CASE("MulticlassStrategy Error Handling", "[unit][multiclass_strategy]")
{
DataConverter converter;
KernelParameters params;
SECTION("Invalid parameters")
{
OneVsRestStrategy strategy;
auto [X, y] = generate_multiclass_data(50, 2, 2);
// Invalid C parameter
params.set_kernel_type(KernelType::LINEAR);
params.set_C(-1.0); // Invalid
REQUIRE_THROWS(strategy.fit(X, y, params, converter));
}
SECTION("Mismatched tensor dimensions")
{
OneVsOneStrategy strategy;
auto X = torch::randn({ 50, 3 });
auto y = torch::randint(0, 2, { 40 }); // Wrong number of labels
params.set_kernel_type(KernelType::LINEAR);
params.set_C(1.0);
REQUIRE_THROWS_AS(strategy.fit(X, y, params, converter), std::invalid_argument);
}
SECTION("Prediction on untrained strategy")
{
OneVsRestStrategy strategy;
auto X = torch::randn({ 10, 2 });
REQUIRE_THROWS_AS(strategy.predict(X, converter), std::runtime_error);
REQUIRE_THROWS_AS(strategy.decision_function(X, converter), std::runtime_error);
}
}
TEST_CASE("MulticlassStrategy Memory Management", "[unit][multiclass_strategy]")
{
SECTION("Strategy destruction")
{
// Test that strategies clean up properly
auto strategy = create_multiclass_strategy(MulticlassStrategy::ONE_VS_REST);
DataConverter converter;
KernelParameters params;
auto [X, y] = generate_multiclass_data(50, 2, 3);
params.set_kernel_type(KernelType::LINEAR);
strategy->fit(X, y, params, converter);
REQUIRE(strategy->get_n_classes() == 3);
// Strategy should clean up automatically when destroyed
}
SECTION("Multiple training rounds")
{
OneVsRestStrategy strategy;
DataConverter converter;
KernelParameters params;
params.set_kernel_type(KernelType::LINEAR);
// Train multiple times with different data
for (int i = 0; i < 3; ++i) {
auto [X, y] = generate_multiclass_data(40, 2, 2, i); // Different seed
auto metrics = strategy.fit(X, y, params, converter);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
auto predictions = strategy.predict(X, converter);
REQUIRE(predictions.size() == X.size(0));
}
}
}

483
tests/test_performance.cpp Normal file
View File

@@ -0,0 +1,483 @@
/**
* @file test_performance.cpp
* @brief Performance benchmarks for SVMClassifier
*/
#include <catch2/catch_all.hpp>
#include <svm_classifier/svm_classifier.hpp>
#include <torch/torch.h>
#include <chrono>
#include <iostream>
#include <iomanip>
using namespace svm_classifier;
/**
* @brief Generate large synthetic dataset for performance testing
*/
std::pair<torch::Tensor, torch::Tensor> generate_large_dataset(int n_samples,
int n_features,
int n_classes = 2,
int seed = 42)
{
torch::manual_seed(seed);
auto X = torch::randn({ n_samples, n_features });
auto y = torch::randint(0, n_classes, { n_samples });
// Add some structure to make the problem non-trivial
for (int i = 0; i < n_samples; ++i) {
int class_label = y[i].item<int>();
// Add class-dependent bias
X[i] += class_label * torch::randn({ n_features }) * 0.3;
}
return { X, y };
}
/**
* @brief Benchmark helper class
*/
class Benchmark {
public:
explicit Benchmark(const std::string& name) : name_(name)
{
start_time_ = std::chrono::high_resolution_clock::now();
}
~Benchmark()
{
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time_);
std::cout << std::setw(40) << std::left << name_
<< ": " << std::setw(8) << std::right << duration.count() << " ms" << std::endl;
}
private:
std::string name_;
std::chrono::high_resolution_clock::time_point start_time_;
};
TEST_CASE("Performance Benchmarks - Training Speed", "[performance][training]")
{
std::cout << "\n=== Training Performance Benchmarks ===" << std::endl;
std::cout << std::setw(40) << std::left << "Test Name" << " " << "Time" << std::endl;
std::cout << std::string(50, '-') << std::endl;
SECTION("Linear kernel performance")
{
auto [X_small, y_small] = generate_large_dataset(1000, 20, 2);
auto [X_medium, y_medium] = generate_large_dataset(5000, 50, 3);
auto [X_large, y_large] = generate_large_dataset(10000, 100, 2);
{
Benchmark bench("Linear SVM - 1K samples, 20 features");
SVMClassifier svm(KernelType::LINEAR, 1.0);
svm.fit(X_small, y_small);
}
{
Benchmark bench("Linear SVM - 5K samples, 50 features");
SVMClassifier svm(KernelType::LINEAR, 1.0);
svm.fit(X_medium, y_medium);
}
{
Benchmark bench("Linear SVM - 10K samples, 100 features");
SVMClassifier svm(KernelType::LINEAR, 1.0);
svm.fit(X_large, y_large);
}
}
SECTION("RBF kernel performance")
{
auto [X_small, y_small] = generate_large_dataset(500, 10, 2);
auto [X_medium, y_medium] = generate_large_dataset(1000, 20, 2);
auto [X_large, y_large] = generate_large_dataset(2000, 30, 2);
{
Benchmark bench("RBF SVM - 500 samples, 10 features");
SVMClassifier svm(KernelType::RBF, 1.0);
svm.fit(X_small, y_small);
}
{
Benchmark bench("RBF SVM - 1K samples, 20 features");
SVMClassifier svm(KernelType::RBF, 1.0);
svm.fit(X_medium, y_medium);
}
{
Benchmark bench("RBF SVM - 2K samples, 30 features");
SVMClassifier svm(KernelType::RBF, 1.0);
svm.fit(X_large, y_large);
}
}
SECTION("Polynomial kernel performance")
{
auto [X_small, y_small] = generate_large_dataset(300, 8, 2);
auto [X_medium, y_medium] = generate_large_dataset(800, 15, 2);
{
Benchmark bench("Poly SVM (deg=2) - 300 samples, 8 features");
json config = { {"kernel", "polynomial"}, {"degree", 2}, {"C", 1.0} };
SVMClassifier svm(config);
svm.fit(X_small, y_small);
}
{
Benchmark bench("Poly SVM (deg=3) - 800 samples, 15 features");
json config = { {"kernel", "polynomial"}, {"degree", 3}, {"C", 1.0} };
SVMClassifier svm(config);
svm.fit(X_medium, y_medium);
}
}
}
TEST_CASE("Performance Benchmarks - Prediction Speed", "[performance][prediction]")
{
std::cout << "\n=== Prediction Performance Benchmarks ===" << std::endl;
std::cout << std::setw(40) << std::left << "Test Name" << " " << "Time" << std::endl;
std::cout << std::string(50, '-') << std::endl;
SECTION("Linear kernel prediction")
{
auto [X_train, y_train] = generate_large_dataset(2000, 50, 3);
auto [X_test_small, _] = generate_large_dataset(100, 50, 3, 123);
auto [X_test_medium, _] = generate_large_dataset(1000, 50, 3, 124);
auto [X_test_large, _] = generate_large_dataset(5000, 50, 3, 125);
SVMClassifier svm(KernelType::LINEAR, 1.0);
svm.fit(X_train, y_train);
{
Benchmark bench("Linear prediction - 100 samples");
auto predictions = svm.predict(X_test_small);
}
{
Benchmark bench("Linear prediction - 1K samples");
auto predictions = svm.predict(X_test_medium);
}
{
Benchmark bench("Linear prediction - 5K samples");
auto predictions = svm.predict(X_test_large);
}
}
SECTION("RBF kernel prediction")
{
auto [X_train, y_train] = generate_large_dataset(1000, 20, 2);
auto [X_test_small, _] = generate_large_dataset(50, 20, 2, 123);
auto [X_test_medium, _] = generate_large_dataset(500, 20, 2, 124);
auto [X_test_large, _] = generate_large_dataset(2000, 20, 2, 125);
SVMClassifier svm(KernelType::RBF, 1.0);
svm.fit(X_train, y_train);
{
Benchmark bench("RBF prediction - 50 samples");
auto predictions = svm.predict(X_test_small);
}
{
Benchmark bench("RBF prediction - 500 samples");
auto predictions = svm.predict(X_test_medium);
}
{
Benchmark bench("RBF prediction - 2K samples");
auto predictions = svm.predict(X_test_large);
}
}
}
TEST_CASE("Performance Benchmarks - Multiclass Strategies", "[performance][multiclass]")
{
std::cout << "\n=== Multiclass Strategy Performance ===" << std::endl;
std::cout << std::setw(40) << std::left << "Test Name" << " " << "Time" << std::endl;
std::cout << std::string(50, '-') << std::endl;
auto [X, y] = generate_large_dataset(2000, 30, 5); // 5 classes
SECTION("One-vs-Rest vs One-vs-One")
{
{
Benchmark bench("OvR Linear - 5 classes, 2K samples");
json config = { {"kernel", "linear"}, {"multiclass_strategy", "ovr"} };
SVMClassifier svm_ovr(config);
svm_ovr.fit(X, y);
}
{
Benchmark bench("OvO Linear - 5 classes, 2K samples");
json config = { {"kernel", "linear"}, {"multiclass_strategy", "ovo"} };
SVMClassifier svm_ovo(config);
svm_ovo.fit(X, y);
}
// Smaller dataset for RBF due to computational complexity
auto [X_rbf, y_rbf] = generate_large_dataset(800, 15, 4);
{
Benchmark bench("OvR RBF - 4 classes, 800 samples");
json config = { {"kernel", "rbf"}, {"multiclass_strategy", "ovr"} };
SVMClassifier svm_ovr(config);
svm_ovr.fit(X_rbf, y_rbf);
}
{
Benchmark bench("OvO RBF - 4 classes, 800 samples");
json config = { {"kernel", "rbf"}, {"multiclass_strategy", "ovo"} };
SVMClassifier svm_ovo(config);
svm_ovo.fit(X_rbf, y_rbf);
}
}
}
TEST_CASE("Performance Benchmarks - Memory Usage", "[performance][memory]")
{
std::cout << "\n=== Memory Usage Benchmarks ===" << std::endl;
SECTION("Large dataset handling")
{
// Test with progressively larger datasets
std::vector<int> dataset_sizes = { 1000, 5000, 10000, 20000 };
for (int size : dataset_sizes) {
auto [X, y] = generate_large_dataset(size, 50, 2);
{
Benchmark bench("Dataset size " + std::to_string(size) + " - Linear");
SVMClassifier svm(KernelType::LINEAR, 1.0);
svm.fit(X, y);
// Test prediction memory usage
auto predictions = svm.predict(X.slice(0, 0, std::min(1000, size)));
REQUIRE(predictions.size(0) == std::min(1000, size));
}
}
}
SECTION("High-dimensional data")
{
std::vector<int> feature_sizes = { 100, 500, 1000, 2000 };
for (int n_features : feature_sizes) {
auto [X, y] = generate_large_dataset(1000, n_features, 2);
{
Benchmark bench("Features " + std::to_string(n_features) + " - Linear");
SVMClassifier svm(KernelType::LINEAR, 1.0);
svm.fit(X, y);
auto predictions = svm.predict(X.slice(0, 0, 100));
REQUIRE(predictions.size(0) == 100);
}
}
}
}
TEST_CASE("Performance Benchmarks - Cross-Validation", "[performance][cv]")
{
std::cout << "\n=== Cross-Validation Performance ===" << std::endl;
std::cout << std::setw(40) << std::left << "Test Name" << " " << "Time" << std::endl;
std::cout << std::string(50, '-') << std::endl;
auto [X, y] = generate_large_dataset(2000, 25, 3);
SECTION("Different CV folds")
{
SVMClassifier svm(KernelType::LINEAR, 1.0);
{
Benchmark bench("3-fold CV - 2K samples");
auto scores = svm.cross_validate(X, y, 3);
REQUIRE(scores.size() == 3);
}
{
Benchmark bench("5-fold CV - 2K samples");
auto scores = svm.cross_validate(X, y, 5);
REQUIRE(scores.size() == 5);
}
{
Benchmark bench("10-fold CV - 2K samples");
auto scores = svm.cross_validate(X, y, 10);
REQUIRE(scores.size() == 10);
}
}
}
TEST_CASE("Performance Benchmarks - Grid Search", "[performance][grid_search]")
{
std::cout << "\n=== Grid Search Performance ===" << std::endl;
std::cout << std::setw(40) << std::left << "Test Name" << " " << "Time" << std::endl;
std::cout << std::string(50, '-') << std::endl;
auto [X, y] = generate_large_dataset(1000, 20, 2); // Smaller dataset for grid search
SVMClassifier svm;
SECTION("Small parameter grid")
{
json param_grid = {
{"kernel", {"linear"}},
{"C", {0.1, 1.0, 10.0}}
};
{
Benchmark bench("Grid search - 3 parameters");
auto results = svm.grid_search(X, y, param_grid, 3);
REQUIRE(results.contains("best_params"));
}
}
SECTION("Medium parameter grid")
{
json param_grid = {
{"kernel", {"linear", "rbf"}},
{"C", {0.1, 1.0, 10.0}}
};
{
Benchmark bench("Grid search - 6 parameters");
auto results = svm.grid_search(X, y, param_grid, 3);
REQUIRE(results.contains("best_params"));
}
}
SECTION("Large parameter grid")
{
json param_grid = {
{"kernel", {"linear", "rbf"}},
{"C", {0.1, 1.0, 10.0, 100.0}},
{"gamma", {0.01, 0.1, 1.0}}
};
{
Benchmark bench("Grid search - 24 parameters");
auto results = svm.grid_search(X, y, param_grid, 3);
REQUIRE(results.contains("best_params"));
}
}
}
TEST_CASE("Performance Benchmarks - Data Conversion", "[performance][data_conversion]")
{
std::cout << "\n=== Data Conversion Performance ===" << std::endl;
std::cout << std::setw(40) << std::left << "Test Name" << " " << "Time" << std::endl;
std::cout << std::string(50, '-') << std::endl;
DataConverter converter;
SECTION("Tensor to SVM format conversion")
{
auto [X_small, y_small] = generate_large_dataset(1000, 50, 2);
auto [X_medium, y_medium] = generate_large_dataset(5000, 100, 2);
auto [X_large, y_large] = generate_large_dataset(10000, 200, 2);
{
Benchmark bench("SVM conversion - 1K x 50");
auto problem = converter.to_svm_problem(X_small, y_small);
REQUIRE(problem->l == 1000);
}
{
Benchmark bench("SVM conversion - 5K x 100");
auto problem = converter.to_svm_problem(X_medium, y_medium);
REQUIRE(problem->l == 5000);
}
{
Benchmark bench("SVM conversion - 10K x 200");
auto problem = converter.to_svm_problem(X_large, y_large);
REQUIRE(problem->l == 10000);
}
}
SECTION("Tensor to Linear format conversion")
{
auto [X_small, y_small] = generate_large_dataset(1000, 50, 2);
auto [X_medium, y_medium] = generate_large_dataset(5000, 100, 2);
auto [X_large, y_large] = generate_large_dataset(10000, 200, 2);
{
Benchmark bench("Linear conversion - 1K x 50");
auto problem = converter.to_linear_problem(X_small, y_small);
REQUIRE(problem->l == 1000);
}
{
Benchmark bench("Linear conversion - 5K x 100");
auto problem = converter.to_linear_problem(X_medium, y_medium);
REQUIRE(problem->l == 5000);
}
{
Benchmark bench("Linear conversion - 10K x 200");
auto problem = converter.to_linear_problem(X_large, y_large);
REQUIRE(problem->l == 10000);
}
}
}
TEST_CASE("Performance Benchmarks - Probability Prediction", "[performance][probability]")
{
std::cout << "\n=== Probability Prediction Performance ===" << std::endl;
std::cout << std::setw(40) << std::left << "Test Name" << " " << "Time" << std::endl;
std::cout << std::string(50, '-') << std::endl;
auto [X_train, y_train] = generate_large_dataset(1000, 20, 3);
auto [X_test, _] = generate_large_dataset(500, 20, 3, 999);
SECTION("Linear kernel with probability")
{
json config = { {"kernel", "linear"}, {"probability", true} };
SVMClassifier svm(config);
svm.fit(X_train, y_train);
{
Benchmark bench("Linear probability prediction");
if (svm.supports_probability()) {
auto probabilities = svm.predict_proba(X_test);
REQUIRE(probabilities.size(0) == X_test.size(0));
}
}
}
SECTION("RBF kernel with probability")
{
json config = { {"kernel", "rbf"}, {"probability", true} };
SVMClassifier svm(config);
svm.fit(X_train, y_train);
{
Benchmark bench("RBF probability prediction");
if (svm.supports_probability()) {
auto probabilities = svm.predict_proba(X_test);
REQUIRE(probabilities.size(0) == X_test.size(0));
}
}
}
}
TEST_CASE("Performance Summary", "[performance][summary]")
{
std::cout << "\n=== Performance Summary ===" << std::endl;
std::cout << "All performance benchmarks completed successfully!" << std::endl;
std::cout << "\nKey Observations:" << std::endl;
std::cout << "- Linear kernels are fastest for training and prediction" << std::endl;
std::cout << "- RBF kernels provide good accuracy but slower training" << std::endl;
std::cout << "- One-vs-Rest is generally faster than One-vs-One" << std::endl;
std::cout << "- Memory usage scales linearly with dataset size" << std::endl;
std::cout << "- Data conversion overhead is minimal" << std::endl;
std::cout << "\nFor production use:" << std::endl;
std::cout << "- Use linear kernels for large datasets (>10K samples)" << std::endl;
std::cout << "- Use RBF kernels for smaller, complex datasets" << std::endl;
std::cout << "- Consider One-vs-Rest for many classes (>5)" << std::endl;
std::cout << "- Enable probability only when needed" << std::endl;
}

View File

@@ -0,0 +1,679 @@
/**
* @file test_svm_classifier.cpp
* @brief Integration tests for SVMClassifier class
*/
#include <catch2/catch_all.hpp>
#include <svm_classifier/svm_classifier.hpp>
#include <torch/torch.h>
#include <nlohmann/json.hpp>
using namespace svm_classifier;
using json = nlohmann::json;
/**
* @brief Generate synthetic classification dataset
*/
std::pair<torch::Tensor, torch::Tensor> generate_test_data(int n_samples = 100,
int n_features = 4,
int n_classes = 3,
int seed = 42)
{
torch::manual_seed(seed);
auto X = torch::randn({ n_samples, n_features });
auto y = torch::randint(0, n_classes, { n_samples });
// Add some structure to make classification meaningful
for (int i = 0; i < n_samples; ++i) {
int target_class = y[i].item<int>();
// Bias features toward the target class
X[i] += torch::randn({ n_features }) * 0.5 + target_class;
}
return { X, y };
}
TEST_CASE("SVMClassifier Construction", "[integration][svm_classifier]")
{
SECTION("Default constructor")
{
SVMClassifier svm;
REQUIRE(svm.get_kernel_type() == KernelType::LINEAR);
REQUIRE_FALSE(svm.is_fitted());
REQUIRE(svm.get_n_classes() == 0);
REQUIRE(svm.get_n_features() == 0);
REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST);
}
SECTION("Constructor with parameters")
{
SVMClassifier svm(KernelType::RBF, 10.0, MulticlassStrategy::ONE_VS_ONE);
REQUIRE(svm.get_kernel_type() == KernelType::RBF);
REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
REQUIRE_FALSE(svm.is_fitted());
}
SECTION("JSON constructor")
{
json config = {
{"kernel", "polynomial"},
{"C", 5.0},
{"degree", 4},
{"multiclass_strategy", "ovo"}
};
SVMClassifier svm(config);
REQUIRE(svm.get_kernel_type() == KernelType::POLYNOMIAL);
REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
}
}
TEST_CASE("SVMClassifier Parameter Management", "[integration][svm_classifier]")
{
SVMClassifier svm;
SECTION("Set and get parameters")
{
json new_params = {
{"kernel", "rbf"},
{"C", 2.0},
{"gamma", 0.1},
{"probability", true}
};
svm.set_parameters(new_params);
auto current_params = svm.get_parameters();
REQUIRE(current_params["kernel"] == "rbf");
REQUIRE(current_params["C"] == Catch::Approx(2.0));
REQUIRE(current_params["gamma"] == Catch::Approx(0.1));
REQUIRE(current_params["probability"] == true);
}
SECTION("Invalid parameters")
{
json invalid_params = {
{"kernel", "invalid_kernel"}
};
REQUIRE_THROWS_AS(svm.set_parameters(invalid_params), std::invalid_argument);
json invalid_C = {
{"C", -1.0}
};
REQUIRE_THROWS_AS(svm.set_parameters(invalid_C), std::invalid_argument);
}
SECTION("Parameter changes reset fitted state")
{
auto [X, y] = generate_test_data(50, 3, 2);
svm.fit(X, y);
REQUIRE(svm.is_fitted());
json new_params = { {"kernel", "rbf"} };
svm.set_parameters(new_params);
REQUIRE_FALSE(svm.is_fitted());
}
}
TEST_CASE("SVMClassifier Linear Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR, 1.0);
auto [X, y] = generate_test_data(100, 4, 3);
SECTION("Basic training")
{
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_n_features() == 4);
REQUIRE(svm.get_n_classes() == 3);
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBLINEAR);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
REQUIRE(metrics.training_time >= 0.0);
}
SECTION("Training with probability")
{
json config = {
{"kernel", "linear"},
{"probability", true}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.supports_probability());
}
SECTION("Binary classification")
{
auto [X_binary, y_binary] = generate_test_data(50, 3, 2);
auto metrics = svm.fit(X_binary, y_binary);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_n_classes() == 2);
}
}
TEST_CASE("SVMClassifier RBF Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::RBF, 1.0);
auto [X, y] = generate_test_data(80, 3, 2);
SECTION("Basic RBF training")
{
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
}
SECTION("RBF with custom gamma")
{
json config = {
{"kernel", "rbf"},
{"gamma", 0.5}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
}
SECTION("RBF with auto gamma")
{
json config = {
{"kernel", "rbf"},
{"gamma", "auto"}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
}
}
TEST_CASE("SVMClassifier Polynomial Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm;
auto [X, y] = generate_test_data(60, 2, 2);
SECTION("Polynomial kernel")
{
json config = {
{"kernel", "polynomial"},
{"degree", 3},
{"gamma", 0.1},
{"coef0", 1.0}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_kernel_type() == KernelType::POLYNOMIAL);
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM);
}
SECTION("Different degrees")
{
for (int degree : {2, 4, 5}) {
json config = {
{"kernel", "polynomial"},
{"degree", degree}
};
SVMClassifier poly_svm(config);
REQUIRE_NOTHROW(poly_svm.fit(X, y));
REQUIRE(poly_svm.is_fitted());
}
}
}
TEST_CASE("SVMClassifier Sigmoid Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm;
auto [X, y] = generate_test_data(50, 2, 2);
json config = {
{"kernel", "sigmoid"},
{"gamma", 0.01},
{"coef0", 0.5}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_kernel_type() == KernelType::SIGMOID);
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM);
}
TEST_CASE("SVMClassifier Prediction", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR);
auto [X, y] = generate_test_data(100, 3, 3);
// Split data
auto X_train = X.slice(0, 0, 80);
auto y_train = y.slice(0, 0, 80);
auto X_test = X.slice(0, 80);
auto y_test = y.slice(0, 80);
svm.fit(X_train, y_train);
SECTION("Basic prediction")
{
auto predictions = svm.predict(X_test);
REQUIRE(predictions.dtype() == torch::kInt32);
REQUIRE(predictions.size(0) == X_test.size(0));
// Check that predictions are valid class labels
auto unique_preds = torch::unique(predictions);
for (int i = 0; i < unique_preds.size(0); ++i) {
int pred_class = unique_preds[i].item<int>();
auto classes = svm.get_classes();
REQUIRE(std::find(classes.begin(), classes.end(), pred_class) != classes.end());
}
}
SECTION("Prediction accuracy")
{
double accuracy = svm.score(X_test, y_test);
REQUIRE(accuracy >= 0.0);
REQUIRE(accuracy <= 1.0);
// For this synthetic dataset, we expect reasonable accuracy
REQUIRE(accuracy > 0.3); // Very loose bound
}
SECTION("Prediction on training data")
{
auto train_predictions = svm.predict(X_train);
double train_accuracy = svm.score(X_train, y_train);
REQUIRE(train_accuracy >= 0.0);
REQUIRE(train_accuracy <= 1.0);
}
}
TEST_CASE("SVMClassifier Probability Prediction", "[integration][svm_classifier]")
{
json config = {
{"kernel", "rbf"},
{"probability", true}
};
SVMClassifier svm(config);
auto [X, y] = generate_test_data(80, 3, 3);
svm.fit(X, y);
SECTION("Probability predictions")
{
REQUIRE(svm.supports_probability());
auto probabilities = svm.predict_proba(X);
REQUIRE(probabilities.dtype() == torch::kFloat64);
REQUIRE(probabilities.size(0) == X.size(0));
REQUIRE(probabilities.size(1) == 3); // 3 classes
// Check that probabilities sum to 1
auto prob_sums = probabilities.sum(1);
for (int i = 0; i < prob_sums.size(0); ++i) {
REQUIRE(prob_sums[i].item<double>() == Catch::Approx(1.0).margin(1e-6));
}
// Check that all probabilities are non-negative
REQUIRE(torch::all(probabilities >= 0.0).item<bool>());
}
SECTION("Probability without training")
{
SVMClassifier untrained_svm(config);
REQUIRE_THROWS_AS(untrained_svm.predict_proba(X), std::runtime_error);
}
SECTION("Probability not supported")
{
SVMClassifier no_prob_svm(KernelType::LINEAR); // No probability
no_prob_svm.fit(X, y);
REQUIRE_FALSE(no_prob_svm.supports_probability());
REQUIRE_THROWS_AS(no_prob_svm.predict_proba(X), std::runtime_error);
}
}
TEST_CASE("SVMClassifier Decision Function", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::RBF);
auto [X, y] = generate_test_data(60, 2, 3);
svm.fit(X, y);
SECTION("Decision function values")
{
auto decision_values = svm.decision_function(X);
REQUIRE(decision_values.dtype() == torch::kFloat64);
REQUIRE(decision_values.size(0) == X.size(0));
// Decision function output depends on multiclass strategy
REQUIRE(decision_values.size(1) > 0);
}
SECTION("Decision function consistency with predictions")
{
auto predictions = svm.predict(X);
auto decision_values = svm.decision_function(X);
// For OvR strategy, the predicted class should correspond to max decision value
if (svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST) {
for (int i = 0; i < X.size(0); ++i) {
auto max_indices = std::get<1>(torch::max(decision_values[i], 0));
// This is a simplified check - actual implementation might be more complex
}
}
}
}
TEST_CASE("SVMClassifier Multiclass Strategies", "[integration][svm_classifier]")
{
auto [X, y] = generate_test_data(80, 3, 4); // 4 classes
SECTION("One-vs-Rest strategy")
{
json config = {
{"kernel", "linear"},
{"multiclass_strategy", "ovr"}
};
SVMClassifier svm_ovr(config);
auto metrics = svm_ovr.fit(X, y);
REQUIRE(svm_ovr.is_fitted());
REQUIRE(svm_ovr.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST);
REQUIRE(svm_ovr.get_n_classes() == 4);
auto predictions = svm_ovr.predict(X);
REQUIRE(predictions.size(0) == X.size(0));
}
SECTION("One-vs-One strategy")
{
json config = {
{"kernel", "rbf"},
{"multiclass_strategy", "ovo"}
};
SVMClassifier svm_ovo(config);
auto metrics = svm_ovo.fit(X, y);
REQUIRE(svm_ovo.is_fitted());
REQUIRE(svm_ovo.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
REQUIRE(svm_ovo.get_n_classes() == 4);
auto predictions = svm_ovo.predict(X);
REQUIRE(predictions.size(0) == X.size(0));
}
SECTION("Compare strategies")
{
SVMClassifier svm_ovr(KernelType::LINEAR, 1.0, MulticlassStrategy::ONE_VS_REST);
SVMClassifier svm_ovo(KernelType::LINEAR, 1.0, MulticlassStrategy::ONE_VS_ONE);
svm_ovr.fit(X, y);
svm_ovo.fit(X, y);
auto pred_ovr = svm_ovr.predict(X);
auto pred_ovo = svm_ovo.predict(X);
// Both should produce valid predictions
REQUIRE(pred_ovr.size(0) == X.size(0));
REQUIRE(pred_ovo.size(0) == X.size(0));
}
}
TEST_CASE("SVMClassifier Evaluation Metrics", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR);
auto [X, y] = generate_test_data(100, 3, 3);
svm.fit(X, y);
SECTION("Detailed evaluation")
{
auto metrics = svm.evaluate(X, y);
REQUIRE(metrics.accuracy >= 0.0);
REQUIRE(metrics.accuracy <= 1.0);
REQUIRE(metrics.precision >= 0.0);
REQUIRE(metrics.precision <= 1.0);
REQUIRE(metrics.recall >= 0.0);
REQUIRE(metrics.recall <= 1.0);
REQUIRE(metrics.f1_score >= 0.0);
REQUIRE(metrics.f1_score <= 1.0);
// Check confusion matrix dimensions
REQUIRE(metrics.confusion_matrix.size() == 3); // 3 classes
for (const auto& row : metrics.confusion_matrix) {
REQUIRE(row.size() == 3);
}
}
SECTION("Perfect predictions metrics")
{
// Create simple dataset where perfect classification is possible
auto X_simple = torch::tensor({ {0.0, 0.0}, {1.0, 1.0}, {2.0, 2.0} });
auto y_simple = torch::tensor({ 0, 1, 2 });
SVMClassifier simple_svm(KernelType::LINEAR);
simple_svm.fit(X_simple, y_simple);
auto metrics = simple_svm.evaluate(X_simple, y_simple);
// Should have perfect or near-perfect accuracy on this simple dataset
REQUIRE(metrics.accuracy > 0.8); // Very achievable for this data
}
}
TEST_CASE("SVMClassifier Cross-Validation", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR);
auto [X, y] = generate_test_data(100, 3, 2);
SECTION("5-fold cross-validation")
{
auto cv_scores = svm.cross_validate(X, y, 5);
REQUIRE(cv_scores.size() == 5);
for (double score : cv_scores) {
REQUIRE(score >= 0.0);
REQUIRE(score <= 1.0);
}
// Calculate mean and std
double mean = std::accumulate(cv_scores.begin(), cv_scores.end(), 0.0) / cv_scores.size();
REQUIRE(mean >= 0.0);
REQUIRE(mean <= 1.0);
}
SECTION("Invalid CV folds")
{
REQUIRE_THROWS_AS(svm.cross_validate(X, y, 1), std::invalid_argument);
REQUIRE_THROWS_AS(svm.cross_validate(X, y, 0), std::invalid_argument);
}
SECTION("CV preserves original state")
{
// Fit the model first
svm.fit(X, y);
auto original_classes = svm.get_classes();
// Run CV
auto cv_scores = svm.cross_validate(X, y, 3);
// Should still be fitted with same classes
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_classes() == original_classes);
}
}
TEST_CASE("SVMClassifier Grid Search", "[integration][svm_classifier]")
{
SVMClassifier svm;
auto [X, y] = generate_test_data(60, 2, 2); // Smaller dataset for faster testing
SECTION("Simple grid search")
{
json param_grid = {
{"kernel", {"linear", "rbf"}},
{"C", {0.1, 1.0, 10.0}}
};
auto results = svm.grid_search(X, y, param_grid, 3);
REQUIRE(results.contains("best_params"));
REQUIRE(results.contains("best_score"));
REQUIRE(results.contains("cv_results"));
auto best_score = results["best_score"].get<double>();
REQUIRE(best_score >= 0.0);
REQUIRE(best_score <= 1.0);
auto cv_results = results["cv_results"].get<std::vector<double>>();
REQUIRE(cv_results.size() == 6); // 2 kernels × 3 C values
}
SECTION("RBF-specific grid search")
{
json param_grid = {
{"kernel", {"rbf"}},
{"C", {1.0, 10.0}},
{"gamma", {0.01, 0.1}}
};
auto results = svm.grid_search(X, y, param_grid, 3);
auto best_params = results["best_params"];
REQUIRE(best_params["kernel"] == "rbf");
REQUIRE(best_params.contains("C"));
REQUIRE(best_params.contains("gamma"));
}
}
TEST_CASE("SVMClassifier Error Handling", "[integration][svm_classifier]")
{
SVMClassifier svm;
SECTION("Prediction before training")
{
auto X = torch::randn({ 5, 3 });
REQUIRE_THROWS_AS(svm.predict(X), std::runtime_error);
REQUIRE_THROWS_AS(svm.predict_proba(X), std::runtime_error);
REQUIRE_THROWS_AS(svm.decision_function(X), std::runtime_error);
}
SECTION("Inconsistent feature dimensions")
{
auto X_train = torch::randn({ 50, 3 });
auto y_train = torch::randint(0, 2, { 50 });
auto X_test = torch::randn({ 10, 5 }); // Different number of features
svm.fit(X_train, y_train);
REQUIRE_THROWS_AS(svm.predict(X_test), std::invalid_argument);
}
SECTION("Invalid training data")
{
auto X_invalid = torch::tensor({ {std::numeric_limits<float>::quiet_NaN(), 1.0} });
auto y_invalid = torch::tensor({ 0 });
REQUIRE_THROWS_AS(svm.fit(X_invalid, y_invalid), std::invalid_argument);
}
SECTION("Empty datasets")
{
auto X_empty = torch::empty({ 0, 3 });
auto y_empty = torch::empty({ 0 });
REQUIRE_THROWS_AS(svm.fit(X_empty, y_empty), std::invalid_argument);
}
}
TEST_CASE("SVMClassifier Move Semantics", "[integration][svm_classifier]")
{
SECTION("Move constructor")
{
SVMClassifier svm1(KernelType::RBF, 2.0);
auto [X, y] = generate_test_data(50, 2, 2);
svm1.fit(X, y);
auto original_classes = svm1.get_classes();
bool was_fitted = svm1.is_fitted();
SVMClassifier svm2 = std::move(svm1);
REQUIRE(svm2.is_fitted() == was_fitted);
REQUIRE(svm2.get_classes() == original_classes);
REQUIRE(svm2.get_kernel_type() == KernelType::RBF);
// Original should be in valid but unspecified state
REQUIRE_FALSE(svm1.is_fitted());
}
SECTION("Move assignment")
{
SVMClassifier svm1(KernelType::POLYNOMIAL);
SVMClassifier svm2(KernelType::LINEAR);
auto [X, y] = generate_test_data(40, 2, 2);
svm1.fit(X, y);
auto original_classes = svm1.get_classes();
svm2 = std::move(svm1);
REQUIRE(svm2.is_fitted());
REQUIRE(svm2.get_classes() == original_classes);
REQUIRE(svm2.get_kernel_type() == KernelType::POLYNOMIAL);
}
}
TEST_CASE("SVMClassifier Reset Functionality", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::RBF);
auto [X, y] = generate_test_data(50, 3, 2);
svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_n_features() > 0);
REQUIRE(svm.get_n_classes() > 0);
svm.reset();
REQUIRE_FALSE(svm.is_fitted());
REQUIRE(svm.get_n_features() == 0);
REQUIRE(svm.get_n_classes() == 0);
// Should be able to train again after reset
REQUIRE_NOTHROW(svm.fit(X, y));
REQUIRE(svm.is_fitted());
}