Compare commits
23 Commits
Boost_CFS
...
64f5a7f14a
Author | SHA1 | Date | |
---|---|---|---|
64f5a7f14a
|
|||
e03efb5f63
|
|||
92820555da
|
|||
5a3af51826
|
|||
a8f9800631
|
|||
84cec0c1e0
|
|||
130139f644
|
|||
651f84b562
|
|||
553ab0fa22
|
|||
4975feabff
|
|||
32293af69f
|
|||
858664be2d
|
|||
1f705f6018
|
|||
7bcd2eed06
|
|||
833acefbb3
|
|||
26b649ebae
|
|||
080eddf9cd
|
|||
04e754b2f5
|
|||
38423048bd
|
|||
64fc97b892
|
|||
2c2159f192
|
|||
6765552a7c
|
|||
f72aa5b9a6 |
@@ -1,5 +1,7 @@
|
|||||||
# BayesNet
|
# BayesNet
|
||||||
|
|
||||||
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
Bayesian Network Classifier with libtorch from scratch
|
Bayesian Network Classifier with libtorch from scratch
|
||||||
|
|
||||||
## 0. Setup
|
## 0. Setup
|
||||||
|
Submodule lib/catch2 updated: 9c541ca72e...766541d12d
2
lib/json
2
lib/json
Submodule lib/json updated: 5d2754306d...edffad036d
33
mac_mst.txt
33
mac_mst.txt
@@ -1,33 +0,0 @@
|
|||||||
Weights matrix:
|
|
||||||
0.0000000, 0.0384968, 0.0795434, 0.1546867, -0.0000000, 0.1788104, 0.2214721, 0.0323837, 0.0366549,
|
|
||||||
0.0384968, 0.0000000, 0.0200662, 0.0200937, -0.0000000, 0.0637224, 0.0183005, 0.0127657, 0.0136054,
|
|
||||||
0.0795434, 0.0200662, 0.0000000, 0.0605489, -0.0000000, 0.0894469, 0.1689408, 0.0321602, 0.0223184,
|
|
||||||
0.1546867, 0.0200937, 0.0605489, 0.0000000, -0.0000000, 0.1150757, 0.1332292, 0.0422865, 0.0191138,
|
|
||||||
-0.0000000, -0.0000000, -0.0000000, -0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000,
|
|
||||||
0.1788104, 0.0637224, 0.0894469, 0.1150757, 0.0000000, 0.0000000, 0.1407102, 0.0406590, 0.0366986,
|
|
||||||
0.2214721, 0.0183005, 0.1689408, 0.1332292, 0.0000000, 0.1407102, 0.0000000, 0.0427515, 0.0349965,
|
|
||||||
0.0323837, 0.0127657, 0.0321602, 0.0422865, 0.0000000, 0.0406590, 0.0427515, 0.0000000, 0.0343376,
|
|
||||||
0.0366549, 0.0136054, 0.0223184, 0.0191138, 0.0000000, 0.0366986, 0.0349965, 0.0343376, 0.0000000,
|
|
||||||
Edge : Weight
|
|
||||||
0 - 6 : 0.2214721
|
|
||||||
0 - 5 : 0.1788104
|
|
||||||
2 - 6 : 0.1689408
|
|
||||||
0 - 3 : 0.1546867
|
|
||||||
1 - 5 : 0.0637224
|
|
||||||
6 - 7 : 0.0427515
|
|
||||||
5 - 8 : 0.0366986
|
|
||||||
4 - 5 : 0.0000000
|
|
||||||
-------------------------------------------------------------------------------
|
|
||||||
Metrics Test
|
|
||||||
Test Maximum Spanning Tree
|
|
||||||
-------------------------------------------------------------------------------
|
|
||||||
/Users/rmontanana/Code/BayesNet/tests/TestBayesMetrics.cc:58
|
|
||||||
...............................................................................
|
|
||||||
|
|
||||||
/Users/rmontanana/Code/BayesNet/tests/TestBayesMetrics.cc:69: PASSED:
|
|
||||||
REQUIRE( result == resultsMST.at(file_name) )
|
|
||||||
with expansion:
|
|
||||||
(0, 6) (0, 5) (0, 3) (5, 1) (5, 8) (5, 4) (6, 2) (6, 7)
|
|
||||||
==
|
|
||||||
(0, 6) (0, 5) (0, 3) (5, 1) (5, 8) (5, 4) (6, 2) (6, 7)
|
|
||||||
|
|
@@ -4,7 +4,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <argparse/argparse.hpp>
|
#include <argparse/argparse.hpp>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include "ArffFiles.h"v
|
#include "ArffFiles.h"
|
||||||
#include "BayesMetrics.h"
|
#include "BayesMetrics.h"
|
||||||
#include "CPPFImdlp.h"
|
#include "CPPFImdlp.h"
|
||||||
#include "Folding.h"
|
#include "Folding.h"
|
||||||
|
@@ -108,8 +108,10 @@ namespace bayesnet {
|
|||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
unordered_set<int> featuresUsed;
|
unordered_set<int> featuresUsed;
|
||||||
|
int tolerance = 5; // number of times the accuracy can be lower than the threshold
|
||||||
if (selectFeatures) {
|
if (selectFeatures) {
|
||||||
featuresUsed = initializeModels();
|
featuresUsed = initializeModels();
|
||||||
|
tolerance = 0; // Remove tolerance if features are selected
|
||||||
}
|
}
|
||||||
if (maxModels == 0)
|
if (maxModels == 0)
|
||||||
maxModels = .1 * n > 10 ? .1 * n : n;
|
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||||
@@ -119,13 +121,13 @@ namespace bayesnet {
|
|||||||
double priorAccuracy = 0.0;
|
double priorAccuracy = 0.0;
|
||||||
double delta = 1.0;
|
double delta = 1.0;
|
||||||
double threshold = 1e-4;
|
double threshold = 1e-4;
|
||||||
int tolerance = 5; // number of times the accuracy can be lower than the threshold
|
|
||||||
int count = 0; // number of times the accuracy is lower than the threshold
|
int count = 0; // number of times the accuracy is lower than the threshold
|
||||||
fitted = true; // to enable predict
|
fitted = true; // to enable predict
|
||||||
// Step 0: Set the finish condition
|
// Step 0: Set the finish condition
|
||||||
// if not repeatSparent a finish condition is run out of features
|
// if not repeatSparent a finish condition is run out of features
|
||||||
// n_models == maxModels
|
// n_models == maxModels
|
||||||
// epsiolon sub t > 0.5 => inverse the weights policy
|
// epsilon sub t > 0.5 => inverse the weights policy
|
||||||
// validation error is not decreasing
|
// validation error is not decreasing
|
||||||
while (!exitCondition) {
|
while (!exitCondition) {
|
||||||
// Step 1: Build ranking with mutual information
|
// Step 1: Build ranking with mutual information
|
||||||
|
@@ -137,7 +137,7 @@ namespace bayesnet {
|
|||||||
int Classifier::getNumberOfNodes() const
|
int Classifier::getNumberOfNodes() const
|
||||||
{
|
{
|
||||||
// Features does not include class
|
// Features does not include class
|
||||||
return fitted ? model.getFeatures().size() + 1 : 0;
|
return fitted ? model.getFeatures().size() : 0;
|
||||||
}
|
}
|
||||||
int Classifier::getNumberOfEdges() const
|
int Classifier::getNumberOfEdges() const
|
||||||
{
|
{
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <algorithm>
|
||||||
#include "BestResults.h"
|
#include "BestResults.h"
|
||||||
#include "Result.h"
|
#include "Result.h"
|
||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
@@ -27,7 +28,6 @@ std::string ftime_to_string(TP tp)
|
|||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
|
||||||
string BestResults::build()
|
string BestResults::build()
|
||||||
{
|
{
|
||||||
auto files = loadResultFiles();
|
auto files = loadResultFiles();
|
||||||
@@ -65,12 +65,10 @@ namespace platform {
|
|||||||
file.close();
|
file.close();
|
||||||
return bestFileName;
|
return bestFileName;
|
||||||
}
|
}
|
||||||
|
|
||||||
string BestResults::bestResultFile()
|
string BestResults::bestResultFile()
|
||||||
{
|
{
|
||||||
return "best_results_" + score + "_" + model + ".json";
|
return "best_results_" + score + "_" + model + ".json";
|
||||||
}
|
}
|
||||||
|
|
||||||
pair<string, string> getModelScore(string name)
|
pair<string, string> getModelScore(string name)
|
||||||
{
|
{
|
||||||
// results_accuracy_BoostAODE_MacBookpro16_2023-09-06_12:27:00_1.json
|
// results_accuracy_BoostAODE_MacBookpro16_2023-09-06_12:27:00_1.json
|
||||||
@@ -82,7 +80,6 @@ namespace platform {
|
|||||||
string model = name.substr(pos2 + 1, pos - pos2 - 1);
|
string model = name.substr(pos2 + 1, pos - pos2 - 1);
|
||||||
return { model, score };
|
return { model, score };
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<string> BestResults::loadResultFiles()
|
vector<string> BestResults::loadResultFiles()
|
||||||
{
|
{
|
||||||
vector<string> files;
|
vector<string> files;
|
||||||
@@ -99,7 +96,6 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return files;
|
return files;
|
||||||
}
|
}
|
||||||
|
|
||||||
json BestResults::loadFile(const string& fileName)
|
json BestResults::loadFile(const string& fileName)
|
||||||
{
|
{
|
||||||
ifstream resultData(fileName);
|
ifstream resultData(fileName);
|
||||||
@@ -136,7 +132,6 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return datasets;
|
return datasets;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BestResults::buildAll()
|
void BestResults::buildAll()
|
||||||
{
|
{
|
||||||
auto models = getModels();
|
auto models = getModels();
|
||||||
@@ -147,8 +142,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
model = "any";
|
model = "any";
|
||||||
}
|
}
|
||||||
|
void BestResults::listFile()
|
||||||
void BestResults::reportSingle()
|
|
||||||
{
|
{
|
||||||
string bestFileName = path + bestResultFile();
|
string bestFileName = path + bestResultFile();
|
||||||
if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) {
|
if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) {
|
||||||
@@ -162,22 +156,35 @@ namespace platform {
|
|||||||
auto data = loadFile(bestFileName);
|
auto data = loadFile(bestFileName);
|
||||||
auto datasets = getDatasets(data);
|
auto datasets = getDatasets(data);
|
||||||
int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
|
int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
|
||||||
cout << Colors::GREEN() << "Best results for " << model << " and " << score << " as of " << date << endl;
|
int maxFileName = 0;
|
||||||
cout << "--------------------------------------------------------" << endl;
|
int maxHyper = 15;
|
||||||
cout << Colors::GREEN() << " # " << setw(maxDatasetName + 1) << left << string("Dataset") << "Score File Hyperparameters" << endl;
|
for (auto const& item : data.items()) {
|
||||||
cout << "=== " << string(maxDatasetName, '=') << " =========== ================================================================== ================================================= " << endl;
|
maxHyper = max(maxHyper, (int)item.value().at(1).dump().size());
|
||||||
|
maxFileName = max(maxFileName, (int)item.value().at(2).get<string>().size());
|
||||||
|
}
|
||||||
|
stringstream oss;
|
||||||
|
oss << Colors::GREEN() << "Best results for " << model << " as of " << date << endl;
|
||||||
|
cout << oss.str();
|
||||||
|
cout << string(oss.str().size() - 8, '-') << endl;
|
||||||
|
cout << Colors::GREEN() << " # " << setw(maxDatasetName + 1) << left << "Dataset" << "Score " << setw(maxFileName) << "File" << " Hyperparameters" << endl;
|
||||||
|
cout << "=== " << string(maxDatasetName, '=') << " =========== " << string(maxFileName, '=') << " " << string(maxHyper, '=') << endl;
|
||||||
auto i = 0;
|
auto i = 0;
|
||||||
bool odd = true;
|
bool odd = true;
|
||||||
|
double total = 0;
|
||||||
for (auto const& item : data.items()) {
|
for (auto const& item : data.items()) {
|
||||||
auto color = odd ? Colors::BLUE() : Colors::CYAN();
|
auto color = odd ? Colors::BLUE() : Colors::CYAN();
|
||||||
|
double value = item.value().at(0).get<double>();
|
||||||
cout << color << setw(3) << fixed << right << i++ << " ";
|
cout << color << setw(3) << fixed << right << i++ << " ";
|
||||||
cout << setw(maxDatasetName) << left << item.key() << " ";
|
cout << setw(maxDatasetName) << left << item.key() << " ";
|
||||||
cout << setw(11) << setprecision(9) << fixed << item.value().at(0).get<double>() << " ";
|
cout << setw(11) << setprecision(9) << fixed << value << " ";
|
||||||
cout << setw(66) << item.value().at(2).get<string>() << " ";
|
cout << setw(maxFileName) << item.value().at(2).get<string>() << " ";
|
||||||
cout << item.value().at(1) << " ";
|
cout << item.value().at(1) << " ";
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
total += value;
|
||||||
odd = !odd;
|
odd = !odd;
|
||||||
}
|
}
|
||||||
|
cout << Colors::GREEN() << "=== " << string(maxDatasetName, '=') << " ===========" << endl;
|
||||||
|
cout << setw(5 + maxDatasetName) << "Total.................. " << setw(11) << setprecision(8) << fixed << total << endl;
|
||||||
}
|
}
|
||||||
json BestResults::buildTableResults(vector<string> models)
|
json BestResults::buildTableResults(vector<string> models)
|
||||||
{
|
{
|
||||||
@@ -202,11 +209,12 @@ namespace platform {
|
|||||||
table["dateTable"] = ftime_to_string(maxDate);
|
table["dateTable"] = ftime_to_string(maxDate);
|
||||||
return table;
|
return table;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BestResults::printTableResults(vector<string> models, json table)
|
void BestResults::printTableResults(vector<string> models, json table)
|
||||||
{
|
{
|
||||||
cout << Colors::GREEN() << "Best results for " << score << " as of " << table.at("dateTable").get<string>() << endl;
|
stringstream oss;
|
||||||
cout << "------------------------------------------------" << endl;
|
oss << Colors::GREEN() << "Best results for " << score << " as of " << table.at("dateTable").get<string>() << endl;
|
||||||
|
cout << oss.str();
|
||||||
|
cout << string(oss.str().size() - 8, '-') << endl;
|
||||||
cout << Colors::GREEN() << " # " << setw(maxDatasetName + 1) << left << string("Dataset");
|
cout << Colors::GREEN() << " # " << setw(maxDatasetName + 1) << left << string("Dataset");
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
cout << setw(maxModelName) << left << model << " ";
|
cout << setw(maxModelName) << left << model << " ";
|
||||||
@@ -271,6 +279,19 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
void BestResults::reportSingle(bool excel)
|
||||||
|
{
|
||||||
|
listFile();
|
||||||
|
if (excel) {
|
||||||
|
auto models = getModels();
|
||||||
|
// Build the table of results
|
||||||
|
json table = buildTableResults(models);
|
||||||
|
vector<string> datasets = getDatasets(table.begin().value());
|
||||||
|
BestResultsExcel excel(score, datasets);
|
||||||
|
excel.reportSingle(model, path + bestResultFile());
|
||||||
|
messageExcelFile(excel.getFileName());
|
||||||
|
}
|
||||||
|
}
|
||||||
void BestResults::reportAll(bool excel)
|
void BestResults::reportAll(bool excel)
|
||||||
{
|
{
|
||||||
auto models = getModels();
|
auto models = getModels();
|
||||||
@@ -292,9 +313,32 @@ namespace platform {
|
|||||||
ranksModels = stats.getRanks();
|
ranksModels = stats.getRanks();
|
||||||
}
|
}
|
||||||
if (excel) {
|
if (excel) {
|
||||||
BestResultsExcel excel(score, models, datasets, table, ranksModels, friedman, significance);
|
BestResultsExcel excel(score, datasets);
|
||||||
excel.build();
|
excel.reportAll(models, table, ranksModels, friedman, significance);
|
||||||
cout << Colors::YELLOW() << "** Excel file generated: " << excel.getFileName() << Colors::RESET() << endl;
|
if (friedman) {
|
||||||
|
int idx = -1;
|
||||||
|
double min = 2000;
|
||||||
|
// Find out the control model
|
||||||
|
auto totals = vector<double>(models.size(), 0.0);
|
||||||
|
for (const auto& dataset : datasets) {
|
||||||
|
for (int i = 0; i < models.size(); ++i) {
|
||||||
|
totals[i] += ranksModels[dataset][models[i]];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < models.size(); ++i) {
|
||||||
|
if (totals[i] < min) {
|
||||||
|
min = totals[i];
|
||||||
|
idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model = models.at(idx);
|
||||||
|
excel.reportSingle(model, path + bestResultFile());
|
||||||
|
}
|
||||||
|
messageExcelFile(excel.getFileName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void BestResults::messageExcelFile(const string& fileName)
|
||||||
|
{
|
||||||
|
cout << Colors::YELLOW() << "** Excel file generated: " << fileName << Colors::RESET() << endl;
|
||||||
|
}
|
||||||
}
|
}
|
@@ -7,19 +7,24 @@ using json = nlohmann::json;
|
|||||||
namespace platform {
|
namespace platform {
|
||||||
class BestResults {
|
class BestResults {
|
||||||
public:
|
public:
|
||||||
explicit BestResults(const string& path, const string& score, const string& model, bool friedman, double significance = 0.05) : path(path), score(score), model(model), friedman(friedman), significance(significance) {}
|
explicit BestResults(const string& path, const string& score, const string& model, bool friedman, double significance = 0.05)
|
||||||
|
: path(path), score(score), model(model), friedman(friedman), significance(significance)
|
||||||
|
{
|
||||||
|
}
|
||||||
string build();
|
string build();
|
||||||
void reportSingle();
|
void reportSingle(bool excel);
|
||||||
void reportAll(bool excel);
|
void reportAll(bool excel);
|
||||||
void buildAll();
|
void buildAll();
|
||||||
private:
|
private:
|
||||||
vector<string> getModels();
|
vector<string> getModels();
|
||||||
vector<string> getDatasets(json table);
|
vector<string> getDatasets(json table);
|
||||||
vector<string> loadResultFiles();
|
vector<string> loadResultFiles();
|
||||||
|
void messageExcelFile(const string& fileName);
|
||||||
json buildTableResults(vector<string> models);
|
json buildTableResults(vector<string> models);
|
||||||
void printTableResults(vector<string> models, json table);
|
void printTableResults(vector<string> models, json table);
|
||||||
string bestResultFile();
|
string bestResultFile();
|
||||||
json loadFile(const string& fileName);
|
json loadFile(const string& fileName);
|
||||||
|
void listFile();
|
||||||
string path;
|
string path;
|
||||||
string score;
|
string score;
|
||||||
string model;
|
string model;
|
||||||
|
@@ -1,21 +1,122 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include "BestResultsExcel.h"
|
#include "BestResultsExcel.h"
|
||||||
#include "Paths.h"
|
#include "Paths.h"
|
||||||
|
#include <map>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
#include "Statistics.h"
|
#include "Statistics.h"
|
||||||
|
#include "ReportExcel.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
BestResultsExcel::BestResultsExcel(const string& score, const vector<string>& models, const vector<string>& datasets, const json& table, const map<string, map<string, float>>& ranksModels, bool friedman, double significance) :
|
json loadResultData(const string& fileName)
|
||||||
score(score), models(models), datasets(datasets), table(table), ranksModels(ranksModels), friedman(friedman), significance(significance)
|
{
|
||||||
|
json data;
|
||||||
|
ifstream resultData(fileName);
|
||||||
|
if (resultData.is_open()) {
|
||||||
|
data = json::parse(resultData);
|
||||||
|
} else {
|
||||||
|
throw invalid_argument("Unable to open result file. [" + fileName + "]");
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
string getColumnName(int colNum)
|
||||||
|
{
|
||||||
|
string columnName = "";
|
||||||
|
if (colNum == 0)
|
||||||
|
return "A";
|
||||||
|
while (colNum > 0) {
|
||||||
|
int modulo = colNum % 26;
|
||||||
|
columnName = char(65 + modulo) + columnName;
|
||||||
|
colNum = (int)((colNum - modulo) / 26);
|
||||||
|
}
|
||||||
|
return columnName;
|
||||||
|
}
|
||||||
|
BestResultsExcel::BestResultsExcel(const string& score, const vector<string>& datasets) : score(score), datasets(datasets)
|
||||||
{
|
{
|
||||||
workbook = workbook_new((Paths::excel() + fileName).c_str());
|
workbook = workbook_new((Paths::excel() + fileName).c_str());
|
||||||
worksheet = workbook_add_worksheet(workbook, "Best Results");
|
|
||||||
setProperties("Best Results");
|
setProperties("Best Results");
|
||||||
createFormats();
|
|
||||||
int maxModelName = (*max_element(models.begin(), models.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
|
|
||||||
modelNameSize = max(modelNameSize, maxModelName);
|
|
||||||
int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
|
int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
|
||||||
datasetNameSize = max(datasetNameSize, maxDatasetName);
|
datasetNameSize = max(datasetNameSize, maxDatasetName);
|
||||||
|
createFormats();
|
||||||
|
}
|
||||||
|
void BestResultsExcel::reportAll(const vector<string>& models, const json& table, const map<string, map<string, float>>& ranks, bool friedman, double significance)
|
||||||
|
{
|
||||||
|
this->table = table;
|
||||||
|
this->models = models;
|
||||||
|
ranksModels = ranks;
|
||||||
|
this->friedman = friedman;
|
||||||
|
this->significance = significance;
|
||||||
|
worksheet = workbook_add_worksheet(workbook, "Best Results");
|
||||||
|
int maxModelName = (*max_element(models.begin(), models.end(), [](const string& a, const string& b) { return a.size() < b.size(); })).size();
|
||||||
|
modelNameSize = max(modelNameSize, maxModelName);
|
||||||
formatColumns();
|
formatColumns();
|
||||||
|
build();
|
||||||
|
}
|
||||||
|
void BestResultsExcel::reportSingle(const string& model, const string& fileName)
|
||||||
|
{
|
||||||
|
worksheet = workbook_add_worksheet(workbook, "Report");
|
||||||
|
if (FILE* fileTest = fopen(fileName.c_str(), "r")) {
|
||||||
|
fclose(fileTest);
|
||||||
|
} else {
|
||||||
|
cerr << "File " << fileName << " doesn't exist." << endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
json data = loadResultData(fileName);
|
||||||
|
|
||||||
|
string title = "Best results for " + model;
|
||||||
|
worksheet_merge_range(worksheet, 0, 0, 0, 4, title.c_str(), styles["headerFirst"]);
|
||||||
|
// Body header
|
||||||
|
row = 3;
|
||||||
|
int col = 1;
|
||||||
|
writeString(row, 0, "Nº", "bodyHeader");
|
||||||
|
writeString(row, 1, "Dataset", "bodyHeader");
|
||||||
|
writeString(row, 2, "Score", "bodyHeader");
|
||||||
|
writeString(row, 3, "File", "bodyHeader");
|
||||||
|
writeString(row, 4, "Hyperparameters", "bodyHeader");
|
||||||
|
auto i = 0;
|
||||||
|
string hyperparameters;
|
||||||
|
int hypSize = 22;
|
||||||
|
map<string, string> files; // map of files imported and their tabs
|
||||||
|
for (auto const& item : data.items()) {
|
||||||
|
row++;
|
||||||
|
writeInt(row, 0, i++, "ints");
|
||||||
|
writeString(row, 1, item.key().c_str(), "text");
|
||||||
|
writeDouble(row, 2, item.value().at(0).get<double>(), "result");
|
||||||
|
auto fileName = item.value().at(2).get<string>();
|
||||||
|
string hyperlink = "";
|
||||||
|
try {
|
||||||
|
hyperlink = files.at(fileName);
|
||||||
|
}
|
||||||
|
catch (const out_of_range& oor) {
|
||||||
|
auto tabName = "table_" + to_string(i);
|
||||||
|
auto worksheetNew = workbook_add_worksheet(workbook, tabName.c_str());
|
||||||
|
json data = loadResultData(Paths::results() + fileName);
|
||||||
|
auto report = ReportExcel(data, false, workbook, worksheetNew);
|
||||||
|
report.show();
|
||||||
|
hyperlink = "#table_" + to_string(i);
|
||||||
|
files[fileName] = hyperlink;
|
||||||
|
}
|
||||||
|
hyperlink += "!H" + to_string(i + 6);
|
||||||
|
string fileNameText = "=HYPERLINK(\"" + hyperlink + "\",\"" + fileName + "\")";
|
||||||
|
worksheet_write_formula(worksheet, row, 3, fileNameText.c_str(), efectiveStyle("text"));
|
||||||
|
hyperparameters = item.value().at(1).dump();
|
||||||
|
if (hyperparameters.size() > hypSize) {
|
||||||
|
hypSize = hyperparameters.size();
|
||||||
|
}
|
||||||
|
writeString(row, 4, hyperparameters, "text");
|
||||||
|
}
|
||||||
|
row++;
|
||||||
|
// Set Totals
|
||||||
|
writeString(row, 1, "Total", "bodyHeader");
|
||||||
|
stringstream oss;
|
||||||
|
auto colName = getColumnName(2);
|
||||||
|
oss << "=sum(" << colName << "5:" << colName << row << ")";
|
||||||
|
worksheet_write_formula(worksheet, row, 2, oss.str().c_str(), styles["bodyHeader_odd"]);
|
||||||
|
// Set format
|
||||||
|
worksheet_freeze_panes(worksheet, 4, 2);
|
||||||
|
vector<int> columns_sizes = { 5, datasetNameSize, modelNameSize, 66, hypSize + 1 };
|
||||||
|
for (int i = 0; i < columns_sizes.size(); ++i) {
|
||||||
|
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
BestResultsExcel::~BestResultsExcel()
|
BestResultsExcel::~BestResultsExcel()
|
||||||
{
|
{
|
||||||
@@ -32,11 +133,30 @@ namespace platform {
|
|||||||
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
|
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void BestResultsExcel::addConditionalFormat(string formula)
|
||||||
|
{
|
||||||
|
// Add conditional format for max/min values in scores/ranks sheets
|
||||||
|
lxw_format* custom_format = workbook_add_format(workbook);
|
||||||
|
format_set_bg_color(custom_format, 0xFFC7CE);
|
||||||
|
format_set_font_color(custom_format, 0x9C0006);
|
||||||
|
// Create a conditional format object. A static object would also work.
|
||||||
|
lxw_conditional_format* conditional_format = (lxw_conditional_format*)calloc(1, sizeof(lxw_conditional_format));
|
||||||
|
conditional_format->type = LXW_CONDITIONAL_TYPE_FORMULA;
|
||||||
|
string col = getColumnName(models.size() + 1);
|
||||||
|
stringstream oss;
|
||||||
|
oss << "=C5=" << formula << "($C5:$" << col << "5)";
|
||||||
|
auto formulaValue = oss.str();
|
||||||
|
conditional_format->value_string = formulaValue.c_str();
|
||||||
|
conditional_format->format = custom_format;
|
||||||
|
worksheet_conditional_format_range(worksheet, 4, 2, datasets.size() + 3, models.size() + 1, conditional_format);
|
||||||
|
}
|
||||||
void BestResultsExcel::build()
|
void BestResultsExcel::build()
|
||||||
{
|
{
|
||||||
// Create Sheet with scores
|
// Create Sheet with scores
|
||||||
header(false);
|
header(false);
|
||||||
body(false);
|
body(false);
|
||||||
|
// Add conditional format for max values
|
||||||
|
addConditionalFormat("max");
|
||||||
footer(false);
|
footer(false);
|
||||||
if (friedman) {
|
if (friedman) {
|
||||||
// Create Sheet with ranks
|
// Create Sheet with ranks
|
||||||
@@ -44,6 +164,7 @@ namespace platform {
|
|||||||
formatColumns();
|
formatColumns();
|
||||||
header(true);
|
header(true);
|
||||||
body(true);
|
body(true);
|
||||||
|
addConditionalFormat("min");
|
||||||
footer(true);
|
footer(true);
|
||||||
// Create Sheet with Friedman Test
|
// Create Sheet with Friedman Test
|
||||||
doFriedman();
|
doFriedman();
|
||||||
@@ -90,7 +211,8 @@ namespace platform {
|
|||||||
int col = 1;
|
int col = 1;
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
stringstream oss;
|
stringstream oss;
|
||||||
oss << "=sum(indirect(address(" << 5 << "," << col + 2 << ")):indirect(address(" << row << "," << col + 2 << ")))";
|
auto colName = getColumnName(col + 1);
|
||||||
|
oss << "=SUM(" << colName << "5:" << colName << row << ")";
|
||||||
worksheet_write_formula(worksheet, row, ++col, oss.str().c_str(), styles["bodyHeader_odd"]);
|
worksheet_write_formula(worksheet, row, ++col, oss.str().c_str(), styles["bodyHeader_odd"]);
|
||||||
}
|
}
|
||||||
if (ranks) {
|
if (ranks) {
|
||||||
@@ -98,8 +220,9 @@ namespace platform {
|
|||||||
writeString(row, 1, "Average ranks", "bodyHeader");
|
writeString(row, 1, "Average ranks", "bodyHeader");
|
||||||
int col = 1;
|
int col = 1;
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
|
auto colName = getColumnName(col + 1);
|
||||||
stringstream oss;
|
stringstream oss;
|
||||||
oss << "=sum(indirect(address(" << 5 << "," << col + 2 << ")):indirect(address(" << row - 1 << "," << col + 2 << ")))/" << datasets.size();
|
oss << "=SUM(" << colName << "5:" << colName << row - 1 << ")/" << datasets.size();
|
||||||
worksheet_write_formula(worksheet, row, ++col, oss.str().c_str(), styles["bodyHeader_odd"]);
|
worksheet_write_formula(worksheet, row, ++col, oss.str().c_str(), styles["bodyHeader_odd"]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -12,16 +12,19 @@ namespace platform {
|
|||||||
|
|
||||||
class BestResultsExcel : ExcelFile {
|
class BestResultsExcel : ExcelFile {
|
||||||
public:
|
public:
|
||||||
BestResultsExcel(const string& score, const vector<string>& models, const vector<string>& datasets, const json& table, const map<string, map<string, float>>& ranks, bool friedman, double significance);
|
BestResultsExcel(const string& score, const vector<string>& datasets);
|
||||||
~BestResultsExcel();
|
~BestResultsExcel();
|
||||||
void build();
|
void reportAll(const vector<string>& models, const json& table, const map<string, map<string, float>>& ranks, bool friedman, double significance);
|
||||||
|
void reportSingle(const string& model, const string& fileName);
|
||||||
string getFileName();
|
string getFileName();
|
||||||
private:
|
private:
|
||||||
|
void build();
|
||||||
void header(bool ranks);
|
void header(bool ranks);
|
||||||
void body(bool ranks);
|
void body(bool ranks);
|
||||||
void footer(bool ranks);
|
void footer(bool ranks);
|
||||||
void formatColumns();
|
void formatColumns();
|
||||||
void doFriedman();
|
void doFriedman();
|
||||||
|
void addConditionalFormat(string formula);
|
||||||
const string fileName = "BestResults.xlsx";
|
const string fileName = "BestResults.xlsx";
|
||||||
string score;
|
string score;
|
||||||
vector<string> models;
|
vector<string> models;
|
||||||
|
@@ -1,10 +1,28 @@
|
|||||||
#ifndef BESTSCORE_H
|
#ifndef BESTSCORE_H
|
||||||
#define BESTSCORE_H
|
#define BESTSCORE_H
|
||||||
#include <string>
|
#include <string>
|
||||||
class BestScore {
|
#include <map>
|
||||||
public:
|
#include <utility>
|
||||||
static std::string title() { return "STree_default (linear-ovo)"; }
|
#include "DotEnv.h"
|
||||||
static double score() { return 22.109799; }
|
namespace platform {
|
||||||
static std::string scoreName() { return "accuracy"; }
|
class BestScore {
|
||||||
};
|
public:
|
||||||
|
static pair<string, double> getScore(const std::string& metric)
|
||||||
|
{
|
||||||
|
static map<pair<string, string>, pair<string, double>> data = {
|
||||||
|
{{"discretiz", "accuracy"}, {"STree_default (linear-ovo)", 22.109799}},
|
||||||
|
{{"odte", "accuracy"}, {"STree_default (linear-ovo)", 22.109799}},
|
||||||
|
};
|
||||||
|
auto env = platform::DotEnv();
|
||||||
|
string experiment = env.get("experiment");
|
||||||
|
try {
|
||||||
|
return data[{experiment, metric}];
|
||||||
|
}
|
||||||
|
catch (...) {
|
||||||
|
return { "", 0.0 };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
@@ -5,13 +5,13 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
|||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include)
|
||||||
add_executable(b_main main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc ReportConsole.cc ReportBase.cc)
|
|
||||||
add_executable(b_manage manage.cc Results.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
|
add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc ReportConsole.cc ReportBase.cc)
|
||||||
add_executable(b_list list.cc Datasets.cc Dataset.cc)
|
add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
|
||||||
add_executable(b_best best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ExcelFile.cc)
|
add_executable(b_list b_list.cc Datasets.cc Dataset.cc)
|
||||||
add_executable(testx testx.cpp Datasets.cc Dataset.cc Folding.cc )
|
add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
|
||||||
|
|
||||||
target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
||||||
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
|
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
|
||||||
target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}")
|
target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}" "${TORCH_LIBRARIES}" ArffFiles mdlp)
|
||||||
target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
||||||
target_link_libraries(testx ArffFiles BayesNet "${TORCH_LIBRARIES}")
|
|
87
src/Platform/CommandParser.cc
Normal file
87
src/Platform/CommandParser.cc
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
#include "CommandParser.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "Colors.h"
|
||||||
|
#include "Utils.h"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
void CommandParser::messageError(const string& message)
|
||||||
|
{
|
||||||
|
cout << Colors::RED() << message << Colors::RESET() << endl;
|
||||||
|
}
|
||||||
|
pair<char, int> CommandParser::parse(const string& color, const vector<tuple<string, char, bool>>& options, const char defaultCommand, const int maxIndex)
|
||||||
|
{
|
||||||
|
bool finished = false;
|
||||||
|
while (!finished) {
|
||||||
|
stringstream oss;
|
||||||
|
string line;
|
||||||
|
oss << color << "Choose option (";
|
||||||
|
bool first = true;
|
||||||
|
for (auto& option : options) {
|
||||||
|
if (first) {
|
||||||
|
first = false;
|
||||||
|
} else {
|
||||||
|
oss << ", ";
|
||||||
|
}
|
||||||
|
oss << get<char>(option) << "=" << get<string>(option);
|
||||||
|
}
|
||||||
|
oss << "): ";
|
||||||
|
cout << oss.str();
|
||||||
|
getline(cin, line);
|
||||||
|
cout << Colors::RESET();
|
||||||
|
line = trim(line);
|
||||||
|
if (line.size() == 0)
|
||||||
|
continue;
|
||||||
|
if (all_of(line.begin(), line.end(), ::isdigit)) {
|
||||||
|
command = defaultCommand;
|
||||||
|
index = stoi(line);
|
||||||
|
if (index > maxIndex || index < 0) {
|
||||||
|
messageError("Index out of range");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
finished = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
bool found = false;
|
||||||
|
for (auto& option : options) {
|
||||||
|
if (line[0] == get<char>(option)) {
|
||||||
|
found = true;
|
||||||
|
// it's a match
|
||||||
|
line.erase(line.begin());
|
||||||
|
line = trim(line);
|
||||||
|
if (get<bool>(option)) {
|
||||||
|
// The option requires a value
|
||||||
|
if (line.size() == 0) {
|
||||||
|
messageError("Option " + get<string>(option) + " requires a value");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
index = stoi(line);
|
||||||
|
if (index > maxIndex || index < 0) {
|
||||||
|
messageError("Index out of range");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (const std::invalid_argument& ia) {
|
||||||
|
messageError("Invalid value: " + line);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (line.size() > 0) {
|
||||||
|
messageError("option " + get<string>(option) + " doesn't accept values");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
command = get<char>(option);
|
||||||
|
finished = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
messageError("I don't know " + line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { command, index };
|
||||||
|
}
|
||||||
|
} /* namespace platform */
|
21
src/Platform/CommandParser.h
Normal file
21
src/Platform/CommandParser.h
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
#ifndef COMMAND_PARSER_H
|
||||||
|
#define COMMAND_PARSER_H
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <tuple>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
class CommandParser {
|
||||||
|
public:
|
||||||
|
CommandParser() = default;
|
||||||
|
pair<char, int> parse(const string& color, const vector<tuple<string, char, bool>>& options, const char defaultCommand, const int maxIndex);
|
||||||
|
char getCommand() const { return command; };
|
||||||
|
int getIndex() const { return index; };
|
||||||
|
private:
|
||||||
|
void messageError(const string& message);
|
||||||
|
char command;
|
||||||
|
int index;
|
||||||
|
};
|
||||||
|
} /* namespace platform */
|
||||||
|
#endif /* COMMAND_PARSER_H */
|
@@ -13,17 +13,6 @@ namespace platform {
|
|||||||
class DotEnv {
|
class DotEnv {
|
||||||
private:
|
private:
|
||||||
std::map<std::string, std::string> env;
|
std::map<std::string, std::string> env;
|
||||||
std::string trim(const std::string& str)
|
|
||||||
{
|
|
||||||
std::string result = str;
|
|
||||||
result.erase(result.begin(), std::find_if(result.begin(), result.end(), [](int ch) {
|
|
||||||
return !std::isspace(ch);
|
|
||||||
}));
|
|
||||||
result.erase(std::find_if(result.rbegin(), result.rend(), [](int ch) {
|
|
||||||
return !std::isspace(ch);
|
|
||||||
}).base(), result.end());
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
public:
|
public:
|
||||||
DotEnv()
|
DotEnv()
|
||||||
{
|
{
|
||||||
|
@@ -9,6 +9,10 @@ namespace platform {
|
|||||||
{
|
{
|
||||||
setDefault();
|
setDefault();
|
||||||
}
|
}
|
||||||
|
ExcelFile::ExcelFile(lxw_workbook* workbook, lxw_worksheet* worksheet) : workbook(workbook), worksheet(worksheet)
|
||||||
|
{
|
||||||
|
setDefault();
|
||||||
|
}
|
||||||
void ExcelFile::setDefault()
|
void ExcelFile::setDefault()
|
||||||
{
|
{
|
||||||
normalSize = 14; //font size for report body
|
normalSize = 14; //font size for report body
|
||||||
|
@@ -18,6 +18,7 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
ExcelFile();
|
ExcelFile();
|
||||||
ExcelFile(lxw_workbook* workbook);
|
ExcelFile(lxw_workbook* workbook);
|
||||||
|
ExcelFile(lxw_workbook* workbook, lxw_worksheet* worksheet);
|
||||||
lxw_workbook* getWorkbook();
|
lxw_workbook* getWorkbook();
|
||||||
protected:
|
protected:
|
||||||
void setProperties(string title);
|
void setProperties(string title);
|
||||||
|
@@ -102,12 +102,12 @@ namespace platform {
|
|||||||
cout << data.dump(4) << endl;
|
cout << data.dump(4) << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Experiment::go(vector<string> filesToProcess)
|
void Experiment::go(vector<string> filesToProcess, bool quiet)
|
||||||
{
|
{
|
||||||
cout << "*** Starting experiment: " << title << " ***" << endl;
|
cout << "*** Starting experiment: " << title << " ***" << endl;
|
||||||
for (auto fileName : filesToProcess) {
|
for (auto fileName : filesToProcess) {
|
||||||
cout << "- " << setw(20) << left << fileName << " " << right << flush;
|
cout << "- " << setw(20) << left << fileName << " " << right << flush;
|
||||||
cross_validation(fileName);
|
cross_validation(fileName, quiet);
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -132,7 +132,7 @@ namespace platform {
|
|||||||
cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
||||||
|
|
||||||
}
|
}
|
||||||
void Experiment::cross_validation(const string& fileName)
|
void Experiment::cross_validation(const string& fileName, bool quiet)
|
||||||
{
|
{
|
||||||
auto datasets = platform::Datasets(discretized, Paths::datasets());
|
auto datasets = platform::Datasets(discretized, Paths::datasets());
|
||||||
// Get dataset
|
// Get dataset
|
||||||
@@ -141,7 +141,9 @@ namespace platform {
|
|||||||
auto features = datasets.getFeatures(fileName);
|
auto features = datasets.getFeatures(fileName);
|
||||||
auto samples = datasets.getNSamples(fileName);
|
auto samples = datasets.getNSamples(fileName);
|
||||||
auto className = datasets.getClassName(fileName);
|
auto className = datasets.getClassName(fileName);
|
||||||
|
if (!quiet) {
|
||||||
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
||||||
|
}
|
||||||
// Prepare Result
|
// Prepare Result
|
||||||
auto result = Result();
|
auto result = Result();
|
||||||
auto [values, counts] = at::_unique(y);
|
auto [values, counts] = at::_unique(y);
|
||||||
@@ -159,6 +161,7 @@ namespace platform {
|
|||||||
Timer train_timer, test_timer;
|
Timer train_timer, test_timer;
|
||||||
int item = 0;
|
int item = 0;
|
||||||
for (auto seed : randomSeeds) {
|
for (auto seed : randomSeeds) {
|
||||||
|
if (!quiet)
|
||||||
cout << "(" << seed << ") doing Fold: " << flush;
|
cout << "(" << seed << ") doing Fold: " << flush;
|
||||||
Fold* fold;
|
Fold* fold;
|
||||||
if (stratified)
|
if (stratified)
|
||||||
@@ -180,9 +183,11 @@ namespace platform {
|
|||||||
auto y_train = y.index({ train_t });
|
auto y_train = y.index({ train_t });
|
||||||
auto X_test = X.index({ "...", test_t });
|
auto X_test = X.index({ "...", test_t });
|
||||||
auto y_test = y.index({ test_t });
|
auto y_test = y.index({ test_t });
|
||||||
|
if (!quiet)
|
||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
|
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
|
||||||
// Train model
|
// Train model
|
||||||
clf->fit(X_train, y_train, features, className, states);
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
|
if (!quiet)
|
||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
|
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
|
||||||
nodes[item] = clf->getNumberOfNodes();
|
nodes[item] = clf->getNumberOfNodes();
|
||||||
edges[item] = clf->getNumberOfEdges();
|
edges[item] = clf->getNumberOfEdges();
|
||||||
@@ -191,12 +196,14 @@ namespace platform {
|
|||||||
// Score train
|
// Score train
|
||||||
auto accuracy_train_value = clf->score(X_train, y_train);
|
auto accuracy_train_value = clf->score(X_train, y_train);
|
||||||
// Test model
|
// Test model
|
||||||
|
if (!quiet)
|
||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
||||||
test_timer.start();
|
test_timer.start();
|
||||||
auto accuracy_test_value = clf->score(X_test, y_test);
|
auto accuracy_test_value = clf->score(X_test, y_test);
|
||||||
test_time[item] = test_timer.getDuration();
|
test_time[item] = test_timer.getDuration();
|
||||||
accuracy_train[item] = accuracy_train_value;
|
accuracy_train[item] = accuracy_train_value;
|
||||||
accuracy_test[item] = accuracy_test_value;
|
accuracy_test[item] = accuracy_test_value;
|
||||||
|
if (!quiet)
|
||||||
cout << "\b\b\b, " << flush;
|
cout << "\b\b\b, " << flush;
|
||||||
// Store results and times in vector
|
// Store results and times in vector
|
||||||
result.addScoreTrain(accuracy_train_value);
|
result.addScoreTrain(accuracy_train_value);
|
||||||
@@ -206,6 +213,7 @@ namespace platform {
|
|||||||
item++;
|
item++;
|
||||||
clf.reset();
|
clf.reset();
|
||||||
}
|
}
|
||||||
|
if (!quiet)
|
||||||
cout << "end. " << flush;
|
cout << "end. " << flush;
|
||||||
delete fold;
|
delete fold;
|
||||||
}
|
}
|
||||||
|
@@ -108,8 +108,8 @@ namespace platform {
|
|||||||
Experiment& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
|
Experiment& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
|
||||||
string get_file_name();
|
string get_file_name();
|
||||||
void save(const string& path);
|
void save(const string& path);
|
||||||
void cross_validation(const string& fileName);
|
void cross_validation(const string& fileName, bool quiet);
|
||||||
void go(vector<string> filesToProcess);
|
void go(vector<string> filesToProcess, bool quiet);
|
||||||
void show();
|
void show();
|
||||||
void report();
|
void report();
|
||||||
};
|
};
|
||||||
|
213
src/Platform/ManageResults.cc
Normal file
213
src/Platform/ManageResults.cc
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
#include "ManageResults.h"
|
||||||
|
#include "CommandParser.h"
|
||||||
|
#include <filesystem>
|
||||||
|
#include <tuple>
|
||||||
|
#include "Colors.h"
|
||||||
|
#include "CLocale.h"
|
||||||
|
#include "Paths.h"
|
||||||
|
#include "ReportConsole.h"
|
||||||
|
#include "ReportExcel.h"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
|
||||||
|
ManageResults::ManageResults(int numFiles, const string& model, const string& score, bool complete, bool partial, bool compare) :
|
||||||
|
numFiles{ numFiles }, complete{ complete }, partial{ partial }, compare{ compare }, results(Results(Paths::results(), model, score, complete, partial))
|
||||||
|
{
|
||||||
|
indexList = true;
|
||||||
|
openExcel = false;
|
||||||
|
workbook = NULL;
|
||||||
|
if (numFiles == 0) {
|
||||||
|
this->numFiles = results.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void ManageResults::doMenu()
|
||||||
|
{
|
||||||
|
if (results.empty()) {
|
||||||
|
cout << Colors::MAGENTA() << "No results found!" << Colors::RESET() << endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
results.sortDate();
|
||||||
|
list();
|
||||||
|
menu();
|
||||||
|
if (openExcel) {
|
||||||
|
workbook_close(workbook);
|
||||||
|
}
|
||||||
|
cout << Colors::RESET() << "Done!" << endl;
|
||||||
|
}
|
||||||
|
void ManageResults::list()
|
||||||
|
{
|
||||||
|
auto temp = ConfigLocale();
|
||||||
|
string suffix = numFiles != results.size() ? " of " + to_string(results.size()) : "";
|
||||||
|
stringstream oss;
|
||||||
|
oss << "Results on screen: " << numFiles << suffix;
|
||||||
|
cout << Colors::GREEN() << oss.str() << endl;
|
||||||
|
cout << string(oss.str().size(), '-') << endl;
|
||||||
|
if (complete) {
|
||||||
|
cout << Colors::MAGENTA() << "Only listing complete results" << endl;
|
||||||
|
}
|
||||||
|
if (partial) {
|
||||||
|
cout << Colors::MAGENTA() << "Only listing partial results" << endl;
|
||||||
|
}
|
||||||
|
auto i = 0;
|
||||||
|
int maxModel = results.maxModelSize();
|
||||||
|
cout << Colors::GREEN() << " # Date " << setw(maxModel) << left << "Model" << " Score Name Score C/P Duration Title" << endl;
|
||||||
|
cout << "=== ========== " << string(maxModel, '=') << " =========== =========== === ========= =============================================================" << endl;
|
||||||
|
bool odd = true;
|
||||||
|
for (auto& result : results) {
|
||||||
|
auto color = odd ? Colors::BLUE() : Colors::CYAN();
|
||||||
|
cout << color << setw(3) << fixed << right << i++ << " ";
|
||||||
|
cout << result.to_string(maxModel) << endl;
|
||||||
|
if (i == numFiles) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
odd = !odd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool ManageResults::confirmAction(const string& intent, const string& fileName) const
|
||||||
|
{
|
||||||
|
string color;
|
||||||
|
if (intent == "delete") {
|
||||||
|
color = Colors::RED();
|
||||||
|
} else {
|
||||||
|
color = Colors::YELLOW();
|
||||||
|
}
|
||||||
|
string line;
|
||||||
|
bool finished = false;
|
||||||
|
while (!finished) {
|
||||||
|
cout << color << "Really want to " << intent << " " << fileName << "? (y/n): ";
|
||||||
|
getline(cin, line);
|
||||||
|
finished = line.size() == 1 && (tolower(line[0]) == 'y' || tolower(line[0] == 'n'));
|
||||||
|
}
|
||||||
|
if (tolower(line[0]) == 'y') {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
cout << "Not done!" << endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void ManageResults::report(const int index, const bool excelReport)
|
||||||
|
{
|
||||||
|
cout << Colors::YELLOW() << "Reporting " << results.at(index).getFilename() << endl;
|
||||||
|
auto data = results.at(index).load();
|
||||||
|
if (excelReport) {
|
||||||
|
ReportExcel reporter(data, compare, workbook);
|
||||||
|
reporter.show();
|
||||||
|
openExcel = true;
|
||||||
|
workbook = reporter.getWorkbook();
|
||||||
|
cout << "Adding sheet to " << Paths::excel() + Paths::excelResults() << endl;
|
||||||
|
} else {
|
||||||
|
ReportConsole reporter(data, compare);
|
||||||
|
reporter.show();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void ManageResults::showIndex(const int index, const int idx)
|
||||||
|
{
|
||||||
|
// Show a dataset result inside a report
|
||||||
|
auto data = results.at(index).load();
|
||||||
|
cout << Colors::YELLOW() << "Showing " << results.at(index).getFilename() << endl;
|
||||||
|
ReportConsole reporter(data, compare, idx);
|
||||||
|
reporter.show();
|
||||||
|
}
|
||||||
|
void ManageResults::sortList()
|
||||||
|
{
|
||||||
|
cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', duration='u', model='m'): ";
|
||||||
|
string line;
|
||||||
|
char option;
|
||||||
|
getline(cin, line);
|
||||||
|
if (line.size() == 0)
|
||||||
|
return;
|
||||||
|
if (line.size() > 1) {
|
||||||
|
cout << "Invalid option" << endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
option = line[0];
|
||||||
|
switch (option) {
|
||||||
|
case 'd':
|
||||||
|
results.sortDate();
|
||||||
|
break;
|
||||||
|
case 's':
|
||||||
|
results.sortScore();
|
||||||
|
break;
|
||||||
|
case 'u':
|
||||||
|
results.sortDuration();
|
||||||
|
break;
|
||||||
|
case 'm':
|
||||||
|
results.sortModel();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
cout << "Invalid option" << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void ManageResults::menu()
|
||||||
|
{
|
||||||
|
char option;
|
||||||
|
int index, subIndex;
|
||||||
|
bool finished = false;
|
||||||
|
string filename;
|
||||||
|
// tuple<Option, digit, requires value>
|
||||||
|
vector<tuple<string, char, bool>> mainOptions = {
|
||||||
|
{"quit", 'q', false},
|
||||||
|
{"list", 'l', false},
|
||||||
|
{"delete", 'd', true},
|
||||||
|
{"hide", 'h', true},
|
||||||
|
{"sort", 's', false},
|
||||||
|
{"report", 'r', true},
|
||||||
|
{"excel", 'e', true}
|
||||||
|
};
|
||||||
|
vector<tuple<string, char, bool>> listOptions = {
|
||||||
|
{"report", 'r', true},
|
||||||
|
{"list", 'l', false},
|
||||||
|
{"quit", 'q', false}
|
||||||
|
};
|
||||||
|
auto parser = CommandParser();
|
||||||
|
while (!finished) {
|
||||||
|
if (indexList) {
|
||||||
|
tie(option, index) = parser.parse(Colors::GREEN(), mainOptions, 'r', numFiles - 1);
|
||||||
|
} else {
|
||||||
|
tie(option, subIndex) = parser.parse(Colors::MAGENTA(), listOptions, 'r', results.at(index).load()["results"].size() - 1);
|
||||||
|
}
|
||||||
|
switch (option) {
|
||||||
|
case 'q':
|
||||||
|
finished = true;
|
||||||
|
break;
|
||||||
|
case 'l':
|
||||||
|
list();
|
||||||
|
indexList = true;
|
||||||
|
break;
|
||||||
|
case 'd':
|
||||||
|
filename = results.at(index).getFilename();
|
||||||
|
if (!confirmAction("delete", filename))
|
||||||
|
break;
|
||||||
|
cout << "Deleting " << filename << endl;
|
||||||
|
results.deleteResult(index);
|
||||||
|
cout << "File: " + filename + " deleted!" << endl;
|
||||||
|
list();
|
||||||
|
break;
|
||||||
|
case 'h':
|
||||||
|
filename = results.at(index).getFilename();
|
||||||
|
if (!confirmAction("hide", filename))
|
||||||
|
break;
|
||||||
|
filename = results.at(index).getFilename();
|
||||||
|
cout << "Hiding " << filename << endl;
|
||||||
|
results.hideResult(index, Paths::hiddenResults());
|
||||||
|
cout << "File: " + filename + " hidden! (moved to " << Paths::hiddenResults() << ")" << endl;
|
||||||
|
list();
|
||||||
|
break;
|
||||||
|
case 's':
|
||||||
|
sortList();
|
||||||
|
list();
|
||||||
|
break;
|
||||||
|
case 'r':
|
||||||
|
if (indexList) {
|
||||||
|
report(index, false);
|
||||||
|
indexList = false;
|
||||||
|
} else {
|
||||||
|
showIndex(index, subIndex);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 'e':
|
||||||
|
report(index, true);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} /* namespace platform */
|
31
src/Platform/ManageResults.h
Normal file
31
src/Platform/ManageResults.h
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#ifndef MANAGE_RESULTS_H
|
||||||
|
#define MANAGE_RESULTS_H
|
||||||
|
#include "Results.h"
|
||||||
|
#include "xlsxwriter.h"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
class ManageResults {
|
||||||
|
public:
|
||||||
|
ManageResults(int numFiles, const string& model, const string& score, bool complete, bool partial, bool compare);
|
||||||
|
~ManageResults() = default;
|
||||||
|
void doMenu();
|
||||||
|
private:
|
||||||
|
void list();
|
||||||
|
bool confirmAction(const string& intent, const string& fileName) const;
|
||||||
|
void report(const int index, const bool excelReport);
|
||||||
|
void showIndex(const int index, const int idx);
|
||||||
|
void sortList();
|
||||||
|
void menu();
|
||||||
|
int numFiles;
|
||||||
|
bool indexList;
|
||||||
|
bool openExcel;
|
||||||
|
bool complete;
|
||||||
|
bool partial;
|
||||||
|
bool compare;
|
||||||
|
Results results;
|
||||||
|
lxw_workbook* workbook;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif /* MANAGE_RESULTS_H */
|
@@ -6,6 +6,7 @@ namespace platform {
|
|||||||
class Paths {
|
class Paths {
|
||||||
public:
|
public:
|
||||||
static std::string results() { return "results/"; }
|
static std::string results() { return "results/"; }
|
||||||
|
static std::string hiddenResults() { return "hidden_results/"; }
|
||||||
static std::string excel() { return "excel/"; }
|
static std::string excel() { return "excel/"; }
|
||||||
static std::string cfs() { return "cfs/"; }
|
static std::string cfs() { return "cfs/"; }
|
||||||
static std::string datasets()
|
static std::string datasets()
|
||||||
@@ -13,6 +14,7 @@ namespace platform {
|
|||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
return env.get("source_data");
|
return env.get("source_data");
|
||||||
}
|
}
|
||||||
|
static std::string excelResults() { return "some_results.xlsx"; }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -2,7 +2,6 @@
|
|||||||
#include <locale>
|
#include <locale>
|
||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
#include "ReportBase.h"
|
#include "ReportBase.h"
|
||||||
#include "BestScore.h"
|
|
||||||
#include "DotEnv.h"
|
#include "DotEnv.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <locale>
|
#include <locale>
|
||||||
#include "ReportConsole.h"
|
#include "ReportConsole.h"
|
||||||
@@ -28,8 +29,15 @@ namespace platform {
|
|||||||
void ReportConsole::body()
|
void ReportConsole::body()
|
||||||
{
|
{
|
||||||
auto tmp = ConfigLocale();
|
auto tmp = ConfigLocale();
|
||||||
cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
|
int maxHyper = 15;
|
||||||
cout << "=== ========================= ====== ===== === ========= ========= ========= =============== =================== ====================" << endl;
|
int maxDataset = 7;
|
||||||
|
for (const auto& r : data["results"]) {
|
||||||
|
maxHyper = max(maxHyper, (int)r["hyperparameters"].dump().size());
|
||||||
|
maxDataset = max(maxDataset, (int)r["dataset"].get<string>().size());
|
||||||
|
|
||||||
|
}
|
||||||
|
cout << Colors::GREEN() << " # " << setw(maxDataset) << left << "Dataset" << " Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
|
||||||
|
cout << "=== " << string(maxDataset, '=') << " ====== ===== === ========= ========= ========= =============== =================== " << string(maxHyper, '=') << endl;
|
||||||
json lastResult;
|
json lastResult;
|
||||||
double totalScore = 0.0;
|
double totalScore = 0.0;
|
||||||
bool odd = true;
|
bool odd = true;
|
||||||
@@ -41,8 +49,8 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||||
cout << color;
|
cout << color;
|
||||||
cout << setw(3) << index++ << " ";
|
cout << setw(3) << right << index++ << " ";
|
||||||
cout << setw(25) << left << r["dataset"].get<string>() << " ";
|
cout << setw(maxDataset) << left << r["dataset"].get<string>() << " ";
|
||||||
cout << setw(6) << right << r["samples"].get<int>() << " ";
|
cout << setw(6) << right << r["samples"].get<int>() << " ";
|
||||||
cout << setw(5) << right << r["features"].get<int>() << " ";
|
cout << setw(5) << right << r["features"].get<int>() << " ";
|
||||||
cout << setw(3) << right << r["classes"].get<int>() << " ";
|
cout << setw(3) << right << r["classes"].get<int>() << " ";
|
||||||
@@ -87,9 +95,10 @@ namespace platform {
|
|||||||
cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
|
cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
|
||||||
showSummary();
|
showSummary();
|
||||||
auto score = data["score_name"].get<string>();
|
auto score = data["score_name"].get<string>();
|
||||||
if (score == BestScore::scoreName()) {
|
auto best = BestScore::getScore(score);
|
||||||
|
if (best.first != "") {
|
||||||
stringstream oss;
|
stringstream oss;
|
||||||
oss << score << " compared to " << BestScore::title() << " .: " << totalScore / BestScore::score();
|
oss << score << " compared to " << best.first << " .: " << totalScore / best.second;
|
||||||
cout << headerLine(oss.str());
|
cout << headerLine(oss.str());
|
||||||
}
|
}
|
||||||
if (!getExistBestFile() && compare) {
|
if (!getExistBestFile() && compare) {
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
#ifndef REPORTCONSOLE_H
|
#ifndef REPORTCONSOLE_H
|
||||||
#define REPORTCONSOLE_H
|
#define REPORTCONSOLE_H
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <iostream>
|
|
||||||
#include "ReportBase.h"
|
#include "ReportBase.h"
|
||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
|
|
||||||
|
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
|
||||||
ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook) : ReportBase(data_, compare), ExcelFile(workbook)
|
ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook, lxw_worksheet* worksheet) : ReportBase(data_, compare), ExcelFile(workbook, worksheet)
|
||||||
{
|
{
|
||||||
createFile();
|
createFile();
|
||||||
}
|
}
|
||||||
@@ -19,12 +19,8 @@ namespace platform {
|
|||||||
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
|
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void ReportExcel::createWorksheet()
|
||||||
void ReportExcel::createFile()
|
|
||||||
{
|
{
|
||||||
if (workbook == NULL) {
|
|
||||||
workbook = workbook_new((Paths::excel() + fileName).c_str());
|
|
||||||
}
|
|
||||||
const string name = data["model"].get<string>();
|
const string name = data["model"].get<string>();
|
||||||
string suffix = "";
|
string suffix = "";
|
||||||
string efectiveName;
|
string efectiveName;
|
||||||
@@ -42,7 +38,16 @@ namespace platform {
|
|||||||
throw invalid_argument("Couldn't create sheet " + efectiveName);
|
throw invalid_argument("Couldn't create sheet " + efectiveName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << "Adding sheet " << efectiveName << " to " << Paths::excel() + fileName << endl;
|
}
|
||||||
|
|
||||||
|
void ReportExcel::createFile()
|
||||||
|
{
|
||||||
|
if (workbook == NULL) {
|
||||||
|
workbook = workbook_new((Paths::excel() + Paths::excelResults()).c_str());
|
||||||
|
}
|
||||||
|
if (worksheet == NULL) {
|
||||||
|
createWorksheet();
|
||||||
|
}
|
||||||
setProperties(data["title"].get<string>());
|
setProperties(data["title"].get<string>());
|
||||||
createFormats();
|
createFormats();
|
||||||
formatColumns();
|
formatColumns();
|
||||||
@@ -115,14 +120,7 @@ namespace platform {
|
|||||||
writeString(row, col + 9, status, "textCentered");
|
writeString(row, col + 9, status, "textCentered");
|
||||||
writeDouble(row, col + 10, r["time"].get<double>(), "time");
|
writeDouble(row, col + 10, r["time"].get<double>(), "time");
|
||||||
writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
|
writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
|
||||||
try {
|
hyperparameters = r["hyperparameters"].dump();
|
||||||
hyperparameters = r["hyperparameters"].get<string>();
|
|
||||||
}
|
|
||||||
catch (const exception& err) {
|
|
||||||
stringstream oss;
|
|
||||||
oss << r["hyperparameters"];
|
|
||||||
hyperparameters = oss.str();
|
|
||||||
}
|
|
||||||
if (hyperparameters.size() > hypSize) {
|
if (hyperparameters.size() > hypSize) {
|
||||||
hypSize = hyperparameters.size();
|
hypSize = hyperparameters.size();
|
||||||
}
|
}
|
||||||
@@ -130,7 +128,6 @@ namespace platform {
|
|||||||
lastResult = r;
|
lastResult = r;
|
||||||
totalScore += r["score"].get<double>();
|
totalScore += r["score"].get<double>();
|
||||||
row++;
|
row++;
|
||||||
|
|
||||||
}
|
}
|
||||||
// Set the right column width of hyperparameters with the maximum length
|
// Set the right column width of hyperparameters with the maximum length
|
||||||
worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL);
|
worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL);
|
||||||
@@ -171,9 +168,10 @@ namespace platform {
|
|||||||
showSummary();
|
showSummary();
|
||||||
row += 4 + summary.size();
|
row += 4 + summary.size();
|
||||||
auto score = data["score_name"].get<string>();
|
auto score = data["score_name"].get<string>();
|
||||||
if (score == BestScore::scoreName()) {
|
auto best = BestScore::getScore(score);
|
||||||
worksheet_merge_range(worksheet, row, 1, row, 5, (score + " compared to " + BestScore::title() + " .:").c_str(), efectiveStyle("text"));
|
if (best.first != "") {
|
||||||
writeDouble(row, 6, totalScore / BestScore::score(), "result");
|
worksheet_merge_range(worksheet, row, 1, row, 5, (score + " compared to " + best.first + " .:").c_str(), efectiveStyle("text"));
|
||||||
|
writeDouble(row, 6, totalScore / best.second, "result");
|
||||||
}
|
}
|
||||||
if (!getExistBestFile() && compare) {
|
if (!getExistBestFile() && compare) {
|
||||||
worksheet_write_string(worksheet, row + 1, 0, "*** Best Results File not found. Couldn't compare any result!", styles["summaryStyle"]);
|
worksheet_write_string(worksheet, row + 1, 0, "*** Best Results File not found. Couldn't compare any result!", styles["summaryStyle"]);
|
||||||
|
@@ -9,11 +9,11 @@ namespace platform {
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
class ReportExcel : public ReportBase, public ExcelFile {
|
class ReportExcel : public ReportBase, public ExcelFile {
|
||||||
public:
|
public:
|
||||||
explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook);
|
explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook, lxw_worksheet* worksheet = NULL);
|
||||||
private:
|
private:
|
||||||
const string fileName = "some_results.xlsx";
|
|
||||||
void formatColumns();
|
void formatColumns();
|
||||||
void createFile();
|
void createFile();
|
||||||
|
void createWorksheet();
|
||||||
void closeFile();
|
void closeFile();
|
||||||
void header() override;
|
void header() override;
|
||||||
void body() override;
|
void body() override;
|
||||||
|
@@ -1,9 +1,10 @@
|
|||||||
|
#include "Result.h"
|
||||||
|
#include "BestScore.h"
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include "Result.h"
|
|
||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
#include "BestScore.h"
|
#include "DotEnv.h"
|
||||||
#include "CLocale.h"
|
#include "CLocale.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
@@ -18,8 +19,9 @@ namespace platform {
|
|||||||
score += result["score"].get<double>();
|
score += result["score"].get<double>();
|
||||||
}
|
}
|
||||||
scoreName = data["score_name"];
|
scoreName = data["score_name"];
|
||||||
if (scoreName == BestScore::scoreName()) {
|
auto best = BestScore::getScore(scoreName);
|
||||||
score /= BestScore::score();
|
if (best.first != "") {
|
||||||
|
score /= best.second;
|
||||||
}
|
}
|
||||||
title = data["title"];
|
title = data["title"];
|
||||||
duration = data["duration"];
|
duration = data["duration"];
|
||||||
@@ -37,14 +39,14 @@ namespace platform {
|
|||||||
throw invalid_argument("Unable to open result file. [" + path + "/" + filename + "]");
|
throw invalid_argument("Unable to open result file. [" + path + "/" + filename + "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
string Result::to_string() const
|
string Result::to_string(int maxModel) const
|
||||||
{
|
{
|
||||||
auto tmp = ConfigLocale();
|
auto tmp = ConfigLocale();
|
||||||
stringstream oss;
|
stringstream oss;
|
||||||
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
|
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
|
||||||
string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
|
string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
|
||||||
oss << date << " ";
|
oss << date << " ";
|
||||||
oss << setw(12) << left << model << " ";
|
oss << setw(maxModel) << left << model << " ";
|
||||||
oss << setw(11) << left << scoreName << " ";
|
oss << setw(11) << left << scoreName << " ";
|
||||||
oss << right << setw(11) << setprecision(7) << fixed << score << " ";
|
oss << right << setw(11) << setprecision(7) << fixed << score << " ";
|
||||||
auto completeString = isComplete() ? "C" : "P";
|
auto completeString = isComplete() ? "C" : "P";
|
||||||
|
@@ -12,7 +12,7 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
Result(const string& path, const string& filename);
|
Result(const string& path, const string& filename);
|
||||||
json load() const;
|
json load() const;
|
||||||
string to_string() const;
|
string to_string(int maxModel) const;
|
||||||
string getFilename() const { return filename; };
|
string getFilename() const { return filename; };
|
||||||
string getDate() const { return date; };
|
string getDate() const { return date; };
|
||||||
double getScore() const { return score; };
|
double getScore() const { return score; };
|
||||||
|
@@ -1,11 +1,17 @@
|
|||||||
#include <filesystem>
|
|
||||||
#include "Results.h"
|
#include "Results.h"
|
||||||
#include "ReportConsole.h"
|
#include <algorithm>
|
||||||
#include "ReportExcel.h"
|
|
||||||
#include "BestScore.h"
|
|
||||||
#include "Colors.h"
|
|
||||||
#include "CLocale.h"
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
Results::Results(const string& path, const string& model, const string& score, bool complete, bool partial) :
|
||||||
|
path(path), model(model), scoreName(score), complete(complete), partial(partial)
|
||||||
|
{
|
||||||
|
load();
|
||||||
|
if (!files.empty()) {
|
||||||
|
maxModel = (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size();
|
||||||
|
} else {
|
||||||
|
maxModel = 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
void Results::load()
|
void Results::load()
|
||||||
{
|
{
|
||||||
using std::filesystem::directory_iterator;
|
using std::filesystem::directory_iterator;
|
||||||
@@ -20,212 +26,22 @@ namespace platform {
|
|||||||
files.push_back(result);
|
files.push_back(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (max == 0) {
|
|
||||||
max = files.size();
|
|
||||||
}
|
}
|
||||||
}
|
void Results::hideResult(int index, const string& pathHidden)
|
||||||
void Results::show() const
|
|
||||||
{
|
{
|
||||||
auto temp = ConfigLocale();
|
auto filename = files.at(index).getFilename();
|
||||||
cout << Colors::GREEN() << "Results found: " << files.size() << endl;
|
rename((path + "/" + filename).c_str(), (pathHidden + "/" + filename).c_str());
|
||||||
cout << "-------------------" << endl;
|
files.erase(files.begin() + index);
|
||||||
if (complete) {
|
|
||||||
cout << Colors::MAGENTA() << "Only listing complete results" << endl;
|
|
||||||
}
|
}
|
||||||
if (partial) {
|
void Results::deleteResult(int index)
|
||||||
cout << Colors::MAGENTA() << "Only listing partial results" << endl;
|
|
||||||
}
|
|
||||||
auto i = 0;
|
|
||||||
cout << Colors::GREEN() << " # Date Model Score Name Score C/P Duration Title" << endl;
|
|
||||||
cout << "=== ========== ============ =========== =========== === ========= =============================================================" << endl;
|
|
||||||
bool odd = true;
|
|
||||||
for (const auto& result : files) {
|
|
||||||
auto color = odd ? Colors::BLUE() : Colors::CYAN();
|
|
||||||
cout << color << setw(3) << fixed << right << i++ << " ";
|
|
||||||
cout << result.to_string() << endl;
|
|
||||||
if (i == max && max != 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
odd = !odd;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int Results::getIndex(const string& intent) const
|
|
||||||
{
|
{
|
||||||
string color;
|
auto filename = files.at(index).getFilename();
|
||||||
if (intent == "delete") {
|
|
||||||
color = Colors::RED();
|
|
||||||
} else {
|
|
||||||
color = Colors::YELLOW();
|
|
||||||
}
|
|
||||||
cout << color << "Choose result to " << intent << " (cancel=-1): ";
|
|
||||||
string line;
|
|
||||||
getline(cin, line);
|
|
||||||
int index = stoi(line);
|
|
||||||
if (index >= -1 && index < static_cast<int>(files.size())) {
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
cout << "Invalid index" << endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
void Results::report(const int index, const bool excelReport)
|
|
||||||
{
|
|
||||||
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
|
|
||||||
auto data = files.at(index).load();
|
|
||||||
if (excelReport) {
|
|
||||||
ReportExcel reporter(data, compare, workbook);
|
|
||||||
reporter.show();
|
|
||||||
openExcel = true;
|
|
||||||
workbook = reporter.getWorkbook();
|
|
||||||
} else {
|
|
||||||
ReportConsole reporter(data, compare);
|
|
||||||
reporter.show();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void Results::showIndex(const int index, const int idx) const
|
|
||||||
{
|
|
||||||
auto data = files.at(index).load();
|
|
||||||
if (idx < 0 or idx >= static_cast<int>(data["results"].size())) {
|
|
||||||
cout << "Invalid index" << endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
cout << Colors::YELLOW() << "Showing " << files.at(index).getFilename() << endl;
|
|
||||||
ReportConsole reporter(data, compare, idx);
|
|
||||||
reporter.show();
|
|
||||||
}
|
|
||||||
void Results::menu()
|
|
||||||
{
|
|
||||||
char option;
|
|
||||||
int index;
|
|
||||||
bool finished = false;
|
|
||||||
string color, context;
|
|
||||||
string filename, line, options = "qldhsre";
|
|
||||||
while (!finished) {
|
|
||||||
if (indexList) {
|
|
||||||
color = Colors::GREEN();
|
|
||||||
context = " (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): ";
|
|
||||||
options = "qldhsre";
|
|
||||||
} else {
|
|
||||||
color = Colors::MAGENTA();
|
|
||||||
context = " (quit='q', list='l'): ";
|
|
||||||
options = "ql";
|
|
||||||
}
|
|
||||||
cout << Colors::RESET() << color;
|
|
||||||
|
|
||||||
cout << "Choose option " << context;
|
|
||||||
getline(cin, line);
|
|
||||||
if (line.size() == 0)
|
|
||||||
continue;
|
|
||||||
if (options.find(line[0]) != string::npos) {
|
|
||||||
if (line.size() > 1) {
|
|
||||||
cout << "Invalid option" << endl;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
option = line[0];
|
|
||||||
} else {
|
|
||||||
if (all_of(line.begin(), line.end(), ::isdigit)) {
|
|
||||||
int idx = stoi(line);
|
|
||||||
if (indexList) {
|
|
||||||
// The value is about the files list
|
|
||||||
index = idx;
|
|
||||||
if (index >= 0 && index < max) {
|
|
||||||
report(index, false);
|
|
||||||
indexList = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// The value is about the result showed on screen
|
|
||||||
showIndex(index, idx);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cout << "Invalid option" << endl;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
switch (option) {
|
|
||||||
case 'q':
|
|
||||||
finished = true;
|
|
||||||
break;
|
|
||||||
case 'l':
|
|
||||||
show();
|
|
||||||
indexList = true;
|
|
||||||
break;
|
|
||||||
case 'd':
|
|
||||||
index = getIndex("delete");
|
|
||||||
if (index == -1)
|
|
||||||
break;
|
|
||||||
filename = files[index].getFilename();
|
|
||||||
cout << "Deleting " << filename << endl;
|
|
||||||
remove((path + "/" + filename).c_str());
|
remove((path + "/" + filename).c_str());
|
||||||
files.erase(files.begin() + index);
|
files.erase(files.begin() + index);
|
||||||
cout << "File: " + filename + " deleted!" << endl;
|
|
||||||
show();
|
|
||||||
indexList = true;
|
|
||||||
break;
|
|
||||||
case 'h':
|
|
||||||
index = getIndex("hide");
|
|
||||||
if (index == -1)
|
|
||||||
break;
|
|
||||||
filename = files[index].getFilename();
|
|
||||||
cout << "Hiding " << filename << endl;
|
|
||||||
rename((path + "/" + filename).c_str(), (path + "/." + filename).c_str());
|
|
||||||
files.erase(files.begin() + index);
|
|
||||||
show();
|
|
||||||
menu();
|
|
||||||
indexList = true;
|
|
||||||
break;
|
|
||||||
case 's':
|
|
||||||
sortList();
|
|
||||||
indexList = true;
|
|
||||||
show();
|
|
||||||
break;
|
|
||||||
case 'r':
|
|
||||||
index = getIndex("report");
|
|
||||||
if (index == -1)
|
|
||||||
break;
|
|
||||||
indexList = false;
|
|
||||||
report(index, false);
|
|
||||||
break;
|
|
||||||
case 'e':
|
|
||||||
index = getIndex("excel");
|
|
||||||
if (index == -1)
|
|
||||||
break;
|
|
||||||
indexList = true;
|
|
||||||
report(index, true);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
cout << "Invalid option" << endl;
|
|
||||||
}
|
}
|
||||||
}
|
int Results::size() const
|
||||||
}
|
|
||||||
void Results::sortList()
|
|
||||||
{
|
{
|
||||||
cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', duration='u', model='m'): ";
|
return files.size();
|
||||||
string line;
|
|
||||||
char option;
|
|
||||||
getline(cin, line);
|
|
||||||
if (line.size() == 0)
|
|
||||||
return;
|
|
||||||
if (line.size() > 1) {
|
|
||||||
cout << "Invalid option" << endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
option = line[0];
|
|
||||||
switch (option) {
|
|
||||||
case 'd':
|
|
||||||
sortDate();
|
|
||||||
break;
|
|
||||||
case 's':
|
|
||||||
sortScore();
|
|
||||||
break;
|
|
||||||
case 'u':
|
|
||||||
sortDuration();
|
|
||||||
break;
|
|
||||||
case 'm':
|
|
||||||
sortModel();
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
cout << "Invalid option" << endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
void Results::sortDate()
|
void Results::sortDate()
|
||||||
{
|
{
|
||||||
@@ -251,19 +67,8 @@ namespace platform {
|
|||||||
return a.getScore() > b.getScore();
|
return a.getScore() > b.getScore();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
void Results::manage()
|
bool Results::empty() const
|
||||||
{
|
{
|
||||||
if (files.size() == 0) {
|
return files.empty();
|
||||||
cout << "No results found!" << endl;
|
|
||||||
exit(0);
|
|
||||||
}
|
}
|
||||||
sortDate();
|
|
||||||
show();
|
|
||||||
menu();
|
|
||||||
if (openExcel) {
|
|
||||||
workbook_close(workbook);
|
|
||||||
}
|
|
||||||
cout << Colors::RESET() << "Done!" << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
@@ -1,6 +1,5 @@
|
|||||||
#ifndef RESULTS_H
|
#ifndef RESULTS_H
|
||||||
#define RESULTS_H
|
#define RESULTS_H
|
||||||
#include "xlsxwriter.h"
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
@@ -12,35 +11,28 @@ namespace platform {
|
|||||||
|
|
||||||
class Results {
|
class Results {
|
||||||
public:
|
public:
|
||||||
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) :
|
Results(const string& path, const string& model, const string& score, bool complete, bool partial);
|
||||||
path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare)
|
|
||||||
{
|
|
||||||
load();
|
|
||||||
};
|
|
||||||
void manage();
|
|
||||||
private:
|
|
||||||
string path;
|
|
||||||
int max;
|
|
||||||
string model;
|
|
||||||
string scoreName;
|
|
||||||
bool complete;
|
|
||||||
bool partial;
|
|
||||||
bool indexList = true;
|
|
||||||
bool openExcel = false;
|
|
||||||
bool compare;
|
|
||||||
lxw_workbook* workbook = NULL;
|
|
||||||
vector<Result> files;
|
|
||||||
void load(); // Loads the list of results
|
|
||||||
void show() const;
|
|
||||||
void report(const int index, const bool excelReport);
|
|
||||||
void showIndex(const int index, const int idx) const;
|
|
||||||
int getIndex(const string& intent) const;
|
|
||||||
void menu();
|
|
||||||
void sortList();
|
|
||||||
void sortDate();
|
void sortDate();
|
||||||
void sortScore();
|
void sortScore();
|
||||||
void sortModel();
|
void sortModel();
|
||||||
void sortDuration();
|
void sortDuration();
|
||||||
|
int maxModelSize() const { return maxModel; };
|
||||||
|
void hideResult(int index, const string& pathHidden);
|
||||||
|
void deleteResult(int index);
|
||||||
|
int size() const;
|
||||||
|
bool empty() const;
|
||||||
|
vector<Result>::iterator begin() { return files.begin(); };
|
||||||
|
vector<Result>::iterator end() { return files.end(); };
|
||||||
|
Result& at(int index) { return files.at(index); };
|
||||||
|
private:
|
||||||
|
string path;
|
||||||
|
string model;
|
||||||
|
string scoreName;
|
||||||
|
bool complete;
|
||||||
|
bool partial;
|
||||||
|
int maxModel;
|
||||||
|
vector<Result> files;
|
||||||
|
void load(); // Loads the list of results
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -15,5 +15,16 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
static std::string trim(const std::string& str)
|
||||||
|
{
|
||||||
|
std::string result = str;
|
||||||
|
result.erase(result.begin(), std::find_if(result.begin(), result.end(), [](int ch) {
|
||||||
|
return !std::isspace(ch);
|
||||||
|
}));
|
||||||
|
result.erase(std::find_if(result.rbegin(), result.rend(), [](int ch) {
|
||||||
|
return !std::isspace(ch);
|
||||||
|
}).base(), result.end());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -29,15 +29,24 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
catch (...) {
|
catch (...) {
|
||||||
throw runtime_error("Number of folds must be an decimal number");
|
throw runtime_error("Number of folds must be an decimal number");
|
||||||
}});
|
}});
|
||||||
|
return program;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
auto program = manageArguments(argc, argv);
|
||||||
|
string model, score;
|
||||||
|
bool build, report, friedman, excel;
|
||||||
|
double level;
|
||||||
try {
|
try {
|
||||||
program.parse_args(argc, argv);
|
program.parse_args(argc, argv);
|
||||||
auto model = program.get<string>("model");
|
model = program.get<string>("model");
|
||||||
auto score = program.get<string>("score");
|
score = program.get<string>("score");
|
||||||
auto build = program.get<bool>("build");
|
build = program.get<bool>("build");
|
||||||
auto report = program.get<bool>("report");
|
report = program.get<bool>("report");
|
||||||
auto friedman = program.get<bool>("friedman");
|
friedman = program.get<bool>("friedman");
|
||||||
auto excel = program.get<bool>("excel");
|
excel = program.get<bool>("excel");
|
||||||
auto level = program.get<double>("level");
|
level = program.get<double>("level");
|
||||||
if (model == "" || score == "") {
|
if (model == "" || score == "") {
|
||||||
throw runtime_error("Model and score name must be supplied");
|
throw runtime_error("Model and score name must be supplied");
|
||||||
}
|
}
|
||||||
@@ -46,11 +55,6 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
cerr << program;
|
cerr << program;
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
if (excel && model != "any") {
|
|
||||||
cerr << "Excel ourput can only be used with all models" << endl;
|
|
||||||
cerr << program;
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
if (!report && !build) {
|
if (!report && !build) {
|
||||||
cerr << "Either build, report or both, have to be selected to do anything!" << endl;
|
cerr << "Either build, report or both, have to be selected to do anything!" << endl;
|
||||||
cerr << program;
|
cerr << program;
|
||||||
@@ -62,19 +66,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
cerr << program;
|
cerr << program;
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
return program;
|
// Generate report
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
|
||||||
{
|
|
||||||
auto program = manageArguments(argc, argv);
|
|
||||||
auto model = program.get<string>("model");
|
|
||||||
auto score = program.get<string>("score");
|
|
||||||
auto build = program.get<bool>("build");
|
|
||||||
auto report = program.get<bool>("report");
|
|
||||||
auto friedman = program.get<bool>("friedman");
|
|
||||||
auto excel = program.get<bool>("excel");
|
|
||||||
auto level = program.get<double>("level");
|
|
||||||
auto results = platform::BestResults(platform::Paths::results(), score, model, friedman, level);
|
auto results = platform::BestResults(platform::Paths::results(), score, model, friedman, level);
|
||||||
if (build) {
|
if (build) {
|
||||||
if (model == "any") {
|
if (model == "any") {
|
||||||
@@ -88,7 +80,7 @@ int main(int argc, char** argv)
|
|||||||
if (model == "any") {
|
if (model == "any") {
|
||||||
results.reportAll(excel);
|
results.reportAll(excel);
|
||||||
} else {
|
} else {
|
||||||
results.reportSingle();
|
results.reportSingle(excel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
@@ -30,6 +30,7 @@ argparse::ArgumentParser manageArguments()
|
|||||||
);
|
);
|
||||||
program.add_argument("--title").default_value("").help("Experiment title");
|
program.add_argument("--title").default_value("").help("Experiment title");
|
||||||
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
|
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const string& value) {
|
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const string& value) {
|
||||||
@@ -55,7 +56,7 @@ int main(int argc, char** argv)
|
|||||||
{
|
{
|
||||||
string file_name, model_name, title;
|
string file_name, model_name, title;
|
||||||
json hyperparameters_json;
|
json hyperparameters_json;
|
||||||
bool discretize_dataset, stratified, saveResults;
|
bool discretize_dataset, stratified, saveResults, quiet;
|
||||||
vector<int> seeds;
|
vector<int> seeds;
|
||||||
vector<string> filesToTest;
|
vector<string> filesToTest;
|
||||||
int n_folds;
|
int n_folds;
|
||||||
@@ -66,6 +67,7 @@ int main(int argc, char** argv)
|
|||||||
model_name = program.get<string>("model");
|
model_name = program.get<string>("model");
|
||||||
discretize_dataset = program.get<bool>("discretize");
|
discretize_dataset = program.get<bool>("discretize");
|
||||||
stratified = program.get<bool>("stratified");
|
stratified = program.get<bool>("stratified");
|
||||||
|
quiet = program.get<bool>("quiet");
|
||||||
n_folds = program.get<int>("folds");
|
n_folds = program.get<int>("folds");
|
||||||
seeds = program.get<vector<int>>("seeds");
|
seeds = program.get<vector<int>>("seeds");
|
||||||
auto hyperparameters = program.get<string>("hyperparameters");
|
auto hyperparameters = program.get<string>("hyperparameters");
|
||||||
@@ -109,11 +111,12 @@ int main(int argc, char** argv)
|
|||||||
}
|
}
|
||||||
platform::Timer timer;
|
platform::Timer timer;
|
||||||
timer.start();
|
timer.start();
|
||||||
experiment.go(filesToTest);
|
experiment.go(filesToTest, quiet);
|
||||||
experiment.setDuration(timer.getDuration());
|
experiment.setDuration(timer.getDuration());
|
||||||
if (saveResults) {
|
if (saveResults) {
|
||||||
experiment.save(platform::Paths::results());
|
experiment.save(platform::Paths::results());
|
||||||
}
|
}
|
||||||
|
if (!quiet)
|
||||||
experiment.report();
|
experiment.report();
|
||||||
cout << "Done!" << endl;
|
cout << "Done!" << endl;
|
||||||
return 0;
|
return 0;
|
@@ -1,7 +1,6 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <argparse/argparse.hpp>
|
#include <argparse/argparse.hpp>
|
||||||
#include "Paths.h"
|
#include "ManageResults.h"
|
||||||
#include "Results.h"
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
@@ -37,15 +36,15 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
auto program = manageArguments(argc, argv);
|
auto program = manageArguments(argc, argv);
|
||||||
auto number = program.get<int>("number");
|
int number = program.get<int>("number");
|
||||||
auto model = program.get<string>("model");
|
string model = program.get<string>("model");
|
||||||
auto score = program.get<string>("score");
|
string score = program.get<string>("score");
|
||||||
auto complete = program.get<bool>("complete");
|
auto complete = program.get<bool>("complete");
|
||||||
auto partial = program.get<bool>("partial");
|
auto partial = program.get<bool>("partial");
|
||||||
auto compare = program.get<bool>("compare");
|
auto compare = program.get<bool>("compare");
|
||||||
if (complete)
|
if (complete)
|
||||||
partial = false;
|
partial = false;
|
||||||
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare);
|
auto manager = platform::ManageResults(number, model, score, complete, partial, compare);
|
||||||
results.manage();
|
manager.doMenu();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
@@ -1,248 +0,0 @@
|
|||||||
#include "Folding.h"
|
|
||||||
#include <torch/torch.h>
|
|
||||||
#include "nlohmann/json.hpp"
|
|
||||||
#include "map"
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
#include "Datasets.h"
|
|
||||||
#include "Network.h"
|
|
||||||
#include "ArffFiles.h"
|
|
||||||
#include "CPPFImdlp.h"
|
|
||||||
#include "CFS.h"
|
|
||||||
#include "IWSS.h"
|
|
||||||
#include "FCBF.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace platform;
|
|
||||||
using namespace torch;
|
|
||||||
|
|
||||||
string counts(vector<int> y, vector<int> indices)
|
|
||||||
{
|
|
||||||
auto result = map<int, int>();
|
|
||||||
stringstream oss;
|
|
||||||
for (auto i = 0; i < indices.size(); ++i) {
|
|
||||||
result[y[indices[i]]]++;
|
|
||||||
}
|
|
||||||
string final_result = "";
|
|
||||||
for (auto i = 0; i < result.size(); ++i)
|
|
||||||
oss << i << " -> " << setprecision(2) << fixed
|
|
||||||
<< (double)result[i] * 100 / indices.size() << "% (" << result[i] << ") //";
|
|
||||||
oss << endl;
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
class Paths {
|
|
||||||
public:
|
|
||||||
static string datasets()
|
|
||||||
{
|
|
||||||
return "datasets/";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pair<vector<mdlp::labels_t>, map<string, int>> discretize(vector<mdlp::samples_t>& X, mdlp::labels_t& y, vector<string> features)
|
|
||||||
{
|
|
||||||
vector<mdlp::labels_t> Xd;
|
|
||||||
map<string, int> maxes;
|
|
||||||
auto fimdlp = mdlp::CPPFImdlp();
|
|
||||||
for (int i = 0; i < X.size(); i++) {
|
|
||||||
fimdlp.fit(X[i], y);
|
|
||||||
mdlp::labels_t& xd = fimdlp.transform(X[i]);
|
|
||||||
maxes[features[i]] = *max_element(xd.begin(), xd.end()) + 1;
|
|
||||||
Xd.push_back(xd);
|
|
||||||
}
|
|
||||||
return { Xd, maxes };
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<mdlp::labels_t> discretizeDataset(vector<mdlp::samples_t>& X, mdlp::labels_t& y)
|
|
||||||
{
|
|
||||||
vector<mdlp::labels_t> Xd;
|
|
||||||
auto fimdlp = mdlp::CPPFImdlp();
|
|
||||||
for (int i = 0; i < X.size(); i++) {
|
|
||||||
fimdlp.fit(X[i], y);
|
|
||||||
mdlp::labels_t& xd = fimdlp.transform(X[i]);
|
|
||||||
Xd.push_back(xd);
|
|
||||||
}
|
|
||||||
return Xd;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool file_exists(const string& name)
|
|
||||||
{
|
|
||||||
if (FILE* file = fopen(name.c_str(), "r")) {
|
|
||||||
fclose(file);
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tuple<Tensor, Tensor, vector<string>, string, map<string, vector<int>>> loadDataset(const string& name, bool class_last, bool discretize_dataset)
|
|
||||||
{
|
|
||||||
auto handler = ArffFiles();
|
|
||||||
handler.load(Paths::datasets() + static_cast<string>(name) + ".arff", class_last);
|
|
||||||
// Get Dataset X, y
|
|
||||||
vector<mdlp::samples_t>& X = handler.getX();
|
|
||||||
mdlp::labels_t& y = handler.getY();
|
|
||||||
// Get className & Features
|
|
||||||
auto className = handler.getClassName();
|
|
||||||
vector<string> features;
|
|
||||||
auto attributes = handler.getAttributes();
|
|
||||||
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; });
|
|
||||||
Tensor Xd;
|
|
||||||
auto states = map<string, vector<int>>();
|
|
||||||
if (discretize_dataset) {
|
|
||||||
auto Xr = discretizeDataset(X, y);
|
|
||||||
Xd = torch::zeros({ static_cast<int>(Xr.size()), static_cast<int>(Xr[0].size()) }, torch::kInt32);
|
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
|
||||||
states[features[i]] = vector<int>(*max_element(Xr[i].begin(), Xr[i].end()) + 1);
|
|
||||||
auto item = states.at(features[i]);
|
|
||||||
iota(begin(item), end(item), 0);
|
|
||||||
Xd.index_put_({ i, "..." }, torch::tensor(Xr[i], torch::kInt32));
|
|
||||||
}
|
|
||||||
states[className] = vector<int>(*max_element(y.begin(), y.end()) + 1);
|
|
||||||
iota(begin(states.at(className)), end(states.at(className)), 0);
|
|
||||||
} else {
|
|
||||||
Xd = torch::zeros({ static_cast<int>(X.size()), static_cast<int>(X[0].size()) }, torch::kFloat32);
|
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
|
||||||
Xd.index_put_({ i, "..." }, torch::tensor(X[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { Xd, torch::tensor(y, torch::kInt32), features, className, states };
|
|
||||||
}
|
|
||||||
|
|
||||||
tuple<vector<vector<int>>, vector<int>, vector<string>, string, map<string, vector<int>>> loadFile(const string& name)
|
|
||||||
{
|
|
||||||
auto handler = ArffFiles();
|
|
||||||
handler.load(Paths::datasets() + static_cast<string>(name) + ".arff");
|
|
||||||
// Get Dataset X, y
|
|
||||||
vector<mdlp::samples_t>& X = handler.getX();
|
|
||||||
mdlp::labels_t& y = handler.getY();
|
|
||||||
// Get className & Features
|
|
||||||
auto className = handler.getClassName();
|
|
||||||
vector<string> features;
|
|
||||||
auto attributes = handler.getAttributes();
|
|
||||||
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; });
|
|
||||||
// Discretize Dataset
|
|
||||||
vector<mdlp::labels_t> Xd;
|
|
||||||
map<string, int> maxes;
|
|
||||||
tie(Xd, maxes) = discretize(X, y, features);
|
|
||||||
maxes[className] = *max_element(y.begin(), y.end()) + 1;
|
|
||||||
map<string, vector<int>> states;
|
|
||||||
for (auto feature : features) {
|
|
||||||
states[feature] = vector<int>(maxes[feature]);
|
|
||||||
}
|
|
||||||
states[className] = vector<int>(maxes[className]);
|
|
||||||
return { Xd, y, features, className, states };
|
|
||||||
}
|
|
||||||
class RawDatasets {
|
|
||||||
public:
|
|
||||||
RawDatasets(const string& file_name, bool discretize)
|
|
||||||
{
|
|
||||||
// Xt can be either discretized or not
|
|
||||||
tie(Xt, yt, featurest, classNamet, statest) = loadDataset(file_name, true, discretize);
|
|
||||||
// Xv is always discretized
|
|
||||||
tie(Xv, yv, featuresv, classNamev, statesv) = loadFile(file_name);
|
|
||||||
auto yresized = torch::transpose(yt.view({ yt.size(0), 1 }), 0, 1);
|
|
||||||
dataset = torch::cat({ Xt, yresized }, 0);
|
|
||||||
nSamples = dataset.size(1);
|
|
||||||
weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
|
|
||||||
weightsv = vector<double>(nSamples, 1.0 / nSamples);
|
|
||||||
classNumStates = discretize ? statest.at(classNamet).size() : 0;
|
|
||||||
}
|
|
||||||
torch::Tensor Xt, yt, dataset, weights;
|
|
||||||
vector<vector<int>> Xv;
|
|
||||||
vector<double> weightsv;
|
|
||||||
vector<int> yv;
|
|
||||||
vector<string> featurest, featuresv;
|
|
||||||
map<string, vector<int>> statest, statesv;
|
|
||||||
string classNamet, classNamev;
|
|
||||||
int nSamples, classNumStates;
|
|
||||||
double epsilon = 1e-5;
|
|
||||||
};
|
|
||||||
int main()
|
|
||||||
{
|
|
||||||
// map<string, string> balance = {
|
|
||||||
// {"iris", "33,33% (50) / 33,33% (50) / 33,33% (50)"},
|
|
||||||
// {"diabetes", "34,90% (268) / 65,10% (500)"},
|
|
||||||
// {"ecoli", "42,56% (143) / 22,92% (77) / 0,60% (2) / 0,60% (2) / 10,42% (35) / 5,95% (20) / 1,49% (5) / 15,48% (52)"},
|
|
||||||
// {"glass", "32,71% (70) / 7,94% (17) / 4,21% (9) / 35,51% (76) / 13,55% (29) / 6,07% (13)"}
|
|
||||||
// };
|
|
||||||
// for (const auto& file_name : { "iris", "glass", "ecoli", "diabetes" }) {
|
|
||||||
// auto dt = Datasets(true, "Arff");
|
|
||||||
// auto [X, y] = dt.getVectors(file_name);
|
|
||||||
// //auto fold = KFold(5, 150);
|
|
||||||
// auto fold = StratifiedKFold(5, y, -1);
|
|
||||||
// cout << "***********************************************************************************************" << endl;
|
|
||||||
// cout << "Dataset: " << file_name << endl;
|
|
||||||
// cout << "Nº Samples: " << dt.getNSamples(file_name) << endl;
|
|
||||||
// cout << "Class states: " << dt.getNClasses(file_name) << endl;
|
|
||||||
// cout << "Balance: " << balance.at(file_name) << endl;
|
|
||||||
// for (int i = 0; i < 5; ++i) {
|
|
||||||
// cout << "Fold: " << i << endl;
|
|
||||||
// auto [train, test] = fold.getFold(i);
|
|
||||||
// cout << "Train: ";
|
|
||||||
// cout << "(" << train.size() << "): ";
|
|
||||||
// // for (auto j = 0; j < static_cast<int>(train.size()); j++)
|
|
||||||
// // cout << train[j] << ", ";
|
|
||||||
// cout << endl;
|
|
||||||
// cout << "Train Statistics : " << counts(y, train);
|
|
||||||
// cout << "-------------------------------------------------------------------------------" << endl;
|
|
||||||
// cout << "Test: ";
|
|
||||||
// cout << "(" << test.size() << "): ";
|
|
||||||
// // for (auto j = 0; j < static_cast<int>(test.size()); j++)
|
|
||||||
// // cout << test[j] << ", ";
|
|
||||||
// cout << endl;
|
|
||||||
// cout << "Test Statistics: " << counts(y, test);
|
|
||||||
// cout << "==============================================================================" << endl;
|
|
||||||
// }
|
|
||||||
// cout << "***********************************************************************************************" << endl;
|
|
||||||
// }
|
|
||||||
// const string file_name = "iris";
|
|
||||||
// auto net = bayesnet::Network();
|
|
||||||
// auto dt = Datasets(true, "Arff");
|
|
||||||
// auto raw = RawDatasets("iris", true);
|
|
||||||
// auto [X, y] = dt.getVectors(file_name);
|
|
||||||
// cout << "Dataset dims " << raw.dataset.sizes() << endl;
|
|
||||||
// cout << "weights dims " << raw.weights.sizes() << endl;
|
|
||||||
// cout << "States dims " << raw.statest.size() << endl;
|
|
||||||
// cout << "features: ";
|
|
||||||
// for (const auto& feature : raw.featurest) {
|
|
||||||
// cout << feature << ", ";
|
|
||||||
// net.addNode(feature);
|
|
||||||
// }
|
|
||||||
// net.addNode(raw.classNamet);
|
|
||||||
// cout << endl;
|
|
||||||
// net.fit(raw.dataset, raw.weights, raw.featurest, raw.classNamet, raw.statest);
|
|
||||||
auto dt = Datasets(true, "Arff");
|
|
||||||
nlohmann::json output;
|
|
||||||
for (const auto& name : dt.getNames()) {
|
|
||||||
// for (const auto& name : { "iris" }) {
|
|
||||||
auto [X, y] = dt.getTensors(name);
|
|
||||||
auto features = dt.getFeatures(name);
|
|
||||||
auto states = dt.getStates(name);
|
|
||||||
auto className = dt.getClassName(name);
|
|
||||||
int maxFeatures = 0;
|
|
||||||
auto classNumStates = states.at(className).size();
|
|
||||||
torch::Tensor weights = torch::full({ X.size(1) }, 1.0 / X.size(1), torch::kDouble);
|
|
||||||
auto dataset = X;
|
|
||||||
auto yresized = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
|
|
||||||
dataset = torch::cat({ dataset, yresized }, 0);
|
|
||||||
auto cfs = bayesnet::CFS(dataset, features, className, maxFeatures, classNumStates, weights);
|
|
||||||
auto fcbf = bayesnet::FCBF(dataset, features, className, maxFeatures, classNumStates, weights, 1e-7);
|
|
||||||
auto iwss = bayesnet::IWSS(dataset, features, className, maxFeatures, classNumStates, weights, 0.5);
|
|
||||||
cout << "Dataset: " << setw(20) << name << flush;
|
|
||||||
cfs.fit();
|
|
||||||
cout << " CFS: " << setw(4) << cfs.getFeatures().size() << flush;
|
|
||||||
fcbf.fit();
|
|
||||||
cout << " FCBF: " << setw(4) << fcbf.getFeatures().size() << flush;
|
|
||||||
iwss.fit();
|
|
||||||
cout << " IWSS: " << setw(4) << iwss.getFeatures().size() << flush;
|
|
||||||
cout << endl;
|
|
||||||
output[name]["CFS"] = cfs.getFeatures();
|
|
||||||
output[name]["FCBF"] = fcbf.getFeatures();
|
|
||||||
output[name]["IWSS"] = iwss.getFeatures();
|
|
||||||
}
|
|
||||||
ofstream file("features_cpp.json");
|
|
||||||
file << output;
|
|
||||||
file.close();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
Reference in New Issue
Block a user