1#include "svm_classifier/kernel_parameters.hpp"
5namespace svm_classifier {
7 KernelParameters::KernelParameters()
8 : kernel_type_(KernelType::LINEAR)
9 , multiclass_strategy_(MulticlassStrategy::ONE_VS_REST)
14 , gamma_(-1.0)
// Auto gamma
21 KernelParameters::KernelParameters(
const nlohmann::json& config) : KernelParameters()
23 set_parameters(config);
26 void KernelParameters::set_parameters(
const nlohmann::json& config)
28 // Set kernel type first as it affects validation
29 if (config.contains(
"kernel")) {
30 if (config[
"kernel"].is_string()) {
31 set_kernel_type(string_to_kernel_type(config[
"kernel"]));
33 throw std::invalid_argument(
"Kernel must be a string");
37 // Set multiclass strategy
38 if (config.contains(
"multiclass_strategy")) {
39 if (config[
"multiclass_strategy"].is_string()) {
40 set_multiclass_strategy(string_to_multiclass_strategy(config[
"multiclass_strategy"]));
42 throw std::invalid_argument(
"Multiclass strategy must be a string");
46 // Set common parameters
47 if (config.contains(
"C")) {
48 if (config[
"C"].is_number()) {
51 throw std::invalid_argument(
"C must be a number");
55 if (config.contains(
"tolerance")) {
56 if (config[
"tolerance"].is_number()) {
57 set_tolerance(config[
"tolerance"]);
59 throw std::invalid_argument(
"Tolerance must be a number");
63 if (config.contains(
"max_iterations")) {
64 if (config[
"max_iterations"].is_number_integer()) {
65 set_max_iterations(config[
"max_iterations"]);
67 throw std::invalid_argument(
"Max iterations must be an integer");
71 if (config.contains(
"probability")) {
72 if (config[
"probability"].is_boolean()) {
73 set_probability(config[
"probability"]);
75 throw std::invalid_argument(
"Probability must be a boolean");
79 // Set kernel-specific parameters
80 if (config.contains(
"gamma")) {
81 if (config[
"gamma"].is_number()) {
82 set_gamma(config[
"gamma"]);
83 }
else if (config[
"gamma"].is_string() && config[
"gamma"] ==
"auto") {
84 set_gamma(-1.0);
// Auto gamma
86 throw std::invalid_argument(
"Gamma must be a number or 'auto'");
90 if (config.contains(
"degree")) {
91 if (config[
"degree"].is_number_integer()) {
92 set_degree(config[
"degree"]);
94 throw std::invalid_argument(
"Degree must be an integer");
98 if (config.contains(
"coef0")) {
99 if (config[
"coef0"].is_number()) {
100 set_coef0(config[
"coef0"]);
102 throw std::invalid_argument(
"Coef0 must be a number");
106 if (config.contains(
"cache_size")) {
107 if (config[
"cache_size"].is_number()) {
108 set_cache_size(config[
"cache_size"]);
110 throw std::invalid_argument(
"Cache size must be a number");
114 // Validate all parameters
118 nlohmann::json KernelParameters::get_parameters()
const
120 nlohmann::json params = {
121 {
"kernel", kernel_type_to_string(kernel_type_)},
122 {
"multiclass_strategy", multiclass_strategy_to_string(multiclass_strategy_)},
124 {
"tolerance", tolerance_},
125 {
"max_iterations", max_iterations_},
126 {
"probability", probability_},
127 {
"cache_size", cache_size_}
130 // Add kernel-specific parameters
131 switch (kernel_type_) {
132 case KernelType::LINEAR:
133 // No additional parameters for linear kernel
136 case KernelType::RBF:
137 params[
"gamma"] = is_gamma_auto() ?
"auto" : gamma_;
140 case KernelType::POLYNOMIAL:
141 params[
"degree"] = degree_;
142 params[
"gamma"] = is_gamma_auto() ?
"auto" : gamma_;
143 params[
"coef0"] = coef0_;
146 case KernelType::SIGMOID:
147 params[
"gamma"] = is_gamma_auto() ?
"auto" : gamma_;
148 params[
"coef0"] = coef0_;
155 void KernelParameters::set_kernel_type(KernelType kernel)
157 kernel_type_ = kernel;
159 // Reset kernel-specific parameters to defaults when kernel changes
160 auto defaults = get_default_parameters(kernel);
162 if (defaults.contains(
"gamma")) {
163 gamma_ = defaults[
"gamma"];
165 if (defaults.contains(
"degree")) {
166 degree_ = defaults[
"degree"];
168 if (defaults.contains(
"coef0")) {
169 coef0_ = defaults[
"coef0"];
173 void KernelParameters::set_C(
double c)
176 throw std::invalid_argument(
"C must be positive (C > 0)");
181 void KernelParameters::set_gamma(
double gamma)
183 // Allow negative values for auto gamma (-1.0)
184 if (gamma > 0.0 || gamma == -1.0) {
187 throw std::invalid_argument(
"Gamma must be positive or -1 for auto");
191 void KernelParameters::set_degree(
int degree)
194 throw std::invalid_argument(
"Degree must be >= 1");
199 void KernelParameters::set_coef0(
double coef0)
204 void KernelParameters::set_tolerance(
double tol)
207 throw std::invalid_argument(
"Tolerance must be positive (tolerance > 0)");
212 void KernelParameters::set_max_iterations(
int max_iter)
214 if (max_iter <= 0 && max_iter != -1) {
215 throw std::invalid_argument(
"Max iterations must be positive or -1 for no limit");
217 max_iterations_ = max_iter;
220 void KernelParameters::set_cache_size(
double cache_size)
222 if (cache_size < 0.0) {
223 throw std::invalid_argument(
"Cache size must be non-negative");
225 cache_size_ = cache_size;
228 void KernelParameters::set_probability(
bool probability)
230 probability_ = probability;
233 void KernelParameters::set_multiclass_strategy(MulticlassStrategy strategy)
235 multiclass_strategy_ = strategy;
238 void KernelParameters::validate()
const
240 // Validate common parameters
242 throw std::invalid_argument(
"C must be positive");
245 if (tolerance_ <= 0.0) {
246 throw std::invalid_argument(
"Tolerance must be positive");
249 if (max_iterations_ <= 0 && max_iterations_ != -1) {
250 throw std::invalid_argument(
"Max iterations must be positive or -1");
253 if (cache_size_ < 0.0) {
254 throw std::invalid_argument(
"Cache size must be non-negative");
257 // Validate kernel-specific parameters
258 validate_kernel_parameters();
261 void KernelParameters::validate_kernel_parameters()
const
263 switch (kernel_type_) {
264 case KernelType::LINEAR:
265 // Linear kernel has no additional parameters to validate
268 case KernelType::RBF:
269 if (gamma_ > 0.0 || gamma_ == -1.0) {
270 // Valid gamma (positive or auto)
272 throw std::invalid_argument(
"RBF kernel gamma must be positive or auto (-1)");
276 case KernelType::POLYNOMIAL:
278 throw std::invalid_argument(
"Polynomial degree must be >= 1");
280 if (gamma_ > 0.0 || gamma_ == -1.0) {
283 throw std::invalid_argument(
"Polynomial kernel gamma must be positive or auto (-1)");
285 // coef0 can be any real number
288 case KernelType::SIGMOID:
289 if (gamma_ > 0.0 || gamma_ == -1.0) {
292 throw std::invalid_argument(
"Sigmoid kernel gamma must be positive or auto (-1)");
294 // coef0 can be any real number
299 nlohmann::json KernelParameters::get_default_parameters(KernelType kernel)
301 nlohmann::json defaults = {
304 {
"max_iterations", -1},
305 {
"probability",
false},
306 {
"multiclass_strategy",
"ovr"},
307 {
"cache_size", 200.0}
311 case KernelType::LINEAR:
312 defaults[
"kernel"] =
"linear";
315 case KernelType::RBF:
316 defaults[
"kernel"] =
"rbf";
317 defaults[
"gamma"] = -1.0;
// Auto gamma
320 case KernelType::POLYNOMIAL:
321 defaults[
"kernel"] =
"polynomial";
322 defaults[
"degree"] = 3;
323 defaults[
"gamma"] = -1.0;
// Auto gamma
324 defaults[
"coef0"] = 0.0;
327 case KernelType::SIGMOID:
328 defaults[
"kernel"] =
"sigmoid";
329 defaults[
"gamma"] = -1.0;
// Auto gamma
330 defaults[
"coef0"] = 0.0;
337 void KernelParameters::reset_to_defaults()
339 auto defaults = get_default_parameters(kernel_type_);
340 set_parameters(defaults);
343 void KernelParameters::set_gamma_auto()
348}
// namespace svm_classifier