Compare commits
10 Commits
Author | SHA1 | Date | |
---|---|---|---|
a63a35df3f
|
|||
c7555dac3f
|
|||
f3b8150e2c
|
|||
03f8b8653b
|
|||
2163e95c4a
|
|||
b33da34655
|
|||
e17aee7bdb
|
|||
37c31ee4c2
|
|||
80afdc06f7
|
|||
|
666782217e |
5
.gitmodules
vendored
5
.gitmodules
vendored
@@ -8,11 +8,6 @@
|
|||||||
main = v2.x
|
main = v2.x
|
||||||
update = merge
|
update = merge
|
||||||
url = https://github.com/catchorg/Catch2.git
|
url = https://github.com/catchorg/Catch2.git
|
||||||
[submodule "lib/argparse"]
|
|
||||||
path = lib/argparse
|
|
||||||
url = https://github.com/p-ranav/argparse
|
|
||||||
master = master
|
|
||||||
update = merge
|
|
||||||
[submodule "lib/json"]
|
[submodule "lib/json"]
|
||||||
path = lib/json
|
path = lib/json
|
||||||
url = https://github.com/nlohmann/json.git
|
url = https://github.com/nlohmann/json.git
|
||||||
|
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
@@ -106,12 +106,12 @@
|
|||||||
"type": "lldb",
|
"type": "lldb",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"name": "test",
|
"name": "test",
|
||||||
"program": "${workspaceFolder}/build_debug/tests/unit_tests",
|
"program": "${workspaceFolder}/build_debug/tests/unit_tests_bayesnet",
|
||||||
"args": [
|
"args": [
|
||||||
"-c=\"Metrics Test\"",
|
//"-c=\"Metrics Test\"",
|
||||||
// "-s",
|
// "-s",
|
||||||
],
|
],
|
||||||
"cwd": "${workspaceFolder}/build/tests",
|
"cwd": "${workspaceFolder}/build_debug/tests",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Build & debug active file",
|
"name": "Build & debug active file",
|
||||||
|
32
CHANGELOG.md
Normal file
32
CHANGELOG.md
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [1.0.2] - 2024-02-20
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Fix bug in BoostAODE: do not include the model if epsilon sub t is greater than 0.5
|
||||||
|
- Fix bug in BoostAODE: compare accuracy with previous accuracy instead of the first of the ensemble if convergence true
|
||||||
|
|
||||||
|
## [1.0.1] - 2024-02-12
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Notes in Classifier class
|
||||||
|
- BoostAODE: Add note with used features in initialization with feature selection
|
||||||
|
- BoostAODE: Add note with the number of models
|
||||||
|
- BoostAODE: Add note with the number of features used to create models if not all features are used
|
||||||
|
- Test version number in TestBayesModels
|
||||||
|
- Add tests with feature_select and notes on BoostAODE
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Network predict test
|
||||||
|
- Network predict_proba test
|
||||||
|
- Network score test
|
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
project(BayesNet
|
project(BayesNet
|
||||||
VERSION 1.0.0
|
VERSION 1.0.2
|
||||||
DESCRIPTION "Bayesian Network and basic classifiers Library."
|
DESCRIPTION "Bayesian Network and basic classifiers Library."
|
||||||
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
|
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
|
||||||
LANGUAGES CXX
|
LANGUAGES CXX
|
||||||
@@ -52,7 +52,6 @@ endif (ENABLE_CLANG_TIDY)
|
|||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
# include(FetchContent)
|
# include(FetchContent)
|
||||||
add_git_submodule("lib/mdlp")
|
add_git_submodule("lib/mdlp")
|
||||||
add_git_submodule("lib/argparse")
|
|
||||||
add_git_submodule("lib/json")
|
add_git_submodule("lib/json")
|
||||||
|
|
||||||
# Subdirectories
|
# Subdirectories
|
||||||
|
BIN
docs/BoostAODE.docx
Normal file
BIN
docs/BoostAODE.docx
Normal file
Binary file not shown.
Submodule lib/catch2 updated: 766541d12d...863c662c0e
Submodule lib/folding updated: a3a2977996...37316a54e0
2
lib/json
2
lib/json
Submodule lib/json updated: edffad036d...a259ecc51e
1
lib/libxlsxwriter
Submodule
1
lib/libxlsxwriter
Submodule
Submodule lib/libxlsxwriter added at 29355a0887
@@ -26,6 +26,7 @@ namespace bayesnet {
|
|||||||
std::vector<std::string> virtual graph(const std::string& title = "") const = 0;
|
std::vector<std::string> virtual graph(const std::string& title = "") const = 0;
|
||||||
virtual std::string getVersion() = 0;
|
virtual std::string getVersion() = 0;
|
||||||
std::vector<std::string> virtual topological_order() = 0;
|
std::vector<std::string> virtual topological_order() = 0;
|
||||||
|
std::vector<std::string> virtual getNotes() const = 0;
|
||||||
void virtual dump_cpt()const = 0;
|
void virtual dump_cpt()const = 0;
|
||||||
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
|
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
|
||||||
std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
|
std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
|
||||||
|
@@ -115,11 +115,15 @@ namespace bayesnet {
|
|||||||
significanceModels.push_back(1.0);
|
significanceModels.push_back(1.0);
|
||||||
n_models++;
|
n_models++;
|
||||||
}
|
}
|
||||||
|
notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + algorithm);
|
||||||
delete featureSelector;
|
delete featureSelector;
|
||||||
return featuresUsed;
|
return featuresUsed;
|
||||||
}
|
}
|
||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
|
fitted = true;
|
||||||
|
// Algorithm based on the adaboost algorithm for classification
|
||||||
|
// as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
|
||||||
std::unordered_set<int> featuresUsed;
|
std::unordered_set<int> featuresUsed;
|
||||||
if (selectFeatures) {
|
if (selectFeatures) {
|
||||||
featuresUsed = initializeModels();
|
featuresUsed = initializeModels();
|
||||||
@@ -131,9 +135,8 @@ namespace bayesnet {
|
|||||||
// Variables to control the accuracy finish condition
|
// Variables to control the accuracy finish condition
|
||||||
double priorAccuracy = 0.0;
|
double priorAccuracy = 0.0;
|
||||||
double delta = 1.0;
|
double delta = 1.0;
|
||||||
double threshold = 1e-4;
|
double convergence_threshold = 1e-4;
|
||||||
int count = 0; // number of times the accuracy is lower than the threshold
|
int count = 0; // number of times the accuracy is lower than the convergence_threshold
|
||||||
fitted = true; // to enable predict
|
|
||||||
// Step 0: Set the finish condition
|
// Step 0: Set the finish condition
|
||||||
// if not repeatSparent a finish condition is run out of features
|
// if not repeatSparent a finish condition is run out of features
|
||||||
// n_models == maxModels
|
// n_models == maxModels
|
||||||
@@ -159,7 +162,6 @@ namespace bayesnet {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
featuresUsed.insert(feature);
|
|
||||||
model = std::make_unique<SPODE>(feature);
|
model = std::make_unique<SPODE>(feature);
|
||||||
model->fit(dataset, features, className, states, weights_);
|
model->fit(dataset, features, className, states, weights_);
|
||||||
auto ypred = model->predict(X_train);
|
auto ypred = model->predict(X_train);
|
||||||
@@ -168,6 +170,12 @@ namespace bayesnet {
|
|||||||
auto mask_right = ypred == y_train;
|
auto mask_right = ypred == y_train;
|
||||||
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
|
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
|
||||||
double epsilon_t = masked_weights.sum().item<double>();
|
double epsilon_t = masked_weights.sum().item<double>();
|
||||||
|
if (epsilon_t > 0.5) {
|
||||||
|
// Inverse the weights policy (plot ln(wt))
|
||||||
|
// "In each round of AdaBoost, there is a sanity check to ensure that the current base
|
||||||
|
// learner is better than random guess" (Zhi-Hua Zhou, 2012)
|
||||||
|
break;
|
||||||
|
}
|
||||||
double wt = (1 - epsilon_t) / epsilon_t;
|
double wt = (1 - epsilon_t) / epsilon_t;
|
||||||
double alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);
|
double alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);
|
||||||
// Step 3.2: Update weights for next classifier
|
// Step 3.2: Update weights for next classifier
|
||||||
@@ -179,6 +187,7 @@ namespace bayesnet {
|
|||||||
double totalWeights = torch::sum(weights_).item<double>();
|
double totalWeights = torch::sum(weights_).item<double>();
|
||||||
weights_ = weights_ / totalWeights;
|
weights_ = weights_ / totalWeights;
|
||||||
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
||||||
|
featuresUsed.insert(feature);
|
||||||
models.push_back(std::move(model));
|
models.push_back(std::move(model));
|
||||||
significanceModels.push_back(alpha_t);
|
significanceModels.push_back(alpha_t);
|
||||||
n_models++;
|
n_models++;
|
||||||
@@ -190,15 +199,18 @@ namespace bayesnet {
|
|||||||
} else {
|
} else {
|
||||||
delta = accuracy - priorAccuracy;
|
delta = accuracy - priorAccuracy;
|
||||||
}
|
}
|
||||||
if (delta < threshold) {
|
if (delta < convergence_threshold) {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
priorAccuracy = accuracy;
|
||||||
}
|
}
|
||||||
exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
exitCondition = n_models >= maxModels && repeatSparent || count > tolerance;
|
||||||
}
|
}
|
||||||
if (featuresUsed.size() != features.size()) {
|
if (featuresUsed.size() != features.size()) {
|
||||||
|
notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));
|
||||||
status = WARNING;
|
status = WARNING;
|
||||||
}
|
}
|
||||||
|
notes.push_back("Number of models: " + std::to_string(n_models));
|
||||||
}
|
}
|
||||||
std::vector<std::string> BoostAODE::graph(const std::string& title) const
|
std::vector<std::string> BoostAODE::graph(const std::string& title) const
|
||||||
{
|
{
|
||||||
|
@@ -19,6 +19,7 @@ namespace bayesnet {
|
|||||||
std::map<std::string, std::vector<int>> states;
|
std::map<std::string, std::vector<int>> states;
|
||||||
torch::Tensor dataset; // (n+1)xm tensor
|
torch::Tensor dataset; // (n+1)xm tensor
|
||||||
status_t status = NORMAL;
|
status_t status = NORMAL;
|
||||||
|
std::vector<std::string> notes; // Used to store messages occurred during the fit process
|
||||||
void checkFitParameters();
|
void checkFitParameters();
|
||||||
virtual void buildModel(const torch::Tensor& weights) = 0;
|
virtual void buildModel(const torch::Tensor& weights) = 0;
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
@@ -36,12 +37,13 @@ namespace bayesnet {
|
|||||||
int getNumberOfStates() const override;
|
int getNumberOfStates() const override;
|
||||||
torch::Tensor predict(torch::Tensor& X) override;
|
torch::Tensor predict(torch::Tensor& X) override;
|
||||||
status_t getStatus() const override { return status; }
|
status_t getStatus() const override { return status; }
|
||||||
std::string getVersion() override { return "0.2.0"; };
|
std::string getVersion() override { return { project_version.begin(), project_version.end() }; };
|
||||||
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
||||||
float score(torch::Tensor& X, torch::Tensor& y) override;
|
float score(torch::Tensor& X, torch::Tensor& y) override;
|
||||||
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
|
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
|
||||||
std::vector<std::string> show() const override;
|
std::vector<std::string> show() const override;
|
||||||
std::vector<std::string> topological_order() override;
|
std::vector<std::string> topological_order() override;
|
||||||
|
std::vector<std::string> getNotes() const override { return notes; }
|
||||||
void dump_cpt() const override;
|
void dump_cpt() const override;
|
||||||
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
|
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
|
||||||
};
|
};
|
||||||
|
@@ -7,7 +7,6 @@ if(ENABLE_TESTING)
|
|||||||
${BayesNet_SOURCE_DIR}/lib/mdlp
|
${BayesNet_SOURCE_DIR}/lib/mdlp
|
||||||
${BayesNet_SOURCE_DIR}/lib/folding
|
${BayesNet_SOURCE_DIR}/lib/folding
|
||||||
${BayesNet_SOURCE_DIR}/lib/json/include
|
${BayesNet_SOURCE_DIR}/lib/json/include
|
||||||
${BayesNet_SOURCE_DIR}/lib/argparse/include
|
|
||||||
${CMAKE_BINARY_DIR}/configured_files/include
|
${CMAKE_BINARY_DIR}/configured_files/include
|
||||||
)
|
)
|
||||||
set(TEST_SOURCES_BAYESNET TestBayesModels.cc TestBayesNetwork.cc TestBayesMetrics.cc TestUtils.cc ${BayesNet_SOURCES})
|
set(TEST_SOURCES_BAYESNET TestBayesModels.cc TestBayesNetwork.cc TestBayesMetrics.cc TestUtils.cc ${BayesNet_SOURCES})
|
||||||
|
@@ -16,6 +16,11 @@
|
|||||||
#include "AODELd.h"
|
#include "AODELd.h"
|
||||||
#include "TestUtils.h"
|
#include "TestUtils.h"
|
||||||
|
|
||||||
|
TEST_CASE("Library check version", "[BayesNet]")
|
||||||
|
{
|
||||||
|
auto clf = bayesnet::KDB(2);
|
||||||
|
REQUIRE(clf.getVersion() == "1.0.2");
|
||||||
|
}
|
||||||
TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
|
TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
|
||||||
{
|
{
|
||||||
map <pair<std::string, std::string>, float> scores = {
|
map <pair<std::string, std::string>, float> scores = {
|
||||||
@@ -138,4 +143,35 @@ TEST_CASE("Get num features & num edges", "[BayesNet]")
|
|||||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||||
REQUIRE(clf.getNumberOfEdges() == 8);
|
REQUIRE(clf.getNumberOfEdges() == 8);
|
||||||
}
|
}
|
||||||
|
TEST_CASE("BoostAODE feature_select CFS", "[BayesNet]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("glass", true);
|
||||||
|
auto clf = bayesnet::BoostAODE();
|
||||||
|
clf.setHyperparameters({ {"select_features", "CFS"} });
|
||||||
|
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
|
REQUIRE(clf.getNumberOfNodes() == 90);
|
||||||
|
REQUIRE(clf.getNumberOfEdges() == 153);
|
||||||
|
REQUIRE(clf.getNotes().size() == 2);
|
||||||
|
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 9 with CFS");
|
||||||
|
REQUIRE(clf.getNotes()[1] == "Number of models: 9");
|
||||||
|
}
|
||||||
|
TEST_CASE("BoostAODE test used features in train note", "[BayesNet]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("diabetes", true);
|
||||||
|
auto clf = bayesnet::BoostAODE();
|
||||||
|
clf.setHyperparameters({
|
||||||
|
{"ascending",true},
|
||||||
|
{"convergence", true},
|
||||||
|
{"repeatSparent",true},
|
||||||
|
{"select_features","CFS"},
|
||||||
|
{"tolerance", 3}
|
||||||
|
});
|
||||||
|
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
|
REQUIRE(clf.getNumberOfNodes() == 72);
|
||||||
|
REQUIRE(clf.getNumberOfEdges() == 120);
|
||||||
|
REQUIRE(clf.getNotes().size() == 3);
|
||||||
|
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 8 with CFS");
|
||||||
|
REQUIRE(clf.getNotes()[1] == "Used features in train: 7 of 8");
|
||||||
|
REQUIRE(clf.getNotes()[2] == "Number of models: 8");
|
||||||
|
}
|
||||||
|
@@ -25,6 +25,7 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]")
|
|||||||
|
|
||||||
auto raw = RawDatasets("iris", true);
|
auto raw = RawDatasets("iris", true);
|
||||||
auto net = bayesnet::Network();
|
auto net = bayesnet::Network();
|
||||||
|
double threshold = 1e-4;
|
||||||
|
|
||||||
SECTION("Test get features")
|
SECTION("Test get features")
|
||||||
{
|
{
|
||||||
@@ -167,97 +168,44 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]")
|
|||||||
REQUIRE(str[5] == "C [shape=circle] \n");
|
REQUIRE(str[5] == "C [shape=circle] \n");
|
||||||
REQUIRE(str[6] == "}\n");
|
REQUIRE(str[6] == "}\n");
|
||||||
}
|
}
|
||||||
|
SECTION("Test predict")
|
||||||
|
{
|
||||||
// SECTION("Test predict")
|
auto net = bayesnet::Network();
|
||||||
// {
|
buildModel(net, raw.featuresv, raw.classNamev);
|
||||||
// auto net = bayesnet::Network();
|
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
// net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
||||||
// std::vector<std::vector<int>> test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} };
|
std::vector<int> y_test = { 2, 2, 0, 2, 1 };
|
||||||
// std::vector<int> y_test = { 0, 1, 1, 0, 2 };
|
auto y_pred = net.predict(test);
|
||||||
// auto y_pred = net.predict(test);
|
REQUIRE(y_pred == y_test);
|
||||||
// REQUIRE(y_pred == y_test);
|
}
|
||||||
// }
|
SECTION("Test predict_proba")
|
||||||
|
{
|
||||||
// SECTION("Test predict_proba")
|
auto net = bayesnet::Network();
|
||||||
// {
|
buildModel(net, raw.featuresv, raw.classNamev);
|
||||||
// auto net = bayesnet::Network();
|
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
// net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
||||||
// std::vector<std::vector<int>> test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} };
|
std::vector<std::vector<double>> y_test = {
|
||||||
// auto y_test = { 0, 1, 1, 0, 2 };
|
{0.450237, 0.0866621, 0.463101},
|
||||||
// auto y_pred = net.predict(test);
|
{0.244443, 0.0925922, 0.662964},
|
||||||
// REQUIRE(y_pred == y_test);
|
{0.913441, 0.0125857, 0.0739732},
|
||||||
// }
|
{0.450237, 0.0866621, 0.463101},
|
||||||
}
|
{0.0135226, 0.971726, 0.0147519}
|
||||||
|
};
|
||||||
// SECTION("Test score")
|
auto y_pred = net.predict_proba(test);
|
||||||
// {
|
REQUIRE(y_pred.size() == 5);
|
||||||
// auto net = bayesnet::Network();
|
REQUIRE(y_pred[0].size() == 3);
|
||||||
// net.fit(Xd, y, weights, features, className, states);
|
for (int i = 0; i < y_pred.size(); ++i) {
|
||||||
// auto test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} };
|
for (int j = 0; j < y_pred[i].size(); ++j) {
|
||||||
// auto score = net.score(X, y);
|
REQUIRE(y_pred[i][j] == Catch::Approx(y_test[i][j]).margin(threshold));
|
||||||
// REQUIRE(score == Catch::Approx();
|
}
|
||||||
// }
|
}
|
||||||
|
}
|
||||||
//
|
SECTION("Test score")
|
||||||
//
|
{
|
||||||
|
auto net = bayesnet::Network();
|
||||||
// SECTION("Test graph")
|
buildModel(net, raw.featuresv, raw.classNamev);
|
||||||
// {
|
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
// auto net = bayesnet::Network();
|
auto score = net.score(raw.Xv, raw.yv);
|
||||||
// net.addNode("A");
|
REQUIRE(score == Catch::Approx(0.97333333).margin(threshold));
|
||||||
// net.addNode("B");
|
}
|
||||||
// net.addNode("C");
|
}
|
||||||
// net.addEdge("A", "B");
|
|
||||||
// net.addEdge("A", "C");
|
|
||||||
// auto str = net.graph("Test Graph");
|
|
||||||
// REQUIRE(str.size() == 6);
|
|
||||||
// REQUIRE(str[0] == "digraph \"Test Graph\" {");
|
|
||||||
// REQUIRE(str[1] == " A -> B;");
|
|
||||||
// REQUIRE(str[2] == " A -> C;");
|
|
||||||
// REQUIRE(str[3] == " B [shape=ellipse];");
|
|
||||||
// REQUIRE(str[4] == " C [shape=ellipse];");
|
|
||||||
// REQUIRE(str[5] == "}");
|
|
||||||
// }
|
|
||||||
|
|
||||||
// SECTION("Test initialize")
|
|
||||||
// {
|
|
||||||
// auto net = bayesnet::Network();
|
|
||||||
// net.addNode("A");
|
|
||||||
// net.addNode("B");
|
|
||||||
// net.addNode("C");
|
|
||||||
// net.addEdge("A", "B");
|
|
||||||
// net.addEdge("A", "C");
|
|
||||||
// net.initialize();
|
|
||||||
// REQUIRE(net.getNodes().size() == 0);
|
|
||||||
// REQUIRE(net.getEdges().size() == 0);
|
|
||||||
// REQUIRE(net.getFeatures().size() == 0);
|
|
||||||
// REQUIRE(net.getClassNumStates() == 0);
|
|
||||||
// REQUIRE(net.getClassName().empty());
|
|
||||||
// REQUIRE(net.getStates() == 0);
|
|
||||||
// REQUIRE(net.getSamples().numel() == 0);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// SECTION("Test dump_cpt")
|
|
||||||
// {
|
|
||||||
// auto net = bayesnet::Network();
|
|
||||||
// net.addNode("A");
|
|
||||||
// net.addNode("B");
|
|
||||||
// net.addNode("C");
|
|
||||||
// net.addEdge("A", "B");
|
|
||||||
// net.addEdge("A", "C");
|
|
||||||
// net.setClassName("C");
|
|
||||||
// net.setStates({ {"A", {0, 1}}, {"B", {0, 1}}, {"C", {0, 1, 2}} });
|
|
||||||
// net.fit({ {0, 0}, {0, 1}, {1, 0}, {1, 1} }, { 0, 1, 1, 2 }, {}, { "A", "B" }, "C", { {"A", {0, 1}}, {"B", {0, 1}}, {"C", {0, 1, 2}} });
|
|
||||||
// net.dump_cpt();
|
|
||||||
// // TODO: Check that the file was created and contains the expected data
|
|
||||||
// }
|
|
||||||
|
|
||||||
// SECTION("Test version")
|
|
||||||
// {
|
|
||||||
// auto net = bayesnet::Network();
|
|
||||||
// REQUIRE(net.version() == "0.2.0");
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// }
|
|
Reference in New Issue
Block a user