Fix Regularization error
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 25s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m18s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 6m17s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m15s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m5s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 6m14s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 6m9s
CI/CD Pipeline / Docker Build Test (push) Failing after 1m12s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Failing after 20s
CI/CD Pipeline / Create Release Package (push) Has been skipped
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 25s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m18s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 6m17s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m15s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m5s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 6m14s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 6m9s
CI/CD Pipeline / Docker Build Test (push) Failing after 1m12s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Failing after 20s
CI/CD Pipeline / Create Release Package (push) Has been skipped
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,3 +36,4 @@ build_Release
|
||||
build_Debug
|
||||
build
|
||||
libtorch
|
||||
Testing
|
||||
|
@@ -138,6 +138,7 @@ namespace svm_classifier {
|
||||
std::vector<double> linear_y_space_;
|
||||
|
||||
// Single sample storage (for prediction)
|
||||
// Thread-local storage for single sample conversions
|
||||
std::vector<svm_node> single_svm_nodes_;
|
||||
std::vector<feature_node> single_linear_nodes_;
|
||||
|
||||
|
@@ -92,7 +92,7 @@ namespace svm_classifier {
|
||||
linear_problem->n = n_features_;
|
||||
linear_problem->x = linear_x_space_.data();
|
||||
linear_problem->y = linear_y_space_.data();
|
||||
linear_problem->bias = -1; // No bias term by default
|
||||
linear_problem->bias = 1.0; // Add bias term with value 1.0
|
||||
|
||||
return linear_problem;
|
||||
}
|
||||
|
@@ -321,7 +321,8 @@ namespace svm_classifier {
|
||||
linear_params.p = 0.1;
|
||||
linear_params.nu = 0.5;
|
||||
linear_params.init_sol = nullptr;
|
||||
linear_params.regularize_bias = 0;
|
||||
linear_params.regularize_bias = 1;
|
||||
linear_params.w_recalc = false;
|
||||
|
||||
// Check parameters
|
||||
const char* error_msg = check_parameter(problem.get(), &linear_params);
|
||||
@@ -628,7 +629,8 @@ namespace svm_classifier {
|
||||
linear_params.p = 0.1;
|
||||
linear_params.nu = 0.5;
|
||||
linear_params.init_sol = nullptr;
|
||||
linear_params.regularize_bias = 0;
|
||||
linear_params.regularize_bias = 1;
|
||||
linear_params.w_recalc = false;
|
||||
|
||||
// Check parameters
|
||||
const char* error_msg = check_parameter(problem.get(), &linear_params);
|
||||
|
36
test_param_change.cpp
Normal file
36
test_param_change.cpp
Normal file
@@ -0,0 +1,36 @@
|
||||
#include <svm_classifier/svm_classifier.hpp>
|
||||
#include <torch/torch.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace svm_classifier;
|
||||
|
||||
int main() {
|
||||
try {
|
||||
std::cout << "Creating linear SVM..." << std::endl;
|
||||
SVMClassifier svm(KernelType::LINEAR, 1.0);
|
||||
|
||||
std::cout << "Generating data..." << std::endl;
|
||||
torch::manual_seed(42);
|
||||
auto X = torch::randn({50, 3});
|
||||
auto y = torch::randint(0, 2, {50});
|
||||
|
||||
std::cout << "Training SVM..." << std::endl;
|
||||
auto metrics = svm.fit(X, y);
|
||||
std::cout << "Training completed successfully!" << std::endl;
|
||||
std::cout << "Is fitted: " << svm.is_fitted() << std::endl;
|
||||
|
||||
std::cout << "\nChanging to RBF kernel..." << std::endl;
|
||||
nlohmann::json new_params = {{"kernel", "rbf"}};
|
||||
svm.set_parameters(new_params);
|
||||
std::cout << "Parameters changed successfully!" << std::endl;
|
||||
std::cout << "Is fitted after param change: " << svm.is_fitted() << std::endl;
|
||||
|
||||
std::cout << "\nAll operations completed successfully!" << std::endl;
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
Reference in New Issue
Block a user