Refactor Makefile
This commit is contained in:
@@ -9,7 +9,7 @@ using namespace std;
|
||||
TEST_CASE("Metrics Test", "[Metrics]")
|
||||
{
|
||||
string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
|
||||
map<string, pair<int, vector<int>>> results = {
|
||||
map<string, pair<int, vector<int>>> resultsKBest = {
|
||||
{"glass", {7, { 3, 2, 0, 1, 6, 7, 5 }}},
|
||||
{"iris", {3, { 1, 0, 2 }} },
|
||||
{"ecoli", {6, { 2, 3, 1, 0, 4, 5 }}},
|
||||
@@ -31,26 +31,16 @@ TEST_CASE("Metrics Test", "[Metrics]")
|
||||
{
|
||||
bayesnet::Metrics metrics(XDisc, featuresDisc, classNameDisc, classNumStates);
|
||||
torch::Tensor weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
|
||||
vector<int> kBest = metrics.SelectKBestWeighted(weights, true, results.at(file_name).first);
|
||||
REQUIRE(kBest.size() == results.at(file_name).first);
|
||||
REQUIRE(kBest == results.at(file_name).second);
|
||||
vector<int> kBest = metrics.SelectKBestWeighted(weights, true, resultsKBest.at(file_name).first);
|
||||
REQUIRE(kBest.size() == resultsKBest.at(file_name).first);
|
||||
REQUIRE(kBest == resultsKBest.at(file_name).second);
|
||||
}
|
||||
|
||||
SECTION("Test mutualInformation")
|
||||
{
|
||||
// torch::Tensor samples = torch::rand({ 10, 5 });
|
||||
// vector<string> features = { "feature1", "feature2", "feature3", "feature4", "feature5" };
|
||||
// string className = "class1";
|
||||
// int classNumStates = 2;
|
||||
|
||||
// bayesnet::Metrics obj(samples, features, className, classNumStates);
|
||||
|
||||
// torch::Tensor firstFeature = samples.select(1, 0);
|
||||
// torch::Tensor secondFeature = samples.select(1, 1);
|
||||
// torch::Tensor weights = torch::ones({ 10 });
|
||||
|
||||
// double mi = obj.mutualInformation(firstFeature, secondFeature, weights);
|
||||
|
||||
// REQUIRE(mi >= 0);
|
||||
bayesnet::Metrics metrics(XDisc, featuresDisc, classNameDisc, classNumStates);
|
||||
torch::Tensor weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
|
||||
auto result = metrics.mutualInformation(dataset.index({ 1, "..." }), dataset.index({ 2, "..." }), weights);
|
||||
REQUIRE(result == 0.87);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user