4#include "kernel_parameters.hpp"
5#include "data_converter.hpp"
6#include "multiclass_strategy.hpp"
7#include <torch/torch.h>
8#include <nlohmann/json.hpp>
12namespace svm_classifier {
42 MulticlassStrategy multiclass_strategy = MulticlassStrategy::ONE_VS_REST);
85 torch::Tensor
predict(const torch::Tensor& X);
110 double score(const torch::Tensor& X, const torch::Tensor& y_true);
209 const torch::Tensor& y,
221 const torch::Tensor& y,
222 const nlohmann::json& param_grid,
238 KernelParameters params_;
239 std::unique_ptr<MulticlassStrategyBase> multiclass_strategy_;
240 std::unique_ptr<DataConverter> data_converter_;
252 void validate_input(
const torch::Tensor& X,
253 const torch::Tensor& y = torch::Tensor(),
254 bool check_fitted =
false);
259 void initialize_multiclass_strategy();
267 std::vector<std::vector<int>> calculate_confusion_matrix(
const std::vector<int>& y_true,
268 const std::vector<int>& y_pred);
275 std::tuple<double, double, double> calculate_metrics_from_confusion_matrix(
276 const std::vector<std::vector<int>>& confusion_matrix);
286 std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
287 split_for_cv(
const torch::Tensor& X,
const torch::Tensor& y,
int fold,
int n_folds);
294 std::vector<nlohmann::json> generate_param_combinations(
const nlohmann::json& param_grid);
Support Vector Machine Classifier with scikit-learn compatible API.
double score(const torch::Tensor &X, const torch::Tensor &y_true)
Calculate accuracy score on test data.
TrainingMetrics get_training_metrics() const
Get training metrics from last fit.
MulticlassStrategy get_multiclass_strategy() const
Get multiclass strategy.
SVMClassifier & operator=(const SVMClassifier &)=delete
Copy assignment (deleted - models are not copyable)
~SVMClassifier()
Destructor.
torch::Tensor get_feature_importance() const
Get feature importance (for linear kernels only)
SVMClassifier(const nlohmann::json &config)
Constructor with JSON parameters.
SVMClassifier(const SVMClassifier &)=delete
Copy constructor (deleted - models are not copyable)
SVMLibrary get_svm_library() const
Get SVM library being used.
EvaluationMetrics evaluate(const torch::Tensor &X, const torch::Tensor &y_true)
Calculate detailed evaluation metrics.
SVMClassifier()
Default constructor with default parameters.
bool supports_probability() const
Check if the current model supports probability prediction.
std::vector< double > cross_validate(const torch::Tensor &X, const torch::Tensor &y, int cv=5)
Perform cross-validation.
KernelType get_kernel_type() const
Get kernel type.
void load_model(const std::string &filename)
Load model from file.
torch::Tensor predict(const torch::Tensor &X)
Predict class labels for samples.
bool is_fitted() const
Check if the model is fitted/trained.
int get_n_classes() const
Get the number of classes.
int get_n_features() const
Get the number of features.
nlohmann::json get_parameters() const
Get current parameters as JSON.
TrainingMetrics fit(const torch::Tensor &X, const torch::Tensor &y)
Train the SVM classifier.
SVMClassifier(KernelType kernel, double C=1.0, MulticlassStrategy multiclass_strategy=MulticlassStrategy::ONE_VS_REST)
Constructor with explicit parameters.
void reset()
Reset the classifier (clear trained model)
torch::Tensor predict_proba(const torch::Tensor &X)
Predict class probabilities for samples.
void save_model(const std::string &filename) const
Save model to file.
torch::Tensor decision_function(const torch::Tensor &X)
Get decision function values.
void set_parameters(const nlohmann::json &config)
Set parameters from JSON configuration.
SVMClassifier(SVMClassifier &&) noexcept
Move constructor.
std::vector< int > get_classes() const
Get unique class labels.
nlohmann::json grid_search(const torch::Tensor &X, const torch::Tensor &y, const nlohmann::json ¶m_grid, int cv=5)
Find optimal hyperparameters using grid search.
Model evaluation metrics.
Training metrics structure.