From c2479bc7c75cb9d7fdf80d3c22e9fb6b9c525c8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 24 Jun 2025 12:57:15 +0200 Subject: [PATCH] Fix Regularization error --- .gitignore | 3 +- include/svm_classifier/data_converter.hpp | 1 + src/data_converter.cpp | 2 +- src/multiclass_strategy.cpp | 6 ++-- test_param_change.cpp | 36 +++++++++++++++++++++++ 5 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 test_param_change.cpp diff --git a/.gitignore b/.gitignore index 6cb333e..347f8f9 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,5 @@ build_Release build_Debug build -libtorch \ No newline at end of file +libtorch +Testing diff --git a/include/svm_classifier/data_converter.hpp b/include/svm_classifier/data_converter.hpp index 3306fdc..9696c73 100644 --- a/include/svm_classifier/data_converter.hpp +++ b/include/svm_classifier/data_converter.hpp @@ -138,6 +138,7 @@ namespace svm_classifier { std::vector linear_y_space_; // Single sample storage (for prediction) + // Thread-local storage for single sample conversions std::vector single_svm_nodes_; std::vector single_linear_nodes_; diff --git a/src/data_converter.cpp b/src/data_converter.cpp index 596354b..7fb7255 100644 --- a/src/data_converter.cpp +++ b/src/data_converter.cpp @@ -92,7 +92,7 @@ namespace svm_classifier { linear_problem->n = n_features_; linear_problem->x = linear_x_space_.data(); linear_problem->y = linear_y_space_.data(); - linear_problem->bias = -1; // No bias term by default + linear_problem->bias = 1.0; // Add bias term with value 1.0 return linear_problem; } diff --git a/src/multiclass_strategy.cpp b/src/multiclass_strategy.cpp index baee468..9424062 100644 --- a/src/multiclass_strategy.cpp +++ b/src/multiclass_strategy.cpp @@ -321,7 +321,8 @@ namespace svm_classifier { linear_params.p = 0.1; linear_params.nu = 0.5; linear_params.init_sol = nullptr; - linear_params.regularize_bias = 0; + linear_params.regularize_bias = 1; + linear_params.w_recalc = false; // Check parameters const char* error_msg = check_parameter(problem.get(), &linear_params); @@ -628,7 +629,8 @@ namespace svm_classifier { linear_params.p = 0.1; linear_params.nu = 0.5; linear_params.init_sol = nullptr; - linear_params.regularize_bias = 0; + linear_params.regularize_bias = 1; + linear_params.w_recalc = false; // Check parameters const char* error_msg = check_parameter(problem.get(), &linear_params); diff --git a/test_param_change.cpp b/test_param_change.cpp new file mode 100644 index 0000000..8fb7a41 --- /dev/null +++ b/test_param_change.cpp @@ -0,0 +1,36 @@ +#include +#include +#include + +using namespace svm_classifier; + +int main() { + try { + std::cout << "Creating linear SVM..." << std::endl; + SVMClassifier svm(KernelType::LINEAR, 1.0); + + std::cout << "Generating data..." << std::endl; + torch::manual_seed(42); + auto X = torch::randn({50, 3}); + auto y = torch::randint(0, 2, {50}); + + std::cout << "Training SVM..." << std::endl; + auto metrics = svm.fit(X, y); + std::cout << "Training completed successfully!" << std::endl; + std::cout << "Is fitted: " << svm.is_fitted() << std::endl; + + std::cout << "\nChanging to RBF kernel..." << std::endl; + nlohmann::json new_params = {{"kernel", "rbf"}}; + svm.set_parameters(new_params); + std::cout << "Parameters changed successfully!" << std::endl; + std::cout << "Is fitted after param change: " << svm.is_fitted() << std::endl; + + std::cout << "\nAll operations completed successfully!" << std::endl; + + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } + + return 0; +} \ No newline at end of file