Begin AdaBoost integration
This commit is contained in:
@@ -21,9 +21,9 @@ namespace bayesnet {
|
||||
std::vector<std::string> graph(const std::string& title = "") const override;
|
||||
|
||||
// AdaBoost specific methods
|
||||
void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; }
|
||||
void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; checkValues(); }
|
||||
int getNEstimators() const { return n_estimators; }
|
||||
void setBaseMaxDepth(int depth) { this->base_max_depth = depth; }
|
||||
void setBaseMaxDepth(int depth) { this->base_max_depth = depth; checkValues(); }
|
||||
int getBaseMaxDepth() const { return base_max_depth; }
|
||||
|
||||
// Get the weight of each base estimator
|
||||
@@ -35,6 +35,11 @@ namespace bayesnet {
|
||||
// Override setHyperparameters from BaseClassifier
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
|
||||
torch::Tensor predict(torch::Tensor& X) override;
|
||||
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
||||
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X);
|
||||
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override;
|
||||
@@ -45,6 +50,8 @@ namespace bayesnet {
|
||||
std::vector<double> alphas; // Weight of each base estimator
|
||||
std::vector<double> training_errors; // Training error at each iteration
|
||||
torch::Tensor sample_weights; // Current sample weights
|
||||
int n_classes; // Number of classes in the target variable
|
||||
int n; // Number of features
|
||||
|
||||
// Train a single base estimator
|
||||
std::unique_ptr<Classifier> trainBaseEstimator(const torch::Tensor& weights);
|
||||
@@ -57,6 +64,15 @@ namespace bayesnet {
|
||||
|
||||
// Normalize weights to sum to 1
|
||||
void normalizeWeights();
|
||||
|
||||
// Check if hyperparameters values are valid
|
||||
void checkValues() const;
|
||||
|
||||
// Make predictions for a single sample
|
||||
int predictSample(const torch::Tensor& x) const;
|
||||
|
||||
// Make probabilistic predictions for a single sample
|
||||
torch::Tensor predictProbaSample(const torch::Tensor& x) const;
|
||||
};
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user