library comile complete and begin tests
Some checks failed
CI/CD Pipeline / Create Release Package (push) Has been cancelled
CI/CD Pipeline / Code Linting (push) Has been cancelled
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Docker Build Test (push) Has been cancelled
CI/CD Pipeline / Performance Benchmarks (push) Has been cancelled
CI/CD Pipeline / Build Documentation (push) Has been cancelled
Some checks failed
CI/CD Pipeline / Create Release Package (push) Has been cancelled
CI/CD Pipeline / Code Linting (push) Has been cancelled
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Docker Build Test (push) Has been cancelled
CI/CD Pipeline / Performance Benchmarks (push) Has been cancelled
CI/CD Pipeline / Build Documentation (push) Has been cancelled
This commit is contained in:
@@ -1,195 +1,117 @@
|
||||
#pragma once
|
||||
|
||||
#include "types.hpp"
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
// Forward declarations for libsvm and liblinear structures
|
||||
struct svm_node;
|
||||
struct svm_problem;
|
||||
struct feature_node;
|
||||
struct problem;
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
namespace svm_classifier {
|
||||
|
||||
/**
|
||||
* @brief Data converter between libtorch tensors and SVM library formats
|
||||
* @brief Kernel parameters configuration class
|
||||
*
|
||||
* This class handles the conversion between PyTorch tensors and the data structures
|
||||
* required by libsvm and liblinear libraries. It manages memory allocation and
|
||||
* provides efficient conversion methods.
|
||||
* This class manages all parameters for SVM kernels including kernel type,
|
||||
* regularization parameters, optimization settings, and kernel-specific parameters.
|
||||
*/
|
||||
class DataConverter {
|
||||
class KernelParameters {
|
||||
public:
|
||||
/**
|
||||
* @brief Default constructor
|
||||
* @brief Default constructor with default parameters
|
||||
*/
|
||||
DataConverter();
|
||||
KernelParameters();
|
||||
|
||||
/**
|
||||
* @brief Destructor - cleans up allocated memory
|
||||
* @brief Constructor with JSON configuration
|
||||
* @param config JSON configuration object
|
||||
*/
|
||||
~DataConverter();
|
||||
explicit KernelParameters(const nlohmann::json& config);
|
||||
|
||||
/**
|
||||
* @brief Convert PyTorch tensors to libsvm format
|
||||
* @param X Feature tensor of shape (n_samples, n_features)
|
||||
* @param y Target tensor of shape (n_samples,) - optional for prediction
|
||||
* @return Pointer to svm_problem structure
|
||||
* @brief Set parameters from JSON configuration
|
||||
* @param config JSON configuration object
|
||||
* @throws std::invalid_argument if parameters are invalid
|
||||
*/
|
||||
std::unique_ptr<svm_problem> to_svm_problem(const torch::Tensor& X,
|
||||
const torch::Tensor& y = torch::Tensor());
|
||||
void set_parameters(const nlohmann::json& config);
|
||||
|
||||
/**
|
||||
* @brief Convert PyTorch tensors to liblinear format
|
||||
* @param X Feature tensor of shape (n_samples, n_features)
|
||||
* @param y Target tensor of shape (n_samples,) - optional for prediction
|
||||
* @return Pointer to problem structure
|
||||
* @brief Get current parameters as JSON
|
||||
* @return JSON object with current parameters
|
||||
*/
|
||||
std::unique_ptr<problem> to_linear_problem(const torch::Tensor& X,
|
||||
const torch::Tensor& y = torch::Tensor());
|
||||
nlohmann::json get_parameters() const;
|
||||
|
||||
// Kernel type
|
||||
void set_kernel_type(KernelType kernel);
|
||||
KernelType get_kernel_type() const { return kernel_type_; }
|
||||
|
||||
// Multiclass strategy
|
||||
void set_multiclass_strategy(MulticlassStrategy strategy);
|
||||
MulticlassStrategy get_multiclass_strategy() const { return multiclass_strategy_; }
|
||||
|
||||
// Common parameters
|
||||
void set_C(double c);
|
||||
double get_C() const { return C_; }
|
||||
|
||||
void set_tolerance(double tol);
|
||||
double get_tolerance() const { return tolerance_; }
|
||||
|
||||
void set_max_iterations(int max_iter);
|
||||
int get_max_iterations() const { return max_iterations_; }
|
||||
|
||||
void set_probability(bool probability);
|
||||
bool get_probability() const { return probability_; }
|
||||
|
||||
void set_cache_size(double cache_size);
|
||||
double get_cache_size() const { return cache_size_; }
|
||||
|
||||
// Kernel-specific parameters
|
||||
void set_gamma(double gamma);
|
||||
double get_gamma() const { return gamma_; }
|
||||
bool is_gamma_auto() const { return gamma_ == -1.0; }
|
||||
void set_gamma_auto();
|
||||
|
||||
void set_degree(int degree);
|
||||
int get_degree() const { return degree_; }
|
||||
|
||||
void set_coef0(double coef0);
|
||||
double get_coef0() const { return coef0_; }
|
||||
|
||||
/**
|
||||
* @brief Convert single sample to libsvm format
|
||||
* @param sample Feature tensor of shape (n_features,)
|
||||
* @return Pointer to svm_node array
|
||||
* @brief Validate all parameters for consistency
|
||||
* @throws std::invalid_argument if parameters are invalid
|
||||
*/
|
||||
svm_node* to_svm_node(const torch::Tensor& sample);
|
||||
void validate() const;
|
||||
|
||||
/**
|
||||
* @brief Convert single sample to liblinear format
|
||||
* @param sample Feature tensor of shape (n_features,)
|
||||
* @return Pointer to feature_node array
|
||||
* @brief Get default parameters for a specific kernel type
|
||||
* @param kernel Kernel type
|
||||
* @return JSON object with default parameters
|
||||
*/
|
||||
feature_node* to_feature_node(const torch::Tensor& sample);
|
||||
static nlohmann::json get_default_parameters(KernelType kernel);
|
||||
|
||||
/**
|
||||
* @brief Convert predictions back to PyTorch tensor
|
||||
* @param predictions Vector of predictions
|
||||
* @return PyTorch tensor with predictions
|
||||
* @brief Reset all parameters to defaults for current kernel type
|
||||
*/
|
||||
torch::Tensor from_predictions(const std::vector<double>& predictions);
|
||||
|
||||
/**
|
||||
* @brief Convert probabilities back to PyTorch tensor
|
||||
* @param probabilities 2D vector of class probabilities
|
||||
* @return PyTorch tensor with probabilities of shape (n_samples, n_classes)
|
||||
*/
|
||||
torch::Tensor from_probabilities(const std::vector<std::vector<double>>& probabilities);
|
||||
|
||||
/**
|
||||
* @brief Convert decision values back to PyTorch tensor
|
||||
* @param decision_values 2D vector of decision function values
|
||||
* @return PyTorch tensor with decision values
|
||||
*/
|
||||
torch::Tensor from_decision_values(const std::vector<std::vector<double>>& decision_values);
|
||||
|
||||
/**
|
||||
* @brief Validate input tensors
|
||||
* @param X Feature tensor
|
||||
* @param y Target tensor (optional)
|
||||
* @throws std::invalid_argument if tensors are invalid
|
||||
*/
|
||||
void validate_tensors(const torch::Tensor& X, const torch::Tensor& y = torch::Tensor());
|
||||
|
||||
/**
|
||||
* @brief Get number of features from last conversion
|
||||
* @return Number of features
|
||||
*/
|
||||
int get_n_features() const { return n_features_; }
|
||||
|
||||
/**
|
||||
* @brief Get number of samples from last conversion
|
||||
* @return Number of samples
|
||||
*/
|
||||
int get_n_samples() const { return n_samples_; }
|
||||
|
||||
/**
|
||||
* @brief Clean up all allocated memory
|
||||
*/
|
||||
void cleanup();
|
||||
|
||||
/**
|
||||
* @brief Set sparse threshold (features with absolute value below this are ignored)
|
||||
* @param threshold Sparse threshold (default: 1e-8)
|
||||
*/
|
||||
void set_sparse_threshold(double threshold) { sparse_threshold_ = threshold; }
|
||||
|
||||
/**
|
||||
* @brief Get sparse threshold
|
||||
* @return Current sparse threshold
|
||||
*/
|
||||
double get_sparse_threshold() const { return sparse_threshold_; }
|
||||
void reset_to_defaults();
|
||||
|
||||
private:
|
||||
int n_features_; ///< Number of features
|
||||
int n_samples_; ///< Number of samples
|
||||
double sparse_threshold_; ///< Threshold for sparse features
|
||||
KernelType kernel_type_; ///< Kernel type
|
||||
MulticlassStrategy multiclass_strategy_; ///< Multiclass strategy
|
||||
|
||||
// Common parameters
|
||||
double C_; ///< Regularization parameter
|
||||
double tolerance_; ///< Convergence tolerance
|
||||
int max_iterations_; ///< Maximum iterations (-1 for no limit)
|
||||
bool probability_; ///< Enable probability estimates
|
||||
double cache_size_; ///< Cache size in MB
|
||||
|
||||
// Memory management for libsvm structures
|
||||
std::vector<std::vector<svm_node>> svm_nodes_storage_;
|
||||
std::vector<svm_node*> svm_x_space_;
|
||||
std::vector<double> svm_y_space_;
|
||||
|
||||
// Memory management for liblinear structures
|
||||
std::vector<std::vector<feature_node>> linear_nodes_storage_;
|
||||
std::vector<feature_node*> linear_x_space_;
|
||||
std::vector<double> linear_y_space_;
|
||||
|
||||
// Single sample storage (for prediction)
|
||||
std::vector<svm_node> single_svm_nodes_;
|
||||
std::vector<feature_node> single_linear_nodes_;
|
||||
// Kernel-specific parameters
|
||||
double gamma_; ///< Gamma parameter (-1 for auto)
|
||||
int degree_; ///< Polynomial degree
|
||||
double coef0_; ///< Independent term in polynomial/sigmoid
|
||||
|
||||
/**
|
||||
* @brief Convert tensor data to libsvm nodes for multiple samples
|
||||
* @param X Feature tensor
|
||||
* @return Vector of svm_node vectors
|
||||
* @brief Validate kernel-specific parameters
|
||||
* @throws std::invalid_argument if kernel parameters are invalid
|
||||
*/
|
||||
std::vector<std::vector<svm_node>> tensor_to_svm_nodes(const torch::Tensor& X);
|
||||
|
||||
/**
|
||||
* @brief Convert tensor data to liblinear nodes for multiple samples
|
||||
* @param X Feature tensor
|
||||
* @return Vector of feature_node vectors
|
||||
*/
|
||||
std::vector<std::vector<feature_node>> tensor_to_linear_nodes(const torch::Tensor& X);
|
||||
|
||||
/**
|
||||
* @brief Convert single tensor sample to svm_node vector
|
||||
* @param sample Feature tensor of shape (n_features,)
|
||||
* @return Vector of svm_node structures
|
||||
*/
|
||||
std::vector<svm_node> sample_to_svm_nodes(const torch::Tensor& sample);
|
||||
|
||||
/**
|
||||
* @brief Convert single tensor sample to feature_node vector
|
||||
* @param sample Feature tensor of shape (n_features,)
|
||||
* @return Vector of feature_node structures
|
||||
*/
|
||||
std::vector<feature_node> sample_to_linear_nodes(const torch::Tensor& sample);
|
||||
|
||||
/**
|
||||
* @brief Extract labels from target tensor
|
||||
* @param y Target tensor
|
||||
* @return Vector of double labels
|
||||
*/
|
||||
std::vector<double> extract_labels(const torch::Tensor& y);
|
||||
|
||||
/**
|
||||
* @brief Check if tensor is on CPU and convert if necessary
|
||||
* @param tensor Input tensor
|
||||
* @return Tensor guaranteed to be on CPU
|
||||
*/
|
||||
torch::Tensor ensure_cpu_tensor(const torch::Tensor& tensor);
|
||||
|
||||
/**
|
||||
* @brief Validate tensor dimensions and data type
|
||||
* @param tensor Tensor to validate
|
||||
* @param expected_dims Expected number of dimensions
|
||||
* @param name Tensor name for error messages
|
||||
*/
|
||||
void validate_tensor_properties(const torch::Tensor& tensor, int expected_dims, const std::string& name);
|
||||
void validate_kernel_parameters() const;
|
||||
};
|
||||
|
||||
} // namespace svm_classifier
|
@@ -8,7 +8,7 @@
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
// Forward declarations
|
||||
// Forward declarations for external library structures
|
||||
struct svm_model;
|
||||
struct model;
|
||||
|
||||
|
@@ -196,7 +196,7 @@ namespace svm_classifier {
|
||||
* @brief Get SVM library being used
|
||||
* @return SVM library type
|
||||
*/
|
||||
SVMLibrary get_svm_library() const { return get_svm_library(params_.get_kernel_type()); }
|
||||
SVMLibrary get_svm_library() const { return ::svm_classifier::get_svm_library(params_.get_kernel_type()); }
|
||||
|
||||
/**
|
||||
* @brief Perform cross-validation
|
||||
|
@@ -134,17 +134,29 @@ namespace svm_classifier {
|
||||
break;
|
||||
|
||||
case KernelType::RBF:
|
||||
params["gamma"] = is_gamma_auto() ? "auto" : gamma_;
|
||||
if (is_gamma_auto()) {
|
||||
params["gamma"] = "auto";
|
||||
} else {
|
||||
params["gamma"] = gamma_;
|
||||
}
|
||||
break;
|
||||
|
||||
case KernelType::POLYNOMIAL:
|
||||
params["degree"] = degree_;
|
||||
params["gamma"] = is_gamma_auto() ? "auto" : gamma_;
|
||||
if (is_gamma_auto()) {
|
||||
params["gamma"] = "auto";
|
||||
} else {
|
||||
params["gamma"] = gamma_;
|
||||
}
|
||||
params["coef0"] = coef0_;
|
||||
break;
|
||||
|
||||
case KernelType::SIGMOID:
|
||||
params["gamma"] = is_gamma_auto() ? "auto" : gamma_;
|
||||
if (is_gamma_auto()) {
|
||||
params["gamma"] = "auto";
|
||||
} else {
|
||||
params["gamma"] = gamma_;
|
||||
}
|
||||
params["coef0"] = coef0_;
|
||||
break;
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include <unordered_set>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <numeric> // for std::accumulate
|
||||
|
||||
namespace svm_classifier {
|
||||
|
||||
@@ -31,11 +32,11 @@ namespace svm_classifier {
|
||||
|
||||
// Store parameters and determine library type
|
||||
params_ = params;
|
||||
library_type_ = get_svm_library(params.get_kernel_type());
|
||||
library_type_ = ::svm_classifier::get_svm_library(params.get_kernel_type());
|
||||
|
||||
// Extract unique classes
|
||||
auto y_cpu = y.to(torch::kCPU);
|
||||
auto unique_classes_tensor = torch::unique(y_cpu);
|
||||
auto unique_classes_tensor = std::get<0>(at::_unique(y_cpu));
|
||||
classes_.clear();
|
||||
|
||||
for (int i = 0; i < unique_classes_tensor.size(0); ++i) {
|
||||
@@ -347,14 +348,16 @@ namespace svm_classifier {
|
||||
{
|
||||
for (auto& model : svm_models_) {
|
||||
if (model) {
|
||||
svm_free_and_destroy_model(&model);
|
||||
auto raw_model = model.release();
|
||||
svm_free_and_destroy_model(&raw_model);
|
||||
}
|
||||
}
|
||||
svm_models_.clear();
|
||||
|
||||
for (auto& model : linear_models_) {
|
||||
if (model) {
|
||||
free_and_destroy_model(&model);
|
||||
auto raw_model = model.release();
|
||||
free_and_destroy_model(&raw_model);
|
||||
}
|
||||
}
|
||||
linear_models_.clear();
|
||||
@@ -384,11 +387,11 @@ namespace svm_classifier {
|
||||
|
||||
// Store parameters and determine library type
|
||||
params_ = params;
|
||||
library_type_ = get_svm_library(params.get_kernel_type());
|
||||
library_type_ = ::svm_classifier::get_svm_library(params.get_kernel_type());
|
||||
|
||||
// Extract unique classes
|
||||
auto y_cpu = y.to(torch::kCPU);
|
||||
auto unique_classes_tensor = torch::unique(y_cpu);
|
||||
auto unique_classes_tensor = std::get<0>(at::_unique(y_cpu));
|
||||
classes_.clear();
|
||||
|
||||
for (int i = 0; i < unique_classes_tensor.size(0); ++i) {
|
||||
@@ -492,4 +495,231 @@ namespace svm_classifier {
|
||||
return probabilities;
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> OneVsOneStrategy::decision_function(const torch::Tensor& X,
|
||||
std::vector<std::vector<double>> OneVsOneStrategy::decision_function(const torch::Tensor& X,
|
||||
DataConverter& converter)
|
||||
{
|
||||
if (!is_trained_) {
|
||||
throw std::runtime_error("Model is not trained");
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> decision_values;
|
||||
decision_values.reserve(X.size(0));
|
||||
|
||||
for (int i = 0; i < X.size(0); ++i) {
|
||||
auto sample = X[i];
|
||||
std::vector<double> sample_decisions;
|
||||
sample_decisions.reserve(class_pairs_.size());
|
||||
|
||||
for (size_t j = 0; j < class_pairs_.size(); ++j) {
|
||||
if (library_type_ == SVMLibrary::LIBSVM && svm_models_[j]) {
|
||||
auto sample_node = converter.to_svm_node(sample);
|
||||
double decision_value;
|
||||
svm_predict_values(svm_models_[j].get(), sample_node, &decision_value);
|
||||
sample_decisions.push_back(decision_value);
|
||||
} else if (library_type_ == SVMLibrary::LIBLINEAR && linear_models_[j]) {
|
||||
auto sample_node = converter.to_feature_node(sample);
|
||||
double decision_value;
|
||||
predict_values(linear_models_[j].get(), sample_node, &decision_value);
|
||||
sample_decisions.push_back(decision_value);
|
||||
} else {
|
||||
sample_decisions.push_back(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
decision_values.push_back(sample_decisions);
|
||||
}
|
||||
|
||||
return decision_values;
|
||||
}
|
||||
|
||||
bool OneVsOneStrategy::supports_probability() const
|
||||
{
|
||||
return params_.get_probability();
|
||||
}
|
||||
|
||||
std::pair<torch::Tensor, torch::Tensor> OneVsOneStrategy::extract_binary_data(const torch::Tensor& X,
|
||||
const torch::Tensor& y,
|
||||
int class1,
|
||||
int class2)
|
||||
{
|
||||
auto mask = (y == class1) | (y == class2);
|
||||
auto filtered_X = X.index_select(0, torch::nonzero(mask).squeeze());
|
||||
auto filtered_y = y.index_select(0, torch::nonzero(mask).squeeze());
|
||||
|
||||
// Convert to binary labels: class1 -> +1, class2 -> -1
|
||||
auto binary_y = torch::where(filtered_y == class1, torch::ones_like(filtered_y), torch::full_like(filtered_y, -1));
|
||||
|
||||
return std::make_pair(filtered_X, binary_y);
|
||||
}
|
||||
|
||||
double OneVsOneStrategy::train_pairwise_classifier(const torch::Tensor& X,
|
||||
const torch::Tensor& y,
|
||||
int class1,
|
||||
int class2,
|
||||
const KernelParameters& params,
|
||||
DataConverter& converter,
|
||||
int model_idx)
|
||||
{
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
auto [filtered_X, binary_y] = extract_binary_data(X, y, class1, class2);
|
||||
|
||||
if (library_type_ == SVMLibrary::LIBSVM) {
|
||||
// Use libsvm
|
||||
auto problem = converter.to_svm_problem(filtered_X, binary_y);
|
||||
|
||||
// Setup SVM parameters (similar to OneVsRest)
|
||||
svm_parameter svm_params;
|
||||
svm_params.svm_type = C_SVC;
|
||||
|
||||
switch (params.get_kernel_type()) {
|
||||
case KernelType::RBF:
|
||||
svm_params.kernel_type = RBF;
|
||||
break;
|
||||
case KernelType::POLYNOMIAL:
|
||||
svm_params.kernel_type = POLY;
|
||||
break;
|
||||
case KernelType::SIGMOID:
|
||||
svm_params.kernel_type = SIGMOID;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Invalid kernel type for libsvm");
|
||||
}
|
||||
|
||||
svm_params.degree = params.get_degree();
|
||||
svm_params.gamma = (params.get_gamma() == -1.0) ? 1.0 / filtered_X.size(1) : params.get_gamma();
|
||||
svm_params.coef0 = params.get_coef0();
|
||||
svm_params.cache_size = params.get_cache_size();
|
||||
svm_params.eps = params.get_tolerance();
|
||||
svm_params.C = params.get_C();
|
||||
svm_params.nr_weight = 0;
|
||||
svm_params.weight_label = nullptr;
|
||||
svm_params.weight = nullptr;
|
||||
svm_params.nu = 0.5;
|
||||
svm_params.p = 0.1;
|
||||
svm_params.shrinking = 1;
|
||||
svm_params.probability = params.get_probability() ? 1 : 0;
|
||||
|
||||
// Check parameters
|
||||
const char* error_msg = svm_check_parameter(problem.get(), &svm_params);
|
||||
if (error_msg) {
|
||||
throw std::runtime_error("SVM parameter error: " + std::string(error_msg));
|
||||
}
|
||||
|
||||
// Train model
|
||||
auto model = svm_train(problem.get(), &svm_params);
|
||||
if (!model) {
|
||||
throw std::runtime_error("Failed to train SVM model");
|
||||
}
|
||||
|
||||
svm_models_[model_idx] = std::unique_ptr<svm_model>(model);
|
||||
} else {
|
||||
// Use liblinear
|
||||
auto problem = converter.to_linear_problem(filtered_X, binary_y);
|
||||
|
||||
// Setup linear parameters
|
||||
parameter linear_params;
|
||||
linear_params.solver_type = L2R_L2LOSS_SVC_DUAL;
|
||||
linear_params.C = params.get_C();
|
||||
linear_params.eps = params.get_tolerance();
|
||||
linear_params.nr_weight = 0;
|
||||
linear_params.weight_label = nullptr;
|
||||
linear_params.weight = nullptr;
|
||||
linear_params.p = 0.1;
|
||||
linear_params.nu = 0.5;
|
||||
linear_params.init_sol = nullptr;
|
||||
linear_params.regularize_bias = 0;
|
||||
|
||||
// Check parameters
|
||||
const char* error_msg = check_parameter(problem.get(), &linear_params);
|
||||
if (error_msg) {
|
||||
throw std::runtime_error("Linear parameter error: " + std::string(error_msg));
|
||||
}
|
||||
|
||||
// Train model
|
||||
auto model = train(problem.get(), &linear_params);
|
||||
if (!model) {
|
||||
throw std::runtime_error("Failed to train linear model");
|
||||
}
|
||||
|
||||
linear_models_[model_idx] = std::unique_ptr<::model>(model);
|
||||
}
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
|
||||
|
||||
return duration.count() / 1000.0;
|
||||
}
|
||||
|
||||
std::vector<int> OneVsOneStrategy::vote_predictions(const std::vector<std::vector<double>>& decisions)
|
||||
{
|
||||
std::vector<int> predictions;
|
||||
predictions.reserve(decisions.size());
|
||||
|
||||
for (const auto& decision_row : decisions) {
|
||||
std::vector<int> votes(classes_.size(), 0);
|
||||
|
||||
// Count votes from pairwise decisions
|
||||
for (size_t i = 0; i < class_pairs_.size(); ++i) {
|
||||
auto [class1, class2] = class_pairs_[i];
|
||||
double decision = decision_row[i];
|
||||
|
||||
auto it1 = std::find(classes_.begin(), classes_.end(), class1);
|
||||
auto it2 = std::find(classes_.begin(), classes_.end(), class2);
|
||||
|
||||
if (it1 != classes_.end() && it2 != classes_.end()) {
|
||||
size_t idx1 = std::distance(classes_.begin(), it1);
|
||||
size_t idx2 = std::distance(classes_.begin(), it2);
|
||||
|
||||
if (decision > 0) {
|
||||
votes[idx1]++;
|
||||
} else {
|
||||
votes[idx2]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find class with most votes
|
||||
auto max_it = std::max_element(votes.begin(), votes.end());
|
||||
int predicted_class_idx = std::distance(votes.begin(), max_it);
|
||||
predictions.push_back(classes_[predicted_class_idx]);
|
||||
}
|
||||
|
||||
return predictions;
|
||||
}
|
||||
|
||||
void OneVsOneStrategy::cleanup_models()
|
||||
{
|
||||
for (auto& model : svm_models_) {
|
||||
if (model) {
|
||||
auto raw_model = model.release();
|
||||
svm_free_and_destroy_model(&raw_model);
|
||||
}
|
||||
}
|
||||
svm_models_.clear();
|
||||
|
||||
for (auto& model : linear_models_) {
|
||||
if (model) {
|
||||
auto raw_model = model.release();
|
||||
free_and_destroy_model(&raw_model);
|
||||
}
|
||||
}
|
||||
linear_models_.clear();
|
||||
|
||||
is_trained_ = false;
|
||||
}
|
||||
|
||||
// Factory function
|
||||
std::unique_ptr<MulticlassStrategyBase> create_multiclass_strategy(MulticlassStrategy strategy)
|
||||
{
|
||||
switch (strategy) {
|
||||
case MulticlassStrategy::ONE_VS_REST:
|
||||
return std::make_unique<OneVsRestStrategy>();
|
||||
case MulticlassStrategy::ONE_VS_ONE:
|
||||
return std::make_unique<OneVsOneStrategy>();
|
||||
default:
|
||||
throw std::invalid_argument("Unknown multiclass strategy");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace svm_classifier
|
@@ -23,16 +23,31 @@ target_link_libraries(svm_classifier_tests
|
||||
PRIVATE
|
||||
svm_classifier
|
||||
Catch2::Catch2WithMain
|
||||
nlohmann_json::nlohmann_json
|
||||
)
|
||||
|
||||
# Set include directories
|
||||
# Set include directories - Handle external libraries dynamically
|
||||
target_include_directories(svm_classifier_tests
|
||||
PRIVATE
|
||||
${CMAKE_SOURCE_DIR}/include
|
||||
${CMAKE_SOURCE_DIR}/external/libsvm
|
||||
${CMAKE_SOURCE_DIR}/external/liblinear
|
||||
)
|
||||
|
||||
# Add libsvm include directory if available
|
||||
if(EXISTS "${CMAKE_CURRENT_BINARY_DIR}/../_deps/libsvm-src")
|
||||
target_include_directories(svm_classifier_tests
|
||||
PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../_deps/libsvm-src"
|
||||
)
|
||||
endif()
|
||||
|
||||
# Add liblinear include directories if available
|
||||
if(EXISTS "${CMAKE_CURRENT_BINARY_DIR}/../_deps/liblinear-src")
|
||||
target_include_directories(svm_classifier_tests
|
||||
PRIVATE
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/../_deps/liblinear-src"
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/../_deps/liblinear-src/blas"
|
||||
)
|
||||
endif()
|
||||
|
||||
# Compiler flags for tests
|
||||
target_compile_features(svm_classifier_tests PRIVATE cxx_std_17)
|
||||
|
||||
|
@@ -7,6 +7,10 @@
|
||||
#include <svm_classifier/data_converter.hpp>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// Include the actual headers for complete struct definitions
|
||||
#include "svm.h" // libsvm structures
|
||||
#include "linear.h" // liblinear structures
|
||||
|
||||
using namespace svm_classifier;
|
||||
|
||||
TEST_CASE("DataConverter Basic Functionality", "[unit][data_converter]")
|
||||
|
@@ -10,7 +10,13 @@
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
// Include the actual headers for complete struct definitions
|
||||
#include "svm.h" // libsvm structures
|
||||
#include "linear.h" // liblinear structures
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
using namespace svm_classifier;
|
||||
using json = nlohmann::json;
|
||||
|
||||
/**
|
||||
* @brief Generate large synthetic dataset for performance testing
|
||||
|
@@ -283,7 +283,7 @@ TEST_CASE("SVMClassifier Prediction", "[integration][svm_classifier]")
|
||||
REQUIRE(predictions.size(0) == X_test.size(0));
|
||||
|
||||
// Check that predictions are valid class labels
|
||||
auto unique_preds = torch::unique(predictions);
|
||||
auto unique_preds = std::get<0>(at::_unique(predictions));
|
||||
for (int i = 0; i < unique_preds.size(0); ++i) {
|
||||
int pred_class = unique_preds[i].item<int>();
|
||||
auto classes = svm.get_classes();
|
||||
|
Reference in New Issue
Block a user