From 69b9609154557e89d9fe79a2a4b2d936a45559b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 10 May 2024 17:12:11 +0200 Subject: [PATCH] Add labels to confusion_matrices in results --- lib/Files/ArffFiles.cc | 2 ++ lib/Files/ArffFiles.h | 22 ++++++++++++---------- src/common/Dataset.cpp | 3 +++ src/common/Dataset.h | 30 ++++++++++++++++-------------- src/common/Datasets.cpp | 8 ++++++++ src/common/Datasets.h | 19 ++++++++++--------- src/main/Experiment.cpp | 5 +++-- tests/TestScores.cpp | 3 +++ 8 files changed, 57 insertions(+), 35 deletions(-) diff --git a/lib/Files/ArffFiles.cc b/lib/Files/ArffFiles.cc index 99f29bd..1460bfa 100644 --- a/lib/Files/ArffFiles.cc +++ b/lib/Files/ArffFiles.cc @@ -155,12 +155,14 @@ std::string ArffFiles::trim(const std::string& source) std::vector ArffFiles::factorize(const std::vector& labels_t) { std::vector yy; + labels.clear(); yy.reserve(labels_t.size()); std::map labelMap; int i = 0; for (const std::string& label : labels_t) { if (labelMap.find(label) == labelMap.end()) { labelMap[label] = i++; + labels.push_back(label); } yy.push_back(labelMap[label]); } diff --git a/lib/Files/ArffFiles.h b/lib/Files/ArffFiles.h index 25e5a8c..21caa05 100644 --- a/lib/Files/ArffFiles.h +++ b/lib/Files/ArffFiles.h @@ -5,15 +5,6 @@ #include class ArffFiles { -private: - std::vector lines; - std::vector> attributes; - std::string className; - std::string classType; - std::vector> X; - std::vector y; - void generateDataset(int); - void loadCommon(std::string); public: ArffFiles(); void load(const std::string&, bool = true); @@ -22,11 +13,22 @@ public: unsigned long int getSize() const; std::string getClassName() const; std::string getClassType() const; + std::vector getLabels() const { return labels; } static std::string trim(const std::string&); std::vector>& getX(); std::vector& getY(); std::vector> getAttributes() const; - static std::vector factorize(const std::vector& labels_t); + std::vector factorize(const std::vector& labels_t); +private: + std::vector lines; + std::vector> attributes; + std::string className; + std::string classType; + std::vector> X; + std::vector y; + std::vector labels; + void generateDataset(int); + void loadCommon(std::string); }; #endif \ No newline at end of file diff --git a/src/common/Dataset.cpp b/src/common/Dataset.cpp index cb16b0b..bb1bf94 100644 --- a/src/common/Dataset.cpp +++ b/src/common/Dataset.cpp @@ -91,6 +91,7 @@ namespace platform { } yv.push_back(stoi(tokens.back())); } + labels.clear(); file.close(); } else { throw std::invalid_argument("Unable to open dataset file."); @@ -117,6 +118,7 @@ namespace platform { className = arff.getClassName(); auto attributes = arff.getAttributes(); transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; }); + labels = arff.getLabels(); } std::vector tokenize(std::string line) { @@ -160,6 +162,7 @@ namespace platform { } yv.push_back(stoi(tokens.back())); } + labels.clear(); file.close(); } else { throw std::invalid_argument("Unable to open dataset file."); diff --git a/src/common/Dataset.h b/src/common/Dataset.h index 49ffc48..b88db75 100644 --- a/src/common/Dataset.h +++ b/src/common/Dataset.h @@ -10,6 +10,21 @@ namespace platform { class Dataset { + public: + Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {}; + explicit Dataset(const Dataset&); + std::string getName() const; + std::string getClassName() const; + std::vector getLabels() const { return labels; } + std::vector getFeatures() const; + std::map> getStates() const; + std::pair>&, std::vector&> getVectors(); + std::pair>&, std::vector&> getVectorsDiscretized(); + std::pair getTensors(); + int getNFeatures() const; + int getNSamples() const; + void load(); + const bool inline isLoaded() const { return loaded; }; private: std::string path; std::string name; @@ -17,6 +32,7 @@ namespace platform { std::string className; int n_samples{ 0 }, n_features{ 0 }; std::vector features; + std::vector labels; std::map> states; bool loaded; bool discretize; @@ -30,20 +46,6 @@ namespace platform { void load_rdata(); void computeStates(); std::vector discretizeDataset(std::vector& X, mdlp::labels_t& y); - public: - Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {}; - explicit Dataset(const Dataset&); - std::string getName() const; - std::string getClassName() const; - std::vector getFeatures() const; - std::map> getStates() const; - std::pair>&, std::vector&> getVectors(); - std::pair>&, std::vector&> getVectorsDiscretized(); - std::pair getTensors(); - int getNFeatures() const; - int getNSamples() const; - void load(); - const bool inline isLoaded() const { return loaded; }; }; }; diff --git a/src/common/Datasets.cpp b/src/common/Datasets.cpp index 4e68d8a..7e87542 100644 --- a/src/common/Datasets.cpp +++ b/src/common/Datasets.cpp @@ -42,6 +42,14 @@ namespace platform { throw std::invalid_argument("Dataset not loaded."); } } + std::vector Datasets::getLabels(const std::string& name) const + { + if (datasets.at(name)->isLoaded()) { + return datasets.at(name)->getLabels(); + } else { + throw std::invalid_argument("Dataset not loaded."); + } + } map> Datasets::getStates(const std::string& name) const { if (datasets.at(name)->isLoaded()) { diff --git a/src/common/Datasets.h b/src/common/Datasets.h index 4e8b5d2..07c99d6 100644 --- a/src/common/Datasets.h +++ b/src/common/Datasets.h @@ -3,18 +3,12 @@ #include "Dataset.h" namespace platform { class Datasets { - private: - std::string path; - fileType_t fileType; - std::string sfileType; - std::map> datasets; - bool discretize; - void load(); // Loads the list of datasets public: explicit Datasets(bool discretize, std::string sfileType) : discretize(discretize), sfileType(sfileType) { load(); }; - std::vector getNames(); - std::vector getFeatures(const std::string& name) const; + std::vector getNames(); + std::vector getFeatures(const std::string& name) const; int getNSamples(const std::string& name) const; + std::vector getLabels(const std::string& name) const; std::string getClassName(const std::string& name) const; int getNClasses(const std::string& name); std::vector getClassesCounts(const std::string& name) const; @@ -25,5 +19,12 @@ namespace platform { bool isDataset(const std::string& name) const; void loadDataset(const std::string& name) const; std::string toString() const; + private: + std::string path; + fileType_t fileType; + std::string sfileType; + std::map> datasets; + bool discretize; + void load(); // Loads the list of datasets }; }; diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 1f67592..e006821 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -83,6 +83,7 @@ namespace platform { auto features = datasets.getFeatures(fileName); auto samples = datasets.getNSamples(fileName); auto className = datasets.getClassName(fileName); + auto labels = datasets.getLabels(fileName); if (!quiet) { std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; } @@ -156,12 +157,12 @@ namespace platform { showProgress(nfold + 1, getColor(clf->getStatus()), "c"); test_timer.start(); auto y_predict = clf->predict(X_test); - Scores scores(y_test, y_predict, states[className].size()); + Scores scores(y_test, y_predict, states[className].size(), labels); auto accuracy_test_value = scores.accuracy(); test_time[item] = test_timer.getDuration(); accuracy_train[item] = accuracy_train_value; accuracy_test[item] = accuracy_test_value; - confusion_matrices.push_back(scores.get_confusion_matrix_json()); + confusion_matrices.push_back(scores.get_confusion_matrix_json(true)); if (!quiet) std::cout << "\b\b\b, " << flush; // Store results and times in std::vector diff --git a/tests/TestScores.cpp b/tests/TestScores.cpp index a58f75b..965a824 100644 --- a/tests/TestScores.cpp +++ b/tests/TestScores.cpp @@ -147,6 +147,9 @@ TEST_CASE("Classification Report", "[Scores]") weighted avg 0.8250000 0.6000000 0.6400000 10 )"; REQUIRE(scores.classification_report() == expected); + auto json_matrix = scores.get_confusion_matrix_json(true); + platform::Scores scores2(json_matrix); + REQUIRE(scores.classification_report() == scores2.classification_report()); } TEST_CASE("JSON constructor", "[Scores]") {