Add tests for Classifier class
This commit is contained in:
parent
9014649a0d
commit
50543e7929
@ -30,7 +30,7 @@ namespace bayesnet {
|
||||
virtual std::string getVersion() = 0;
|
||||
std::vector<std::string> virtual topological_order() = 0;
|
||||
std::vector<std::string> virtual getNotes() const = 0;
|
||||
void virtual dump_cpt()const = 0;
|
||||
std::string virtual dump_cpt()const = 0;
|
||||
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
|
||||
std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
|
||||
protected:
|
||||
|
@ -75,11 +75,11 @@ namespace bayesnet {
|
||||
if (torch::is_floating_point(dataset)) {
|
||||
throw std::invalid_argument("dataset (X, y) must be of type Integer");
|
||||
}
|
||||
if (n != features.size()) {
|
||||
throw std::invalid_argument("Classifier: X " + std::to_string(n) + " and features " + std::to_string(features.size()) + " must have the same number of features");
|
||||
if (dataset.size(0) - 1 != features.size()) {
|
||||
throw std::invalid_argument("Classifier: X " + std::to_string(dataset.size(0) - 1) + " and features " + std::to_string(features.size()) + " must have the same number of features");
|
||||
}
|
||||
if (states.find(className) == states.end()) {
|
||||
throw std::invalid_argument("className not found in states");
|
||||
throw std::invalid_argument("class name not found in states");
|
||||
}
|
||||
for (auto feature : features) {
|
||||
if (states.find(feature) == states.end()) {
|
||||
@ -175,9 +175,9 @@ namespace bayesnet {
|
||||
{
|
||||
return model.topological_sort();
|
||||
}
|
||||
void Classifier::dump_cpt() const
|
||||
std::string Classifier::dump_cpt() const
|
||||
{
|
||||
model.dump_cpt();
|
||||
return model.dump_cpt();
|
||||
}
|
||||
void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
{
|
||||
|
@ -30,7 +30,7 @@ namespace bayesnet {
|
||||
std::vector<std::string> show() const override;
|
||||
std::vector<std::string> topological_order() override;
|
||||
std::vector<std::string> getNotes() const override { return notes; }
|
||||
void dump_cpt() const override;
|
||||
std::string dump_cpt() const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
|
||||
protected:
|
||||
bool fitted;
|
||||
|
@ -25,8 +25,9 @@ namespace bayesnet {
|
||||
{
|
||||
return std::vector<std::string>();
|
||||
}
|
||||
void dump_cpt() const override
|
||||
std::string dump_cpt() const override
|
||||
{
|
||||
return "";
|
||||
}
|
||||
protected:
|
||||
torch::Tensor predict_average_voting(torch::Tensor& X);
|
||||
|
@ -20,4 +20,67 @@ TEST_CASE("Test Cannot build dataset with wrong data tensor", "[Classifier]")
|
||||
auto yshort = torch::zeros({ 149 }, torch::kInt32);
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), std::runtime_error);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
|
||||
}
|
||||
TEST_CASE("Invalid data type", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", false);
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "dataset (X, y) must be of type Integer");
|
||||
}
|
||||
TEST_CASE("Invalid number of features", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto Xt = torch::cat({ raw.Xt, torch::zeros({ 1, 150 }, torch::kInt32) }, 0);
|
||||
REQUIRE_THROWS_AS(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "Classifier: X 5 and features 4 must have the same number of features");
|
||||
}
|
||||
TEST_CASE("Invalid class name", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), "class name not found in states");
|
||||
}
|
||||
TEST_CASE("Invalid feature name", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto statest = raw.statest;
|
||||
statest.erase("petallength");
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), "feature [petallength] not found in states");
|
||||
}
|
||||
TEST_CASE("Topological order", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||
auto order = model.topological_order();
|
||||
REQUIRE(order.size() == 4);
|
||||
REQUIRE(order[0] == "petallength");
|
||||
REQUIRE(order[1] == "sepallength");
|
||||
REQUIRE(order[2] == "sepalwidth");
|
||||
REQUIRE(order[3] == "petalwidth");
|
||||
}
|
||||
TEST_CASE("Not fitted model", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto message = "Classifier has not been fitted";
|
||||
// tensors
|
||||
REQUIRE_THROWS_AS(model.predict(raw.Xt), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(model.predict(raw.Xt), message);
|
||||
REQUIRE_THROWS_AS(model.predict_proba(raw.Xt), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(model.predict_proba(raw.Xt), message);
|
||||
REQUIRE_THROWS_AS(model.score(raw.Xt, raw.yt), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(model.score(raw.Xt, raw.yt), message);
|
||||
// vectors
|
||||
REQUIRE_THROWS_AS(model.predict(raw.Xv), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(model.predict(raw.Xv), message);
|
||||
REQUIRE_THROWS_AS(model.predict_proba(raw.Xv), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(model.predict_proba(raw.Xv), message);
|
||||
REQUIRE_THROWS_AS(model.score(raw.Xv, raw.yv), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(model.score(raw.Xv, raw.yv), message);
|
||||
}
|
@ -246,7 +246,7 @@ TEST_CASE("BoostAODE voting-proba", "[Models]")
|
||||
REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon));
|
||||
REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon));
|
||||
REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon));
|
||||
clf.dump_cpt();
|
||||
REQUIRE(clf.dump_cpt() == "");
|
||||
REQUIRE(clf.topological_order() == std::vector<std::string>());
|
||||
}
|
||||
TEST_CASE("AODE voting-proba", "[Models]")
|
||||
|
Loading…
Reference in New Issue
Block a user