diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index 6288ebc..087854f 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -48,6 +48,7 @@ void manageArguments(argparse::ArgumentParser& program) program.add_argument("--title").default_value("").help("Experiment title"); program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--discretize-algo").help("Discretize input dataset").default_value(env.get("discretize_algo")); + program.add_argument("--smooth-strat").help("Discretize input dataset").default_value(env.get("smooth_strat")); program.add_argument("--generate-fold-files").help("generate fold information in datasets_experiment folder").default_value(false).implicit_value(true); program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); @@ -75,7 +76,7 @@ int main(int argc, char** argv) { argparse::ArgumentParser program("b_main", { platform_project_version.begin(), platform_project_version.end() }); manageArguments(program); - std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo; + std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat; json hyperparameters_json; bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files; std::vector seeds; @@ -90,6 +91,7 @@ int main(int argc, char** argv) model_name = program.get("model"); discretize_dataset = program.get("discretize"); discretize_algo = program.get("discretize-algo"); + smooth_strat = program.get("smooth-strat"); stratified = program.get("stratified"); quiet = program.get("quiet"); n_folds = program.get("folds"); @@ -180,7 +182,7 @@ int main(int argc, char** argv) auto env = platform::DotEnv(); auto experiment = platform::Experiment(); experiment.setTitle(title).setLanguage("c++").setLanguageVersion("gcc 14.1.1"); - experiment.setDiscretizationAlgorithm(discretize_algo); + experiment.setDiscretizationAlgorithm(discretize_algo).setSmoothSrategy(smooth_strat); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform")); experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy"); experiment.setHyperparameters(test_hyperparams); diff --git a/src/common/DotEnv.h b/src/common/DotEnv.h index a50ad5e..4fea7bf 100644 --- a/src/common/DotEnv.h +++ b/src/common/DotEnv.h @@ -30,7 +30,8 @@ namespace platform { {"margin", {"0.1", "0.2", "0.3"}}, {"n_folds", {"5", "10"}}, {"discretize_algo", {"mdlp", "bin3u", "bin3q", "bin4u", "bin4q"}}, - {"platform", {"any"}}, + {"smooth_strat", {"OLD_LAPLACE", "LAPLACE", "CESTNIK"}}, + { "platform", {"any"} }, {"model", {"any"}}, {"seeds", {"any"}}, {"nodes", {"any"}}, diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 64affe1..3866c97 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -194,6 +194,7 @@ namespace platform { // // Train model // + clf->setSmoothing(smooth_type); clf->fit(X_train, y_train, features, className, states); if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "b"); diff --git a/src/main/Experiment.h b/src/main/Experiment.h index ef85cc0..29d208b 100644 --- a/src/main/Experiment.h +++ b/src/main/Experiment.h @@ -7,6 +7,7 @@ #include "bayesnet/BaseClassifier.h" #include "HyperParameters.h" #include "results/Result.h" +#include "bayesnet/network/Network.h" namespace platform { using json = nlohmann::ordered_json; @@ -24,6 +25,17 @@ namespace platform { { this->discretization_algo = discretization_algo; this->result.setDiscretizationAlgorithm(discretization_algo); return *this; } + Experiment& setSmoothSrategy(const std::string& smooth_strategy) + { + this->smooth_strategy = smooth_strategy; this->result.setSmoothStrategy(smooth_strategy); + if (smooth_strategy == "OLD_LAPLACE") + smooth_type = bayesnet::Smoothing_t::OLD_LAPLACE; + else if (smooth_strategy == "LAPLACE") + smooth_type = bayesnet::Smoothing_t::LAPLACE; + else if (smooth_strategy == "CESTNIK") + smooth_type = bayesnet::Smoothing_t::CESTNIK; + return *this; + } Experiment& setLanguageVersion(const std::string& language_version) { this->result.setLanguageVersion(language_version); return *this; } Experiment& setDiscretized(bool discretized) { this->discretized = discretized; result.setDiscretized(discretized); return *this; } Experiment& setStratified(bool stratified) { this->stratified = stratified; result.setStratified(stratified); return *this; } @@ -43,6 +55,8 @@ namespace platform { std::vector results; std::vector randomSeeds; std::string discretization_algo; + std::string smooth_strategy; + bayesnet::Smoothing_t smooth_type = bayesnet::Smoothing_t::OLD_LAPLACE; HyperParameters hyperparameters; int nfolds{ 0 }; int max_name{ 7 }; // max length of dataset name for formatting (default 7) diff --git a/src/reports/ReportConsole.cpp b/src/reports/ReportConsole.cpp index 92b670c..3c9ddbb 100644 --- a/src/reports/ReportConsole.cpp +++ b/src/reports/ReportConsole.cpp @@ -23,10 +23,12 @@ namespace platform { + " random seeds. " + data["date"].get() + " " + data["time"].get() ); sheader << headerLine(data["title"].get()); - std::string algorithm = data["discretized"].get() ? " (" + data["discretization_algorithm"].get() + ")" : ""; + std::string discretiz_algo = data.find("discretization_algorithm") != data.end() ? data["discretization_algorithm"].get() : "OLD_LAPLACE"; + std::string algorithm = data["discretized"].get() ? " (" + discretiz_algo + ")" : ""; + std::string smooth = data.find("smooth_strategy") != data.end() ? data["smooth_strategy"].get() : "OLD_LAPLACE"; sheader << headerLine( "Random seeds: " + fromVector("seeds") + " Discretized: " + (data["discretized"].get() ? "True" : "False") + algorithm - + " Stratified: " + (data["stratified"].get() ? "True" : "False") + + " Stratified: " + (data["stratified"].get() ? "True" : "False") + " Smooth Strategy: " + smooth ); oss << "Execution took " << std::setprecision(2) << std::fixed << data["duration"].get() << " seconds, " << data["duration"].get() / 3600 << " hours, on " << data["platform"].get(); diff --git a/src/reports/ReportExcel.cpp b/src/reports/ReportExcel.cpp index 0a63b5a..0acd714 100644 --- a/src/reports/ReportExcel.cpp +++ b/src/reports/ReportExcel.cpp @@ -49,7 +49,10 @@ namespace platform { worksheet_merge_range(worksheet, 0, 0, 0, 12, message.c_str(), styles["headerFirst"]); worksheet_merge_range(worksheet, 1, 0, 1, 12, data["title"].get().c_str(), styles["headerRest"]); worksheet_merge_range(worksheet, 2, 0, 3, 0, ("Score is " + data["score_name"].get()).c_str(), styles["headerRest"]); - worksheet_merge_range(worksheet, 2, 1, 3, 3, "Execution time", styles["headerRest"]); + writeString(2, 1, "Smooth", "headerRest"); + std::string smooth = data.find("smooth_strategy") != data.end() ? data["smooth_strategy"].get() : "OLD_LAPLACE"; + writeString(3, 1, smooth, "headerSmall"); + worksheet_merge_range(worksheet, 2, 2, 3, 3, "Execution time", styles["headerRest"]); oss << std::setprecision(2) << std::fixed << data["duration"].get() << " s"; worksheet_merge_range(worksheet, 2, 4, 2, 5, oss.str().c_str(), styles["headerRest"]); oss.str(""); @@ -65,7 +68,8 @@ namespace platform { worksheet_merge_range(worksheet, 3, 10, 3, 11, oss.str().c_str(), styles["headerSmall"]); oss.str(""); oss.clear(); - std::string algorithm = data["discretized"].get() ? " (" + data["discretization_algorithm"].get() + ")" : ""; + std::string discretiz_algo = data.find("discretization_algorithm") != data.end() ? data["discretization_algorithm"].get() : "mdlp"; + std::string algorithm = data["discretized"].get() ? " (" + discretiz_algo + ")" : ""; oss << "Discretized: " << (data["discretized"].get() ? "True" : "False") << algorithm; worksheet_write_string(worksheet, 3, 12, oss.str().c_str(), styles["headerSmall"]); } diff --git a/src/results/Result.h b/src/results/Result.h index 4a9f4f2..70d6c6b 100644 --- a/src/results/Result.h +++ b/src/results/Result.h @@ -32,6 +32,7 @@ namespace platform { json getData() const { return data; } // Setters void setTitle(const std::string& title) { data["title"] = title; }; + void setSmoothStrategy(const std::string& smooth_strategy) { data["smooth_strategy"] = smooth_strategy; }; void setDiscretizationAlgorithm(const std::string& discretization_algo) { data["discretization_algorithm"] = discretization_algo; }; void setLanguage(const std::string& language) { data["language"] = language; }; void setLanguageVersion(const std::string& language_version) { data["language_version"] = language_version; }; @@ -45,7 +46,6 @@ namespace platform { void setStratified(bool stratified) { data["stratified"] = stratified; }; void setNFolds(int nfolds) { data["folds"] = nfolds; }; void setPlatform(const std::string& platform_name) { data["platform"] = platform_name; }; - private: json data; bool complete;