Add tests for Classifier class

This commit is contained in:
2024-04-08 01:25:14 +02:00
parent 9014649a0d
commit 50543e7929
6 changed files with 73 additions and 9 deletions

View File

@@ -20,4 +20,67 @@ TEST_CASE("Test Cannot build dataset with wrong data tensor", "[Classifier]")
auto yshort = torch::zeros({ 149 }, torch::kInt32);
REQUIRE_THROWS_AS(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), std::runtime_error);
REQUIRE_THROWS_WITH(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
}
TEST_CASE("Invalid data type", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", false);
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "dataset (X, y) must be of type Integer");
}
TEST_CASE("Invalid number of features", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
auto Xt = torch::cat({ raw.Xt, torch::zeros({ 1, 150 }, torch::kInt32) }, 0);
REQUIRE_THROWS_AS(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
REQUIRE_THROWS_WITH(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "Classifier: X 5 and features 4 must have the same number of features");
}
TEST_CASE("Invalid class name", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), std::invalid_argument);
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), "class name not found in states");
}
TEST_CASE("Invalid feature name", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
auto statest = raw.statest;
statest.erase("petallength");
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), std::invalid_argument);
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), "feature [petallength] not found in states");
}
TEST_CASE("Topological order", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
auto order = model.topological_order();
REQUIRE(order.size() == 4);
REQUIRE(order[0] == "petallength");
REQUIRE(order[1] == "sepallength");
REQUIRE(order[2] == "sepalwidth");
REQUIRE(order[3] == "petalwidth");
}
TEST_CASE("Not fitted model", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
auto message = "Classifier has not been fitted";
// tensors
REQUIRE_THROWS_AS(model.predict(raw.Xt), std::logic_error);
REQUIRE_THROWS_WITH(model.predict(raw.Xt), message);
REQUIRE_THROWS_AS(model.predict_proba(raw.Xt), std::logic_error);
REQUIRE_THROWS_WITH(model.predict_proba(raw.Xt), message);
REQUIRE_THROWS_AS(model.score(raw.Xt, raw.yt), std::logic_error);
REQUIRE_THROWS_WITH(model.score(raw.Xt, raw.yt), message);
// vectors
REQUIRE_THROWS_AS(model.predict(raw.Xv), std::logic_error);
REQUIRE_THROWS_WITH(model.predict(raw.Xv), message);
REQUIRE_THROWS_AS(model.predict_proba(raw.Xv), std::logic_error);
REQUIRE_THROWS_WITH(model.predict_proba(raw.Xv), message);
REQUIRE_THROWS_AS(model.score(raw.Xv, raw.yv), std::logic_error);
REQUIRE_THROWS_WITH(model.score(raw.Xv, raw.yv), message);
}

View File

@@ -246,7 +246,7 @@ TEST_CASE("BoostAODE voting-proba", "[Models]")
REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon));
REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon));
REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon));
clf.dump_cpt();
REQUIRE(clf.dump_cpt() == "");
REQUIRE(clf.topological_order() == std::vector<std::string>());
}
TEST_CASE("AODE voting-proba", "[Models]")