SVM Classifier C++ 1.0.0
High-performance Support Vector Machine classifier with scikit-learn compatible API
Loading...
Searching...
No Matches
Public Member Functions | Protected Attributes | List of all members
svm_classifier::MulticlassStrategyBase Class Referenceabstract

Abstract base class for multiclass classification strategies. More...

#include <multiclass_strategy.hpp>

Inheritance diagram for svm_classifier::MulticlassStrategyBase:
Inheritance graph
[legend]

Public Member Functions

virtual ~MulticlassStrategyBase ()=default
 Virtual destructor.
 
virtual TrainingMetrics fit (const torch::Tensor &X, const torch::Tensor &y, const KernelParameters &params, DataConverter &converter)=0
 Train the multiclass classifier.
 
virtual std::vector< int > predict (const torch::Tensor &X, DataConverter &converter)=0
 Predict class labels.
 
virtual std::vector< std::vector< double > > predict_proba (const torch::Tensor &X, DataConverter &converter)=0
 Predict class probabilities.
 
virtual std::vector< std::vector< double > > decision_function (const torch::Tensor &X, DataConverter &converter)=0
 Get decision function values.
 
virtual std::vector< int > get_classes () const =0
 Get unique class labels.
 
virtual bool supports_probability () const =0
 Check if the model supports probability prediction.
 
virtual int get_n_classes () const =0
 Get number of classes.
 
virtual MulticlassStrategy get_strategy_type () const =0
 Get strategy type.
 

Protected Attributes

std::vector< int > classes_
 Unique class labels.
 
bool is_trained_ = false
 Whether the model is trained.
 

Detailed Description

Abstract base class for multiclass classification strategies.

Definition at line 20 of file multiclass_strategy.hpp.

Member Function Documentation

◆ decision_function()

virtual std::vector< std::vector< double > > svm_classifier::MulticlassStrategyBase::decision_function ( const torch::Tensor &  X,
DataConverter converter 
)
pure virtual

Get decision function values.

Parameters
XFeature tensor of shape (n_samples, n_features)
converterData converter instance
Returns
Decision function values

Implemented in svm_classifier::OneVsRestStrategy, and svm_classifier::OneVsOneStrategy.

◆ fit()

virtual TrainingMetrics svm_classifier::MulticlassStrategyBase::fit ( const torch::Tensor &  X,
const torch::Tensor &  y,
const KernelParameters &  params,
DataConverter converter 
)
pure virtual

Train the multiclass classifier.

Parameters
XFeature tensor of shape (n_samples, n_features)
yTarget tensor of shape (n_samples,)
paramsKernel parameters
converterData converter instance
Returns
Training metrics

Implemented in svm_classifier::OneVsRestStrategy, and svm_classifier::OneVsOneStrategy.

◆ get_classes()

virtual std::vector< int > svm_classifier::MulticlassStrategyBase::get_classes ( ) const
pure virtual

Get unique class labels.

Returns
Vector of unique class labels

Implemented in svm_classifier::OneVsRestStrategy, and svm_classifier::OneVsOneStrategy.

◆ get_n_classes()

virtual int svm_classifier::MulticlassStrategyBase::get_n_classes ( ) const
pure virtual

Get number of classes.

Returns
Number of classes

Implemented in svm_classifier::OneVsRestStrategy, and svm_classifier::OneVsOneStrategy.

◆ get_strategy_type()

virtual MulticlassStrategy svm_classifier::MulticlassStrategyBase::get_strategy_type ( ) const
pure virtual

Get strategy type.

Returns
Multiclass strategy type

Implemented in svm_classifier::OneVsRestStrategy, and svm_classifier::OneVsOneStrategy.

◆ predict()

virtual std::vector< int > svm_classifier::MulticlassStrategyBase::predict ( const torch::Tensor &  X,
DataConverter converter 
)
pure virtual

Predict class labels.

Parameters
XFeature tensor of shape (n_samples, n_features)
converterData converter instance
Returns
Predicted class labels

Implemented in svm_classifier::OneVsRestStrategy, and svm_classifier::OneVsOneStrategy.

◆ predict_proba()

virtual std::vector< std::vector< double > > svm_classifier::MulticlassStrategyBase::predict_proba ( const torch::Tensor &  X,
DataConverter converter 
)
pure virtual

Predict class probabilities.

Parameters
XFeature tensor of shape (n_samples, n_features)
converterData converter instance
Returns
Class probabilities for each sample

Implemented in svm_classifier::OneVsRestStrategy, and svm_classifier::OneVsOneStrategy.

◆ supports_probability()

virtual bool svm_classifier::MulticlassStrategyBase::supports_probability ( ) const
pure virtual

Check if the model supports probability prediction.

Returns
True if probabilities are supported

Implemented in svm_classifier::OneVsRestStrategy, and svm_classifier::OneVsOneStrategy.

Member Data Documentation

◆ classes_

std::vector<int> svm_classifier::MulticlassStrategyBase::classes_
protected

Unique class labels.

Definition at line 92 of file multiclass_strategy.hpp.

◆ is_trained_

bool svm_classifier::MulticlassStrategyBase::is_trained_ = false
protected

Whether the model is trained.

Definition at line 93 of file multiclass_strategy.hpp.


The documentation for this class was generated from the following file: