Add hyperparameter convergence_best
move test libraries to test folder
This commit is contained in:
@@ -18,47 +18,47 @@ TEST_CASE("Test Cannot build dataset with wrong data vector", "[Classifier]")
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
raw.yv.pop_back();
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), std::runtime_error);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states), std::runtime_error);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
|
||||
}
|
||||
TEST_CASE("Test Cannot build dataset with wrong data tensor", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto yshort = torch::zeros({ 149 }, torch::kInt32);
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), std::runtime_error);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, yshort, raw.features, raw.className, raw.states), std::runtime_error);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, yshort, raw.features, raw.className, raw.states), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
|
||||
}
|
||||
TEST_CASE("Invalid data type", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", false);
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "dataset (X, y) must be of type Integer");
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states), "dataset (X, y) must be of type Integer");
|
||||
}
|
||||
TEST_CASE("Invalid number of features", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto Xt = torch::cat({ raw.Xt, torch::zeros({ 1, 150 }, torch::kInt32) }, 0);
|
||||
REQUIRE_THROWS_AS(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "Classifier: X 5 and features 4 must have the same number of features");
|
||||
REQUIRE_THROWS_AS(model.fit(Xt, raw.yt, raw.features, raw.className, raw.states), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(Xt, raw.yt, raw.features, raw.className, raw.states), "Classifier: X 5 and features 4 must have the same number of features");
|
||||
}
|
||||
TEST_CASE("Invalid class name", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), "class name not found in states");
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.features, "duck", raw.states), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.features, "duck", raw.states), "class name not found in states");
|
||||
}
|
||||
TEST_CASE("Invalid feature name", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto statest = raw.statest;
|
||||
auto statest = raw.states;
|
||||
statest.erase("petallength");
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), "feature [petallength] not found in states");
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.features, raw.className, statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.features, raw.className, statest), "feature [petallength] not found in states");
|
||||
}
|
||||
TEST_CASE("Invalid hyperparameter", "[Classifier]")
|
||||
{
|
||||
@@ -71,7 +71,7 @@ TEST_CASE("Topological order", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||
model.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states);
|
||||
auto order = model.topological_order();
|
||||
REQUIRE(order.size() == 4);
|
||||
REQUIRE(order[0] == "petallength");
|
||||
@@ -83,7 +83,7 @@ TEST_CASE("Dump_cpt", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
auto raw = RawDatasets("iris", true);
|
||||
model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||
model.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states);
|
||||
auto cpt = model.dump_cpt();
|
||||
REQUIRE(cpt.size() == 1713);
|
||||
}
|
||||
@@ -111,7 +111,7 @@ TEST_CASE("KDB Graph", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::KDB(2);
|
||||
auto raw = RawDatasets("iris", true);
|
||||
model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
model.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states);
|
||||
auto graph = model.graph();
|
||||
REQUIRE(graph.size() == 15);
|
||||
}
|
||||
@@ -119,7 +119,7 @@ TEST_CASE("KDBLd Graph", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::KDBLd(2);
|
||||
auto raw = RawDatasets("iris", false);
|
||||
model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||
model.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states);
|
||||
auto graph = model.graph();
|
||||
REQUIRE(graph.size() == 15);
|
||||
}
|
Reference in New Issue
Block a user