Add DecisionTree with tests
This commit is contained in:
@@ -9,13 +9,12 @@
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <torch/torch.h>
|
||||
#include <bayesnet/ensembles/Ensemble.h>
|
||||
#include "bayesnet/ensembles/Ensemble.h"
|
||||
|
||||
namespace platform {
|
||||
class AdaBoost : public bayesnet::Ensemble {
|
||||
namespace bayesnet {
|
||||
class AdaBoost : public Ensemble {
|
||||
public:
|
||||
explicit AdaBoost(int n_estimators = 100);
|
||||
explicit AdaBoost(int n_estimators = 50, int max_depth = 1);
|
||||
virtual ~AdaBoost() = default;
|
||||
|
||||
// Override base class methods
|
||||
@@ -24,10 +23,15 @@ namespace platform {
|
||||
// AdaBoost specific methods
|
||||
void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; }
|
||||
int getNEstimators() const { return n_estimators; }
|
||||
void setBaseMaxDepth(int depth) { this->base_max_depth = depth; }
|
||||
int getBaseMaxDepth() const { return base_max_depth; }
|
||||
|
||||
// Get the weight of each base estimator
|
||||
std::vector<double> getEstimatorWeights() const { return alphas; }
|
||||
|
||||
// Get training errors for each iteration
|
||||
std::vector<double> getTrainingErrors() const { return training_errors; }
|
||||
|
||||
// Override setHyperparameters from BaseClassifier
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
|
||||
@@ -37,6 +41,7 @@ namespace platform {
|
||||
|
||||
private:
|
||||
int n_estimators;
|
||||
int base_max_depth; // Max depth for base decision trees
|
||||
std::vector<double> alphas; // Weight of each base estimator
|
||||
std::vector<double> training_errors; // Training error at each iteration
|
||||
torch::Tensor sample_weights; // Current sample weights
|
||||
|
Reference in New Issue
Block a user