Fix mistakes in feature selection in SPnDE

Complete the first A2DE test
Update version number
This commit is contained in:
Ricardo Montañana Gómez 2024-05-05 11:14:01 +02:00
parent f806015b29
commit 0ec53f405f
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 18 additions and 15 deletions

2
.vscode/launch.json vendored
View File

@ -16,7 +16,7 @@
"name": "test", "name": "test",
"program": "${workspaceFolder}/build_debug/tests/TestBayesNet", "program": "${workspaceFolder}/build_debug/tests/TestBayesNet",
"args": [ "args": [
"\"Bisection Best\"" "[A2DE]"
], ],
"cwd": "${workspaceFolder}/build_debug/tests" "cwd": "${workspaceFolder}/build_debug/tests"
}, },

View File

@ -9,9 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Add the Library logo generated with <https://openart.ai> to README.md - Library logo generated with <https://openart.ai> to README.md
- Add link to the coverage report in the README.md coverage label. - Link to the coverage report in the README.md coverage label.
- Add the *convergence_best* hyperparameter to the BoostAODE class, to control the way the prior accuracy is computed if convergence is set. Default value is *false*. - *convergence_best* hyperparameter to the BoostAODE class, to control the way the prior accuracy is computed if convergence is set. Default value is *false*.
- SPnDE model.
- A2DE model.
### Internal ### Internal
@ -19,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Refactor catch2 library location to test/lib - Refactor catch2 library location to test/lib
- Refactor loadDataset function in tests. - Refactor loadDataset function in tests.
- Remove conditionalEdgeWeights method in BayesMetrics. - Remove conditionalEdgeWeights method in BayesMetrics.
- A2DE & SPnDE tests.
## [1.0.5] 2024-04-20 ## [1.0.5] 2024-04-20

View File

@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
project(BayesNet project(BayesNet
VERSION 1.0.5 VERSION 1.0.5.1
DESCRIPTION "Bayesian Network and basic classifiers Library." DESCRIPTION "Bayesian Network and basic classifiers Library."
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet" HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
LANGUAGES CXX LANGUAGES CXX

View File

@ -16,7 +16,7 @@ namespace bayesnet {
addNodes(); addNodes();
std::vector<int> attributes; std::vector<int> attributes;
for (int i = 0; i < static_cast<int>(features.size()); ++i) { for (int i = 0; i < static_cast<int>(features.size()); ++i) {
if (std::find(parents.begin(), parents.end(), i) != parents.end()) { if (std::find(parents.begin(), parents.end(), i) == parents.end()) {
attributes.push_back(i); attributes.push_back(i);
} }
} }
@ -25,6 +25,7 @@ namespace bayesnet {
for (const auto& attribute : attributes) { for (const auto& attribute : attributes) {
model.addEdge(className, features[attribute]); model.addEdge(className, features[attribute]);
for (const auto& root : parents) { for (const auto& root : parents) {
model.addEdge(features[root], features[attribute]); model.addEdge(features[root], features[attribute]);
} }
} }

View File

@ -27,10 +27,11 @@ namespace bayesnet {
significanceModels.clear(); significanceModels.clear();
for (int i = 0; i < features.size() - 1; ++i) { for (int i = 0; i < features.size() - 1; ++i) {
for (int j = i + 1; j < features.size(); ++j) { for (int j = i + 1; j < features.size(); ++j) {
models.push_back(std::make_unique<SPnDE>(std::vector<int>({ i, j }))); auto model = std::make_unique<SPnDE>(std::vector<int>({ i, j }));
models.push_back(std::move(model));
} }
} }
n_models = models.size(); n_models = static_cast<unsigned>(models.size());
significanceModels = std::vector<double>(n_models, 1.0); significanceModels = std::vector<double>(n_models, 1.0);
} }
std::vector<std::string> A2DE::graph(const std::string& title) const std::vector<std::string> A2DE::graph(const std::string& title) const

View File

@ -9,7 +9,7 @@ if(ENABLE_TESTING)
) )
file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc") file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc")
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc
TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc TestA2DE.cc
TestUtils.cc TestBayesEnsemble.cc ${BayesNet_SOURCES}) TestUtils.cc TestBayesEnsemble.cc ${BayesNet_SOURCES})
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp PRIVATE Catch2::Catch2WithMain) target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp PRIVATE Catch2::Catch2WithMain)
add_test(NAME BayesNetworkTest COMMAND TestBayesNet) add_test(NAME BayesNetworkTest COMMAND TestBayesNet)

View File

@ -17,10 +17,8 @@ TEST_CASE("Fit and Score", "[A2DE]")
auto raw = RawDatasets("glass", true); auto raw = RawDatasets("glass", true);
auto clf = bayesnet::A2DE(); auto clf = bayesnet::A2DE();
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states); clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states);
std::cout << "Score A2DE: " << clf.score(raw.Xv, raw.yv) << std::endl; REQUIRE(clf.score(raw.Xv, raw.yv) == Catch::Approx(0.831776).epsilon(raw.epsilon));
// REQUIRE(clf.getNumberOfNodes() == 90); REQUIRE(clf.getNumberOfNodes() == 360);
// REQUIRE(clf.getNumberOfEdges() == 153); REQUIRE(clf.getNumberOfEdges() == 756);
// REQUIRE(clf.getNotes().size() == 2); REQUIRE(clf.getNotes().size() == 0);
// REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 9 with CFS");
// REQUIRE(clf.getNotes()[1] == "Number of models: 9");
} }