Remove dataset clone in BoostAODE
This commit is contained in:
parent
58d5a35a35
commit
bc0b938cfc
@ -40,8 +40,6 @@ namespace bayesnet {
|
|||||||
if (convergence) {
|
if (convergence) {
|
||||||
// Prepare train & validation sets from train data
|
// Prepare train & validation sets from train data
|
||||||
auto fold = folding::StratifiedKFold(5, y_, 271);
|
auto fold = folding::StratifiedKFold(5, y_, 271);
|
||||||
// save input dataset
|
|
||||||
dataset_ = torch::clone(dataset);
|
|
||||||
auto [train, test] = fold.getFold(0);
|
auto [train, test] = fold.getFold(0);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
auto test_t = torch::tensor(test);
|
auto test_t = torch::tensor(test);
|
||||||
|
@ -16,7 +16,6 @@ namespace bayesnet {
|
|||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
private:
|
private:
|
||||||
std::vector<int> initializeModels();
|
std::vector<int> initializeModels();
|
||||||
torch::Tensor dataset_; // Backup the original dataset
|
|
||||||
torch::Tensor X_train, y_train, X_test, y_test;
|
torch::Tensor X_train, y_train, X_test, y_test;
|
||||||
// Hyperparameters
|
// Hyperparameters
|
||||||
bool bisection = false; // if true, use bisection stratety to add k models at once to the ensemble
|
bool bisection = false; // if true, use bisection stratety to add k models at once to the ensemble
|
||||||
|
Loading…
Reference in New Issue
Block a user