Add labels to confusion_matrices in results

This commit is contained in:
2024-05-10 17:12:11 +02:00
parent 6d4117d188
commit 69b9609154
8 changed files with 57 additions and 35 deletions

View File

@@ -155,12 +155,14 @@ std::string ArffFiles::trim(const std::string& source)
std::vector<int> ArffFiles::factorize(const std::vector<std::string>& labels_t) std::vector<int> ArffFiles::factorize(const std::vector<std::string>& labels_t)
{ {
std::vector<int> yy; std::vector<int> yy;
labels.clear();
yy.reserve(labels_t.size()); yy.reserve(labels_t.size());
std::map<std::string, int> labelMap; std::map<std::string, int> labelMap;
int i = 0; int i = 0;
for (const std::string& label : labels_t) { for (const std::string& label : labels_t) {
if (labelMap.find(label) == labelMap.end()) { if (labelMap.find(label) == labelMap.end()) {
labelMap[label] = i++; labelMap[label] = i++;
labels.push_back(label);
} }
yy.push_back(labelMap[label]); yy.push_back(labelMap[label]);
} }

View File

@@ -5,15 +5,6 @@
#include <vector> #include <vector>
class ArffFiles { class ArffFiles {
private:
std::vector<std::string> lines;
std::vector<std::pair<std::string, std::string>> attributes;
std::string className;
std::string classType;
std::vector<std::vector<float>> X;
std::vector<int> y;
void generateDataset(int);
void loadCommon(std::string);
public: public:
ArffFiles(); ArffFiles();
void load(const std::string&, bool = true); void load(const std::string&, bool = true);
@@ -22,11 +13,22 @@ public:
unsigned long int getSize() const; unsigned long int getSize() const;
std::string getClassName() const; std::string getClassName() const;
std::string getClassType() const; std::string getClassType() const;
std::vector<std::string> getLabels() const { return labels; }
static std::string trim(const std::string&); static std::string trim(const std::string&);
std::vector<std::vector<float>>& getX(); std::vector<std::vector<float>>& getX();
std::vector<int>& getY(); std::vector<int>& getY();
std::vector<std::pair<std::string, std::string>> getAttributes() const; std::vector<std::pair<std::string, std::string>> getAttributes() const;
static std::vector<int> factorize(const std::vector<std::string>& labels_t); std::vector<int> factorize(const std::vector<std::string>& labels_t);
private:
std::vector<std::string> lines;
std::vector<std::pair<std::string, std::string>> attributes;
std::string className;
std::string classType;
std::vector<std::vector<float>> X;
std::vector<int> y;
std::vector<std::string> labels;
void generateDataset(int);
void loadCommon(std::string);
}; };
#endif #endif

View File

@@ -91,6 +91,7 @@ namespace platform {
} }
yv.push_back(stoi(tokens.back())); yv.push_back(stoi(tokens.back()));
} }
labels.clear();
file.close(); file.close();
} else { } else {
throw std::invalid_argument("Unable to open dataset file."); throw std::invalid_argument("Unable to open dataset file.");
@@ -117,6 +118,7 @@ namespace platform {
className = arff.getClassName(); className = arff.getClassName();
auto attributes = arff.getAttributes(); auto attributes = arff.getAttributes();
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; }); transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; });
labels = arff.getLabels();
} }
std::vector<std::string> tokenize(std::string line) std::vector<std::string> tokenize(std::string line)
{ {
@@ -160,6 +162,7 @@ namespace platform {
} }
yv.push_back(stoi(tokens.back())); yv.push_back(stoi(tokens.back()));
} }
labels.clear();
file.close(); file.close();
} else { } else {
throw std::invalid_argument("Unable to open dataset file."); throw std::invalid_argument("Unable to open dataset file.");

View File

@@ -10,6 +10,21 @@
namespace platform { namespace platform {
class Dataset { class Dataset {
public:
Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
explicit Dataset(const Dataset&);
std::string getName() const;
std::string getClassName() const;
std::vector<std::string> getLabels() const { return labels; }
std::vector<string> getFeatures() const;
std::map<std::string, std::vector<int>> getStates() const;
std::pair<vector<std::vector<float>>&, std::vector<int>&> getVectors();
std::pair<vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized();
std::pair<torch::Tensor&, torch::Tensor&> getTensors();
int getNFeatures() const;
int getNSamples() const;
void load();
const bool inline isLoaded() const { return loaded; };
private: private:
std::string path; std::string path;
std::string name; std::string name;
@@ -17,6 +32,7 @@ namespace platform {
std::string className; std::string className;
int n_samples{ 0 }, n_features{ 0 }; int n_samples{ 0 }, n_features{ 0 };
std::vector<std::string> features; std::vector<std::string> features;
std::vector<std::string> labels;
std::map<std::string, std::vector<int>> states; std::map<std::string, std::vector<int>> states;
bool loaded; bool loaded;
bool discretize; bool discretize;
@@ -30,20 +46,6 @@ namespace platform {
void load_rdata(); void load_rdata();
void computeStates(); void computeStates();
std::vector<mdlp::labels_t> discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y); std::vector<mdlp::labels_t> discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y);
public:
Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
explicit Dataset(const Dataset&);
std::string getName() const;
std::string getClassName() const;
std::vector<string> getFeatures() const;
std::map<std::string, std::vector<int>> getStates() const;
std::pair<vector<std::vector<float>>&, std::vector<int>&> getVectors();
std::pair<vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized();
std::pair<torch::Tensor&, torch::Tensor&> getTensors();
int getNFeatures() const;
int getNSamples() const;
void load();
const bool inline isLoaded() const { return loaded; };
}; };
}; };

View File

@@ -42,6 +42,14 @@ namespace platform {
throw std::invalid_argument("Dataset not loaded."); throw std::invalid_argument("Dataset not loaded.");
} }
} }
std::vector<std::string> Datasets::getLabels(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getLabels();
} else {
throw std::invalid_argument("Dataset not loaded.");
}
}
map<std::string, std::vector<int>> Datasets::getStates(const std::string& name) const map<std::string, std::vector<int>> Datasets::getStates(const std::string& name) const
{ {
if (datasets.at(name)->isLoaded()) { if (datasets.at(name)->isLoaded()) {

View File

@@ -3,18 +3,12 @@
#include "Dataset.h" #include "Dataset.h"
namespace platform { namespace platform {
class Datasets { class Datasets {
private:
std::string path;
fileType_t fileType;
std::string sfileType;
std::map<std::string, std::unique_ptr<Dataset>> datasets;
bool discretize;
void load(); // Loads the list of datasets
public: public:
explicit Datasets(bool discretize, std::string sfileType) : discretize(discretize), sfileType(sfileType) { load(); }; explicit Datasets(bool discretize, std::string sfileType) : discretize(discretize), sfileType(sfileType) { load(); };
std::vector<string> getNames(); std::vector<std::string> getNames();
std::vector<string> getFeatures(const std::string& name) const; std::vector<std::string> getFeatures(const std::string& name) const;
int getNSamples(const std::string& name) const; int getNSamples(const std::string& name) const;
std::vector<std::string> getLabels(const std::string& name) const;
std::string getClassName(const std::string& name) const; std::string getClassName(const std::string& name) const;
int getNClasses(const std::string& name); int getNClasses(const std::string& name);
std::vector<int> getClassesCounts(const std::string& name) const; std::vector<int> getClassesCounts(const std::string& name) const;
@@ -25,5 +19,12 @@ namespace platform {
bool isDataset(const std::string& name) const; bool isDataset(const std::string& name) const;
void loadDataset(const std::string& name) const; void loadDataset(const std::string& name) const;
std::string toString() const; std::string toString() const;
private:
std::string path;
fileType_t fileType;
std::string sfileType;
std::map<std::string, std::unique_ptr<Dataset>> datasets;
bool discretize;
void load(); // Loads the list of datasets
}; };
}; };

View File

@@ -83,6 +83,7 @@ namespace platform {
auto features = datasets.getFeatures(fileName); auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName); auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName); auto className = datasets.getClassName(fileName);
auto labels = datasets.getLabels(fileName);
if (!quiet) { if (!quiet) {
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
} }
@@ -156,12 +157,12 @@ namespace platform {
showProgress(nfold + 1, getColor(clf->getStatus()), "c"); showProgress(nfold + 1, getColor(clf->getStatus()), "c");
test_timer.start(); test_timer.start();
auto y_predict = clf->predict(X_test); auto y_predict = clf->predict(X_test);
Scores scores(y_test, y_predict, states[className].size()); Scores scores(y_test, y_predict, states[className].size(), labels);
auto accuracy_test_value = scores.accuracy(); auto accuracy_test_value = scores.accuracy();
test_time[item] = test_timer.getDuration(); test_time[item] = test_timer.getDuration();
accuracy_train[item] = accuracy_train_value; accuracy_train[item] = accuracy_train_value;
accuracy_test[item] = accuracy_test_value; accuracy_test[item] = accuracy_test_value;
confusion_matrices.push_back(scores.get_confusion_matrix_json()); confusion_matrices.push_back(scores.get_confusion_matrix_json(true));
if (!quiet) if (!quiet)
std::cout << "\b\b\b, " << flush; std::cout << "\b\b\b, " << flush;
// Store results and times in std::vector // Store results and times in std::vector

View File

@@ -147,6 +147,9 @@ TEST_CASE("Classification Report", "[Scores]")
weighted avg 0.8250000 0.6000000 0.6400000 10 weighted avg 0.8250000 0.6000000 0.6400000 10
)"; )";
REQUIRE(scores.classification_report() == expected); REQUIRE(scores.classification_report() == expected);
auto json_matrix = scores.get_confusion_matrix_json(true);
platform::Scores scores2(json_matrix);
REQUIRE(scores.classification_report() == scores2.classification_report());
} }
TEST_CASE("JSON constructor", "[Scores]") TEST_CASE("JSON constructor", "[Scores]")
{ {