Compare commits
2 Commits
0d6a081d01
...
50543e7929
Author | SHA1 | Date | |
---|---|---|---|
50543e7929
|
|||
9014649a0d
|
@@ -30,7 +30,7 @@ namespace bayesnet {
|
|||||||
virtual std::string getVersion() = 0;
|
virtual std::string getVersion() = 0;
|
||||||
std::vector<std::string> virtual topological_order() = 0;
|
std::vector<std::string> virtual topological_order() = 0;
|
||||||
std::vector<std::string> virtual getNotes() const = 0;
|
std::vector<std::string> virtual getNotes() const = 0;
|
||||||
void virtual dump_cpt()const = 0;
|
std::string virtual dump_cpt()const = 0;
|
||||||
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
|
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
|
||||||
std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
|
std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
|
||||||
protected:
|
protected:
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
#include <sstream>
|
||||||
#include "bayesnet/utils/bayesnetUtils.h"
|
#include "bayesnet/utils/bayesnetUtils.h"
|
||||||
#include "Classifier.h"
|
#include "Classifier.h"
|
||||||
|
|
||||||
@@ -27,10 +28,11 @@ namespace bayesnet {
|
|||||||
dataset = torch::cat({ dataset, yresized }, 0);
|
dataset = torch::cat({ dataset, yresized }, 0);
|
||||||
}
|
}
|
||||||
catch (const std::exception& e) {
|
catch (const std::exception& e) {
|
||||||
std::cerr << e.what() << '\n';
|
std::stringstream oss;
|
||||||
std::cout << "X dimensions: " << dataset.sizes() << "\n";
|
oss << "* Error in X and y dimensions *\n";
|
||||||
std::cout << "y dimensions: " << ytmp.sizes() << "\n";
|
oss << "X dimensions: " << dataset.sizes() << "\n";
|
||||||
exit(1);
|
oss << "y dimensions: " << ytmp.sizes();
|
||||||
|
throw std::runtime_error(oss.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void Classifier::trainModel(const torch::Tensor& weights)
|
void Classifier::trainModel(const torch::Tensor& weights)
|
||||||
@@ -73,11 +75,11 @@ namespace bayesnet {
|
|||||||
if (torch::is_floating_point(dataset)) {
|
if (torch::is_floating_point(dataset)) {
|
||||||
throw std::invalid_argument("dataset (X, y) must be of type Integer");
|
throw std::invalid_argument("dataset (X, y) must be of type Integer");
|
||||||
}
|
}
|
||||||
if (n != features.size()) {
|
if (dataset.size(0) - 1 != features.size()) {
|
||||||
throw std::invalid_argument("Classifier: X " + std::to_string(n) + " and features " + std::to_string(features.size()) + " must have the same number of features");
|
throw std::invalid_argument("Classifier: X " + std::to_string(dataset.size(0) - 1) + " and features " + std::to_string(features.size()) + " must have the same number of features");
|
||||||
}
|
}
|
||||||
if (states.find(className) == states.end()) {
|
if (states.find(className) == states.end()) {
|
||||||
throw std::invalid_argument("className not found in states");
|
throw std::invalid_argument("class name not found in states");
|
||||||
}
|
}
|
||||||
for (auto feature : features) {
|
for (auto feature : features) {
|
||||||
if (states.find(feature) == states.end()) {
|
if (states.find(feature) == states.end()) {
|
||||||
@@ -173,12 +175,14 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
return model.topological_sort();
|
return model.topological_sort();
|
||||||
}
|
}
|
||||||
void Classifier::dump_cpt() const
|
std::string Classifier::dump_cpt() const
|
||||||
{
|
{
|
||||||
model.dump_cpt();
|
return model.dump_cpt();
|
||||||
}
|
}
|
||||||
void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
|
void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||||
{
|
{
|
||||||
//For classifiers that don't have hyperparameters
|
if (!hyperparameters.empty()) {
|
||||||
|
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -30,7 +30,7 @@ namespace bayesnet {
|
|||||||
std::vector<std::string> show() const override;
|
std::vector<std::string> show() const override;
|
||||||
std::vector<std::string> topological_order() override;
|
std::vector<std::string> topological_order() override;
|
||||||
std::vector<std::string> getNotes() const override { return notes; }
|
std::vector<std::string> getNotes() const override { return notes; }
|
||||||
void dump_cpt() const override;
|
std::string dump_cpt() const override;
|
||||||
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
|
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
|
||||||
protected:
|
protected:
|
||||||
bool fitted;
|
bool fitted;
|
||||||
|
@@ -6,14 +6,18 @@ namespace bayesnet {
|
|||||||
validHyperparameters = { "k", "theta" };
|
validHyperparameters = { "k", "theta" };
|
||||||
|
|
||||||
}
|
}
|
||||||
void KDB::setHyperparameters(const nlohmann::json& hyperparameters)
|
void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||||
{
|
{
|
||||||
|
auto hyperparameters = hyperparameters_;
|
||||||
if (hyperparameters.contains("k")) {
|
if (hyperparameters.contains("k")) {
|
||||||
k = hyperparameters["k"];
|
k = hyperparameters["k"];
|
||||||
|
hyperparameters.erase("k");
|
||||||
}
|
}
|
||||||
if (hyperparameters.contains("theta")) {
|
if (hyperparameters.contains("theta")) {
|
||||||
theta = hyperparameters["theta"];
|
theta = hyperparameters["theta"];
|
||||||
|
hyperparameters.erase("theta");
|
||||||
}
|
}
|
||||||
|
Classifier::setHyperparameters(hyperparameters);
|
||||||
}
|
}
|
||||||
void KDB::buildModel(const torch::Tensor& weights)
|
void KDB::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
|
@@ -14,7 +14,7 @@ namespace bayesnet {
|
|||||||
public:
|
public:
|
||||||
explicit KDB(int k, float theta = 0.03);
|
explicit KDB(int k, float theta = 0.03);
|
||||||
virtual ~KDB() = default;
|
virtual ~KDB() = default;
|
||||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||||
std::vector<std::string> graph(const std::string& name = "KDB") const override;
|
std::vector<std::string> graph(const std::string& name = "KDB") const override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@@ -13,9 +13,7 @@ namespace bayesnet {
|
|||||||
predict_voting = hyperparameters["predict_voting"];
|
predict_voting = hyperparameters["predict_voting"];
|
||||||
hyperparameters.erase("predict_voting");
|
hyperparameters.erase("predict_voting");
|
||||||
}
|
}
|
||||||
if (!hyperparameters.empty()) {
|
Classifier::setHyperparameters(hyperparameters);
|
||||||
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
void AODE::buildModel(const torch::Tensor& weights)
|
void AODE::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
|
@@ -94,9 +94,7 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
hyperparameters.erase("select_features");
|
hyperparameters.erase("select_features");
|
||||||
}
|
}
|
||||||
if (!hyperparameters.empty()) {
|
Classifier::setHyperparameters(hyperparameters);
|
||||||
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights)
|
std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
|
@@ -20,7 +20,7 @@ namespace bayesnet {
|
|||||||
BoostAODE(bool predict_voting = false);
|
BoostAODE(bool predict_voting = false);
|
||||||
virtual ~BoostAODE() = default;
|
virtual ~BoostAODE() = default;
|
||||||
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
|
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
|
||||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||||
protected:
|
protected:
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
|
@@ -25,8 +25,9 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
return std::vector<std::string>();
|
return std::vector<std::string>();
|
||||||
}
|
}
|
||||||
void dump_cpt() const override
|
std::string dump_cpt() const override
|
||||||
{
|
{
|
||||||
|
return "";
|
||||||
}
|
}
|
||||||
protected:
|
protected:
|
||||||
torch::Tensor predict_average_voting(torch::Tensor& X);
|
torch::Tensor predict_average_voting(torch::Tensor& X);
|
||||||
|
@@ -8,12 +8,13 @@ if(ENABLE_TESTING)
|
|||||||
${CMAKE_BINARY_DIR}/configured_files/include
|
${CMAKE_BINARY_DIR}/configured_files/include
|
||||||
)
|
)
|
||||||
file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc")
|
file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc")
|
||||||
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES})
|
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES})
|
||||||
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain )
|
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain )
|
||||||
add_test(NAME BayesNetworkTest COMMAND TestBayesNet)
|
add_test(NAME BayesNetworkTest COMMAND TestBayesNet)
|
||||||
add_test(NAME Network COMMAND TestBayesNet "[Network]")
|
add_test(NAME Network COMMAND TestBayesNet "[Network]")
|
||||||
add_test(NAME Node COMMAND TestBayesNet "[Node]")
|
add_test(NAME Node COMMAND TestBayesNet "[Node]")
|
||||||
add_test(NAME Metrics COMMAND TestBayesNet "[Metrics]")
|
add_test(NAME Metrics COMMAND TestBayesNet "[Metrics]")
|
||||||
add_test(NAME FeatureSelection COMMAND TestBayesNet "[FeatureSelection]")
|
add_test(NAME FeatureSelection COMMAND TestBayesNet "[FeatureSelection]")
|
||||||
|
add_test(NAME Classifier COMMAND TestBayesNet "[Classifier]")
|
||||||
add_test(NAME Models COMMAND TestBayesNet "[Models]")
|
add_test(NAME Models COMMAND TestBayesNet "[Models]")
|
||||||
endif(ENABLE_TESTING)
|
endif(ENABLE_TESTING)
|
||||||
|
86
tests/TestBayesClassifier.cc
Normal file
86
tests/TestBayesClassifier.cc
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
#include <catch2/catch_test_macros.hpp>
|
||||||
|
#include <catch2/matchers/catch_matchers.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include "TestUtils.h"
|
||||||
|
#include "bayesnet/classifiers/TAN.h"
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE("Test Cannot build dataset with wrong data vector", "[Classifier]")
|
||||||
|
{
|
||||||
|
auto model = bayesnet::TAN();
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
raw.yv.pop_back();
|
||||||
|
REQUIRE_THROWS_AS(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), std::runtime_error);
|
||||||
|
REQUIRE_THROWS_WITH(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
|
||||||
|
}
|
||||||
|
TEST_CASE("Test Cannot build dataset with wrong data tensor", "[Classifier]")
|
||||||
|
{
|
||||||
|
auto model = bayesnet::TAN();
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
auto yshort = torch::zeros({ 149 }, torch::kInt32);
|
||||||
|
REQUIRE_THROWS_AS(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), std::runtime_error);
|
||||||
|
REQUIRE_THROWS_WITH(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
|
||||||
|
}
|
||||||
|
TEST_CASE("Invalid data type", "[Classifier]")
|
||||||
|
{
|
||||||
|
auto model = bayesnet::TAN();
|
||||||
|
auto raw = RawDatasets("iris", false);
|
||||||
|
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "dataset (X, y) must be of type Integer");
|
||||||
|
}
|
||||||
|
TEST_CASE("Invalid number of features", "[Classifier]")
|
||||||
|
{
|
||||||
|
auto model = bayesnet::TAN();
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
auto Xt = torch::cat({ raw.Xt, torch::zeros({ 1, 150 }, torch::kInt32) }, 0);
|
||||||
|
REQUIRE_THROWS_AS(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "Classifier: X 5 and features 4 must have the same number of features");
|
||||||
|
}
|
||||||
|
TEST_CASE("Invalid class name", "[Classifier]")
|
||||||
|
{
|
||||||
|
auto model = bayesnet::TAN();
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), "class name not found in states");
|
||||||
|
}
|
||||||
|
TEST_CASE("Invalid feature name", "[Classifier]")
|
||||||
|
{
|
||||||
|
auto model = bayesnet::TAN();
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
auto statest = raw.statest;
|
||||||
|
statest.erase("petallength");
|
||||||
|
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), "feature [petallength] not found in states");
|
||||||
|
}
|
||||||
|
TEST_CASE("Topological order", "[Classifier]")
|
||||||
|
{
|
||||||
|
auto model = bayesnet::TAN();
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||||
|
auto order = model.topological_order();
|
||||||
|
REQUIRE(order.size() == 4);
|
||||||
|
REQUIRE(order[0] == "petallength");
|
||||||
|
REQUIRE(order[1] == "sepallength");
|
||||||
|
REQUIRE(order[2] == "sepalwidth");
|
||||||
|
REQUIRE(order[3] == "petalwidth");
|
||||||
|
}
|
||||||
|
TEST_CASE("Not fitted model", "[Classifier]")
|
||||||
|
{
|
||||||
|
auto model = bayesnet::TAN();
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
auto message = "Classifier has not been fitted";
|
||||||
|
// tensors
|
||||||
|
REQUIRE_THROWS_AS(model.predict(raw.Xt), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(model.predict(raw.Xt), message);
|
||||||
|
REQUIRE_THROWS_AS(model.predict_proba(raw.Xt), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(model.predict_proba(raw.Xt), message);
|
||||||
|
REQUIRE_THROWS_AS(model.score(raw.Xt, raw.yt), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(model.score(raw.Xt, raw.yt), message);
|
||||||
|
// vectors
|
||||||
|
REQUIRE_THROWS_AS(model.predict(raw.Xv), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(model.predict(raw.Xv), message);
|
||||||
|
REQUIRE_THROWS_AS(model.predict_proba(raw.Xv), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(model.predict_proba(raw.Xv), message);
|
||||||
|
REQUIRE_THROWS_AS(model.score(raw.Xv, raw.yv), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(model.score(raw.Xv, raw.yv), message);
|
||||||
|
}
|
@@ -246,7 +246,7 @@ TEST_CASE("BoostAODE voting-proba", "[Models]")
|
|||||||
REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon));
|
REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon));
|
||||||
REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon));
|
REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon));
|
||||||
REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon));
|
REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon));
|
||||||
clf.dump_cpt();
|
REQUIRE(clf.dump_cpt() == "");
|
||||||
REQUIRE(clf.topological_order() == std::vector<std::string>());
|
REQUIRE(clf.topological_order() == std::vector<std::string>());
|
||||||
}
|
}
|
||||||
TEST_CASE("AODE voting-proba", "[Models]")
|
TEST_CASE("AODE voting-proba", "[Models]")
|
||||||
|
Reference in New Issue
Block a user