Fix Regularization error
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 25s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m18s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 6m17s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m15s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m5s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 6m14s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 6m9s
CI/CD Pipeline / Docker Build Test (push) Failing after 1m12s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Failing after 20s
CI/CD Pipeline / Create Release Package (push) Has been skipped
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 25s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m18s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 6m17s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m15s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m5s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 6m14s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 6m9s
CI/CD Pipeline / Docker Build Test (push) Failing after 1m12s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Failing after 20s
CI/CD Pipeline / Create Release Package (push) Has been skipped
This commit is contained in:
36
test_param_change.cpp
Normal file
36
test_param_change.cpp
Normal file
@@ -0,0 +1,36 @@
|
||||
#include <svm_classifier/svm_classifier.hpp>
|
||||
#include <torch/torch.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace svm_classifier;
|
||||
|
||||
int main() {
|
||||
try {
|
||||
std::cout << "Creating linear SVM..." << std::endl;
|
||||
SVMClassifier svm(KernelType::LINEAR, 1.0);
|
||||
|
||||
std::cout << "Generating data..." << std::endl;
|
||||
torch::manual_seed(42);
|
||||
auto X = torch::randn({50, 3});
|
||||
auto y = torch::randint(0, 2, {50});
|
||||
|
||||
std::cout << "Training SVM..." << std::endl;
|
||||
auto metrics = svm.fit(X, y);
|
||||
std::cout << "Training completed successfully!" << std::endl;
|
||||
std::cout << "Is fitted: " << svm.is_fitted() << std::endl;
|
||||
|
||||
std::cout << "\nChanging to RBF kernel..." << std::endl;
|
||||
nlohmann::json new_params = {{"kernel", "rbf"}};
|
||||
svm.set_parameters(new_params);
|
||||
std::cout << "Parameters changed successfully!" << std::endl;
|
||||
std::cout << "Is fitted after param change: " << svm.is_fitted() << std::endl;
|
||||
|
||||
std::cout << "\nAll operations completed successfully!" << std::endl;
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
Reference in New Issue
Block a user