Begin adding cfs to BoostAODE

This commit is contained in:
Ricardo Montañana Gómez 2023-10-10 11:52:39 +02:00
parent 7d8aca4f59
commit f288bbd6fa
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
13 changed files with 59 additions and 28 deletions

View File

@ -4,6 +4,7 @@
#include "Colors.h"
#include "Folding.h"
#include <limits.h>
#include "Paths.h"
namespace bayesnet {
BoostAODE::BoostAODE() : Ensemble() {}
@ -28,6 +29,9 @@ namespace bayesnet {
if (hyperparameters.contains("convergence")) {
convergence = hyperparameters["convergence"];
}
if (hyperparameters.contains("cfs")) {
cfs = hyperparameters["cfs"];
}
}
void BoostAODE::validationInit()
{
@ -58,6 +62,12 @@ namespace bayesnet {
}
}
void BoostAODE::initializeModels()
{
ifstream file(cfs + ".json");
if (file.is_open()) {
}
}
void BoostAODE::trainModel(const torch::Tensor& weights)
{
models.clear();
@ -66,6 +76,9 @@ namespace bayesnet {
maxModels = .1 * n > 10 ? .1 * n : n;
validationInit();
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
if (cfs != "") {
initializeModels();
}
bool exitCondition = false;
unordered_set<int> featuresUsed;
// Variables to control the accuracy finish condition

View File

@ -16,10 +16,13 @@ namespace bayesnet {
torch::Tensor dataset_;
torch::Tensor X_train, y_train, X_test, y_test;
void validationInit();
bool repeatSparent = false;
void initializeModels();
// Hyperparameters
bool repeatSparent = false; // if true, a feature can be selected more than once
int maxModels = 0;
bool ascending = false; //Process KBest features ascending or descending order
bool convergence = false; //if true, stop when the model does not improve
string cfs = ""; // if not empty, use CFS to select features
};
}
#endif

View File

@ -212,14 +212,4 @@ namespace platform {
}
return Xd;
}
vector<string> Dataset::split(const string& text, char delimiter)
{
vector<string> result;
stringstream ss(text);
string token;
while (getline(ss, token, delimiter)) {
result.push_back(token);
}
return result;
}
}

View File

@ -5,6 +5,7 @@
#include <vector>
#include <string>
#include "CPPFImdlp.h"
#include "Utils.h"
namespace platform {
using namespace std;
@ -62,7 +63,6 @@ namespace platform {
public:
Dataset(const string& path, const string& name, const string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
explicit Dataset(const Dataset&);
static vector<string> split(const string& text, char delimiter);
string getName() const;
string getClassName() const;
vector<string> getFeatures() const;

View File

@ -13,7 +13,7 @@ namespace platform {
if (line.empty() || line[0] == '#') {
continue;
}
vector<string> tokens = Dataset::split(line, ',');
vector<string> tokens = split(line, ',');
string name = tokens[0];
string className;
if (tokens.size() == 1) {

View File

@ -4,7 +4,10 @@
#include <map>
#include <fstream>
#include <sstream>
#include "Dataset.h"
#include <iostream>
#include "Utils.h"
//#include "Dataset.h"
namespace platform {
class DotEnv {
private:
@ -51,7 +54,7 @@ namespace platform {
auto seeds_str = env["seeds"];
seeds_str = trim(seeds_str);
seeds_str = seeds_str.substr(1, seeds_str.size() - 2);
auto seeds_str_split = Dataset::split(seeds_str, ',');
auto seeds_str_split = split(seeds_str, ',');
transform(seeds_str_split.begin(), seeds_str_split.end(), back_inserter(seeds), [](const std::string& str) {
return stoi(str);
});

View File

@ -3,7 +3,7 @@
#include "Datasets.h"
#include "Models.h"
#include "ReportConsole.h"
#include "DotEnv.h"
#include "Paths.h"
namespace platform {
using json = nlohmann::json;
string get_date()
@ -134,8 +134,7 @@ namespace platform {
}
void Experiment::cross_validation(const string& fileName)
{
auto env = platform::DotEnv();
auto datasets = platform::Datasets(discretized, env.get("source_data"));
auto datasets = platform::Datasets(discretized, Paths::datasets());
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);

View File

@ -1,11 +1,17 @@
#ifndef PATHS_H
#define PATHS_H
#include <string>
#include "DotEnv.h"
namespace platform {
class Paths {
public:
static std::string results() { return "results/"; }
static std::string excel() { return "excel/"; }
static std::string datasets()
{
auto env = platform::DotEnv();
return env.get("source_data");
}
};
}
#endif

View File

@ -58,8 +58,7 @@ namespace platform {
}
} else {
if (data["score_name"].get<string>() == "accuracy") {
auto env = platform::DotEnv();
auto dt = Datasets(false, env.get("source_data"));
auto dt = Datasets(false, Paths::datasets());
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
if (numClasses == 2) {

View File

@ -56,10 +56,12 @@ namespace platform {
try {
cout << r["hyperparameters"].get<string>();
}
catch (const exception& err) {
cout << r["hyperparameters"];
catch (...) {
//cout << r["hyperparameters"];
cout << "Arrggggghhhh!" << endl;
}
cout << endl;
cout << flush;
lastResult = r;
totalScore += r["score"].get<double>();
odd = !odd;

19
src/Platform/Utils.h Normal file
View File

@ -0,0 +1,19 @@
#ifndef UTILS_H
#define UTILS_H
#include <sstream>
#include <string>
#include <vector>
namespace platform {
//static vector<string> split(const string& text, char delimiter);
static std::vector<std::string> split(const std::string& text, char delimiter)
{
std::vector<std::string> result;
std::stringstream ss(text);
std::string token;
while (std::getline(ss, token, delimiter)) {
result.push_back(token);
}
return result;
}
}
#endif

View File

@ -3,7 +3,6 @@
#include "Paths.h"
#include "Colors.h"
#include "Datasets.h"
#include "DotEnv.h"
using namespace std;
const int BALANCE_LENGTH = 75;
@ -28,8 +27,7 @@ void outputBalance(const string& balance)
int main(int argc, char** argv)
{
auto env = platform::DotEnv();
auto data = platform::Datasets(false, env.get("source_data"));
auto data = platform::Datasets(false, platform::Paths::datasets());
locale mylocale(cout.getloc(), new separated);
locale::global(mylocale);
cout.imbue(mylocale);

View File

@ -82,8 +82,7 @@ int main(int argc, char** argv)
auto seeds = program.get<vector<int>>("seeds");
auto hyperparameters = program.get<string>("hyperparameters");
vector<string> filesToTest;
auto env = platform::DotEnv();
auto datasets = platform::Datasets(discretize_dataset, env.get("source_data"));
auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets());
auto title = program.get<string>("title");
auto saveResults = program.get<bool>("save");
if (file_name != "") {
@ -102,7 +101,7 @@ int main(int argc, char** argv)
/*
* Begin Processing
*/
auto env = platform::DotEnv();
auto experiment = platform::Experiment();
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));