Initial commit as Claude developed it
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 22s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m44s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m12s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m13s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 5m30s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Docker Build Test (push) Failing after 13s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Successful in 31s
CI/CD Pipeline / Create Release Package (push) Has been skipped
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 22s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m44s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m12s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m13s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 5m30s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 5m33s
CI/CD Pipeline / Docker Build Test (push) Failing after 13s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Successful in 31s
CI/CD Pipeline / Create Release Package (push) Has been skipped
This commit is contained in:
378
src/data_converter.cpp
Normal file
378
src/data_converter.cpp
Normal file
@@ -0,0 +1,378 @@
|
||||
#include "svm_classifier/data_converter.hpp"
|
||||
#include "svm.h" // libsvm
|
||||
#include "linear.h" // liblinear
|
||||
#include <stdexcept>
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
|
||||
namespace svm_classifier {
|
||||
|
||||
DataConverter::DataConverter()
|
||||
: n_features_(0)
|
||||
, n_samples_(0)
|
||||
, sparse_threshold_(1e-8)
|
||||
{
|
||||
}
|
||||
|
||||
DataConverter::~DataConverter()
|
||||
{
|
||||
cleanup();
|
||||
}
|
||||
|
||||
std::unique_ptr<svm_problem> DataConverter::to_svm_problem(const torch::Tensor& X,
|
||||
const torch::Tensor& y)
|
||||
{
|
||||
validate_tensors(X, y);
|
||||
|
||||
auto X_cpu = ensure_cpu_tensor(X);
|
||||
|
||||
n_samples_ = X_cpu.size(0);
|
||||
n_features_ = X_cpu.size(1);
|
||||
|
||||
// Convert tensor data to svm_node structures
|
||||
svm_nodes_storage_ = tensor_to_svm_nodes(X_cpu);
|
||||
|
||||
// Prepare pointers for svm_problem
|
||||
svm_x_space_.clear();
|
||||
svm_x_space_.reserve(n_samples_);
|
||||
|
||||
for (auto& nodes : svm_nodes_storage_) {
|
||||
svm_x_space_.push_back(nodes.data());
|
||||
}
|
||||
|
||||
// Extract labels if provided
|
||||
if (y.defined() && y.numel() > 0) {
|
||||
svm_y_space_ = extract_labels(y);
|
||||
} else {
|
||||
svm_y_space_.clear();
|
||||
svm_y_space_.resize(n_samples_, 0.0); // Dummy labels for prediction
|
||||
}
|
||||
|
||||
// Create svm_problem
|
||||
auto problem = std::make_unique<svm_problem>();
|
||||
problem->l = n_samples_;
|
||||
problem->x = svm_x_space_.data();
|
||||
problem->y = svm_y_space_.data();
|
||||
|
||||
return problem;
|
||||
}
|
||||
|
||||
std::unique_ptr<problem> DataConverter::to_linear_problem(const torch::Tensor& X,
|
||||
const torch::Tensor& y)
|
||||
{
|
||||
validate_tensors(X, y);
|
||||
|
||||
auto X_cpu = ensure_cpu_tensor(X);
|
||||
|
||||
n_samples_ = X_cpu.size(0);
|
||||
n_features_ = X_cpu.size(1);
|
||||
|
||||
// Convert tensor data to feature_node structures
|
||||
linear_nodes_storage_ = tensor_to_linear_nodes(X_cpu);
|
||||
|
||||
// Prepare pointers for problem
|
||||
linear_x_space_.clear();
|
||||
linear_x_space_.reserve(n_samples_);
|
||||
|
||||
for (auto& nodes : linear_nodes_storage_) {
|
||||
linear_x_space_.push_back(nodes.data());
|
||||
}
|
||||
|
||||
// Extract labels if provided
|
||||
if (y.defined() && y.numel() > 0) {
|
||||
linear_y_space_ = extract_labels(y);
|
||||
} else {
|
||||
linear_y_space_.clear();
|
||||
linear_y_space_.resize(n_samples_, 0.0); // Dummy labels for prediction
|
||||
}
|
||||
|
||||
// Create problem
|
||||
auto linear_problem = std::make_unique<problem>();
|
||||
linear_problem->l = n_samples_;
|
||||
linear_problem->n = n_features_;
|
||||
linear_problem->x = linear_x_space_.data();
|
||||
linear_problem->y = linear_y_space_.data();
|
||||
linear_problem->bias = -1; // No bias term by default
|
||||
|
||||
return linear_problem;
|
||||
}
|
||||
|
||||
svm_node* DataConverter::to_svm_node(const torch::Tensor& sample)
|
||||
{
|
||||
validate_tensor_properties(sample, 1, "sample");
|
||||
|
||||
auto sample_cpu = ensure_cpu_tensor(sample);
|
||||
single_svm_nodes_ = sample_to_svm_nodes(sample_cpu);
|
||||
|
||||
return single_svm_nodes_.data();
|
||||
}
|
||||
|
||||
feature_node* DataConverter::to_feature_node(const torch::Tensor& sample)
|
||||
{
|
||||
validate_tensor_properties(sample, 1, "sample");
|
||||
|
||||
auto sample_cpu = ensure_cpu_tensor(sample);
|
||||
single_linear_nodes_ = sample_to_linear_nodes(sample_cpu);
|
||||
|
||||
return single_linear_nodes_.data();
|
||||
}
|
||||
|
||||
torch::Tensor DataConverter::from_predictions(const std::vector<double>& predictions)
|
||||
{
|
||||
auto options = torch::TensorOptions().dtype(torch::kInt32);
|
||||
auto tensor = torch::zeros({ static_cast<int64_t>(predictions.size()) }, options);
|
||||
|
||||
for (size_t i = 0; i < predictions.size(); ++i) {
|
||||
tensor[i] = static_cast<int>(predictions[i]);
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
torch::Tensor DataConverter::from_probabilities(const std::vector<std::vector<double>>& probabilities)
|
||||
{
|
||||
if (probabilities.empty()) {
|
||||
return torch::empty({ 0, 0 });
|
||||
}
|
||||
|
||||
int n_samples = static_cast<int>(probabilities.size());
|
||||
int n_classes = static_cast<int>(probabilities[0].size());
|
||||
|
||||
auto tensor = torch::zeros({ n_samples, n_classes }, torch::kFloat64);
|
||||
|
||||
for (int i = 0; i < n_samples; ++i) {
|
||||
for (int j = 0; j < n_classes; ++j) {
|
||||
tensor[i][j] = probabilities[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
torch::Tensor DataConverter::from_decision_values(const std::vector<std::vector<double>>& decision_values)
|
||||
{
|
||||
if (decision_values.empty()) {
|
||||
return torch::empty({ 0, 0 });
|
||||
}
|
||||
|
||||
int n_samples = static_cast<int>(decision_values.size());
|
||||
int n_values = static_cast<int>(decision_values[0].size());
|
||||
|
||||
auto tensor = torch::zeros({ n_samples, n_values }, torch::kFloat64);
|
||||
|
||||
for (int i = 0; i < n_samples; ++i) {
|
||||
for (int j = 0; j < n_values; ++j) {
|
||||
tensor[i][j] = decision_values[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void DataConverter::validate_tensors(const torch::Tensor& X, const torch::Tensor& y)
|
||||
{
|
||||
validate_tensor_properties(X, 2, "X");
|
||||
|
||||
if (y.defined() && y.numel() > 0) {
|
||||
validate_tensor_properties(y, 1, "y");
|
||||
|
||||
// Check that number of samples match
|
||||
if (X.size(0) != y.size(0)) {
|
||||
throw std::invalid_argument(
|
||||
"Number of samples in X (" + std::to_string(X.size(0)) +
|
||||
") does not match number of labels in y (" + std::to_string(y.size(0)) + ")"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for reasonable dimensions
|
||||
if (X.size(0) == 0) {
|
||||
throw std::invalid_argument("X cannot have 0 samples");
|
||||
}
|
||||
|
||||
if (X.size(1) == 0) {
|
||||
throw std::invalid_argument("X cannot have 0 features");
|
||||
}
|
||||
}
|
||||
|
||||
void DataConverter::cleanup()
|
||||
{
|
||||
svm_nodes_storage_.clear();
|
||||
svm_x_space_.clear();
|
||||
svm_y_space_.clear();
|
||||
|
||||
linear_nodes_storage_.clear();
|
||||
linear_x_space_.clear();
|
||||
linear_y_space_.clear();
|
||||
|
||||
single_svm_nodes_.clear();
|
||||
single_linear_nodes_.clear();
|
||||
|
||||
n_features_ = 0;
|
||||
n_samples_ = 0;
|
||||
}
|
||||
|
||||
std::vector<std::vector<svm_node>> DataConverter::tensor_to_svm_nodes(const torch::Tensor& X)
|
||||
{
|
||||
std::vector<std::vector<svm_node>> nodes_storage;
|
||||
nodes_storage.reserve(X.size(0));
|
||||
|
||||
auto X_acc = X.accessor<float, 2>();
|
||||
|
||||
for (int i = 0; i < X.size(0); ++i) {
|
||||
nodes_storage.push_back(sample_to_svm_nodes(X[i]));
|
||||
}
|
||||
|
||||
return nodes_storage;
|
||||
}
|
||||
|
||||
std::vector<std::vector<feature_node>> DataConverter::tensor_to_linear_nodes(const torch::Tensor& X)
|
||||
{
|
||||
std::vector<std::vector<feature_node>> nodes_storage;
|
||||
nodes_storage.reserve(X.size(0));
|
||||
|
||||
for (int i = 0; i < X.size(0); ++i) {
|
||||
nodes_storage.push_back(sample_to_linear_nodes(X[i]));
|
||||
}
|
||||
|
||||
return nodes_storage;
|
||||
}
|
||||
|
||||
std::vector<svm_node> DataConverter::sample_to_svm_nodes(const torch::Tensor& sample)
|
||||
{
|
||||
std::vector<svm_node> nodes;
|
||||
|
||||
auto sample_acc = sample.accessor<float, 1>();
|
||||
|
||||
// Reserve space (worst case: all features are non-sparse)
|
||||
nodes.reserve(sample.size(0) + 1); // +1 for terminator
|
||||
|
||||
for (int j = 0; j < sample.size(0); ++j) {
|
||||
double value = static_cast<double>(sample_acc[j]);
|
||||
|
||||
// Skip sparse features
|
||||
if (std::abs(value) > sparse_threshold_) {
|
||||
svm_node node;
|
||||
node.index = j + 1; // libsvm uses 1-based indexing
|
||||
node.value = value;
|
||||
nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Add terminator
|
||||
svm_node terminator;
|
||||
terminator.index = -1;
|
||||
terminator.value = 0;
|
||||
nodes.push_back(terminator);
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
std::vector<feature_node> DataConverter::sample_to_linear_nodes(const torch::Tensor& sample)
|
||||
{
|
||||
std::vector<feature_node> nodes;
|
||||
|
||||
auto sample_acc = sample.accessor<float, 1>();
|
||||
|
||||
// Reserve space (worst case: all features are non-sparse)
|
||||
nodes.reserve(sample.size(0) + 1); // +1 for terminator
|
||||
|
||||
for (int j = 0; j < sample.size(0); ++j) {
|
||||
double value = static_cast<double>(sample_acc[j]);
|
||||
|
||||
// Skip sparse features
|
||||
if (std::abs(value) > sparse_threshold_) {
|
||||
feature_node node;
|
||||
node.index = j + 1; // liblinear uses 1-based indexing
|
||||
node.value = value;
|
||||
nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Add terminator
|
||||
feature_node terminator;
|
||||
terminator.index = -1;
|
||||
terminator.value = 0;
|
||||
nodes.push_back(terminator);
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
std::vector<double> DataConverter::extract_labels(const torch::Tensor& y)
|
||||
{
|
||||
auto y_cpu = ensure_cpu_tensor(y);
|
||||
std::vector<double> labels;
|
||||
labels.reserve(y_cpu.size(0));
|
||||
|
||||
// Handle different tensor types
|
||||
if (y_cpu.dtype() == torch::kInt32) {
|
||||
auto y_acc = y_cpu.accessor<int32_t, 1>();
|
||||
for (int i = 0; i < y_cpu.size(0); ++i) {
|
||||
labels.push_back(static_cast<double>(y_acc[i]));
|
||||
}
|
||||
} else if (y_cpu.dtype() == torch::kInt64) {
|
||||
auto y_acc = y_cpu.accessor<int64_t, 1>();
|
||||
for (int i = 0; i < y_cpu.size(0); ++i) {
|
||||
labels.push_back(static_cast<double>(y_acc[i]));
|
||||
}
|
||||
} else if (y_cpu.dtype() == torch::kFloat32) {
|
||||
auto y_acc = y_cpu.accessor<float, 1>();
|
||||
for (int i = 0; i < y_cpu.size(0); ++i) {
|
||||
labels.push_back(static_cast<double>(y_acc[i]));
|
||||
}
|
||||
} else if (y_cpu.dtype() == torch::kFloat64) {
|
||||
auto y_acc = y_cpu.accessor<double, 1>();
|
||||
for (int i = 0; i < y_cpu.size(0); ++i) {
|
||||
labels.push_back(y_acc[i]);
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("Unsupported label tensor dtype");
|
||||
}
|
||||
|
||||
return labels;
|
||||
}
|
||||
|
||||
torch::Tensor DataConverter::ensure_cpu_tensor(const torch::Tensor& tensor)
|
||||
{
|
||||
if (tensor.device().type() != torch::kCPU) {
|
||||
return tensor.to(torch::kCPU);
|
||||
}
|
||||
|
||||
// Convert to float32 if not already
|
||||
if (tensor.dtype() != torch::kFloat32) {
|
||||
return tensor.to(torch::kFloat32);
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void DataConverter::validate_tensor_properties(const torch::Tensor& tensor,
|
||||
int expected_dims,
|
||||
const std::string& name)
|
||||
{
|
||||
if (!tensor.defined()) {
|
||||
throw std::invalid_argument(name + " tensor is not defined");
|
||||
}
|
||||
|
||||
if (tensor.dim() != expected_dims) {
|
||||
throw std::invalid_argument(
|
||||
name + " must have " + std::to_string(expected_dims) +
|
||||
" dimensions, got " + std::to_string(tensor.dim())
|
||||
);
|
||||
}
|
||||
|
||||
if (tensor.numel() == 0) {
|
||||
throw std::invalid_argument(name + " tensor cannot be empty");
|
||||
}
|
||||
|
||||
// Check for NaN or Inf values
|
||||
if (torch::any(torch::isnan(tensor)).item<bool>()) {
|
||||
throw std::invalid_argument(name + " contains NaN values");
|
||||
}
|
||||
|
||||
if (torch::any(torch::isinf(tensor)).item<bool>()) {
|
||||
throw std::invalid_argument(name + " contains infinite values");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace svm_classifier
|
348
src/kernel_parameters.cpp
Normal file
348
src/kernel_parameters.cpp
Normal file
@@ -0,0 +1,348 @@
|
||||
#include "svm_classifier/kernel_parameters.hpp"
|
||||
#include <stdexcept>
|
||||
#include <cmath>
|
||||
|
||||
namespace svm_classifier {
|
||||
|
||||
KernelParameters::KernelParameters()
|
||||
: kernel_type_(KernelType::LINEAR)
|
||||
, multiclass_strategy_(MulticlassStrategy::ONE_VS_REST)
|
||||
, C_(1.0)
|
||||
, tolerance_(1e-3)
|
||||
, max_iterations_(-1)
|
||||
, probability_(false)
|
||||
, gamma_(-1.0) // Auto gamma
|
||||
, degree_(3)
|
||||
, coef0_(0.0)
|
||||
, cache_size_(200.0)
|
||||
{
|
||||
}
|
||||
|
||||
KernelParameters::KernelParameters(const nlohmann::json& config) : KernelParameters()
|
||||
{
|
||||
set_parameters(config);
|
||||
}
|
||||
|
||||
void KernelParameters::set_parameters(const nlohmann::json& config)
|
||||
{
|
||||
// Set kernel type first as it affects validation
|
||||
if (config.contains("kernel")) {
|
||||
if (config["kernel"].is_string()) {
|
||||
set_kernel_type(string_to_kernel_type(config["kernel"]));
|
||||
} else {
|
||||
throw std::invalid_argument("Kernel must be a string");
|
||||
}
|
||||
}
|
||||
|
||||
// Set multiclass strategy
|
||||
if (config.contains("multiclass_strategy")) {
|
||||
if (config["multiclass_strategy"].is_string()) {
|
||||
set_multiclass_strategy(string_to_multiclass_strategy(config["multiclass_strategy"]));
|
||||
} else {
|
||||
throw std::invalid_argument("Multiclass strategy must be a string");
|
||||
}
|
||||
}
|
||||
|
||||
// Set common parameters
|
||||
if (config.contains("C")) {
|
||||
if (config["C"].is_number()) {
|
||||
set_C(config["C"]);
|
||||
} else {
|
||||
throw std::invalid_argument("C must be a number");
|
||||
}
|
||||
}
|
||||
|
||||
if (config.contains("tolerance")) {
|
||||
if (config["tolerance"].is_number()) {
|
||||
set_tolerance(config["tolerance"]);
|
||||
} else {
|
||||
throw std::invalid_argument("Tolerance must be a number");
|
||||
}
|
||||
}
|
||||
|
||||
if (config.contains("max_iterations")) {
|
||||
if (config["max_iterations"].is_number_integer()) {
|
||||
set_max_iterations(config["max_iterations"]);
|
||||
} else {
|
||||
throw std::invalid_argument("Max iterations must be an integer");
|
||||
}
|
||||
}
|
||||
|
||||
if (config.contains("probability")) {
|
||||
if (config["probability"].is_boolean()) {
|
||||
set_probability(config["probability"]);
|
||||
} else {
|
||||
throw std::invalid_argument("Probability must be a boolean");
|
||||
}
|
||||
}
|
||||
|
||||
// Set kernel-specific parameters
|
||||
if (config.contains("gamma")) {
|
||||
if (config["gamma"].is_number()) {
|
||||
set_gamma(config["gamma"]);
|
||||
} else if (config["gamma"].is_string() && config["gamma"] == "auto") {
|
||||
set_gamma(-1.0); // Auto gamma
|
||||
} else {
|
||||
throw std::invalid_argument("Gamma must be a number or 'auto'");
|
||||
}
|
||||
}
|
||||
|
||||
if (config.contains("degree")) {
|
||||
if (config["degree"].is_number_integer()) {
|
||||
set_degree(config["degree"]);
|
||||
} else {
|
||||
throw std::invalid_argument("Degree must be an integer");
|
||||
}
|
||||
}
|
||||
|
||||
if (config.contains("coef0")) {
|
||||
if (config["coef0"].is_number()) {
|
||||
set_coef0(config["coef0"]);
|
||||
} else {
|
||||
throw std::invalid_argument("Coef0 must be a number");
|
||||
}
|
||||
}
|
||||
|
||||
if (config.contains("cache_size")) {
|
||||
if (config["cache_size"].is_number()) {
|
||||
set_cache_size(config["cache_size"]);
|
||||
} else {
|
||||
throw std::invalid_argument("Cache size must be a number");
|
||||
}
|
||||
}
|
||||
|
||||
// Validate all parameters
|
||||
validate();
|
||||
}
|
||||
|
||||
nlohmann::json KernelParameters::get_parameters() const
|
||||
{
|
||||
nlohmann::json params = {
|
||||
{"kernel", kernel_type_to_string(kernel_type_)},
|
||||
{"multiclass_strategy", multiclass_strategy_to_string(multiclass_strategy_)},
|
||||
{"C", C_},
|
||||
{"tolerance", tolerance_},
|
||||
{"max_iterations", max_iterations_},
|
||||
{"probability", probability_},
|
||||
{"cache_size", cache_size_}
|
||||
};
|
||||
|
||||
// Add kernel-specific parameters
|
||||
switch (kernel_type_) {
|
||||
case KernelType::LINEAR:
|
||||
// No additional parameters for linear kernel
|
||||
break;
|
||||
|
||||
case KernelType::RBF:
|
||||
params["gamma"] = is_gamma_auto() ? "auto" : gamma_;
|
||||
break;
|
||||
|
||||
case KernelType::POLYNOMIAL:
|
||||
params["degree"] = degree_;
|
||||
params["gamma"] = is_gamma_auto() ? "auto" : gamma_;
|
||||
params["coef0"] = coef0_;
|
||||
break;
|
||||
|
||||
case KernelType::SIGMOID:
|
||||
params["gamma"] = is_gamma_auto() ? "auto" : gamma_;
|
||||
params["coef0"] = coef0_;
|
||||
break;
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
void KernelParameters::set_kernel_type(KernelType kernel)
|
||||
{
|
||||
kernel_type_ = kernel;
|
||||
|
||||
// Reset kernel-specific parameters to defaults when kernel changes
|
||||
auto defaults = get_default_parameters(kernel);
|
||||
|
||||
if (defaults.contains("gamma")) {
|
||||
gamma_ = defaults["gamma"];
|
||||
}
|
||||
if (defaults.contains("degree")) {
|
||||
degree_ = defaults["degree"];
|
||||
}
|
||||
if (defaults.contains("coef0")) {
|
||||
coef0_ = defaults["coef0"];
|
||||
}
|
||||
}
|
||||
|
||||
void KernelParameters::set_C(double c)
|
||||
{
|
||||
if (c <= 0.0) {
|
||||
throw std::invalid_argument("C must be positive (C > 0)");
|
||||
}
|
||||
C_ = c;
|
||||
}
|
||||
|
||||
void KernelParameters::set_gamma(double gamma)
|
||||
{
|
||||
// Allow negative values for auto gamma (-1.0)
|
||||
if (gamma > 0.0 || gamma == -1.0) {
|
||||
gamma_ = gamma;
|
||||
} else {
|
||||
throw std::invalid_argument("Gamma must be positive or -1 for auto");
|
||||
}
|
||||
}
|
||||
|
||||
void KernelParameters::set_degree(int degree)
|
||||
{
|
||||
if (degree < 1) {
|
||||
throw std::invalid_argument("Degree must be >= 1");
|
||||
}
|
||||
degree_ = degree;
|
||||
}
|
||||
|
||||
void KernelParameters::set_coef0(double coef0)
|
||||
{
|
||||
coef0_ = coef0;
|
||||
}
|
||||
|
||||
void KernelParameters::set_tolerance(double tol)
|
||||
{
|
||||
if (tol <= 0.0) {
|
||||
throw std::invalid_argument("Tolerance must be positive (tolerance > 0)");
|
||||
}
|
||||
tolerance_ = tol;
|
||||
}
|
||||
|
||||
void KernelParameters::set_max_iterations(int max_iter)
|
||||
{
|
||||
if (max_iter <= 0 && max_iter != -1) {
|
||||
throw std::invalid_argument("Max iterations must be positive or -1 for no limit");
|
||||
}
|
||||
max_iterations_ = max_iter;
|
||||
}
|
||||
|
||||
void KernelParameters::set_cache_size(double cache_size)
|
||||
{
|
||||
if (cache_size < 0.0) {
|
||||
throw std::invalid_argument("Cache size must be non-negative");
|
||||
}
|
||||
cache_size_ = cache_size;
|
||||
}
|
||||
|
||||
void KernelParameters::set_probability(bool probability)
|
||||
{
|
||||
probability_ = probability;
|
||||
}
|
||||
|
||||
void KernelParameters::set_multiclass_strategy(MulticlassStrategy strategy)
|
||||
{
|
||||
multiclass_strategy_ = strategy;
|
||||
}
|
||||
|
||||
void KernelParameters::validate() const
|
||||
{
|
||||
// Validate common parameters
|
||||
if (C_ <= 0.0) {
|
||||
throw std::invalid_argument("C must be positive");
|
||||
}
|
||||
|
||||
if (tolerance_ <= 0.0) {
|
||||
throw std::invalid_argument("Tolerance must be positive");
|
||||
}
|
||||
|
||||
if (max_iterations_ <= 0 && max_iterations_ != -1) {
|
||||
throw std::invalid_argument("Max iterations must be positive or -1");
|
||||
}
|
||||
|
||||
if (cache_size_ < 0.0) {
|
||||
throw std::invalid_argument("Cache size must be non-negative");
|
||||
}
|
||||
|
||||
// Validate kernel-specific parameters
|
||||
validate_kernel_parameters();
|
||||
}
|
||||
|
||||
void KernelParameters::validate_kernel_parameters() const
|
||||
{
|
||||
switch (kernel_type_) {
|
||||
case KernelType::LINEAR:
|
||||
// Linear kernel has no additional parameters to validate
|
||||
break;
|
||||
|
||||
case KernelType::RBF:
|
||||
if (gamma_ > 0.0 || gamma_ == -1.0) {
|
||||
// Valid gamma (positive or auto)
|
||||
} else {
|
||||
throw std::invalid_argument("RBF kernel gamma must be positive or auto (-1)");
|
||||
}
|
||||
break;
|
||||
|
||||
case KernelType::POLYNOMIAL:
|
||||
if (degree_ < 1) {
|
||||
throw std::invalid_argument("Polynomial degree must be >= 1");
|
||||
}
|
||||
if (gamma_ > 0.0 || gamma_ == -1.0) {
|
||||
// Valid gamma
|
||||
} else {
|
||||
throw std::invalid_argument("Polynomial kernel gamma must be positive or auto (-1)");
|
||||
}
|
||||
// coef0 can be any real number
|
||||
break;
|
||||
|
||||
case KernelType::SIGMOID:
|
||||
if (gamma_ > 0.0 || gamma_ == -1.0) {
|
||||
// Valid gamma
|
||||
} else {
|
||||
throw std::invalid_argument("Sigmoid kernel gamma must be positive or auto (-1)");
|
||||
}
|
||||
// coef0 can be any real number
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json KernelParameters::get_default_parameters(KernelType kernel)
|
||||
{
|
||||
nlohmann::json defaults = {
|
||||
{"C", 1.0},
|
||||
{"tolerance", 1e-3},
|
||||
{"max_iterations", -1},
|
||||
{"probability", false},
|
||||
{"multiclass_strategy", "ovr"},
|
||||
{"cache_size", 200.0}
|
||||
};
|
||||
|
||||
switch (kernel) {
|
||||
case KernelType::LINEAR:
|
||||
defaults["kernel"] = "linear";
|
||||
break;
|
||||
|
||||
case KernelType::RBF:
|
||||
defaults["kernel"] = "rbf";
|
||||
defaults["gamma"] = -1.0; // Auto gamma
|
||||
break;
|
||||
|
||||
case KernelType::POLYNOMIAL:
|
||||
defaults["kernel"] = "polynomial";
|
||||
defaults["degree"] = 3;
|
||||
defaults["gamma"] = -1.0; // Auto gamma
|
||||
defaults["coef0"] = 0.0;
|
||||
break;
|
||||
|
||||
case KernelType::SIGMOID:
|
||||
defaults["kernel"] = "sigmoid";
|
||||
defaults["gamma"] = -1.0; // Auto gamma
|
||||
defaults["coef0"] = 0.0;
|
||||
break;
|
||||
}
|
||||
|
||||
return defaults;
|
||||
}
|
||||
|
||||
void KernelParameters::reset_to_defaults()
|
||||
{
|
||||
auto defaults = get_default_parameters(kernel_type_);
|
||||
set_parameters(defaults);
|
||||
}
|
||||
|
||||
void KernelParameters::set_gamma_auto()
|
||||
{
|
||||
gamma_ = -1.0;
|
||||
}
|
||||
|
||||
} // namespace svm_classifier
|
495
src/multiclass_strategy.cpp
Normal file
495
src/multiclass_strategy.cpp
Normal file
@@ -0,0 +1,495 @@
|
||||
#include "svm_classifier/multiclass_strategy.hpp"
|
||||
#include "svm.h" // libsvm
|
||||
#include "linear.h" // liblinear
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
|
||||
namespace svm_classifier {
|
||||
|
||||
// OneVsRestStrategy Implementation
|
||||
OneVsRestStrategy::OneVsRestStrategy()
|
||||
: library_type_(SVMLibrary::LIBLINEAR)
|
||||
{
|
||||
}
|
||||
|
||||
OneVsRestStrategy::~OneVsRestStrategy()
|
||||
{
|
||||
cleanup_models();
|
||||
}
|
||||
|
||||
TrainingMetrics OneVsRestStrategy::fit(const torch::Tensor& X,
|
||||
const torch::Tensor& y,
|
||||
const KernelParameters& params,
|
||||
DataConverter& converter)
|
||||
{
|
||||
cleanup_models();
|
||||
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// Store parameters and determine library type
|
||||
params_ = params;
|
||||
library_type_ = get_svm_library(params.get_kernel_type());
|
||||
|
||||
// Extract unique classes
|
||||
auto y_cpu = y.to(torch::kCPU);
|
||||
auto unique_classes_tensor = torch::unique(y_cpu);
|
||||
classes_.clear();
|
||||
|
||||
for (int i = 0; i < unique_classes_tensor.size(0); ++i) {
|
||||
classes_.push_back(unique_classes_tensor[i].item<int>());
|
||||
}
|
||||
|
||||
std::sort(classes_.begin(), classes_.end());
|
||||
|
||||
// Handle binary classification case
|
||||
if (classes_.size() <= 2) {
|
||||
// For binary classification, train a single classifier
|
||||
classes_.resize(2); // Ensure we have exactly 2 classes
|
||||
|
||||
auto binary_y = y;
|
||||
if (classes_.size() == 1) {
|
||||
// Edge case: only one class, create dummy binary problem
|
||||
classes_.push_back(classes_[0] + 1);
|
||||
binary_y = torch::cat({ y, torch::full({1}, classes_[1], y.options()) });
|
||||
auto dummy_x = torch::zeros({ 1, X.size(1) }, X.options());
|
||||
auto extended_X = torch::cat({ X, dummy_x });
|
||||
|
||||
double training_time = train_binary_classifier(extended_X, binary_y, params, converter, 0);
|
||||
} else {
|
||||
double training_time = train_binary_classifier(X, binary_y, params, converter, 0);
|
||||
}
|
||||
} else {
|
||||
// Multiclass case: train one classifier per class
|
||||
if (library_type_ == SVMLibrary::LIBSVM) {
|
||||
svm_models_.resize(classes_.size());
|
||||
} else {
|
||||
linear_models_.resize(classes_.size());
|
||||
}
|
||||
|
||||
double total_training_time = 0.0;
|
||||
|
||||
for (size_t i = 0; i < classes_.size(); ++i) {
|
||||
auto binary_y = create_binary_labels(y, classes_[i]);
|
||||
total_training_time += train_binary_classifier(X, binary_y, params, converter, i);
|
||||
}
|
||||
}
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
|
||||
|
||||
is_trained_ = true;
|
||||
|
||||
TrainingMetrics metrics;
|
||||
metrics.training_time = duration.count() / 1000.0;
|
||||
metrics.status = TrainingStatus::SUCCESS;
|
||||
|
||||
return metrics;
|
||||
}
|
||||
|
||||
std::vector<int> OneVsRestStrategy::predict(const torch::Tensor& X, DataConverter& converter)
|
||||
{
|
||||
if (!is_trained_) {
|
||||
throw std::runtime_error("Model is not trained");
|
||||
}
|
||||
|
||||
auto decision_values = decision_function(X, converter);
|
||||
std::vector<int> predictions;
|
||||
predictions.reserve(X.size(0));
|
||||
|
||||
for (const auto& decision_row : decision_values) {
|
||||
// Find the class with maximum decision value
|
||||
auto max_it = std::max_element(decision_row.begin(), decision_row.end());
|
||||
int predicted_class_idx = std::distance(decision_row.begin(), max_it);
|
||||
predictions.push_back(classes_[predicted_class_idx]);
|
||||
}
|
||||
|
||||
return predictions;
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> OneVsRestStrategy::predict_proba(const torch::Tensor& X,
|
||||
DataConverter& converter)
|
||||
{
|
||||
if (!supports_probability()) {
|
||||
throw std::runtime_error("Probability prediction not supported for current configuration");
|
||||
}
|
||||
|
||||
if (!is_trained_) {
|
||||
throw std::runtime_error("Model is not trained");
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> probabilities;
|
||||
probabilities.reserve(X.size(0));
|
||||
|
||||
for (int i = 0; i < X.size(0); ++i) {
|
||||
auto sample = X[i];
|
||||
std::vector<double> sample_probs;
|
||||
sample_probs.reserve(classes_.size());
|
||||
|
||||
if (library_type_ == SVMLibrary::LIBSVM) {
|
||||
for (size_t j = 0; j < classes_.size(); ++j) {
|
||||
if (svm_models_[j]) {
|
||||
auto sample_node = converter.to_svm_node(sample);
|
||||
double prob_estimates[2];
|
||||
svm_predict_probability(svm_models_[j].get(), sample_node, prob_estimates);
|
||||
sample_probs.push_back(prob_estimates[0]); // Probability of positive class
|
||||
} else {
|
||||
sample_probs.push_back(0.0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t j = 0; j < classes_.size(); ++j) {
|
||||
if (linear_models_[j]) {
|
||||
auto sample_node = converter.to_feature_node(sample);
|
||||
double prob_estimates[2];
|
||||
predict_probability(linear_models_[j].get(), sample_node, prob_estimates);
|
||||
sample_probs.push_back(prob_estimates[0]); // Probability of positive class
|
||||
} else {
|
||||
sample_probs.push_back(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize probabilities
|
||||
double sum = std::accumulate(sample_probs.begin(), sample_probs.end(), 0.0);
|
||||
if (sum > 0.0) {
|
||||
for (auto& prob : sample_probs) {
|
||||
prob /= sum;
|
||||
}
|
||||
} else {
|
||||
// Uniform distribution if all probabilities are zero
|
||||
std::fill(sample_probs.begin(), sample_probs.end(), 1.0 / classes_.size());
|
||||
}
|
||||
|
||||
probabilities.push_back(sample_probs);
|
||||
}
|
||||
|
||||
return probabilities;
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> OneVsRestStrategy::decision_function(const torch::Tensor& X,
|
||||
DataConverter& converter)
|
||||
{
|
||||
if (!is_trained_) {
|
||||
throw std::runtime_error("Model is not trained");
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> decision_values;
|
||||
decision_values.reserve(X.size(0));
|
||||
|
||||
for (int i = 0; i < X.size(0); ++i) {
|
||||
auto sample = X[i];
|
||||
std::vector<double> sample_decisions;
|
||||
sample_decisions.reserve(classes_.size());
|
||||
|
||||
if (library_type_ == SVMLibrary::LIBSVM) {
|
||||
for (size_t j = 0; j < classes_.size(); ++j) {
|
||||
if (svm_models_[j]) {
|
||||
auto sample_node = converter.to_svm_node(sample);
|
||||
double decision_value;
|
||||
svm_predict_values(svm_models_[j].get(), sample_node, &decision_value);
|
||||
sample_decisions.push_back(decision_value);
|
||||
} else {
|
||||
sample_decisions.push_back(0.0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t j = 0; j < classes_.size(); ++j) {
|
||||
if (linear_models_[j]) {
|
||||
auto sample_node = converter.to_feature_node(sample);
|
||||
double decision_value;
|
||||
predict_values(linear_models_[j].get(), sample_node, &decision_value);
|
||||
sample_decisions.push_back(decision_value);
|
||||
} else {
|
||||
sample_decisions.push_back(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
decision_values.push_back(sample_decisions);
|
||||
}
|
||||
|
||||
return decision_values;
|
||||
}
|
||||
|
||||
bool OneVsRestStrategy::supports_probability() const
|
||||
{
|
||||
if (!is_trained_) {
|
||||
return params_.get_probability();
|
||||
}
|
||||
|
||||
// Check if any model supports probability
|
||||
if (library_type_ == SVMLibrary::LIBSVM) {
|
||||
for (const auto& model : svm_models_) {
|
||||
if (model && svm_check_probability_model(model.get())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (const auto& model : linear_models_) {
|
||||
if (model && check_probability_model(model.get())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
torch::Tensor OneVsRestStrategy::create_binary_labels(const torch::Tensor& y, int positive_class)
|
||||
{
|
||||
auto binary_labels = torch::ones_like(y) * (-1); // Initialize with -1 (negative class)
|
||||
auto positive_mask = (y == positive_class);
|
||||
binary_labels.masked_fill_(positive_mask, 1); // Set positive class to +1
|
||||
|
||||
return binary_labels;
|
||||
}
|
||||
|
||||
double OneVsRestStrategy::train_binary_classifier(const torch::Tensor& X,
|
||||
const torch::Tensor& y_binary,
|
||||
const KernelParameters& params,
|
||||
DataConverter& converter,
|
||||
int class_idx)
|
||||
{
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (library_type_ == SVMLibrary::LIBSVM) {
|
||||
// Use libsvm
|
||||
auto problem = converter.to_svm_problem(X, y_binary);
|
||||
|
||||
// Setup SVM parameters
|
||||
svm_parameter svm_params;
|
||||
svm_params.svm_type = C_SVC;
|
||||
|
||||
switch (params.get_kernel_type()) {
|
||||
case KernelType::RBF:
|
||||
svm_params.kernel_type = RBF;
|
||||
break;
|
||||
case KernelType::POLYNOMIAL:
|
||||
svm_params.kernel_type = POLY;
|
||||
break;
|
||||
case KernelType::SIGMOID:
|
||||
svm_params.kernel_type = SIGMOID;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Invalid kernel type for libsvm");
|
||||
}
|
||||
|
||||
svm_params.degree = params.get_degree();
|
||||
svm_params.gamma = (params.get_gamma() == -1.0) ? 1.0 / X.size(1) : params.get_gamma();
|
||||
svm_params.coef0 = params.get_coef0();
|
||||
svm_params.cache_size = params.get_cache_size();
|
||||
svm_params.eps = params.get_tolerance();
|
||||
svm_params.C = params.get_C();
|
||||
svm_params.nr_weight = 0;
|
||||
svm_params.weight_label = nullptr;
|
||||
svm_params.weight = nullptr;
|
||||
svm_params.nu = 0.5;
|
||||
svm_params.p = 0.1;
|
||||
svm_params.shrinking = 1;
|
||||
svm_params.probability = params.get_probability() ? 1 : 0;
|
||||
|
||||
// Check parameters
|
||||
const char* error_msg = svm_check_parameter(problem.get(), &svm_params);
|
||||
if (error_msg) {
|
||||
throw std::runtime_error("SVM parameter error: " + std::string(error_msg));
|
||||
}
|
||||
|
||||
// Train model
|
||||
auto model = svm_train(problem.get(), &svm_params);
|
||||
if (!model) {
|
||||
throw std::runtime_error("Failed to train SVM model");
|
||||
}
|
||||
|
||||
svm_models_[class_idx] = std::unique_ptr<svm_model>(model);
|
||||
|
||||
} else {
|
||||
// Use liblinear
|
||||
auto problem = converter.to_linear_problem(X, y_binary);
|
||||
|
||||
// Setup linear parameters
|
||||
parameter linear_params;
|
||||
linear_params.solver_type = L2R_L2LOSS_SVC_DUAL; // Default solver for C-SVC
|
||||
linear_params.C = params.get_C();
|
||||
linear_params.eps = params.get_tolerance();
|
||||
linear_params.nr_weight = 0;
|
||||
linear_params.weight_label = nullptr;
|
||||
linear_params.weight = nullptr;
|
||||
linear_params.p = 0.1;
|
||||
linear_params.nu = 0.5;
|
||||
linear_params.init_sol = nullptr;
|
||||
linear_params.regularize_bias = 0;
|
||||
|
||||
// Check parameters
|
||||
const char* error_msg = check_parameter(problem.get(), &linear_params);
|
||||
if (error_msg) {
|
||||
throw std::runtime_error("Linear parameter error: " + std::string(error_msg));
|
||||
}
|
||||
|
||||
// Train model
|
||||
auto model = train(problem.get(), &linear_params);
|
||||
if (!model) {
|
||||
throw std::runtime_error("Failed to train linear model");
|
||||
}
|
||||
|
||||
linear_models_[class_idx] = std::unique_ptr<::model>(model);
|
||||
}
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
|
||||
|
||||
return duration.count() / 1000.0;
|
||||
}
|
||||
|
||||
void OneVsRestStrategy::cleanup_models()
|
||||
{
|
||||
for (auto& model : svm_models_) {
|
||||
if (model) {
|
||||
svm_free_and_destroy_model(&model);
|
||||
}
|
||||
}
|
||||
svm_models_.clear();
|
||||
|
||||
for (auto& model : linear_models_) {
|
||||
if (model) {
|
||||
free_and_destroy_model(&model);
|
||||
}
|
||||
}
|
||||
linear_models_.clear();
|
||||
|
||||
is_trained_ = false;
|
||||
}
|
||||
|
||||
// OneVsOneStrategy Implementation
|
||||
OneVsOneStrategy::OneVsOneStrategy()
|
||||
: library_type_(SVMLibrary::LIBLINEAR)
|
||||
{
|
||||
}
|
||||
|
||||
OneVsOneStrategy::~OneVsOneStrategy()
|
||||
{
|
||||
cleanup_models();
|
||||
}
|
||||
|
||||
TrainingMetrics OneVsOneStrategy::fit(const torch::Tensor& X,
|
||||
const torch::Tensor& y,
|
||||
const KernelParameters& params,
|
||||
DataConverter& converter)
|
||||
{
|
||||
cleanup_models();
|
||||
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// Store parameters and determine library type
|
||||
params_ = params;
|
||||
library_type_ = get_svm_library(params.get_kernel_type());
|
||||
|
||||
// Extract unique classes
|
||||
auto y_cpu = y.to(torch::kCPU);
|
||||
auto unique_classes_tensor = torch::unique(y_cpu);
|
||||
classes_.clear();
|
||||
|
||||
for (int i = 0; i < unique_classes_tensor.size(0); ++i) {
|
||||
classes_.push_back(unique_classes_tensor[i].item<int>());
|
||||
}
|
||||
|
||||
std::sort(classes_.begin(), classes_.end());
|
||||
|
||||
// Generate all class pairs
|
||||
class_pairs_.clear();
|
||||
for (size_t i = 0; i < classes_.size(); ++i) {
|
||||
for (size_t j = i + 1; j < classes_.size(); ++j) {
|
||||
class_pairs_.emplace_back(classes_[i], classes_[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize model storage
|
||||
if (library_type_ == SVMLibrary::LIBSVM) {
|
||||
svm_models_.resize(class_pairs_.size());
|
||||
} else {
|
||||
linear_models_.resize(class_pairs_.size());
|
||||
}
|
||||
|
||||
double total_training_time = 0.0;
|
||||
|
||||
// Train one classifier for each class pair
|
||||
for (size_t i = 0; i < class_pairs_.size(); ++i) {
|
||||
auto [class1, class2] = class_pairs_[i];
|
||||
total_training_time += train_pairwise_classifier(X, y, class1, class2, params, converter, i);
|
||||
}
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
|
||||
|
||||
is_trained_ = true;
|
||||
|
||||
TrainingMetrics metrics;
|
||||
metrics.training_time = duration.count() / 1000.0;
|
||||
metrics.status = TrainingStatus::SUCCESS;
|
||||
|
||||
return metrics;
|
||||
}
|
||||
|
||||
std::vector<int> OneVsOneStrategy::predict(const torch::Tensor& X, DataConverter& converter)
|
||||
{
|
||||
if (!is_trained_) {
|
||||
throw std::runtime_error("Model is not trained");
|
||||
}
|
||||
|
||||
auto decision_values = decision_function(X, converter);
|
||||
return vote_predictions(decision_values);
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> OneVsOneStrategy::predict_proba(const torch::Tensor& X,
|
||||
DataConverter& converter)
|
||||
{
|
||||
// OvO probability estimation is more complex and typically done via
|
||||
// pairwise coupling (Hastie & Tibshirani, 1998)
|
||||
// For simplicity, we'll use decision function values and normalize
|
||||
|
||||
auto decision_values = decision_function(X, converter);
|
||||
std::vector<std::vector<double>> probabilities;
|
||||
probabilities.reserve(X.size(0));
|
||||
|
||||
for (const auto& decision_row : decision_values) {
|
||||
std::vector<double> class_scores(classes_.size(), 0.0);
|
||||
|
||||
// Aggregate decision values for each class
|
||||
for (size_t i = 0; i < class_pairs_.size(); ++i) {
|
||||
auto [class1, class2] = class_pairs_[i];
|
||||
double decision = decision_row[i];
|
||||
|
||||
auto it1 = std::find(classes_.begin(), classes_.end(), class1);
|
||||
auto it2 = std::find(classes_.begin(), classes_.end(), class2);
|
||||
|
||||
if (it1 != classes_.end() && it2 != classes_.end()) {
|
||||
size_t idx1 = std::distance(classes_.begin(), it1);
|
||||
size_t idx2 = std::distance(classes_.begin(), it2);
|
||||
|
||||
if (decision > 0) {
|
||||
class_scores[idx1] += 1.0;
|
||||
} else {
|
||||
class_scores[idx2] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert scores to probabilities
|
||||
double sum = std::accumulate(class_scores.begin(), class_scores.end(), 0.0);
|
||||
if (sum > 0.0) {
|
||||
for (auto& score : class_scores) {
|
||||
score /= sum;
|
||||
}
|
||||
} else {
|
||||
std::fill(class_scores.begin(), class_scores.end(), 1.0 / classes_.size());
|
||||
}
|
||||
|
||||
probabilities.push_back(class_scores);
|
||||
}
|
||||
|
||||
return probabilities;
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> OneVsOneStrategy::decision_function(const torch::Tensor& X,
|
Reference in New Issue
Block a user