From 9e1ef5bce22d10d278be84b2fa378048048f0105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Wed, 12 Jun 2024 21:10:53 +0200 Subject: [PATCH] Complete adding string features --- ArffFiles.hpp | 132 ++++++++++++++--------------------------- tests/TestArffFiles.cc | 40 +++++++++---- 2 files changed, 74 insertions(+), 98 deletions(-) diff --git a/ArffFiles.hpp b/ArffFiles.hpp index 1b32a9f..ea356bd 100644 --- a/ArffFiles.hpp +++ b/ArffFiles.hpp @@ -7,7 +7,9 @@ #include #include #include // std::isdigit -#include // std::all_of +#include // std::all_of std::transform + +#include // TODO remove class ArffFiles { const std::string VERSION = "1.1.0"; @@ -28,48 +30,8 @@ public: attributes.erase(attributes.begin()); labelIndex = 0; } + preprocessDataset(labelIndex); generateDataset(labelIndex); - }; - void preprocessDataset(int labelIndex) - { - // - // Learn the type of features - // - bool goodLine = false; - std::vector tokens; - auto removeLines = std::vector(); // Lines with missing values - int i = 0; - // Select a line with no missing values - while (!goodLine && i < lines.size()) { - tokens = split(lines[i], ','); - for (const auto& token : tokens) { - goodLine = true; - if (token == "?") { - goodLine = false; - removeLines.push_back(i); - break; - } - - } - i++; - } - // Remove lines in reverse order - std::sort(removeLines.begin(), removeLines.end(), std::greater()); - for (auto i : removeLines) { - lines.erase(lines.begin() + i); - } - numeric_features = std::vector(attributes.size(), true); - for (size_t i = 0; i < tokens.size(); i++) { - if (i == labelIndex) { - continue; - } - try { - stof(tokens[i]); - } - catch (std::invalid_argument& e) { - numeric_features[i] = false; - } - } } void load(const std::string& fileName, const std::string& name) { @@ -91,12 +53,12 @@ public: } preprocessDataset(labelIndex); generateDataset(labelIndex); - }; - std::vector getLines() const { return lines; }; - unsigned long int getSize() const { return lines.size(); }; - std::string getClassName() const { return className; }; - std::string getClassType() const { return classType; }; - std::map> getStates() const { return states; }; + } + std::vector getLines() const { return lines; } + unsigned long int getSize() const { return lines.size(); } + std::string getClassName() const { return className; } + std::string getClassType() const { return classType; } + std::map> getStates() const { return states; } std::vector getLabels() const { return states.at(className); } static std::string trim(const std::string& source) { @@ -104,9 +66,10 @@ public: s.erase(0, s.find_first_not_of(" '\n\r\t")); s.erase(s.find_last_not_of(" '\n\r\t") + 1); return s; - }; - std::vector>& getX() { return X; }; + } + std::vector>& getX() { return X; } std::vector& getY() { return y; } + std::vector getNumericAttributes() const { return numeric_features; } std::vector> getAttributes() const { return attributes; }; std::vector split(const std::string& text, char delimiter) { @@ -118,7 +81,7 @@ public: } return result; } - std::string version() const { return VERSION; }; + std::string version() const { return VERSION; } protected: std::vector lines; std::vector numeric_features; @@ -128,8 +91,23 @@ protected: std::vector> X; std::vector> Xs; std::vector y; - std::map> states; + std::map> states; private: + void preprocessDataset(int labelIndex) + { + // + // Learn the numeric features + // + numeric_features = std::vector(attributes.size(), false); + for (size_t i = 0; i < attributes.size(); i++) { + if (i == labelIndex) { + continue; + } + std::string values = attributes.at(i).second; + std::transform(values.begin(), values.end(), values.begin(), ::toupper); + numeric_features[i] = values == "REAL" || values == "INTEGER" || values == "NUMERIC"; + } + } std::vector factorize(const std::string feature, const std::vector& labels_t) { std::vector yy; @@ -149,13 +127,12 @@ private: yy.push_back(labelMap[label]); } return yy; - }; + } void generateDataset(int labelIndex) { X = std::vector>(attributes.size(), std::vector(lines.size())); Xs = std::vector>(attributes.size(), std::vector(lines.size())); auto yy = std::vector(lines.size(), ""); - auto removeLines = std::vector(); // Lines with missing values for (size_t i = 0; i < lines.size(); i++) { std::stringstream ss(lines[i]); std::string value; @@ -166,17 +143,12 @@ private: if (pos++ == labelIndex) { yy[i] = token; } else { - if (token == "?") { - X[xIndex++][i] = -1; - removeLines.push_back(i); + if (numeric_features[xIndex]) { + X[xIndex][i] = stof(token); } else { - if (numeric_features[xIndex]) { - X[xIndex][i] = stof(token); - } else { - Xs[xIndex][i] = token; - } - xIndex++; + Xs[xIndex][i] = token; } + xIndex++; } } } @@ -184,33 +156,10 @@ private: if (!numeric_features[i]) { auto data = factorize(attributes[i].first, Xs[i]); std::transform(data.begin(), data.end(), X[i].begin(), [](int x) { return float(x);}); - } else { - states[attributes[i].first] = std::vector(Xs[i].begin(), Xs[i].end()); - } - } - // Remove lines in reverse order - std::sort(removeLines.begin(), removeLines.end(), std::greater()); - for (auto i : removeLines) { - yy.erase(yy.begin() + i); - for (auto& x : X) { - try { - x.erase(x.begin() + i); - } - catch (std::out_of_range& e) { - continue; - } - } - for (auto& x : Xs) { - try { - x.erase(x.begin() + i); - } - catch (std::out_of_range& e) { - continue; - } } } y = factorize(className, yy); - }; + } void loadCommon(std::string fileName) { std::ifstream file(fileName); @@ -238,12 +187,19 @@ private: if (line[0] == '@') { continue; } + if (line.find("?", 0) != std::string::npos) { + // ignore lines with missing values + continue; + } lines.push_back(line); } file.close(); + for (const auto& attribute : attributes) { + states[attribute.first] = std::vector(); + } if (attributes.empty()) throw std::invalid_argument("No attributes found"); - }; + } }; #endif \ No newline at end of file diff --git a/tests/TestArffFiles.cc b/tests/TestArffFiles.cc index 744227c..b2025f5 100644 --- a/tests/TestArffFiles.cc +++ b/tests/TestArffFiles.cc @@ -18,7 +18,7 @@ public: TEST_CASE("Version Test", "[ArffFiles]") { ArffFiles arff; - REQUIRE(arff.version() == "1.0.0"); + REQUIRE(arff.version() == "1.1.0"); } TEST_CASE("Load Test", "[ArffFiles]") { @@ -65,14 +65,14 @@ TEST_CASE("Load Test", "[ArffFiles]") TEST_CASE("Load with class name", "[ArffFiles]") { ArffFiles arff; - arff.load(Paths::datasets("glass"), "Type"); + arff.load(Paths::datasets("glass"), std::string("Type")); REQUIRE(arff.getClassName() == "Type"); REQUIRE(arff.getClassType() == "{ 'build wind float', 'build wind non-float', 'vehic wind float', 'vehic wind non-float', containers, tableware, headlamps}"); REQUIRE(arff.getLabels().size() == 6); - REQUIRE(arff.getLabels()[0] == "'build wind float'"); - REQUIRE(arff.getLabels()[1] == "'vehic wind float'"); + REQUIRE(arff.getLabels()[0] == "build wind float"); + REQUIRE(arff.getLabels()[1] == "vehic wind float"); REQUIRE(arff.getLabels()[2] == "tableware"); - REQUIRE(arff.getLabels()[3] == "'build wind non-float'"); + REQUIRE(arff.getLabels()[3] == "build wind non-float"); REQUIRE(arff.getLabels()[4] == "headlamps"); REQUIRE(arff.getLabels()[5] == "containers"); REQUIRE(arff.getSize() == 214); @@ -119,14 +119,34 @@ TEST_CASE("Load with class name as first attribute", "[ArffFiles]") TEST_CASE("Adult dataset", "[ArffFiles]") { ArffFiles arff; - arff.load(Paths::datasets("adult")); + arff.load(Paths::datasets("adult"), std::string("class")); REQUIRE(arff.getClassName() == "class"); - REQUIRE(arff.getClassType() == "{ <=50K, >50K}"); + REQUIRE(arff.getClassType() == "{ >50K, <=50K }"); REQUIRE(arff.getLabels().size() == 2); REQUIRE(arff.getLabels()[0] == "<=50K"); REQUIRE(arff.getLabels()[1] == ">50K"); - REQUIRE(arff.getSize() == 32561); - REQUIRE(arff.getLines().size() == 32561); - REQUIRE(arff.getLines()[0] == "39, State-gov, 77516, Bachelors, 13, Never-married, Adm-clerical, Not-in-family, White "); + REQUIRE(arff.getSize() == 45222); + REQUIRE(arff.getLines().size() == 45222); + REQUIRE(arff.getLines()[0] == "25, Private, 226802, 11th, 7, Never-married, Machine-op-inspct, Own-child, Black, Male, 0, 0, 40, United-States, <=50K"); + auto X = arff.getX(); + REQUIRE(X[0][0] == 25); + REQUIRE(X[1][0] == 0); + REQUIRE(X[2][0] == 226802); + auto states = arff.getStates(); + auto numeric = arff.getNumericAttributes(); + auto attributes = arff.getAttributes(); + for (size_t i = 0; i < numeric.size(); ++i) { + auto feature = attributes.at(i).first; + auto state = states.at(feature); + if (!numeric.at(i)) { + std::cout << feature << ": "; + for (const auto& s : state) { + std::cout << s << ", "; + } + std::cout << std::endl; + } else { + std::cout << feature << " size: " << state.size() << std::endl; + } + } }