Remove dataset clone in BoostAODE
This commit is contained in:
parent
58d5a35a35
commit
bc0b938cfc
@ -40,8 +40,6 @@ namespace bayesnet {
|
||||
if (convergence) {
|
||||
// Prepare train & validation sets from train data
|
||||
auto fold = folding::StratifiedKFold(5, y_, 271);
|
||||
// save input dataset
|
||||
dataset_ = torch::clone(dataset);
|
||||
auto [train, test] = fold.getFold(0);
|
||||
auto train_t = torch::tensor(train);
|
||||
auto test_t = torch::tensor(test);
|
||||
|
@ -16,7 +16,6 @@ namespace bayesnet {
|
||||
void trainModel(const torch::Tensor& weights) override;
|
||||
private:
|
||||
std::vector<int> initializeModels();
|
||||
torch::Tensor dataset_; // Backup the original dataset
|
||||
torch::Tensor X_train, y_train, X_test, y_test;
|
||||
// Hyperparameters
|
||||
bool bisection = false; // if true, use bisection stratety to add k models at once to the ensemble
|
||||
|
Loading…
Reference in New Issue
Block a user