Refactor Cestnik smoothin factor assuming m=1
This commit is contained in:
parent
b34869cc61
commit
ca0ae4dacf
@ -7,7 +7,7 @@
|
||||
[![Security Rating](https://sonarcloud.io/api/project_badges/measure?project=rmontanana_BayesNet&metric=security_rating)](https://sonarcloud.io/summary/new_code?id=rmontanana_BayesNet)
|
||||
[![Reliability Rating](https://sonarcloud.io/api/project_badges/measure?project=rmontanana_BayesNet&metric=reliability_rating)](https://sonarcloud.io/summary/new_code?id=rmontanana_BayesNet)
|
||||
![Gitea Last Commit](https://img.shields.io/gitea/last-commit/rmontanana/bayesnet?gitea_url=https://gitea.rmontanana.es:3000&logo=gitea)
|
||||
[![Coverage Badge](https://img.shields.io/badge/Coverage-96,9%25-green)](html/index.html)
|
||||
[![Coverage Badge](https://img.shields.io/badge/Coverage-97,0%25-green)](html/index.html)
|
||||
|
||||
Bayesian Network Classifiers using libtorch from scratch
|
||||
|
||||
|
@ -204,8 +204,8 @@ namespace bayesnet {
|
||||
case Smoothing_t::LAPLACE:
|
||||
smoothing_factor = 1.0;
|
||||
break;
|
||||
case Smoothing_t::CESTNIK:
|
||||
smoothing_factor = n_samples / numStates;
|
||||
case Smoothing_t::CESTNIK: // Considering m=1 pa = 1/numStates
|
||||
smoothing_factor = 1 / numStates;
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast<int>(smoothing)));
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "bayesnet/network/Node.h"
|
||||
#include "bayesnet/utils/bayesnetUtils.h"
|
||||
|
||||
const double threshold = 1e-4;
|
||||
void buildModel(bayesnet::Network& net, const std::vector<std::string>& features, const std::string& className)
|
||||
{
|
||||
std::vector<pair<int, int>> network = { {0, 1}, {0, 2}, {1, 3} };
|
||||
@ -29,13 +30,11 @@ void buildModel(bayesnet::Network& net, const std::vector<std::string>& features
|
||||
net.addEdge(className, feature);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
{
|
||||
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto net = bayesnet::Network();
|
||||
double threshold = 1e-4;
|
||||
|
||||
SECTION("Test get features")
|
||||
{
|
||||
@ -459,3 +458,69 @@ TEST_CASE("Dump CPT", "[Network]")
|
||||
REQUIRE(res == expected);
|
||||
}
|
||||
|
||||
TEST_CASE("Test Smoothing", "[Network]")
|
||||
{
|
||||
/*
|
||||
|
||||
Tomando m = 1 Pa = 0.5
|
||||
Si estoy calculando P(A | C), con C en{ 0,1,2 } y tengo :
|
||||
AC = { 11, 12, 11, 10, 10, 12, 10, 01, 00, 02 }
|
||||
Entonces:
|
||||
P(A = 1 | C = 0) = (3 + 1 / 2 * 1) / (4 + 1) = 3.5 / 5
|
||||
P(A = 0 | C = 0) = (1 + 1 / 2 * 1) / (4 + 1) = 1.5 / 5
|
||||
Donde m aquí es el número de veces de C = 0 que es la que condiciona y la a priori vuelve a ser sobre A que es sobre las que estaríamos calculando esas marginales.
|
||||
P(A = 1 | C = 1) = (2 + 1 / 2 * 1) / (3 + 1) = 2.5 / 4
|
||||
P(A = 0 | C = 1) = (1 + 1 / 2 * 1) / (3 + 1) = 1.5 / 4
|
||||
P(A = 1 | C = 2) = (2 + 1 / 2 * 1) / (3 + 1) = 2.5 / 5
|
||||
P(A = 0 | C = 2) = (1 + 1 / 2 * 1) / (3 + 1) = 1.5 / 5
|
||||
En realidad es parecido a Laplace, que en este caso p.e.con C = 0 sería
|
||||
P(A = 1 | C = 0) = (3 + 1) / (4 + 2) = 4 / 6
|
||||
P(A = 0 | C = 0) = (1 + 1) / (4 + 2) = 2 / 6
|
||||
*/
|
||||
auto net = bayesnet::Network();
|
||||
net.addNode("A");
|
||||
net.addNode("C");
|
||||
net.addEdge("C", "A");
|
||||
std::vector<int> C = { 1, 2, 1, 0, 0, 2, 0, 1, 0, 2 };
|
||||
std::vector<std::vector<int>> A = { { 1, 1, 1, 1, 1, 1, 1, 0, 0, 0 } };
|
||||
std::map<std::string, std::vector<int>> states = { { "A", {0, 1} }, { "C", {0, 1, 2} } };
|
||||
auto weights = std::vector<double>(C.size(), 1);
|
||||
//
|
||||
// Laplace
|
||||
//
|
||||
net.fit(A, C, weights, { "A" }, "C", states, bayesnet::Smoothing_t::LAPLACE);
|
||||
auto cpt_c_laplace = net.getNodes().at("C")->getCPT();
|
||||
REQUIRE(cpt_c_laplace.size(0) == 3);
|
||||
auto laplace_c = std::vector<float>({ 0.3846, 0.3077, 0.3077 });
|
||||
for (int i = 0; i < laplace_c.size(); ++i) {
|
||||
REQUIRE(cpt_c_laplace.index({ i }).item<float>() == Catch::Approx(laplace_c[i]).margin(threshold));
|
||||
}
|
||||
auto cpt_a_laplace = net.getNodes().at("A")->getCPT();
|
||||
REQUIRE(cpt_a_laplace.size(0) == 2);
|
||||
REQUIRE(cpt_a_laplace.size(1) == 3);
|
||||
auto laplace_a = std::vector<std::vector<float>>({ {0.3333, 0.4000,0.4000}, {0.6667, 0.6000, 0.6000} });
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
REQUIRE(cpt_a_laplace.index({ i, j }).item<float>() == Catch::Approx(laplace_a[i][j]).margin(threshold));
|
||||
}
|
||||
}
|
||||
//
|
||||
// Cestnik
|
||||
//
|
||||
net.fit(A, C, weights, { "A" }, "C", states, bayesnet::Smoothing_t::CESTNIK);
|
||||
auto cpt_c_cestnik = net.getNodes().at("C")->getCPT();
|
||||
REQUIRE(cpt_c_cestnik.size(0) == 3);
|
||||
auto cestnik_c = std::vector<float>({ 0.3939, 0.3030, 0.3030 });
|
||||
for (int i = 0; i < laplace_c.size(); ++i) {
|
||||
REQUIRE(cpt_c_cestnik.index({ i }).item<float>() == Catch::Approx(cestnik_c[i]).margin(threshold));
|
||||
}
|
||||
auto cpt_a_cestnik = net.getNodes().at("A")->getCPT();
|
||||
REQUIRE(cpt_a_cestnik.size(0) == 2);
|
||||
REQUIRE(cpt_a_cestnik.size(1) == 3);
|
||||
auto cestnik_a = std::vector<std::vector<float>>({ {0.3000, 0.3750, 0.3750}, {0.7000, 0.6250, 0.6250} });
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
REQUIRE(cpt_a_cestnik.index({ i, j }).item<float>() == Catch::Approx(cestnik_a[i][j]).margin(threshold));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user