diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index c976408..4ddf30d 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -4,6 +4,7 @@ #include "Colors.h" #include "Folding.h" #include +#include "Paths.h" namespace bayesnet { BoostAODE::BoostAODE() : Ensemble() {} @@ -28,6 +29,9 @@ namespace bayesnet { if (hyperparameters.contains("convergence")) { convergence = hyperparameters["convergence"]; } + if (hyperparameters.contains("cfs")) { + cfs = hyperparameters["cfs"]; + } } void BoostAODE::validationInit() { @@ -58,6 +62,12 @@ namespace bayesnet { } } + void BoostAODE::initializeModels() + { + ifstream file(cfs + ".json"); + if (file.is_open()) { + } + } void BoostAODE::trainModel(const torch::Tensor& weights) { models.clear(); @@ -66,6 +76,9 @@ namespace bayesnet { maxModels = .1 * n > 10 ? .1 * n : n; validationInit(); Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); + if (cfs != "") { + initializeModels(); + } bool exitCondition = false; unordered_set featuresUsed; // Variables to control the accuracy finish condition diff --git a/src/BayesNet/BoostAODE.h b/src/BayesNet/BoostAODE.h index 61e2e95..5c99145 100644 --- a/src/BayesNet/BoostAODE.h +++ b/src/BayesNet/BoostAODE.h @@ -16,10 +16,13 @@ namespace bayesnet { torch::Tensor dataset_; torch::Tensor X_train, y_train, X_test, y_test; void validationInit(); - bool repeatSparent = false; + void initializeModels(); + // Hyperparameters + bool repeatSparent = false; // if true, a feature can be selected more than once int maxModels = 0; bool ascending = false; //Process KBest features ascending or descending order bool convergence = false; //if true, stop when the model does not improve + string cfs = ""; // if not empty, use CFS to select features }; } #endif \ No newline at end of file diff --git a/src/Platform/Dataset.cc b/src/Platform/Dataset.cc index 02a36f9..f75fdbc 100644 --- a/src/Platform/Dataset.cc +++ b/src/Platform/Dataset.cc @@ -212,14 +212,4 @@ namespace platform { } return Xd; } - vector Dataset::split(const string& text, char delimiter) - { - vector result; - stringstream ss(text); - string token; - while (getline(ss, token, delimiter)) { - result.push_back(token); - } - return result; - } } \ No newline at end of file diff --git a/src/Platform/Dataset.h b/src/Platform/Dataset.h index fbc577e..21b619e 100644 --- a/src/Platform/Dataset.h +++ b/src/Platform/Dataset.h @@ -5,6 +5,7 @@ #include #include #include "CPPFImdlp.h" +#include "Utils.h" namespace platform { using namespace std; @@ -62,7 +63,6 @@ namespace platform { public: Dataset(const string& path, const string& name, const string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {}; explicit Dataset(const Dataset&); - static vector split(const string& text, char delimiter); string getName() const; string getClassName() const; vector getFeatures() const; diff --git a/src/Platform/Datasets.cc b/src/Platform/Datasets.cc index 717ccbc..4f53a2b 100644 --- a/src/Platform/Datasets.cc +++ b/src/Platform/Datasets.cc @@ -13,7 +13,7 @@ namespace platform { if (line.empty() || line[0] == '#') { continue; } - vector tokens = Dataset::split(line, ','); + vector tokens = split(line, ','); string name = tokens[0]; string className; if (tokens.size() == 1) { diff --git a/src/Platform/DotEnv.h b/src/Platform/DotEnv.h index c481310..87ec50e 100644 --- a/src/Platform/DotEnv.h +++ b/src/Platform/DotEnv.h @@ -4,7 +4,10 @@ #include #include #include -#include "Dataset.h" +#include +#include "Utils.h" + +//#include "Dataset.h" namespace platform { class DotEnv { private: @@ -51,7 +54,7 @@ namespace platform { auto seeds_str = env["seeds"]; seeds_str = trim(seeds_str); seeds_str = seeds_str.substr(1, seeds_str.size() - 2); - auto seeds_str_split = Dataset::split(seeds_str, ','); + auto seeds_str_split = split(seeds_str, ','); transform(seeds_str_split.begin(), seeds_str_split.end(), back_inserter(seeds), [](const std::string& str) { return stoi(str); }); diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index dced445..311dbc7 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -3,7 +3,7 @@ #include "Datasets.h" #include "Models.h" #include "ReportConsole.h" -#include "DotEnv.h" +#include "Paths.h" namespace platform { using json = nlohmann::json; string get_date() @@ -134,8 +134,7 @@ namespace platform { } void Experiment::cross_validation(const string& fileName) { - auto env = platform::DotEnv(); - auto datasets = platform::Datasets(discretized, env.get("source_data")); + auto datasets = platform::Datasets(discretized, Paths::datasets()); // Get dataset auto [X, y] = datasets.getTensors(fileName); auto states = datasets.getStates(fileName); diff --git a/src/Platform/Paths.h b/src/Platform/Paths.h index 926568e..a1eb00c 100644 --- a/src/Platform/Paths.h +++ b/src/Platform/Paths.h @@ -1,11 +1,17 @@ #ifndef PATHS_H #define PATHS_H #include +#include "DotEnv.h" namespace platform { class Paths { public: static std::string results() { return "results/"; } static std::string excel() { return "excel/"; } + static std::string datasets() + { + auto env = platform::DotEnv(); + return env.get("source_data"); + } }; } #endif \ No newline at end of file diff --git a/src/Platform/ReportBase.cc b/src/Platform/ReportBase.cc index 5f113a5..acb5581 100644 --- a/src/Platform/ReportBase.cc +++ b/src/Platform/ReportBase.cc @@ -58,8 +58,7 @@ namespace platform { } } else { if (data["score_name"].get() == "accuracy") { - auto env = platform::DotEnv(); - auto dt = Datasets(false, env.get("source_data")); + auto dt = Datasets(false, Paths::datasets()); dt.loadDataset(dataset); auto numClasses = dt.getNClasses(dataset); if (numClasses == 2) { diff --git a/src/Platform/ReportConsole.cc b/src/Platform/ReportConsole.cc index bb08ef3..aaba840 100644 --- a/src/Platform/ReportConsole.cc +++ b/src/Platform/ReportConsole.cc @@ -56,10 +56,12 @@ namespace platform { try { cout << r["hyperparameters"].get(); } - catch (const exception& err) { - cout << r["hyperparameters"]; + catch (...) { + //cout << r["hyperparameters"]; + cout << "Arrggggghhhh!" << endl; } cout << endl; + cout << flush; lastResult = r; totalScore += r["score"].get(); odd = !odd; diff --git a/src/Platform/Utils.h b/src/Platform/Utils.h new file mode 100644 index 0000000..3e24f05 --- /dev/null +++ b/src/Platform/Utils.h @@ -0,0 +1,19 @@ +#ifndef UTILS_H +#define UTILS_H +#include +#include +#include +namespace platform { + //static vector split(const string& text, char delimiter); + static std::vector split(const std::string& text, char delimiter) + { + std::vector result; + std::stringstream ss(text); + std::string token; + while (std::getline(ss, token, delimiter)) { + result.push_back(token); + } + return result; + } +} +#endif \ No newline at end of file diff --git a/src/Platform/list.cc b/src/Platform/list.cc index 8c386a5..581ee5f 100644 --- a/src/Platform/list.cc +++ b/src/Platform/list.cc @@ -3,7 +3,6 @@ #include "Paths.h" #include "Colors.h" #include "Datasets.h" -#include "DotEnv.h" using namespace std; const int BALANCE_LENGTH = 75; @@ -28,8 +27,7 @@ void outputBalance(const string& balance) int main(int argc, char** argv) { - auto env = platform::DotEnv(); - auto data = platform::Datasets(false, env.get("source_data")); + auto data = platform::Datasets(false, platform::Paths::datasets()); locale mylocale(cout.getloc(), new separated); locale::global(mylocale); cout.imbue(mylocale); diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 62470c5..1101e2b 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -82,8 +82,7 @@ int main(int argc, char** argv) auto seeds = program.get>("seeds"); auto hyperparameters = program.get("hyperparameters"); vector filesToTest; - auto env = platform::DotEnv(); - auto datasets = platform::Datasets(discretize_dataset, env.get("source_data")); + auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets()); auto title = program.get("title"); auto saveResults = program.get("save"); if (file_name != "") { @@ -102,7 +101,7 @@ int main(int argc, char** argv) /* * Begin Processing */ - + auto env = platform::DotEnv(); auto experiment = platform::Experiment(); experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3"); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));