1#include "svm_classifier/multiclass_strategy.hpp"
2#include "svm.h" // libsvm
3#include "linear.h" // liblinear
5#include <unordered_map>
6#include <unordered_set>
10namespace svm_classifier {
12 // OneVsRestStrategy Implementation
14 : library_type_(SVMLibrary::LIBLINEAR)
18 OneVsRestStrategy::~OneVsRestStrategy()
23 TrainingMetrics OneVsRestStrategy::fit(
const torch::Tensor& X,
24 const torch::Tensor& y,
25 const KernelParameters& params,
26 DataConverter& converter)
30 auto start_time = std::chrono::high_resolution_clock::now();
32 // Store parameters and determine library type
34 library_type_ = get_svm_library(params.get_kernel_type());
36 // Extract unique classes
37 auto y_cpu = y.to(torch::kCPU);
38 auto unique_classes_tensor = torch::unique(y_cpu);
41 for (
int i = 0; i < unique_classes_tensor.size(0); ++i) {
42 classes_.push_back(unique_classes_tensor[i].item<int>());
45 std::sort(classes_.begin(), classes_.end());
47 // Handle binary classification case
48 if (classes_.size() <= 2) {
49 // For binary classification, train a single classifier
50 classes_.resize(2);
// Ensure we have exactly 2 classes
53 if (classes_.size() == 1) {
54 // Edge case: only one class, create dummy binary problem
55 classes_.push_back(classes_[0] + 1);
56 binary_y = torch::cat({ y, torch::full({1}, classes_[1], y.options()) });
57 auto dummy_x = torch::zeros({ 1, X.size(1) }, X.options());
58 auto extended_X = torch::cat({ X, dummy_x });
60 double training_time = train_binary_classifier(extended_X, binary_y, params, converter, 0);
62 double training_time = train_binary_classifier(X, binary_y, params, converter, 0);
65 // Multiclass case: train one classifier per class
66 if (library_type_ == SVMLibrary::LIBSVM) {
67 svm_models_.resize(classes_.size());
69 linear_models_.resize(classes_.size());
72 double total_training_time = 0.0;
74 for (
size_t i = 0; i < classes_.size(); ++i) {
75 auto binary_y = create_binary_labels(y, classes_[i]);
76 total_training_time += train_binary_classifier(X, binary_y, params, converter, i);
80 auto end_time = std::chrono::high_resolution_clock::now();
81 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
85 TrainingMetrics metrics;
86 metrics.training_time = duration.count() / 1000.0;
87 metrics.status = TrainingStatus::SUCCESS;
92 std::vector<int> OneVsRestStrategy::predict(
const torch::Tensor& X, DataConverter& converter)
95 throw std::runtime_error(
"Model is not trained");
98 auto decision_values = decision_function(X, converter);
99 std::vector<int> predictions;
100 predictions.reserve(X.size(0));
102 for (
const auto& decision_row : decision_values) {
103 // Find the class with maximum decision value
104 auto max_it = std::max_element(decision_row.begin(), decision_row.end());
105 int predicted_class_idx = std::distance(decision_row.begin(), max_it);
106 predictions.push_back(classes_[predicted_class_idx]);
112 std::vector<std::vector<double>> OneVsRestStrategy::predict_proba(
const torch::Tensor& X,
113 DataConverter& converter)
115 if (!supports_probability()) {
116 throw std::runtime_error(
"Probability prediction not supported for current configuration");
120 throw std::runtime_error(
"Model is not trained");
123 std::vector<std::vector<double>> probabilities;
124 probabilities.reserve(X.size(0));
126 for (
int i = 0; i < X.size(0); ++i) {
128 std::vector<double> sample_probs;
129 sample_probs.reserve(classes_.size());
131 if (library_type_ == SVMLibrary::LIBSVM) {
132 for (
size_t j = 0; j < classes_.size(); ++j) {
133 if (svm_models_[j]) {
134 auto sample_node = converter.to_svm_node(sample);
135 double prob_estimates[2];
136 svm_predict_probability(svm_models_[j].get(), sample_node, prob_estimates);
137 sample_probs.push_back(prob_estimates[0]);
// Probability of positive class
139 sample_probs.push_back(0.0);
143 for (
size_t j = 0; j < classes_.size(); ++j) {
144 if (linear_models_[j]) {
145 auto sample_node = converter.to_feature_node(sample);
146 double prob_estimates[2];
147 predict_probability(linear_models_[j].get(), sample_node, prob_estimates);
148 sample_probs.push_back(prob_estimates[0]);
// Probability of positive class
150 sample_probs.push_back(0.0);
155 // Normalize probabilities
156 double sum = std::accumulate(sample_probs.begin(), sample_probs.end(), 0.0);
158 for (
auto& prob : sample_probs) {
162 // Uniform distribution if all probabilities are zero
163 std::fill(sample_probs.begin(), sample_probs.end(), 1.0 / classes_.size());
166 probabilities.push_back(sample_probs);
169 return probabilities;
172 std::vector<std::vector<double>> OneVsRestStrategy::decision_function(
const torch::Tensor& X,
173 DataConverter& converter)
176 throw std::runtime_error(
"Model is not trained");
179 std::vector<std::vector<double>> decision_values;
180 decision_values.reserve(X.size(0));
182 for (
int i = 0; i < X.size(0); ++i) {
184 std::vector<double> sample_decisions;
185 sample_decisions.reserve(classes_.size());
187 if (library_type_ == SVMLibrary::LIBSVM) {
188 for (
size_t j = 0; j < classes_.size(); ++j) {
189 if (svm_models_[j]) {
190 auto sample_node = converter.to_svm_node(sample);
191 double decision_value;
192 svm_predict_values(svm_models_[j].get(), sample_node, &decision_value);
193 sample_decisions.push_back(decision_value);
195 sample_decisions.push_back(0.0);
199 for (
size_t j = 0; j < classes_.size(); ++j) {
200 if (linear_models_[j]) {
201 auto sample_node = converter.to_feature_node(sample);
202 double decision_value;
203 predict_values(linear_models_[j].get(), sample_node, &decision_value);
204 sample_decisions.push_back(decision_value);
206 sample_decisions.push_back(0.0);
211 decision_values.push_back(sample_decisions);
214 return decision_values;
217 bool OneVsRestStrategy::supports_probability()
const
220 return params_.get_probability();
223 // Check if any model supports probability
224 if (library_type_ == SVMLibrary::LIBSVM) {
225 for (
const auto& model : svm_models_) {
226 if (model && svm_check_probability_model(model.get())) {
231 for (
const auto& model : linear_models_) {
232 if (model && check_probability_model(model.get())) {
241 torch::Tensor OneVsRestStrategy::create_binary_labels(
const torch::Tensor& y,
int positive_class)
243 auto binary_labels = torch::ones_like(y) * (-1);
// Initialize with -1 (negative class)
244 auto positive_mask = (y == positive_class);
245 binary_labels.masked_fill_(positive_mask, 1);
// Set positive class to +1
247 return binary_labels;
250 double OneVsRestStrategy::train_binary_classifier(
const torch::Tensor& X,
251 const torch::Tensor& y_binary,
252 const KernelParameters& params,
253 DataConverter& converter,
256 auto start_time = std::chrono::high_resolution_clock::now();
258 if (library_type_ == SVMLibrary::LIBSVM) {
260 auto problem = converter.to_svm_problem(X, y_binary);
262 // Setup SVM parameters
263 svm_parameter svm_params;
264 svm_params.svm_type = C_SVC;
266 switch (params.get_kernel_type()) {
267 case KernelType::RBF:
268 svm_params.kernel_type = RBF;
270 case KernelType::POLYNOMIAL:
271 svm_params.kernel_type = POLY;
273 case KernelType::SIGMOID:
274 svm_params.kernel_type = SIGMOID;
277 throw std::runtime_error(
"Invalid kernel type for libsvm");
280 svm_params.degree = params.get_degree();
281 svm_params.gamma = (params.get_gamma() == -1.0) ? 1.0 / X.size(1) : params.get_gamma();
282 svm_params.coef0 = params.get_coef0();
283 svm_params.cache_size = params.get_cache_size();
284 svm_params.eps = params.get_tolerance();
285 svm_params.C = params.get_C();
286 svm_params.nr_weight = 0;
287 svm_params.weight_label =
nullptr;
288 svm_params.weight =
nullptr;
291 svm_params.shrinking = 1;
292 svm_params.probability = params.get_probability() ? 1 : 0;
295 const char* error_msg = svm_check_parameter(problem.get(), &svm_params);
297 throw std::runtime_error(
"SVM parameter error: " + std::string(error_msg));
301 auto model = svm_train(problem.get(), &svm_params);
303 throw std::runtime_error(
"Failed to train SVM model");
306 svm_models_[class_idx] = std::unique_ptr<svm_model>(model);
310 auto problem = converter.to_linear_problem(X, y_binary);
312 // Setup linear parameters
313 parameter linear_params;
314 linear_params.solver_type = L2R_L2LOSS_SVC_DUAL;
// Default solver for C-SVC
315 linear_params.C = params.get_C();
316 linear_params.eps = params.get_tolerance();
317 linear_params.nr_weight = 0;
318 linear_params.weight_label =
nullptr;
319 linear_params.weight =
nullptr;
320 linear_params.p = 0.1;
321 linear_params.nu = 0.5;
322 linear_params.init_sol =
nullptr;
323 linear_params.regularize_bias = 0;
326 const char* error_msg = check_parameter(problem.get(), &linear_params);
328 throw std::runtime_error(
"Linear parameter error: " + std::string(error_msg));
332 auto model = train(problem.get(), &linear_params);
334 throw std::runtime_error(
"Failed to train linear model");
337 linear_models_[class_idx] = std::unique_ptr<::model>(model);
340 auto end_time = std::chrono::high_resolution_clock::now();
341 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
343 return duration.count() / 1000.0;
346 void OneVsRestStrategy::cleanup_models()
348 for (
auto& model : svm_models_) {
350 svm_free_and_destroy_model(&model);
355 for (
auto& model : linear_models_) {
357 free_and_destroy_model(&model);
360 linear_models_.clear();
365 // OneVsOneStrategy Implementation
366 OneVsOneStrategy::OneVsOneStrategy()
367 : library_type_(SVMLibrary::LIBLINEAR)
371 OneVsOneStrategy::~OneVsOneStrategy()
376 TrainingMetrics OneVsOneStrategy::fit(
const torch::Tensor& X,
377 const torch::Tensor& y,
378 const KernelParameters& params,
379 DataConverter& converter)
383 auto start_time = std::chrono::high_resolution_clock::now();
385 // Store parameters and determine library type
387 library_type_ = get_svm_library(params.get_kernel_type());
389 // Extract unique classes
390 auto y_cpu = y.to(torch::kCPU);
391 auto unique_classes_tensor = torch::unique(y_cpu);
394 for (
int i = 0; i < unique_classes_tensor.size(0); ++i) {
395 classes_.push_back(unique_classes_tensor[i].item<int>());
398 std::sort(classes_.begin(), classes_.end());
400 // Generate all class pairs
401 class_pairs_.clear();
402 for (
size_t i = 0; i < classes_.size(); ++i) {
403 for (
size_t j = i + 1; j < classes_.size(); ++j) {
404 class_pairs_.emplace_back(classes_[i], classes_[j]);
408 // Initialize model storage
409 if (library_type_ == SVMLibrary::LIBSVM) {
410 svm_models_.resize(class_pairs_.size());
412 linear_models_.resize(class_pairs_.size());
415 double total_training_time = 0.0;
417 // Train one classifier for each class pair
418 for (
size_t i = 0; i < class_pairs_.size(); ++i) {
419 auto [class1, class2] = class_pairs_[i];
420 total_training_time += train_pairwise_classifier(X, y, class1, class2, params, converter, i);
423 auto end_time = std::chrono::high_resolution_clock::now();
424 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
428 TrainingMetrics metrics;
429 metrics.training_time = duration.count() / 1000.0;
430 metrics.status = TrainingStatus::SUCCESS;
435 std::vector<int> OneVsOneStrategy::predict(
const torch::Tensor& X, DataConverter& converter)
438 throw std::runtime_error(
"Model is not trained");
441 auto decision_values = decision_function(X, converter);
442 return vote_predictions(decision_values);
445 std::vector<std::vector<double>> OneVsOneStrategy::predict_proba(
const torch::Tensor& X,
446 DataConverter& converter)
448 // OvO probability estimation is more complex and typically done via
449 // pairwise coupling (Hastie & Tibshirani, 1998)
450 // For simplicity, we'll use decision function values and normalize
452 auto decision_values = decision_function(X, converter);
453 std::vector<std::vector<double>> probabilities;
454 probabilities.reserve(X.size(0));
456 for (
const auto& decision_row : decision_values) {
457 std::vector<double> class_scores(classes_.size(), 0.0);
459 // Aggregate decision values for each class
460 for (
size_t i = 0; i < class_pairs_.size(); ++i) {
461 auto [class1, class2] = class_pairs_[i];
462 double decision = decision_row[i];
464 auto it1 = std::find(classes_.begin(), classes_.end(), class1);
465 auto it2 = std::find(classes_.begin(), classes_.end(), class2);
467 if (it1 != classes_.end() && it2 != classes_.end()) {
468 size_t idx1 = std::distance(classes_.begin(), it1);
469 size_t idx2 = std::distance(classes_.begin(), it2);
472 class_scores[idx1] += 1.0;
474 class_scores[idx2] += 1.0;
479 // Convert scores to probabilities
480 double sum = std::accumulate(class_scores.begin(), class_scores.end(), 0.0);
482 for (
auto& score : class_scores) {
486 std::fill(class_scores.begin(), class_scores.end(), 1.0 / classes_.size());
489 probabilities.push_back(class_scores);
492 return probabilities;
495 std::vector<std::vector<double>> OneVsOneStrategy::decision_function(
const torch::Tensor& X,
OneVsRestStrategy()
Constructor.