Fix metrics error in BoostAODE Convergence

Update algorithm
This commit is contained in:
2024-03-20 23:33:02 +01:00
parent 5826702fc7
commit 6e854dfda3
6 changed files with 98 additions and 65 deletions

View File

@@ -15,8 +15,8 @@ namespace bayesnet {
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights) override;
private:
std::unordered_set<int> initializeModels();
torch::Tensor dataset_;
std::vector<int> initializeModels();
torch::Tensor dataset_; // Backup the original dataset
torch::Tensor X_train, y_train, X_test, y_test;
// Hyperparameters
bool bisection = false; // if true, use bisection stratety to add k models at once to the ensemble