From e7ded6826792d14abc7aeeaa1ca877c8e0382b4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Tue, 10 Oct 2023 23:00:38 +0200 Subject: [PATCH] First cfs working version --- src/BayesNet/BoostAODE.cc | 33 +++++++++++++++++++++------------ src/BayesNet/BoostAODE.h | 4 ++-- src/Platform/Paths.h | 1 + 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index d3e8901..0952a7a 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -13,10 +13,10 @@ namespace bayesnet { void BoostAODE::buildModel(const torch::Tensor& weights) { // Models shall be built in trainModel + models.clear(); + n_models = 0; // Prepare the validation dataset auto y_ = dataset.index({ -1, "..." }); - int nSamples = dataset.size(1); - int nFeatures = dataset.size(0) - 1; if (convergence) { // Prepare train & validation sets from train data auto fold = platform::StratifiedKFold(5, y_, 271); @@ -41,8 +41,8 @@ namespace bayesnet { X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); y_train = y_; } - if (cfs != "") { - initializeModels(nSamples, nFeatures); + if (cfs) { + initializeModels(); } } void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) @@ -82,18 +82,18 @@ namespace bayesnet { EVP_MD_CTX_free(mdctx); stringstream oss; for (unsigned int i = 0; i < hash_len; i++) { - oss << hex << (int)hash[i]; + oss << hex << setfill('0') << setw(2) << (int)hash[i]; } return oss.str(); } - void BoostAODE::initializeModels(int nSamples, int nFeatures) + void BoostAODE::initializeModels() { // Read the CFS features string output = "[", prefix = ""; bool first = true; for (const auto& feature : features) { - output += prefix + feature; + output += prefix + "'" + feature + "'"; if (first) { prefix = ", "; first = false; @@ -103,21 +103,30 @@ namespace bayesnet { // std::size_t str_hash = std::hash{}(output); string str_hash = sha256(output); stringstream oss; - oss << "cfs/" << str_hash << ".json"; + oss << platform::Paths::cfs() << str_hash << ".json"; string name = oss.str(); ifstream file(name); + Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); if (file.is_open()) { - nlohmann::json features = nlohmann::json::parse(file); + nlohmann::json cfsFeatures = nlohmann::json::parse(file); file.close(); - cout << "features: " << features.dump() << endl; + for (const string& feature : cfsFeatures) { + // cout << "Feature: [" << feature << "]" << endl; + auto pos = find(features.begin(), features.end(), feature); + if (pos == features.end()) + throw runtime_error("Feature " + feature + " not found in dataset"); + int numFeature = pos - features.begin(); + cout << "Feature: [" << feature << "] " << numFeature << endl; + models.push_back(std::make_unique(numFeature)); + models.back()->fit(dataset, features, className, states, weights_); + n_models++; + } } else { throw runtime_error("File " + name + " not found"); } } void BoostAODE::trainModel(const torch::Tensor& weights) { - models.clear(); - n_models = 0; if (maxModels == 0) maxModels = .1 * n > 10 ? .1 * n : n; Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); diff --git a/src/BayesNet/BoostAODE.h b/src/BayesNet/BoostAODE.h index 3464a7d..683cb99 100644 --- a/src/BayesNet/BoostAODE.h +++ b/src/BayesNet/BoostAODE.h @@ -15,13 +15,13 @@ namespace bayesnet { private: torch::Tensor dataset_; torch::Tensor X_train, y_train, X_test, y_test; - void initializeModels(int nSamples, int nFeatures); + void initializeModels(); // Hyperparameters bool repeatSparent = false; // if true, a feature can be selected more than once int maxModels = 0; bool ascending = false; //Process KBest features ascending or descending order bool convergence = false; //if true, stop when the model does not improve - string cfs = ""; // if not empty, use CFS to select features + bool cfs = false; // if true use CFS to select features stored in cfs folder with sha256(features) file_name }; } #endif \ No newline at end of file diff --git a/src/Platform/Paths.h b/src/Platform/Paths.h index a1eb00c..16d459c 100644 --- a/src/Platform/Paths.h +++ b/src/Platform/Paths.h @@ -7,6 +7,7 @@ namespace platform { public: static std::string results() { return "results/"; } static std::string excel() { return "excel/"; } + static std::string cfs() { return "cfs/"; } static std::string datasets() { auto env = platform::DotEnv();