Initial commit as Claude developed it
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 22s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m44s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m12s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m13s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 5m30s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Docker Build Test (push) Failing after 13s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Successful in 31s
CI/CD Pipeline / Create Release Package (push) Has been skipped

This commit is contained in:
2025-06-22 12:50:10 +02:00
parent 270c540556
commit d6dc083a5a
38 changed files with 10197 additions and 6 deletions

View File

@@ -0,0 +1,679 @@
/**
* @file test_svm_classifier.cpp
* @brief Integration tests for SVMClassifier class
*/
#include <catch2/catch_all.hpp>
#include <svm_classifier/svm_classifier.hpp>
#include <torch/torch.h>
#include <nlohmann/json.hpp>
using namespace svm_classifier;
using json = nlohmann::json;
/**
* @brief Generate synthetic classification dataset
*/
std::pair<torch::Tensor, torch::Tensor> generate_test_data(int n_samples = 100,
int n_features = 4,
int n_classes = 3,
int seed = 42)
{
torch::manual_seed(seed);
auto X = torch::randn({ n_samples, n_features });
auto y = torch::randint(0, n_classes, { n_samples });
// Add some structure to make classification meaningful
for (int i = 0; i < n_samples; ++i) {
int target_class = y[i].item<int>();
// Bias features toward the target class
X[i] += torch::randn({ n_features }) * 0.5 + target_class;
}
return { X, y };
}
TEST_CASE("SVMClassifier Construction", "[integration][svm_classifier]")
{
SECTION("Default constructor")
{
SVMClassifier svm;
REQUIRE(svm.get_kernel_type() == KernelType::LINEAR);
REQUIRE_FALSE(svm.is_fitted());
REQUIRE(svm.get_n_classes() == 0);
REQUIRE(svm.get_n_features() == 0);
REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST);
}
SECTION("Constructor with parameters")
{
SVMClassifier svm(KernelType::RBF, 10.0, MulticlassStrategy::ONE_VS_ONE);
REQUIRE(svm.get_kernel_type() == KernelType::RBF);
REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
REQUIRE_FALSE(svm.is_fitted());
}
SECTION("JSON constructor")
{
json config = {
{"kernel", "polynomial"},
{"C", 5.0},
{"degree", 4},
{"multiclass_strategy", "ovo"}
};
SVMClassifier svm(config);
REQUIRE(svm.get_kernel_type() == KernelType::POLYNOMIAL);
REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
}
}
TEST_CASE("SVMClassifier Parameter Management", "[integration][svm_classifier]")
{
SVMClassifier svm;
SECTION("Set and get parameters")
{
json new_params = {
{"kernel", "rbf"},
{"C", 2.0},
{"gamma", 0.1},
{"probability", true}
};
svm.set_parameters(new_params);
auto current_params = svm.get_parameters();
REQUIRE(current_params["kernel"] == "rbf");
REQUIRE(current_params["C"] == Catch::Approx(2.0));
REQUIRE(current_params["gamma"] == Catch::Approx(0.1));
REQUIRE(current_params["probability"] == true);
}
SECTION("Invalid parameters")
{
json invalid_params = {
{"kernel", "invalid_kernel"}
};
REQUIRE_THROWS_AS(svm.set_parameters(invalid_params), std::invalid_argument);
json invalid_C = {
{"C", -1.0}
};
REQUIRE_THROWS_AS(svm.set_parameters(invalid_C), std::invalid_argument);
}
SECTION("Parameter changes reset fitted state")
{
auto [X, y] = generate_test_data(50, 3, 2);
svm.fit(X, y);
REQUIRE(svm.is_fitted());
json new_params = { {"kernel", "rbf"} };
svm.set_parameters(new_params);
REQUIRE_FALSE(svm.is_fitted());
}
}
TEST_CASE("SVMClassifier Linear Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR, 1.0);
auto [X, y] = generate_test_data(100, 4, 3);
SECTION("Basic training")
{
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_n_features() == 4);
REQUIRE(svm.get_n_classes() == 3);
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBLINEAR);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
REQUIRE(metrics.training_time >= 0.0);
}
SECTION("Training with probability")
{
json config = {
{"kernel", "linear"},
{"probability", true}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.supports_probability());
}
SECTION("Binary classification")
{
auto [X_binary, y_binary] = generate_test_data(50, 3, 2);
auto metrics = svm.fit(X_binary, y_binary);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_n_classes() == 2);
}
}
TEST_CASE("SVMClassifier RBF Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::RBF, 1.0);
auto [X, y] = generate_test_data(80, 3, 2);
SECTION("Basic RBF training")
{
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM);
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
}
SECTION("RBF with custom gamma")
{
json config = {
{"kernel", "rbf"},
{"gamma", 0.5}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
}
SECTION("RBF with auto gamma")
{
json config = {
{"kernel", "rbf"},
{"gamma", "auto"}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
}
}
TEST_CASE("SVMClassifier Polynomial Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm;
auto [X, y] = generate_test_data(60, 2, 2);
SECTION("Polynomial kernel")
{
json config = {
{"kernel", "polynomial"},
{"degree", 3},
{"gamma", 0.1},
{"coef0", 1.0}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_kernel_type() == KernelType::POLYNOMIAL);
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM);
}
SECTION("Different degrees")
{
for (int degree : {2, 4, 5}) {
json config = {
{"kernel", "polynomial"},
{"degree", degree}
};
SVMClassifier poly_svm(config);
REQUIRE_NOTHROW(poly_svm.fit(X, y));
REQUIRE(poly_svm.is_fitted());
}
}
}
TEST_CASE("SVMClassifier Sigmoid Kernel Training", "[integration][svm_classifier]")
{
SVMClassifier svm;
auto [X, y] = generate_test_data(50, 2, 2);
json config = {
{"kernel", "sigmoid"},
{"gamma", 0.01},
{"coef0", 0.5}
};
svm.set_parameters(config);
auto metrics = svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_kernel_type() == KernelType::SIGMOID);
REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM);
}
TEST_CASE("SVMClassifier Prediction", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR);
auto [X, y] = generate_test_data(100, 3, 3);
// Split data
auto X_train = X.slice(0, 0, 80);
auto y_train = y.slice(0, 0, 80);
auto X_test = X.slice(0, 80);
auto y_test = y.slice(0, 80);
svm.fit(X_train, y_train);
SECTION("Basic prediction")
{
auto predictions = svm.predict(X_test);
REQUIRE(predictions.dtype() == torch::kInt32);
REQUIRE(predictions.size(0) == X_test.size(0));
// Check that predictions are valid class labels
auto unique_preds = torch::unique(predictions);
for (int i = 0; i < unique_preds.size(0); ++i) {
int pred_class = unique_preds[i].item<int>();
auto classes = svm.get_classes();
REQUIRE(std::find(classes.begin(), classes.end(), pred_class) != classes.end());
}
}
SECTION("Prediction accuracy")
{
double accuracy = svm.score(X_test, y_test);
REQUIRE(accuracy >= 0.0);
REQUIRE(accuracy <= 1.0);
// For this synthetic dataset, we expect reasonable accuracy
REQUIRE(accuracy > 0.3); // Very loose bound
}
SECTION("Prediction on training data")
{
auto train_predictions = svm.predict(X_train);
double train_accuracy = svm.score(X_train, y_train);
REQUIRE(train_accuracy >= 0.0);
REQUIRE(train_accuracy <= 1.0);
}
}
TEST_CASE("SVMClassifier Probability Prediction", "[integration][svm_classifier]")
{
json config = {
{"kernel", "rbf"},
{"probability", true}
};
SVMClassifier svm(config);
auto [X, y] = generate_test_data(80, 3, 3);
svm.fit(X, y);
SECTION("Probability predictions")
{
REQUIRE(svm.supports_probability());
auto probabilities = svm.predict_proba(X);
REQUIRE(probabilities.dtype() == torch::kFloat64);
REQUIRE(probabilities.size(0) == X.size(0));
REQUIRE(probabilities.size(1) == 3); // 3 classes
// Check that probabilities sum to 1
auto prob_sums = probabilities.sum(1);
for (int i = 0; i < prob_sums.size(0); ++i) {
REQUIRE(prob_sums[i].item<double>() == Catch::Approx(1.0).margin(1e-6));
}
// Check that all probabilities are non-negative
REQUIRE(torch::all(probabilities >= 0.0).item<bool>());
}
SECTION("Probability without training")
{
SVMClassifier untrained_svm(config);
REQUIRE_THROWS_AS(untrained_svm.predict_proba(X), std::runtime_error);
}
SECTION("Probability not supported")
{
SVMClassifier no_prob_svm(KernelType::LINEAR); // No probability
no_prob_svm.fit(X, y);
REQUIRE_FALSE(no_prob_svm.supports_probability());
REQUIRE_THROWS_AS(no_prob_svm.predict_proba(X), std::runtime_error);
}
}
TEST_CASE("SVMClassifier Decision Function", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::RBF);
auto [X, y] = generate_test_data(60, 2, 3);
svm.fit(X, y);
SECTION("Decision function values")
{
auto decision_values = svm.decision_function(X);
REQUIRE(decision_values.dtype() == torch::kFloat64);
REQUIRE(decision_values.size(0) == X.size(0));
// Decision function output depends on multiclass strategy
REQUIRE(decision_values.size(1) > 0);
}
SECTION("Decision function consistency with predictions")
{
auto predictions = svm.predict(X);
auto decision_values = svm.decision_function(X);
// For OvR strategy, the predicted class should correspond to max decision value
if (svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST) {
for (int i = 0; i < X.size(0); ++i) {
auto max_indices = std::get<1>(torch::max(decision_values[i], 0));
// This is a simplified check - actual implementation might be more complex
}
}
}
}
TEST_CASE("SVMClassifier Multiclass Strategies", "[integration][svm_classifier]")
{
auto [X, y] = generate_test_data(80, 3, 4); // 4 classes
SECTION("One-vs-Rest strategy")
{
json config = {
{"kernel", "linear"},
{"multiclass_strategy", "ovr"}
};
SVMClassifier svm_ovr(config);
auto metrics = svm_ovr.fit(X, y);
REQUIRE(svm_ovr.is_fitted());
REQUIRE(svm_ovr.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST);
REQUIRE(svm_ovr.get_n_classes() == 4);
auto predictions = svm_ovr.predict(X);
REQUIRE(predictions.size(0) == X.size(0));
}
SECTION("One-vs-One strategy")
{
json config = {
{"kernel", "rbf"},
{"multiclass_strategy", "ovo"}
};
SVMClassifier svm_ovo(config);
auto metrics = svm_ovo.fit(X, y);
REQUIRE(svm_ovo.is_fitted());
REQUIRE(svm_ovo.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE);
REQUIRE(svm_ovo.get_n_classes() == 4);
auto predictions = svm_ovo.predict(X);
REQUIRE(predictions.size(0) == X.size(0));
}
SECTION("Compare strategies")
{
SVMClassifier svm_ovr(KernelType::LINEAR, 1.0, MulticlassStrategy::ONE_VS_REST);
SVMClassifier svm_ovo(KernelType::LINEAR, 1.0, MulticlassStrategy::ONE_VS_ONE);
svm_ovr.fit(X, y);
svm_ovo.fit(X, y);
auto pred_ovr = svm_ovr.predict(X);
auto pred_ovo = svm_ovo.predict(X);
// Both should produce valid predictions
REQUIRE(pred_ovr.size(0) == X.size(0));
REQUIRE(pred_ovo.size(0) == X.size(0));
}
}
TEST_CASE("SVMClassifier Evaluation Metrics", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR);
auto [X, y] = generate_test_data(100, 3, 3);
svm.fit(X, y);
SECTION("Detailed evaluation")
{
auto metrics = svm.evaluate(X, y);
REQUIRE(metrics.accuracy >= 0.0);
REQUIRE(metrics.accuracy <= 1.0);
REQUIRE(metrics.precision >= 0.0);
REQUIRE(metrics.precision <= 1.0);
REQUIRE(metrics.recall >= 0.0);
REQUIRE(metrics.recall <= 1.0);
REQUIRE(metrics.f1_score >= 0.0);
REQUIRE(metrics.f1_score <= 1.0);
// Check confusion matrix dimensions
REQUIRE(metrics.confusion_matrix.size() == 3); // 3 classes
for (const auto& row : metrics.confusion_matrix) {
REQUIRE(row.size() == 3);
}
}
SECTION("Perfect predictions metrics")
{
// Create simple dataset where perfect classification is possible
auto X_simple = torch::tensor({ {0.0, 0.0}, {1.0, 1.0}, {2.0, 2.0} });
auto y_simple = torch::tensor({ 0, 1, 2 });
SVMClassifier simple_svm(KernelType::LINEAR);
simple_svm.fit(X_simple, y_simple);
auto metrics = simple_svm.evaluate(X_simple, y_simple);
// Should have perfect or near-perfect accuracy on this simple dataset
REQUIRE(metrics.accuracy > 0.8); // Very achievable for this data
}
}
TEST_CASE("SVMClassifier Cross-Validation", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::LINEAR);
auto [X, y] = generate_test_data(100, 3, 2);
SECTION("5-fold cross-validation")
{
auto cv_scores = svm.cross_validate(X, y, 5);
REQUIRE(cv_scores.size() == 5);
for (double score : cv_scores) {
REQUIRE(score >= 0.0);
REQUIRE(score <= 1.0);
}
// Calculate mean and std
double mean = std::accumulate(cv_scores.begin(), cv_scores.end(), 0.0) / cv_scores.size();
REQUIRE(mean >= 0.0);
REQUIRE(mean <= 1.0);
}
SECTION("Invalid CV folds")
{
REQUIRE_THROWS_AS(svm.cross_validate(X, y, 1), std::invalid_argument);
REQUIRE_THROWS_AS(svm.cross_validate(X, y, 0), std::invalid_argument);
}
SECTION("CV preserves original state")
{
// Fit the model first
svm.fit(X, y);
auto original_classes = svm.get_classes();
// Run CV
auto cv_scores = svm.cross_validate(X, y, 3);
// Should still be fitted with same classes
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_classes() == original_classes);
}
}
TEST_CASE("SVMClassifier Grid Search", "[integration][svm_classifier]")
{
SVMClassifier svm;
auto [X, y] = generate_test_data(60, 2, 2); // Smaller dataset for faster testing
SECTION("Simple grid search")
{
json param_grid = {
{"kernel", {"linear", "rbf"}},
{"C", {0.1, 1.0, 10.0}}
};
auto results = svm.grid_search(X, y, param_grid, 3);
REQUIRE(results.contains("best_params"));
REQUIRE(results.contains("best_score"));
REQUIRE(results.contains("cv_results"));
auto best_score = results["best_score"].get<double>();
REQUIRE(best_score >= 0.0);
REQUIRE(best_score <= 1.0);
auto cv_results = results["cv_results"].get<std::vector<double>>();
REQUIRE(cv_results.size() == 6); // 2 kernels × 3 C values
}
SECTION("RBF-specific grid search")
{
json param_grid = {
{"kernel", {"rbf"}},
{"C", {1.0, 10.0}},
{"gamma", {0.01, 0.1}}
};
auto results = svm.grid_search(X, y, param_grid, 3);
auto best_params = results["best_params"];
REQUIRE(best_params["kernel"] == "rbf");
REQUIRE(best_params.contains("C"));
REQUIRE(best_params.contains("gamma"));
}
}
TEST_CASE("SVMClassifier Error Handling", "[integration][svm_classifier]")
{
SVMClassifier svm;
SECTION("Prediction before training")
{
auto X = torch::randn({ 5, 3 });
REQUIRE_THROWS_AS(svm.predict(X), std::runtime_error);
REQUIRE_THROWS_AS(svm.predict_proba(X), std::runtime_error);
REQUIRE_THROWS_AS(svm.decision_function(X), std::runtime_error);
}
SECTION("Inconsistent feature dimensions")
{
auto X_train = torch::randn({ 50, 3 });
auto y_train = torch::randint(0, 2, { 50 });
auto X_test = torch::randn({ 10, 5 }); // Different number of features
svm.fit(X_train, y_train);
REQUIRE_THROWS_AS(svm.predict(X_test), std::invalid_argument);
}
SECTION("Invalid training data")
{
auto X_invalid = torch::tensor({ {std::numeric_limits<float>::quiet_NaN(), 1.0} });
auto y_invalid = torch::tensor({ 0 });
REQUIRE_THROWS_AS(svm.fit(X_invalid, y_invalid), std::invalid_argument);
}
SECTION("Empty datasets")
{
auto X_empty = torch::empty({ 0, 3 });
auto y_empty = torch::empty({ 0 });
REQUIRE_THROWS_AS(svm.fit(X_empty, y_empty), std::invalid_argument);
}
}
TEST_CASE("SVMClassifier Move Semantics", "[integration][svm_classifier]")
{
SECTION("Move constructor")
{
SVMClassifier svm1(KernelType::RBF, 2.0);
auto [X, y] = generate_test_data(50, 2, 2);
svm1.fit(X, y);
auto original_classes = svm1.get_classes();
bool was_fitted = svm1.is_fitted();
SVMClassifier svm2 = std::move(svm1);
REQUIRE(svm2.is_fitted() == was_fitted);
REQUIRE(svm2.get_classes() == original_classes);
REQUIRE(svm2.get_kernel_type() == KernelType::RBF);
// Original should be in valid but unspecified state
REQUIRE_FALSE(svm1.is_fitted());
}
SECTION("Move assignment")
{
SVMClassifier svm1(KernelType::POLYNOMIAL);
SVMClassifier svm2(KernelType::LINEAR);
auto [X, y] = generate_test_data(40, 2, 2);
svm1.fit(X, y);
auto original_classes = svm1.get_classes();
svm2 = std::move(svm1);
REQUIRE(svm2.is_fitted());
REQUIRE(svm2.get_classes() == original_classes);
REQUIRE(svm2.get_kernel_type() == KernelType::POLYNOMIAL);
}
}
TEST_CASE("SVMClassifier Reset Functionality", "[integration][svm_classifier]")
{
SVMClassifier svm(KernelType::RBF);
auto [X, y] = generate_test_data(50, 3, 2);
svm.fit(X, y);
REQUIRE(svm.is_fitted());
REQUIRE(svm.get_n_features() > 0);
REQUIRE(svm.get_n_classes() > 0);
svm.reset();
REQUIRE_FALSE(svm.is_fitted());
REQUIRE(svm.get_n_features() == 0);
REQUIRE(svm.get_n_classes() == 0);
// Should be able to train again after reset
REQUIRE_NOTHROW(svm.fit(X, y));
REQUIRE(svm.is_fitted());
}