BayesNet/tests/TestBayesMetrics.cc

46 lines
1.9 KiB
C++
Raw Normal View History

#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_approx.hpp>
#include <catch2/generators/catch_generators.hpp>
#include "BayesMetrics.h"
2023-10-04 23:14:16 +00:00
#include "TestUtils.h"
using namespace std;
TEST_CASE("Metrics Test", "[Metrics]")
{
2023-10-04 23:14:16 +00:00
string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
2023-10-05 09:45:00 +00:00
map<string, pair<int, vector<int>>> resultsKBest = {
2023-10-04 23:14:16 +00:00
{"glass", {7, { 3, 2, 0, 1, 6, 7, 5 }}},
{"iris", {3, { 1, 0, 2 }} },
{"ecoli", {6, { 2, 3, 1, 0, 4, 5 }}},
{"diabetes", {2, { 2, 0 }}}
};
auto [XDisc, yDisc, featuresDisc, classNameDisc, statesDisc] = loadDataset(file_name, true, true);
int classNumStates = statesDisc.at(classNameDisc).size();
auto yresized = torch::transpose(yDisc.view({ yDisc.size(0), 1 }), 0, 1);
torch::Tensor dataset = torch::cat({ XDisc, yresized }, 0);
int nSamples = dataset.size(1);
SECTION("Test Constructor")
{
2023-10-04 23:14:16 +00:00
bayesnet::Metrics metrics(XDisc, featuresDisc, classNameDisc, classNumStates);
REQUIRE(metrics.getScoresKBest().size() == 0);
}
SECTION("Test SelectKBestWeighted")
{
2023-10-04 23:14:16 +00:00
bayesnet::Metrics metrics(XDisc, featuresDisc, classNameDisc, classNumStates);
torch::Tensor weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
2023-10-05 09:45:00 +00:00
vector<int> kBest = metrics.SelectKBestWeighted(weights, true, resultsKBest.at(file_name).first);
REQUIRE(kBest.size() == resultsKBest.at(file_name).first);
REQUIRE(kBest == resultsKBest.at(file_name).second);
}
SECTION("Test mutualInformation")
{
2023-10-05 09:45:00 +00:00
bayesnet::Metrics metrics(XDisc, featuresDisc, classNameDisc, classNumStates);
torch::Tensor weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
auto result = metrics.mutualInformation(dataset.index({ 1, "..." }), dataset.index({ 2, "..." }), weights);
REQUIRE(result == 0.87);
}
}