Abstract base class for multiclass classification strategies.
More...
#include <multiclass_strategy.hpp>
|
virtual | ~MulticlassStrategyBase ()=default |
| Virtual destructor.
|
|
virtual TrainingMetrics | fit (const torch::Tensor &X, const torch::Tensor &y, const KernelParameters ¶ms, DataConverter &converter)=0 |
| Train the multiclass classifier.
|
|
virtual std::vector< int > | predict (const torch::Tensor &X, DataConverter &converter)=0 |
| Predict class labels.
|
|
virtual std::vector< std::vector< double > > | predict_proba (const torch::Tensor &X, DataConverter &converter)=0 |
| Predict class probabilities.
|
|
virtual std::vector< std::vector< double > > | decision_function (const torch::Tensor &X, DataConverter &converter)=0 |
| Get decision function values.
|
|
virtual std::vector< int > | get_classes () const =0 |
| Get unique class labels.
|
|
virtual bool | supports_probability () const =0 |
| Check if the model supports probability prediction.
|
|
virtual int | get_n_classes () const =0 |
| Get number of classes.
|
|
virtual MulticlassStrategy | get_strategy_type () const =0 |
| Get strategy type.
|
|
|
std::vector< int > | classes_ |
| Unique class labels.
|
|
bool | is_trained_ = false |
| Whether the model is trained.
|
|
Abstract base class for multiclass classification strategies.
Definition at line 20 of file multiclass_strategy.hpp.
◆ decision_function()
virtual std::vector< std::vector< double > > svm_classifier::MulticlassStrategyBase::decision_function |
( |
const torch::Tensor & |
X, |
|
|
DataConverter & |
converter |
|
) |
| |
|
pure virtual |
◆ fit()
virtual TrainingMetrics svm_classifier::MulticlassStrategyBase::fit |
( |
const torch::Tensor & |
X, |
|
|
const torch::Tensor & |
y, |
|
|
const KernelParameters & |
params, |
|
|
DataConverter & |
converter |
|
) |
| |
|
pure virtual |
◆ get_classes()
virtual std::vector< int > svm_classifier::MulticlassStrategyBase::get_classes |
( |
| ) |
const |
|
pure virtual |
◆ get_n_classes()
virtual int svm_classifier::MulticlassStrategyBase::get_n_classes |
( |
| ) |
const |
|
pure virtual |
◆ get_strategy_type()
virtual MulticlassStrategy svm_classifier::MulticlassStrategyBase::get_strategy_type |
( |
| ) |
const |
|
pure virtual |
◆ predict()
virtual std::vector< int > svm_classifier::MulticlassStrategyBase::predict |
( |
const torch::Tensor & |
X, |
|
|
DataConverter & |
converter |
|
) |
| |
|
pure virtual |
◆ predict_proba()
virtual std::vector< std::vector< double > > svm_classifier::MulticlassStrategyBase::predict_proba |
( |
const torch::Tensor & |
X, |
|
|
DataConverter & |
converter |
|
) |
| |
|
pure virtual |
◆ supports_probability()
virtual bool svm_classifier::MulticlassStrategyBase::supports_probability |
( |
| ) |
const |
|
pure virtual |
◆ classes_
std::vector<int> svm_classifier::MulticlassStrategyBase::classes_ |
|
protected |
◆ is_trained_
bool svm_classifier::MulticlassStrategyBase::is_trained_ = false |
|
protected |
The documentation for this class was generated from the following file: