4#include "kernel_parameters.hpp"
5#include "data_converter.hpp"
6#include <torch/torch.h>
9#include <unordered_map>
15namespace svm_classifier {
36 const torch::Tensor& y,
37 const KernelParameters& params,
46 virtual std::vector<int>
predict(
const torch::Tensor& X,
55 virtual std::vector<std::vector<double>>
predict_proba(
const torch::Tensor& X,
112 const torch::Tensor& y,
113 const KernelParameters& params,
116 std::vector<int>
predict(
const torch::Tensor& X,
131 MulticlassStrategy
get_strategy_type()
const override {
return MulticlassStrategy::ONE_VS_REST; }
134 std::vector<std::unique_ptr<svm_model>> svm_models_;
135 std::vector<std::unique_ptr<model>> linear_models_;
136 KernelParameters params_;
137 SVMLibrary library_type_;
145 torch::Tensor create_binary_labels(
const torch::Tensor& y,
int positive_class);
156 double train_binary_classifier(
const torch::Tensor& X,
157 const torch::Tensor& y_binary,
158 const KernelParameters& params,
165 void cleanup_models();
184 const torch::Tensor& y,
185 const KernelParameters& params,
188 std::vector<int>
predict(
const torch::Tensor& X,
203 MulticlassStrategy
get_strategy_type()
const override {
return MulticlassStrategy::ONE_VS_ONE; }
206 std::vector<std::unique_ptr<svm_model>> svm_models_;
207 std::vector<std::unique_ptr<model>> linear_models_;
208 std::vector<std::pair<int, int>> class_pairs_;
209 KernelParameters params_;
210 SVMLibrary library_type_;
220 std::pair<torch::Tensor, torch::Tensor> extract_binary_data(
const torch::Tensor& X,
221 const torch::Tensor& y,
236 double train_pairwise_classifier(
const torch::Tensor& X,
237 const torch::Tensor& y,
240 const KernelParameters& params,
249 std::vector<int> vote_predictions(
const std::vector<std::vector<double>>& decisions);
254 void cleanup_models();
262 std::unique_ptr<MulticlassStrategyBase> create_multiclass_strategy(MulticlassStrategy strategy);
Data converter between libtorch tensors and SVM library formats.
Abstract base class for multiclass classification strategies.
std::vector< int > classes_
Unique class labels.
virtual int get_n_classes() const =0
Get number of classes.
virtual bool supports_probability() const =0
Check if the model supports probability prediction.
virtual MulticlassStrategy get_strategy_type() const =0
Get strategy type.
virtual std::vector< int > get_classes() const =0
Get unique class labels.
virtual TrainingMetrics fit(const torch::Tensor &X, const torch::Tensor &y, const KernelParameters ¶ms, DataConverter &converter)=0
Train the multiclass classifier.
virtual std::vector< int > predict(const torch::Tensor &X, DataConverter &converter)=0
Predict class labels.
virtual ~MulticlassStrategyBase()=default
Virtual destructor.
bool is_trained_
Whether the model is trained.
virtual std::vector< std::vector< double > > predict_proba(const torch::Tensor &X, DataConverter &converter)=0
Predict class probabilities.
virtual std::vector< std::vector< double > > decision_function(const torch::Tensor &X, DataConverter &converter)=0
Get decision function values.
One-vs-One (OvO) multiclass strategy.
int get_n_classes() const override
Get number of classes.
std::vector< int > get_classes() const override
Get unique class labels.
std::vector< std::vector< double > > decision_function(const torch::Tensor &X, DataConverter &converter) override
Get decision function values.
OneVsOneStrategy()
Constructor.
bool supports_probability() const override
Check if the model supports probability prediction.
MulticlassStrategy get_strategy_type() const override
Get strategy type.
std::vector< int > predict(const torch::Tensor &X, DataConverter &converter) override
Predict class labels.
std::vector< std::vector< double > > predict_proba(const torch::Tensor &X, DataConverter &converter) override
Predict class probabilities.
TrainingMetrics fit(const torch::Tensor &X, const torch::Tensor &y, const KernelParameters ¶ms, DataConverter &converter) override
Train the multiclass classifier.
~OneVsOneStrategy() override
Destructor.
One-vs-Rest (OvR) multiclass strategy.
bool supports_probability() const override
Check if the model supports probability prediction.
OneVsRestStrategy()
Constructor.
int get_n_classes() const override
Get number of classes.
std::vector< std::vector< double > > predict_proba(const torch::Tensor &X, DataConverter &converter) override
Predict class probabilities.
std::vector< int > get_classes() const override
Get unique class labels.
std::vector< int > predict(const torch::Tensor &X, DataConverter &converter) override
Predict class labels.
std::vector< std::vector< double > > decision_function(const torch::Tensor &X, DataConverter &converter) override
Get decision function values.
TrainingMetrics fit(const torch::Tensor &X, const torch::Tensor &y, const KernelParameters ¶ms, DataConverter &converter) override
Train the multiclass classifier.
~OneVsRestStrategy() override
Destructor.
MulticlassStrategy get_strategy_type() const override
Get strategy type.
Training metrics structure.