Begin CFS initialization

This commit is contained in:
Ricardo Montañana Gómez 2023-10-10 13:39:11 +02:00
parent f288bbd6fa
commit df9b4c48d2
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 57 additions and 63 deletions

View File

@ -11,30 +11,7 @@ namespace bayesnet {
void BoostAODE::buildModel(const torch::Tensor& weights) void BoostAODE::buildModel(const torch::Tensor& weights)
{ {
// Models shall be built in trainModel // Models shall be built in trainModel
} // Prepare the validation dataset
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending", "convergence" };
checkHyperparameters(validKeys, hyperparameters);
if (hyperparameters.contains("repeatSparent")) {
repeatSparent = hyperparameters["repeatSparent"];
}
if (hyperparameters.contains("maxModels")) {
maxModels = hyperparameters["maxModels"];
}
if (hyperparameters.contains("ascending")) {
ascending = hyperparameters["ascending"];
}
if (hyperparameters.contains("convergence")) {
convergence = hyperparameters["convergence"];
}
if (hyperparameters.contains("cfs")) {
cfs = hyperparameters["cfs"];
}
}
void BoostAODE::validationInit()
{
auto y_ = dataset.index({ -1, "..." }); auto y_ = dataset.index({ -1, "..." });
if (convergence) { if (convergence) {
// Prepare train & validation sets from train data // Prepare train & validation sets from train data
@ -60,12 +37,43 @@ namespace bayesnet {
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
y_train = y_; y_train = y_;
} }
if (cfs != "") {
initializeModels();
}
}
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending", "convergence", "cfs" };
checkHyperparameters(validKeys, hyperparameters);
if (hyperparameters.contains("repeatSparent")) {
repeatSparent = hyperparameters["repeatSparent"];
}
if (hyperparameters.contains("maxModels")) {
maxModels = hyperparameters["maxModels"];
}
if (hyperparameters.contains("ascending")) {
ascending = hyperparameters["ascending"];
}
if (hyperparameters.contains("convergence")) {
convergence = hyperparameters["convergence"];
}
if (hyperparameters.contains("cfs")) {
cfs = hyperparameters["cfs"];
}
} }
void BoostAODE::initializeModels() void BoostAODE::initializeModels()
{ {
ifstream file(cfs + ".json"); ifstream file(cfs + ".json");
if (file.is_open()) { if (file.is_open()) {
nlohmann::json data;
file >> data;
file.close();
auto model = "iris"; // has to come in when building object
auto features = data[model];
cout << "features: " << features.dump() << endl;
} else {
throw runtime_error("File " + cfs + ".json not found");
} }
} }
void BoostAODE::trainModel(const torch::Tensor& weights) void BoostAODE::trainModel(const torch::Tensor& weights)
@ -74,11 +82,7 @@ namespace bayesnet {
n_models = 0; n_models = 0;
if (maxModels == 0) if (maxModels == 0)
maxModels = .1 * n > 10 ? .1 * n : n; maxModels = .1 * n > 10 ? .1 * n : n;
validationInit();
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
if (cfs != "") {
initializeModels();
}
bool exitCondition = false; bool exitCondition = false;
unordered_set<int> featuresUsed; unordered_set<int> featuresUsed;
// Variables to control the accuracy finish condition // Variables to control the accuracy finish condition

View File

@ -15,7 +15,6 @@ namespace bayesnet {
private: private:
torch::Tensor dataset_; torch::Tensor dataset_;
torch::Tensor X_train, y_train, X_test, y_test; torch::Tensor X_train, y_train, X_test, y_test;
void validationInit();
void initializeModels(); void initializeModels();
// Hyperparameters // Hyperparameters
bool repeatSparent = false; // if true, a feature can be selected more than once bool repeatSparent = false; // if true, a feature can be selected more than once

View File

@ -53,13 +53,7 @@ namespace platform {
const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>()); const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
cout << status; cout << status;
cout << setw(12) << right << setprecision(6) << fixed << r["time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get<double>() << " "; cout << setw(12) << right << setprecision(6) << fixed << r["time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get<double>() << " ";
try { cout << r["hyperparameters"].dump();
cout << r["hyperparameters"].get<string>();
}
catch (...) {
//cout << r["hyperparameters"];
cout << "Arrggggghhhh!" << endl;
}
cout << endl; cout << endl;
cout << flush; cout << flush;
lastResult = r; lastResult = r;

View File

@ -12,7 +12,7 @@
using namespace std; using namespace std;
using json = nlohmann::json; using json = nlohmann::json;
argparse::ArgumentParser manageArguments(int argc, char** argv) argparse::ArgumentParser manageArguments()
{ {
auto env = platform::DotEnv(); auto env = platform::DotEnv();
argparse::ArgumentParser program("main"); argparse::ArgumentParser program("main");
@ -48,43 +48,40 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
}}); }});
auto seed_values = env.getSeeds(); auto seed_values = env.getSeeds();
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
return program;
}
int main(int argc, char** argv)
{
string file_name, model_name, title;
json hyperparameters_json;
bool discretize_dataset, stratified, saveResults;
vector<int> seeds;
vector<string> filesToTest;
int n_folds;
auto program = manageArguments();
try { try {
program.parse_args(argc, argv); program.parse_args(argc, argv);
auto file_name = program.get<string>("dataset"); file_name = program.get<string>("dataset");
auto model_name = program.get<string>("model"); model_name = program.get<string>("model");
auto discretize_dataset = program.get<bool>("discretize"); discretize_dataset = program.get<bool>("discretize");
auto stratified = program.get<bool>("stratified"); stratified = program.get<bool>("stratified");
auto n_folds = program.get<int>("folds"); n_folds = program.get<int>("folds");
auto seeds = program.get<vector<int>>("seeds"); seeds = program.get<vector<int>>("seeds");
auto title = program.get<string>("title");
auto hyperparameters = program.get<string>("hyperparameters"); auto hyperparameters = program.get<string>("hyperparameters");
auto saveResults = program.get<bool>("save"); hyperparameters_json = json::parse(hyperparameters);
title = program.get<string>("title");
if (title == "" && file_name == "") { if (title == "" && file_name == "") {
throw runtime_error("title is mandatory if dataset is not provided"); throw runtime_error("title is mandatory if dataset is not provided");
} }
saveResults = program.get<bool>("save");
} }
catch (const exception& err) { catch (const exception& err) {
cerr << err.what() << endl; cerr << err.what() << endl;
cerr << program; cerr << program;
exit(1); exit(1);
} }
return program;
}
int main(int argc, char** argv)
{
auto program = manageArguments(argc, argv);
auto file_name = program.get<string>("dataset");
auto model_name = program.get<string>("model");
auto discretize_dataset = program.get<bool>("discretize");
auto stratified = program.get<bool>("stratified");
auto n_folds = program.get<int>("folds");
auto seeds = program.get<vector<int>>("seeds");
auto hyperparameters = program.get<string>("hyperparameters");
vector<string> filesToTest;
auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets()); auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets());
auto title = program.get<string>("title");
auto saveResults = program.get<bool>("save");
if (file_name != "") { if (file_name != "") {
if (!datasets.isDataset(file_name)) { if (!datasets.isDataset(file_name)) {
cerr << "Dataset " << file_name << " not found" << endl; cerr << "Dataset " << file_name << " not found" << endl;
@ -106,7 +103,7 @@ int main(int argc, char** argv)
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3"); experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform")); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy"); experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy");
experiment.setHyperparameters(json::parse(hyperparameters)); experiment.setHyperparameters(hyperparameters_json);
for (auto seed : seeds) { for (auto seed : seeds) {
experiment.addRandomSeed(seed); experiment.addRandomSeed(seed);
} }