/** * @file test_multiclass_strategy.cpp * @brief Unit tests for multiclass strategy classes */ #include #include #include #include #include using namespace svm_classifier; /** * @brief Generate simple test data for multiclass testing */ std::pair generate_multiclass_data(int n_samples = 60, int n_features = 2, int n_classes = 3, int seed = 42) { torch::manual_seed(seed); auto X = torch::randn({ n_samples, n_features }); auto y = torch::randint(0, n_classes, { n_samples }); // Create some structure in the data for (int i = 0; i < n_samples; ++i) { int class_label = y[i].item(); // Add class-specific bias to make classification easier X[i] += class_label * 0.5; } return { X, y }; } TEST_CASE("MulticlassStrategy Factory Function", "[unit][multiclass_strategy]") { SECTION("Create One-vs-Rest strategy") { auto strategy = create_multiclass_strategy(MulticlassStrategy::ONE_VS_REST); REQUIRE(strategy != nullptr); REQUIRE(strategy->get_strategy_type() == MulticlassStrategy::ONE_VS_REST); REQUIRE_FALSE(strategy->get_classes().empty() == false); // Not trained yet REQUIRE(strategy->get_n_classes() == 0); } SECTION("Create One-vs-One strategy") { auto strategy = create_multiclass_strategy(MulticlassStrategy::ONE_VS_ONE); REQUIRE(strategy != nullptr); REQUIRE(strategy->get_strategy_type() == MulticlassStrategy::ONE_VS_ONE); REQUIRE(strategy->get_n_classes() == 0); } } TEST_CASE("OneVsRestStrategy Basic Functionality", "[unit][multiclass_strategy]") { OneVsRestStrategy strategy; DataConverter converter; KernelParameters params; SECTION("Initial state") { REQUIRE(strategy.get_strategy_type() == MulticlassStrategy::ONE_VS_REST); REQUIRE(strategy.get_n_classes() == 0); REQUIRE(strategy.get_classes().empty()); REQUIRE_FALSE(strategy.supports_probability()); } SECTION("Training with linear kernel") { auto [X, y] = generate_multiclass_data(60, 3, 3); params.set_kernel_type(KernelType::LINEAR); params.set_C(1.0); auto metrics = strategy.fit(X, y, params, converter); REQUIRE(metrics.status == TrainingStatus::SUCCESS); REQUIRE(metrics.training_time >= 0.0); REQUIRE(strategy.get_n_classes() == 3); auto classes = strategy.get_classes(); REQUIRE(classes.size() == 3); REQUIRE(std::is_sorted(classes.begin(), classes.end())); } SECTION("Training with RBF kernel") { auto [X, y] = generate_multiclass_data(50, 2, 2); params.set_kernel_type(KernelType::RBF); params.set_C(1.0); params.set_gamma(0.1); auto metrics = strategy.fit(X, y, params, converter); REQUIRE(metrics.status == TrainingStatus::SUCCESS); REQUIRE(strategy.get_n_classes() == 2); } } TEST_CASE("OneVsRestStrategy Prediction", "[unit][multiclass_strategy]") { OneVsRestStrategy strategy; DataConverter converter; KernelParameters params; auto [X, y] = generate_multiclass_data(80, 3, 3); // Split data auto X_train = X.slice(0, 0, 60); auto y_train = y.slice(0, 0, 60); auto X_test = X.slice(0, 60); auto y_test = y.slice(0, 60); params.set_kernel_type(KernelType::LINEAR); strategy.fit(X_train, y_train, params, converter); SECTION("Basic prediction") { auto predictions = strategy.predict(X_test, converter); REQUIRE(static_cast(predictions.size()) == X_test.size(0)); // Check that all predictions are valid class labels auto classes = strategy.get_classes(); for (int pred : predictions) { REQUIRE(std::find(classes.begin(), classes.end(), pred) != classes.end()); } } SECTION("Decision function") { auto decision_values = strategy.decision_function(X_test, converter); REQUIRE(static_cast(decision_values.size()) == X_test.size(0)); REQUIRE(static_cast(decision_values[0].size()) == strategy.get_n_classes()); // Decision values should be real numbers for (const auto& sample_decisions : decision_values) { for (double value : sample_decisions) { REQUIRE(std::isfinite(value)); } } } SECTION("Prediction without training") { OneVsRestStrategy untrained_strategy; REQUIRE_THROWS_AS(untrained_strategy.predict(X_test, converter), std::runtime_error); } } TEST_CASE("OneVsRestStrategy Probability Prediction", "[unit][multiclass_strategy]") { OneVsRestStrategy strategy; DataConverter converter; KernelParameters params; auto [X, y] = generate_multiclass_data(60, 2, 3); SECTION("With probability enabled") { params.set_kernel_type(KernelType::RBF); params.set_probability(true); strategy.fit(X, y, params, converter); if (strategy.supports_probability()) { auto probabilities = strategy.predict_proba(X, converter); REQUIRE(static_cast(probabilities.size()) == X.size(0)); REQUIRE(probabilities[0].size() == 3); // 3 classes // Check probability constraints for (const auto& sample_probs : probabilities) { double sum = 0.0; for (double prob : sample_probs) { REQUIRE(prob >= 0.0); REQUIRE(prob <= 1.0); sum += prob; } REQUIRE(sum == Catch::Approx(1.0).margin(1e-6)); } } } SECTION("Without probability enabled") { params.set_kernel_type(KernelType::LINEAR); params.set_probability(false); strategy.fit(X, y, params, converter); // May or may not support probability depending on implementation // If not supported, should throw if (!strategy.supports_probability()) { REQUIRE_THROWS_AS(strategy.predict_proba(X, converter), std::runtime_error); } } } TEST_CASE("OneVsOneStrategy Basic Functionality", "[unit][multiclass_strategy]") { OneVsOneStrategy strategy; DataConverter converter; KernelParameters params; SECTION("Initial state") { REQUIRE(strategy.get_strategy_type() == MulticlassStrategy::ONE_VS_ONE); REQUIRE(strategy.get_n_classes() == 0); REQUIRE(strategy.get_classes().empty()); } SECTION("Training with multiple classes") { auto [X, y] = generate_multiclass_data(80, 3, 4); // 4 classes for OvO params.set_kernel_type(KernelType::LINEAR); params.set_C(1.0); auto metrics = strategy.fit(X, y, params, converter); REQUIRE(metrics.status == TrainingStatus::SUCCESS); REQUIRE(strategy.get_n_classes() == 4); auto classes = strategy.get_classes(); REQUIRE(classes.size() == 4); // For 4 classes, OvO should train C(4,2) = 6 binary classifiers // This is implementation detail but good to verify the concept } SECTION("Binary classification") { auto [X, y] = generate_multiclass_data(50, 2, 2); params.set_kernel_type(KernelType::RBF); params.set_gamma(0.1); auto metrics = strategy.fit(X, y, params, converter); REQUIRE(metrics.status == TrainingStatus::SUCCESS); REQUIRE(strategy.get_n_classes() == 2); } } TEST_CASE("OneVsOneStrategy Prediction", "[unit][multiclass_strategy]") { OneVsOneStrategy strategy; DataConverter converter; KernelParameters params; auto [X, y] = generate_multiclass_data(90, 2, 3); auto X_train = X.slice(0, 0, 70); auto y_train = y.slice(0, 0, 70); auto X_test = X.slice(0, 70); params.set_kernel_type(KernelType::LINEAR); strategy.fit(X_train, y_train, params, converter); SECTION("Basic prediction") { auto predictions = strategy.predict(X_test, converter); REQUIRE(static_cast(predictions.size()) == X_test.size(0)); auto classes = strategy.get_classes(); for (int pred : predictions) { REQUIRE(std::find(classes.begin(), classes.end(), pred) != classes.end()); } } SECTION("Decision function") { auto decision_values = strategy.decision_function(X_test, converter); REQUIRE(static_cast(decision_values.size()) == X_test.size(0)); // For 3 classes, OvO should have C(3,2) = 3 pairwise comparisons REQUIRE(decision_values[0].size() == 3); for (const auto& sample_decisions : decision_values) { for (double value : sample_decisions) { REQUIRE(std::isfinite(value)); } } } SECTION("Probability prediction") { // OvO probability estimation is more complex auto probabilities = strategy.predict_proba(X_test, converter); REQUIRE(static_cast(probabilities.size()) == X_test.size(0)); REQUIRE(probabilities[0].size() == 3); // 3 classes // Check basic probability constraints for (const auto& sample_probs : probabilities) { double sum = 0.0; for (double prob : sample_probs) { REQUIRE(prob >= 0.0); REQUIRE(prob <= 1.0); sum += prob; } // OvO probability might not sum exactly to 1 due to voting mechanism REQUIRE(sum == Catch::Approx(1.0).margin(0.1)); } } } TEST_CASE("MulticlassStrategy Comparison", "[integration][multiclass_strategy]") { auto [X, y] = generate_multiclass_data(100, 3, 3); auto X_train = X.slice(0, 0, 80); auto y_train = y.slice(0, 0, 80); auto X_test = X.slice(0, 80); auto y_test = y.slice(0, 80); DataConverter converter1, converter2; KernelParameters params; params.set_kernel_type(KernelType::LINEAR); params.set_C(1.0); SECTION("Compare OvR vs OvO predictions") { OneVsRestStrategy ovr_strategy; OneVsOneStrategy ovo_strategy; ovr_strategy.fit(X_train, y_train, params, converter1); ovo_strategy.fit(X_train, y_train, params, converter2); auto ovr_predictions = ovr_strategy.predict(X_test, converter1); auto ovo_predictions = ovo_strategy.predict(X_test, converter2); REQUIRE(ovr_predictions.size() == ovo_predictions.size()); // Both should predict valid class labels auto ovr_classes = ovr_strategy.get_classes(); auto ovo_classes = ovo_strategy.get_classes(); REQUIRE(ovr_classes == ovo_classes); // Should have same classes for (size_t i = 0; i < ovr_predictions.size(); ++i) { REQUIRE(std::find(ovr_classes.begin(), ovr_classes.end(), ovr_predictions[i]) != ovr_classes.end()); REQUIRE(std::find(ovo_classes.begin(), ovo_classes.end(), ovo_predictions[i]) != ovo_classes.end()); } } SECTION("Compare decision function outputs") { OneVsRestStrategy ovr_strategy; OneVsOneStrategy ovo_strategy; ovr_strategy.fit(X_train, y_train, params, converter1); ovo_strategy.fit(X_train, y_train, params, converter2); auto ovr_decisions = ovr_strategy.decision_function(X_test, converter1); auto ovo_decisions = ovo_strategy.decision_function(X_test, converter2); REQUIRE(ovr_decisions.size() == ovo_decisions.size()); // OvR should have one decision value per class REQUIRE(ovr_decisions[0].size() == 3); // OvO should have one decision value per class pair: C(3,2) = 3 REQUIRE(ovo_decisions[0].size() == 3); } } TEST_CASE("MulticlassStrategy Edge Cases", "[unit][multiclass_strategy]") { DataConverter converter; KernelParameters params; params.set_kernel_type(KernelType::LINEAR); SECTION("Single class dataset") { auto X = torch::randn({ 20, 2 }); auto y = torch::zeros({ 20 }, torch::kInt32); // All same class OneVsRestStrategy strategy; // Should handle single class gracefully auto metrics = strategy.fit(X, y, params, converter); REQUIRE(metrics.status == TrainingStatus::SUCCESS); // Implementation might extend to binary case auto predictions = strategy.predict(X, converter); REQUIRE(static_cast(predictions.size()) == X.size(0)); } SECTION("Very small dataset") { auto X = torch::tensor({ {1.0, 2.0}, {3.0, 4.0} }); auto y = torch::tensor({ 0, 1 }); OneVsOneStrategy strategy; auto metrics = strategy.fit(X, y, params, converter); REQUIRE(metrics.status == TrainingStatus::SUCCESS); auto predictions = strategy.predict(X, converter); REQUIRE(predictions.size() == 2); } SECTION("Imbalanced classes") { // Create dataset with very imbalanced classes auto X1 = torch::randn({ 80, 2 }); auto y1 = torch::zeros({ 80 }, torch::kInt32); auto X2 = torch::randn({ 5, 2 }); auto y2 = torch::ones({ 5 }, torch::kInt32); auto X = torch::cat({ X1, X2 }, 0); auto y = torch::cat({ y1, y2 }, 0); OneVsRestStrategy strategy; auto metrics = strategy.fit(X, y, params, converter); REQUIRE(metrics.status == TrainingStatus::SUCCESS); REQUIRE(strategy.get_n_classes() == 2); auto predictions = strategy.predict(X, converter); REQUIRE(static_cast(predictions.size()) == X.size(0)); } } TEST_CASE("MulticlassStrategy Error Handling", "[unit][multiclass_strategy]") { DataConverter converter; KernelParameters params; SECTION("Invalid parameters") { OneVsRestStrategy strategy; auto [X, y] = generate_multiclass_data(50, 2, 2); // Invalid C parameter params.set_kernel_type(KernelType::LINEAR); params.set_C(-1.0); // Invalid REQUIRE_THROWS(strategy.fit(X, y, params, converter)); } SECTION("Mismatched tensor dimensions") { OneVsOneStrategy strategy; auto X = torch::randn({ 50, 3 }); auto y = torch::randint(0, 2, { 40 }); // Wrong number of labels params.set_kernel_type(KernelType::LINEAR); params.set_C(1.0); REQUIRE_THROWS_AS(strategy.fit(X, y, params, converter), std::invalid_argument); } SECTION("Prediction on untrained strategy") { OneVsRestStrategy strategy; auto X = torch::randn({ 10, 2 }); REQUIRE_THROWS_AS(strategy.predict(X, converter), std::runtime_error); REQUIRE_THROWS_AS(strategy.decision_function(X, converter), std::runtime_error); } } TEST_CASE("MulticlassStrategy Memory Management", "[unit][multiclass_strategy]") { SECTION("Strategy destruction") { // Test that strategies clean up properly auto strategy = create_multiclass_strategy(MulticlassStrategy::ONE_VS_REST); DataConverter converter; KernelParameters params; auto [X, y] = generate_multiclass_data(50, 2, 3); params.set_kernel_type(KernelType::LINEAR); strategy->fit(X, y, params, converter); REQUIRE(strategy->get_n_classes() == 3); // Strategy should clean up automatically when destroyed } SECTION("Multiple training rounds") { OneVsRestStrategy strategy; DataConverter converter; KernelParameters params; params.set_kernel_type(KernelType::LINEAR); // Train multiple times with different data for (int i = 0; i < 3; ++i) { auto [X, y] = generate_multiclass_data(40, 2, 2, i); // Different seed auto metrics = strategy.fit(X, y, params, converter); REQUIRE(metrics.status == TrainingStatus::SUCCESS); auto predictions = strategy.predict(X, converter); REQUIRE(static_cast(predictions.size()) == X.size(0)); } } }