First implemented aproximation
This commit is contained in:
@@ -9,11 +9,12 @@ if(ENABLE_TESTING)
|
||||
)
|
||||
file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc")
|
||||
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc
|
||||
TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc TestA2DE.cc
|
||||
TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc TestA2DE.cc TestWA2DE.cc
|
||||
TestUtils.cc TestBayesEnsemble.cc TestModulesVersions.cc TestBoostA2DE.cc TestMST.cc ${BayesNet_SOURCES})
|
||||
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" fimdlp PRIVATE Catch2::Catch2WithMain)
|
||||
add_test(NAME BayesNetworkTest COMMAND TestBayesNet)
|
||||
add_test(NAME A2DE COMMAND TestBayesNet "[A2DE]")
|
||||
add_test(NAME WA2DE COMMAND TestBayesNet "[WA2DE]")
|
||||
add_test(NAME BoostA2DE COMMAND TestBayesNet "[BoostA2DE]")
|
||||
add_test(NAME BoostAODE COMMAND TestBayesNet "[BoostAODE]")
|
||||
add_test(NAME Classifier COMMAND TestBayesNet "[Classifier]")
|
||||
|
31
tests/TestWA2DE.cc
Normal file
31
tests/TestWA2DE.cc
Normal file
@@ -0,0 +1,31 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include <type_traits>
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include "bayesnet/ensembles/WA2DE.h"
|
||||
#include "TestUtils.h"
|
||||
|
||||
|
||||
TEST_CASE("Fit and Score", "[WA2DE]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::WA2DE();
|
||||
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
REQUIRE(clf.score(raw.Xt, raw.yt) == Catch::Approx(0.831776).epsilon(raw.epsilon));
|
||||
}
|
||||
TEST_CASE("Test graph", "[WA2DE]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::WA2DE();
|
||||
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto graph = clf.graph("BayesNet WA2DE");
|
||||
REQUIRE(graph.size() == 2);
|
||||
REQUIRE(graph[0] == "BayesNet WA2DE");
|
||||
REQUIRE(graph[1] == "Graph visualization not implemented.");
|
||||
}
|
Reference in New Issue
Block a user