Begin adding cfs to BoostAODE
This commit is contained in:
parent
7d8aca4f59
commit
f288bbd6fa
@ -4,6 +4,7 @@
|
||||
#include "Colors.h"
|
||||
#include "Folding.h"
|
||||
#include <limits.h>
|
||||
#include "Paths.h"
|
||||
|
||||
namespace bayesnet {
|
||||
BoostAODE::BoostAODE() : Ensemble() {}
|
||||
@ -28,6 +29,9 @@ namespace bayesnet {
|
||||
if (hyperparameters.contains("convergence")) {
|
||||
convergence = hyperparameters["convergence"];
|
||||
}
|
||||
if (hyperparameters.contains("cfs")) {
|
||||
cfs = hyperparameters["cfs"];
|
||||
}
|
||||
}
|
||||
void BoostAODE::validationInit()
|
||||
{
|
||||
@ -58,6 +62,12 @@ namespace bayesnet {
|
||||
}
|
||||
|
||||
}
|
||||
void BoostAODE::initializeModels()
|
||||
{
|
||||
ifstream file(cfs + ".json");
|
||||
if (file.is_open()) {
|
||||
}
|
||||
}
|
||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
models.clear();
|
||||
@ -66,6 +76,9 @@ namespace bayesnet {
|
||||
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||
validationInit();
|
||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||
if (cfs != "") {
|
||||
initializeModels();
|
||||
}
|
||||
bool exitCondition = false;
|
||||
unordered_set<int> featuresUsed;
|
||||
// Variables to control the accuracy finish condition
|
||||
|
@ -16,10 +16,13 @@ namespace bayesnet {
|
||||
torch::Tensor dataset_;
|
||||
torch::Tensor X_train, y_train, X_test, y_test;
|
||||
void validationInit();
|
||||
bool repeatSparent = false;
|
||||
void initializeModels();
|
||||
// Hyperparameters
|
||||
bool repeatSparent = false; // if true, a feature can be selected more than once
|
||||
int maxModels = 0;
|
||||
bool ascending = false; //Process KBest features ascending or descending order
|
||||
bool convergence = false; //if true, stop when the model does not improve
|
||||
string cfs = ""; // if not empty, use CFS to select features
|
||||
};
|
||||
}
|
||||
#endif
|
@ -212,14 +212,4 @@ namespace platform {
|
||||
}
|
||||
return Xd;
|
||||
}
|
||||
vector<string> Dataset::split(const string& text, char delimiter)
|
||||
{
|
||||
vector<string> result;
|
||||
stringstream ss(text);
|
||||
string token;
|
||||
while (getline(ss, token, delimiter)) {
|
||||
result.push_back(token);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "CPPFImdlp.h"
|
||||
#include "Utils.h"
|
||||
namespace platform {
|
||||
using namespace std;
|
||||
|
||||
@ -62,7 +63,6 @@ namespace platform {
|
||||
public:
|
||||
Dataset(const string& path, const string& name, const string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
|
||||
explicit Dataset(const Dataset&);
|
||||
static vector<string> split(const string& text, char delimiter);
|
||||
string getName() const;
|
||||
string getClassName() const;
|
||||
vector<string> getFeatures() const;
|
||||
|
@ -13,7 +13,7 @@ namespace platform {
|
||||
if (line.empty() || line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
vector<string> tokens = Dataset::split(line, ',');
|
||||
vector<string> tokens = split(line, ',');
|
||||
string name = tokens[0];
|
||||
string className;
|
||||
if (tokens.size() == 1) {
|
||||
|
@ -4,7 +4,10 @@
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include "Dataset.h"
|
||||
#include <iostream>
|
||||
#include "Utils.h"
|
||||
|
||||
//#include "Dataset.h"
|
||||
namespace platform {
|
||||
class DotEnv {
|
||||
private:
|
||||
@ -51,7 +54,7 @@ namespace platform {
|
||||
auto seeds_str = env["seeds"];
|
||||
seeds_str = trim(seeds_str);
|
||||
seeds_str = seeds_str.substr(1, seeds_str.size() - 2);
|
||||
auto seeds_str_split = Dataset::split(seeds_str, ',');
|
||||
auto seeds_str_split = split(seeds_str, ',');
|
||||
transform(seeds_str_split.begin(), seeds_str_split.end(), back_inserter(seeds), [](const std::string& str) {
|
||||
return stoi(str);
|
||||
});
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include "Datasets.h"
|
||||
#include "Models.h"
|
||||
#include "ReportConsole.h"
|
||||
#include "DotEnv.h"
|
||||
#include "Paths.h"
|
||||
namespace platform {
|
||||
using json = nlohmann::json;
|
||||
string get_date()
|
||||
@ -134,8 +134,7 @@ namespace platform {
|
||||
}
|
||||
void Experiment::cross_validation(const string& fileName)
|
||||
{
|
||||
auto env = platform::DotEnv();
|
||||
auto datasets = platform::Datasets(discretized, env.get("source_data"));
|
||||
auto datasets = platform::Datasets(discretized, Paths::datasets());
|
||||
// Get dataset
|
||||
auto [X, y] = datasets.getTensors(fileName);
|
||||
auto states = datasets.getStates(fileName);
|
||||
|
@ -1,11 +1,17 @@
|
||||
#ifndef PATHS_H
|
||||
#define PATHS_H
|
||||
#include <string>
|
||||
#include "DotEnv.h"
|
||||
namespace platform {
|
||||
class Paths {
|
||||
public:
|
||||
static std::string results() { return "results/"; }
|
||||
static std::string excel() { return "excel/"; }
|
||||
static std::string datasets()
|
||||
{
|
||||
auto env = platform::DotEnv();
|
||||
return env.get("source_data");
|
||||
}
|
||||
};
|
||||
}
|
||||
#endif
|
@ -58,8 +58,7 @@ namespace platform {
|
||||
}
|
||||
} else {
|
||||
if (data["score_name"].get<string>() == "accuracy") {
|
||||
auto env = platform::DotEnv();
|
||||
auto dt = Datasets(false, env.get("source_data"));
|
||||
auto dt = Datasets(false, Paths::datasets());
|
||||
dt.loadDataset(dataset);
|
||||
auto numClasses = dt.getNClasses(dataset);
|
||||
if (numClasses == 2) {
|
||||
|
@ -56,10 +56,12 @@ namespace platform {
|
||||
try {
|
||||
cout << r["hyperparameters"].get<string>();
|
||||
}
|
||||
catch (const exception& err) {
|
||||
cout << r["hyperparameters"];
|
||||
catch (...) {
|
||||
//cout << r["hyperparameters"];
|
||||
cout << "Arrggggghhhh!" << endl;
|
||||
}
|
||||
cout << endl;
|
||||
cout << flush;
|
||||
lastResult = r;
|
||||
totalScore += r["score"].get<double>();
|
||||
odd = !odd;
|
||||
|
19
src/Platform/Utils.h
Normal file
19
src/Platform/Utils.h
Normal file
@ -0,0 +1,19 @@
|
||||
#ifndef UTILS_H
|
||||
#define UTILS_H
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
namespace platform {
|
||||
//static vector<string> split(const string& text, char delimiter);
|
||||
static std::vector<std::string> split(const std::string& text, char delimiter)
|
||||
{
|
||||
std::vector<std::string> result;
|
||||
std::stringstream ss(text);
|
||||
std::string token;
|
||||
while (std::getline(ss, token, delimiter)) {
|
||||
result.push_back(token);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
#endif
|
@ -3,7 +3,6 @@
|
||||
#include "Paths.h"
|
||||
#include "Colors.h"
|
||||
#include "Datasets.h"
|
||||
#include "DotEnv.h"
|
||||
|
||||
using namespace std;
|
||||
const int BALANCE_LENGTH = 75;
|
||||
@ -28,8 +27,7 @@ void outputBalance(const string& balance)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto env = platform::DotEnv();
|
||||
auto data = platform::Datasets(false, env.get("source_data"));
|
||||
auto data = platform::Datasets(false, platform::Paths::datasets());
|
||||
locale mylocale(cout.getloc(), new separated);
|
||||
locale::global(mylocale);
|
||||
cout.imbue(mylocale);
|
||||
|
@ -82,8 +82,7 @@ int main(int argc, char** argv)
|
||||
auto seeds = program.get<vector<int>>("seeds");
|
||||
auto hyperparameters = program.get<string>("hyperparameters");
|
||||
vector<string> filesToTest;
|
||||
auto env = platform::DotEnv();
|
||||
auto datasets = platform::Datasets(discretize_dataset, env.get("source_data"));
|
||||
auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets());
|
||||
auto title = program.get<string>("title");
|
||||
auto saveResults = program.get<bool>("save");
|
||||
if (file_name != "") {
|
||||
@ -102,7 +101,7 @@ int main(int argc, char** argv)
|
||||
/*
|
||||
* Begin Processing
|
||||
*/
|
||||
|
||||
auto env = platform::DotEnv();
|
||||
auto experiment = platform::Experiment();
|
||||
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
|
||||
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
||||
|
Loading…
Reference in New Issue
Block a user