/** * @file test_svm_classifier.cpp * @brief Integration tests for SVMClassifier class */ #include #include #include #include #include using namespace svm_classifier; using json = nlohmann::json; /** * @brief Generate synthetic classification dataset */ std::pair generate_test_data(int n_samples = 100, int n_features = 4, int n_classes = 3, int seed = 42) { torch::manual_seed(seed); auto X = torch::randn({ n_samples, n_features }); auto y = torch::randint(0, n_classes, { n_samples }); // Add some structure to make classification meaningful for (int i = 0; i < n_samples; ++i) { int target_class = y[i].item(); // Bias features toward the target class X[i] += torch::randn({ n_features }) * 0.5 + target_class; } return { X, y }; } TEST_CASE("SVMClassifier Construction", "[integration][svm_classifier]") { SECTION("Default constructor") { SVMClassifier svm; REQUIRE(svm.get_kernel_type() == KernelType::LINEAR); REQUIRE_FALSE(svm.is_fitted()); REQUIRE(svm.get_n_classes() == 0); REQUIRE(svm.get_n_features() == 0); REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST); } SECTION("Constructor with parameters") { SVMClassifier svm(KernelType::RBF, 10.0, MulticlassStrategy::ONE_VS_ONE); REQUIRE(svm.get_kernel_type() == KernelType::RBF); REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE); REQUIRE_FALSE(svm.is_fitted()); } SECTION("JSON constructor") { json config = { {"kernel", "polynomial"}, {"C", 5.0}, {"degree", 4}, {"multiclass_strategy", "ovo"} }; SVMClassifier svm(config); REQUIRE(svm.get_kernel_type() == KernelType::POLYNOMIAL); REQUIRE(svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE); } } TEST_CASE("SVMClassifier Parameter Management", "[integration][svm_classifier]") { SVMClassifier svm; SECTION("Set and get parameters") { json new_params = { {"kernel", "rbf"}, {"C", 2.0}, {"gamma", 0.1}, {"probability", true} }; svm.set_parameters(new_params); auto current_params = svm.get_parameters(); REQUIRE(current_params["kernel"] == "rbf"); REQUIRE(current_params["C"] == Catch::Approx(2.0)); REQUIRE(current_params["gamma"] == Catch::Approx(0.1)); REQUIRE(current_params["probability"] == true); } SECTION("Invalid parameters") { json invalid_params = { {"kernel", "invalid_kernel"} }; REQUIRE_THROWS_AS(svm.set_parameters(invalid_params), std::invalid_argument); json invalid_C = { {"C", -1.0} }; REQUIRE_THROWS_AS(svm.set_parameters(invalid_C), std::invalid_argument); } SECTION("Parameter changes reset fitted state") { auto [X, y] = generate_test_data(50, 3, 2); svm.fit(X, y); REQUIRE(svm.is_fitted()); json new_params = { {"kernel", "rbf"} }; svm.set_parameters(new_params); REQUIRE_FALSE(svm.is_fitted()); } } TEST_CASE("SVMClassifier Linear Kernel Training", "[integration][svm_classifier]") { SVMClassifier svm(KernelType::LINEAR, 1.0); auto [X, y] = generate_test_data(100, 4, 3); SECTION("Basic training") { auto metrics = svm.fit(X, y); REQUIRE(svm.is_fitted()); REQUIRE(svm.get_n_features() == 4); REQUIRE(svm.get_n_classes() == 3); REQUIRE(svm.get_svm_library() == SVMLibrary::LIBLINEAR); REQUIRE(metrics.status == TrainingStatus::SUCCESS); REQUIRE(metrics.training_time >= 0.0); } SECTION("Training with probability") { json config = { {"kernel", "linear"}, {"probability", true} }; svm.set_parameters(config); auto metrics = svm.fit(X, y); REQUIRE(svm.is_fitted()); REQUIRE(svm.supports_probability()); } SECTION("Binary classification") { auto [X_binary, y_binary] = generate_test_data(50, 3, 2); auto metrics = svm.fit(X_binary, y_binary); REQUIRE(svm.is_fitted()); REQUIRE(svm.get_n_classes() == 2); } } TEST_CASE("SVMClassifier RBF Kernel Training", "[integration][svm_classifier]") { SVMClassifier svm(KernelType::RBF, 1.0); auto [X, y] = generate_test_data(80, 3, 2); SECTION("Basic RBF training") { auto metrics = svm.fit(X, y); REQUIRE(svm.is_fitted()); REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM); REQUIRE(metrics.status == TrainingStatus::SUCCESS); } SECTION("RBF with custom gamma") { json config = { {"kernel", "rbf"}, {"gamma", 0.5} }; svm.set_parameters(config); auto metrics = svm.fit(X, y); REQUIRE(svm.is_fitted()); } SECTION("RBF with auto gamma") { json config = { {"kernel", "rbf"}, {"gamma", "auto"} }; svm.set_parameters(config); auto metrics = svm.fit(X, y); REQUIRE(svm.is_fitted()); } } TEST_CASE("SVMClassifier Polynomial Kernel Training", "[integration][svm_classifier]") { SVMClassifier svm; auto [X, y] = generate_test_data(60, 2, 2); SECTION("Polynomial kernel") { json config = { {"kernel", "polynomial"}, {"degree", 3}, {"gamma", 0.1}, {"coef0", 1.0} }; svm.set_parameters(config); auto metrics = svm.fit(X, y); REQUIRE(svm.is_fitted()); REQUIRE(svm.get_kernel_type() == KernelType::POLYNOMIAL); REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM); } SECTION("Different degrees") { for (int degree : {2, 4, 5}) { json config = { {"kernel", "polynomial"}, {"degree", degree} }; SVMClassifier poly_svm(config); REQUIRE_NOTHROW(poly_svm.fit(X, y)); REQUIRE(poly_svm.is_fitted()); } } } TEST_CASE("SVMClassifier Sigmoid Kernel Training", "[integration][svm_classifier]") { SVMClassifier svm; auto [X, y] = generate_test_data(50, 2, 2); json config = { {"kernel", "sigmoid"}, {"gamma", 0.01}, {"coef0", 0.5} }; svm.set_parameters(config); auto metrics = svm.fit(X, y); REQUIRE(svm.is_fitted()); REQUIRE(svm.get_kernel_type() == KernelType::SIGMOID); REQUIRE(svm.get_svm_library() == SVMLibrary::LIBSVM); } TEST_CASE("SVMClassifier Prediction", "[integration][svm_classifier]") { SVMClassifier svm(KernelType::LINEAR); auto [X, y] = generate_test_data(100, 3, 3); // Split data auto X_train = X.slice(0, 0, 80); auto y_train = y.slice(0, 0, 80); auto X_test = X.slice(0, 80); auto y_test = y.slice(0, 80); svm.fit(X_train, y_train); SECTION("Basic prediction") { auto predictions = svm.predict(X_test); REQUIRE(predictions.dtype() == torch::kInt32); REQUIRE(predictions.size(0) == X_test.size(0)); // Check that predictions are valid class labels auto unique_preds = std::get<0>(at::_unique(predictions)); for (int i = 0; i < unique_preds.size(0); ++i) { int pred_class = unique_preds[i].item(); auto classes = svm.get_classes(); REQUIRE(std::find(classes.begin(), classes.end(), pred_class) != classes.end()); } } SECTION("Prediction accuracy") { double accuracy = svm.score(X_test, y_test); REQUIRE(accuracy >= 0.0); REQUIRE(accuracy <= 1.0); // For this synthetic dataset, we expect reasonable accuracy REQUIRE(accuracy > 0.3); // Very loose bound } SECTION("Prediction on training data") { auto train_predictions = svm.predict(X_train); double train_accuracy = svm.score(X_train, y_train); REQUIRE(train_accuracy >= 0.0); REQUIRE(train_accuracy <= 1.0); } } TEST_CASE("SVMClassifier Probability Prediction", "[integration][svm_classifier]") { json config = { {"kernel", "rbf"}, {"probability", true} }; SVMClassifier svm(config); auto [X, y] = generate_test_data(80, 3, 3); svm.fit(X, y); SECTION("Probability predictions") { REQUIRE(svm.supports_probability()); auto probabilities = svm.predict_proba(X); REQUIRE(probabilities.dtype() == torch::kFloat64); REQUIRE(probabilities.size(0) == X.size(0)); REQUIRE(probabilities.size(1) == 3); // 3 classes // Check that probabilities sum to 1 auto prob_sums = probabilities.sum(1); for (int i = 0; i < prob_sums.size(0); ++i) { REQUIRE(prob_sums[i].item() == Catch::Approx(1.0).margin(1e-6)); } // Check that all probabilities are non-negative REQUIRE(torch::all(probabilities >= 0.0).item()); } SECTION("Probability without training") { SVMClassifier untrained_svm(config); REQUIRE_THROWS_AS(untrained_svm.predict_proba(X), std::runtime_error); } SECTION("Probability not supported") { SVMClassifier no_prob_svm(KernelType::LINEAR); // No probability no_prob_svm.fit(X, y); REQUIRE_FALSE(no_prob_svm.supports_probability()); REQUIRE_THROWS_AS(no_prob_svm.predict_proba(X), std::runtime_error); } } TEST_CASE("SVMClassifier Decision Function", "[integration][svm_classifier]") { SVMClassifier svm(KernelType::RBF); auto [X, y] = generate_test_data(60, 2, 3); svm.fit(X, y); SECTION("Decision function values") { auto decision_values = svm.decision_function(X); REQUIRE(decision_values.dtype() == torch::kFloat64); REQUIRE(decision_values.size(0) == X.size(0)); // Decision function output depends on multiclass strategy REQUIRE(decision_values.size(1) > 0); } SECTION("Decision function consistency with predictions") { auto predictions = svm.predict(X); auto decision_values = svm.decision_function(X); // For OvR strategy, the predicted class should correspond to max decision value if (svm.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST) { for (int i = 0; i < X.size(0); ++i) { auto max_indices = std::get<1>(torch::max(decision_values[i], 0)); // This is a simplified check - actual implementation might be more complex } } } } TEST_CASE("SVMClassifier Multiclass Strategies", "[integration][svm_classifier]") { auto [X, y] = generate_test_data(80, 3, 4); // 4 classes SECTION("One-vs-Rest strategy") { json config = { {"kernel", "linear"}, {"multiclass_strategy", "ovr"} }; SVMClassifier svm_ovr(config); auto metrics = svm_ovr.fit(X, y); REQUIRE(svm_ovr.is_fitted()); REQUIRE(svm_ovr.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_REST); REQUIRE(svm_ovr.get_n_classes() == 4); auto predictions = svm_ovr.predict(X); REQUIRE(predictions.size(0) == X.size(0)); } SECTION("One-vs-One strategy") { json config = { {"kernel", "rbf"}, {"multiclass_strategy", "ovo"} }; SVMClassifier svm_ovo(config); auto metrics = svm_ovo.fit(X, y); REQUIRE(svm_ovo.is_fitted()); REQUIRE(svm_ovo.get_multiclass_strategy() == MulticlassStrategy::ONE_VS_ONE); REQUIRE(svm_ovo.get_n_classes() == 4); auto predictions = svm_ovo.predict(X); REQUIRE(predictions.size(0) == X.size(0)); } SECTION("Compare strategies") { SVMClassifier svm_ovr(KernelType::LINEAR, 1.0, MulticlassStrategy::ONE_VS_REST); SVMClassifier svm_ovo(KernelType::LINEAR, 1.0, MulticlassStrategy::ONE_VS_ONE); svm_ovr.fit(X, y); svm_ovo.fit(X, y); auto pred_ovr = svm_ovr.predict(X); auto pred_ovo = svm_ovo.predict(X); // Both should produce valid predictions REQUIRE(pred_ovr.size(0) == X.size(0)); REQUIRE(pred_ovo.size(0) == X.size(0)); } } TEST_CASE("SVMClassifier Evaluation Metrics", "[integration][svm_classifier]") { SVMClassifier svm(KernelType::LINEAR); auto [X, y] = generate_test_data(100, 3, 3); svm.fit(X, y); SECTION("Detailed evaluation") { auto metrics = svm.evaluate(X, y); REQUIRE(metrics.accuracy >= 0.0); REQUIRE(metrics.accuracy <= 1.0); REQUIRE(metrics.precision >= 0.0); REQUIRE(metrics.precision <= 1.0); REQUIRE(metrics.recall >= 0.0); REQUIRE(metrics.recall <= 1.0); REQUIRE(metrics.f1_score >= 0.0); REQUIRE(metrics.f1_score <= 1.0); // Check confusion matrix dimensions REQUIRE(metrics.confusion_matrix.size() == 3); // 3 classes for (const auto& row : metrics.confusion_matrix) { REQUIRE(row.size() == 3); } } SECTION("Perfect predictions metrics") { // Create simple dataset where perfect classification is possible auto X_simple = torch::tensor({ {0.0, 0.0}, {1.0, 1.0}, {2.0, 2.0} }); auto y_simple = torch::tensor({ 0, 1, 2 }); SVMClassifier simple_svm(KernelType::LINEAR); simple_svm.fit(X_simple, y_simple); auto metrics = simple_svm.evaluate(X_simple, y_simple); // Should have perfect or near-perfect accuracy on this simple dataset REQUIRE(metrics.accuracy > 0.8); // Very achievable for this data } } TEST_CASE("SVMClassifier Cross-Validation", "[integration][svm_classifier]") { SVMClassifier svm(KernelType::LINEAR); auto [X, y] = generate_test_data(100, 3, 2); SECTION("5-fold cross-validation") { auto cv_scores = svm.cross_validate(X, y, 5); REQUIRE(cv_scores.size() == 5); for (double score : cv_scores) { REQUIRE(score >= 0.0); REQUIRE(score <= 1.0); } // Calculate mean and std double mean = std::accumulate(cv_scores.begin(), cv_scores.end(), 0.0) / cv_scores.size(); REQUIRE(mean >= 0.0); REQUIRE(mean <= 1.0); } SECTION("Invalid CV folds") { REQUIRE_THROWS_AS(svm.cross_validate(X, y, 1), std::invalid_argument); REQUIRE_THROWS_AS(svm.cross_validate(X, y, 0), std::invalid_argument); } SECTION("CV preserves original state") { // Fit the model first svm.fit(X, y); auto original_classes = svm.get_classes(); // Run CV auto cv_scores = svm.cross_validate(X, y, 3); // Should still be fitted with same classes REQUIRE(svm.is_fitted()); REQUIRE(svm.get_classes() == original_classes); } } TEST_CASE("SVMClassifier Grid Search", "[integration][svm_classifier]") { SVMClassifier svm; auto [X, y] = generate_test_data(60, 2, 2); // Smaller dataset for faster testing SECTION("Simple grid search") { json param_grid = { {"kernel", {"linear", "rbf"}}, {"C", {0.1, 1.0, 10.0}} }; auto results = svm.grid_search(X, y, param_grid, 3); REQUIRE(results.contains("best_params")); REQUIRE(results.contains("best_score")); REQUIRE(results.contains("cv_results")); auto best_score = results["best_score"].get(); REQUIRE(best_score >= 0.0); REQUIRE(best_score <= 1.0); auto cv_results = results["cv_results"].get>(); REQUIRE(cv_results.size() == 6); // 2 kernels × 3 C values } SECTION("RBF-specific grid search") { json param_grid = { {"kernel", {"rbf"}}, {"C", {1.0, 10.0}}, {"gamma", {0.01, 0.1}} }; auto results = svm.grid_search(X, y, param_grid, 3); auto best_params = results["best_params"]; REQUIRE(best_params["kernel"] == "rbf"); REQUIRE(best_params.contains("C")); REQUIRE(best_params.contains("gamma")); } } TEST_CASE("SVMClassifier Error Handling", "[integration][svm_classifier]") { SVMClassifier svm; SECTION("Prediction before training") { auto X = torch::randn({ 5, 3 }); REQUIRE_THROWS_AS(svm.predict(X), std::runtime_error); REQUIRE_THROWS_AS(svm.predict_proba(X), std::runtime_error); REQUIRE_THROWS_AS(svm.decision_function(X), std::runtime_error); } SECTION("Inconsistent feature dimensions") { auto X_train = torch::randn({ 50, 3 }); auto y_train = torch::randint(0, 2, { 50 }); auto X_test = torch::randn({ 10, 5 }); // Different number of features svm.fit(X_train, y_train); REQUIRE_THROWS_AS(svm.predict(X_test), std::invalid_argument); } SECTION("Invalid training data") { auto X_invalid = torch::tensor({ {std::numeric_limits::quiet_NaN(), 1.0} }); auto y_invalid = torch::tensor({ 0 }); REQUIRE_THROWS_AS(svm.fit(X_invalid, y_invalid), std::invalid_argument); } SECTION("Empty datasets") { auto X_empty = torch::empty({ 0, 3 }); auto y_empty = torch::empty({ 0 }); REQUIRE_THROWS_AS(svm.fit(X_empty, y_empty), std::invalid_argument); } } TEST_CASE("SVMClassifier Move Semantics", "[integration][svm_classifier]") { SECTION("Move constructor") { SVMClassifier svm1(KernelType::RBF, 2.0); auto [X, y] = generate_test_data(50, 2, 2); svm1.fit(X, y); auto original_classes = svm1.get_classes(); bool was_fitted = svm1.is_fitted(); SVMClassifier svm2 = std::move(svm1); REQUIRE(svm2.is_fitted() == was_fitted); REQUIRE(svm2.get_classes() == original_classes); REQUIRE(svm2.get_kernel_type() == KernelType::RBF); // Original should be in valid but unspecified state REQUIRE_FALSE(svm1.is_fitted()); } SECTION("Move assignment") { SVMClassifier svm1(KernelType::POLYNOMIAL); SVMClassifier svm2(KernelType::LINEAR); auto [X, y] = generate_test_data(40, 2, 2); svm1.fit(X, y); auto original_classes = svm1.get_classes(); svm2 = std::move(svm1); REQUIRE(svm2.is_fitted()); REQUIRE(svm2.get_classes() == original_classes); REQUIRE(svm2.get_kernel_type() == KernelType::POLYNOMIAL); } } TEST_CASE("SVMClassifier Reset Functionality", "[integration][svm_classifier]") { SVMClassifier svm(KernelType::RBF); auto [X, y] = generate_test_data(50, 3, 2); svm.fit(X, y); REQUIRE(svm.is_fitted()); REQUIRE(svm.get_n_features() > 0); REQUIRE(svm.get_n_classes() > 0); svm.reset(); REQUIRE_FALSE(svm.is_fitted()); REQUIRE(svm.get_n_features() == 0); REQUIRE(svm.get_n_classes() == 0); // Should be able to train again after reset REQUIRE_NOTHROW(svm.fit(X, y)); REQUIRE(svm.is_fitted()); }