diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index 61d013d..d918172 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -2,6 +2,7 @@ #include #include "BayesMetrics.h" #include "Colors.h" +#include "Folding.h" namespace bayesnet { BoostAODE::BoostAODE() : Ensemble() {} @@ -24,15 +25,35 @@ namespace bayesnet { ascending = hyperparameters["ascending"]; } } + void BoostAODE::validationInit() + { + auto y_ = dataset.index({ -1, "..." }); + auto fold = platform::StratifiedKFold(5, y_, 271); + // save input dataset + dataset_ = torch::clone(dataset); + auto [train, test] = fold.getFold(0); + auto train_t = torch::tensor(train); + auto test_t = torch::tensor(test); + // Get train and validation sets + X_train = dataset.index({ "...", train_t }); + y_train = dataset.index({ -1, train_t }); + X_test = dataset.index({ "...", test_t }); + y_test = dataset.index({ -1, test_t }); + // Build dataset with train data + dataset = X_train; + buildDataset(y_train); + m = X_train.size(1); + auto n_classes = states.at(className).size(); + metrics = Metrics(dataset, features, className, n_classes); + } void BoostAODE::trainModel(const torch::Tensor& weights) { models.clear(); n_models = 0; if (maxModels == 0) maxModels = .1 * n > 10 ? .1 * n : n; + validationInit(); Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); - auto X_ = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); - auto y_ = dataset.index({ -1, "..." }); bool exitCondition = false; unordered_set featuresUsed; // Step 0: Set the finish condition @@ -62,10 +83,10 @@ namespace bayesnet { model = std::make_unique(feature); n_models++; model->fit(dataset, features, className, states, weights_); - auto ypred = model->predict(X_); + auto ypred = model->predict(X_train); // Step 3.1: Compute the classifier amout of say - auto mask_wrong = ypred != y_; - auto mask_right = ypred == y_; + auto mask_wrong = ypred != y_train; + auto mask_right = ypred == y_train; auto masked_weights = weights_ * mask_wrong.to(weights_.dtype()); double epsilon_t = masked_weights.sum().item(); double wt = (1 - epsilon_t) / epsilon_t; diff --git a/src/BayesNet/BoostAODE.h b/src/BayesNet/BoostAODE.h index 508086b..74ca02a 100644 --- a/src/BayesNet/BoostAODE.h +++ b/src/BayesNet/BoostAODE.h @@ -13,9 +13,12 @@ namespace bayesnet { void buildModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights) override; private: - bool repeatSparent=false; - int maxModels=0; - bool ascending=false; //Process KBest features ascending or descending order + torch::Tensor dataset_; + torch::Tensor X_train, y_train, X_test, y_test; + void validationInit(); + bool repeatSparent = false; + int maxModels = 0; + bool ascending = false; //Process KBest features ascending or descending order }; } #endif \ No newline at end of file diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index aca6ea5..7727939 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -75,7 +75,7 @@ namespace bayesnet { throw invalid_argument("dataset (X, y) must be of type Integer"); } if (n != features.size()) { - throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features"); + throw invalid_argument("Classifier: X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features"); } if (states.find(className) == states.end()) { throw invalid_argument("className not found in states"); diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index 011987b..5dd3040 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -10,7 +10,6 @@ using namespace torch; namespace bayesnet { class Classifier : public BaseClassifier { private: - void buildDataset(torch::Tensor& y); Classifier& build(const vector& features, const string& className, map>& states, const torch::Tensor& weights); protected: bool fitted; @@ -26,6 +25,7 @@ namespace bayesnet { virtual void buildModel(const torch::Tensor& weights) = 0; void trainModel(const torch::Tensor& weights) override; void checkHyperparameters(const vector& validKeys, nlohmann::json& hyperparameters); + void buildDataset(torch::Tensor& y); public: Classifier(Network model); virtual ~Classifier() = default;