Fix numeric_features problem

This commit is contained in:
2024-06-12 21:59:59 +02:00
parent 9e1ef5bce2
commit cf32b9ae58
2 changed files with 23 additions and 27 deletions

View File

@@ -69,7 +69,7 @@ public:
} }
std::vector<std::vector<float>>& getX() { return X; } std::vector<std::vector<float>>& getX() { return X; }
std::vector<int>& getY() { return y; } std::vector<int>& getY() { return y; }
std::vector<bool> getNumericAttributes() const { return numeric_features; } std::map<std::string, bool> getNumericAttributes() const { return numeric_features; }
std::vector<std::pair<std::string, std::string>> getAttributes() const { return attributes; }; std::vector<std::pair<std::string, std::string>> getAttributes() const { return attributes; };
std::vector<std::string> split(const std::string& text, char delimiter) std::vector<std::string> split(const std::string& text, char delimiter)
{ {
@@ -84,7 +84,7 @@ public:
std::string version() const { return VERSION; } std::string version() const { return VERSION; }
protected: protected:
std::vector<std::string> lines; std::vector<std::string> lines;
std::vector<bool> numeric_features; std::map<std::string, bool> numeric_features;
std::vector<std::pair<std::string, std::string>> attributes; std::vector<std::pair<std::string, std::string>> attributes;
std::string className; std::string className;
std::string classType; std::string classType;
@@ -98,14 +98,14 @@ private:
// //
// Learn the numeric features // Learn the numeric features
// //
numeric_features = std::vector<bool>(attributes.size(), false); numeric_features.clear();
for (size_t i = 0; i < attributes.size(); i++) { for (const auto& attribute : attributes) {
if (i == labelIndex) { auto feature = attribute.first;
if (feature == className)
continue; continue;
} auto values = attribute.second;
std::string values = attributes.at(i).second;
std::transform(values.begin(), values.end(), values.begin(), ::toupper); std::transform(values.begin(), values.end(), values.begin(), ::toupper);
numeric_features[i] = values == "REAL" || values == "INTEGER" || values == "NUMERIC"; numeric_features[feature] = values == "REAL" || values == "INTEGER" || values == "NUMERIC";
} }
} }
std::vector<int> factorize(const std::string feature, const std::vector<std::string>& labels_t) std::vector<int> factorize(const std::string feature, const std::vector<std::string>& labels_t)
@@ -143,7 +143,7 @@ private:
if (pos++ == labelIndex) { if (pos++ == labelIndex) {
yy[i] = token; yy[i] = token;
} else { } else {
if (numeric_features[xIndex]) { if (numeric_features[attributes[xIndex].first]) {
X[xIndex][i] = stof(token); X[xIndex][i] = stof(token);
} else { } else {
Xs[xIndex][i] = token; Xs[xIndex][i] = token;
@@ -153,7 +153,7 @@ private:
} }
} }
for (size_t i = 0; i < attributes.size(); i++) { for (size_t i = 0; i < attributes.size(); i++) {
if (!numeric_features[i]) { if (!numeric_features[attributes[i].first]) {
auto data = factorize(attributes[i].first, Xs[i]); auto data = factorize(attributes[i].first, Xs[i]);
std::transform(data.begin(), data.end(), X[i].begin(), [](int x) { return float(x);}); std::transform(data.begin(), data.end(), X[i].begin(), [](int x) { return float(x);});
} }

View File

@@ -108,9 +108,10 @@ TEST_CASE("Load with class name as first attribute", "[ArffFiles]")
{1.86094, 1.89165, 1.93921, 1.71752}, {1.86094, 1.89165, 1.93921, 1.71752},
{-0.207383, -0.193249, -0.239664, -0.218572} } {-0.207383, -0.193249, -0.239664, -0.218572} }
}; };
auto X = arff.getX();
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) for (int j = 0; j < 4; ++j)
REQUIRE(arff.getX()[i][j] == Catch::Approx(expected[i][j])); REQUIRE(X[i][j] == Catch::Approx(expected[i][j]));
} }
auto expected_y = std::vector<int>{ 0, 0, 0, 0 }; auto expected_y = std::vector<int>{ 0, 0, 0, 0 };
for (int i = 120; i < 124; ++i) for (int i = 120; i < 124; ++i)
@@ -132,21 +133,16 @@ TEST_CASE("Adult dataset", "[ArffFiles]")
REQUIRE(X[0][0] == 25); REQUIRE(X[0][0] == 25);
REQUIRE(X[1][0] == 0); REQUIRE(X[1][0] == 0);
REQUIRE(X[2][0] == 226802); REQUIRE(X[2][0] == 226802);
auto states = arff.getStates(); REQUIRE(X[3][0] == 0);
auto numeric = arff.getNumericAttributes(); REQUIRE(X[4][0] == 7);
auto attributes = arff.getAttributes(); REQUIRE(X[5][0] == 0);
for (size_t i = 0; i < numeric.size(); ++i) { REQUIRE(X[6][0] == 0);
auto feature = attributes.at(i).first; REQUIRE(X[7][0] == 0);
auto state = states.at(feature); REQUIRE(X[8][0] == 0);
if (!numeric.at(i)) { REQUIRE(X[9][0] == 0);
std::cout << feature << ": "; REQUIRE(X[10][0] == 0);
for (const auto& s : state) { REQUIRE(X[11][0] == 0);
std::cout << s << ", "; REQUIRE(X[12][0] == 40);
} REQUIRE(X[13][0] == 0);
std::cout << std::endl;
} else {
std::cout << feature << " size: " << state.size() << std::endl;
}
}
} }