Begin Test Folding
This commit is contained in:
@@ -22,19 +22,13 @@ TEST_CASE("Metrics Test", "[BayesNet]")
|
||||
{"diabetes", 0.0345470614}
|
||||
};
|
||||
map<string, vector<pair<int, int>>> resultsMST = {
|
||||
{"glass", {{0,6}, {0,5}, {0,3}, {5,1}, {5,8}, {6,2}, {6,7}, {7,4}}},
|
||||
{"glass", {{0,6}, {0,5}, {0,3}, {6,2}, {6,7}, {5,1}, {5,8}, {5,4}}},
|
||||
{"iris", {{0,1},{0,2},{1,3}}},
|
||||
{"ecoli", {{0,1}, {0,2}, {1,5}, {1,3}, {5,6}, {5,4}}},
|
||||
{"diabetes", {{0,7}, {0,2}, {0,6}, {2,3}, {3,4}, {3,5}, {4,1}}}
|
||||
};
|
||||
auto [XDisc, yDisc, featuresDisc, classNameDisc, statesDisc] = loadDataset(file_name, true, true);
|
||||
int classNumStates = statesDisc.at(classNameDisc).size();
|
||||
auto yresized = torch::transpose(yDisc.view({ yDisc.size(0), 1 }), 0, 1);
|
||||
torch::Tensor dataset = torch::cat({ XDisc, yresized }, 0);
|
||||
int nSamples = dataset.size(1);
|
||||
double epsilon = 1e-5;
|
||||
torch::Tensor weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
|
||||
bayesnet::Metrics metrics(dataset, featuresDisc, classNameDisc, classNumStates);
|
||||
auto raw = RawDatasets(file_name, true);
|
||||
bayesnet::Metrics metrics(raw.dataset, raw.featurest, raw.classNamet, raw.classNumStates);
|
||||
|
||||
SECTION("Test Constructor")
|
||||
{
|
||||
@@ -43,21 +37,21 @@ TEST_CASE("Metrics Test", "[BayesNet]")
|
||||
|
||||
SECTION("Test SelectKBestWeighted")
|
||||
{
|
||||
vector<int> kBest = metrics.SelectKBestWeighted(weights, true, resultsKBest.at(file_name).first);
|
||||
vector<int> kBest = metrics.SelectKBestWeighted(raw.weights, true, resultsKBest.at(file_name).first);
|
||||
REQUIRE(kBest.size() == resultsKBest.at(file_name).first);
|
||||
REQUIRE(kBest == resultsKBest.at(file_name).second);
|
||||
}
|
||||
|
||||
SECTION("Test Mutual Information")
|
||||
{
|
||||
auto result = metrics.mutualInformation(dataset.index({ 1, "..." }), dataset.index({ 2, "..." }), weights);
|
||||
REQUIRE(result == Catch::Approx(resultsMI.at(file_name)).epsilon(epsilon));
|
||||
auto result = metrics.mutualInformation(raw.dataset.index({ 1, "..." }), raw.dataset.index({ 2, "..." }), raw.weights);
|
||||
REQUIRE(result == Catch::Approx(resultsMI.at(file_name)).epsilon(raw.epsilon));
|
||||
}
|
||||
|
||||
SECTION("Test Maximum Spanning Tree")
|
||||
{
|
||||
auto weights_matrix = metrics.conditionalEdge(weights);
|
||||
auto result = metrics.maximumSpanningTree(featuresDisc, weights_matrix, 0);
|
||||
auto weights_matrix = metrics.conditionalEdge(raw.weights);
|
||||
auto result = metrics.maximumSpanningTree(raw.featurest, weights_matrix, 0);
|
||||
REQUIRE(result == resultsMST.at(file_name));
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user