Fix mistakes in feature selection in SPnDE

Complete the first A2DE test
Update version number
This commit is contained in:
Ricardo Montañana Gómez 2024-05-05 11:14:01 +02:00
parent f806015b29
commit 0ec53f405f
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 18 additions and 15 deletions

2
.vscode/launch.json vendored
View File

@ -16,7 +16,7 @@
"name": "test",
"program": "${workspaceFolder}/build_debug/tests/TestBayesNet",
"args": [
"\"Bisection Best\""
"[A2DE]"
],
"cwd": "${workspaceFolder}/build_debug/tests"
},

View File

@ -9,9 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Add the Library logo generated with <https://openart.ai> to README.md
- Add link to the coverage report in the README.md coverage label.
- Add the *convergence_best* hyperparameter to the BoostAODE class, to control the way the prior accuracy is computed if convergence is set. Default value is *false*.
- Library logo generated with <https://openart.ai> to README.md
- Link to the coverage report in the README.md coverage label.
- *convergence_best* hyperparameter to the BoostAODE class, to control the way the prior accuracy is computed if convergence is set. Default value is *false*.
- SPnDE model.
- A2DE model.
### Internal
@ -19,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Refactor catch2 library location to test/lib
- Refactor loadDataset function in tests.
- Remove conditionalEdgeWeights method in BayesMetrics.
- A2DE & SPnDE tests.
## [1.0.5] 2024-04-20

View File

@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.20)
project(BayesNet
VERSION 1.0.5
VERSION 1.0.5.1
DESCRIPTION "Bayesian Network and basic classifiers Library."
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
LANGUAGES CXX

View File

@ -16,7 +16,7 @@ namespace bayesnet {
addNodes();
std::vector<int> attributes;
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
if (std::find(parents.begin(), parents.end(), i) != parents.end()) {
if (std::find(parents.begin(), parents.end(), i) == parents.end()) {
attributes.push_back(i);
}
}
@ -25,6 +25,7 @@ namespace bayesnet {
for (const auto& attribute : attributes) {
model.addEdge(className, features[attribute]);
for (const auto& root : parents) {
model.addEdge(features[root], features[attribute]);
}
}

View File

@ -27,10 +27,11 @@ namespace bayesnet {
significanceModels.clear();
for (int i = 0; i < features.size() - 1; ++i) {
for (int j = i + 1; j < features.size(); ++j) {
models.push_back(std::make_unique<SPnDE>(std::vector<int>({ i, j })));
auto model = std::make_unique<SPnDE>(std::vector<int>({ i, j }));
models.push_back(std::move(model));
}
}
n_models = models.size();
n_models = static_cast<unsigned>(models.size());
significanceModels = std::vector<double>(n_models, 1.0);
}
std::vector<std::string> A2DE::graph(const std::string& title) const

View File

@ -9,7 +9,7 @@ if(ENABLE_TESTING)
)
file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc")
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc
TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc
TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc TestA2DE.cc
TestUtils.cc TestBayesEnsemble.cc ${BayesNet_SOURCES})
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp PRIVATE Catch2::Catch2WithMain)
add_test(NAME BayesNetworkTest COMMAND TestBayesNet)

View File

@ -17,10 +17,8 @@ TEST_CASE("Fit and Score", "[A2DE]")
auto raw = RawDatasets("glass", true);
auto clf = bayesnet::A2DE();
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states);
std::cout << "Score A2DE: " << clf.score(raw.Xv, raw.yv) << std::endl;
// REQUIRE(clf.getNumberOfNodes() == 90);
// REQUIRE(clf.getNumberOfEdges() == 153);
// REQUIRE(clf.getNotes().size() == 2);
// REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 9 with CFS");
// REQUIRE(clf.getNotes()[1] == "Number of models: 9");
REQUIRE(clf.score(raw.Xv, raw.yv) == Catch::Approx(0.831776).epsilon(raw.epsilon));
REQUIRE(clf.getNumberOfNodes() == 360);
REQUIRE(clf.getNumberOfEdges() == 756);
REQUIRE(clf.getNotes().size() == 0);
}