Files
SVMClassifier/multiclass__strategy_8cpp_source.html
2025-06-22 11:25:27 +00:00

74 KiB

<html xmlns="http://www.w3.org/1999/xhtml" lang="en-US"> <head> <script type="text/javascript" src="jquery.js"></script> <script type="text/javascript" src="dynsections.js"></script> <script type="text/javascript" src="search/searchdata.js"></script> <script type="text/javascript" src="search/search.js"></script> </head>
SVM Classifier C++ 1.0.0
High-performance Support Vector Machine classifier with scikit-learn compatible API
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ var searchBox = new SearchBox("searchBox", "search/",'.html'); /* @license-end */ </script> <script type="text/javascript" src="menudata.js"></script> <script type="text/javascript" src="menu.js"></script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { initMenu('',true,false,'search.php','Search'); $(document).ready(function() { init_search(); }); }); /* @license-end */ </script>
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(document).ready(function() { init_codefold(0); }); /* @license-end */ </script>
Loading...
Searching...
No Matches
multiclass_strategy.cpp
1#include "svm_classifier/multiclass_strategy.hpp"
2#include "svm.h" // libsvm
3#include "linear.h" // liblinear
4#include <algorithm>
5#include <unordered_map>
6#include <unordered_set>
7#include <chrono>
8#include <cmath>
9
10namespace svm_classifier {
11
12 // OneVsRestStrategy Implementation
14 : library_type_(SVMLibrary::LIBLINEAR)
15 {
16 }
17
18 OneVsRestStrategy::~OneVsRestStrategy()
19 {
20 cleanup_models();
21 }
22
23 TrainingMetrics OneVsRestStrategy::fit(const torch::Tensor& X,
24 const torch::Tensor& y,
25 const KernelParameters& params,
26 DataConverter& converter)
27 {
28 cleanup_models();
29
30 auto start_time = std::chrono::high_resolution_clock::now();
31
32 // Store parameters and determine library type
33 params_ = params;
34 library_type_ = get_svm_library(params.get_kernel_type());
35
36 // Extract unique classes
37 auto y_cpu = y.to(torch::kCPU);
38 auto unique_classes_tensor = torch::unique(y_cpu);
39 classes_.clear();
40
41 for (int i = 0; i < unique_classes_tensor.size(0); ++i) {
42 classes_.push_back(unique_classes_tensor[i].item<int>());
43 }
44
45 std::sort(classes_.begin(), classes_.end());
46
47 // Handle binary classification case
48 if (classes_.size() <= 2) {
49 // For binary classification, train a single classifier
50 classes_.resize(2); // Ensure we have exactly 2 classes
51
52 auto binary_y = y;
53 if (classes_.size() == 1) {
54 // Edge case: only one class, create dummy binary problem
55 classes_.push_back(classes_[0] + 1);
56 binary_y = torch::cat({ y, torch::full({1}, classes_[1], y.options()) });
57 auto dummy_x = torch::zeros({ 1, X.size(1) }, X.options());
58 auto extended_X = torch::cat({ X, dummy_x });
59
60 double training_time = train_binary_classifier(extended_X, binary_y, params, converter, 0);
61 } else {
62 double training_time = train_binary_classifier(X, binary_y, params, converter, 0);
63 }
64 } else {
65 // Multiclass case: train one classifier per class
66 if (library_type_ == SVMLibrary::LIBSVM) {
67 svm_models_.resize(classes_.size());
68 } else {
69 linear_models_.resize(classes_.size());
70 }
71
72 double total_training_time = 0.0;
73
74 for (size_t i = 0; i < classes_.size(); ++i) {
75 auto binary_y = create_binary_labels(y, classes_[i]);
76 total_training_time += train_binary_classifier(X, binary_y, params, converter, i);
77 }
78 }
79
80 auto end_time = std::chrono::high_resolution_clock::now();
81 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
82
83 is_trained_ = true;
84
85 TrainingMetrics metrics;
86 metrics.training_time = duration.count() / 1000.0;
87 metrics.status = TrainingStatus::SUCCESS;
88
89 return metrics;
90 }
91
92 std::vector<int> OneVsRestStrategy::predict(const torch::Tensor& X, DataConverter& converter)
93 {
94 if (!is_trained_) {
95 throw std::runtime_error("Model is not trained");
96 }
97
98 auto decision_values = decision_function(X, converter);
99 std::vector<int> predictions;
100 predictions.reserve(X.size(0));
101
102 for (const auto& decision_row : decision_values) {
103 // Find the class with maximum decision value
104 auto max_it = std::max_element(decision_row.begin(), decision_row.end());
105 int predicted_class_idx = std::distance(decision_row.begin(), max_it);
106 predictions.push_back(classes_[predicted_class_idx]);
107 }
108
109 return predictions;
110 }
111
112 std::vector<std::vector<double>> OneVsRestStrategy::predict_proba(const torch::Tensor& X,
113 DataConverter& converter)
114 {
115 if (!supports_probability()) {
116 throw std::runtime_error("Probability prediction not supported for current configuration");
117 }
118
119 if (!is_trained_) {
120 throw std::runtime_error("Model is not trained");
121 }
122
123 std::vector<std::vector<double>> probabilities;
124 probabilities.reserve(X.size(0));
125
126 for (int i = 0; i < X.size(0); ++i) {
127 auto sample = X[i];
128 std::vector<double> sample_probs;
129 sample_probs.reserve(classes_.size());
130
131 if (library_type_ == SVMLibrary::LIBSVM) {
132 for (size_t j = 0; j < classes_.size(); ++j) {
133 if (svm_models_[j]) {
134 auto sample_node = converter.to_svm_node(sample);
135 double prob_estimates[2];
136 svm_predict_probability(svm_models_[j].get(), sample_node, prob_estimates);
137 sample_probs.push_back(prob_estimates[0]); // Probability of positive class
138 } else {
139 sample_probs.push_back(0.0);
140 }
141 }
142 } else {
143 for (size_t j = 0; j < classes_.size(); ++j) {
144 if (linear_models_[j]) {
145 auto sample_node = converter.to_feature_node(sample);
146 double prob_estimates[2];
147 predict_probability(linear_models_[j].get(), sample_node, prob_estimates);
148 sample_probs.push_back(prob_estimates[0]); // Probability of positive class
149 } else {
150 sample_probs.push_back(0.0);
151 }
152 }
153 }
154
155 // Normalize probabilities
156 double sum = std::accumulate(sample_probs.begin(), sample_probs.end(), 0.0);
157 if (sum > 0.0) {
158 for (auto& prob : sample_probs) {
159 prob /= sum;
160 }
161 } else {
162 // Uniform distribution if all probabilities are zero
163 std::fill(sample_probs.begin(), sample_probs.end(), 1.0 / classes_.size());
164 }
165
166 probabilities.push_back(sample_probs);
167 }
168
169 return probabilities;
170 }
171
172 std::vector<std::vector<double>> OneVsRestStrategy::decision_function(const torch::Tensor& X,
173 DataConverter& converter)
174 {
175 if (!is_trained_) {
176 throw std::runtime_error("Model is not trained");
177 }
178
179 std::vector<std::vector<double>> decision_values;
180 decision_values.reserve(X.size(0));
181
182 for (int i = 0; i < X.size(0); ++i) {
183 auto sample = X[i];
184 std::vector<double> sample_decisions;
185 sample_decisions.reserve(classes_.size());
186
187 if (library_type_ == SVMLibrary::LIBSVM) {
188 for (size_t j = 0; j < classes_.size(); ++j) {
189 if (svm_models_[j]) {
190 auto sample_node = converter.to_svm_node(sample);
191 double decision_value;
192 svm_predict_values(svm_models_[j].get(), sample_node, &decision_value);
193 sample_decisions.push_back(decision_value);
194 } else {
195 sample_decisions.push_back(0.0);
196 }
197 }
198 } else {
199 for (size_t j = 0; j < classes_.size(); ++j) {
200 if (linear_models_[j]) {
201 auto sample_node = converter.to_feature_node(sample);
202 double decision_value;
203 predict_values(linear_models_[j].get(), sample_node, &decision_value);
204 sample_decisions.push_back(decision_value);
205 } else {
206 sample_decisions.push_back(0.0);
207 }
208 }
209 }
210
211 decision_values.push_back(sample_decisions);
212 }
213
214 return decision_values;
215 }
216
217 bool OneVsRestStrategy::supports_probability() const
218 {
219 if (!is_trained_) {
220 return params_.get_probability();
221 }
222
223 // Check if any model supports probability
224 if (library_type_ == SVMLibrary::LIBSVM) {
225 for (const auto& model : svm_models_) {
226 if (model && svm_check_probability_model(model.get())) {
227 return true;
228 }
229 }
230 } else {
231 for (const auto& model : linear_models_) {
232 if (model && check_probability_model(model.get())) {
233 return true;
234 }
235 }
236 }
237
238 return false;
239 }
240
241 torch::Tensor OneVsRestStrategy::create_binary_labels(const torch::Tensor& y, int positive_class)
242 {
243 auto binary_labels = torch::ones_like(y) * (-1); // Initialize with -1 (negative class)
244 auto positive_mask = (y == positive_class);
245 binary_labels.masked_fill_(positive_mask, 1); // Set positive class to +1
246
247 return binary_labels;
248 }
249
250 double OneVsRestStrategy::train_binary_classifier(const torch::Tensor& X,
251 const torch::Tensor& y_binary,
252 const KernelParameters& params,
253 DataConverter& converter,
254 int class_idx)
255 {
256 auto start_time = std::chrono::high_resolution_clock::now();
257
258 if (library_type_ == SVMLibrary::LIBSVM) {
259 // Use libsvm
260 auto problem = converter.to_svm_problem(X, y_binary);
261
262 // Setup SVM parameters
263 svm_parameter svm_params;
264 svm_params.svm_type = C_SVC;
265
266 switch (params.get_kernel_type()) {
267 case KernelType::RBF:
268 svm_params.kernel_type = RBF;
269 break;
270 case KernelType::POLYNOMIAL:
271 svm_params.kernel_type = POLY;
272 break;
273 case KernelType::SIGMOID:
274 svm_params.kernel_type = SIGMOID;
275 break;
276 default:
277 throw std::runtime_error("Invalid kernel type for libsvm");
278 }
279
280 svm_params.degree = params.get_degree();
281 svm_params.gamma = (params.get_gamma() == -1.0) ? 1.0 / X.size(1) : params.get_gamma();
282 svm_params.coef0 = params.get_coef0();
283 svm_params.cache_size = params.get_cache_size();
284 svm_params.eps = params.get_tolerance();
285 svm_params.C = params.get_C();
286 svm_params.nr_weight = 0;
287 svm_params.weight_label = nullptr;
288 svm_params.weight = nullptr;
289 svm_params.nu = 0.5;
290 svm_params.p = 0.1;
291 svm_params.shrinking = 1;
292 svm_params.probability = params.get_probability() ? 1 : 0;
293
294 // Check parameters
295 const char* error_msg = svm_check_parameter(problem.get(), &svm_params);
296 if (error_msg) {
297 throw std::runtime_error("SVM parameter error: " + std::string(error_msg));
298 }
299
300 // Train model
301 auto model = svm_train(problem.get(), &svm_params);
302 if (!model) {
303 throw std::runtime_error("Failed to train SVM model");
304 }
305
306 svm_models_[class_idx] = std::unique_ptr<svm_model>(model);
307
308 } else {
309 // Use liblinear
310 auto problem = converter.to_linear_problem(X, y_binary);
311
312 // Setup linear parameters
313 parameter linear_params;
314 linear_params.solver_type = L2R_L2LOSS_SVC_DUAL; // Default solver for C-SVC
315 linear_params.C = params.get_C();
316 linear_params.eps = params.get_tolerance();
317 linear_params.nr_weight = 0;
318 linear_params.weight_label = nullptr;
319 linear_params.weight = nullptr;
320 linear_params.p = 0.1;
321 linear_params.nu = 0.5;
322 linear_params.init_sol = nullptr;
323 linear_params.regularize_bias = 0;
324
325 // Check parameters
326 const char* error_msg = check_parameter(problem.get(), &linear_params);
327 if (error_msg) {
328 throw std::runtime_error("Linear parameter error: " + std::string(error_msg));
329 }
330
331 // Train model
332 auto model = train(problem.get(), &linear_params);
333 if (!model) {
334 throw std::runtime_error("Failed to train linear model");
335 }
336
337 linear_models_[class_idx] = std::unique_ptr<::model>(model);
338 }
339
340 auto end_time = std::chrono::high_resolution_clock::now();
341 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
342
343 return duration.count() / 1000.0;
344 }
345
346 void OneVsRestStrategy::cleanup_models()
347 {
348 for (auto& model : svm_models_) {
349 if (model) {
350 svm_free_and_destroy_model(&model);
351 }
352 }
353 svm_models_.clear();
354
355 for (auto& model : linear_models_) {
356 if (model) {
357 free_and_destroy_model(&model);
358 }
359 }
360 linear_models_.clear();
361
362 is_trained_ = false;
363 }
364
365 // OneVsOneStrategy Implementation
366 OneVsOneStrategy::OneVsOneStrategy()
367 : library_type_(SVMLibrary::LIBLINEAR)
368 {
369 }
370
371 OneVsOneStrategy::~OneVsOneStrategy()
372 {
373 cleanup_models();
374 }
375
376 TrainingMetrics OneVsOneStrategy::fit(const torch::Tensor& X,
377 const torch::Tensor& y,
378 const KernelParameters& params,
379 DataConverter& converter)
380 {
381 cleanup_models();
382
383 auto start_time = std::chrono::high_resolution_clock::now();
384
385 // Store parameters and determine library type
386 params_ = params;
387 library_type_ = get_svm_library(params.get_kernel_type());
388
389 // Extract unique classes
390 auto y_cpu = y.to(torch::kCPU);
391 auto unique_classes_tensor = torch::unique(y_cpu);
392 classes_.clear();
393
394 for (int i = 0; i < unique_classes_tensor.size(0); ++i) {
395 classes_.push_back(unique_classes_tensor[i].item<int>());
396 }
397
398 std::sort(classes_.begin(), classes_.end());
399
400 // Generate all class pairs
401 class_pairs_.clear();
402 for (size_t i = 0; i < classes_.size(); ++i) {
403 for (size_t j = i + 1; j < classes_.size(); ++j) {
404 class_pairs_.emplace_back(classes_[i], classes_[j]);
405 }
406 }
407
408 // Initialize model storage
409 if (library_type_ == SVMLibrary::LIBSVM) {
410 svm_models_.resize(class_pairs_.size());
411 } else {
412 linear_models_.resize(class_pairs_.size());
413 }
414
415 double total_training_time = 0.0;
416
417 // Train one classifier for each class pair
418 for (size_t i = 0; i < class_pairs_.size(); ++i) {
419 auto [class1, class2] = class_pairs_[i];
420 total_training_time += train_pairwise_classifier(X, y, class1, class2, params, converter, i);
421 }
422
423 auto end_time = std::chrono::high_resolution_clock::now();
424 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
425
426 is_trained_ = true;
427
428 TrainingMetrics metrics;
429 metrics.training_time = duration.count() / 1000.0;
430 metrics.status = TrainingStatus::SUCCESS;
431
432 return metrics;
433 }
434
435 std::vector<int> OneVsOneStrategy::predict(const torch::Tensor& X, DataConverter& converter)
436 {
437 if (!is_trained_) {
438 throw std::runtime_error("Model is not trained");
439 }
440
441 auto decision_values = decision_function(X, converter);
442 return vote_predictions(decision_values);
443 }
444
445 std::vector<std::vector<double>> OneVsOneStrategy::predict_proba(const torch::Tensor& X,
446 DataConverter& converter)
447 {
448 // OvO probability estimation is more complex and typically done via
449 // pairwise coupling (Hastie & Tibshirani, 1998)
450 // For simplicity, we'll use decision function values and normalize
451
452 auto decision_values = decision_function(X, converter);
453 std::vector<std::vector<double>> probabilities;
454 probabilities.reserve(X.size(0));
455
456 for (const auto& decision_row : decision_values) {
457 std::vector<double> class_scores(classes_.size(), 0.0);
458
459 // Aggregate decision values for each class
460 for (size_t i = 0; i < class_pairs_.size(); ++i) {
461 auto [class1, class2] = class_pairs_[i];
462 double decision = decision_row[i];
463
464 auto it1 = std::find(classes_.begin(), classes_.end(), class1);
465 auto it2 = std::find(classes_.begin(), classes_.end(), class2);
466
467 if (it1 != classes_.end() && it2 != classes_.end()) {
468 size_t idx1 = std::distance(classes_.begin(), it1);
469 size_t idx2 = std::distance(classes_.begin(), it2);
470
471 if (decision > 0) {
472 class_scores[idx1] += 1.0;
473 } else {
474 class_scores[idx2] += 1.0;
475 }
476 }
477 }
478
479 // Convert scores to probabilities
480 double sum = std::accumulate(class_scores.begin(), class_scores.end(), 0.0);
481 if (sum > 0.0) {
482 for (auto& score : class_scores) {
483 score /= sum;
484 }
485 } else {
486 std::fill(class_scores.begin(), class_scores.end(), 1.0 / classes_.size());
487 }
488
489 probabilities.push_back(class_scores);
490 }
491
492 return probabilities;
493 }
494
495 std::vector<std::vector<double>> OneVsOneStrategy::decision_function(const torch::Tensor& X,

Generated on Sun Jun 22 2025 11:25:27 for SVM Classifier C++ by doxygen 1.9.8 </html>