diff --git a/bayesnet/BaseClassifier.h b/bayesnet/BaseClassifier.h index 69f4b29..c237349 100644 --- a/bayesnet/BaseClassifier.h +++ b/bayesnet/BaseClassifier.h @@ -30,7 +30,7 @@ namespace bayesnet { virtual std::string getVersion() = 0; std::vector virtual topological_order() = 0; std::vector virtual getNotes() const = 0; - void virtual dump_cpt()const = 0; + std::string virtual dump_cpt()const = 0; virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0; std::vector& getValidHyperparameters() { return validHyperparameters; } protected: diff --git a/bayesnet/classifiers/Classifier.cc b/bayesnet/classifiers/Classifier.cc index 19f337a..a8bf6ef 100644 --- a/bayesnet/classifiers/Classifier.cc +++ b/bayesnet/classifiers/Classifier.cc @@ -75,11 +75,11 @@ namespace bayesnet { if (torch::is_floating_point(dataset)) { throw std::invalid_argument("dataset (X, y) must be of type Integer"); } - if (n != features.size()) { - throw std::invalid_argument("Classifier: X " + std::to_string(n) + " and features " + std::to_string(features.size()) + " must have the same number of features"); + if (dataset.size(0) - 1 != features.size()) { + throw std::invalid_argument("Classifier: X " + std::to_string(dataset.size(0) - 1) + " and features " + std::to_string(features.size()) + " must have the same number of features"); } if (states.find(className) == states.end()) { - throw std::invalid_argument("className not found in states"); + throw std::invalid_argument("class name not found in states"); } for (auto feature : features) { if (states.find(feature) == states.end()) { @@ -175,9 +175,9 @@ namespace bayesnet { { return model.topological_sort(); } - void Classifier::dump_cpt() const + std::string Classifier::dump_cpt() const { - model.dump_cpt(); + return model.dump_cpt(); } void Classifier::setHyperparameters(const nlohmann::json& hyperparameters) { diff --git a/bayesnet/classifiers/Classifier.h b/bayesnet/classifiers/Classifier.h index c7685a2..2511c4d 100644 --- a/bayesnet/classifiers/Classifier.h +++ b/bayesnet/classifiers/Classifier.h @@ -30,7 +30,7 @@ namespace bayesnet { std::vector show() const override; std::vector topological_order() override; std::vector getNotes() const override { return notes; } - void dump_cpt() const override; + std::string dump_cpt() const override; void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters protected: bool fitted; diff --git a/bayesnet/ensembles/Ensemble.h b/bayesnet/ensembles/Ensemble.h index cb4220a..bab4d25 100644 --- a/bayesnet/ensembles/Ensemble.h +++ b/bayesnet/ensembles/Ensemble.h @@ -25,8 +25,9 @@ namespace bayesnet { { return std::vector(); } - void dump_cpt() const override + std::string dump_cpt() const override { + return ""; } protected: torch::Tensor predict_average_voting(torch::Tensor& X); diff --git a/tests/TestBayesClassifier.cc b/tests/TestBayesClassifier.cc index d07adbd..1e33e67 100644 --- a/tests/TestBayesClassifier.cc +++ b/tests/TestBayesClassifier.cc @@ -20,4 +20,67 @@ TEST_CASE("Test Cannot build dataset with wrong data tensor", "[Classifier]") auto yshort = torch::zeros({ 149 }, torch::kInt32); REQUIRE_THROWS_AS(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), std::runtime_error); REQUIRE_THROWS_WITH(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]"); +} +TEST_CASE("Invalid data type", "[Classifier]") +{ + auto model = bayesnet::TAN(); + auto raw = RawDatasets("iris", false); + REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument); + REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "dataset (X, y) must be of type Integer"); +} +TEST_CASE("Invalid number of features", "[Classifier]") +{ + auto model = bayesnet::TAN(); + auto raw = RawDatasets("iris", true); + auto Xt = torch::cat({ raw.Xt, torch::zeros({ 1, 150 }, torch::kInt32) }, 0); + REQUIRE_THROWS_AS(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument); + REQUIRE_THROWS_WITH(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "Classifier: X 5 and features 4 must have the same number of features"); +} +TEST_CASE("Invalid class name", "[Classifier]") +{ + auto model = bayesnet::TAN(); + auto raw = RawDatasets("iris", true); + REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), std::invalid_argument); + REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), "class name not found in states"); +} +TEST_CASE("Invalid feature name", "[Classifier]") +{ + auto model = bayesnet::TAN(); + auto raw = RawDatasets("iris", true); + auto statest = raw.statest; + statest.erase("petallength"); + REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), std::invalid_argument); + REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), "feature [petallength] not found in states"); +} +TEST_CASE("Topological order", "[Classifier]") +{ + auto model = bayesnet::TAN(); + auto raw = RawDatasets("iris", true); + model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); + auto order = model.topological_order(); + REQUIRE(order.size() == 4); + REQUIRE(order[0] == "petallength"); + REQUIRE(order[1] == "sepallength"); + REQUIRE(order[2] == "sepalwidth"); + REQUIRE(order[3] == "petalwidth"); +} +TEST_CASE("Not fitted model", "[Classifier]") +{ + auto model = bayesnet::TAN(); + auto raw = RawDatasets("iris", true); + auto message = "Classifier has not been fitted"; + // tensors + REQUIRE_THROWS_AS(model.predict(raw.Xt), std::logic_error); + REQUIRE_THROWS_WITH(model.predict(raw.Xt), message); + REQUIRE_THROWS_AS(model.predict_proba(raw.Xt), std::logic_error); + REQUIRE_THROWS_WITH(model.predict_proba(raw.Xt), message); + REQUIRE_THROWS_AS(model.score(raw.Xt, raw.yt), std::logic_error); + REQUIRE_THROWS_WITH(model.score(raw.Xt, raw.yt), message); + // vectors + REQUIRE_THROWS_AS(model.predict(raw.Xv), std::logic_error); + REQUIRE_THROWS_WITH(model.predict(raw.Xv), message); + REQUIRE_THROWS_AS(model.predict_proba(raw.Xv), std::logic_error); + REQUIRE_THROWS_WITH(model.predict_proba(raw.Xv), message); + REQUIRE_THROWS_AS(model.score(raw.Xv, raw.yv), std::logic_error); + REQUIRE_THROWS_WITH(model.score(raw.Xv, raw.yv), message); } \ No newline at end of file diff --git a/tests/TestBayesModels.cc b/tests/TestBayesModels.cc index bfa7169..51d5091 100644 --- a/tests/TestBayesModels.cc +++ b/tests/TestBayesModels.cc @@ -246,7 +246,7 @@ TEST_CASE("BoostAODE voting-proba", "[Models]") REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon)); REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon)); REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon)); - clf.dump_cpt(); + REQUIRE(clf.dump_cpt() == ""); REQUIRE(clf.topological_order() == std::vector()); } TEST_CASE("AODE voting-proba", "[Models]")