Fix numeric_features problem
This commit is contained in:
@@ -69,7 +69,7 @@ public:
|
||||
}
|
||||
std::vector<std::vector<float>>& getX() { return X; }
|
||||
std::vector<int>& getY() { return y; }
|
||||
std::vector<bool> getNumericAttributes() const { return numeric_features; }
|
||||
std::map<std::string, bool> getNumericAttributes() const { return numeric_features; }
|
||||
std::vector<std::pair<std::string, std::string>> getAttributes() const { return attributes; };
|
||||
std::vector<std::string> split(const std::string& text, char delimiter)
|
||||
{
|
||||
@@ -84,7 +84,7 @@ public:
|
||||
std::string version() const { return VERSION; }
|
||||
protected:
|
||||
std::vector<std::string> lines;
|
||||
std::vector<bool> numeric_features;
|
||||
std::map<std::string, bool> numeric_features;
|
||||
std::vector<std::pair<std::string, std::string>> attributes;
|
||||
std::string className;
|
||||
std::string classType;
|
||||
@@ -98,14 +98,14 @@ private:
|
||||
//
|
||||
// Learn the numeric features
|
||||
//
|
||||
numeric_features = std::vector<bool>(attributes.size(), false);
|
||||
for (size_t i = 0; i < attributes.size(); i++) {
|
||||
if (i == labelIndex) {
|
||||
numeric_features.clear();
|
||||
for (const auto& attribute : attributes) {
|
||||
auto feature = attribute.first;
|
||||
if (feature == className)
|
||||
continue;
|
||||
}
|
||||
std::string values = attributes.at(i).second;
|
||||
auto values = attribute.second;
|
||||
std::transform(values.begin(), values.end(), values.begin(), ::toupper);
|
||||
numeric_features[i] = values == "REAL" || values == "INTEGER" || values == "NUMERIC";
|
||||
numeric_features[feature] = values == "REAL" || values == "INTEGER" || values == "NUMERIC";
|
||||
}
|
||||
}
|
||||
std::vector<int> factorize(const std::string feature, const std::vector<std::string>& labels_t)
|
||||
@@ -143,7 +143,7 @@ private:
|
||||
if (pos++ == labelIndex) {
|
||||
yy[i] = token;
|
||||
} else {
|
||||
if (numeric_features[xIndex]) {
|
||||
if (numeric_features[attributes[xIndex].first]) {
|
||||
X[xIndex][i] = stof(token);
|
||||
} else {
|
||||
Xs[xIndex][i] = token;
|
||||
@@ -153,7 +153,7 @@ private:
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < attributes.size(); i++) {
|
||||
if (!numeric_features[i]) {
|
||||
if (!numeric_features[attributes[i].first]) {
|
||||
auto data = factorize(attributes[i].first, Xs[i]);
|
||||
std::transform(data.begin(), data.end(), X[i].begin(), [](int x) { return float(x);});
|
||||
}
|
||||
|
@@ -108,9 +108,10 @@ TEST_CASE("Load with class name as first attribute", "[ArffFiles]")
|
||||
{1.86094, 1.89165, 1.93921, 1.71752},
|
||||
{-0.207383, -0.193249, -0.239664, -0.218572} }
|
||||
};
|
||||
auto X = arff.getX();
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int j = 0; j < 4; ++j)
|
||||
REQUIRE(arff.getX()[i][j] == Catch::Approx(expected[i][j]));
|
||||
REQUIRE(X[i][j] == Catch::Approx(expected[i][j]));
|
||||
}
|
||||
auto expected_y = std::vector<int>{ 0, 0, 0, 0 };
|
||||
for (int i = 120; i < 124; ++i)
|
||||
@@ -132,21 +133,16 @@ TEST_CASE("Adult dataset", "[ArffFiles]")
|
||||
REQUIRE(X[0][0] == 25);
|
||||
REQUIRE(X[1][0] == 0);
|
||||
REQUIRE(X[2][0] == 226802);
|
||||
auto states = arff.getStates();
|
||||
auto numeric = arff.getNumericAttributes();
|
||||
auto attributes = arff.getAttributes();
|
||||
for (size_t i = 0; i < numeric.size(); ++i) {
|
||||
auto feature = attributes.at(i).first;
|
||||
auto state = states.at(feature);
|
||||
if (!numeric.at(i)) {
|
||||
std::cout << feature << ": ";
|
||||
for (const auto& s : state) {
|
||||
std::cout << s << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
} else {
|
||||
std::cout << feature << " size: " << state.size() << std::endl;
|
||||
}
|
||||
}
|
||||
REQUIRE(X[3][0] == 0);
|
||||
REQUIRE(X[4][0] == 7);
|
||||
REQUIRE(X[5][0] == 0);
|
||||
REQUIRE(X[6][0] == 0);
|
||||
REQUIRE(X[7][0] == 0);
|
||||
REQUIRE(X[8][0] == 0);
|
||||
REQUIRE(X[9][0] == 0);
|
||||
REQUIRE(X[10][0] == 0);
|
||||
REQUIRE(X[11][0] == 0);
|
||||
REQUIRE(X[12][0] == 40);
|
||||
REQUIRE(X[13][0] == 0);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user