diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index 4ddf30d..aeae235 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -11,30 +11,7 @@ namespace bayesnet { void BoostAODE::buildModel(const torch::Tensor& weights) { // Models shall be built in trainModel - } - void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) - { - // Check if hyperparameters are valid - const vector validKeys = { "repeatSparent", "maxModels", "ascending", "convergence" }; - checkHyperparameters(validKeys, hyperparameters); - if (hyperparameters.contains("repeatSparent")) { - repeatSparent = hyperparameters["repeatSparent"]; - } - if (hyperparameters.contains("maxModels")) { - maxModels = hyperparameters["maxModels"]; - } - if (hyperparameters.contains("ascending")) { - ascending = hyperparameters["ascending"]; - } - if (hyperparameters.contains("convergence")) { - convergence = hyperparameters["convergence"]; - } - if (hyperparameters.contains("cfs")) { - cfs = hyperparameters["cfs"]; - } - } - void BoostAODE::validationInit() - { + // Prepare the validation dataset auto y_ = dataset.index({ -1, "..." }); if (convergence) { // Prepare train & validation sets from train data @@ -60,12 +37,43 @@ namespace bayesnet { X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); y_train = y_; } - + if (cfs != "") { + initializeModels(); + } + } + void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) + { + // Check if hyperparameters are valid + const vector validKeys = { "repeatSparent", "maxModels", "ascending", "convergence", "cfs" }; + checkHyperparameters(validKeys, hyperparameters); + if (hyperparameters.contains("repeatSparent")) { + repeatSparent = hyperparameters["repeatSparent"]; + } + if (hyperparameters.contains("maxModels")) { + maxModels = hyperparameters["maxModels"]; + } + if (hyperparameters.contains("ascending")) { + ascending = hyperparameters["ascending"]; + } + if (hyperparameters.contains("convergence")) { + convergence = hyperparameters["convergence"]; + } + if (hyperparameters.contains("cfs")) { + cfs = hyperparameters["cfs"]; + } } void BoostAODE::initializeModels() { ifstream file(cfs + ".json"); if (file.is_open()) { + nlohmann::json data; + file >> data; + file.close(); + auto model = "iris"; // has to come in when building object + auto features = data[model]; + cout << "features: " << features.dump() << endl; + } else { + throw runtime_error("File " + cfs + ".json not found"); } } void BoostAODE::trainModel(const torch::Tensor& weights) @@ -74,11 +82,7 @@ namespace bayesnet { n_models = 0; if (maxModels == 0) maxModels = .1 * n > 10 ? .1 * n : n; - validationInit(); Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); - if (cfs != "") { - initializeModels(); - } bool exitCondition = false; unordered_set featuresUsed; // Variables to control the accuracy finish condition diff --git a/src/BayesNet/BoostAODE.h b/src/BayesNet/BoostAODE.h index 5c99145..f3fa5bd 100644 --- a/src/BayesNet/BoostAODE.h +++ b/src/BayesNet/BoostAODE.h @@ -15,7 +15,6 @@ namespace bayesnet { private: torch::Tensor dataset_; torch::Tensor X_train, y_train, X_test, y_test; - void validationInit(); void initializeModels(); // Hyperparameters bool repeatSparent = false; // if true, a feature can be selected more than once diff --git a/src/Platform/ReportConsole.cc b/src/Platform/ReportConsole.cc index aaba840..c8e6890 100644 --- a/src/Platform/ReportConsole.cc +++ b/src/Platform/ReportConsole.cc @@ -53,13 +53,7 @@ namespace platform { const string status = compareResult(r["dataset"].get(), r["score"].get()); cout << status; cout << setw(12) << right << setprecision(6) << fixed << r["time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get() << " "; - try { - cout << r["hyperparameters"].get(); - } - catch (...) { - //cout << r["hyperparameters"]; - cout << "Arrggggghhhh!" << endl; - } + cout << r["hyperparameters"].dump(); cout << endl; cout << flush; lastResult = r; diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 1101e2b..ecdf258 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -12,7 +12,7 @@ using namespace std; using json = nlohmann::json; -argparse::ArgumentParser manageArguments(int argc, char** argv) +argparse::ArgumentParser manageArguments() { auto env = platform::DotEnv(); argparse::ArgumentParser program("main"); @@ -48,43 +48,40 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) }}); auto seed_values = env.getSeeds(); program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); + return program; +} + +int main(int argc, char** argv) +{ + string file_name, model_name, title; + json hyperparameters_json; + bool discretize_dataset, stratified, saveResults; + vector seeds; + vector filesToTest; + int n_folds; + auto program = manageArguments(); try { program.parse_args(argc, argv); - auto file_name = program.get("dataset"); - auto model_name = program.get("model"); - auto discretize_dataset = program.get("discretize"); - auto stratified = program.get("stratified"); - auto n_folds = program.get("folds"); - auto seeds = program.get>("seeds"); - auto title = program.get("title"); + file_name = program.get("dataset"); + model_name = program.get("model"); + discretize_dataset = program.get("discretize"); + stratified = program.get("stratified"); + n_folds = program.get("folds"); + seeds = program.get>("seeds"); auto hyperparameters = program.get("hyperparameters"); - auto saveResults = program.get("save"); + hyperparameters_json = json::parse(hyperparameters); + title = program.get("title"); if (title == "" && file_name == "") { throw runtime_error("title is mandatory if dataset is not provided"); } + saveResults = program.get("save"); } catch (const exception& err) { cerr << err.what() << endl; cerr << program; exit(1); } - return program; -} - -int main(int argc, char** argv) -{ - auto program = manageArguments(argc, argv); - auto file_name = program.get("dataset"); - auto model_name = program.get("model"); - auto discretize_dataset = program.get("discretize"); - auto stratified = program.get("stratified"); - auto n_folds = program.get("folds"); - auto seeds = program.get>("seeds"); - auto hyperparameters = program.get("hyperparameters"); - vector filesToTest; auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets()); - auto title = program.get("title"); - auto saveResults = program.get("save"); if (file_name != "") { if (!datasets.isDataset(file_name)) { cerr << "Dataset " << file_name << " not found" << endl; @@ -106,7 +103,7 @@ int main(int argc, char** argv) experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3"); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform")); experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy"); - experiment.setHyperparameters(json::parse(hyperparameters)); + experiment.setHyperparameters(hyperparameters_json); for (auto seed : seeds) { experiment.addRandomSeed(seed); }