library comile complete and begin tests
Some checks failed
CI/CD Pipeline / Create Release Package (push) Has been cancelled
CI/CD Pipeline / Code Linting (push) Has been cancelled
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Docker Build Test (push) Has been cancelled
CI/CD Pipeline / Performance Benchmarks (push) Has been cancelled
CI/CD Pipeline / Build Documentation (push) Has been cancelled

This commit is contained in:
2025-06-23 12:05:35 +02:00
parent e07eb4d2ed
commit 7b27d5c1f3
9 changed files with 361 additions and 172 deletions

View File

@@ -6,6 +6,7 @@
#include <unordered_set>
#include <chrono>
#include <cmath>
#include <numeric> // for std::accumulate
namespace svm_classifier {
@@ -31,11 +32,11 @@ namespace svm_classifier {
// Store parameters and determine library type
params_ = params;
library_type_ = get_svm_library(params.get_kernel_type());
library_type_ = ::svm_classifier::get_svm_library(params.get_kernel_type());
// Extract unique classes
auto y_cpu = y.to(torch::kCPU);
auto unique_classes_tensor = torch::unique(y_cpu);
auto unique_classes_tensor = std::get<0>(at::_unique(y_cpu));
classes_.clear();
for (int i = 0; i < unique_classes_tensor.size(0); ++i) {
@@ -347,14 +348,16 @@ namespace svm_classifier {
{
for (auto& model : svm_models_) {
if (model) {
svm_free_and_destroy_model(&model);
auto raw_model = model.release();
svm_free_and_destroy_model(&raw_model);
}
}
svm_models_.clear();
for (auto& model : linear_models_) {
if (model) {
free_and_destroy_model(&model);
auto raw_model = model.release();
free_and_destroy_model(&raw_model);
}
}
linear_models_.clear();
@@ -384,11 +387,11 @@ namespace svm_classifier {
// Store parameters and determine library type
params_ = params;
library_type_ = get_svm_library(params.get_kernel_type());
library_type_ = ::svm_classifier::get_svm_library(params.get_kernel_type());
// Extract unique classes
auto y_cpu = y.to(torch::kCPU);
auto unique_classes_tensor = torch::unique(y_cpu);
auto unique_classes_tensor = std::get<0>(at::_unique(y_cpu));
classes_.clear();
for (int i = 0; i < unique_classes_tensor.size(0); ++i) {
@@ -492,4 +495,231 @@ namespace svm_classifier {
return probabilities;
}
std::vector<std::vector<double>> OneVsOneStrategy::decision_function(const torch::Tensor& X,
std::vector<std::vector<double>> OneVsOneStrategy::decision_function(const torch::Tensor& X,
DataConverter& converter)
{
if (!is_trained_) {
throw std::runtime_error("Model is not trained");
}
std::vector<std::vector<double>> decision_values;
decision_values.reserve(X.size(0));
for (int i = 0; i < X.size(0); ++i) {
auto sample = X[i];
std::vector<double> sample_decisions;
sample_decisions.reserve(class_pairs_.size());
for (size_t j = 0; j < class_pairs_.size(); ++j) {
if (library_type_ == SVMLibrary::LIBSVM && svm_models_[j]) {
auto sample_node = converter.to_svm_node(sample);
double decision_value;
svm_predict_values(svm_models_[j].get(), sample_node, &decision_value);
sample_decisions.push_back(decision_value);
} else if (library_type_ == SVMLibrary::LIBLINEAR && linear_models_[j]) {
auto sample_node = converter.to_feature_node(sample);
double decision_value;
predict_values(linear_models_[j].get(), sample_node, &decision_value);
sample_decisions.push_back(decision_value);
} else {
sample_decisions.push_back(0.0);
}
}
decision_values.push_back(sample_decisions);
}
return decision_values;
}
bool OneVsOneStrategy::supports_probability() const
{
return params_.get_probability();
}
std::pair<torch::Tensor, torch::Tensor> OneVsOneStrategy::extract_binary_data(const torch::Tensor& X,
const torch::Tensor& y,
int class1,
int class2)
{
auto mask = (y == class1) | (y == class2);
auto filtered_X = X.index_select(0, torch::nonzero(mask).squeeze());
auto filtered_y = y.index_select(0, torch::nonzero(mask).squeeze());
// Convert to binary labels: class1 -> +1, class2 -> -1
auto binary_y = torch::where(filtered_y == class1, torch::ones_like(filtered_y), torch::full_like(filtered_y, -1));
return std::make_pair(filtered_X, binary_y);
}
double OneVsOneStrategy::train_pairwise_classifier(const torch::Tensor& X,
const torch::Tensor& y,
int class1,
int class2,
const KernelParameters& params,
DataConverter& converter,
int model_idx)
{
auto start_time = std::chrono::high_resolution_clock::now();
auto [filtered_X, binary_y] = extract_binary_data(X, y, class1, class2);
if (library_type_ == SVMLibrary::LIBSVM) {
// Use libsvm
auto problem = converter.to_svm_problem(filtered_X, binary_y);
// Setup SVM parameters (similar to OneVsRest)
svm_parameter svm_params;
svm_params.svm_type = C_SVC;
switch (params.get_kernel_type()) {
case KernelType::RBF:
svm_params.kernel_type = RBF;
break;
case KernelType::POLYNOMIAL:
svm_params.kernel_type = POLY;
break;
case KernelType::SIGMOID:
svm_params.kernel_type = SIGMOID;
break;
default:
throw std::runtime_error("Invalid kernel type for libsvm");
}
svm_params.degree = params.get_degree();
svm_params.gamma = (params.get_gamma() == -1.0) ? 1.0 / filtered_X.size(1) : params.get_gamma();
svm_params.coef0 = params.get_coef0();
svm_params.cache_size = params.get_cache_size();
svm_params.eps = params.get_tolerance();
svm_params.C = params.get_C();
svm_params.nr_weight = 0;
svm_params.weight_label = nullptr;
svm_params.weight = nullptr;
svm_params.nu = 0.5;
svm_params.p = 0.1;
svm_params.shrinking = 1;
svm_params.probability = params.get_probability() ? 1 : 0;
// Check parameters
const char* error_msg = svm_check_parameter(problem.get(), &svm_params);
if (error_msg) {
throw std::runtime_error("SVM parameter error: " + std::string(error_msg));
}
// Train model
auto model = svm_train(problem.get(), &svm_params);
if (!model) {
throw std::runtime_error("Failed to train SVM model");
}
svm_models_[model_idx] = std::unique_ptr<svm_model>(model);
} else {
// Use liblinear
auto problem = converter.to_linear_problem(filtered_X, binary_y);
// Setup linear parameters
parameter linear_params;
linear_params.solver_type = L2R_L2LOSS_SVC_DUAL;
linear_params.C = params.get_C();
linear_params.eps = params.get_tolerance();
linear_params.nr_weight = 0;
linear_params.weight_label = nullptr;
linear_params.weight = nullptr;
linear_params.p = 0.1;
linear_params.nu = 0.5;
linear_params.init_sol = nullptr;
linear_params.regularize_bias = 0;
// Check parameters
const char* error_msg = check_parameter(problem.get(), &linear_params);
if (error_msg) {
throw std::runtime_error("Linear parameter error: " + std::string(error_msg));
}
// Train model
auto model = train(problem.get(), &linear_params);
if (!model) {
throw std::runtime_error("Failed to train linear model");
}
linear_models_[model_idx] = std::unique_ptr<::model>(model);
}
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
return duration.count() / 1000.0;
}
std::vector<int> OneVsOneStrategy::vote_predictions(const std::vector<std::vector<double>>& decisions)
{
std::vector<int> predictions;
predictions.reserve(decisions.size());
for (const auto& decision_row : decisions) {
std::vector<int> votes(classes_.size(), 0);
// Count votes from pairwise decisions
for (size_t i = 0; i < class_pairs_.size(); ++i) {
auto [class1, class2] = class_pairs_[i];
double decision = decision_row[i];
auto it1 = std::find(classes_.begin(), classes_.end(), class1);
auto it2 = std::find(classes_.begin(), classes_.end(), class2);
if (it1 != classes_.end() && it2 != classes_.end()) {
size_t idx1 = std::distance(classes_.begin(), it1);
size_t idx2 = std::distance(classes_.begin(), it2);
if (decision > 0) {
votes[idx1]++;
} else {
votes[idx2]++;
}
}
}
// Find class with most votes
auto max_it = std::max_element(votes.begin(), votes.end());
int predicted_class_idx = std::distance(votes.begin(), max_it);
predictions.push_back(classes_[predicted_class_idx]);
}
return predictions;
}
void OneVsOneStrategy::cleanup_models()
{
for (auto& model : svm_models_) {
if (model) {
auto raw_model = model.release();
svm_free_and_destroy_model(&raw_model);
}
}
svm_models_.clear();
for (auto& model : linear_models_) {
if (model) {
auto raw_model = model.release();
free_and_destroy_model(&raw_model);
}
}
linear_models_.clear();
is_trained_ = false;
}
// Factory function
std::unique_ptr<MulticlassStrategyBase> create_multiclass_strategy(MulticlassStrategy strategy)
{
switch (strategy) {
case MulticlassStrategy::ONE_VS_REST:
return std::make_unique<OneVsRestStrategy>();
case MulticlassStrategy::ONE_VS_ONE:
return std::make_unique<OneVsOneStrategy>();
default:
throw std::invalid_argument("Unknown multiclass strategy");
}
}
} // namespace svm_classifier