diff --git a/.gitmodules b/.gitmodules index 90a50b3..e0ca40e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -18,3 +18,6 @@ url = https://github.com/catchorg/Catch2.git main = main update = merge +[submodule "tests/lib/Files"] + path = tests/lib/Files + url = https://github.com/rmontanana/ArffFiles diff --git a/bayesnet/CMakeLists.txt b/bayesnet/CMakeLists.txt index c1aebe9..2aef26a 100644 --- a/bayesnet/CMakeLists.txt +++ b/bayesnet/CMakeLists.txt @@ -1,6 +1,5 @@ include_directories( ${BayesNet_SOURCE_DIR}/lib/mdlp - ${BayesNet_SOURCE_DIR}/lib/Files ${BayesNet_SOURCE_DIR}/lib/folding ${BayesNet_SOURCE_DIR}/lib/json/include ${BayesNet_SOURCE_DIR} diff --git a/bayesnet/classifiers/Proposal.cc b/bayesnet/classifiers/Proposal.cc index d62ed6e..2dfadb7 100644 --- a/bayesnet/classifiers/Proposal.cc +++ b/bayesnet/classifiers/Proposal.cc @@ -4,7 +4,6 @@ // SPDX-License-Identifier: MIT // *************************************************************** -#include #include "Proposal.h" namespace bayesnet { @@ -54,8 +53,7 @@ namespace bayesnet { yJoinParents[i] += to_string(pDataset.index({ idx, i }).item()); } } - auto arff = ArffFiles(); - auto yxv = arff.factorize(yJoinParents); + auto yxv = factorize(yJoinParents); auto xvf_ptr = Xf.index({ index }).data_ptr(); auto xvf = std::vector(xvf_ptr, xvf_ptr + Xf.size(1)); discretizers[feature]->fit(xvf, yxv); @@ -113,4 +111,19 @@ namespace bayesnet { } return Xtd; } + std::vector Proposal::factorize(const std::vector& labels_t) + { + std::vector yy; + yy.reserve(labels_t.size()); + std::map labelMap; + int i = 0; + for (const std::string& label : labels_t) { + if (labelMap.find(label) == labelMap.end()) { + labelMap[label] = i++; + bool allDigits = std::all_of(label.begin(), label.end(), ::isdigit); + } + yy.push_back(labelMap[label]); + } + return yy; + } } \ No newline at end of file diff --git a/bayesnet/classifiers/Proposal.h b/bayesnet/classifiers/Proposal.h index 6e7c351..dd011d8 100644 --- a/bayesnet/classifiers/Proposal.h +++ b/bayesnet/classifiers/Proposal.h @@ -27,6 +27,7 @@ namespace bayesnet { torch::Tensor y; // y discrete nx1 tensor map discretizers; private: + std::vector factorize(const std::vector& labels_t); torch::Tensor& pDataset; // (n+1)xm tensor std::vector& pFeatures; std::string& pClassName; diff --git a/lib/Files/ArffFiles.hpp b/lib/Files/ArffFiles.hpp deleted file mode 100644 index 7227299..0000000 --- a/lib/Files/ArffFiles.hpp +++ /dev/null @@ -1,161 +0,0 @@ -#ifndef ARFFFILES_HPP -#define ARFFFILES_HPP - -#include -#include -#include -#include -#include -#include // std::isdigit -#include // std::all_of -#include - -class ArffFiles { -public: - ArffFiles() = default; - void load(const std::string& fileName, bool classLast = true) - { - int labelIndex; - loadCommon(fileName); - if (classLast) { - className = std::get<0>(attributes.back()); - classType = std::get<1>(attributes.back()); - attributes.pop_back(); - labelIndex = static_cast(attributes.size()); - } else { - className = std::get<0>(attributes.front()); - classType = std::get<1>(attributes.front()); - attributes.erase(attributes.begin()); - labelIndex = 0; - } - generateDataset(labelIndex); - }; - void load(const std::string& fileName, const std::string& name) - { - int labelIndex; - loadCommon(fileName); - bool found = false; - for (int i = 0; i < attributes.size(); ++i) { - if (attributes[i].first == name) { - className = std::get<0>(attributes[i]); - classType = std::get<1>(attributes[i]); - attributes.erase(attributes.begin() + i); - labelIndex = i; - found = true; - break; - } - } - if (!found) { - throw std::invalid_argument("Class name not found"); - } - generateDataset(labelIndex); - }; - std::vector getLines() const { return lines; }; - unsigned long int getSize() const { return lines.size(); }; - std::string getClassName() const { return className; }; - std::string getClassType() const { return classType; }; - std::vector getLabels() const { return labels; } - static std::string trim(const std::string& source) - { - std::string s(source); - s.erase(0, s.find_first_not_of(" '\n\r\t")); - s.erase(s.find_last_not_of(" '\n\r\t") + 1); - return s; - }; - std::vector>& getX() { return X; }; - std::vector& getY() { return y; } - std::vector> getAttributes() const { return attributes; }; - std::vector factorize(const std::vector& labels_t) - { - std::vector yy; - labels.clear(); - yy.reserve(labels_t.size()); - std::map labelMap; - int i = 0; - for (const std::string& label : labels_t) { - if (labelMap.find(label) == labelMap.end()) { - labelMap[label] = i++; - bool allDigits = std::all_of(label.begin(), label.end(), isdigit); - if (allDigits) - labels.push_back("Class " + label); - else - labels.push_back(label); - } - yy.push_back(labelMap[label]); - } - return yy; - }; -private: - void generateDataset(int labelIndex) - { - X = std::vector>(attributes.size(), std::vector(lines.size())); - auto yy = std::vector(lines.size(), ""); - auto removeLines = std::vector(); // Lines with missing values - for (size_t i = 0; i < lines.size(); i++) { - std::stringstream ss(lines[i]); - std::string value; - int pos = 0; - int xIndex = 0; - while (getline(ss, value, ',')) { - if (pos++ == labelIndex) { - yy[i] = value; - } else { - if (value == "?") { - X[xIndex++][i] = -1; - removeLines.push_back(i); - } else - X[xIndex++][i] = stof(value); - } - } - } - for (auto i : removeLines) { - yy.erase(yy.begin() + i); - for (auto& x : X) { - x.erase(x.begin() + i); - } - } - y = factorize(yy); - }; - void loadCommon(std::string fileName) - { - std::ifstream file(fileName); - if (!file.is_open()) { - throw std::invalid_argument("Unable to open file"); - } - std::string line; - std::string keyword; - std::string attribute; - std::string type; - std::string type_w; - while (getline(file, line)) { - if (line.empty() || line[0] == '%' || line == "\r" || line == " ") { - continue; - } - if (line.find("@attribute") != std::string::npos || line.find("@ATTRIBUTE") != std::string::npos) { - std::stringstream ss(line); - ss >> keyword >> attribute; - type = ""; - while (ss >> type_w) - type += type_w + " "; - attributes.emplace_back(trim(attribute), trim(type)); - continue; - } - if (line[0] == '@') { - continue; - } - lines.push_back(line); - } - file.close(); - if (attributes.empty()) - throw std::invalid_argument("No attributes found"); - }; - std::vector lines; - std::vector> attributes; - std::string className; - std::string classType; - std::vector> X; - std::vector y; - std::vector labels; -}; - -#endif \ No newline at end of file diff --git a/sample/lib/Files/ArffFiles.hpp b/sample/lib/Files/ArffFiles.hpp deleted file mode 100644 index 7227299..0000000 --- a/sample/lib/Files/ArffFiles.hpp +++ /dev/null @@ -1,161 +0,0 @@ -#ifndef ARFFFILES_HPP -#define ARFFFILES_HPP - -#include -#include -#include -#include -#include -#include // std::isdigit -#include // std::all_of -#include - -class ArffFiles { -public: - ArffFiles() = default; - void load(const std::string& fileName, bool classLast = true) - { - int labelIndex; - loadCommon(fileName); - if (classLast) { - className = std::get<0>(attributes.back()); - classType = std::get<1>(attributes.back()); - attributes.pop_back(); - labelIndex = static_cast(attributes.size()); - } else { - className = std::get<0>(attributes.front()); - classType = std::get<1>(attributes.front()); - attributes.erase(attributes.begin()); - labelIndex = 0; - } - generateDataset(labelIndex); - }; - void load(const std::string& fileName, const std::string& name) - { - int labelIndex; - loadCommon(fileName); - bool found = false; - for (int i = 0; i < attributes.size(); ++i) { - if (attributes[i].first == name) { - className = std::get<0>(attributes[i]); - classType = std::get<1>(attributes[i]); - attributes.erase(attributes.begin() + i); - labelIndex = i; - found = true; - break; - } - } - if (!found) { - throw std::invalid_argument("Class name not found"); - } - generateDataset(labelIndex); - }; - std::vector getLines() const { return lines; }; - unsigned long int getSize() const { return lines.size(); }; - std::string getClassName() const { return className; }; - std::string getClassType() const { return classType; }; - std::vector getLabels() const { return labels; } - static std::string trim(const std::string& source) - { - std::string s(source); - s.erase(0, s.find_first_not_of(" '\n\r\t")); - s.erase(s.find_last_not_of(" '\n\r\t") + 1); - return s; - }; - std::vector>& getX() { return X; }; - std::vector& getY() { return y; } - std::vector> getAttributes() const { return attributes; }; - std::vector factorize(const std::vector& labels_t) - { - std::vector yy; - labels.clear(); - yy.reserve(labels_t.size()); - std::map labelMap; - int i = 0; - for (const std::string& label : labels_t) { - if (labelMap.find(label) == labelMap.end()) { - labelMap[label] = i++; - bool allDigits = std::all_of(label.begin(), label.end(), isdigit); - if (allDigits) - labels.push_back("Class " + label); - else - labels.push_back(label); - } - yy.push_back(labelMap[label]); - } - return yy; - }; -private: - void generateDataset(int labelIndex) - { - X = std::vector>(attributes.size(), std::vector(lines.size())); - auto yy = std::vector(lines.size(), ""); - auto removeLines = std::vector(); // Lines with missing values - for (size_t i = 0; i < lines.size(); i++) { - std::stringstream ss(lines[i]); - std::string value; - int pos = 0; - int xIndex = 0; - while (getline(ss, value, ',')) { - if (pos++ == labelIndex) { - yy[i] = value; - } else { - if (value == "?") { - X[xIndex++][i] = -1; - removeLines.push_back(i); - } else - X[xIndex++][i] = stof(value); - } - } - } - for (auto i : removeLines) { - yy.erase(yy.begin() + i); - for (auto& x : X) { - x.erase(x.begin() + i); - } - } - y = factorize(yy); - }; - void loadCommon(std::string fileName) - { - std::ifstream file(fileName); - if (!file.is_open()) { - throw std::invalid_argument("Unable to open file"); - } - std::string line; - std::string keyword; - std::string attribute; - std::string type; - std::string type_w; - while (getline(file, line)) { - if (line.empty() || line[0] == '%' || line == "\r" || line == " ") { - continue; - } - if (line.find("@attribute") != std::string::npos || line.find("@ATTRIBUTE") != std::string::npos) { - std::stringstream ss(line); - ss >> keyword >> attribute; - type = ""; - while (ss >> type_w) - type += type_w + " "; - attributes.emplace_back(trim(attribute), trim(type)); - continue; - } - if (line[0] == '@') { - continue; - } - lines.push_back(line); - } - file.close(); - if (attributes.empty()) - throw std::invalid_argument("No attributes found"); - }; - std::vector lines; - std::vector> attributes; - std::string className; - std::string classType; - std::vector> X; - std::vector y; - std::vector labels; -}; - -#endif \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 99249fc..1ff33ee 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,6 +1,6 @@ if(ENABLE_TESTING) include_directories( - ${BayesNet_SOURCE_DIR}/lib/Files + ${BayesNet_SOURCE_DIR}/tests/lib/Files ${BayesNet_SOURCE_DIR}/lib/folding ${BayesNet_SOURCE_DIR}/lib/mdlp ${BayesNet_SOURCE_DIR}/lib/json/include @@ -11,7 +11,7 @@ if(ENABLE_TESTING) add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc TestA2DE.cc TestUtils.cc TestBayesEnsemble.cc TestModulesVersions.cc TestBoostA2DE.cc ${BayesNet_SOURCES}) - target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp PRIVATE Catch2::Catch2WithMain) + target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" mdlp PRIVATE Catch2::Catch2WithMain) add_test(NAME BayesNetworkTest COMMAND TestBayesNet) add_test(NAME A2DE COMMAND TestBayesNet "[A2DE]") add_test(NAME BoostA2DE COMMAND TestBayesNet "[BoostA2DE]") diff --git a/tests/TestModulesVersions.cc b/tests/TestModulesVersions.cc index c9a44e0..a8b2ce2 100644 --- a/tests/TestModulesVersions.cc +++ b/tests/TestModulesVersions.cc @@ -18,7 +18,8 @@ std::map modules = { { "mdlp", "1.1.2" }, { "Folding", "1.1.0" }, - { "json", "3.11" } + { "json", "3.11" }, + { "ArffFiles", "1.0.0" } }; TEST_CASE("MDLP", "[Modules]") @@ -35,3 +36,8 @@ TEST_CASE("NLOHMANN_JSON", "[Modules]") { REQUIRE(JSON_VERSION == modules["json"]); } +TEST_CASE("ArffFiles", "[Modules]") +{ + auto handler = ArffFiles(); + REQUIRE(handler.version() == modules["ArffFiles"]); +} diff --git a/tests/TestUtils.h b/tests/TestUtils.h index 09ea305..96b6775 100644 --- a/tests/TestUtils.h +++ b/tests/TestUtils.h @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include #include diff --git a/tests/lib/Files b/tests/lib/Files new file mode 160000 index 0000000..40ac380 --- /dev/null +++ b/tests/lib/Files @@ -0,0 +1 @@ +Subproject commit 40ac38011a2445e00df8a18048c67abaff16fa59