518 lines
16 KiB
C++
518 lines
16 KiB
C++
/**
|
|
* @file test_multiclass_strategy.cpp
|
|
* @brief Unit tests for multiclass strategy classes
|
|
*/
|
|
|
|
#include <catch2/catch_test_macros.hpp>
|
|
#include <catch2/catch_approx.hpp>
|
|
#include <svm_classifier/multiclass_strategy.hpp>
|
|
#include <svm_classifier/kernel_parameters.hpp>
|
|
#include <svm_classifier/data_converter.hpp>
|
|
#include <torch/torch.h>
|
|
|
|
using namespace svm_classifier;
|
|
|
|
/**
|
|
* @brief Generate simple test data for multiclass testing
|
|
*/
|
|
std::pair<torch::Tensor, torch::Tensor> generate_multiclass_data(int n_samples = 60,
|
|
int n_features = 2,
|
|
int n_classes = 3,
|
|
int seed = 42)
|
|
{
|
|
torch::manual_seed(seed);
|
|
|
|
auto X = torch::randn({ n_samples, n_features });
|
|
auto y = torch::randint(0, n_classes, { n_samples });
|
|
|
|
// Create some structure in the data
|
|
for (int i = 0; i < n_samples; ++i) {
|
|
int class_label = y[i].item<int>();
|
|
// Add class-specific bias to make classification easier
|
|
X[i] += class_label * 0.5;
|
|
}
|
|
|
|
return { X, y };
|
|
}
|
|
|
|
TEST_CASE("MulticlassStrategy Factory Function", "[unit][multiclass_strategy]")
|
|
{
|
|
SECTION("Create One-vs-Rest strategy")
|
|
{
|
|
auto strategy = create_multiclass_strategy(MulticlassStrategy::ONE_VS_REST);
|
|
|
|
REQUIRE(strategy != nullptr);
|
|
REQUIRE(strategy->get_strategy_type() == MulticlassStrategy::ONE_VS_REST);
|
|
REQUIRE_FALSE(strategy->get_classes().empty() == false); // Not trained yet
|
|
REQUIRE(strategy->get_n_classes() == 0);
|
|
}
|
|
|
|
SECTION("Create One-vs-One strategy")
|
|
{
|
|
auto strategy = create_multiclass_strategy(MulticlassStrategy::ONE_VS_ONE);
|
|
|
|
REQUIRE(strategy != nullptr);
|
|
REQUIRE(strategy->get_strategy_type() == MulticlassStrategy::ONE_VS_ONE);
|
|
REQUIRE(strategy->get_n_classes() == 0);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("OneVsRestStrategy Basic Functionality", "[unit][multiclass_strategy]")
|
|
{
|
|
OneVsRestStrategy strategy;
|
|
DataConverter converter;
|
|
KernelParameters params;
|
|
|
|
SECTION("Initial state")
|
|
{
|
|
REQUIRE(strategy.get_strategy_type() == MulticlassStrategy::ONE_VS_REST);
|
|
REQUIRE(strategy.get_n_classes() == 0);
|
|
REQUIRE(strategy.get_classes().empty());
|
|
REQUIRE_FALSE(strategy.supports_probability());
|
|
}
|
|
|
|
SECTION("Training with linear kernel")
|
|
{
|
|
auto [X, y] = generate_multiclass_data(60, 3, 3);
|
|
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
params.set_C(1.0);
|
|
|
|
auto metrics = strategy.fit(X, y, params, converter);
|
|
|
|
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
|
|
REQUIRE(metrics.training_time >= 0.0);
|
|
REQUIRE(strategy.get_n_classes() == 3);
|
|
|
|
auto classes = strategy.get_classes();
|
|
REQUIRE(classes.size() == 3);
|
|
REQUIRE(std::is_sorted(classes.begin(), classes.end()));
|
|
}
|
|
|
|
SECTION("Training with RBF kernel")
|
|
{
|
|
auto [X, y] = generate_multiclass_data(50, 2, 2);
|
|
|
|
params.set_kernel_type(KernelType::RBF);
|
|
params.set_C(1.0);
|
|
params.set_gamma(0.1);
|
|
|
|
auto metrics = strategy.fit(X, y, params, converter);
|
|
|
|
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
|
|
REQUIRE(strategy.get_n_classes() == 2);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("OneVsRestStrategy Prediction", "[unit][multiclass_strategy]")
|
|
{
|
|
OneVsRestStrategy strategy;
|
|
DataConverter converter;
|
|
KernelParameters params;
|
|
|
|
auto [X, y] = generate_multiclass_data(80, 3, 3);
|
|
|
|
// Split data
|
|
auto X_train = X.slice(0, 0, 60);
|
|
auto y_train = y.slice(0, 0, 60);
|
|
auto X_test = X.slice(0, 60);
|
|
auto y_test = y.slice(0, 60);
|
|
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
strategy.fit(X_train, y_train, params, converter);
|
|
|
|
SECTION("Basic prediction")
|
|
{
|
|
auto predictions = strategy.predict(X_test, converter);
|
|
|
|
REQUIRE(static_cast<int64_t>(predictions.size()) == X_test.size(0));
|
|
|
|
// Check that all predictions are valid class labels
|
|
auto classes = strategy.get_classes();
|
|
for (int pred : predictions) {
|
|
REQUIRE(std::find(classes.begin(), classes.end(), pred) != classes.end());
|
|
}
|
|
}
|
|
|
|
SECTION("Decision function")
|
|
{
|
|
auto decision_values = strategy.decision_function(X_test, converter);
|
|
|
|
REQUIRE(static_cast<int64_t>(decision_values.size()) == X_test.size(0));
|
|
REQUIRE(static_cast<int>(decision_values[0].size()) == strategy.get_n_classes());
|
|
|
|
// Decision values should be real numbers
|
|
for (const auto& sample_decisions : decision_values) {
|
|
for (double value : sample_decisions) {
|
|
REQUIRE(std::isfinite(value));
|
|
}
|
|
}
|
|
}
|
|
|
|
SECTION("Prediction without training")
|
|
{
|
|
OneVsRestStrategy untrained_strategy;
|
|
|
|
REQUIRE_THROWS_AS(untrained_strategy.predict(X_test, converter), std::runtime_error);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("OneVsRestStrategy Probability Prediction", "[unit][multiclass_strategy]")
|
|
{
|
|
OneVsRestStrategy strategy;
|
|
DataConverter converter;
|
|
KernelParameters params;
|
|
|
|
auto [X, y] = generate_multiclass_data(60, 2, 3);
|
|
|
|
SECTION("With probability enabled")
|
|
{
|
|
params.set_kernel_type(KernelType::RBF);
|
|
params.set_probability(true);
|
|
|
|
strategy.fit(X, y, params, converter);
|
|
|
|
if (strategy.supports_probability()) {
|
|
auto probabilities = strategy.predict_proba(X, converter);
|
|
|
|
REQUIRE(static_cast<int64_t>(probabilities.size()) == X.size(0));
|
|
REQUIRE(probabilities[0].size() == 3); // 3 classes
|
|
|
|
// Check probability constraints
|
|
for (const auto& sample_probs : probabilities) {
|
|
double sum = 0.0;
|
|
for (double prob : sample_probs) {
|
|
REQUIRE(prob >= 0.0);
|
|
REQUIRE(prob <= 1.0);
|
|
sum += prob;
|
|
}
|
|
REQUIRE(sum == Catch::Approx(1.0).margin(1e-6));
|
|
}
|
|
}
|
|
}
|
|
|
|
SECTION("Without probability enabled")
|
|
{
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
params.set_probability(false);
|
|
|
|
strategy.fit(X, y, params, converter);
|
|
|
|
// May or may not support probability depending on implementation
|
|
// If not supported, should throw
|
|
if (!strategy.supports_probability()) {
|
|
REQUIRE_THROWS_AS(strategy.predict_proba(X, converter), std::runtime_error);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_CASE("OneVsOneStrategy Basic Functionality", "[unit][multiclass_strategy]")
|
|
{
|
|
OneVsOneStrategy strategy;
|
|
DataConverter converter;
|
|
KernelParameters params;
|
|
|
|
SECTION("Initial state")
|
|
{
|
|
REQUIRE(strategy.get_strategy_type() == MulticlassStrategy::ONE_VS_ONE);
|
|
REQUIRE(strategy.get_n_classes() == 0);
|
|
REQUIRE(strategy.get_classes().empty());
|
|
}
|
|
|
|
SECTION("Training with multiple classes")
|
|
{
|
|
auto [X, y] = generate_multiclass_data(80, 3, 4); // 4 classes for OvO
|
|
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
params.set_C(1.0);
|
|
|
|
auto metrics = strategy.fit(X, y, params, converter);
|
|
|
|
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
|
|
REQUIRE(strategy.get_n_classes() == 4);
|
|
|
|
auto classes = strategy.get_classes();
|
|
REQUIRE(classes.size() == 4);
|
|
|
|
// For 4 classes, OvO should train C(4,2) = 6 binary classifiers
|
|
// This is implementation detail but good to verify the concept
|
|
}
|
|
|
|
SECTION("Binary classification")
|
|
{
|
|
auto [X, y] = generate_multiclass_data(50, 2, 2);
|
|
|
|
params.set_kernel_type(KernelType::RBF);
|
|
params.set_gamma(0.1);
|
|
|
|
auto metrics = strategy.fit(X, y, params, converter);
|
|
|
|
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
|
|
REQUIRE(strategy.get_n_classes() == 2);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("OneVsOneStrategy Prediction", "[unit][multiclass_strategy]")
|
|
{
|
|
OneVsOneStrategy strategy;
|
|
DataConverter converter;
|
|
KernelParameters params;
|
|
|
|
auto [X, y] = generate_multiclass_data(90, 2, 3);
|
|
|
|
auto X_train = X.slice(0, 0, 70);
|
|
auto y_train = y.slice(0, 0, 70);
|
|
auto X_test = X.slice(0, 70);
|
|
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
strategy.fit(X_train, y_train, params, converter);
|
|
|
|
SECTION("Basic prediction")
|
|
{
|
|
auto predictions = strategy.predict(X_test, converter);
|
|
|
|
REQUIRE(static_cast<int64_t>(predictions.size()) == X_test.size(0));
|
|
|
|
auto classes = strategy.get_classes();
|
|
for (int pred : predictions) {
|
|
REQUIRE(std::find(classes.begin(), classes.end(), pred) != classes.end());
|
|
}
|
|
}
|
|
|
|
SECTION("Decision function")
|
|
{
|
|
auto decision_values = strategy.decision_function(X_test, converter);
|
|
|
|
REQUIRE(static_cast<int64_t>(decision_values.size()) == X_test.size(0));
|
|
|
|
// For 3 classes, OvO should have C(3,2) = 3 pairwise comparisons
|
|
REQUIRE(decision_values[0].size() == 3);
|
|
|
|
for (const auto& sample_decisions : decision_values) {
|
|
for (double value : sample_decisions) {
|
|
REQUIRE(std::isfinite(value));
|
|
}
|
|
}
|
|
}
|
|
|
|
SECTION("Probability prediction")
|
|
{
|
|
// OvO probability estimation is more complex
|
|
auto probabilities = strategy.predict_proba(X_test, converter);
|
|
|
|
REQUIRE(static_cast<int64_t>(probabilities.size()) == X_test.size(0));
|
|
REQUIRE(probabilities[0].size() == 3); // 3 classes
|
|
|
|
// Check basic probability constraints
|
|
for (const auto& sample_probs : probabilities) {
|
|
double sum = 0.0;
|
|
for (double prob : sample_probs) {
|
|
REQUIRE(prob >= 0.0);
|
|
REQUIRE(prob <= 1.0);
|
|
sum += prob;
|
|
}
|
|
// OvO probability might not sum exactly to 1 due to voting mechanism
|
|
REQUIRE(sum == Catch::Approx(1.0).margin(0.1));
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_CASE("MulticlassStrategy Comparison", "[integration][multiclass_strategy]")
|
|
{
|
|
auto [X, y] = generate_multiclass_data(100, 3, 3);
|
|
|
|
auto X_train = X.slice(0, 0, 80);
|
|
auto y_train = y.slice(0, 0, 80);
|
|
auto X_test = X.slice(0, 80);
|
|
auto y_test = y.slice(0, 80);
|
|
|
|
DataConverter converter1, converter2;
|
|
KernelParameters params;
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
params.set_C(1.0);
|
|
|
|
SECTION("Compare OvR vs OvO predictions")
|
|
{
|
|
OneVsRestStrategy ovr_strategy;
|
|
OneVsOneStrategy ovo_strategy;
|
|
|
|
ovr_strategy.fit(X_train, y_train, params, converter1);
|
|
ovo_strategy.fit(X_train, y_train, params, converter2);
|
|
|
|
auto ovr_predictions = ovr_strategy.predict(X_test, converter1);
|
|
auto ovo_predictions = ovo_strategy.predict(X_test, converter2);
|
|
|
|
REQUIRE(ovr_predictions.size() == ovo_predictions.size());
|
|
|
|
// Both should predict valid class labels
|
|
auto ovr_classes = ovr_strategy.get_classes();
|
|
auto ovo_classes = ovo_strategy.get_classes();
|
|
|
|
REQUIRE(ovr_classes == ovo_classes); // Should have same classes
|
|
|
|
for (size_t i = 0; i < ovr_predictions.size(); ++i) {
|
|
REQUIRE(std::find(ovr_classes.begin(), ovr_classes.end(), ovr_predictions[i]) != ovr_classes.end());
|
|
REQUIRE(std::find(ovo_classes.begin(), ovo_classes.end(), ovo_predictions[i]) != ovo_classes.end());
|
|
}
|
|
}
|
|
|
|
SECTION("Compare decision function outputs")
|
|
{
|
|
OneVsRestStrategy ovr_strategy;
|
|
OneVsOneStrategy ovo_strategy;
|
|
|
|
ovr_strategy.fit(X_train, y_train, params, converter1);
|
|
ovo_strategy.fit(X_train, y_train, params, converter2);
|
|
|
|
auto ovr_decisions = ovr_strategy.decision_function(X_test, converter1);
|
|
auto ovo_decisions = ovo_strategy.decision_function(X_test, converter2);
|
|
|
|
REQUIRE(ovr_decisions.size() == ovo_decisions.size());
|
|
|
|
// OvR should have one decision value per class
|
|
REQUIRE(ovr_decisions[0].size() == 3);
|
|
|
|
// OvO should have one decision value per class pair: C(3,2) = 3
|
|
REQUIRE(ovo_decisions[0].size() == 3);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("MulticlassStrategy Edge Cases", "[unit][multiclass_strategy]")
|
|
{
|
|
DataConverter converter;
|
|
KernelParameters params;
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
|
|
SECTION("Single class dataset")
|
|
{
|
|
auto X = torch::randn({ 20, 2 });
|
|
auto y = torch::zeros({ 20 }, torch::kInt32); // All same class
|
|
|
|
OneVsRestStrategy strategy;
|
|
|
|
// Should handle single class gracefully
|
|
auto metrics = strategy.fit(X, y, params, converter);
|
|
|
|
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
|
|
// Implementation might extend to binary case
|
|
|
|
auto predictions = strategy.predict(X, converter);
|
|
REQUIRE(static_cast<int64_t>(predictions.size()) == X.size(0));
|
|
}
|
|
|
|
SECTION("Very small dataset")
|
|
{
|
|
auto X = torch::tensor({ {1.0, 2.0}, {3.0, 4.0} });
|
|
auto y = torch::tensor({ 0, 1 });
|
|
|
|
OneVsOneStrategy strategy;
|
|
|
|
auto metrics = strategy.fit(X, y, params, converter);
|
|
|
|
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
|
|
|
|
auto predictions = strategy.predict(X, converter);
|
|
REQUIRE(predictions.size() == 2);
|
|
}
|
|
|
|
SECTION("Imbalanced classes")
|
|
{
|
|
// Create dataset with very imbalanced classes
|
|
auto X1 = torch::randn({ 80, 2 });
|
|
auto y1 = torch::zeros({ 80 }, torch::kInt32);
|
|
|
|
auto X2 = torch::randn({ 5, 2 });
|
|
auto y2 = torch::ones({ 5 }, torch::kInt32);
|
|
|
|
auto X = torch::cat({ X1, X2 }, 0);
|
|
auto y = torch::cat({ y1, y2 }, 0);
|
|
|
|
OneVsRestStrategy strategy;
|
|
auto metrics = strategy.fit(X, y, params, converter);
|
|
|
|
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
|
|
REQUIRE(strategy.get_n_classes() == 2);
|
|
|
|
auto predictions = strategy.predict(X, converter);
|
|
REQUIRE(static_cast<int64_t>(predictions.size()) == X.size(0));
|
|
}
|
|
}
|
|
|
|
TEST_CASE("MulticlassStrategy Error Handling", "[unit][multiclass_strategy]")
|
|
{
|
|
DataConverter converter;
|
|
KernelParameters params;
|
|
|
|
SECTION("Invalid parameters")
|
|
{
|
|
OneVsRestStrategy strategy;
|
|
auto [X, y] = generate_multiclass_data(50, 2, 2);
|
|
|
|
// Invalid C parameter
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
params.set_C(-1.0); // Invalid
|
|
|
|
REQUIRE_THROWS(strategy.fit(X, y, params, converter));
|
|
}
|
|
|
|
SECTION("Mismatched tensor dimensions")
|
|
{
|
|
OneVsOneStrategy strategy;
|
|
|
|
auto X = torch::randn({ 50, 3 });
|
|
auto y = torch::randint(0, 2, { 40 }); // Wrong number of labels
|
|
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
params.set_C(1.0);
|
|
|
|
REQUIRE_THROWS_AS(strategy.fit(X, y, params, converter), std::invalid_argument);
|
|
}
|
|
|
|
SECTION("Prediction on untrained strategy")
|
|
{
|
|
OneVsRestStrategy strategy;
|
|
auto X = torch::randn({ 10, 2 });
|
|
|
|
REQUIRE_THROWS_AS(strategy.predict(X, converter), std::runtime_error);
|
|
REQUIRE_THROWS_AS(strategy.decision_function(X, converter), std::runtime_error);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("MulticlassStrategy Memory Management", "[unit][multiclass_strategy]")
|
|
{
|
|
SECTION("Strategy destruction")
|
|
{
|
|
// Test that strategies clean up properly
|
|
auto strategy = create_multiclass_strategy(MulticlassStrategy::ONE_VS_REST);
|
|
|
|
DataConverter converter;
|
|
KernelParameters params;
|
|
auto [X, y] = generate_multiclass_data(50, 2, 3);
|
|
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
strategy->fit(X, y, params, converter);
|
|
|
|
REQUIRE(strategy->get_n_classes() == 3);
|
|
|
|
// Strategy should clean up automatically when destroyed
|
|
}
|
|
|
|
SECTION("Multiple training rounds")
|
|
{
|
|
OneVsRestStrategy strategy;
|
|
DataConverter converter;
|
|
KernelParameters params;
|
|
params.set_kernel_type(KernelType::LINEAR);
|
|
|
|
// Train multiple times with different data
|
|
for (int i = 0; i < 3; ++i) {
|
|
auto [X, y] = generate_multiclass_data(40, 2, 2, i); // Different seed
|
|
|
|
auto metrics = strategy.fit(X, y, params, converter);
|
|
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
|
|
|
|
auto predictions = strategy.predict(X, converter);
|
|
REQUIRE(static_cast<int64_t>(predictions.size()) == X.size(0));
|
|
}
|
|
}
|
|
} |