Begin adding cfs to BoostAODE
This commit is contained in:
parent
7d8aca4f59
commit
f288bbd6fa
@ -4,6 +4,7 @@
|
|||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
#include "Folding.h"
|
#include "Folding.h"
|
||||||
#include <limits.h>
|
#include <limits.h>
|
||||||
|
#include "Paths.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
BoostAODE::BoostAODE() : Ensemble() {}
|
BoostAODE::BoostAODE() : Ensemble() {}
|
||||||
@ -28,6 +29,9 @@ namespace bayesnet {
|
|||||||
if (hyperparameters.contains("convergence")) {
|
if (hyperparameters.contains("convergence")) {
|
||||||
convergence = hyperparameters["convergence"];
|
convergence = hyperparameters["convergence"];
|
||||||
}
|
}
|
||||||
|
if (hyperparameters.contains("cfs")) {
|
||||||
|
cfs = hyperparameters["cfs"];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
void BoostAODE::validationInit()
|
void BoostAODE::validationInit()
|
||||||
{
|
{
|
||||||
@ -58,6 +62,12 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
void BoostAODE::initializeModels()
|
||||||
|
{
|
||||||
|
ifstream file(cfs + ".json");
|
||||||
|
if (file.is_open()) {
|
||||||
|
}
|
||||||
|
}
|
||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
models.clear();
|
models.clear();
|
||||||
@ -66,6 +76,9 @@ namespace bayesnet {
|
|||||||
maxModels = .1 * n > 10 ? .1 * n : n;
|
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||||
validationInit();
|
validationInit();
|
||||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||||
|
if (cfs != "") {
|
||||||
|
initializeModels();
|
||||||
|
}
|
||||||
bool exitCondition = false;
|
bool exitCondition = false;
|
||||||
unordered_set<int> featuresUsed;
|
unordered_set<int> featuresUsed;
|
||||||
// Variables to control the accuracy finish condition
|
// Variables to control the accuracy finish condition
|
||||||
|
@ -16,10 +16,13 @@ namespace bayesnet {
|
|||||||
torch::Tensor dataset_;
|
torch::Tensor dataset_;
|
||||||
torch::Tensor X_train, y_train, X_test, y_test;
|
torch::Tensor X_train, y_train, X_test, y_test;
|
||||||
void validationInit();
|
void validationInit();
|
||||||
bool repeatSparent = false;
|
void initializeModels();
|
||||||
|
// Hyperparameters
|
||||||
|
bool repeatSparent = false; // if true, a feature can be selected more than once
|
||||||
int maxModels = 0;
|
int maxModels = 0;
|
||||||
bool ascending = false; //Process KBest features ascending or descending order
|
bool ascending = false; //Process KBest features ascending or descending order
|
||||||
bool convergence = false; //if true, stop when the model does not improve
|
bool convergence = false; //if true, stop when the model does not improve
|
||||||
|
string cfs = ""; // if not empty, use CFS to select features
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -212,14 +212,4 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return Xd;
|
return Xd;
|
||||||
}
|
}
|
||||||
vector<string> Dataset::split(const string& text, char delimiter)
|
|
||||||
{
|
|
||||||
vector<string> result;
|
|
||||||
stringstream ss(text);
|
|
||||||
string token;
|
|
||||||
while (getline(ss, token, delimiter)) {
|
|
||||||
result.push_back(token);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -5,6 +5,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "CPPFImdlp.h"
|
#include "CPPFImdlp.h"
|
||||||
|
#include "Utils.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
@ -62,7 +63,6 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
Dataset(const string& path, const string& name, const string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
|
Dataset(const string& path, const string& name, const string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
|
||||||
explicit Dataset(const Dataset&);
|
explicit Dataset(const Dataset&);
|
||||||
static vector<string> split(const string& text, char delimiter);
|
|
||||||
string getName() const;
|
string getName() const;
|
||||||
string getClassName() const;
|
string getClassName() const;
|
||||||
vector<string> getFeatures() const;
|
vector<string> getFeatures() const;
|
||||||
|
@ -13,7 +13,7 @@ namespace platform {
|
|||||||
if (line.empty() || line[0] == '#') {
|
if (line.empty() || line[0] == '#') {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
vector<string> tokens = Dataset::split(line, ',');
|
vector<string> tokens = split(line, ',');
|
||||||
string name = tokens[0];
|
string name = tokens[0];
|
||||||
string className;
|
string className;
|
||||||
if (tokens.size() == 1) {
|
if (tokens.size() == 1) {
|
||||||
|
@ -4,7 +4,10 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include "Dataset.h"
|
#include <iostream>
|
||||||
|
#include "Utils.h"
|
||||||
|
|
||||||
|
//#include "Dataset.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class DotEnv {
|
class DotEnv {
|
||||||
private:
|
private:
|
||||||
@ -51,7 +54,7 @@ namespace platform {
|
|||||||
auto seeds_str = env["seeds"];
|
auto seeds_str = env["seeds"];
|
||||||
seeds_str = trim(seeds_str);
|
seeds_str = trim(seeds_str);
|
||||||
seeds_str = seeds_str.substr(1, seeds_str.size() - 2);
|
seeds_str = seeds_str.substr(1, seeds_str.size() - 2);
|
||||||
auto seeds_str_split = Dataset::split(seeds_str, ',');
|
auto seeds_str_split = split(seeds_str, ',');
|
||||||
transform(seeds_str_split.begin(), seeds_str_split.end(), back_inserter(seeds), [](const std::string& str) {
|
transform(seeds_str_split.begin(), seeds_str_split.end(), back_inserter(seeds), [](const std::string& str) {
|
||||||
return stoi(str);
|
return stoi(str);
|
||||||
});
|
});
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
#include "Models.h"
|
#include "Models.h"
|
||||||
#include "ReportConsole.h"
|
#include "ReportConsole.h"
|
||||||
#include "DotEnv.h"
|
#include "Paths.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
string get_date()
|
string get_date()
|
||||||
@ -134,8 +134,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
void Experiment::cross_validation(const string& fileName)
|
void Experiment::cross_validation(const string& fileName)
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto datasets = platform::Datasets(discretized, Paths::datasets());
|
||||||
auto datasets = platform::Datasets(discretized, env.get("source_data"));
|
|
||||||
// Get dataset
|
// Get dataset
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
auto states = datasets.getStates(fileName);
|
auto states = datasets.getStates(fileName);
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
#ifndef PATHS_H
|
#ifndef PATHS_H
|
||||||
#define PATHS_H
|
#define PATHS_H
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include "DotEnv.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Paths {
|
class Paths {
|
||||||
public:
|
public:
|
||||||
static std::string results() { return "results/"; }
|
static std::string results() { return "results/"; }
|
||||||
static std::string excel() { return "excel/"; }
|
static std::string excel() { return "excel/"; }
|
||||||
|
static std::string datasets()
|
||||||
|
{
|
||||||
|
auto env = platform::DotEnv();
|
||||||
|
return env.get("source_data");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -58,8 +58,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (data["score_name"].get<string>() == "accuracy") {
|
if (data["score_name"].get<string>() == "accuracy") {
|
||||||
auto env = platform::DotEnv();
|
auto dt = Datasets(false, Paths::datasets());
|
||||||
auto dt = Datasets(false, env.get("source_data"));
|
|
||||||
dt.loadDataset(dataset);
|
dt.loadDataset(dataset);
|
||||||
auto numClasses = dt.getNClasses(dataset);
|
auto numClasses = dt.getNClasses(dataset);
|
||||||
if (numClasses == 2) {
|
if (numClasses == 2) {
|
||||||
|
@ -56,10 +56,12 @@ namespace platform {
|
|||||||
try {
|
try {
|
||||||
cout << r["hyperparameters"].get<string>();
|
cout << r["hyperparameters"].get<string>();
|
||||||
}
|
}
|
||||||
catch (const exception& err) {
|
catch (...) {
|
||||||
cout << r["hyperparameters"];
|
//cout << r["hyperparameters"];
|
||||||
|
cout << "Arrggggghhhh!" << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
cout << flush;
|
||||||
lastResult = r;
|
lastResult = r;
|
||||||
totalScore += r["score"].get<double>();
|
totalScore += r["score"].get<double>();
|
||||||
odd = !odd;
|
odd = !odd;
|
||||||
|
19
src/Platform/Utils.h
Normal file
19
src/Platform/Utils.h
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#ifndef UTILS_H
|
||||||
|
#define UTILS_H
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
namespace platform {
|
||||||
|
//static vector<string> split(const string& text, char delimiter);
|
||||||
|
static std::vector<std::string> split(const std::string& text, char delimiter)
|
||||||
|
{
|
||||||
|
std::vector<std::string> result;
|
||||||
|
std::stringstream ss(text);
|
||||||
|
std::string token;
|
||||||
|
while (std::getline(ss, token, delimiter)) {
|
||||||
|
result.push_back(token);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
@ -3,7 +3,6 @@
|
|||||||
#include "Paths.h"
|
#include "Paths.h"
|
||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
#include "DotEnv.h"
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
const int BALANCE_LENGTH = 75;
|
const int BALANCE_LENGTH = 75;
|
||||||
@ -28,8 +27,7 @@ void outputBalance(const string& balance)
|
|||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto data = platform::Datasets(false, platform::Paths::datasets());
|
||||||
auto data = platform::Datasets(false, env.get("source_data"));
|
|
||||||
locale mylocale(cout.getloc(), new separated);
|
locale mylocale(cout.getloc(), new separated);
|
||||||
locale::global(mylocale);
|
locale::global(mylocale);
|
||||||
cout.imbue(mylocale);
|
cout.imbue(mylocale);
|
||||||
|
@ -82,8 +82,7 @@ int main(int argc, char** argv)
|
|||||||
auto seeds = program.get<vector<int>>("seeds");
|
auto seeds = program.get<vector<int>>("seeds");
|
||||||
auto hyperparameters = program.get<string>("hyperparameters");
|
auto hyperparameters = program.get<string>("hyperparameters");
|
||||||
vector<string> filesToTest;
|
vector<string> filesToTest;
|
||||||
auto env = platform::DotEnv();
|
auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets());
|
||||||
auto datasets = platform::Datasets(discretize_dataset, env.get("source_data"));
|
|
||||||
auto title = program.get<string>("title");
|
auto title = program.get<string>("title");
|
||||||
auto saveResults = program.get<bool>("save");
|
auto saveResults = program.get<bool>("save");
|
||||||
if (file_name != "") {
|
if (file_name != "") {
|
||||||
@ -102,7 +101,7 @@ int main(int argc, char** argv)
|
|||||||
/*
|
/*
|
||||||
* Begin Processing
|
* Begin Processing
|
||||||
*/
|
*/
|
||||||
|
auto env = platform::DotEnv();
|
||||||
auto experiment = platform::Experiment();
|
auto experiment = platform::Experiment();
|
||||||
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
|
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
|
||||||
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
||||||
|
Loading…
Reference in New Issue
Block a user