Compare commits

..

21 Commits

Author SHA1 Message Date
92820555da Simple fix 2023-10-28 10:56:47 +02:00
5a3af51826 Activate best score in odte 2023-10-25 10:23:42 +02:00
a8f9800631 Fix mistake when no results in manage 2023-10-24 19:44:23 +02:00
84cec0c1e0 Add results files affected in best results excel 2023-10-24 16:18:52 +02:00
130139f644 Update formulas to use letters in ranges in excel 2023-10-24 13:06:31 +02:00
651f84b562 Fix mistake in conditional format in bestresults 2023-10-24 11:18:19 +02:00
553ab0fa22 Add conditional format to BestResults Excel 2023-10-24 10:56:41 +02:00
4975feabff Fix mistake in node count 2023-10-23 22:46:10 +02:00
32293af69f Fix header in manage 2023-10-23 17:04:59 +02:00
858664be2d Add total number of results in manage 2023-10-23 16:22:15 +02:00
1f705f6018 Refactor BestScore and add experiment to .env 2023-10-23 16:12:52 +02:00
7bcd2eed06 Add variable width of dataset name in reports 2023-10-22 22:58:52 +02:00
833acefbb3 Fix index limits mistake in manage 2023-10-22 20:21:50 +02:00
26b649ebae Refactor ManageResults and CommandParser 2023-10-22 20:03:34 +02:00
080eddf9cd Fix hyperparameters output in b_best 2023-10-20 22:52:48 +02:00
04e754b2f5 Adjust filename and hyperparameters in reports 2023-10-20 11:12:46 +02:00
38423048bd Add excel to best report of model 2023-10-19 18:12:55 +02:00
64fc97b892 Rename utilities sources to match final names 2023-10-19 09:57:04 +02:00
2c2159f192 Add quiet mode to b_main
Reduce output when --quiet is set, not showing fold info
2023-10-17 21:51:53 +02:00
6765552a7c Update submodule versions 2023-10-16 19:21:57 +02:00
f72aa5b9a6 Merge pull request 'Create Boost_CFS' (#11) from Boost_CFS into main
Add hyper parameter to BoostAODE. This hyper parameter decides if we select features with cfs/fcbf/iwss before start building models and build a Spode with the selected features.
The hyperparameter is select_features
2023-10-15 09:22:14 +00:00
37 changed files with 750 additions and 671 deletions

View File

@@ -1,5 +1,7 @@
# BayesNet # BayesNet
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
Bayesian Network Classifier with libtorch from scratch Bayesian Network Classifier with libtorch from scratch
## 0. Setup ## 0. Setup

View File

@@ -1,33 +0,0 @@
Weights matrix:
0.0000000, 0.0384968, 0.0795434, 0.1546867, -0.0000000, 0.1788104, 0.2214721, 0.0323837, 0.0366549,
0.0384968, 0.0000000, 0.0200662, 0.0200937, -0.0000000, 0.0637224, 0.0183005, 0.0127657, 0.0136054,
0.0795434, 0.0200662, 0.0000000, 0.0605489, -0.0000000, 0.0894469, 0.1689408, 0.0321602, 0.0223184,
0.1546867, 0.0200937, 0.0605489, 0.0000000, -0.0000000, 0.1150757, 0.1332292, 0.0422865, 0.0191138,
-0.0000000, -0.0000000, -0.0000000, -0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000,
0.1788104, 0.0637224, 0.0894469, 0.1150757, 0.0000000, 0.0000000, 0.1407102, 0.0406590, 0.0366986,
0.2214721, 0.0183005, 0.1689408, 0.1332292, 0.0000000, 0.1407102, 0.0000000, 0.0427515, 0.0349965,
0.0323837, 0.0127657, 0.0321602, 0.0422865, 0.0000000, 0.0406590, 0.0427515, 0.0000000, 0.0343376,
0.0366549, 0.0136054, 0.0223184, 0.0191138, 0.0000000, 0.0366986, 0.0349965, 0.0343376, 0.0000000,
Edge : Weight
0 - 6 : 0.2214721
0 - 5 : 0.1788104
2 - 6 : 0.1689408
0 - 3 : 0.1546867
1 - 5 : 0.0637224
6 - 7 : 0.0427515
5 - 8 : 0.0366986
4 - 5 : 0.0000000
-------------------------------------------------------------------------------
Metrics Test
Test Maximum Spanning Tree
-------------------------------------------------------------------------------
/Users/rmontanana/Code/BayesNet/tests/TestBayesMetrics.cc:58
...............................................................................
/Users/rmontanana/Code/BayesNet/tests/TestBayesMetrics.cc:69: PASSED:
REQUIRE( result == resultsMST.at(file_name) )
with expansion:
(0, 6) (0, 5) (0, 3) (5, 1) (5, 8) (5, 4) (6, 2) (6, 7)
==
(0, 6) (0, 5) (0, 3) (5, 1) (5, 8) (5, 4) (6, 2) (6, 7)

View File

@@ -125,7 +125,7 @@ namespace bayesnet {
// Step 0: Set the finish condition // Step 0: Set the finish condition
// if not repeatSparent a finish condition is run out of features // if not repeatSparent a finish condition is run out of features
// n_models == maxModels // n_models == maxModels
// epsiolon sub t > 0.5 => inverse the weights policy // epsilon sub t > 0.5 => inverse the weights policy
// validation error is not decreasing // validation error is not decreasing
while (!exitCondition) { while (!exitCondition) {
// Step 1: Build ranking with mutual information // Step 1: Build ranking with mutual information

View File

@@ -137,7 +137,7 @@ namespace bayesnet {
int Classifier::getNumberOfNodes() const int Classifier::getNumberOfNodes() const
{ {
// Features does not include class // Features does not include class
return fitted ? model.getFeatures().size() + 1 : 0; return fitted ? model.getFeatures().size() : 0;
} }
int Classifier::getNumberOfEdges() const int Classifier::getNumberOfEdges() const
{ {

View File

@@ -3,6 +3,7 @@
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm>
#include "BestResults.h" #include "BestResults.h"
#include "Result.h" #include "Result.h"
#include "Colors.h" #include "Colors.h"
@@ -27,7 +28,6 @@ std::string ftime_to_string(TP tp)
return buffer.str(); return buffer.str();
} }
namespace platform { namespace platform {
string BestResults::build() string BestResults::build()
{ {
auto files = loadResultFiles(); auto files = loadResultFiles();
@@ -65,12 +65,10 @@ namespace platform {
file.close(); file.close();
return bestFileName; return bestFileName;
} }
string BestResults::bestResultFile() string BestResults::bestResultFile()
{ {
return "best_results_" + score + "_" + model + ".json"; return "best_results_" + score + "_" + model + ".json";
} }
pair<string, string> getModelScore(string name) pair<string, string> getModelScore(string name)
{ {
// results_accuracy_BoostAODE_MacBookpro16_2023-09-06_12:27:00_1.json // results_accuracy_BoostAODE_MacBookpro16_2023-09-06_12:27:00_1.json
@@ -82,7 +80,6 @@ namespace platform {
string model = name.substr(pos2 + 1, pos - pos2 - 1); string model = name.substr(pos2 + 1, pos - pos2 - 1);
return { model, score }; return { model, score };
} }
vector<string> BestResults::loadResultFiles() vector<string> BestResults::loadResultFiles()
{ {
vector<string> files; vector<string> files;
@@ -99,7 +96,6 @@ namespace platform {
} }
return files; return files;
} }
json BestResults::loadFile(const string& fileName) json BestResults::loadFile(const string& fileName)
{ {
ifstream resultData(fileName); ifstream resultData(fileName);
@@ -136,7 +132,6 @@ namespace platform {
} }
return datasets; return datasets;
} }
void BestResults::buildAll() void BestResults::buildAll()
{ {
auto models = getModels(); auto models = getModels();
@@ -147,8 +142,7 @@ namespace platform {
} }
model = "any"; model = "any";
} }
void BestResults::listFile()
void BestResults::reportSingle()
{ {
string bestFileName = path + bestResultFile(); string bestFileName = path + bestResultFile();
if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) { if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) {
@@ -162,22 +156,35 @@ namespace platform {
auto data = loadFile(bestFileName); auto data = loadFile(bestFileName);
auto datasets = getDatasets(data); auto datasets = getDatasets(data);
int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size(); int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
cout << Colors::GREEN() << "Best results for " << model << " and " << score << " as of " << date << endl; int maxFileName = 0;
cout << "--------------------------------------------------------" << endl; int maxHyper = 15;
cout << Colors::GREEN() << " # " << setw(maxDatasetName + 1) << left << string("Dataset") << "Score File Hyperparameters" << endl; for (auto const& item : data.items()) {
cout << "=== " << string(maxDatasetName, '=') << " =========== ================================================================== ================================================= " << endl; maxHyper = max(maxHyper, (int)item.value().at(1).dump().size());
maxFileName = max(maxFileName, (int)item.value().at(2).get<string>().size());
}
stringstream oss;
oss << Colors::GREEN() << "Best results for " << model << " as of " << date << endl;
cout << oss.str();
cout << string(oss.str().size() - 8, '-') << endl;
cout << Colors::GREEN() << " # " << setw(maxDatasetName + 1) << left << "Dataset" << "Score " << setw(maxFileName) << "File" << " Hyperparameters" << endl;
cout << "=== " << string(maxDatasetName, '=') << " =========== " << string(maxFileName, '=') << " " << string(maxHyper, '=') << endl;
auto i = 0; auto i = 0;
bool odd = true; bool odd = true;
double total = 0;
for (auto const& item : data.items()) { for (auto const& item : data.items()) {
auto color = odd ? Colors::BLUE() : Colors::CYAN(); auto color = odd ? Colors::BLUE() : Colors::CYAN();
double value = item.value().at(0).get<double>();
cout << color << setw(3) << fixed << right << i++ << " "; cout << color << setw(3) << fixed << right << i++ << " ";
cout << setw(maxDatasetName) << left << item.key() << " "; cout << setw(maxDatasetName) << left << item.key() << " ";
cout << setw(11) << setprecision(9) << fixed << item.value().at(0).get<double>() << " "; cout << setw(11) << setprecision(9) << fixed << value << " ";
cout << setw(66) << item.value().at(2).get<string>() << " "; cout << setw(maxFileName) << item.value().at(2).get<string>() << " ";
cout << item.value().at(1) << " "; cout << item.value().at(1) << " ";
cout << endl; cout << endl;
total += value;
odd = !odd; odd = !odd;
} }
cout << Colors::GREEN() << "=== " << string(maxDatasetName, '=') << " ===========" << endl;
cout << setw(5 + maxDatasetName) << "Total.................. " << setw(11) << setprecision(8) << fixed << total << endl;
} }
json BestResults::buildTableResults(vector<string> models) json BestResults::buildTableResults(vector<string> models)
{ {
@@ -202,11 +209,12 @@ namespace platform {
table["dateTable"] = ftime_to_string(maxDate); table["dateTable"] = ftime_to_string(maxDate);
return table; return table;
} }
void BestResults::printTableResults(vector<string> models, json table) void BestResults::printTableResults(vector<string> models, json table)
{ {
cout << Colors::GREEN() << "Best results for " << score << " as of " << table.at("dateTable").get<string>() << endl; stringstream oss;
cout << "------------------------------------------------" << endl; oss << Colors::GREEN() << "Best results for " << score << " as of " << table.at("dateTable").get<string>() << endl;
cout << oss.str();
cout << string(oss.str().size() - 8, '-') << endl;
cout << Colors::GREEN() << " # " << setw(maxDatasetName + 1) << left << string("Dataset"); cout << Colors::GREEN() << " # " << setw(maxDatasetName + 1) << left << string("Dataset");
for (const auto& model : models) { for (const auto& model : models) {
cout << setw(maxModelName) << left << model << " "; cout << setw(maxModelName) << left << model << " ";
@@ -271,6 +279,19 @@ namespace platform {
} }
cout << endl; cout << endl;
} }
void BestResults::reportSingle(bool excel)
{
listFile();
if (excel) {
auto models = getModels();
// Build the table of results
json table = buildTableResults(models);
vector<string> datasets = getDatasets(table.begin().value());
BestResultsExcel excel(score, datasets);
excel.reportSingle(model, path + bestResultFile());
messageExcelFile(excel.getFileName());
}
}
void BestResults::reportAll(bool excel) void BestResults::reportAll(bool excel)
{ {
auto models = getModels(); auto models = getModels();
@@ -292,9 +313,32 @@ namespace platform {
ranksModels = stats.getRanks(); ranksModels = stats.getRanks();
} }
if (excel) { if (excel) {
BestResultsExcel excel(score, models, datasets, table, ranksModels, friedman, significance); BestResultsExcel excel(score, datasets);
excel.build(); excel.reportAll(models, table, ranksModels, friedman, significance);
cout << Colors::YELLOW() << "** Excel file generated: " << excel.getFileName() << Colors::RESET() << endl; if (friedman) {
int idx = -1;
double min = 2000;
// Find out the control model
auto totals = vector<double>(models.size(), 0.0);
for (const auto& dataset : datasets) {
for (int i = 0; i < models.size(); ++i) {
totals[i] += ranksModels[dataset][models[i]];
} }
} }
for (int i = 0; i < models.size(); ++i) {
if (totals[i] < min) {
min = totals[i];
idx = i;
}
}
model = models.at(idx);
excel.reportSingle(model, path + bestResultFile());
}
messageExcelFile(excel.getFileName());
}
}
void BestResults::messageExcelFile(const string& fileName)
{
cout << Colors::YELLOW() << "** Excel file generated: " << fileName << Colors::RESET() << endl;
}
} }

View File

@@ -7,19 +7,24 @@ using json = nlohmann::json;
namespace platform { namespace platform {
class BestResults { class BestResults {
public: public:
explicit BestResults(const string& path, const string& score, const string& model, bool friedman, double significance = 0.05) : path(path), score(score), model(model), friedman(friedman), significance(significance) {} explicit BestResults(const string& path, const string& score, const string& model, bool friedman, double significance = 0.05)
: path(path), score(score), model(model), friedman(friedman), significance(significance)
{
}
string build(); string build();
void reportSingle(); void reportSingle(bool excel);
void reportAll(bool excel); void reportAll(bool excel);
void buildAll(); void buildAll();
private: private:
vector<string> getModels(); vector<string> getModels();
vector<string> getDatasets(json table); vector<string> getDatasets(json table);
vector<string> loadResultFiles(); vector<string> loadResultFiles();
void messageExcelFile(const string& fileName);
json buildTableResults(vector<string> models); json buildTableResults(vector<string> models);
void printTableResults(vector<string> models, json table); void printTableResults(vector<string> models, json table);
string bestResultFile(); string bestResultFile();
json loadFile(const string& fileName); json loadFile(const string& fileName);
void listFile();
string path; string path;
string score; string score;
string model; string model;

View File

@@ -1,21 +1,122 @@
#include <sstream> #include <sstream>
#include "BestResultsExcel.h" #include "BestResultsExcel.h"
#include "Paths.h" #include "Paths.h"
#include <map>
#include <nlohmann/json.hpp>
#include "Statistics.h" #include "Statistics.h"
#include "ReportExcel.h"
namespace platform { namespace platform {
BestResultsExcel::BestResultsExcel(const string& score, const vector<string>& models, const vector<string>& datasets, const json& table, const map<string, map<string, float>>& ranksModels, bool friedman, double significance) : json loadResultData(const string& fileName)
score(score), models(models), datasets(datasets), table(table), ranksModels(ranksModels), friedman(friedman), significance(significance) {
json data;
ifstream resultData(fileName);
if (resultData.is_open()) {
data = json::parse(resultData);
} else {
throw invalid_argument("Unable to open result file. [" + fileName + "]");
}
return data;
}
string getColumnName(int colNum)
{
string columnName = "";
if (colNum == 0)
return "A";
while (colNum > 0) {
int modulo = colNum % 26;
columnName = char(65 + modulo) + columnName;
colNum = (int)((colNum - modulo) / 26);
}
return columnName;
}
BestResultsExcel::BestResultsExcel(const string& score, const vector<string>& datasets) : score(score), datasets(datasets)
{ {
workbook = workbook_new((Paths::excel() + fileName).c_str()); workbook = workbook_new((Paths::excel() + fileName).c_str());
worksheet = workbook_add_worksheet(workbook, "Best Results");
setProperties("Best Results"); setProperties("Best Results");
createFormats();
int maxModelName = (*max_element(models.begin(), models.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
modelNameSize = max(modelNameSize, maxModelName);
int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size(); int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
datasetNameSize = max(datasetNameSize, maxDatasetName); datasetNameSize = max(datasetNameSize, maxDatasetName);
createFormats();
}
void BestResultsExcel::reportAll(const vector<string>& models, const json& table, const map<string, map<string, float>>& ranks, bool friedman, double significance)
{
this->table = table;
this->models = models;
ranksModels = ranks;
this->friedman = friedman;
this->significance = significance;
worksheet = workbook_add_worksheet(workbook, "Best Results");
int maxModelName = (*max_element(models.begin(), models.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
modelNameSize = max(modelNameSize, maxModelName);
formatColumns(); formatColumns();
build();
}
void BestResultsExcel::reportSingle(const string& model, const string& fileName)
{
worksheet = workbook_add_worksheet(workbook, "Report");
if (FILE* fileTest = fopen(fileName.c_str(), "r")) {
fclose(fileTest);
} else {
cerr << "File " << fileName << " doesn't exist." << endl;
exit(1);
}
json data = loadResultData(fileName);
string title = "Best results for " + model;
worksheet_merge_range(worksheet, 0, 0, 0, 4, title.c_str(), styles["headerFirst"]);
// Body header
row = 3;
int col = 1;
writeString(row, 0, "", "bodyHeader");
writeString(row, 1, "Dataset", "bodyHeader");
writeString(row, 2, "Score", "bodyHeader");
writeString(row, 3, "File", "bodyHeader");
writeString(row, 4, "Hyperparameters", "bodyHeader");
auto i = 0;
string hyperparameters;
int hypSize = 22;
map<string, string> files; // map of files imported and their tabs
for (auto const& item : data.items()) {
row++;
writeInt(row, 0, i++, "ints");
writeString(row, 1, item.key().c_str(), "text");
writeDouble(row, 2, item.value().at(0).get<double>(), "result");
auto fileName = item.value().at(2).get<string>();
string hyperlink = "";
try {
hyperlink = files.at(fileName);
}
catch (const out_of_range& oor) {
auto tabName = "table_" + to_string(i);
auto worksheetNew = workbook_add_worksheet(workbook, tabName.c_str());
json data = loadResultData(Paths::results() + fileName);
auto report = ReportExcel(data, false, workbook, worksheetNew);
report.show();
hyperlink = "#table_" + to_string(i);
files[fileName] = hyperlink;
}
hyperlink += "!H" + to_string(i + 6);
string fileNameText = "=HYPERLINK(\"" + hyperlink + "\",\"" + fileName + "\")";
worksheet_write_formula(worksheet, row, 3, fileNameText.c_str(), efectiveStyle("text"));
hyperparameters = item.value().at(1).dump();
if (hyperparameters.size() > hypSize) {
hypSize = hyperparameters.size();
}
writeString(row, 4, hyperparameters, "text");
}
row++;
// Set Totals
writeString(row, 1, "Total", "bodyHeader");
stringstream oss;
auto colName = getColumnName(2);
oss << "=sum(" << colName << "5:" << colName << row << ")";
worksheet_write_formula(worksheet, row, 2, oss.str().c_str(), styles["bodyHeader_odd"]);
// Set format
worksheet_freeze_panes(worksheet, 4, 2);
vector<int> columns_sizes = { 5, datasetNameSize, modelNameSize, 66, hypSize + 1 };
for (int i = 0; i < columns_sizes.size(); ++i) {
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
}
} }
BestResultsExcel::~BestResultsExcel() BestResultsExcel::~BestResultsExcel()
{ {
@@ -32,11 +133,30 @@ namespace platform {
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL); worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
} }
} }
void BestResultsExcel::addConditionalFormat(string formula)
{
// Add conditional format for max/min values in scores/ranks sheets
lxw_format* custom_format = workbook_add_format(workbook);
format_set_bg_color(custom_format, 0xFFC7CE);
format_set_font_color(custom_format, 0x9C0006);
// Create a conditional format object. A static object would also work.
lxw_conditional_format* conditional_format = (lxw_conditional_format*)calloc(1, sizeof(lxw_conditional_format));
conditional_format->type = LXW_CONDITIONAL_TYPE_FORMULA;
string col = getColumnName(models.size() + 1);
stringstream oss;
oss << "=C5=" << formula << "($C5:$" << col << "5)";
auto formulaValue = oss.str();
conditional_format->value_string = formulaValue.c_str();
conditional_format->format = custom_format;
worksheet_conditional_format_range(worksheet, 4, 2, datasets.size() + 3, models.size() + 1, conditional_format);
}
void BestResultsExcel::build() void BestResultsExcel::build()
{ {
// Create Sheet with scores // Create Sheet with scores
header(false); header(false);
body(false); body(false);
// Add conditional format for max values
addConditionalFormat("max");
footer(false); footer(false);
if (friedman) { if (friedman) {
// Create Sheet with ranks // Create Sheet with ranks
@@ -44,6 +164,7 @@ namespace platform {
formatColumns(); formatColumns();
header(true); header(true);
body(true); body(true);
addConditionalFormat("min");
footer(true); footer(true);
// Create Sheet with Friedman Test // Create Sheet with Friedman Test
doFriedman(); doFriedman();
@@ -90,7 +211,8 @@ namespace platform {
int col = 1; int col = 1;
for (const auto& model : models) { for (const auto& model : models) {
stringstream oss; stringstream oss;
oss << "=sum(indirect(address(" << 5 << "," << col + 2 << ")):indirect(address(" << row << "," << col + 2 << ")))"; auto colName = getColumnName(col + 1);
oss << "=SUM(" << colName << "5:" << colName << row << ")";
worksheet_write_formula(worksheet, row, ++col, oss.str().c_str(), styles["bodyHeader_odd"]); worksheet_write_formula(worksheet, row, ++col, oss.str().c_str(), styles["bodyHeader_odd"]);
} }
if (ranks) { if (ranks) {
@@ -98,8 +220,9 @@ namespace platform {
writeString(row, 1, "Average ranks", "bodyHeader"); writeString(row, 1, "Average ranks", "bodyHeader");
int col = 1; int col = 1;
for (const auto& model : models) { for (const auto& model : models) {
auto colName = getColumnName(col + 1);
stringstream oss; stringstream oss;
oss << "=sum(indirect(address(" << 5 << "," << col + 2 << ")):indirect(address(" << row - 1 << "," << col + 2 << ")))/" << datasets.size(); oss << "=SUM(" << colName << "5:" << colName << row - 1 << ")/" << datasets.size();
worksheet_write_formula(worksheet, row, ++col, oss.str().c_str(), styles["bodyHeader_odd"]); worksheet_write_formula(worksheet, row, ++col, oss.str().c_str(), styles["bodyHeader_odd"]);
} }
} }

View File

@@ -12,16 +12,19 @@ namespace platform {
class BestResultsExcel : ExcelFile { class BestResultsExcel : ExcelFile {
public: public:
BestResultsExcel(const string& score, const vector<string>& models, const vector<string>& datasets, const json& table, const map<string, map<string, float>>& ranks, bool friedman, double significance); BestResultsExcel(const string& score, const vector<string>& datasets);
~BestResultsExcel(); ~BestResultsExcel();
void build(); void reportAll(const vector<string>& models, const json& table, const map<string, map<string, float>>& ranks, bool friedman, double significance);
void reportSingle(const string& model, const string& fileName);
string getFileName(); string getFileName();
private: private:
void build();
void header(bool ranks); void header(bool ranks);
void body(bool ranks); void body(bool ranks);
void footer(bool ranks); void footer(bool ranks);
void formatColumns(); void formatColumns();
void doFriedman(); void doFriedman();
void addConditionalFormat(string formula);
const string fileName = "BestResults.xlsx"; const string fileName = "BestResults.xlsx";
string score; string score;
vector<string> models; vector<string> models;

View File

@@ -1,10 +1,28 @@
#ifndef BESTSCORE_H #ifndef BESTSCORE_H
#define BESTSCORE_H #define BESTSCORE_H
#include <string> #include <string>
#include <map>
#include <utility>
#include "DotEnv.h"
namespace platform {
class BestScore { class BestScore {
public: public:
static std::string title() { return "STree_default (linear-ovo)"; } static pair<string, double> getScore(const std::string& metric)
static double score() { return 22.109799; } {
static std::string scoreName() { return "accuracy"; } static map<pair<string, string>, pair<string, double>> data = {
{{"discretiz", "accuracy"}, {"STree_default (linear-ovo)", 22.109799}},
{{"odte", "accuracy"}, {"STree_default (linear-ovo)", 22.109799}},
}; };
auto env = platform::DotEnv();
string experiment = env.get("experiment");
try {
return data[{experiment, metric}];
}
catch (...) {
return { "", 0.0 };
}
}
};
}
#endif #endif

View File

@@ -5,13 +5,13 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include) include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include)
add_executable(b_main main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc ReportConsole.cc ReportBase.cc)
add_executable(b_manage manage.cc Results.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc ReportConsole.cc ReportBase.cc)
add_executable(b_list list.cc Datasets.cc Dataset.cc) add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
add_executable(b_best best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ExcelFile.cc) add_executable(b_list b_list.cc Datasets.cc Dataset.cc)
add_executable(testx testx.cpp Datasets.cc Dataset.cc Folding.cc ) add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp) target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}") target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}" "${TORCH_LIBRARIES}" ArffFiles mdlp)
target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}")
target_link_libraries(testx ArffFiles BayesNet "${TORCH_LIBRARIES}")

View File

@@ -0,0 +1,87 @@
#include "CommandParser.h"
#include <iostream>
#include <sstream>
#include <algorithm>
#include "Colors.h"
#include "Utils.h"
namespace platform {
void CommandParser::messageError(const string& message)
{
cout << Colors::RED() << message << Colors::RESET() << endl;
}
pair<char, int> CommandParser::parse(const string& color, const vector<tuple<string, char, bool>>& options, const char defaultCommand, const int maxIndex)
{
bool finished = false;
while (!finished) {
stringstream oss;
string line;
oss << color << "Choose option (";
bool first = true;
for (auto& option : options) {
if (first) {
first = false;
} else {
oss << ", ";
}
oss << get<char>(option) << "=" << get<string>(option);
}
oss << "): ";
cout << oss.str();
getline(cin, line);
cout << Colors::RESET();
line = trim(line);
if (line.size() == 0)
continue;
if (all_of(line.begin(), line.end(), ::isdigit)) {
command = defaultCommand;
index = stoi(line);
if (index > maxIndex || index < 0) {
messageError("Index out of range");
continue;
}
finished = true;
break;
}
bool found = false;
for (auto& option : options) {
if (line[0] == get<char>(option)) {
found = true;
// it's a match
line.erase(line.begin());
line = trim(line);
if (get<bool>(option)) {
// The option requires a value
if (line.size() == 0) {
messageError("Option " + get<string>(option) + " requires a value");
break;
}
try {
index = stoi(line);
if (index > maxIndex || index < 0) {
messageError("Index out of range");
break;
}
}
catch (const std::invalid_argument& ia) {
messageError("Invalid value: " + line);
break;
}
} else {
if (line.size() > 0) {
messageError("option " + get<string>(option) + " doesn't accept values");
break;
}
}
command = get<char>(option);
finished = true;
break;
}
}
if (!found) {
messageError("I don't know " + line);
}
}
return { command, index };
}
} /* namespace platform */

View File

@@ -0,0 +1,21 @@
#ifndef COMMAND_PARSER_H
#define COMMAND_PARSER_H
#include <string>
#include <vector>
#include <tuple>
using namespace std;
namespace platform {
class CommandParser {
public:
CommandParser() = default;
pair<char, int> parse(const string& color, const vector<tuple<string, char, bool>>& options, const char defaultCommand, const int maxIndex);
char getCommand() const { return command; };
int getIndex() const { return index; };
private:
void messageError(const string& message);
char command;
int index;
};
} /* namespace platform */
#endif /* COMMAND_PARSER_H */

View File

@@ -13,17 +13,6 @@ namespace platform {
class DotEnv { class DotEnv {
private: private:
std::map<std::string, std::string> env; std::map<std::string, std::string> env;
std::string trim(const std::string& str)
{
std::string result = str;
result.erase(result.begin(), std::find_if(result.begin(), result.end(), [](int ch) {
return !std::isspace(ch);
}));
result.erase(std::find_if(result.rbegin(), result.rend(), [](int ch) {
return !std::isspace(ch);
}).base(), result.end());
return result;
}
public: public:
DotEnv() DotEnv()
{ {

View File

@@ -9,6 +9,10 @@ namespace platform {
{ {
setDefault(); setDefault();
} }
ExcelFile::ExcelFile(lxw_workbook* workbook, lxw_worksheet* worksheet) : workbook(workbook), worksheet(worksheet)
{
setDefault();
}
void ExcelFile::setDefault() void ExcelFile::setDefault()
{ {
normalSize = 14; //font size for report body normalSize = 14; //font size for report body

View File

@@ -18,6 +18,7 @@ namespace platform {
public: public:
ExcelFile(); ExcelFile();
ExcelFile(lxw_workbook* workbook); ExcelFile(lxw_workbook* workbook);
ExcelFile(lxw_workbook* workbook, lxw_worksheet* worksheet);
lxw_workbook* getWorkbook(); lxw_workbook* getWorkbook();
protected: protected:
void setProperties(string title); void setProperties(string title);

View File

@@ -102,12 +102,12 @@ namespace platform {
cout << data.dump(4) << endl; cout << data.dump(4) << endl;
} }
void Experiment::go(vector<string> filesToProcess) void Experiment::go(vector<string> filesToProcess, bool quiet)
{ {
cout << "*** Starting experiment: " << title << " ***" << endl; cout << "*** Starting experiment: " << title << " ***" << endl;
for (auto fileName : filesToProcess) { for (auto fileName : filesToProcess) {
cout << "- " << setw(20) << left << fileName << " " << right << flush; cout << "- " << setw(20) << left << fileName << " " << right << flush;
cross_validation(fileName); cross_validation(fileName, quiet);
cout << endl; cout << endl;
} }
} }
@@ -132,7 +132,7 @@ namespace platform {
cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush; cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
} }
void Experiment::cross_validation(const string& fileName) void Experiment::cross_validation(const string& fileName, bool quiet)
{ {
auto datasets = platform::Datasets(discretized, Paths::datasets()); auto datasets = platform::Datasets(discretized, Paths::datasets());
// Get dataset // Get dataset
@@ -141,7 +141,9 @@ namespace platform {
auto features = datasets.getFeatures(fileName); auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName); auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName); auto className = datasets.getClassName(fileName);
if (!quiet) {
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
}
// Prepare Result // Prepare Result
auto result = Result(); auto result = Result();
auto [values, counts] = at::_unique(y); auto [values, counts] = at::_unique(y);
@@ -159,6 +161,7 @@ namespace platform {
Timer train_timer, test_timer; Timer train_timer, test_timer;
int item = 0; int item = 0;
for (auto seed : randomSeeds) { for (auto seed : randomSeeds) {
if (!quiet)
cout << "(" << seed << ") doing Fold: " << flush; cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold; Fold* fold;
if (stratified) if (stratified)
@@ -180,9 +183,11 @@ namespace platform {
auto y_train = y.index({ train_t }); auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t }); auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t }); auto y_test = y.index({ test_t });
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "a"); showProgress(nfold + 1, getColor(clf->getStatus()), "a");
// Train model // Train model
clf->fit(X_train, y_train, features, className, states); clf->fit(X_train, y_train, features, className, states);
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "b"); showProgress(nfold + 1, getColor(clf->getStatus()), "b");
nodes[item] = clf->getNumberOfNodes(); nodes[item] = clf->getNumberOfNodes();
edges[item] = clf->getNumberOfEdges(); edges[item] = clf->getNumberOfEdges();
@@ -191,12 +196,14 @@ namespace platform {
// Score train // Score train
auto accuracy_train_value = clf->score(X_train, y_train); auto accuracy_train_value = clf->score(X_train, y_train);
// Test model // Test model
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "c"); showProgress(nfold + 1, getColor(clf->getStatus()), "c");
test_timer.start(); test_timer.start();
auto accuracy_test_value = clf->score(X_test, y_test); auto accuracy_test_value = clf->score(X_test, y_test);
test_time[item] = test_timer.getDuration(); test_time[item] = test_timer.getDuration();
accuracy_train[item] = accuracy_train_value; accuracy_train[item] = accuracy_train_value;
accuracy_test[item] = accuracy_test_value; accuracy_test[item] = accuracy_test_value;
if (!quiet)
cout << "\b\b\b, " << flush; cout << "\b\b\b, " << flush;
// Store results and times in vector // Store results and times in vector
result.addScoreTrain(accuracy_train_value); result.addScoreTrain(accuracy_train_value);
@@ -206,6 +213,7 @@ namespace platform {
item++; item++;
clf.reset(); clf.reset();
} }
if (!quiet)
cout << "end. " << flush; cout << "end. " << flush;
delete fold; delete fold;
} }

View File

@@ -108,8 +108,8 @@ namespace platform {
Experiment& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; } Experiment& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
string get_file_name(); string get_file_name();
void save(const string& path); void save(const string& path);
void cross_validation(const string& fileName); void cross_validation(const string& fileName, bool quiet);
void go(vector<string> filesToProcess); void go(vector<string> filesToProcess, bool quiet);
void show(); void show();
void report(); void report();
}; };

View File

@@ -0,0 +1,213 @@
#include "ManageResults.h"
#include "CommandParser.h"
#include <filesystem>
#include <tuple>
#include "Colors.h"
#include "CLocale.h"
#include "Paths.h"
#include "ReportConsole.h"
#include "ReportExcel.h"
namespace platform {
ManageResults::ManageResults(int numFiles, const string& model, const string& score, bool complete, bool partial, bool compare) :
numFiles{ numFiles }, complete{ complete }, partial{ partial }, compare{ compare }, results(Results(Paths::results(), model, score, complete, partial))
{
indexList = true;
openExcel = false;
workbook = NULL;
if (numFiles == 0) {
this->numFiles = results.size();
}
}
void ManageResults::doMenu()
{
if (results.empty()) {
cout << Colors::MAGENTA() << "No results found!" << Colors::RESET() << endl;
return;
}
results.sortDate();
list();
menu();
if (openExcel) {
workbook_close(workbook);
}
cout << Colors::RESET() << "Done!" << endl;
}
void ManageResults::list()
{
auto temp = ConfigLocale();
string suffix = numFiles != results.size() ? " of " + to_string(results.size()) : "";
stringstream oss;
oss << "Results on screen: " << numFiles << suffix;
cout << Colors::GREEN() << oss.str() << endl;
cout << string(oss.str().size(), '-') << endl;
if (complete) {
cout << Colors::MAGENTA() << "Only listing complete results" << endl;
}
if (partial) {
cout << Colors::MAGENTA() << "Only listing partial results" << endl;
}
auto i = 0;
int maxModel = results.maxModelSize();
cout << Colors::GREEN() << " # Date " << setw(maxModel) << left << "Model" << " Score Name Score C/P Duration Title" << endl;
cout << "=== ========== " << string(maxModel, '=') << " =========== =========== === ========= =============================================================" << endl;
bool odd = true;
for (auto& result : results) {
auto color = odd ? Colors::BLUE() : Colors::CYAN();
cout << color << setw(3) << fixed << right << i++ << " ";
cout << result.to_string(maxModel) << endl;
if (i == numFiles) {
break;
}
odd = !odd;
}
}
bool ManageResults::confirmAction(const string& intent, const string& fileName) const
{
string color;
if (intent == "delete") {
color = Colors::RED();
} else {
color = Colors::YELLOW();
}
string line;
bool finished = false;
while (!finished) {
cout << color << "Really want to " << intent << " " << fileName << "? (y/n): ";
getline(cin, line);
finished = line.size() == 1 && (tolower(line[0]) == 'y' || tolower(line[0] == 'n'));
}
if (tolower(line[0]) == 'y') {
return true;
}
cout << "Not done!" << endl;
return false;
}
void ManageResults::report(const int index, const bool excelReport)
{
cout << Colors::YELLOW() << "Reporting " << results.at(index).getFilename() << endl;
auto data = results.at(index).load();
if (excelReport) {
ReportExcel reporter(data, compare, workbook);
reporter.show();
openExcel = true;
workbook = reporter.getWorkbook();
cout << "Adding sheet to " << Paths::excel() + Paths::excelResults() << endl;
} else {
ReportConsole reporter(data, compare);
reporter.show();
}
}
void ManageResults::showIndex(const int index, const int idx)
{
// Show a dataset result inside a report
auto data = results.at(index).load();
cout << Colors::YELLOW() << "Showing " << results.at(index).getFilename() << endl;
ReportConsole reporter(data, compare, idx);
reporter.show();
}
void ManageResults::sortList()
{
cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', duration='u', model='m'): ";
string line;
char option;
getline(cin, line);
if (line.size() == 0)
return;
if (line.size() > 1) {
cout << "Invalid option" << endl;
return;
}
option = line[0];
switch (option) {
case 'd':
results.sortDate();
break;
case 's':
results.sortScore();
break;
case 'u':
results.sortDuration();
break;
case 'm':
results.sortModel();
break;
default:
cout << "Invalid option" << endl;
}
}
void ManageResults::menu()
{
char option;
int index, subIndex;
bool finished = false;
string filename;
// tuple<Option, digit, requires value>
vector<tuple<string, char, bool>> mainOptions = {
{"quit", 'q', false},
{"list", 'l', false},
{"delete", 'd', true},
{"hide", 'h', true},
{"sort", 's', false},
{"report", 'r', true},
{"excel", 'e', true}
};
vector<tuple<string, char, bool>> listOptions = {
{"report", 'r', true},
{"list", 'l', false},
{"quit", 'q', false}
};
auto parser = CommandParser();
while (!finished) {
if (indexList) {
tie(option, index) = parser.parse(Colors::GREEN(), mainOptions, 'r', numFiles - 1);
} else {
tie(option, subIndex) = parser.parse(Colors::MAGENTA(), listOptions, 'r', results.at(index).load()["results"].size() - 1);
}
switch (option) {
case 'q':
finished = true;
break;
case 'l':
list();
indexList = true;
break;
case 'd':
filename = results.at(index).getFilename();
if (!confirmAction("delete", filename))
break;
cout << "Deleting " << filename << endl;
results.deleteResult(index);
cout << "File: " + filename + " deleted!" << endl;
list();
break;
case 'h':
filename = results.at(index).getFilename();
if (!confirmAction("hide", filename))
break;
filename = results.at(index).getFilename();
cout << "Hiding " << filename << endl;
results.hideResult(index, Paths::hiddenResults());
cout << "File: " + filename + " hidden! (moved to " << Paths::hiddenResults() << ")" << endl;
list();
break;
case 's':
sortList();
list();
break;
case 'r':
if (indexList) {
report(index, false);
indexList = false;
} else {
showIndex(index, subIndex);
}
break;
case 'e':
report(index, true);
break;
}
}
}
} /* namespace platform */

View File

@@ -0,0 +1,31 @@
#ifndef MANAGE_RESULTS_H
#define MANAGE_RESULTS_H
#include "Results.h"
#include "xlsxwriter.h"
namespace platform {
class ManageResults {
public:
ManageResults(int numFiles, const string& model, const string& score, bool complete, bool partial, bool compare);
~ManageResults() = default;
void doMenu();
private:
void list();
bool confirmAction(const string& intent, const string& fileName) const;
void report(const int index, const bool excelReport);
void showIndex(const int index, const int idx);
void sortList();
void menu();
int numFiles;
bool indexList;
bool openExcel;
bool complete;
bool partial;
bool compare;
Results results;
lxw_workbook* workbook;
};
}
#endif /* MANAGE_RESULTS_H */

View File

@@ -6,6 +6,7 @@ namespace platform {
class Paths { class Paths {
public: public:
static std::string results() { return "results/"; } static std::string results() { return "results/"; }
static std::string hiddenResults() { return "hidden_results/"; }
static std::string excel() { return "excel/"; } static std::string excel() { return "excel/"; }
static std::string cfs() { return "cfs/"; } static std::string cfs() { return "cfs/"; }
static std::string datasets() static std::string datasets()
@@ -13,6 +14,7 @@ namespace platform {
auto env = platform::DotEnv(); auto env = platform::DotEnv();
return env.get("source_data"); return env.get("source_data");
} }
static std::string excelResults() { return "some_results.xlsx"; }
}; };
} }
#endif #endif

View File

@@ -2,7 +2,6 @@
#include <locale> #include <locale>
#include "Datasets.h" #include "Datasets.h"
#include "ReportBase.h" #include "ReportBase.h"
#include "BestScore.h"
#include "DotEnv.h" #include "DotEnv.h"
namespace platform { namespace platform {

View File

@@ -1,3 +1,4 @@
#include <iostream>
#include <sstream> #include <sstream>
#include <locale> #include <locale>
#include "ReportConsole.h" #include "ReportConsole.h"
@@ -28,8 +29,15 @@ namespace platform {
void ReportConsole::body() void ReportConsole::body()
{ {
auto tmp = ConfigLocale(); auto tmp = ConfigLocale();
cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; int maxHyper = 15;
cout << "=== ========================= ====== ===== === ========= ========= ========= =============== =================== ====================" << endl; int maxDataset = 7;
for (const auto& r : data["results"]) {
maxHyper = max(maxHyper, (int)r["hyperparameters"].dump().size());
maxDataset = max(maxDataset, (int)r["dataset"].get<string>().size());
}
cout << Colors::GREEN() << " # " << setw(maxDataset) << left << "Dataset" << " Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
cout << "=== " << string(maxDataset, '=') << " ====== ===== === ========= ========= ========= =============== =================== " << string(maxHyper, '=') << endl;
json lastResult; json lastResult;
double totalScore = 0.0; double totalScore = 0.0;
bool odd = true; bool odd = true;
@@ -41,8 +49,8 @@ namespace platform {
} }
auto color = odd ? Colors::CYAN() : Colors::BLUE(); auto color = odd ? Colors::CYAN() : Colors::BLUE();
cout << color; cout << color;
cout << setw(3) << index++ << " "; cout << setw(3) << right << index++ << " ";
cout << setw(25) << left << r["dataset"].get<string>() << " "; cout << setw(maxDataset) << left << r["dataset"].get<string>() << " ";
cout << setw(6) << right << r["samples"].get<int>() << " "; cout << setw(6) << right << r["samples"].get<int>() << " ";
cout << setw(5) << right << r["features"].get<int>() << " "; cout << setw(5) << right << r["features"].get<int>() << " ";
cout << setw(3) << right << r["classes"].get<int>() << " "; cout << setw(3) << right << r["classes"].get<int>() << " ";
@@ -87,9 +95,10 @@ namespace platform {
cout << Colors::MAGENTA() << string(MAXL, '*') << endl; cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
showSummary(); showSummary();
auto score = data["score_name"].get<string>(); auto score = data["score_name"].get<string>();
if (score == BestScore::scoreName()) { auto best = BestScore::getScore(score);
if (best.first != "") {
stringstream oss; stringstream oss;
oss << score << " compared to " << BestScore::title() << " .: " << totalScore / BestScore::score(); oss << score << " compared to " << best.first << " .: " << totalScore / best.second;
cout << headerLine(oss.str()); cout << headerLine(oss.str());
} }
if (!getExistBestFile() && compare) { if (!getExistBestFile() && compare) {

View File

@@ -1,7 +1,6 @@
#ifndef REPORTCONSOLE_H #ifndef REPORTCONSOLE_H
#define REPORTCONSOLE_H #define REPORTCONSOLE_H
#include <string> #include <string>
#include <iostream>
#include "ReportBase.h" #include "ReportBase.h"
#include "Colors.h" #include "Colors.h"

View File

@@ -6,7 +6,7 @@
namespace platform { namespace platform {
ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook) : ReportBase(data_, compare), ExcelFile(workbook) ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook, lxw_worksheet* worksheet) : ReportBase(data_, compare), ExcelFile(workbook, worksheet)
{ {
createFile(); createFile();
} }
@@ -19,12 +19,8 @@ namespace platform {
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL); worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
} }
} }
void ReportExcel::createWorksheet()
void ReportExcel::createFile()
{ {
if (workbook == NULL) {
workbook = workbook_new((Paths::excel() + fileName).c_str());
}
const string name = data["model"].get<string>(); const string name = data["model"].get<string>();
string suffix = ""; string suffix = "";
string efectiveName; string efectiveName;
@@ -42,7 +38,16 @@ namespace platform {
throw invalid_argument("Couldn't create sheet " + efectiveName); throw invalid_argument("Couldn't create sheet " + efectiveName);
} }
} }
cout << "Adding sheet " << efectiveName << " to " << Paths::excel() + fileName << endl; }
void ReportExcel::createFile()
{
if (workbook == NULL) {
workbook = workbook_new((Paths::excel() + Paths::excelResults()).c_str());
}
if (worksheet == NULL) {
createWorksheet();
}
setProperties(data["title"].get<string>()); setProperties(data["title"].get<string>());
createFormats(); createFormats();
formatColumns(); formatColumns();
@@ -115,14 +120,7 @@ namespace platform {
writeString(row, col + 9, status, "textCentered"); writeString(row, col + 9, status, "textCentered");
writeDouble(row, col + 10, r["time"].get<double>(), "time"); writeDouble(row, col + 10, r["time"].get<double>(), "time");
writeDouble(row, col + 11, r["time_std"].get<double>(), "time"); writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
try { hyperparameters = r["hyperparameters"].dump();
hyperparameters = r["hyperparameters"].get<string>();
}
catch (const exception& err) {
stringstream oss;
oss << r["hyperparameters"];
hyperparameters = oss.str();
}
if (hyperparameters.size() > hypSize) { if (hyperparameters.size() > hypSize) {
hypSize = hyperparameters.size(); hypSize = hyperparameters.size();
} }
@@ -130,7 +128,6 @@ namespace platform {
lastResult = r; lastResult = r;
totalScore += r["score"].get<double>(); totalScore += r["score"].get<double>();
row++; row++;
} }
// Set the right column width of hyperparameters with the maximum length // Set the right column width of hyperparameters with the maximum length
worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL); worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL);
@@ -171,9 +168,10 @@ namespace platform {
showSummary(); showSummary();
row += 4 + summary.size(); row += 4 + summary.size();
auto score = data["score_name"].get<string>(); auto score = data["score_name"].get<string>();
if (score == BestScore::scoreName()) { auto best = BestScore::getScore(score);
worksheet_merge_range(worksheet, row, 1, row, 5, (score + " compared to " + BestScore::title() + " .:").c_str(), efectiveStyle("text")); if (best.first != "") {
writeDouble(row, 6, totalScore / BestScore::score(), "result"); worksheet_merge_range(worksheet, row, 1, row, 5, (score + " compared to " + best.first + " .:").c_str(), efectiveStyle("text"));
writeDouble(row, 6, totalScore / best.second, "result");
} }
if (!getExistBestFile() && compare) { if (!getExistBestFile() && compare) {
worksheet_write_string(worksheet, row + 1, 0, "*** Best Results File not found. Couldn't compare any result!", styles["summaryStyle"]); worksheet_write_string(worksheet, row + 1, 0, "*** Best Results File not found. Couldn't compare any result!", styles["summaryStyle"]);

View File

@@ -9,11 +9,11 @@ namespace platform {
using namespace std; using namespace std;
class ReportExcel : public ReportBase, public ExcelFile { class ReportExcel : public ReportBase, public ExcelFile {
public: public:
explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook); explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook, lxw_worksheet* worksheet = NULL);
private: private:
const string fileName = "some_results.xlsx";
void formatColumns(); void formatColumns();
void createFile(); void createFile();
void createWorksheet();
void closeFile(); void closeFile();
void header() override; void header() override;
void body() override; void body() override;

View File

@@ -1,9 +1,10 @@
#include "Result.h"
#include "BestScore.h"
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include "Result.h"
#include "Colors.h" #include "Colors.h"
#include "BestScore.h" #include "DotEnv.h"
#include "CLocale.h" #include "CLocale.h"
namespace platform { namespace platform {
@@ -18,8 +19,9 @@ namespace platform {
score += result["score"].get<double>(); score += result["score"].get<double>();
} }
scoreName = data["score_name"]; scoreName = data["score_name"];
if (scoreName == BestScore::scoreName()) { auto best = BestScore::getScore(scoreName);
score /= BestScore::score(); if (best.first != "") {
score /= best.second;
} }
title = data["title"]; title = data["title"];
duration = data["duration"]; duration = data["duration"];
@@ -37,14 +39,14 @@ namespace platform {
throw invalid_argument("Unable to open result file. [" + path + "/" + filename + "]"); throw invalid_argument("Unable to open result file. [" + path + "/" + filename + "]");
} }
string Result::to_string() const string Result::to_string(int maxModel) const
{ {
auto tmp = ConfigLocale(); auto tmp = ConfigLocale();
stringstream oss; stringstream oss;
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration; double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s"; string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
oss << date << " "; oss << date << " ";
oss << setw(12) << left << model << " "; oss << setw(maxModel) << left << model << " ";
oss << setw(11) << left << scoreName << " "; oss << setw(11) << left << scoreName << " ";
oss << right << setw(11) << setprecision(7) << fixed << score << " "; oss << right << setw(11) << setprecision(7) << fixed << score << " ";
auto completeString = isComplete() ? "C" : "P"; auto completeString = isComplete() ? "C" : "P";

View File

@@ -12,7 +12,7 @@ namespace platform {
public: public:
Result(const string& path, const string& filename); Result(const string& path, const string& filename);
json load() const; json load() const;
string to_string() const; string to_string(int maxModel) const;
string getFilename() const { return filename; }; string getFilename() const { return filename; };
string getDate() const { return date; }; string getDate() const { return date; };
double getScore() const { return score; }; double getScore() const { return score; };

View File

@@ -1,11 +1,17 @@
#include <filesystem>
#include "Results.h" #include "Results.h"
#include "ReportConsole.h" #include <algorithm>
#include "ReportExcel.h"
#include "BestScore.h"
#include "Colors.h"
#include "CLocale.h"
namespace platform { namespace platform {
Results::Results(const string& path, const string& model, const string& score, bool complete, bool partial) :
path(path), model(model), scoreName(score), complete(complete), partial(partial)
{
load();
if (!files.empty()) {
maxModel = (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size();
} else {
maxModel = 0;
}
};
void Results::load() void Results::load()
{ {
using std::filesystem::directory_iterator; using std::filesystem::directory_iterator;
@@ -20,212 +26,22 @@ namespace platform {
files.push_back(result); files.push_back(result);
} }
} }
if (max == 0) {
max = files.size();
} }
} void Results::hideResult(int index, const string& pathHidden)
void Results::show() const
{ {
auto temp = ConfigLocale(); auto filename = files.at(index).getFilename();
cout << Colors::GREEN() << "Results found: " << files.size() << endl; rename((path + "/" + filename).c_str(), (pathHidden + "/" + filename).c_str());
cout << "-------------------" << endl; files.erase(files.begin() + index);
if (complete) {
cout << Colors::MAGENTA() << "Only listing complete results" << endl;
} }
if (partial) { void Results::deleteResult(int index)
cout << Colors::MAGENTA() << "Only listing partial results" << endl;
}
auto i = 0;
cout << Colors::GREEN() << " # Date Model Score Name Score C/P Duration Title" << endl;
cout << "=== ========== ============ =========== =========== === ========= =============================================================" << endl;
bool odd = true;
for (const auto& result : files) {
auto color = odd ? Colors::BLUE() : Colors::CYAN();
cout << color << setw(3) << fixed << right << i++ << " ";
cout << result.to_string() << endl;
if (i == max && max != 0) {
break;
}
odd = !odd;
}
}
int Results::getIndex(const string& intent) const
{ {
string color; auto filename = files.at(index).getFilename();
if (intent == "delete") {
color = Colors::RED();
} else {
color = Colors::YELLOW();
}
cout << color << "Choose result to " << intent << " (cancel=-1): ";
string line;
getline(cin, line);
int index = stoi(line);
if (index >= -1 && index < static_cast<int>(files.size())) {
return index;
}
cout << "Invalid index" << endl;
return -1;
}
void Results::report(const int index, const bool excelReport)
{
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
auto data = files.at(index).load();
if (excelReport) {
ReportExcel reporter(data, compare, workbook);
reporter.show();
openExcel = true;
workbook = reporter.getWorkbook();
} else {
ReportConsole reporter(data, compare);
reporter.show();
}
}
void Results::showIndex(const int index, const int idx) const
{
auto data = files.at(index).load();
if (idx < 0 or idx >= static_cast<int>(data["results"].size())) {
cout << "Invalid index" << endl;
return;
}
cout << Colors::YELLOW() << "Showing " << files.at(index).getFilename() << endl;
ReportConsole reporter(data, compare, idx);
reporter.show();
}
void Results::menu()
{
char option;
int index;
bool finished = false;
string color, context;
string filename, line, options = "qldhsre";
while (!finished) {
if (indexList) {
color = Colors::GREEN();
context = " (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): ";
options = "qldhsre";
} else {
color = Colors::MAGENTA();
context = " (quit='q', list='l'): ";
options = "ql";
}
cout << Colors::RESET() << color;
cout << "Choose option " << context;
getline(cin, line);
if (line.size() == 0)
continue;
if (options.find(line[0]) != string::npos) {
if (line.size() > 1) {
cout << "Invalid option" << endl;
continue;
}
option = line[0];
} else {
if (all_of(line.begin(), line.end(), ::isdigit)) {
int idx = stoi(line);
if (indexList) {
// The value is about the files list
index = idx;
if (index >= 0 && index < max) {
report(index, false);
indexList = false;
continue;
}
} else {
// The value is about the result showed on screen
showIndex(index, idx);
continue;
}
}
cout << "Invalid option" << endl;
continue;
}
switch (option) {
case 'q':
finished = true;
break;
case 'l':
show();
indexList = true;
break;
case 'd':
index = getIndex("delete");
if (index == -1)
break;
filename = files[index].getFilename();
cout << "Deleting " << filename << endl;
remove((path + "/" + filename).c_str()); remove((path + "/" + filename).c_str());
files.erase(files.begin() + index); files.erase(files.begin() + index);
cout << "File: " + filename + " deleted!" << endl;
show();
indexList = true;
break;
case 'h':
index = getIndex("hide");
if (index == -1)
break;
filename = files[index].getFilename();
cout << "Hiding " << filename << endl;
rename((path + "/" + filename).c_str(), (path + "/." + filename).c_str());
files.erase(files.begin() + index);
show();
menu();
indexList = true;
break;
case 's':
sortList();
indexList = true;
show();
break;
case 'r':
index = getIndex("report");
if (index == -1)
break;
indexList = false;
report(index, false);
break;
case 'e':
index = getIndex("excel");
if (index == -1)
break;
indexList = true;
report(index, true);
break;
default:
cout << "Invalid option" << endl;
} }
} int Results::size() const
}
void Results::sortList()
{ {
cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', duration='u', model='m'): "; return files.size();
string line;
char option;
getline(cin, line);
if (line.size() == 0)
return;
if (line.size() > 1) {
cout << "Invalid option" << endl;
return;
}
option = line[0];
switch (option) {
case 'd':
sortDate();
break;
case 's':
sortScore();
break;
case 'u':
sortDuration();
break;
case 'm':
sortModel();
break;
default:
cout << "Invalid option" << endl;
}
} }
void Results::sortDate() void Results::sortDate()
{ {
@@ -251,19 +67,8 @@ namespace platform {
return a.getScore() > b.getScore(); return a.getScore() > b.getScore();
}); });
} }
void Results::manage() bool Results::empty() const
{ {
if (files.size() == 0) { return files.empty();
cout << "No results found!" << endl;
exit(0);
} }
sortDate();
show();
menu();
if (openExcel) {
workbook_close(workbook);
}
cout << Colors::RESET() << "Done!" << endl;
}
} }

View File

@@ -1,6 +1,5 @@
#ifndef RESULTS_H #ifndef RESULTS_H
#define RESULTS_H #define RESULTS_H
#include "xlsxwriter.h"
#include <map> #include <map>
#include <vector> #include <vector>
#include <string> #include <string>
@@ -12,35 +11,28 @@ namespace platform {
class Results { class Results {
public: public:
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) : Results(const string& path, const string& model, const string& score, bool complete, bool partial);
path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare)
{
load();
};
void manage();
private:
string path;
int max;
string model;
string scoreName;
bool complete;
bool partial;
bool indexList = true;
bool openExcel = false;
bool compare;
lxw_workbook* workbook = NULL;
vector<Result> files;
void load(); // Loads the list of results
void show() const;
void report(const int index, const bool excelReport);
void showIndex(const int index, const int idx) const;
int getIndex(const string& intent) const;
void menu();
void sortList();
void sortDate(); void sortDate();
void sortScore(); void sortScore();
void sortModel(); void sortModel();
void sortDuration(); void sortDuration();
int maxModelSize() const { return maxModel; };
void hideResult(int index, const string& pathHidden);
void deleteResult(int index);
int size() const;
bool empty() const;
vector<Result>::iterator begin() { return files.begin(); };
vector<Result>::iterator end() { return files.end(); };
Result& at(int index) { return files.at(index); };
private:
string path;
string model;
string scoreName;
bool complete;
bool partial;
int maxModel;
vector<Result> files;
void load(); // Loads the list of results
}; };
}; };

View File

@@ -15,5 +15,16 @@ namespace platform {
} }
return result; return result;
} }
static std::string trim(const std::string& str)
{
std::string result = str;
result.erase(result.begin(), std::find_if(result.begin(), result.end(), [](int ch) {
return !std::isspace(ch);
}));
result.erase(std::find_if(result.rbegin(), result.rend(), [](int ch) {
return !std::isspace(ch);
}).base(), result.end());
return result;
}
} }
#endif #endif

View File

@@ -29,15 +29,24 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
catch (...) { catch (...) {
throw runtime_error("Number of folds must be an decimal number"); throw runtime_error("Number of folds must be an decimal number");
}}); }});
return program;
}
int main(int argc, char** argv)
{
auto program = manageArguments(argc, argv);
string model, score;
bool build, report, friedman, excel;
double level;
try { try {
program.parse_args(argc, argv); program.parse_args(argc, argv);
auto model = program.get<string>("model"); model = program.get<string>("model");
auto score = program.get<string>("score"); score = program.get<string>("score");
auto build = program.get<bool>("build"); build = program.get<bool>("build");
auto report = program.get<bool>("report"); report = program.get<bool>("report");
auto friedman = program.get<bool>("friedman"); friedman = program.get<bool>("friedman");
auto excel = program.get<bool>("excel"); excel = program.get<bool>("excel");
auto level = program.get<double>("level"); level = program.get<double>("level");
if (model == "" || score == "") { if (model == "" || score == "") {
throw runtime_error("Model and score name must be supplied"); throw runtime_error("Model and score name must be supplied");
} }
@@ -46,11 +55,6 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
cerr << program; cerr << program;
exit(1); exit(1);
} }
if (excel && model != "any") {
cerr << "Excel ourput can only be used with all models" << endl;
cerr << program;
exit(1);
}
if (!report && !build) { if (!report && !build) {
cerr << "Either build, report or both, have to be selected to do anything!" << endl; cerr << "Either build, report or both, have to be selected to do anything!" << endl;
cerr << program; cerr << program;
@@ -62,19 +66,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
cerr << program; cerr << program;
exit(1); exit(1);
} }
return program; // Generate report
}
int main(int argc, char** argv)
{
auto program = manageArguments(argc, argv);
auto model = program.get<string>("model");
auto score = program.get<string>("score");
auto build = program.get<bool>("build");
auto report = program.get<bool>("report");
auto friedman = program.get<bool>("friedman");
auto excel = program.get<bool>("excel");
auto level = program.get<double>("level");
auto results = platform::BestResults(platform::Paths::results(), score, model, friedman, level); auto results = platform::BestResults(platform::Paths::results(), score, model, friedman, level);
if (build) { if (build) {
if (model == "any") { if (model == "any") {
@@ -88,7 +80,7 @@ int main(int argc, char** argv)
if (model == "any") { if (model == "any") {
results.reportAll(excel); results.reportAll(excel);
} else { } else {
results.reportSingle(); results.reportSingle(excel);
} }
} }
return 0; return 0;

View File

@@ -30,6 +30,7 @@ argparse::ArgumentParser manageArguments()
); );
program.add_argument("--title").default_value("").help("Experiment title"); program.add_argument("--title").default_value("").help("Experiment title");
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true); program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const string& value) { program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const string& value) {
@@ -55,7 +56,7 @@ int main(int argc, char** argv)
{ {
string file_name, model_name, title; string file_name, model_name, title;
json hyperparameters_json; json hyperparameters_json;
bool discretize_dataset, stratified, saveResults; bool discretize_dataset, stratified, saveResults, quiet;
vector<int> seeds; vector<int> seeds;
vector<string> filesToTest; vector<string> filesToTest;
int n_folds; int n_folds;
@@ -66,6 +67,7 @@ int main(int argc, char** argv)
model_name = program.get<string>("model"); model_name = program.get<string>("model");
discretize_dataset = program.get<bool>("discretize"); discretize_dataset = program.get<bool>("discretize");
stratified = program.get<bool>("stratified"); stratified = program.get<bool>("stratified");
quiet = program.get<bool>("quiet");
n_folds = program.get<int>("folds"); n_folds = program.get<int>("folds");
seeds = program.get<vector<int>>("seeds"); seeds = program.get<vector<int>>("seeds");
auto hyperparameters = program.get<string>("hyperparameters"); auto hyperparameters = program.get<string>("hyperparameters");
@@ -109,11 +111,12 @@ int main(int argc, char** argv)
} }
platform::Timer timer; platform::Timer timer;
timer.start(); timer.start();
experiment.go(filesToTest); experiment.go(filesToTest, quiet);
experiment.setDuration(timer.getDuration()); experiment.setDuration(timer.getDuration());
if (saveResults) { if (saveResults) {
experiment.save(platform::Paths::results()); experiment.save(platform::Paths::results());
} }
if (!quiet)
experiment.report(); experiment.report();
cout << "Done!" << endl; cout << "Done!" << endl;
return 0; return 0;

View File

@@ -1,7 +1,6 @@
#include <iostream> #include <iostream>
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
#include "Paths.h" #include "ManageResults.h"
#include "Results.h"
using namespace std; using namespace std;
@@ -37,15 +36,15 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
auto program = manageArguments(argc, argv); auto program = manageArguments(argc, argv);
auto number = program.get<int>("number"); int number = program.get<int>("number");
auto model = program.get<string>("model"); string model = program.get<string>("model");
auto score = program.get<string>("score"); string score = program.get<string>("score");
auto complete = program.get<bool>("complete"); auto complete = program.get<bool>("complete");
auto partial = program.get<bool>("partial"); auto partial = program.get<bool>("partial");
auto compare = program.get<bool>("compare"); auto compare = program.get<bool>("compare");
if (complete) if (complete)
partial = false; partial = false;
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare); auto manager = platform::ManageResults(number, model, score, complete, partial, compare);
results.manage(); manager.doMenu();
return 0; return 0;
} }

View File

@@ -1,248 +0,0 @@
#include "Folding.h"
#include <torch/torch.h>
#include "nlohmann/json.hpp"
#include "map"
#include <iostream>
#include <sstream>
#include "Datasets.h"
#include "Network.h"
#include "ArffFiles.h"
#include "CPPFImdlp.h"
#include "CFS.h"
#include "IWSS.h"
#include "FCBF.h"
using namespace std;
using namespace platform;
using namespace torch;
string counts(vector<int> y, vector<int> indices)
{
auto result = map<int, int>();
stringstream oss;
for (auto i = 0; i < indices.size(); ++i) {
result[y[indices[i]]]++;
}
string final_result = "";
for (auto i = 0; i < result.size(); ++i)
oss << i << " -> " << setprecision(2) << fixed
<< (double)result[i] * 100 / indices.size() << "% (" << result[i] << ") //";
oss << endl;
return oss.str();
}
class Paths {
public:
static string datasets()
{
return "datasets/";
}
};
pair<vector<mdlp::labels_t>, map<string, int>> discretize(vector<mdlp::samples_t>& X, mdlp::labels_t& y, vector<string> features)
{
vector<mdlp::labels_t> Xd;
map<string, int> maxes;
auto fimdlp = mdlp::CPPFImdlp();
for (int i = 0; i < X.size(); i++) {
fimdlp.fit(X[i], y);
mdlp::labels_t& xd = fimdlp.transform(X[i]);
maxes[features[i]] = *max_element(xd.begin(), xd.end()) + 1;
Xd.push_back(xd);
}
return { Xd, maxes };
}
vector<mdlp::labels_t> discretizeDataset(vector<mdlp::samples_t>& X, mdlp::labels_t& y)
{
vector<mdlp::labels_t> Xd;
auto fimdlp = mdlp::CPPFImdlp();
for (int i = 0; i < X.size(); i++) {
fimdlp.fit(X[i], y);
mdlp::labels_t& xd = fimdlp.transform(X[i]);
Xd.push_back(xd);
}
return Xd;
}
bool file_exists(const string& name)
{
if (FILE* file = fopen(name.c_str(), "r")) {
fclose(file);
return true;
} else {
return false;
}
}
tuple<Tensor, Tensor, vector<string>, string, map<string, vector<int>>> loadDataset(const string& name, bool class_last, bool discretize_dataset)
{
auto handler = ArffFiles();
handler.load(Paths::datasets() + static_cast<string>(name) + ".arff", class_last);
// Get Dataset X, y
vector<mdlp::samples_t>& X = handler.getX();
mdlp::labels_t& y = handler.getY();
// Get className & Features
auto className = handler.getClassName();
vector<string> features;
auto attributes = handler.getAttributes();
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; });
Tensor Xd;
auto states = map<string, vector<int>>();
if (discretize_dataset) {
auto Xr = discretizeDataset(X, y);
Xd = torch::zeros({ static_cast<int>(Xr.size()), static_cast<int>(Xr[0].size()) }, torch::kInt32);
for (int i = 0; i < features.size(); ++i) {
states[features[i]] = vector<int>(*max_element(Xr[i].begin(), Xr[i].end()) + 1);
auto item = states.at(features[i]);
iota(begin(item), end(item), 0);
Xd.index_put_({ i, "..." }, torch::tensor(Xr[i], torch::kInt32));
}
states[className] = vector<int>(*max_element(y.begin(), y.end()) + 1);
iota(begin(states.at(className)), end(states.at(className)), 0);
} else {
Xd = torch::zeros({ static_cast<int>(X.size()), static_cast<int>(X[0].size()) }, torch::kFloat32);
for (int i = 0; i < features.size(); ++i) {
Xd.index_put_({ i, "..." }, torch::tensor(X[i]));
}
}
return { Xd, torch::tensor(y, torch::kInt32), features, className, states };
}
tuple<vector<vector<int>>, vector<int>, vector<string>, string, map<string, vector<int>>> loadFile(const string& name)
{
auto handler = ArffFiles();
handler.load(Paths::datasets() + static_cast<string>(name) + ".arff");
// Get Dataset X, y
vector<mdlp::samples_t>& X = handler.getX();
mdlp::labels_t& y = handler.getY();
// Get className & Features
auto className = handler.getClassName();
vector<string> features;
auto attributes = handler.getAttributes();
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; });
// Discretize Dataset
vector<mdlp::labels_t> Xd;
map<string, int> maxes;
tie(Xd, maxes) = discretize(X, y, features);
maxes[className] = *max_element(y.begin(), y.end()) + 1;
map<string, vector<int>> states;
for (auto feature : features) {
states[feature] = vector<int>(maxes[feature]);
}
states[className] = vector<int>(maxes[className]);
return { Xd, y, features, className, states };
}
class RawDatasets {
public:
RawDatasets(const string& file_name, bool discretize)
{
// Xt can be either discretized or not
tie(Xt, yt, featurest, classNamet, statest) = loadDataset(file_name, true, discretize);
// Xv is always discretized
tie(Xv, yv, featuresv, classNamev, statesv) = loadFile(file_name);
auto yresized = torch::transpose(yt.view({ yt.size(0), 1 }), 0, 1);
dataset = torch::cat({ Xt, yresized }, 0);
nSamples = dataset.size(1);
weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
weightsv = vector<double>(nSamples, 1.0 / nSamples);
classNumStates = discretize ? statest.at(classNamet).size() : 0;
}
torch::Tensor Xt, yt, dataset, weights;
vector<vector<int>> Xv;
vector<double> weightsv;
vector<int> yv;
vector<string> featurest, featuresv;
map<string, vector<int>> statest, statesv;
string classNamet, classNamev;
int nSamples, classNumStates;
double epsilon = 1e-5;
};
int main()
{
// map<string, string> balance = {
// {"iris", "33,33% (50) / 33,33% (50) / 33,33% (50)"},
// {"diabetes", "34,90% (268) / 65,10% (500)"},
// {"ecoli", "42,56% (143) / 22,92% (77) / 0,60% (2) / 0,60% (2) / 10,42% (35) / 5,95% (20) / 1,49% (5) / 15,48% (52)"},
// {"glass", "32,71% (70) / 7,94% (17) / 4,21% (9) / 35,51% (76) / 13,55% (29) / 6,07% (13)"}
// };
// for (const auto& file_name : { "iris", "glass", "ecoli", "diabetes" }) {
// auto dt = Datasets(true, "Arff");
// auto [X, y] = dt.getVectors(file_name);
// //auto fold = KFold(5, 150);
// auto fold = StratifiedKFold(5, y, -1);
// cout << "***********************************************************************************************" << endl;
// cout << "Dataset: " << file_name << endl;
// cout << "Nº Samples: " << dt.getNSamples(file_name) << endl;
// cout << "Class states: " << dt.getNClasses(file_name) << endl;
// cout << "Balance: " << balance.at(file_name) << endl;
// for (int i = 0; i < 5; ++i) {
// cout << "Fold: " << i << endl;
// auto [train, test] = fold.getFold(i);
// cout << "Train: ";
// cout << "(" << train.size() << "): ";
// // for (auto j = 0; j < static_cast<int>(train.size()); j++)
// // cout << train[j] << ", ";
// cout << endl;
// cout << "Train Statistics : " << counts(y, train);
// cout << "-------------------------------------------------------------------------------" << endl;
// cout << "Test: ";
// cout << "(" << test.size() << "): ";
// // for (auto j = 0; j < static_cast<int>(test.size()); j++)
// // cout << test[j] << ", ";
// cout << endl;
// cout << "Test Statistics: " << counts(y, test);
// cout << "==============================================================================" << endl;
// }
// cout << "***********************************************************************************************" << endl;
// }
// const string file_name = "iris";
// auto net = bayesnet::Network();
// auto dt = Datasets(true, "Arff");
// auto raw = RawDatasets("iris", true);
// auto [X, y] = dt.getVectors(file_name);
// cout << "Dataset dims " << raw.dataset.sizes() << endl;
// cout << "weights dims " << raw.weights.sizes() << endl;
// cout << "States dims " << raw.statest.size() << endl;
// cout << "features: ";
// for (const auto& feature : raw.featurest) {
// cout << feature << ", ";
// net.addNode(feature);
// }
// net.addNode(raw.classNamet);
// cout << endl;
// net.fit(raw.dataset, raw.weights, raw.featurest, raw.classNamet, raw.statest);
auto dt = Datasets(true, "Arff");
nlohmann::json output;
for (const auto& name : dt.getNames()) {
// for (const auto& name : { "iris" }) {
auto [X, y] = dt.getTensors(name);
auto features = dt.getFeatures(name);
auto states = dt.getStates(name);
auto className = dt.getClassName(name);
int maxFeatures = 0;
auto classNumStates = states.at(className).size();
torch::Tensor weights = torch::full({ X.size(1) }, 1.0 / X.size(1), torch::kDouble);
auto dataset = X;
auto yresized = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
dataset = torch::cat({ dataset, yresized }, 0);
auto cfs = bayesnet::CFS(dataset, features, className, maxFeatures, classNumStates, weights);
auto fcbf = bayesnet::FCBF(dataset, features, className, maxFeatures, classNumStates, weights, 1e-7);
auto iwss = bayesnet::IWSS(dataset, features, className, maxFeatures, classNumStates, weights, 0.5);
cout << "Dataset: " << setw(20) << name << flush;
cfs.fit();
cout << " CFS: " << setw(4) << cfs.getFeatures().size() << flush;
fcbf.fit();
cout << " FCBF: " << setw(4) << fcbf.getFeatures().size() << flush;
iwss.fit();
cout << " IWSS: " << setw(4) << iwss.getFeatures().size() << flush;
cout << endl;
output[name]["CFS"] = cfs.getFeatures();
output[name]["FCBF"] = fcbf.getFeatures();
output[name]["IWSS"] = iwss.getFeatures();
}
ofstream file("features_cpp.json");
file << output;
file.close();
}