Files
SVMClassifier/include/svm_classifier/multiclass_strategy.hpp
Ricardo Montañana Gómez 7b27d5c1f3
Some checks failed
CI/CD Pipeline / Create Release Package (push) Has been cancelled
CI/CD Pipeline / Code Linting (push) Has been cancelled
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Has been cancelled
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Has been cancelled
CI/CD Pipeline / Docker Build Test (push) Has been cancelled
CI/CD Pipeline / Performance Benchmarks (push) Has been cancelled
CI/CD Pipeline / Build Documentation (push) Has been cancelled
library comile complete and begin tests
2025-06-23 12:05:35 +02:00

264 lines
8.9 KiB
C++

#pragma once
#include "types.hpp"
#include "kernel_parameters.hpp"
#include "data_converter.hpp"
#include <torch/torch.h>
#include <vector>
#include <memory>
#include <unordered_map>
// Forward declarations for external library structures
struct svm_model;
struct model;
namespace svm_classifier {
/**
* @brief Abstract base class for multiclass classification strategies
*/
class MulticlassStrategyBase {
public:
/**
* @brief Virtual destructor
*/
virtual ~MulticlassStrategyBase() = default;
/**
* @brief Train the multiclass classifier
* @param X Feature tensor of shape (n_samples, n_features)
* @param y Target tensor of shape (n_samples,)
* @param params Kernel parameters
* @param converter Data converter instance
* @return Training metrics
*/
virtual TrainingMetrics fit(const torch::Tensor& X,
const torch::Tensor& y,
const KernelParameters& params,
DataConverter& converter) = 0;
/**
* @brief Predict class labels
* @param X Feature tensor of shape (n_samples, n_features)
* @param converter Data converter instance
* @return Predicted class labels
*/
virtual std::vector<int> predict(const torch::Tensor& X,
DataConverter& converter) = 0;
/**
* @brief Predict class probabilities
* @param X Feature tensor of shape (n_samples, n_features)
* @param converter Data converter instance
* @return Class probabilities for each sample
*/
virtual std::vector<std::vector<double>> predict_proba(const torch::Tensor& X,
DataConverter& converter) = 0;
/**
* @brief Get decision function values
* @param X Feature tensor of shape (n_samples, n_features)
* @param converter Data converter instance
* @return Decision function values
*/
virtual std::vector<std::vector<double>> decision_function(const torch::Tensor& X,
DataConverter& converter) = 0;
/**
* @brief Get unique class labels
* @return Vector of unique class labels
*/
virtual std::vector<int> get_classes() const = 0;
/**
* @brief Check if the model supports probability prediction
* @return True if probabilities are supported
*/
virtual bool supports_probability() const = 0;
/**
* @brief Get number of classes
* @return Number of classes
*/
virtual int get_n_classes() const = 0;
/**
* @brief Get strategy type
* @return Multiclass strategy type
*/
virtual MulticlassStrategy get_strategy_type() const = 0;
protected:
std::vector<int> classes_; ///< Unique class labels
bool is_trained_ = false; ///< Whether the model is trained
};
/**
* @brief One-vs-Rest (OvR) multiclass strategy
*/
class OneVsRestStrategy : public MulticlassStrategyBase {
public:
/**
* @brief Constructor
*/
OneVsRestStrategy();
/**
* @brief Destructor
*/
~OneVsRestStrategy() override;
TrainingMetrics fit(const torch::Tensor& X,
const torch::Tensor& y,
const KernelParameters& params,
DataConverter& converter) override;
std::vector<int> predict(const torch::Tensor& X,
DataConverter& converter) override;
std::vector<std::vector<double>> predict_proba(const torch::Tensor& X,
DataConverter& converter) override;
std::vector<std::vector<double>> decision_function(const torch::Tensor& X,
DataConverter& converter) override;
std::vector<int> get_classes() const override { return classes_; }
bool supports_probability() const override;
int get_n_classes() const override { return static_cast<int>(classes_.size()); }
MulticlassStrategy get_strategy_type() const override { return MulticlassStrategy::ONE_VS_REST; }
private:
std::vector<std::unique_ptr<svm_model>> svm_models_; ///< SVM models (one per class)
std::vector<std::unique_ptr<model>> linear_models_; ///< Linear models (one per class)
KernelParameters params_; ///< Stored parameters
SVMLibrary library_type_; ///< Which library is being used
/**
* @brief Create binary labels for one-vs-rest
* @param y Original labels
* @param positive_class Positive class label
* @return Binary labels (+1 for positive class, -1 for others)
*/
torch::Tensor create_binary_labels(const torch::Tensor& y, int positive_class);
/**
* @brief Train a single binary classifier
* @param X Feature tensor
* @param y_binary Binary labels
* @param params Kernel parameters
* @param converter Data converter
* @param class_idx Index of the class being trained
* @return Training time for this classifier
*/
double train_binary_classifier(const torch::Tensor& X,
const torch::Tensor& y_binary,
const KernelParameters& params,
DataConverter& converter,
int class_idx);
/**
* @brief Clean up all models
*/
void cleanup_models();
};
/**
* @brief One-vs-One (OvO) multiclass strategy
*/
class OneVsOneStrategy : public MulticlassStrategyBase {
public:
/**
* @brief Constructor
*/
OneVsOneStrategy();
/**
* @brief Destructor
*/
~OneVsOneStrategy() override;
TrainingMetrics fit(const torch::Tensor& X,
const torch::Tensor& y,
const KernelParameters& params,
DataConverter& converter) override;
std::vector<int> predict(const torch::Tensor& X,
DataConverter& converter) override;
std::vector<std::vector<double>> predict_proba(const torch::Tensor& X,
DataConverter& converter) override;
std::vector<std::vector<double>> decision_function(const torch::Tensor& X,
DataConverter& converter) override;
std::vector<int> get_classes() const override { return classes_; }
bool supports_probability() const override;
int get_n_classes() const override { return static_cast<int>(classes_.size()); }
MulticlassStrategy get_strategy_type() const override { return MulticlassStrategy::ONE_VS_ONE; }
private:
std::vector<std::unique_ptr<svm_model>> svm_models_; ///< SVM models (one per pair)
std::vector<std::unique_ptr<model>> linear_models_; ///< Linear models (one per pair)
std::vector<std::pair<int, int>> class_pairs_; ///< Class pairs for each model
KernelParameters params_; ///< Stored parameters
SVMLibrary library_type_; ///< Which library is being used
/**
* @brief Extract samples for a specific class pair
* @param X Feature tensor
* @param y Label tensor
* @param class1 First class
* @param class2 Second class
* @return Pair of (filtered_X, filtered_y)
*/
std::pair<torch::Tensor, torch::Tensor> extract_binary_data(const torch::Tensor& X,
const torch::Tensor& y,
int class1,
int class2);
/**
* @brief Train a single pairwise classifier
* @param X Feature tensor
* @param y Labels
* @param class1 First class
* @param class2 Second class
* @param params Kernel parameters
* @param converter Data converter
* @param model_idx Index of the model being trained
* @return Training time for this classifier
*/
double train_pairwise_classifier(const torch::Tensor& X,
const torch::Tensor& y,
int class1,
int class2,
const KernelParameters& params,
DataConverter& converter,
int model_idx);
/**
* @brief Voting mechanism for OvO predictions
* @param decisions Matrix of pairwise decisions
* @return Predicted class for each sample
*/
std::vector<int> vote_predictions(const std::vector<std::vector<double>>& decisions);
/**
* @brief Clean up all models
*/
void cleanup_models();
};
/**
* @brief Factory function to create multiclass strategy
* @param strategy Strategy type
* @return Unique pointer to multiclass strategy
*/
std::unique_ptr<MulticlassStrategyBase> create_multiclass_strategy(MulticlassStrategy strategy);
} // namespace svm_classifier