diff --git a/ArffFiles.hpp b/ArffFiles.hpp index ea356bd..e4bf81f 100644 --- a/ArffFiles.hpp +++ b/ArffFiles.hpp @@ -69,7 +69,7 @@ public: } std::vector>& getX() { return X; } std::vector& getY() { return y; } - std::vector getNumericAttributes() const { return numeric_features; } + std::map getNumericAttributes() const { return numeric_features; } std::vector> getAttributes() const { return attributes; }; std::vector split(const std::string& text, char delimiter) { @@ -84,7 +84,7 @@ public: std::string version() const { return VERSION; } protected: std::vector lines; - std::vector numeric_features; + std::map numeric_features; std::vector> attributes; std::string className; std::string classType; @@ -98,14 +98,14 @@ private: // // Learn the numeric features // - numeric_features = std::vector(attributes.size(), false); - for (size_t i = 0; i < attributes.size(); i++) { - if (i == labelIndex) { + numeric_features.clear(); + for (const auto& attribute : attributes) { + auto feature = attribute.first; + if (feature == className) continue; - } - std::string values = attributes.at(i).second; + auto values = attribute.second; std::transform(values.begin(), values.end(), values.begin(), ::toupper); - numeric_features[i] = values == "REAL" || values == "INTEGER" || values == "NUMERIC"; + numeric_features[feature] = values == "REAL" || values == "INTEGER" || values == "NUMERIC"; } } std::vector factorize(const std::string feature, const std::vector& labels_t) @@ -143,7 +143,7 @@ private: if (pos++ == labelIndex) { yy[i] = token; } else { - if (numeric_features[xIndex]) { + if (numeric_features[attributes[xIndex].first]) { X[xIndex][i] = stof(token); } else { Xs[xIndex][i] = token; @@ -153,7 +153,7 @@ private: } } for (size_t i = 0; i < attributes.size(); i++) { - if (!numeric_features[i]) { + if (!numeric_features[attributes[i].first]) { auto data = factorize(attributes[i].first, Xs[i]); std::transform(data.begin(), data.end(), X[i].begin(), [](int x) { return float(x);}); } diff --git a/tests/TestArffFiles.cc b/tests/TestArffFiles.cc index b2025f5..9ff6405 100644 --- a/tests/TestArffFiles.cc +++ b/tests/TestArffFiles.cc @@ -108,9 +108,10 @@ TEST_CASE("Load with class name as first attribute", "[ArffFiles]") {1.86094, 1.89165, 1.93921, 1.71752}, {-0.207383, -0.193249, -0.239664, -0.218572} } }; + auto X = arff.getX(); for (int i = 0; i < 4; ++i) { for (int j = 0; j < 4; ++j) - REQUIRE(arff.getX()[i][j] == Catch::Approx(expected[i][j])); + REQUIRE(X[i][j] == Catch::Approx(expected[i][j])); } auto expected_y = std::vector{ 0, 0, 0, 0 }; for (int i = 120; i < 124; ++i) @@ -132,21 +133,16 @@ TEST_CASE("Adult dataset", "[ArffFiles]") REQUIRE(X[0][0] == 25); REQUIRE(X[1][0] == 0); REQUIRE(X[2][0] == 226802); - auto states = arff.getStates(); - auto numeric = arff.getNumericAttributes(); - auto attributes = arff.getAttributes(); - for (size_t i = 0; i < numeric.size(); ++i) { - auto feature = attributes.at(i).first; - auto state = states.at(feature); - if (!numeric.at(i)) { - std::cout << feature << ": "; - for (const auto& s : state) { - std::cout << s << ", "; - } - std::cout << std::endl; - } else { - std::cout << feature << " size: " << state.size() << std::endl; - } - } + REQUIRE(X[3][0] == 0); + REQUIRE(X[4][0] == 7); + REQUIRE(X[5][0] == 0); + REQUIRE(X[6][0] == 0); + REQUIRE(X[7][0] == 0); + REQUIRE(X[8][0] == 0); + REQUIRE(X[9][0] == 0); + REQUIRE(X[10][0] == 0); + REQUIRE(X[11][0] == 0); + REQUIRE(X[12][0] == 40); + REQUIRE(X[13][0] == 0); }