diff --git a/sample/sample.cc b/sample/sample.cc index 9d7175f..10b80ba 100644 --- a/sample/sample.cc +++ b/sample/sample.cc @@ -4,7 +4,7 @@ #include #include #include -#include "ArffFiles.h" +#include "ArffFiles.h"v #include "BayesMetrics.h" #include "CPPFImdlp.h" #include "Folding.h" diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index bcf4301..6434d5d 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -157,7 +157,7 @@ namespace bayesnet { completeFit(states, weights); } // input_data comes in nxm, where n is the number of features and m the number of samples - void Network::fit(const vector>& input_data, const vector& labels, const vector& weights_, const vector& featureNames, const string& className, const map>& states) + void Network::fit(const vector>& input_data, const vector& labels, const vector& weights_, const vector& featureNames, const string& className, const map>& states) { const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64); checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights); diff --git a/src/BayesNet/Network.h b/src/BayesNet/Network.h index 0bf1b08..e720c52 100644 --- a/src/BayesNet/Network.h +++ b/src/BayesNet/Network.h @@ -39,7 +39,7 @@ namespace bayesnet { int getNumEdges() const; int getClassNumStates() const; string getClassName() const; - void fit(const vector>& input_data, const vector& labels, const vector& weights, const vector& featureNames, const string& className, const map>& states); + void fit(const vector>& input_data, const vector& labels, const vector& weights, const vector& featureNames, const string& className, const map>& states); void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector& featureNames, const string& className, const map>& states); void fit(const torch::Tensor& samples, const torch::Tensor& weights, const vector& featureNames, const string& className, const map>& states); vector predict(const vector>&); // Return mx1 vector of predictions diff --git a/tests/TestBayesNetwork.cc b/tests/TestBayesNetwork.cc index 0a31ca8..128bd16 100644 --- a/tests/TestBayesNetwork.cc +++ b/tests/TestBayesNetwork.cc @@ -3,11 +3,12 @@ #include #include #include "TestUtils.h" -#include "KDB.h" +#include "Network.h" TEST_CASE("Test Bayesian Network", "[BayesNet]") { - auto [Xd, y, features, className, states] = loadFile("iris"); + + auto raw = RawDatasets("iris", true); SECTION("Test get features") { @@ -27,7 +28,144 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]") net.addEdge("A", "B"); net.addEdge("B", "C"); REQUIRE(net.getEdges() == vector>{ {"A", "B"}, { "B", "C" } }); + REQUIRE(net.getNumEdges() == 2); net.addEdge("A", "C"); REQUIRE(net.getEdges() == vector>{ {"A", "B"}, { "A", "C" }, { "B", "C" } }); + REQUIRE(net.getNumEdges() == 3); } + SECTION("Test getNodes") + { + auto net = bayesnet::Network(); + net.addNode("A"); + net.addNode("B"); + auto& nodes = net.getNodes(); + REQUIRE(nodes.count("A") == 1); + REQUIRE(nodes.count("B") == 1); + } + + SECTION("Test fit") + { + auto net = bayesnet::Network(); + // net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv); + net.fit(raw.Xt, raw.yt, raw.weights, raw.featurest, raw.classNamet, raw.statest); + REQUIRE(net.getClassName() == "class"); + } + + // SECTION("Test predict") + // { + // auto net = bayesnet::Network(); + // net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv); + // vector> test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} }; + // vector y_test = { 0, 1, 1, 0, 2 }; + // auto y_pred = net.predict(test); + // REQUIRE(y_pred == y_test); + // } + + // SECTION("Test predict_proba") + // { + // auto net = bayesnet::Network(); + // net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv); + // vector> test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} }; + // auto y_test = { 0, 1, 1, 0, 2 }; + // auto y_pred = net.predict(test); + // REQUIRE(y_pred == y_test); + // } } + +// SECTION("Test score") +// { +// auto net = bayesnet::Network(); +// net.fit(Xd, y, weights, features, className, states); +// auto test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} }; +// auto score = net.score(X, y); +// REQUIRE(score == Catch::Approx(); +// } + +// SECTION("Test topological_sort") +// { +// auto net = bayesnet::Network(); +// net.addNode("A"); +// net.addNode("B"); +// net.addNode("C"); +// net.addEdge("A", "B"); +// net.addEdge("A", "C"); +// auto sorted = net.topological_sort(); +// REQUIRE(sorted.size() == 3); +// REQUIRE(sorted[0] == "A"); +// REQUIRE((sorted[1] == "B" && sorted[2] == "C") || (sorted[1] == "C" && sorted[2] == "B")); +// } + +// SECTION("Test show") +// { +// auto net = bayesnet::Network(); +// net.addNode("A"); +// net.addNode("B"); +// net.addNode("C"); +// net.addEdge("A", "B"); +// net.addEdge("A", "C"); +// auto str = net.show(); +// REQUIRE(str.size() == 3); +// REQUIRE(str[0] == "A"); +// REQUIRE(str[1] == "B -> C"); +// REQUIRE(str[2] == "C"); +// } + +// SECTION("Test graph") +// { +// auto net = bayesnet::Network(); +// net.addNode("A"); +// net.addNode("B"); +// net.addNode("C"); +// net.addEdge("A", "B"); +// net.addEdge("A", "C"); +// auto str = net.graph("Test Graph"); +// REQUIRE(str.size() == 6); +// REQUIRE(str[0] == "digraph \"Test Graph\" {"); +// REQUIRE(str[1] == " A -> B;"); +// REQUIRE(str[2] == " A -> C;"); +// REQUIRE(str[3] == " B [shape=ellipse];"); +// REQUIRE(str[4] == " C [shape=ellipse];"); +// REQUIRE(str[5] == "}"); +// } + +// SECTION("Test initialize") +// { +// auto net = bayesnet::Network(); +// net.addNode("A"); +// net.addNode("B"); +// net.addNode("C"); +// net.addEdge("A", "B"); +// net.addEdge("A", "C"); +// net.initialize(); +// REQUIRE(net.getNodes().size() == 0); +// REQUIRE(net.getEdges().size() == 0); +// REQUIRE(net.getFeatures().size() == 0); +// REQUIRE(net.getClassNumStates() == 0); +// REQUIRE(net.getClassName().empty()); +// REQUIRE(net.getStates() == 0); +// REQUIRE(net.getSamples().numel() == 0); +// } + +// SECTION("Test dump_cpt") +// { +// auto net = bayesnet::Network(); +// net.addNode("A"); +// net.addNode("B"); +// net.addNode("C"); +// net.addEdge("A", "B"); +// net.addEdge("A", "C"); +// net.setClassName("C"); +// net.setStates({ {"A", {0, 1}}, {"B", {0, 1}}, {"C", {0, 1, 2}} }); +// net.fit({ {0, 0}, {0, 1}, {1, 0}, {1, 1} }, { 0, 1, 1, 2 }, {}, { "A", "B" }, "C", { {"A", {0, 1}}, {"B", {0, 1}}, {"C", {0, 1, 2}} }); +// net.dump_cpt(); +// // TODO: Check that the file was created and contains the expected data +// } + +// SECTION("Test version") +// { +// auto net = bayesnet::Network(); +// REQUIRE(net.version() == "0.2.0"); +// } +// } + +// } diff --git a/tests/TestUtils.h b/tests/TestUtils.h index f442f85..1d091a7 100644 --- a/tests/TestUtils.h +++ b/tests/TestUtils.h @@ -27,10 +27,12 @@ public: dataset = torch::cat({ Xt, yresized }, 0); nSamples = dataset.size(1); weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble); + weightsv = vector(nSamples, 1.0 / nSamples); classNumStates = discretize ? statest.at(classNamet).size() : 0; } torch::Tensor Xt, yt, dataset, weights; vector> Xv; + vector weightsv; vector yv; vector featurest, featuresv; map> statest, statesv;