bisection proposal #24

Merged
rmontanana merged 23 commits from bisection into main 2024-04-08 14:29:26 +00:00
2 changed files with 0 additions and 3 deletions
Showing only changes of commit bc0b938cfc - Show all commits

View File

@ -40,8 +40,6 @@ namespace bayesnet {
if (convergence) {
// Prepare train & validation sets from train data
auto fold = folding::StratifiedKFold(5, y_, 271);
// save input dataset
dataset_ = torch::clone(dataset);
auto [train, test] = fold.getFold(0);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);

View File

@ -16,7 +16,6 @@ namespace bayesnet {
void trainModel(const torch::Tensor& weights) override;
private:
std::vector<int> initializeModels();
torch::Tensor dataset_; // Backup the original dataset
torch::Tensor X_train, y_train, X_test, y_test;
// Hyperparameters
bool bisection = false; // if true, use bisection stratety to add k models at once to the ensemble