Add show experiment

This commit is contained in:
Ricardo Montañana Gómez 2023-07-29 16:31:36 +02:00
parent 85cb447283
commit b9e76becce
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 31 additions and 10 deletions

View File

@ -43,7 +43,7 @@ namespace platform {
result["discretized"] = discretized; result["discretized"] = discretized;
result["stratified"] = stratified; result["stratified"] = stratified;
result["folds"] = nfolds; result["folds"] = nfolds;
result["seeds"] = random_seeds; result["seeds"] = randomSeeds;
result["duration"] = duration; result["duration"] = duration;
result["results"] = json::array(); result["results"] = json::array();
for (auto& r : results) { for (auto& r : results) {
@ -65,6 +65,10 @@ namespace platform {
j["test_time_std"] = r.getTestTimeStd(); j["test_time_std"] = r.getTestTimeStd();
j["time"] = r.getTestTime() + r.getTrainTime(); j["time"] = r.getTestTime() + r.getTrainTime();
j["time_std"] = r.getTestTimeStd() + r.getTrainTimeStd(); j["time_std"] = r.getTestTimeStd() + r.getTrainTimeStd();
j["scores_train"] = r.getScoresTrain();
j["scores_test"] = r.getScoresTest();
j["times_train"] = r.getTimesTrain();
j["times_test"] = r.getTimesTest();
j["nodes"] = r.getNodes(); j["nodes"] = r.getNodes();
j["leaves"] = r.getLeaves(); j["leaves"] = r.getLeaves();
j["depth"] = r.getDepth(); j["depth"] = r.getDepth();
@ -79,6 +83,11 @@ namespace platform {
file << data; file << data;
file.close(); file.close();
} }
void Experiment::show()
{
json data = build_json();
cout << data.dump(4) << endl;
}
Result cross_validation(Fold* fold, string model_name, torch::Tensor& Xt, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states) Result cross_validation(Fold* fold, string model_name, torch::Tensor& Xt, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states)
{ {
auto classifiers = map<string, bayesnet::BaseClassifier*>({ auto classifiers = map<string, bayesnet::BaseClassifier*>({
@ -101,7 +110,6 @@ namespace platform {
cout << "doing Fold: " << flush; cout << "doing Fold: " << flush;
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
bayesnet::BaseClassifier* model = classifiers[model_name]; bayesnet::BaseClassifier* model = classifiers[model_name];
result.setModelVersion(model->getVersion());
train_timer.start(); train_timer.start();
auto [train, test] = fold->getFold(i); auto [train, test] = fold->getFold(i);
auto train_t = torch::tensor(train); auto train_t = torch::tensor(train);
@ -122,8 +130,9 @@ namespace platform {
test_time[i] = test_timer.getDuration(); test_time[i] = test_timer.getDuration();
accuracy_train[i] = accuracy_train_value; accuracy_train[i] = accuracy_train_value;
accuracy_test[i] = accuracy_test_value; accuracy_test[i] = accuracy_test_value;
} }
cout << "end." << endl; cout << "end. " << flush;
result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>()); result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>()); result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>()); result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());

View File

@ -33,6 +33,7 @@ namespace platform {
int samples, features, classes; int samples, features, classes;
double score_train, score_test, score_train_std, score_test_std, train_time, train_time_std, test_time, test_time_std; double score_train, score_test, score_train_std, score_test_std, train_time, train_time_std, test_time, test_time_std;
float nodes, leaves, depth; float nodes, leaves, depth;
vector<double> scores_train, scores_test, times_train, times_test;
public: public:
Result() = default; Result() = default;
Result& setDataset(string dataset) { this->dataset = dataset; return *this; } Result& setDataset(string dataset) { this->dataset = dataset; return *this; }
@ -51,7 +52,10 @@ namespace platform {
Result& setNodes(float nodes) { this->nodes = nodes; return *this; } Result& setNodes(float nodes) { this->nodes = nodes; return *this; }
Result& setLeaves(float leaves) { this->leaves = leaves; return *this; } Result& setLeaves(float leaves) { this->leaves = leaves; return *this; }
Result& setDepth(float depth) { this->depth = depth; return *this; } Result& setDepth(float depth) { this->depth = depth; return *this; }
Result& setModelVersion(string model_version) { this->model_version = model_version; return *this; } Result& addScoreTrain(double score) { scores_train.push_back(score); return *this; }
Result& addScoreTest(double score) { scores_test.push_back(score); return *this; }
Result& addTimeTrain(double time) { times_train.push_back(time); return *this; }
Result& addTimeTest(double time) { times_test.push_back(time); return *this; }
const float get_score_train() const { return score_train; } const float get_score_train() const { return score_train; }
float get_score_test() { return score_test; } float get_score_test() { return score_test; }
const string& getDataset() const { return dataset; } const string& getDataset() const { return dataset; }
@ -70,14 +74,17 @@ namespace platform {
const float getNodes() const { return nodes; } const float getNodes() const { return nodes; }
const float getLeaves() const { return leaves; } const float getLeaves() const { return leaves; }
const float getDepth() const { return depth; } const float getDepth() const { return depth; }
const string& getModelVersion() const { return model_version; } const vector<double>& getScoresTrain() const { return scores_train; }
const vector<double>& getScoresTest() const { return scores_test; }
const vector<double>& getTimesTrain() const { return times_train; }
const vector<double>& getTimesTest() const { return times_test; }
}; };
class Experiment { class Experiment {
private: private:
string title, model, platform, score_name, model_version, language_version, language; string title, model, platform, score_name, model_version, language_version, language;
bool discretized, stratified; bool discretized, stratified;
vector<Result> results; vector<Result> results;
vector<int> random_seeds; vector<int> randomSeeds;
int nfolds; int nfolds;
float duration; float duration;
json build_json(); json build_json();
@ -94,11 +101,12 @@ namespace platform {
Experiment& setStratified(bool stratified) { this->stratified = stratified; return *this; } Experiment& setStratified(bool stratified) { this->stratified = stratified; return *this; }
Experiment& setNFolds(int nfolds) { this->nfolds = nfolds; return *this; } Experiment& setNFolds(int nfolds) { this->nfolds = nfolds; return *this; }
Experiment& addResult(Result result) { results.push_back(result); return *this; } Experiment& addResult(Result result) { results.push_back(result); return *this; }
Experiment& addRandomSeed(int random_seed) { random_seeds.push_back(random_seed); return *this; } Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); return *this; }
Experiment& setDuration(float duration) { this->duration = duration; return *this; } Experiment& setDuration(float duration) { this->duration = duration; return *this; }
string get_file_name(); string get_file_name();
void save(string path); void save(string path);
void show() { cout << "Showing experiment..." << "Score Test: " << results[0].get_score_test() << " Score Train: " << results[0].get_score_train() << endl; } Result cross_validation(const string& path, const string& fileName);
void show();
}; };
Result cross_validation(Fold* fold, string model_name, torch::Tensor& X, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states); Result cross_validation(Fold* fold, string model_name, torch::Tensor& X, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states);
} }

View File

@ -135,13 +135,17 @@ int main(int argc, char** argv)
fold = new KFold(n_folds, samples, seed); fold = new KFold(n_folds, samples, seed);
auto result = platform::cross_validation(fold, model_name, X, y, features, className, states); auto result = platform::cross_validation(fold, model_name, X, y, features, className, states);
result.setDataset(fileName); result.setDataset(fileName);
experiment.setModelVersion(result.getModelVersion()); experiment.setModelVersion("-FIXME-");
experiment.addResult(result); experiment.addResult(result);
delete fold; delete fold;
} }
cout << endl;
} }
experiment.setDuration(timer.getDuration()); experiment.setDuration(timer.getDuration());
if (saveResults)
experiment.save(PATH_RESULTS); experiment.save(PATH_RESULTS);
else
experiment.show();
cout << "Done!" << endl; cout << "Done!" << endl;
return 0; return 0;
} }