Compare commits
16 Commits
Author | SHA1 | Date | |
---|---|---|---|
97ca8ac084
|
|||
1c1385b768
|
|||
35432b6294
|
|||
c59dd30e53
|
|||
d2da0ddb88
|
|||
8066701c3c
|
|||
0f66ac73d0
|
|||
4370bf51d7
|
|||
2b7353b9e0
|
|||
b686b3c9c3
|
|||
2dd04a6c44
|
|||
1da83662d0
|
|||
3ac9593c65
|
|||
6b317accf1
|
|||
4964aab722
|
|||
7a6ec73d63 |
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -31,7 +31,9 @@
|
|||||||
"--discretize",
|
"--discretize",
|
||||||
"--stratified",
|
"--stratified",
|
||||||
"-d",
|
"-d",
|
||||||
"iris"
|
"glass",
|
||||||
|
"--hyperparameters",
|
||||||
|
"{\"repeatSparent\": true, \"maxModels\": 12}"
|
||||||
],
|
],
|
||||||
"cwd": "/Users/rmontanana/Code/discretizbench",
|
"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||||
},
|
},
|
||||||
|
@@ -55,6 +55,7 @@ endif (ENABLE_CLANG_TIDY)
|
|||||||
add_git_submodule("lib/mdlp")
|
add_git_submodule("lib/mdlp")
|
||||||
add_git_submodule("lib/argparse")
|
add_git_submodule("lib/argparse")
|
||||||
add_git_submodule("lib/json")
|
add_git_submodule("lib/json")
|
||||||
|
add_git_submodule("lib/openXLSX")
|
||||||
|
|
||||||
# Subdirectories
|
# Subdirectories
|
||||||
# --------------
|
# --------------
|
||||||
|
10
Makefile
10
Makefile
@@ -11,6 +11,16 @@ setup: ## Install dependencies for tests and coverage
|
|||||||
pip install gcovr; \
|
pip install gcovr; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
dest ?= ../discretizbench
|
||||||
|
copy: ## Copy binary files to selected folder
|
||||||
|
@echo "Destination folder: $(dest)"
|
||||||
|
make build
|
||||||
|
@echo ">>> Copying files to $(dest)"
|
||||||
|
@cp build/src/Platform/main $(dest)
|
||||||
|
@cp build/src/Platform/list $(dest)
|
||||||
|
@cp build/src/Platform/manage $(dest)
|
||||||
|
@echo ">>> Done"
|
||||||
|
|
||||||
dependency: ## Create a dependency graph diagram of the project (build/dependency.png)
|
dependency: ## Create a dependency graph diagram of the project (build/dependency.png)
|
||||||
cd build && cmake .. --graphviz=dependency.dot && dot -Tpng dependency.dot -o dependency.png
|
cd build && cmake .. --graphviz=dependency.dot && dot -Tpng dependency.dot -o dependency.png
|
||||||
|
|
||||||
|
1
lib/openXLSX
Submodule
1
lib/openXLSX
Submodule
Submodule lib/openXLSX added at b80da42d14
@@ -3,5 +3,6 @@ include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
|
|||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
||||||
|
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
||||||
add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
|
add_executable(BayesNetSample sample.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
|
||||||
target_link_libraries(BayesNetSample BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
target_link_libraries(BayesNetSample BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
200
sample/sample.cc
200
sample/sample.cc
@@ -3,6 +3,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <argparse/argparse.hpp>
|
#include <argparse/argparse.hpp>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
#include "ArffFiles.h"
|
#include "ArffFiles.h"
|
||||||
#include "BayesMetrics.h"
|
#include "BayesMetrics.h"
|
||||||
#include "CPPFImdlp.h"
|
#include "CPPFImdlp.h"
|
||||||
@@ -141,111 +142,96 @@ int main(int argc, char** argv)
|
|||||||
/*
|
/*
|
||||||
* Begin Processing
|
* Begin Processing
|
||||||
*/
|
*/
|
||||||
auto ypred = torch::tensor({ 1,2,3,2,2,3,4,5,2,1 });
|
auto handler = ArffFiles();
|
||||||
auto y = torch::tensor({ 0,0,0,0,2,3,4,0,0,0 });
|
handler.load(complete_file_name, class_last);
|
||||||
auto weights = torch::ones({ 10 }, kDouble);
|
// Get Dataset X, y
|
||||||
auto mask = ypred == y;
|
vector<mdlp::samples_t>& X = handler.getX();
|
||||||
cout << "ypred:" << ypred << endl;
|
mdlp::labels_t& y = handler.getY();
|
||||||
cout << "y:" << y << endl;
|
// Get className & Features
|
||||||
cout << "weights:" << weights << endl;
|
auto className = handler.getClassName();
|
||||||
cout << "mask:" << mask << endl;
|
vector<string> features;
|
||||||
double value_to_add = 0.5;
|
auto attributes = handler.getAttributes();
|
||||||
weights += mask.to(torch::kDouble) * value_to_add;
|
transform(attributes.begin(), attributes.end(), back_inserter(features),
|
||||||
cout << "New weights:" << weights << endl;
|
[](const pair<string, string>& item) { return item.first; });
|
||||||
auto masked_weights = weights * mask.to(weights.dtype());
|
// Discretize Dataset
|
||||||
double sum_of_weights = masked_weights.sum().item<double>();
|
auto [Xd, maxes] = discretize(X, y, features);
|
||||||
cout << "Sum of weights: " << sum_of_weights << endl;
|
maxes[className] = *max_element(y.begin(), y.end()) + 1;
|
||||||
//weights.index_put_({ mask }, weights + 10);
|
map<string, vector<int>> states;
|
||||||
// auto handler = ArffFiles();
|
for (auto feature : features) {
|
||||||
// handler.load(complete_file_name, class_last);
|
states[feature] = vector<int>(maxes[feature]);
|
||||||
// // Get Dataset X, y
|
}
|
||||||
// vector<mdlp::samples_t>& X = handler.getX();
|
states[className] = vector<int>(maxes[className]);
|
||||||
// mdlp::labels_t& y = handler.getY();
|
auto clf = platform::Models::instance()->create(model_name);
|
||||||
// // Get className & Features
|
clf->fit(Xd, y, features, className, states);
|
||||||
// auto className = handler.getClassName();
|
if (dump_cpt) {
|
||||||
// vector<string> features;
|
cout << "--- CPT Tables ---" << endl;
|
||||||
// auto attributes = handler.getAttributes();
|
clf->dump_cpt();
|
||||||
// transform(attributes.begin(), attributes.end(), back_inserter(features),
|
}
|
||||||
// [](const pair<string, string>& item) { return item.first; });
|
auto lines = clf->show();
|
||||||
// // Discretize Dataset
|
for (auto line : lines) {
|
||||||
// auto [Xd, maxes] = discretize(X, y, features);
|
cout << line << endl;
|
||||||
// maxes[className] = *max_element(y.begin(), y.end()) + 1;
|
}
|
||||||
// map<string, vector<int>> states;
|
cout << "--- Topological Order ---" << endl;
|
||||||
// for (auto feature : features) {
|
auto order = clf->topological_order();
|
||||||
// states[feature] = vector<int>(maxes[feature]);
|
for (auto name : order) {
|
||||||
// }
|
cout << name << ", ";
|
||||||
// states[className] = vector<int>(maxes[className]);
|
}
|
||||||
// auto clf = platform::Models::instance()->create(model_name);
|
cout << "end." << endl;
|
||||||
// clf->fit(Xd, y, features, className, states);
|
auto score = clf->score(Xd, y);
|
||||||
// if (dump_cpt) {
|
cout << "Score: " << score << endl;
|
||||||
// cout << "--- CPT Tables ---" << endl;
|
auto graph = clf->graph();
|
||||||
// clf->dump_cpt();
|
auto dot_file = model_name + "_" + file_name;
|
||||||
// }
|
ofstream file(dot_file + ".dot");
|
||||||
// auto lines = clf->show();
|
file << graph;
|
||||||
// for (auto line : lines) {
|
file.close();
|
||||||
// cout << line << endl;
|
cout << "Graph saved in " << model_name << "_" << file_name << ".dot" << endl;
|
||||||
// }
|
cout << "dot -Tpng -o " + dot_file + ".png " + dot_file + ".dot " << endl;
|
||||||
// cout << "--- Topological Order ---" << endl;
|
string stratified_string = stratified ? " Stratified" : "";
|
||||||
// auto order = clf->topological_order();
|
cout << nFolds << " Folds" << stratified_string << " Cross validation" << endl;
|
||||||
// for (auto name : order) {
|
cout << "==========================================" << endl;
|
||||||
// cout << name << ", ";
|
torch::Tensor Xt = torch::zeros({ static_cast<int>(Xd.size()), static_cast<int>(Xd[0].size()) }, torch::kInt32);
|
||||||
// }
|
torch::Tensor yt = torch::tensor(y, torch::kInt32);
|
||||||
// cout << "end." << endl;
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
// auto score = clf->score(Xd, y);
|
Xt.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32));
|
||||||
// cout << "Score: " << score << endl;
|
}
|
||||||
// auto graph = clf->graph();
|
float total_score = 0, total_score_train = 0, score_train, score_test;
|
||||||
// auto dot_file = model_name + "_" + file_name;
|
Fold* fold;
|
||||||
// ofstream file(dot_file + ".dot");
|
if (stratified)
|
||||||
// file << graph;
|
fold = new StratifiedKFold(nFolds, y, seed);
|
||||||
// file.close();
|
else
|
||||||
// cout << "Graph saved in " << model_name << "_" << file_name << ".dot" << endl;
|
fold = new KFold(nFolds, y.size(), seed);
|
||||||
// cout << "dot -Tpng -o " + dot_file + ".png " + dot_file + ".dot " << endl;
|
for (auto i = 0; i < nFolds; ++i) {
|
||||||
// string stratified_string = stratified ? " Stratified" : "";
|
auto [train, test] = fold->getFold(i);
|
||||||
// cout << nFolds << " Folds" << stratified_string << " Cross validation" << endl;
|
cout << "Fold: " << i + 1 << endl;
|
||||||
// cout << "==========================================" << endl;
|
if (tensors) {
|
||||||
// torch::Tensor Xt = torch::zeros({ static_cast<int>(Xd.size()), static_cast<int>(Xd[0].size()) }, torch::kInt32);
|
auto ttrain = torch::tensor(train, torch::kInt64);
|
||||||
// torch::Tensor yt = torch::tensor(y, torch::kInt32);
|
auto ttest = torch::tensor(test, torch::kInt64);
|
||||||
// for (int i = 0; i < features.size(); ++i) {
|
torch::Tensor Xtraint = torch::index_select(Xt, 1, ttrain);
|
||||||
// Xt.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32));
|
torch::Tensor ytraint = yt.index({ ttrain });
|
||||||
// }
|
torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
|
||||||
// float total_score = 0, total_score_train = 0, score_train, score_test;
|
torch::Tensor ytestt = yt.index({ ttest });
|
||||||
// Fold* fold;
|
clf->fit(Xtraint, ytraint, features, className, states);
|
||||||
// if (stratified)
|
auto temp = clf->predict(Xtraint);
|
||||||
// fold = new StratifiedKFold(nFolds, y, seed);
|
score_train = clf->score(Xtraint, ytraint);
|
||||||
// else
|
score_test = clf->score(Xtestt, ytestt);
|
||||||
// fold = new KFold(nFolds, y.size(), seed);
|
} else {
|
||||||
// for (auto i = 0; i < nFolds; ++i) {
|
auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
|
||||||
// auto [train, test] = fold->getFold(i);
|
auto [Xtest, ytest] = extract_indices(test, Xd, y);
|
||||||
// cout << "Fold: " << i + 1 << endl;
|
clf->fit(Xtrain, ytrain, features, className, states);
|
||||||
// if (tensors) {
|
score_train = clf->score(Xtrain, ytrain);
|
||||||
// auto ttrain = torch::tensor(train, torch::kInt64);
|
score_test = clf->score(Xtest, ytest);
|
||||||
// auto ttest = torch::tensor(test, torch::kInt64);
|
}
|
||||||
// torch::Tensor Xtraint = torch::index_select(Xt, 1, ttrain);
|
if (dump_cpt) {
|
||||||
// torch::Tensor ytraint = yt.index({ ttrain });
|
cout << "--- CPT Tables ---" << endl;
|
||||||
// torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
|
clf->dump_cpt();
|
||||||
// torch::Tensor ytestt = yt.index({ ttest });
|
}
|
||||||
// clf->fit(Xtraint, ytraint, features, className, states);
|
total_score_train += score_train;
|
||||||
// auto temp = clf->predict(Xtraint);
|
total_score += score_test;
|
||||||
// score_train = clf->score(Xtraint, ytraint);
|
cout << "Score Train: " << score_train << endl;
|
||||||
// score_test = clf->score(Xtestt, ytestt);
|
cout << "Score Test : " << score_test << endl;
|
||||||
// } else {
|
cout << "-------------------------------------------------------------------------------" << endl;
|
||||||
// auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
|
}
|
||||||
// auto [Xtest, ytest] = extract_indices(test, Xd, y);
|
cout << "**********************************************************************************" << endl;
|
||||||
// clf->fit(Xtrain, ytrain, features, className, states);
|
cout << "Average Score Train: " << total_score_train / nFolds << endl;
|
||||||
// score_train = clf->score(Xtrain, ytrain);
|
cout << "Average Score Test : " << total_score / nFolds << endl;return 0;
|
||||||
// score_test = clf->score(Xtest, ytest);
|
|
||||||
// }
|
|
||||||
// if (dump_cpt) {
|
|
||||||
// cout << "--- CPT Tables ---" << endl;
|
|
||||||
// clf->dump_cpt();
|
|
||||||
// }
|
|
||||||
// total_score_train += score_train;
|
|
||||||
// total_score += score_test;
|
|
||||||
// cout << "Score Train: " << score_train << endl;
|
|
||||||
// cout << "Score Test : " << score_test << endl;
|
|
||||||
// cout << "-------------------------------------------------------------------------------" << endl;
|
|
||||||
// }
|
|
||||||
// cout << "**********************************************************************************" << endl;
|
|
||||||
// cout << "Average Score Train: " << total_score_train / nFolds << endl;
|
|
||||||
// cout << "Average Score Test : " << total_score / nFolds << endl;return 0;
|
|
||||||
}
|
}
|
@@ -10,6 +10,7 @@ namespace bayesnet {
|
|||||||
AODE();
|
AODE();
|
||||||
virtual ~AODE() {};
|
virtual ~AODE() {};
|
||||||
vector<string> graph(const string& title = "AODE") const override;
|
vector<string> graph(const string& title = "AODE") const override;
|
||||||
|
void setHyperparameters(nlohmann::json& hyperparameters) override {};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -16,6 +16,7 @@ namespace bayesnet {
|
|||||||
virtual ~AODELd() = default;
|
virtual ~AODELd() = default;
|
||||||
vector<string> graph(const string& name = "AODE") const override;
|
vector<string> graph(const string& name = "AODE") const override;
|
||||||
static inline string version() { return "0.0.1"; };
|
static inline string version() { return "0.0.1"; };
|
||||||
|
void setHyperparameters(nlohmann::json& hyperparameters) override {};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif // !AODELD_H
|
#endif // !AODELD_H
|
@@ -1,6 +1,7 @@
|
|||||||
#ifndef BASE_H
|
#ifndef BASE_H
|
||||||
#define BASE_H
|
#define BASE_H
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@@ -27,6 +28,7 @@ namespace bayesnet {
|
|||||||
const string inline getVersion() const { return "0.1.0"; };
|
const string inline getVersion() const { return "0.1.0"; };
|
||||||
vector<string> virtual topological_order() = 0;
|
vector<string> virtual topological_order() = 0;
|
||||||
void virtual dump_cpt()const = 0;
|
void virtual dump_cpt()const = 0;
|
||||||
|
virtual void setHyperparameters(nlohmann::json& hyperparameters) = 0;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -21,25 +21,39 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
|
samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
|
||||||
}
|
}
|
||||||
vector<int> Metrics::SelectKBestWeighted(const torch::Tensor& weights, unsigned k)
|
vector<int> Metrics::SelectKBestWeighted(const torch::Tensor& weights, bool ascending, unsigned k)
|
||||||
{
|
{
|
||||||
|
// Return the K Best features
|
||||||
auto n = samples.size(0) - 1;
|
auto n = samples.size(0) - 1;
|
||||||
if (k == 0) {
|
if (k == 0) {
|
||||||
k = n;
|
k = n;
|
||||||
}
|
}
|
||||||
// compute scores
|
// compute scores
|
||||||
scoresKBest.reserve(n);
|
scoresKBest.clear();
|
||||||
|
featuresKBest.clear();
|
||||||
auto label = samples.index({ -1, "..." });
|
auto label = samples.index({ -1, "..." });
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
scoresKBest.push_back(mutualInformation(label, samples.index({ i, "..." }), weights));
|
scoresKBest.push_back(mutualInformation(label, samples.index({ i, "..." }), weights));
|
||||||
featuresKBest.push_back(i);
|
featuresKBest.push_back(i);
|
||||||
}
|
}
|
||||||
// sort & reduce scores and features
|
// sort & reduce scores and features
|
||||||
sort(featuresKBest.begin(), featuresKBest.end(), [&](int i, int j)
|
if (ascending) {
|
||||||
{ return scoresKBest[i] > scoresKBest[j]; });
|
sort(featuresKBest.begin(), featuresKBest.end(), [&](int i, int j)
|
||||||
sort(scoresKBest.begin(), scoresKBest.end(), std::greater<double>());
|
{ return scoresKBest[i] < scoresKBest[j]; });
|
||||||
featuresKBest.resize(k);
|
sort(scoresKBest.begin(), scoresKBest.end(), std::less<double>());
|
||||||
scoresKBest.resize(k);
|
if (k < n) {
|
||||||
|
for (int i = 0; i < n - k; ++i) {
|
||||||
|
featuresKBest.erase(featuresKBest.begin());
|
||||||
|
scoresKBest.erase(scoresKBest.begin());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sort(featuresKBest.begin(), featuresKBest.end(), [&](int i, int j)
|
||||||
|
{ return scoresKBest[i] > scoresKBest[j]; });
|
||||||
|
sort(scoresKBest.begin(), scoresKBest.end(), std::greater<double>());
|
||||||
|
featuresKBest.resize(k);
|
||||||
|
scoresKBest.resize(k);
|
||||||
|
}
|
||||||
return featuresKBest;
|
return featuresKBest;
|
||||||
}
|
}
|
||||||
vector<double> Metrics::getScoresKBest() const
|
vector<double> Metrics::getScoresKBest() const
|
||||||
|
@@ -21,7 +21,7 @@ namespace bayesnet {
|
|||||||
Metrics() = default;
|
Metrics() = default;
|
||||||
Metrics(const torch::Tensor& samples, const vector<string>& features, const string& className, const int classNumStates);
|
Metrics(const torch::Tensor& samples, const vector<string>& features, const string& className, const int classNumStates);
|
||||||
Metrics(const vector<vector<int>>& vsamples, const vector<int>& labels, const vector<string>& features, const string& className, const int classNumStates);
|
Metrics(const vector<vector<int>>& vsamples, const vector<int>& labels, const vector<string>& features, const string& className, const int classNumStates);
|
||||||
vector<int> SelectKBestWeighted(const torch::Tensor& weights, unsigned k = 0);
|
vector<int> SelectKBestWeighted(const torch::Tensor& weights, bool ascending=false, unsigned k = 0);
|
||||||
vector<double> getScoresKBest() const;
|
vector<double> getScoresKBest() const;
|
||||||
double mutualInformation(const Tensor& firstFeature, const Tensor& secondFeature, const Tensor& weights);
|
double mutualInformation(const Tensor& firstFeature, const Tensor& secondFeature, const Tensor& weights);
|
||||||
vector<float> conditionalEdgeWeights(vector<float>& weights); // To use in Python
|
vector<float> conditionalEdgeWeights(vector<float>& weights); // To use in Python
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include "BoostAODE.h"
|
#include "BoostAODE.h"
|
||||||
|
#include <set>
|
||||||
#include "BayesMetrics.h"
|
#include "BayesMetrics.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
@@ -7,40 +8,49 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
// Models shall be built in trainModel
|
// Models shall be built in trainModel
|
||||||
}
|
}
|
||||||
|
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
|
||||||
|
{
|
||||||
|
// Check if hyperparameters are valid
|
||||||
|
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending" };
|
||||||
|
checkHyperparameters(validKeys, hyperparameters);
|
||||||
|
if (hyperparameters.contains("repeatSparent")) {
|
||||||
|
repeatSparent = hyperparameters["repeatSparent"];
|
||||||
|
}
|
||||||
|
if (hyperparameters.contains("maxModels")) {
|
||||||
|
maxModels = hyperparameters["maxModels"];
|
||||||
|
}
|
||||||
|
if (hyperparameters.contains("ascending")) {
|
||||||
|
ascending = hyperparameters["ascending"];
|
||||||
|
}
|
||||||
|
}
|
||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
models.clear();
|
models.clear();
|
||||||
n_models = 0;
|
n_models = 0;
|
||||||
int max_models = .1 * n > 10 ? .1 * n : n;
|
if (maxModels == 0)
|
||||||
|
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||||
auto X_ = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
auto X_ = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
||||||
auto y_ = dataset.index({ -1, "..." });
|
auto y_ = dataset.index({ -1, "..." });
|
||||||
bool exitCondition = false;
|
bool exitCondition = false;
|
||||||
bool repeatSparent = false;
|
unordered_set<int> featuresUsed;
|
||||||
vector<int> featuresUsed;
|
|
||||||
// Step 0: Set the finish condition
|
// Step 0: Set the finish condition
|
||||||
// if not repeatSparent a finish condition is run out of features
|
// if not repeatSparent a finish condition is run out of features
|
||||||
// n_models == max_models
|
// n_models == maxModels
|
||||||
int numClasses = states[className].size();
|
int numClasses = states[className].size();
|
||||||
while (!exitCondition) {
|
while (!exitCondition) {
|
||||||
// Step 1: Build ranking with mutual information
|
// Step 1: Build ranking with mutual information
|
||||||
auto featureSelection = metrics.SelectKBestWeighted(weights_, n); // Get all the features sorted
|
auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted
|
||||||
auto feature = featureSelection[0];
|
|
||||||
unique_ptr<Classifier> model;
|
unique_ptr<Classifier> model;
|
||||||
if (!repeatSparent) {
|
auto feature = featureSelection[0];
|
||||||
if (n_models == 0) {
|
if (!repeatSparent || featuresUsed.size() < featureSelection.size()) {
|
||||||
models.resize(n); // Resize for n==nfeatures SPODEs
|
|
||||||
significanceModels.resize(n);
|
|
||||||
}
|
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (int i = 0; i < featureSelection.size(); ++i) {
|
for (auto feat : featureSelection) {
|
||||||
if (find(featuresUsed.begin(), featuresUsed.end(), i) != featuresUsed.end()) {
|
if (find(featuresUsed.begin(), featuresUsed.end(), feat) != featuresUsed.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
found = true;
|
found = true;
|
||||||
feature = i;
|
feature = feat;
|
||||||
featuresUsed.push_back(feature);
|
|
||||||
n_models++;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (!found) {
|
if (!found) {
|
||||||
@@ -48,7 +58,9 @@ namespace bayesnet {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
featuresUsed.insert(feature);
|
||||||
model = std::make_unique<SPODE>(feature);
|
model = std::make_unique<SPODE>(feature);
|
||||||
|
n_models++;
|
||||||
model->fit(dataset, features, className, states, weights_);
|
model->fit(dataset, features, className, states, weights_);
|
||||||
auto ypred = model->predict(X_);
|
auto ypred = model->predict(X_);
|
||||||
// Step 3.1: Compute the classifier amout of say
|
// Step 3.1: Compute the classifier amout of say
|
||||||
@@ -63,15 +75,12 @@ namespace bayesnet {
|
|||||||
double totalWeights = torch::sum(weights_).item<double>();
|
double totalWeights = torch::sum(weights_).item<double>();
|
||||||
weights_ = weights_ / totalWeights;
|
weights_ = weights_ / totalWeights;
|
||||||
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
||||||
if (!repeatSparent) {
|
models.push_back(std::move(model));
|
||||||
models[feature] = std::move(model);
|
significanceModels.push_back(significance);
|
||||||
significanceModels[feature] = significance;
|
exitCondition = n_models == maxModels && repeatSparent;
|
||||||
} else {
|
}
|
||||||
models.push_back(std::move(model));
|
if (featuresUsed.size() != features.size()) {
|
||||||
significanceModels.push_back(significance);
|
cout << "Warning: BoostAODE did not use all the features" << endl;
|
||||||
n_models++;
|
|
||||||
}
|
|
||||||
exitCondition = n_models == max_models;
|
|
||||||
}
|
}
|
||||||
weights.copy_(weights_);
|
weights.copy_(weights_);
|
||||||
}
|
}
|
||||||
|
@@ -4,13 +4,18 @@
|
|||||||
#include "SPODE.h"
|
#include "SPODE.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class BoostAODE : public Ensemble {
|
class BoostAODE : public Ensemble {
|
||||||
protected:
|
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
|
||||||
public:
|
public:
|
||||||
BoostAODE();
|
BoostAODE();
|
||||||
virtual ~BoostAODE() {};
|
virtual ~BoostAODE() {};
|
||||||
vector<string> graph(const string& title = "BoostAODE") const override;
|
vector<string> graph(const string& title = "BoostAODE") const override;
|
||||||
|
void setHyperparameters(nlohmann::json& hyperparameters) override;
|
||||||
|
protected:
|
||||||
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
|
private:
|
||||||
|
bool repeatSparent=false;
|
||||||
|
int maxModels=0;
|
||||||
|
bool ascending=false; //Process KBest features ascending or descending order
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -1,5 +1,6 @@
|
|||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
|
||||||
|
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
|
include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/src/Platform)
|
include_directories(${BayesNet_SOURCE_DIR}/src/Platform)
|
||||||
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc
|
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc
|
||||||
|
@@ -152,4 +152,12 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
model.dump_cpt();
|
model.dump_cpt();
|
||||||
}
|
}
|
||||||
|
void Classifier::checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters)
|
||||||
|
{
|
||||||
|
for (const auto& item : hyperparameters.items()) {
|
||||||
|
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
|
||||||
|
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@@ -24,6 +24,7 @@ namespace bayesnet {
|
|||||||
void checkFitParameters();
|
void checkFitParameters();
|
||||||
virtual void buildModel(const torch::Tensor& weights) = 0;
|
virtual void buildModel(const torch::Tensor& weights) = 0;
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
|
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
|
||||||
public:
|
public:
|
||||||
Classifier(Network model);
|
Classifier(Network model);
|
||||||
virtual ~Classifier() = default;
|
virtual ~Classifier() = default;
|
||||||
|
@@ -16,6 +16,7 @@ namespace bayesnet {
|
|||||||
public:
|
public:
|
||||||
explicit KDB(int k, float theta = 0.03);
|
explicit KDB(int k, float theta = 0.03);
|
||||||
virtual ~KDB() {};
|
virtual ~KDB() {};
|
||||||
|
void setHyperparameters(nlohmann::json& hyperparameters) override {};
|
||||||
vector<string> graph(const string& name = "KDB") const override;
|
vector<string> graph(const string& name = "KDB") const override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@@ -13,6 +13,7 @@ namespace bayesnet {
|
|||||||
KDBLd& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
KDBLd& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
vector<string> graph(const string& name = "KDB") const override;
|
vector<string> graph(const string& name = "KDB") const override;
|
||||||
Tensor predict(Tensor& X) override;
|
Tensor predict(Tensor& X) override;
|
||||||
|
void setHyperparameters(nlohmann::json& hyperparameters) override {};
|
||||||
static inline string version() { return "0.0.1"; };
|
static inline string version() { return "0.0.1"; };
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@@ -12,6 +12,7 @@ namespace bayesnet {
|
|||||||
explicit SPODE(int root);
|
explicit SPODE(int root);
|
||||||
virtual ~SPODE() {};
|
virtual ~SPODE() {};
|
||||||
vector<string> graph(const string& name = "SPODE") const override;
|
vector<string> graph(const string& name = "SPODE") const override;
|
||||||
|
void setHyperparameters(nlohmann::json& hyperparameters) override {};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -13,6 +13,7 @@ namespace bayesnet {
|
|||||||
SPODELd& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
SPODELd& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
vector<string> graph(const string& name = "SPODE") const override;
|
vector<string> graph(const string& name = "SPODE") const override;
|
||||||
Tensor predict(Tensor& X) override;
|
Tensor predict(Tensor& X) override;
|
||||||
|
void setHyperparameters(nlohmann::json& hyperparameters) override {};
|
||||||
static inline string version() { return "0.0.1"; };
|
static inline string version() { return "0.0.1"; };
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@@ -3,7 +3,6 @@
|
|||||||
#include "Classifier.h"
|
#include "Classifier.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace torch;
|
|
||||||
class TAN : public Classifier {
|
class TAN : public Classifier {
|
||||||
private:
|
private:
|
||||||
protected:
|
protected:
|
||||||
@@ -12,6 +11,7 @@ namespace bayesnet {
|
|||||||
TAN();
|
TAN();
|
||||||
virtual ~TAN() {};
|
virtual ~TAN() {};
|
||||||
vector<string> graph(const string& name = "TAN") const override;
|
vector<string> graph(const string& name = "TAN") const override;
|
||||||
|
void setHyperparameters(nlohmann::json& hyperparameters) override {};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -14,6 +14,7 @@ namespace bayesnet {
|
|||||||
vector<string> graph(const string& name = "TAN") const override;
|
vector<string> graph(const string& name = "TAN") const override;
|
||||||
Tensor predict(Tensor& X) override;
|
Tensor predict(Tensor& X) override;
|
||||||
static inline string version() { return "0.0.1"; };
|
static inline string version() { return "0.0.1"; };
|
||||||
|
void setHyperparameters(nlohmann::json& hyperparameters) override {};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif // !TANLD_H
|
#endif // !TANLD_H
|
@@ -4,9 +4,9 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
|
|||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
||||||
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc Report.cc)
|
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc)
|
||||||
add_executable(manage manage.cc Results.cc Report.cc)
|
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc)
|
||||||
add_executable(list list.cc platformUtils Datasets.cc)
|
add_executable(list list.cc platformUtils Datasets.cc)
|
||||||
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
||||||
target_link_libraries(manage "${TORCH_LIBRARIES}")
|
target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX)
|
||||||
target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
@@ -1,7 +1,7 @@
|
|||||||
#include "Experiment.h"
|
#include "Experiment.h"
|
||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
#include "Models.h"
|
#include "Models.h"
|
||||||
#include "Report.h"
|
#include "ReportConsole.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
@@ -25,6 +25,7 @@ namespace platform {
|
|||||||
oss << std::put_time(timeinfo, "%H:%M:%S");
|
oss << std::put_time(timeinfo, "%H:%M:%S");
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
Experiment::Experiment() : hyperparameters(json::parse("{}")) {}
|
||||||
string Experiment::get_file_name()
|
string Experiment::get_file_name()
|
||||||
{
|
{
|
||||||
string result = "results_" + score_name + "_" + model + "_" + platform + "_" + get_date() + "_" + get_time() + "_" + (stratified ? "1" : "0") + ".json";
|
string result = "results_" + score_name + "_" + model + "_" + platform + "_" + get_date() + "_" + get_time() + "_" + (stratified ? "1" : "0") + ".json";
|
||||||
@@ -90,7 +91,7 @@ namespace platform {
|
|||||||
void Experiment::report()
|
void Experiment::report()
|
||||||
{
|
{
|
||||||
json data = build_json();
|
json data = build_json();
|
||||||
Report report(data);
|
ReportConsole report(data);
|
||||||
report.show();
|
report.show();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,6 +125,8 @@ namespace platform {
|
|||||||
auto result = Result();
|
auto result = Result();
|
||||||
auto [values, counts] = at::_unique(y);
|
auto [values, counts] = at::_unique(y);
|
||||||
result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
||||||
|
result.setHyperparameters(hyperparameters);
|
||||||
|
// Initialize results vectors
|
||||||
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
||||||
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
@@ -144,6 +147,10 @@ namespace platform {
|
|||||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||||
auto clf = Models::instance()->create(model);
|
auto clf = Models::instance()->create(model);
|
||||||
setModelVersion(clf->getVersion());
|
setModelVersion(clf->getVersion());
|
||||||
|
if (hyperparameters.size() != 0) {
|
||||||
|
clf->setHyperparameters(hyperparameters);
|
||||||
|
}
|
||||||
|
// Split train - test dataset
|
||||||
train_timer.start();
|
train_timer.start();
|
||||||
auto [train, test] = fold->getFold(nfold);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
@@ -153,12 +160,14 @@ namespace platform {
|
|||||||
auto X_test = X.index({ "...", test_t });
|
auto X_test = X.index({ "...", test_t });
|
||||||
auto y_test = y.index({ test_t });
|
auto y_test = y.index({ test_t });
|
||||||
cout << nfold + 1 << ", " << flush;
|
cout << nfold + 1 << ", " << flush;
|
||||||
|
// Train model
|
||||||
clf->fit(X_train, y_train, features, className, states);
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
nodes[item] = clf->getNumberOfNodes();
|
nodes[item] = clf->getNumberOfNodes();
|
||||||
edges[item] = clf->getNumberOfEdges();
|
edges[item] = clf->getNumberOfEdges();
|
||||||
num_states[item] = clf->getNumberOfStates();
|
num_states[item] = clf->getNumberOfStates();
|
||||||
train_time[item] = train_timer.getDuration();
|
train_time[item] = train_timer.getDuration();
|
||||||
auto accuracy_train_value = clf->score(X_train, y_train);
|
auto accuracy_train_value = clf->score(X_train, y_train);
|
||||||
|
// Test model
|
||||||
test_timer.start();
|
test_timer.start();
|
||||||
auto accuracy_test_value = clf->score(X_test, y_test);
|
auto accuracy_test_value = clf->score(X_test, y_test);
|
||||||
test_time[item] = test_timer.getDuration();
|
test_time[item] = test_timer.getDuration();
|
||||||
@@ -172,11 +181,11 @@ namespace platform {
|
|||||||
item++;
|
item++;
|
||||||
}
|
}
|
||||||
cout << "end. " << flush;
|
cout << "end. " << flush;
|
||||||
delete fold;
|
|
||||||
}
|
}
|
||||||
result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
|
result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
|
||||||
result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
|
result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
|
||||||
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
||||||
|
result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
|
||||||
result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
||||||
result.setDataset(fileName);
|
result.setDataset(fileName);
|
||||||
addResult(result);
|
addResult(result);
|
||||||
|
@@ -29,7 +29,8 @@ namespace platform {
|
|||||||
};
|
};
|
||||||
class Result {
|
class Result {
|
||||||
private:
|
private:
|
||||||
string dataset, hyperparameters, model_version;
|
string dataset, model_version;
|
||||||
|
json hyperparameters;
|
||||||
int samples{ 0 }, features{ 0 }, classes{ 0 };
|
int samples{ 0 }, features{ 0 }, classes{ 0 };
|
||||||
double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 };
|
double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 };
|
||||||
float nodes{ 0 }, leaves{ 0 }, depth{ 0 };
|
float nodes{ 0 }, leaves{ 0 }, depth{ 0 };
|
||||||
@@ -37,7 +38,7 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
Result() = default;
|
Result() = default;
|
||||||
Result& setDataset(const string& dataset) { this->dataset = dataset; return *this; }
|
Result& setDataset(const string& dataset) { this->dataset = dataset; return *this; }
|
||||||
Result& setHyperparameters(const string& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
|
Result& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
|
||||||
Result& setSamples(int samples) { this->samples = samples; return *this; }
|
Result& setSamples(int samples) { this->samples = samples; return *this; }
|
||||||
Result& setFeatures(int features) { this->features = features; return *this; }
|
Result& setFeatures(int features) { this->features = features; return *this; }
|
||||||
Result& setClasses(int classes) { this->classes = classes; return *this; }
|
Result& setClasses(int classes) { this->classes = classes; return *this; }
|
||||||
@@ -59,7 +60,7 @@ namespace platform {
|
|||||||
const float get_score_train() const { return score_train; }
|
const float get_score_train() const { return score_train; }
|
||||||
float get_score_test() { return score_test; }
|
float get_score_test() { return score_test; }
|
||||||
const string& getDataset() const { return dataset; }
|
const string& getDataset() const { return dataset; }
|
||||||
const string& getHyperparameters() const { return hyperparameters; }
|
const json& getHyperparameters() const { return hyperparameters; }
|
||||||
const int getSamples() const { return samples; }
|
const int getSamples() const { return samples; }
|
||||||
const int getFeatures() const { return features; }
|
const int getFeatures() const { return features; }
|
||||||
const int getClasses() const { return classes; }
|
const int getClasses() const { return classes; }
|
||||||
@@ -85,11 +86,12 @@ namespace platform {
|
|||||||
bool discretized{ false }, stratified{ false };
|
bool discretized{ false }, stratified{ false };
|
||||||
vector<Result> results;
|
vector<Result> results;
|
||||||
vector<int> randomSeeds;
|
vector<int> randomSeeds;
|
||||||
|
json hyperparameters = "{}";
|
||||||
int nfolds{ 0 };
|
int nfolds{ 0 };
|
||||||
float duration{ 0 };
|
float duration{ 0 };
|
||||||
json build_json();
|
json build_json();
|
||||||
public:
|
public:
|
||||||
Experiment() = default;
|
Experiment();
|
||||||
Experiment& setTitle(const string& title) { this->title = title; return *this; }
|
Experiment& setTitle(const string& title) { this->title = title; return *this; }
|
||||||
Experiment& setModel(const string& model) { this->model = model; return *this; }
|
Experiment& setModel(const string& model) { this->model = model; return *this; }
|
||||||
Experiment& setPlatform(const string& platform) { this->platform = platform; return *this; }
|
Experiment& setPlatform(const string& platform) { this->platform = platform; return *this; }
|
||||||
@@ -103,6 +105,7 @@ namespace platform {
|
|||||||
Experiment& addResult(Result result) { results.push_back(result); return *this; }
|
Experiment& addResult(Result result) { results.push_back(result); return *this; }
|
||||||
Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); return *this; }
|
Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); return *this; }
|
||||||
Experiment& setDuration(float duration) { this->duration = duration; return *this; }
|
Experiment& setDuration(float duration) { this->duration = duration; return *this; }
|
||||||
|
Experiment& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
|
||||||
string get_file_name();
|
string get_file_name();
|
||||||
void save(const string& path);
|
void save(const string& path);
|
||||||
void cross_validation(const string& path, const string& fileName);
|
void cross_validation(const string& path, const string& fileName);
|
||||||
|
@@ -6,6 +6,7 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
static std::string datasets() { return "datasets/"; }
|
static std::string datasets() { return "datasets/"; }
|
||||||
static std::string results() { return "results/"; }
|
static std::string results() { return "results/"; }
|
||||||
|
static std::string excel() { return "excel/"; }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -1,26 +0,0 @@
|
|||||||
#ifndef REPORT_H
|
|
||||||
#define REPORT_H
|
|
||||||
#include <string>
|
|
||||||
#include <iostream>
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
#include "Colors.h"
|
|
||||||
|
|
||||||
using json = nlohmann::json;
|
|
||||||
const int MAXL = 128;
|
|
||||||
namespace platform {
|
|
||||||
using namespace std;
|
|
||||||
class Report {
|
|
||||||
public:
|
|
||||||
explicit Report(json data_) { data = data_; };
|
|
||||||
virtual ~Report() = default;
|
|
||||||
void show();
|
|
||||||
private:
|
|
||||||
void header();
|
|
||||||
void body();
|
|
||||||
void footer();
|
|
||||||
string fromVector(const string& key);
|
|
||||||
json data;
|
|
||||||
double totalScore; // Total score of all results in a report
|
|
||||||
};
|
|
||||||
};
|
|
||||||
#endif
|
|
37
src/Platform/ReportBase.cc
Normal file
37
src/Platform/ReportBase.cc
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#include <sstream>
|
||||||
|
#include <locale>
|
||||||
|
#include "ReportBase.h"
|
||||||
|
#include "BestResult.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
string ReportBase::fromVector(const string& key)
|
||||||
|
{
|
||||||
|
stringstream oss;
|
||||||
|
string sep = "";
|
||||||
|
oss << "[";
|
||||||
|
for (auto& item : data[key]) {
|
||||||
|
oss << sep << item.get<double>();
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
oss << "]";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
string ReportBase::fVector(const string& title, const json& data, const int width, const int precision)
|
||||||
|
{
|
||||||
|
stringstream oss;
|
||||||
|
string sep = "";
|
||||||
|
oss << title << "[";
|
||||||
|
for (const auto& item : data) {
|
||||||
|
oss << sep << fixed << setw(width) << setprecision(precision) << item.get<double>();
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
oss << "]";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
void ReportBase::show()
|
||||||
|
{
|
||||||
|
header();
|
||||||
|
body();
|
||||||
|
}
|
||||||
|
}
|
23
src/Platform/ReportBase.h
Normal file
23
src/Platform/ReportBase.h
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
#ifndef REPORTBASE_H
|
||||||
|
#define REPORTBASE_H
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
namespace platform {
|
||||||
|
using namespace std;
|
||||||
|
class ReportBase {
|
||||||
|
public:
|
||||||
|
explicit ReportBase(json data_) { data = data_; };
|
||||||
|
virtual ~ReportBase() = default;
|
||||||
|
void show();
|
||||||
|
protected:
|
||||||
|
json data;
|
||||||
|
string fromVector(const string& key);
|
||||||
|
string fVector(const string& title, const json& data, const int width, const int precision);
|
||||||
|
virtual void header() = 0;
|
||||||
|
virtual void body() = 0;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
#endif
|
@@ -1,52 +1,24 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <locale>
|
#include <locale>
|
||||||
#include "Report.h"
|
#include "ReportConsole.h"
|
||||||
#include "BestResult.h"
|
#include "BestResult.h"
|
||||||
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
string headerLine(const string& text)
|
|
||||||
{
|
|
||||||
int n = MAXL - text.length() - 3;
|
|
||||||
n = n < 0 ? 0 : n;
|
|
||||||
return "* " + text + string(n, ' ') + "*\n";
|
|
||||||
}
|
|
||||||
string Report::fromVector(const string& key)
|
|
||||||
{
|
|
||||||
stringstream oss;
|
|
||||||
string sep = "";
|
|
||||||
oss << "[";
|
|
||||||
for (auto& item : data[key]) {
|
|
||||||
oss << sep << item.get<double>();
|
|
||||||
sep = ", ";
|
|
||||||
}
|
|
||||||
oss << "]";
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
string fVector(const string& title, const json& data, const int width, const int precision)
|
|
||||||
{
|
|
||||||
stringstream oss;
|
|
||||||
string sep = "";
|
|
||||||
oss << title << "[";
|
|
||||||
for (const auto& item : data) {
|
|
||||||
oss << sep << fixed << setw(width) << setprecision(precision) << item.get<double>();
|
|
||||||
sep = ", ";
|
|
||||||
}
|
|
||||||
oss << "]";
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
void Report::show()
|
|
||||||
{
|
|
||||||
header();
|
|
||||||
body();
|
|
||||||
footer();
|
|
||||||
}
|
|
||||||
struct separated : numpunct<char> {
|
struct separated : numpunct<char> {
|
||||||
char do_decimal_point() const { return ','; }
|
char do_decimal_point() const { return ','; }
|
||||||
char do_thousands_sep() const { return '.'; }
|
char do_thousands_sep() const { return '.'; }
|
||||||
string do_grouping() const { return "\03"; }
|
string do_grouping() const { return "\03"; }
|
||||||
};
|
};
|
||||||
void Report::header()
|
|
||||||
|
string ReportConsole::headerLine(const string& text)
|
||||||
|
{
|
||||||
|
int n = MAXL - text.length() - 3;
|
||||||
|
n = n < 0 ? 0 : n;
|
||||||
|
return "* " + text + string(n, ' ') + "*\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReportConsole::header()
|
||||||
{
|
{
|
||||||
locale mylocale(cout.getloc(), new separated);
|
locale mylocale(cout.getloc(), new separated);
|
||||||
locale::global(mylocale);
|
locale::global(mylocale);
|
||||||
@@ -62,12 +34,12 @@ namespace platform {
|
|||||||
cout << string(MAXL, '*') << endl;
|
cout << string(MAXL, '*') << endl;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
void Report::body()
|
void ReportConsole::body()
|
||||||
{
|
{
|
||||||
cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
|
cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
|
||||||
cout << "============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl;
|
cout << "============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl;
|
||||||
json lastResult;
|
json lastResult;
|
||||||
totalScore = 0;
|
double totalScore = 0.0;
|
||||||
bool odd = true;
|
bool odd = true;
|
||||||
for (const auto& r : data["results"]) {
|
for (const auto& r : data["results"]) {
|
||||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||||
@@ -98,9 +70,11 @@ namespace platform {
|
|||||||
cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3));
|
cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3));
|
||||||
cout << headerLine(fVector("Test times: ", lastResult["times_test"], 10, 3));
|
cout << headerLine(fVector("Test times: ", lastResult["times_test"], 10, 3));
|
||||||
cout << string(MAXL, '*') << endl;
|
cout << string(MAXL, '*') << endl;
|
||||||
|
} else {
|
||||||
|
footer(totalScore);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void Report::footer()
|
void ReportConsole::footer(double totalScore)
|
||||||
{
|
{
|
||||||
cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
|
cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
|
||||||
auto score = data["score_name"].get<string>();
|
auto score = data["score_name"].get<string>();
|
||||||
@@ -110,6 +84,5 @@ namespace platform {
|
|||||||
cout << headerLine(oss.str());
|
cout << headerLine(oss.str());
|
||||||
}
|
}
|
||||||
cout << string(MAXL, '*') << endl << Colors::RESET();
|
cout << string(MAXL, '*') << endl << Colors::RESET();
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
22
src/Platform/ReportConsole.h
Normal file
22
src/Platform/ReportConsole.h
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
#ifndef REPORTCONSOLE_H
|
||||||
|
#define REPORTCONSOLE_H
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
#include "ReportBase.h"
|
||||||
|
#include "Colors.h"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
using namespace std;
|
||||||
|
const int MAXL = 128;
|
||||||
|
class ReportConsole : public ReportBase{
|
||||||
|
public:
|
||||||
|
explicit ReportConsole(json data_) : ReportBase(data_) {};
|
||||||
|
virtual ~ReportConsole() = default;
|
||||||
|
private:
|
||||||
|
string headerLine(const string& text);
|
||||||
|
void header() override;
|
||||||
|
void body() override;
|
||||||
|
void footer(double totalScore);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
#endif
|
109
src/Platform/ReportExcel.cc
Normal file
109
src/Platform/ReportExcel.cc
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
#include <sstream>
|
||||||
|
#include <locale>
|
||||||
|
#include "ReportExcel.h"
|
||||||
|
#include "BestResult.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
struct separated : numpunct<char> {
|
||||||
|
char do_decimal_point() const { return ','; }
|
||||||
|
|
||||||
|
char do_thousands_sep() const { return '.'; }
|
||||||
|
|
||||||
|
string do_grouping() const { return "\03"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
void ReportExcel::createFile()
|
||||||
|
{
|
||||||
|
doc.create(Paths::excel() + "some_results.xlsx");
|
||||||
|
wks = doc.workbook().worksheet("Sheet1");
|
||||||
|
wks.setName(data["model"].get<string>());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReportExcel::closeFile()
|
||||||
|
{
|
||||||
|
doc.save();
|
||||||
|
doc.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReportExcel::header()
|
||||||
|
{
|
||||||
|
locale mylocale(cout.getloc(), new separated);
|
||||||
|
locale::global(mylocale);
|
||||||
|
cout.imbue(mylocale);
|
||||||
|
stringstream oss;
|
||||||
|
wks.cell("A1").value().set(
|
||||||
|
"Report " + data["model"].get<string>() + " ver. " + data["version"].get<string>() + " with " +
|
||||||
|
to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
|
||||||
|
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>());
|
||||||
|
wks.cell("A2").value() = data["title"].get<string>();
|
||||||
|
wks.cell("A3").value() = "Random seeds: " + fromVector("seeds") + " Stratified: " +
|
||||||
|
(data["stratified"].get<bool>() ? "True" : "False");
|
||||||
|
oss << "Execution took " << setprecision(2) << fixed << data["duration"].get<float>() << " seconds, "
|
||||||
|
<< data["duration"].get<float>() / 3600 << " hours, on " << data["platform"].get<string>();
|
||||||
|
wks.cell("A4").value() = oss.str();
|
||||||
|
wks.cell("A5").value() = "Score is " + data["score_name"].get<string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReportExcel::body()
|
||||||
|
{
|
||||||
|
auto header = vector<string>(
|
||||||
|
{ "Dataset", "Samples", "Features", "Classes", "Nodes", "Edges", "States", "Score", "Score Std.", "Time",
|
||||||
|
"Time Std.", "Hyperparameters" });
|
||||||
|
int col = 1;
|
||||||
|
for (const auto& item : header) {
|
||||||
|
wks.cell(8, col++).value() = item;
|
||||||
|
}
|
||||||
|
int row = 9;
|
||||||
|
col = 1;
|
||||||
|
json lastResult;
|
||||||
|
double totalScore = 0.0;
|
||||||
|
string hyperparameters;
|
||||||
|
for (const auto& r : data["results"]) {
|
||||||
|
wks.cell(row, col).value() = r["dataset"].get<string>();
|
||||||
|
wks.cell(row, col + 1).value() = r["samples"].get<int>();
|
||||||
|
wks.cell(row, col + 2).value() = r["features"].get<int>();
|
||||||
|
wks.cell(row, col + 3).value() = r["classes"].get<int>();
|
||||||
|
wks.cell(row, col + 4).value() = r["nodes"].get<float>();
|
||||||
|
wks.cell(row, col + 5).value() = r["leaves"].get<float>();
|
||||||
|
wks.cell(row, col + 6).value() = r["depth"].get<float>();
|
||||||
|
wks.cell(row, col + 7).value() = r["score"].get<double>();
|
||||||
|
wks.cell(row, col + 8).value() = r["score_std"].get<double>();
|
||||||
|
wks.cell(row, col + 9).value() = r["time"].get<double>();
|
||||||
|
wks.cell(row, col + 10).value() = r["time_std"].get<double>();
|
||||||
|
try {
|
||||||
|
hyperparameters = r["hyperparameters"].get<string>();
|
||||||
|
}
|
||||||
|
catch (const exception& err) {
|
||||||
|
stringstream oss;
|
||||||
|
oss << r["hyperparameters"];
|
||||||
|
hyperparameters = oss.str();
|
||||||
|
}
|
||||||
|
wks.cell(row, col + 11).value() = hyperparameters;
|
||||||
|
lastResult = r;
|
||||||
|
totalScore += r["score"].get<double>();
|
||||||
|
row++;
|
||||||
|
}
|
||||||
|
if (data["results"].size() == 1) {
|
||||||
|
for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
|
||||||
|
row++;
|
||||||
|
col = 1;
|
||||||
|
wks.cell(row, col).value() = group;
|
||||||
|
for (double item : lastResult[group]) {
|
||||||
|
wks.cell(row, ++col).value() = item;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
footer(totalScore, row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReportExcel::footer(double totalScore, int row)
|
||||||
|
{
|
||||||
|
auto score = data["score_name"].get<string>();
|
||||||
|
if (score == BestResult::scoreName()) {
|
||||||
|
wks.cell(row + 2, 1).value() = score + " compared to " + BestResult::title() + " .: ";
|
||||||
|
wks.cell(row + 2, 5).value() = totalScore / BestResult::score();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
25
src/Platform/ReportExcel.h
Normal file
25
src/Platform/ReportExcel.h
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
#ifndef REPORTEXCEL_H
|
||||||
|
#define REPORTEXCEL_H
|
||||||
|
#include <OpenXLSX.hpp>
|
||||||
|
#include "ReportBase.h"
|
||||||
|
#include "Paths.h"
|
||||||
|
#include "Colors.h"
|
||||||
|
namespace platform {
|
||||||
|
using namespace std;
|
||||||
|
using namespace OpenXLSX;
|
||||||
|
const int MAXLL = 128;
|
||||||
|
class ReportExcel : public ReportBase{
|
||||||
|
public:
|
||||||
|
explicit ReportExcel(json data_) : ReportBase(data_) {createFile();};
|
||||||
|
virtual ~ReportExcel() {closeFile();};
|
||||||
|
private:
|
||||||
|
void createFile();
|
||||||
|
void closeFile();
|
||||||
|
XLDocument doc;
|
||||||
|
XLWorksheet wks;
|
||||||
|
void header() override;
|
||||||
|
void body() override;
|
||||||
|
void footer(double totalScore, int row);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
#endif // !REPORTEXCEL_H
|
@@ -1,7 +1,8 @@
|
|||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include "platformUtils.h"
|
#include "platformUtils.h"
|
||||||
#include "Results.h"
|
#include "Results.h"
|
||||||
#include "Report.h"
|
#include "ReportConsole.h"
|
||||||
|
#include "ReportExcel.h"
|
||||||
#include "BestResult.h"
|
#include "BestResult.h"
|
||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
@@ -94,21 +95,26 @@ namespace platform {
|
|||||||
cout << "Invalid index" << endl;
|
cout << "Invalid index" << endl;
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
void Results::report(const int index) const
|
void Results::report(const int index, const bool excelReport) const
|
||||||
{
|
{
|
||||||
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
|
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
|
||||||
auto data = files.at(index).load();
|
auto data = files.at(index).load();
|
||||||
Report report(data);
|
if (excelReport) {
|
||||||
report.show();
|
ReportExcel report(data);
|
||||||
|
report.show();
|
||||||
|
} else {
|
||||||
|
ReportConsole report(data);
|
||||||
|
report.show();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
void Results::menu()
|
void Results::menu()
|
||||||
{
|
{
|
||||||
char option;
|
char option;
|
||||||
int index;
|
int index;
|
||||||
bool finished = false;
|
bool finished = false;
|
||||||
string filename, line, options = "qldhsr";
|
string filename, line, options = "qldhsre";
|
||||||
while (!finished) {
|
while (!finished) {
|
||||||
cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): ";
|
cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): ";
|
||||||
getline(cin, line);
|
getline(cin, line);
|
||||||
if (line.size() == 0)
|
if (line.size() == 0)
|
||||||
continue;
|
continue;
|
||||||
@@ -119,12 +125,14 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
option = line[0];
|
option = line[0];
|
||||||
} else {
|
} else {
|
||||||
index = stoi(line);
|
if (all_of(line.begin(), line.end(), ::isdigit)) {
|
||||||
if (index >= 0 && index < files.size()) {
|
index = stoi(line);
|
||||||
report(index);
|
if (index >= 0 && index < files.size()) {
|
||||||
} else {
|
report(index, false);
|
||||||
cout << "Invalid option" << endl;
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
cout << "Invalid option" << endl;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
switch (option) {
|
switch (option) {
|
||||||
@@ -164,7 +172,13 @@ namespace platform {
|
|||||||
index = getIndex("report");
|
index = getIndex("report");
|
||||||
if (index == -1)
|
if (index == -1)
|
||||||
break;
|
break;
|
||||||
report(index);
|
report(index, false);
|
||||||
|
break;
|
||||||
|
case 'e':
|
||||||
|
index = getIndex("excel");
|
||||||
|
if (index == -1)
|
||||||
|
break;
|
||||||
|
report(index, true);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
cout << "Invalid option" << endl;
|
cout << "Invalid option" << endl;
|
||||||
@@ -231,6 +245,7 @@ namespace platform {
|
|||||||
cout << "No results found!" << endl;
|
cout << "No results found!" << endl;
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
|
sortDate();
|
||||||
show();
|
show();
|
||||||
menu();
|
menu();
|
||||||
cout << "Done!" << endl;
|
cout << "Done!" << endl;
|
||||||
|
@@ -42,7 +42,7 @@ namespace platform {
|
|||||||
vector<Result> files;
|
vector<Result> files;
|
||||||
void load(); // Loads the list of results
|
void load(); // Loads the list of results
|
||||||
void show() const;
|
void show() const;
|
||||||
void report(const int index) const;
|
void report(const int index, const bool excelReport) const;
|
||||||
int getIndex(const string& intent) const;
|
int getIndex(const string& intent) const;
|
||||||
void menu();
|
void menu();
|
||||||
void sortList();
|
void sortList();
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <argparse/argparse.hpp>
|
#include <argparse/argparse.hpp>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
#include "platformUtils.h"
|
#include "platformUtils.h"
|
||||||
#include "Experiment.h"
|
#include "Experiment.h"
|
||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
@@ -10,12 +11,14 @@
|
|||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
argparse::ArgumentParser program("main");
|
argparse::ArgumentParser program("main");
|
||||||
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
||||||
|
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparamters passed to the model in Experiment");
|
||||||
program.add_argument("-p", "--path")
|
program.add_argument("-p", "--path")
|
||||||
.help("folder where the data files are located, default")
|
.help("folder where the data files are located, default")
|
||||||
.default_value(string{ platform::Paths::datasets() });
|
.default_value(string{ platform::Paths::datasets() });
|
||||||
@@ -31,6 +34,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
);
|
);
|
||||||
program.add_argument("--title").default_value("").help("Experiment title");
|
program.add_argument("--title").default_value("").help("Experiment title");
|
||||||
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
|
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const string& value) {
|
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const string& value) {
|
||||||
try {
|
try {
|
||||||
@@ -59,6 +63,8 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
auto seeds = program.get<vector<int>>("seeds");
|
auto seeds = program.get<vector<int>>("seeds");
|
||||||
auto complete_file_name = path + file_name + ".arff";
|
auto complete_file_name = path + file_name + ".arff";
|
||||||
auto title = program.get<string>("title");
|
auto title = program.get<string>("title");
|
||||||
|
auto hyperparameters = program.get<string>("hyperparameters");
|
||||||
|
auto saveResults = program.get<bool>("save");
|
||||||
if (title == "" && file_name == "") {
|
if (title == "" && file_name == "") {
|
||||||
throw runtime_error("title is mandatory if dataset is not provided");
|
throw runtime_error("title is mandatory if dataset is not provided");
|
||||||
}
|
}
|
||||||
@@ -74,7 +80,6 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
auto program = manageArguments(argc, argv);
|
auto program = manageArguments(argc, argv);
|
||||||
bool saveResults = false;
|
|
||||||
auto file_name = program.get<string>("dataset");
|
auto file_name = program.get<string>("dataset");
|
||||||
auto path = program.get<string>("path");
|
auto path = program.get<string>("path");
|
||||||
auto model_name = program.get<string>("model");
|
auto model_name = program.get<string>("model");
|
||||||
@@ -82,9 +87,11 @@ int main(int argc, char** argv)
|
|||||||
auto stratified = program.get<bool>("stratified");
|
auto stratified = program.get<bool>("stratified");
|
||||||
auto n_folds = program.get<int>("folds");
|
auto n_folds = program.get<int>("folds");
|
||||||
auto seeds = program.get<vector<int>>("seeds");
|
auto seeds = program.get<vector<int>>("seeds");
|
||||||
|
auto hyperparameters =program.get<string>("hyperparameters");
|
||||||
vector<string> filesToTest;
|
vector<string> filesToTest;
|
||||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||||
auto title = program.get<string>("title");
|
auto title = program.get<string>("title");
|
||||||
|
auto saveResults = program.get<bool>("save");
|
||||||
if (file_name != "") {
|
if (file_name != "") {
|
||||||
if (!datasets.isDataset(file_name)) {
|
if (!datasets.isDataset(file_name)) {
|
||||||
cerr << "Dataset " << file_name << " not found" << endl;
|
cerr << "Dataset " << file_name << " not found" << endl;
|
||||||
@@ -106,6 +113,7 @@ int main(int argc, char** argv)
|
|||||||
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
|
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
|
||||||
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
||||||
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy");
|
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy");
|
||||||
|
experiment.setHyperparameters(json::parse(hyperparameters));
|
||||||
for (auto seed : seeds) {
|
for (auto seed : seeds) {
|
||||||
experiment.addRandomSeed(seed);
|
experiment.addRandomSeed(seed);
|
||||||
}
|
}
|
||||||
@@ -113,10 +121,10 @@ int main(int argc, char** argv)
|
|||||||
timer.start();
|
timer.start();
|
||||||
experiment.go(filesToTest, path);
|
experiment.go(filesToTest, path);
|
||||||
experiment.setDuration(timer.getDuration());
|
experiment.setDuration(timer.getDuration());
|
||||||
if (saveResults)
|
if (saveResults) {
|
||||||
experiment.save(platform::Paths::results());
|
experiment.save(platform::Paths::results());
|
||||||
else
|
}
|
||||||
experiment.report();
|
experiment.report();
|
||||||
cout << "Done!" << endl;
|
cout << "Done!" << endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user