#include #include #include #include "BayesMetrics.h" using namespace std; TEST_CASE("Metrics Test", "[Metrics]") { SECTION("Test Constructor") { torch::Tensor samples = torch::rand({ 10, 5 }); vector features = { "feature1", "feature2", "feature3", "feature4", "feature5" }; string className = "class1"; int classNumStates = 2; bayesnet::Metrics obj(samples, features, className, classNumStates); REQUIRE(obj.getScoresKBest().size() == 0); } SECTION("Test SelectKBestWeighted") { torch::Tensor samples = torch::rand({ 10, 5 }); vector features = { "feature1", "feature2", "feature3", "feature4", "feature5" }; string className = "class1"; int classNumStates = 2; bayesnet::Metrics obj(samples, features, className, classNumStates); torch::Tensor weights = torch::ones({ 5 }); vector kBest = obj.SelectKBestWeighted(weights, true, 3); REQUIRE(kBest.size() == 3); } SECTION("Test mutualInformation") { torch::Tensor samples = torch::rand({ 10, 5 }); vector features = { "feature1", "feature2", "feature3", "feature4", "feature5" }; string className = "class1"; int classNumStates = 2; bayesnet::Metrics obj(samples, features, className, classNumStates); torch::Tensor firstFeature = samples.select(1, 0); torch::Tensor secondFeature = samples.select(1, 1); torch::Tensor weights = torch::ones({ 10 }); double mi = obj.mutualInformation(firstFeature, secondFeature, weights); REQUIRE(mi >= 0); } }