SVM Classifier C++ 1.0.0
High-performance Support Vector Machine classifier with scikit-learn compatible API
Loading...
Searching...
No Matches
kernel_parameters.cpp
1#include "svm_classifier/kernel_parameters.hpp"
2#include <stdexcept>
3#include <cmath>
4
5namespace svm_classifier {
6
7 KernelParameters::KernelParameters()
8 : kernel_type_(KernelType::LINEAR)
9 , multiclass_strategy_(MulticlassStrategy::ONE_VS_REST)
10 , C_(1.0)
11 , tolerance_(1e-3)
12 , max_iterations_(-1)
13 , probability_(false)
14 , gamma_(-1.0) // Auto gamma
15 , degree_(3)
16 , coef0_(0.0)
17 , cache_size_(200.0)
18 {
19 }
20
21 KernelParameters::KernelParameters(const nlohmann::json& config) : KernelParameters()
22 {
23 set_parameters(config);
24 }
25
26 void KernelParameters::set_parameters(const nlohmann::json& config)
27 {
28 // Set kernel type first as it affects validation
29 if (config.contains("kernel")) {
30 if (config["kernel"].is_string()) {
31 set_kernel_type(string_to_kernel_type(config["kernel"]));
32 } else {
33 throw std::invalid_argument("Kernel must be a string");
34 }
35 }
36
37 // Set multiclass strategy
38 if (config.contains("multiclass_strategy")) {
39 if (config["multiclass_strategy"].is_string()) {
40 set_multiclass_strategy(string_to_multiclass_strategy(config["multiclass_strategy"]));
41 } else {
42 throw std::invalid_argument("Multiclass strategy must be a string");
43 }
44 }
45
46 // Set common parameters
47 if (config.contains("C")) {
48 if (config["C"].is_number()) {
49 set_C(config["C"]);
50 } else {
51 throw std::invalid_argument("C must be a number");
52 }
53 }
54
55 if (config.contains("tolerance")) {
56 if (config["tolerance"].is_number()) {
57 set_tolerance(config["tolerance"]);
58 } else {
59 throw std::invalid_argument("Tolerance must be a number");
60 }
61 }
62
63 if (config.contains("max_iterations")) {
64 if (config["max_iterations"].is_number_integer()) {
65 set_max_iterations(config["max_iterations"]);
66 } else {
67 throw std::invalid_argument("Max iterations must be an integer");
68 }
69 }
70
71 if (config.contains("probability")) {
72 if (config["probability"].is_boolean()) {
73 set_probability(config["probability"]);
74 } else {
75 throw std::invalid_argument("Probability must be a boolean");
76 }
77 }
78
79 // Set kernel-specific parameters
80 if (config.contains("gamma")) {
81 if (config["gamma"].is_number()) {
82 set_gamma(config["gamma"]);
83 } else if (config["gamma"].is_string() && config["gamma"] == "auto") {
84 set_gamma(-1.0); // Auto gamma
85 } else {
86 throw std::invalid_argument("Gamma must be a number or 'auto'");
87 }
88 }
89
90 if (config.contains("degree")) {
91 if (config["degree"].is_number_integer()) {
92 set_degree(config["degree"]);
93 } else {
94 throw std::invalid_argument("Degree must be an integer");
95 }
96 }
97
98 if (config.contains("coef0")) {
99 if (config["coef0"].is_number()) {
100 set_coef0(config["coef0"]);
101 } else {
102 throw std::invalid_argument("Coef0 must be a number");
103 }
104 }
105
106 if (config.contains("cache_size")) {
107 if (config["cache_size"].is_number()) {
108 set_cache_size(config["cache_size"]);
109 } else {
110 throw std::invalid_argument("Cache size must be a number");
111 }
112 }
113
114 // Validate all parameters
115 validate();
116 }
117
118 nlohmann::json KernelParameters::get_parameters() const
119 {
120 nlohmann::json params = {
121 {"kernel", kernel_type_to_string(kernel_type_)},
122 {"multiclass_strategy", multiclass_strategy_to_string(multiclass_strategy_)},
123 {"C", C_},
124 {"tolerance", tolerance_},
125 {"max_iterations", max_iterations_},
126 {"probability", probability_},
127 {"cache_size", cache_size_}
128 };
129
130 // Add kernel-specific parameters
131 switch (kernel_type_) {
132 case KernelType::LINEAR:
133 // No additional parameters for linear kernel
134 break;
135
136 case KernelType::RBF:
137 params["gamma"] = is_gamma_auto() ? "auto" : gamma_;
138 break;
139
140 case KernelType::POLYNOMIAL:
141 params["degree"] = degree_;
142 params["gamma"] = is_gamma_auto() ? "auto" : gamma_;
143 params["coef0"] = coef0_;
144 break;
145
146 case KernelType::SIGMOID:
147 params["gamma"] = is_gamma_auto() ? "auto" : gamma_;
148 params["coef0"] = coef0_;
149 break;
150 }
151
152 return params;
153 }
154
155 void KernelParameters::set_kernel_type(KernelType kernel)
156 {
157 kernel_type_ = kernel;
158
159 // Reset kernel-specific parameters to defaults when kernel changes
160 auto defaults = get_default_parameters(kernel);
161
162 if (defaults.contains("gamma")) {
163 gamma_ = defaults["gamma"];
164 }
165 if (defaults.contains("degree")) {
166 degree_ = defaults["degree"];
167 }
168 if (defaults.contains("coef0")) {
169 coef0_ = defaults["coef0"];
170 }
171 }
172
173 void KernelParameters::set_C(double c)
174 {
175 if (c <= 0.0) {
176 throw std::invalid_argument("C must be positive (C > 0)");
177 }
178 C_ = c;
179 }
180
181 void KernelParameters::set_gamma(double gamma)
182 {
183 // Allow negative values for auto gamma (-1.0)
184 if (gamma > 0.0 || gamma == -1.0) {
185 gamma_ = gamma;
186 } else {
187 throw std::invalid_argument("Gamma must be positive or -1 for auto");
188 }
189 }
190
191 void KernelParameters::set_degree(int degree)
192 {
193 if (degree < 1) {
194 throw std::invalid_argument("Degree must be >= 1");
195 }
196 degree_ = degree;
197 }
198
199 void KernelParameters::set_coef0(double coef0)
200 {
201 coef0_ = coef0;
202 }
203
204 void KernelParameters::set_tolerance(double tol)
205 {
206 if (tol <= 0.0) {
207 throw std::invalid_argument("Tolerance must be positive (tolerance > 0)");
208 }
209 tolerance_ = tol;
210 }
211
212 void KernelParameters::set_max_iterations(int max_iter)
213 {
214 if (max_iter <= 0 && max_iter != -1) {
215 throw std::invalid_argument("Max iterations must be positive or -1 for no limit");
216 }
217 max_iterations_ = max_iter;
218 }
219
220 void KernelParameters::set_cache_size(double cache_size)
221 {
222 if (cache_size < 0.0) {
223 throw std::invalid_argument("Cache size must be non-negative");
224 }
225 cache_size_ = cache_size;
226 }
227
228 void KernelParameters::set_probability(bool probability)
229 {
230 probability_ = probability;
231 }
232
233 void KernelParameters::set_multiclass_strategy(MulticlassStrategy strategy)
234 {
235 multiclass_strategy_ = strategy;
236 }
237
238 void KernelParameters::validate() const
239 {
240 // Validate common parameters
241 if (C_ <= 0.0) {
242 throw std::invalid_argument("C must be positive");
243 }
244
245 if (tolerance_ <= 0.0) {
246 throw std::invalid_argument("Tolerance must be positive");
247 }
248
249 if (max_iterations_ <= 0 && max_iterations_ != -1) {
250 throw std::invalid_argument("Max iterations must be positive or -1");
251 }
252
253 if (cache_size_ < 0.0) {
254 throw std::invalid_argument("Cache size must be non-negative");
255 }
256
257 // Validate kernel-specific parameters
258 validate_kernel_parameters();
259 }
260
261 void KernelParameters::validate_kernel_parameters() const
262 {
263 switch (kernel_type_) {
264 case KernelType::LINEAR:
265 // Linear kernel has no additional parameters to validate
266 break;
267
268 case KernelType::RBF:
269 if (gamma_ > 0.0 || gamma_ == -1.0) {
270 // Valid gamma (positive or auto)
271 } else {
272 throw std::invalid_argument("RBF kernel gamma must be positive or auto (-1)");
273 }
274 break;
275
276 case KernelType::POLYNOMIAL:
277 if (degree_ < 1) {
278 throw std::invalid_argument("Polynomial degree must be >= 1");
279 }
280 if (gamma_ > 0.0 || gamma_ == -1.0) {
281 // Valid gamma
282 } else {
283 throw std::invalid_argument("Polynomial kernel gamma must be positive or auto (-1)");
284 }
285 // coef0 can be any real number
286 break;
287
288 case KernelType::SIGMOID:
289 if (gamma_ > 0.0 || gamma_ == -1.0) {
290 // Valid gamma
291 } else {
292 throw std::invalid_argument("Sigmoid kernel gamma must be positive or auto (-1)");
293 }
294 // coef0 can be any real number
295 break;
296 }
297 }
298
299 nlohmann::json KernelParameters::get_default_parameters(KernelType kernel)
300 {
301 nlohmann::json defaults = {
302 {"C", 1.0},
303 {"tolerance", 1e-3},
304 {"max_iterations", -1},
305 {"probability", false},
306 {"multiclass_strategy", "ovr"},
307 {"cache_size", 200.0}
308 };
309
310 switch (kernel) {
311 case KernelType::LINEAR:
312 defaults["kernel"] = "linear";
313 break;
314
315 case KernelType::RBF:
316 defaults["kernel"] = "rbf";
317 defaults["gamma"] = -1.0; // Auto gamma
318 break;
319
320 case KernelType::POLYNOMIAL:
321 defaults["kernel"] = "polynomial";
322 defaults["degree"] = 3;
323 defaults["gamma"] = -1.0; // Auto gamma
324 defaults["coef0"] = 0.0;
325 break;
326
327 case KernelType::SIGMOID:
328 defaults["kernel"] = "sigmoid";
329 defaults["gamma"] = -1.0; // Auto gamma
330 defaults["coef0"] = 0.0;
331 break;
332 }
333
334 return defaults;
335 }
336
337 void KernelParameters::reset_to_defaults()
338 {
339 auto defaults = get_default_parameters(kernel_type_);
340 set_parameters(defaults);
341 }
342
343 void KernelParameters::set_gamma_auto()
344 {
345 gamma_ = -1.0;
346 }
347
348} // namespace svm_classifier