Files
SVMClassifier/tests/test_svm_classifier.cpp
Ricardo Montañana Gómez d6dc083a5a
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 22s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m44s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m12s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m13s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 5m30s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Docker Build Test (push) Failing after 13s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Successful in 31s
CI/CD Pipeline / Create Release Package (push) Has been skipped
Initial commit as Claude developed it
2025-06-22 12:50:10 +02:00

679 lines
19 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* @file test_svm_classifier.cpp
* @brief Integration tests for SVMClassifier class
*/
#include <catch2/catch_all.hpp>
#include <svm_classifier/svm_classifier.hpp>
#include <torch/torch.h>
#include <nlohmann/json.hpp>
using namespace svm_classifier;
using json = nlohmann::json;
/**
* @brief Generate synthetic classification dataset
*/
std::pair<torch::Tensor, torch::Tensor> generate_test_data(int n_samples = 100,
int n_features = 4,
int n_classes = 3,
int seed = 42)
{
torch::manual_seed(seed);
auto X = torch::randn({ n_samples, n_features });
auto y = torch::randint(0, n_classes, { n_samples });
// Add some structure to make classification meaningful
for (int i = 0; i < n_samples; ++i) {
int target_class = y[i].item<int>();
// Bias features toward the target class
X[i] += torch::randn({ n_features }) * 0.5 + target_class;
}
return { X, y };
}
TEST_CASE("SVMClassifier Construction", "[integration][svm_classifier]")
{
SECTION("Default constructor")
{
SVMClassifier svm;
REQUIRE(svm.get_kernel_type() == KernelType::LINEAR);
REQUIRE_FALSE(svm.is_fitted());
REQUIRE(svm.get_n_classes() == 0);
REQUIRE(svm.get_n_features() == 0);
REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST);
}
SECTION("Constructor with parameters")
{
SVMClassifier svm(KernelType::RBF, 10.0, MulticlassStrategy::ONE_VS_ONE);
REQUIRE(svm.get_kernel_type() == KernelType::RBF);
REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
REQUIRE_FALSE(svm.is_fitted());
}
SECTION("JSON constructor")
{
json config = {
{"kernel", "polynomial"},
{"C", 5.0},
{"degree", 4},
{"multiclass_strategy", "ovo"}
};
SVMClassifier svm(config);
REQUIRE(svm.get_kernel_type() == KernelType::POLYNOMIAL);
REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
}
}
TEST_CASE("SVMClassifier Parameter Management", "[integration][svm_classifier]")
{
SVMClassifier svm;
SECTION("Set and get parameters")
{
json new_params = {
{"kernel", "rbf"},
{"C", 2.0},
{"gamma", 0.1},
{"probability", true}
};
svm.set_parameters(new_params);
auto current_params = svm.get_parameters();
REQUIRE(current_params["kernel"] == "rbf");
REQUIRE(current_params["C"] == Catch::Approx(2.0));
REQUIRE(current_params["gamma"] == Catch::Approx(0.1));
REQUIRE(current_params["probability"] == true);
}
SECTION("Invalid parameters")
{
json invalid_params = {
{"kernel", "invalid_kernel"}
};
REQUIRE_THROWS_AS(svm.set_parameters(invalid_params), std::invalid_argument);
json invalid_C = {
{"C", -1.0}
};
REQUIRE_THROWS_AS(svm.set_parameters(invalid_C), std::invalid_argument);
}
SECTION("Parameter changes reset fitted state")
{
auto [X, y] = generate_test_data(50, 3, 2);
svm.fit(X, y);
REQUIRE(svm.is_fitted());
json new_params = { {"kernel", "rbf"} };
svm.set_parameters(new_params);
REQUIRE_FALSE(svm.is_fitted());
}
}
TEST_CASE("SVMClassifier Linear Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR, 1.0);
auto [X, y] = generate_test_data(100, 4, 3);
SECTION("Basic training")
{
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_n_features() == 4);
REQUIRE(svm.get_n_classes() == 3);
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBLINEAR);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
REQUIRE(metrics.training_time >= 0.0);
}
SECTION("Training with probability")
{
json config = {
{"kernel", "linear"},
{"probability", true}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.supports_probability());
}
SECTION("Binary classification")
{
auto [X_binary, y_binary] = generate_test_data(50, 3, 2);
auto metrics = svm.fit(X_binary, y_binary);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_n_classes() == 2);
}
}
TEST_CASE("SVMClassifier RBF Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::RBF, 1.0);
auto [X, y] = generate_test_data(80, 3, 2);
SECTION("Basic RBF training")
{
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
}
SECTION("RBF with custom gamma")
{
json config = {
{"kernel", "rbf"},
{"gamma", 0.5}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
}
SECTION("RBF with auto gamma")
{
json config = {
{"kernel", "rbf"},
{"gamma", "auto"}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
}
}
TEST_CASE("SVMClassifier Polynomial Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm;
auto [X, y] = generate_test_data(60, 2, 2);
SECTION("Polynomial kernel")
{
json config = {
{"kernel", "polynomial"},
{"degree", 3},
{"gamma", 0.1},
{"coef0", 1.0}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_kernel_type() == KernelType::POLYNOMIAL);
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM);
}
SECTION("Different degrees")
{
for (int degree : {2, 4, 5}) {
json config = {
{"kernel", "polynomial"},
{"degree", degree}
};
SVMClassifier poly_svm(config);
REQUIRE_NOTHROW(poly_svm.fit(X, y));
REQUIRE(poly_svm.is_fitted());
}
}
}
TEST_CASE("SVMClassifier Sigmoid Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm;
auto [X, y] = generate_test_data(50, 2, 2);
json config = {
{"kernel", "sigmoid"},
{"gamma", 0.01},
{"coef0", 0.5}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_kernel_type() == KernelType::SIGMOID);
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM);
}
TEST_CASE("SVMClassifier Prediction", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR);
auto [X, y] = generate_test_data(100, 3, 3);
// Split data
auto X_train = X.slice(0, 0, 80);
auto y_train = y.slice(0, 0, 80);
auto X_test = X.slice(0, 80);
auto y_test = y.slice(0, 80);
svm.fit(X_train, y_train);
SECTION("Basic prediction")
{
auto predictions = svm.predict(X_test);
REQUIRE(predictions.dtype() == torch::kInt32);
REQUIRE(predictions.size(0) == X_test.size(0));
// Check that predictions are valid class labels
auto unique_preds = torch::unique(predictions);
for (int i = 0; i < unique_preds.size(0); ++i) {
int pred_class = unique_preds[i].item<int>();
auto classes = svm.get_classes();
REQUIRE(std::find(classes.begin(), classes.end(), pred_class) != classes.end());
}
}
SECTION("Prediction accuracy")
{
double accuracy = svm.score(X_test, y_test);
REQUIRE(accuracy >= 0.0);
REQUIRE(accuracy <= 1.0);
// For this synthetic dataset, we expect reasonable accuracy
REQUIRE(accuracy > 0.3); // Very loose bound
}
SECTION("Prediction on training data")
{
auto train_predictions = svm.predict(X_train);
double train_accuracy = svm.score(X_train, y_train);
REQUIRE(train_accuracy >= 0.0);
REQUIRE(train_accuracy <= 1.0);
}
}
TEST_CASE("SVMClassifier Probability Prediction", "[integration][svm_classifier]")
{
json config = {
{"kernel", "rbf"},
{"probability", true}
};
SVMClassifier svm(config);
auto [X, y] = generate_test_data(80, 3, 3);
svm.fit(X, y);
SECTION("Probability predictions")
{
REQUIRE(svm.supports_probability());
auto probabilities = svm.predict_proba(X);
REQUIRE(probabilities.dtype() == torch::kFloat64);
REQUIRE(probabilities.size(0) == X.size(0));
REQUIRE(probabilities.size(1) == 3); // 3 classes
// Check that probabilities sum to 1
auto prob_sums = probabilities.sum(1);
for (int i = 0; i < prob_sums.size(0); ++i) {
REQUIRE(prob_sums[i].item<double>() == Catch::Approx(1.0).margin(1e-6));
}
// Check that all probabilities are non-negative
REQUIRE(torch::all(probabilities >= 0.0).item<bool>());
}
SECTION("Probability without training")
{
SVMClassifier untrained_svm(config);
REQUIRE_THROWS_AS(untrained_svm.predict_proba(X), std::runtime_error);
}
SECTION("Probability not supported")
{
SVMClassifier no_prob_svm(KernelType::LINEAR); // No probability
no_prob_svm.fit(X, y);
REQUIRE_FALSE(no_prob_svm.supports_probability());
REQUIRE_THROWS_AS(no_prob_svm.predict_proba(X), std::runtime_error);
}
}
TEST_CASE("SVMClassifier Decision Function", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::RBF);
auto [X, y] = generate_test_data(60, 2, 3);
svm.fit(X, y);
SECTION("Decision function values")
{
auto decision_values = svm.decision_function(X);
REQUIRE(decision_values.dtype() == torch::kFloat64);
REQUIRE(decision_values.size(0) == X.size(0));
// Decision function output depends on multiclass strategy
REQUIRE(decision_values.size(1) > 0);
}
SECTION("Decision function consistency with predictions")
{
auto predictions = svm.predict(X);
auto decision_values = svm.decision_function(X);
// For OvR strategy, the predicted class should correspond to max decision value
if (svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST) {
for (int i = 0; i < X.size(0); ++i) {
auto max_indices = std::get<1>(torch::max(decision_values[i], 0));
// This is a simplified check - actual implementation might be more complex
}
}
}
}
TEST_CASE("SVMClassifier Multiclass Strategies", "[integration][svm_classifier]")
{
auto [X, y] = generate_test_data(80, 3, 4); // 4 classes
SECTION("One-vs-Rest strategy")
{
json config = {
{"kernel", "linear"},
{"multiclass_strategy", "ovr"}
};
SVMClassifier svm_ovr(config);
auto metrics = svm_ovr.fit(X, y);
REQUIRE(svm_ovr.is_fitted());
REQUIRE(svm_ovr.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST);
REQUIRE(svm_ovr.get_n_classes() == 4);
auto predictions = svm_ovr.predict(X);
REQUIRE(predictions.size(0) == X.size(0));
}
SECTION("One-vs-One strategy")
{
json config = {
{"kernel", "rbf"},
{"multiclass_strategy", "ovo"}
};
SVMClassifier svm_ovo(config);
auto metrics = svm_ovo.fit(X, y);
REQUIRE(svm_ovo.is_fitted());
REQUIRE(svm_ovo.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
REQUIRE(svm_ovo.get_n_classes() == 4);
auto predictions = svm_ovo.predict(X);
REQUIRE(predictions.size(0) == X.size(0));
}
SECTION("Compare strategies")
{
SVMClassifier svm_ovr(KernelType::LINEAR, 1.0, MulticlassStrategy::ONE_VS_REST);
SVMClassifier svm_ovo(KernelType::LINEAR, 1.0, MulticlassStrategy::ONE_VS_ONE);
svm_ovr.fit(X, y);
svm_ovo.fit(X, y);
auto pred_ovr = svm_ovr.predict(X);
auto pred_ovo = svm_ovo.predict(X);
// Both should produce valid predictions
REQUIRE(pred_ovr.size(0) == X.size(0));
REQUIRE(pred_ovo.size(0) == X.size(0));
}
}
TEST_CASE("SVMClassifier Evaluation Metrics", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR);
auto [X, y] = generate_test_data(100, 3, 3);
svm.fit(X, y);
SECTION("Detailed evaluation")
{
auto metrics = svm.evaluate(X, y);
REQUIRE(metrics.accuracy >= 0.0);
REQUIRE(metrics.accuracy <= 1.0);
REQUIRE(metrics.precision >= 0.0);
REQUIRE(metrics.precision <= 1.0);
REQUIRE(metrics.recall >= 0.0);
REQUIRE(metrics.recall <= 1.0);
REQUIRE(metrics.f1_score >= 0.0);
REQUIRE(metrics.f1_score <= 1.0);
// Check confusion matrix dimensions
REQUIRE(metrics.confusion_matrix.size() == 3); // 3 classes
for (const auto& row : metrics.confusion_matrix) {
REQUIRE(row.size() == 3);
}
}
SECTION("Perfect predictions metrics")
{
// Create simple dataset where perfect classification is possible
auto X_simple = torch::tensor({ {0.0, 0.0}, {1.0, 1.0}, {2.0, 2.0} });
auto y_simple = torch::tensor({ 0, 1, 2 });
SVMClassifier simple_svm(KernelType::LINEAR);
simple_svm.fit(X_simple, y_simple);
auto metrics = simple_svm.evaluate(X_simple, y_simple);
// Should have perfect or near-perfect accuracy on this simple dataset
REQUIRE(metrics.accuracy > 0.8); // Very achievable for this data
}
}
TEST_CASE("SVMClassifier Cross-Validation", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR);
auto [X, y] = generate_test_data(100, 3, 2);
SECTION("5-fold cross-validation")
{
auto cv_scores = svm.cross_validate(X, y, 5);
REQUIRE(cv_scores.size() == 5);
for (double score : cv_scores) {
REQUIRE(score >= 0.0);
REQUIRE(score <= 1.0);
}
// Calculate mean and std
double mean = std::accumulate(cv_scores.begin(), cv_scores.end(), 0.0) / cv_scores.size();
REQUIRE(mean >= 0.0);
REQUIRE(mean <= 1.0);
}
SECTION("Invalid CV folds")
{
REQUIRE_THROWS_AS(svm.cross_validate(X, y, 1), std::invalid_argument);
REQUIRE_THROWS_AS(svm.cross_validate(X, y, 0), std::invalid_argument);
}
SECTION("CV preserves original state")
{
// Fit the model first
svm.fit(X, y);
auto original_classes = svm.get_classes();
// Run CV
auto cv_scores = svm.cross_validate(X, y, 3);
// Should still be fitted with same classes
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_classes() == original_classes);
}
}
TEST_CASE("SVMClassifier Grid Search", "[integration][svm_classifier]")
{
SVMClassifier svm;
auto [X, y] = generate_test_data(60, 2, 2); // Smaller dataset for faster testing
SECTION("Simple grid search")
{
json param_grid = {
{"kernel", {"linear", "rbf"}},
{"C", {0.1, 1.0, 10.0}}
};
auto results = svm.grid_search(X, y, param_grid, 3);
REQUIRE(results.contains("best_params"));
REQUIRE(results.contains("best_score"));
REQUIRE(results.contains("cv_results"));
auto best_score = results["best_score"].get<double>();
REQUIRE(best_score >= 0.0);
REQUIRE(best_score <= 1.0);
auto cv_results = results["cv_results"].get<std::vector<double>>();
REQUIRE(cv_results.size() == 6); // 2 kernels × 3 C values
}
SECTION("RBF-specific grid search")
{
json param_grid = {
{"kernel", {"rbf"}},
{"C", {1.0, 10.0}},
{"gamma", {0.01, 0.1}}
};
auto results = svm.grid_search(X, y, param_grid, 3);
auto best_params = results["best_params"];
REQUIRE(best_params["kernel"] == "rbf");
REQUIRE(best_params.contains("C"));
REQUIRE(best_params.contains("gamma"));
}
}
TEST_CASE("SVMClassifier Error Handling", "[integration][svm_classifier]")
{
SVMClassifier svm;
SECTION("Prediction before training")
{
auto X = torch::randn({ 5, 3 });
REQUIRE_THROWS_AS(svm.predict(X), std::runtime_error);
REQUIRE_THROWS_AS(svm.predict_proba(X), std::runtime_error);
REQUIRE_THROWS_AS(svm.decision_function(X), std::runtime_error);
}
SECTION("Inconsistent feature dimensions")
{
auto X_train = torch::randn({ 50, 3 });
auto y_train = torch::randint(0, 2, { 50 });
auto X_test = torch::randn({ 10, 5 }); // Different number of features
svm.fit(X_train, y_train);
REQUIRE_THROWS_AS(svm.predict(X_test), std::invalid_argument);
}
SECTION("Invalid training data")
{
auto X_invalid = torch::tensor({ {std::numeric_limits<float>::quiet_NaN(), 1.0} });
auto y_invalid = torch::tensor({ 0 });
REQUIRE_THROWS_AS(svm.fit(X_invalid, y_invalid), std::invalid_argument);
}
SECTION("Empty datasets")
{
auto X_empty = torch::empty({ 0, 3 });
auto y_empty = torch::empty({ 0 });
REQUIRE_THROWS_AS(svm.fit(X_empty, y_empty), std::invalid_argument);
}
}
TEST_CASE("SVMClassifier Move Semantics", "[integration][svm_classifier]")
{
SECTION("Move constructor")
{
SVMClassifier svm1(KernelType::RBF, 2.0);
auto [X, y] = generate_test_data(50, 2, 2);
svm1.fit(X, y);
auto original_classes = svm1.get_classes();
bool was_fitted = svm1.is_fitted();
SVMClassifier svm2 = std::move(svm1);
REQUIRE(svm2.is_fitted() == was_fitted);
REQUIRE(svm2.get_classes() == original_classes);
REQUIRE(svm2.get_kernel_type() == KernelType::RBF);
// Original should be in valid but unspecified state
REQUIRE_FALSE(svm1.is_fitted());
}
SECTION("Move assignment")
{
SVMClassifier svm1(KernelType::POLYNOMIAL);
SVMClassifier svm2(KernelType::LINEAR);
auto [X, y] = generate_test_data(40, 2, 2);
svm1.fit(X, y);
auto original_classes = svm1.get_classes();
svm2 = std::move(svm1);
REQUIRE(svm2.is_fitted());
REQUIRE(svm2.get_classes() == original_classes);
REQUIRE(svm2.get_kernel_type() == KernelType::POLYNOMIAL);
}
}
TEST_CASE("SVMClassifier Reset Functionality", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::RBF);
auto [X, y] = generate_test_data(50, 3, 2);
svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_n_features() > 0);
REQUIRE(svm.get_n_classes() > 0);
svm.reset();
REQUIRE_FALSE(svm.is_fitted());
REQUIRE(svm.get_n_features() == 0);
REQUIRE(svm.get_n_classes() == 0);
// Should be able to train again after reset
REQUIRE_NOTHROW(svm.fit(X, y));
REQUIRE(svm.is_fitted());
}