Begin AdaBoost integration

This commit is contained in:
2025-06-18 11:27:11 +02:00
parent 023d5613b4
commit 415a7ae608
10 changed files with 1001 additions and 56 deletions

View File

@@ -30,6 +30,9 @@ namespace bayesnet {
void setMaxDepth(int depth) { max_depth = depth; checkValues(); }
void setMinSamplesSplit(int samples) { min_samples_split = samples; checkValues(); }
void setMinSamplesLeaf(int samples) { min_samples_leaf = samples; checkValues(); }
int getMaxDepth() const { return max_depth; }
int getMinSamplesSplit() const { return min_samples_split; }
int getMinSamplesLeaf() const { return min_samples_leaf; }
// Override setHyperparameters
void setHyperparameters(const nlohmann::json& hyperparameters) override;
@@ -39,6 +42,12 @@ namespace bayesnet {
torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X);
// Make predictions for a single sample
int predictSample(const torch::Tensor& x) const;
// Make probabilistic predictions for a single sample
torch::Tensor predictProbaSample(const torch::Tensor& x) const;
protected:
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override
@@ -88,11 +97,7 @@ namespace bayesnet {
const torch::Tensor& sample_weights
);
// Make predictions for a single sample
int predictSample(const torch::Tensor& x) const;
// Make probabilistic predictions for a single sample
torch::Tensor predictProbaSample(const torch::Tensor& x) const;
// Traverse tree to find leaf node
const TreeNode* traverseTree(const torch::Tensor& x, const TreeNode* node) const;