Continue test Network
This commit is contained in:
parent
4b732e76c2
commit
e3ae073333
@ -4,7 +4,7 @@
|
||||
#include <map>
|
||||
#include <argparse/argparse.hpp>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "ArffFiles.h"
|
||||
#include "ArffFiles.h"v
|
||||
#include "BayesMetrics.h"
|
||||
#include "CPPFImdlp.h"
|
||||
#include "Folding.h"
|
||||
|
@ -157,7 +157,7 @@ namespace bayesnet {
|
||||
completeFit(states, weights);
|
||||
}
|
||||
// input_data comes in nxm, where n is the number of features and m the number of samples
|
||||
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<float>& weights_, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
|
||||
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<double>& weights_, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
|
||||
{
|
||||
const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
|
||||
checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
|
||||
|
@ -39,7 +39,7 @@ namespace bayesnet {
|
||||
int getNumEdges() const;
|
||||
int getClassNumStates() const;
|
||||
string getClassName() const;
|
||||
void fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<float>& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
|
||||
void fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<double>& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
|
||||
void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
|
||||
void fit(const torch::Tensor& samples, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
|
||||
vector<int> predict(const vector<vector<int>>&); // Return mx1 vector of predictions
|
||||
|
@ -3,11 +3,12 @@
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <string>
|
||||
#include "TestUtils.h"
|
||||
#include "KDB.h"
|
||||
#include "Network.h"
|
||||
|
||||
TEST_CASE("Test Bayesian Network", "[BayesNet]")
|
||||
{
|
||||
auto [Xd, y, features, className, states] = loadFile("iris");
|
||||
|
||||
auto raw = RawDatasets("iris", true);
|
||||
|
||||
SECTION("Test get features")
|
||||
{
|
||||
@ -27,7 +28,144 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]")
|
||||
net.addEdge("A", "B");
|
||||
net.addEdge("B", "C");
|
||||
REQUIRE(net.getEdges() == vector<pair<string, string>>{ {"A", "B"}, { "B", "C" } });
|
||||
REQUIRE(net.getNumEdges() == 2);
|
||||
net.addEdge("A", "C");
|
||||
REQUIRE(net.getEdges() == vector<pair<string, string>>{ {"A", "B"}, { "A", "C" }, { "B", "C" } });
|
||||
REQUIRE(net.getNumEdges() == 3);
|
||||
}
|
||||
SECTION("Test getNodes")
|
||||
{
|
||||
auto net = bayesnet::Network();
|
||||
net.addNode("A");
|
||||
net.addNode("B");
|
||||
auto& nodes = net.getNodes();
|
||||
REQUIRE(nodes.count("A") == 1);
|
||||
REQUIRE(nodes.count("B") == 1);
|
||||
}
|
||||
|
||||
SECTION("Test fit")
|
||||
{
|
||||
auto net = bayesnet::Network();
|
||||
// net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
net.fit(raw.Xt, raw.yt, raw.weights, raw.featurest, raw.classNamet, raw.statest);
|
||||
REQUIRE(net.getClassName() == "class");
|
||||
}
|
||||
|
||||
// SECTION("Test predict")
|
||||
// {
|
||||
// auto net = bayesnet::Network();
|
||||
// net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
// vector<vector<int>> test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} };
|
||||
// vector<int> y_test = { 0, 1, 1, 0, 2 };
|
||||
// auto y_pred = net.predict(test);
|
||||
// REQUIRE(y_pred == y_test);
|
||||
// }
|
||||
|
||||
// SECTION("Test predict_proba")
|
||||
// {
|
||||
// auto net = bayesnet::Network();
|
||||
// net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
// vector<vector<int>> test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} };
|
||||
// auto y_test = { 0, 1, 1, 0, 2 };
|
||||
// auto y_pred = net.predict(test);
|
||||
// REQUIRE(y_pred == y_test);
|
||||
// }
|
||||
}
|
||||
|
||||
// SECTION("Test score")
|
||||
// {
|
||||
// auto net = bayesnet::Network();
|
||||
// net.fit(Xd, y, weights, features, className, states);
|
||||
// auto test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} };
|
||||
// auto score = net.score(X, y);
|
||||
// REQUIRE(score == Catch::Approx();
|
||||
// }
|
||||
|
||||
// SECTION("Test topological_sort")
|
||||
// {
|
||||
// auto net = bayesnet::Network();
|
||||
// net.addNode("A");
|
||||
// net.addNode("B");
|
||||
// net.addNode("C");
|
||||
// net.addEdge("A", "B");
|
||||
// net.addEdge("A", "C");
|
||||
// auto sorted = net.topological_sort();
|
||||
// REQUIRE(sorted.size() == 3);
|
||||
// REQUIRE(sorted[0] == "A");
|
||||
// REQUIRE((sorted[1] == "B" && sorted[2] == "C") || (sorted[1] == "C" && sorted[2] == "B"));
|
||||
// }
|
||||
|
||||
// SECTION("Test show")
|
||||
// {
|
||||
// auto net = bayesnet::Network();
|
||||
// net.addNode("A");
|
||||
// net.addNode("B");
|
||||
// net.addNode("C");
|
||||
// net.addEdge("A", "B");
|
||||
// net.addEdge("A", "C");
|
||||
// auto str = net.show();
|
||||
// REQUIRE(str.size() == 3);
|
||||
// REQUIRE(str[0] == "A");
|
||||
// REQUIRE(str[1] == "B -> C");
|
||||
// REQUIRE(str[2] == "C");
|
||||
// }
|
||||
|
||||
// SECTION("Test graph")
|
||||
// {
|
||||
// auto net = bayesnet::Network();
|
||||
// net.addNode("A");
|
||||
// net.addNode("B");
|
||||
// net.addNode("C");
|
||||
// net.addEdge("A", "B");
|
||||
// net.addEdge("A", "C");
|
||||
// auto str = net.graph("Test Graph");
|
||||
// REQUIRE(str.size() == 6);
|
||||
// REQUIRE(str[0] == "digraph \"Test Graph\" {");
|
||||
// REQUIRE(str[1] == " A -> B;");
|
||||
// REQUIRE(str[2] == " A -> C;");
|
||||
// REQUIRE(str[3] == " B [shape=ellipse];");
|
||||
// REQUIRE(str[4] == " C [shape=ellipse];");
|
||||
// REQUIRE(str[5] == "}");
|
||||
// }
|
||||
|
||||
// SECTION("Test initialize")
|
||||
// {
|
||||
// auto net = bayesnet::Network();
|
||||
// net.addNode("A");
|
||||
// net.addNode("B");
|
||||
// net.addNode("C");
|
||||
// net.addEdge("A", "B");
|
||||
// net.addEdge("A", "C");
|
||||
// net.initialize();
|
||||
// REQUIRE(net.getNodes().size() == 0);
|
||||
// REQUIRE(net.getEdges().size() == 0);
|
||||
// REQUIRE(net.getFeatures().size() == 0);
|
||||
// REQUIRE(net.getClassNumStates() == 0);
|
||||
// REQUIRE(net.getClassName().empty());
|
||||
// REQUIRE(net.getStates() == 0);
|
||||
// REQUIRE(net.getSamples().numel() == 0);
|
||||
// }
|
||||
|
||||
// SECTION("Test dump_cpt")
|
||||
// {
|
||||
// auto net = bayesnet::Network();
|
||||
// net.addNode("A");
|
||||
// net.addNode("B");
|
||||
// net.addNode("C");
|
||||
// net.addEdge("A", "B");
|
||||
// net.addEdge("A", "C");
|
||||
// net.setClassName("C");
|
||||
// net.setStates({ {"A", {0, 1}}, {"B", {0, 1}}, {"C", {0, 1, 2}} });
|
||||
// net.fit({ {0, 0}, {0, 1}, {1, 0}, {1, 1} }, { 0, 1, 1, 2 }, {}, { "A", "B" }, "C", { {"A", {0, 1}}, {"B", {0, 1}}, {"C", {0, 1, 2}} });
|
||||
// net.dump_cpt();
|
||||
// // TODO: Check that the file was created and contains the expected data
|
||||
// }
|
||||
|
||||
// SECTION("Test version")
|
||||
// {
|
||||
// auto net = bayesnet::Network();
|
||||
// REQUIRE(net.version() == "0.2.0");
|
||||
// }
|
||||
// }
|
||||
|
||||
// }
|
||||
|
@ -27,10 +27,12 @@ public:
|
||||
dataset = torch::cat({ Xt, yresized }, 0);
|
||||
nSamples = dataset.size(1);
|
||||
weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
|
||||
weightsv = vector<double>(nSamples, 1.0 / nSamples);
|
||||
classNumStates = discretize ? statest.at(classNamet).size() : 0;
|
||||
}
|
||||
torch::Tensor Xt, yt, dataset, weights;
|
||||
vector<vector<int>> Xv;
|
||||
vector<double> weightsv;
|
||||
vector<int> yv;
|
||||
vector<string> featurest, featuresv;
|
||||
map<string, vector<int>> statest, statesv;
|
||||
|
Loading…
Reference in New Issue
Block a user