Add continue from parameter to gridsearch
This commit is contained in:
parent
64069a6cb7
commit
c713c0b1df
@ -7,6 +7,26 @@
|
|||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
std::string get_date()
|
||||||
|
{
|
||||||
|
time_t rawtime;
|
||||||
|
tm* timeinfo;
|
||||||
|
time(&rawtime);
|
||||||
|
timeinfo = std::localtime(&rawtime);
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << std::put_time(timeinfo, "%Y-%m-%d");
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
std::string get_time()
|
||||||
|
{
|
||||||
|
time_t rawtime;
|
||||||
|
tm* timeinfo;
|
||||||
|
time(&rawtime);
|
||||||
|
timeinfo = std::localtime(&rawtime);
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << std::put_time(timeinfo, "%H:%M:%S");
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
||||||
{
|
{
|
||||||
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
||||||
@ -43,7 +63,6 @@ namespace platform {
|
|||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
auto states = datasets.getStates(fileName);
|
auto states = datasets.getStates(fileName);
|
||||||
auto features = datasets.getFeatures(fileName);
|
auto features = datasets.getFeatures(fileName);
|
||||||
auto samples = datasets.getNSamples(fileName);
|
|
||||||
auto className = datasets.getClassName(fileName);
|
auto className = datasets.getClassName(fileName);
|
||||||
double totalScore = 0.0;
|
double totalScore = 0.0;
|
||||||
int numItems = 0;
|
int numItems = 0;
|
||||||
@ -86,6 +105,33 @@ namespace platform {
|
|||||||
{
|
{
|
||||||
// Load datasets
|
// Load datasets
|
||||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||||
|
// Load previous results
|
||||||
|
json results;
|
||||||
|
auto datasets_names = datasets.getNames();
|
||||||
|
if (config.continue_from != "no") {
|
||||||
|
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
||||||
|
throw std::invalid_argument("Dataset " + config.continue_from + " not found");
|
||||||
|
}
|
||||||
|
if (!config.quiet)
|
||||||
|
std::cout << "* Loading previous results" << std::endl;
|
||||||
|
try {
|
||||||
|
std::ifstream file(config.output_file);
|
||||||
|
if (file.is_open()) {
|
||||||
|
results = json::parse(file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (const std::exception& e) {
|
||||||
|
std::cerr << "Error loading previous results: " << e.what() << std::endl;
|
||||||
|
}
|
||||||
|
// Remove datasets already processed
|
||||||
|
vector< string >::iterator it = datasets_names.begin();
|
||||||
|
while (it != datasets_names.end()) {
|
||||||
|
if (*it != config.continue_from) {
|
||||||
|
it = datasets_names.erase(it);
|
||||||
|
} else
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
// Create model
|
// Create model
|
||||||
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
||||||
std::cout << "input file=" << config.input_file << std::endl;
|
std::cout << "input file=" << config.input_file << std::endl;
|
||||||
@ -94,7 +140,7 @@ namespace platform {
|
|||||||
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
|
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
|
||||||
// Generate hyperparameters grid & run gridsearch
|
// Generate hyperparameters grid & run gridsearch
|
||||||
// Check each combination of hyperparameters for each dataset and each seed
|
// Check each combination of hyperparameters for each dataset and each seed
|
||||||
for (const auto& dataset : datasets.getNames()) {
|
for (const auto& dataset : datasets_names) {
|
||||||
if (!config.quiet)
|
if (!config.quiet)
|
||||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||||
int num = 0;
|
int num = 0;
|
||||||
@ -116,15 +162,17 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
results[dataset]["score"] = bestScore;
|
results[dataset]["score"] = bestScore;
|
||||||
results[dataset]["hyperparameters"] = bestHyperparameters;
|
results[dataset]["hyperparameters"] = bestHyperparameters;
|
||||||
|
results[dataset]["date"] = get_date() + " " + get_time();
|
||||||
|
// Save partial results
|
||||||
|
save(results);
|
||||||
}
|
}
|
||||||
// Save results
|
// Save final results
|
||||||
save();
|
save(results);
|
||||||
std::cout << "***************** Ending Gridsearch *******************" << std::endl;
|
std::cout << "***************** Ending Gridsearch *******************" << std::endl;
|
||||||
}
|
}
|
||||||
void GridSearch::save() const
|
void GridSearch::save(json& results) const
|
||||||
{
|
{
|
||||||
std::ofstream file(config.output_file);
|
std::ofstream file(config.output_file);
|
||||||
file << results.dump(4);
|
file << results.dump(4);
|
||||||
file.close();
|
|
||||||
}
|
}
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
@ -15,6 +15,7 @@ namespace platform {
|
|||||||
std::string path;
|
std::string path;
|
||||||
std::string input_file;
|
std::string input_file;
|
||||||
std::string output_file;
|
std::string output_file;
|
||||||
|
std::string continue_from;
|
||||||
bool quiet;
|
bool quiet;
|
||||||
bool discretize;
|
bool discretize;
|
||||||
bool stratified;
|
bool stratified;
|
||||||
@ -25,11 +26,10 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
explicit GridSearch(struct ConfigGrid& config);
|
explicit GridSearch(struct ConfigGrid& config);
|
||||||
void go();
|
void go();
|
||||||
void save() const;
|
|
||||||
~GridSearch() = default;
|
~GridSearch() = default;
|
||||||
private:
|
private:
|
||||||
|
void save(json& results) const;
|
||||||
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
||||||
json results;
|
|
||||||
struct ConfigGrid config;
|
struct ConfigGrid config;
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
|
@ -32,5 +32,4 @@ namespace platform {
|
|||||||
bool complete;
|
bool complete;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
@ -7,7 +7,6 @@
|
|||||||
#include "Result.h"
|
#include "Result.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
class Results {
|
class Results {
|
||||||
public:
|
public:
|
||||||
Results(const std::string& path, const std::string& model, const std::string& score, bool complete, bool partial);
|
Results(const std::string& path, const std::string& model, const std::string& score, bool complete, bool partial);
|
||||||
@ -34,5 +33,4 @@ namespace platform {
|
|||||||
void load(); // Loads the list of results
|
void load(); // Loads the list of results
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
@ -24,6 +24,7 @@ argparse::ArgumentParser manageArguments(std::string program_name)
|
|||||||
);
|
);
|
||||||
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No");
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
||||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
||||||
@ -58,6 +59,7 @@ int main(int argc, char** argv)
|
|||||||
config.n_folds = program.get<int>("folds");
|
config.n_folds = program.get<int>("folds");
|
||||||
config.quiet = program.get<bool>("quiet");
|
config.quiet = program.get<bool>("quiet");
|
||||||
config.seeds = program.get<std::vector<int>>("seeds");
|
config.seeds = program.get<std::vector<int>>("seeds");
|
||||||
|
config.continue_from = program.get<std::string>("continue");
|
||||||
}
|
}
|
||||||
catch (const exception& err) {
|
catch (const exception& err) {
|
||||||
cerr << err.what() << std::endl;
|
cerr << err.what() << std::endl;
|
||||||
@ -75,7 +77,6 @@ int main(int argc, char** argv)
|
|||||||
timer.start();
|
timer.start();
|
||||||
grid_search.go();
|
grid_search.go();
|
||||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||||
grid_search.save();
|
|
||||||
std::cout << "Done!" << std::endl;
|
std::cout << "Done!" << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user