Refactor Tests and add BayesMetrics test
This commit is contained in:
parent
5e938d5cca
commit
3448fb1299
@ -1,87 +0,0 @@
|
||||
#define CATCH_CONFIG_MAIN // This tells Catch to provide a main() - only do
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "KDB.h"
|
||||
#include "TAN.h"
|
||||
#include "SPODE.h"
|
||||
#include "AODE.h"
|
||||
|
||||
TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
|
||||
{
|
||||
map <pair<string, string>, float> scores = {
|
||||
{{"diabetes", "AODE"}, 0.811198}, {{"diabetes", "KDB"}, 0.852865}, {{"diabetes", "SPODE"}, 0.802083}, {{"diabetes", "TAN"}, 0.821615},
|
||||
{{"ecoli", "AODE"}, 0.889881}, {{"ecoli", "KDB"}, 0.889881}, {{"ecoli", "SPODE"}, 0.880952}, {{"ecoli", "TAN"}, 0.892857},
|
||||
{{"glass", "AODE"}, 0.78972}, {{"glass", "KDB"}, 0.827103}, {{"glass", "SPODE"}, 0.775701}, {{"glass", "TAN"}, 0.827103},
|
||||
{{"iris", "AODE"}, 0.973333}, {{"iris", "KDB"}, 0.973333}, {{"iris", "SPODE"}, 0.973333}, {{"iris", "TAN"}, 0.973333}
|
||||
};
|
||||
|
||||
string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
|
||||
auto [Xd, y, features, className, states] = loadFile(file_name);
|
||||
|
||||
SECTION("Test TAN classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::TAN();
|
||||
clf.fit(Xd, y, features, className, states);
|
||||
auto score = clf.score(Xd, y);
|
||||
//scores[{file_name, "TAN"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "TAN"}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test KDB classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::KDB(2);
|
||||
clf.fit(Xd, y, features, className, states);
|
||||
auto score = clf.score(Xd, y);
|
||||
//scores[{file_name, "KDB"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "KDB"
|
||||
}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test SPODE classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::SPODE(1);
|
||||
clf.fit(Xd, y, features, className, states);
|
||||
auto score = clf.score(Xd, y);
|
||||
// scores[{file_name, "SPODE"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "SPODE"}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test AODE classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::AODE();
|
||||
clf.fit(Xd, y, features, className, states);
|
||||
auto score = clf.score(Xd, y);
|
||||
// scores[{file_name, "AODE"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "AODE"}]).epsilon(1e-6));
|
||||
}
|
||||
// for (auto scores : scores) {
|
||||
// cout << "{{\"" << scores.first.first << "\", \"" << scores.first.second << "\"}, " << scores.second << "}, ";
|
||||
// }
|
||||
}
|
||||
TEST_CASE("Models features")
|
||||
{
|
||||
auto graph = vector<string>({ "digraph BayesNet {\nlabel=<BayesNet Test>\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n",
|
||||
"class [shape=circle, fontcolor=red, fillcolor=lightblue, style=filled ] \n",
|
||||
"class -> sepallength", "class -> sepalwidth", "class -> petallength", "class -> petalwidth", "petallength [shape=circle] \n",
|
||||
"petallength -> sepallength", "petalwidth [shape=circle] \n", "sepallength [shape=circle] \n",
|
||||
"sepallength -> sepalwidth", "sepalwidth [shape=circle] \n", "sepalwidth -> petalwidth", "}\n"
|
||||
}
|
||||
);
|
||||
|
||||
auto clf = bayesnet::TAN();
|
||||
auto [Xd, y, features, className, states] = loadFile("iris");
|
||||
clf.fit(Xd, y, features, className, states);
|
||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||
REQUIRE(clf.getNumberOfEdges() == 7);
|
||||
REQUIRE(clf.show() == vector<string>{"class -> sepallength, sepalwidth, petallength, petalwidth, ", "petallength -> sepallength, ", "petalwidth -> ", "sepallength -> sepalwidth, ", "sepalwidth -> petalwidth, "});
|
||||
REQUIRE(clf.graph("Test") == graph);
|
||||
}
|
||||
TEST_CASE("Get num features & num edges")
|
||||
{
|
||||
auto [Xd, y, features, className, states] = loadFile("iris");
|
||||
auto clf = bayesnet::KDB(2);
|
||||
clf.fit(Xd, y, features, className, states);
|
||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||
REQUIRE(clf.getNumberOfEdges() == 8);
|
||||
}
|
@ -5,7 +5,7 @@ if(ENABLE_TESTING)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
||||
set(TEST_SOURCES BayesModels.cc BayesNetwork.cc ${BayesNet_SOURCES})
|
||||
set(TEST_SOURCES TestBayesModels.cc TestBayesNetwork.cc TestBayesMetrics.cc TestUtils.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCES})
|
||||
add_executable(${TEST_MAIN} ${TEST_SOURCES})
|
||||
target_link_libraries(${TEST_MAIN} PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain)
|
||||
add_test(NAME ${TEST_MAIN} COMMAND ${TEST_MAIN})
|
||||
|
55
tests/TestBayesMetrics.cc
Normal file
55
tests/TestBayesMetrics.cc
Normal file
@ -0,0 +1,55 @@
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include "BayesMetrics.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
TEST_CASE("Metrics Test", "[Metrics]")
|
||||
{
|
||||
SECTION("Test Constructor")
|
||||
{
|
||||
torch::Tensor samples = torch::rand({ 10, 5 });
|
||||
vector<string> features = { "feature1", "feature2", "feature3", "feature4", "feature5" };
|
||||
string className = "class1";
|
||||
int classNumStates = 2;
|
||||
|
||||
bayesnet::Metrics obj(samples, features, className, classNumStates);
|
||||
|
||||
REQUIRE(obj.getScoresKBest().size() == 0);
|
||||
}
|
||||
|
||||
SECTION("Test SelectKBestWeighted")
|
||||
{
|
||||
torch::Tensor samples = torch::rand({ 10, 5 });
|
||||
vector<string> features = { "feature1", "feature2", "feature3", "feature4", "feature5" };
|
||||
string className = "class1";
|
||||
int classNumStates = 2;
|
||||
|
||||
bayesnet::Metrics obj(samples, features, className, classNumStates);
|
||||
|
||||
torch::Tensor weights = torch::ones({ 5 });
|
||||
|
||||
vector<int> kBest = obj.SelectKBestWeighted(weights, true, 3);
|
||||
|
||||
REQUIRE(kBest.size() == 3);
|
||||
}
|
||||
|
||||
SECTION("Test mutualInformation")
|
||||
{
|
||||
torch::Tensor samples = torch::rand({ 10, 5 });
|
||||
vector<string> features = { "feature1", "feature2", "feature3", "feature4", "feature5" };
|
||||
string className = "class1";
|
||||
int classNumStates = 2;
|
||||
|
||||
bayesnet::Metrics obj(samples, features, className, classNumStates);
|
||||
|
||||
torch::Tensor firstFeature = samples.select(1, 0);
|
||||
torch::Tensor secondFeature = samples.select(1, 1);
|
||||
torch::Tensor weights = torch::ones({ 10 });
|
||||
|
||||
double mi = obj.mutualInformation(firstFeature, secondFeature, weights);
|
||||
|
||||
REQUIRE(mi >= 0);
|
||||
}
|
||||
}
|
143
tests/TestBayesModels.cc
Normal file
143
tests/TestBayesModels.cc
Normal file
@ -0,0 +1,143 @@
|
||||
#define CATCH_CONFIG_MAIN // This tells Catch to provide a main() - only do
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "KDB.h"
|
||||
#include "TAN.h"
|
||||
#include "SPODE.h"
|
||||
#include "AODE.h"
|
||||
#include "BoostAODE.h"
|
||||
#include "TANLd.h"
|
||||
#include "KDBLd.h"
|
||||
#include "SPODELd.h"
|
||||
#include "AODELd.h"
|
||||
#include "TestUtils.h"
|
||||
|
||||
TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
|
||||
{
|
||||
map <pair<string, string>, float> scores = {
|
||||
// Diabetes
|
||||
{{"diabetes", "AODE"}, 0.811198}, {{"diabetes", "KDB"}, 0.852865}, {{"diabetes", "SPODE"}, 0.802083}, {{"diabetes", "TAN"}, 0.821615},
|
||||
{{"diabetes", "AODELd"}, 0.811198}, {{"diabetes", "KDBLd"}, 0.852865}, {{"diabetes", "SPODELd"}, 0.802083}, {{"diabetes", "TANLd"}, 0.821615}, {{"diabetes", "BoostAODE"}, 0.821615},
|
||||
// Ecoli
|
||||
{{"ecoli", "AODE"}, 0.889881}, {{"ecoli", "KDB"}, 0.889881}, {{"ecoli", "SPODE"}, 0.880952}, {{"ecoli", "TAN"}, 0.892857},
|
||||
{{"ecoli", "AODELd"}, 0.889881}, {{"ecoli", "KDBLd"}, 0.889881}, {{"ecoli", "SPODELd"}, 0.880952}, {{"ecoli", "TANLd"}, 0.892857}, {{"ecoli", "BoostAODE"}, 0.892857},
|
||||
// Glass
|
||||
{{"glass", "AODE"}, 0.78972}, {{"glass", "KDB"}, 0.827103}, {{"glass", "SPODE"}, 0.775701}, {{"glass", "TAN"}, 0.827103},
|
||||
{{"glass", "AODELd"}, 0.78972}, {{"glass", "KDBLd"}, 0.827103}, {{"glass", "SPODELd"}, 0.775701}, {{"glass", "TANLd"}, 0.827103}, {{"glass", "BoostAODE"}, 0.827103},
|
||||
// Iris
|
||||
{{"iris", "AODE"}, 0.973333}, {{"iris", "KDB"}, 0.973333}, {{"iris", "SPODE"}, 0.973333}, {{"iris", "TAN"}, 0.973333},
|
||||
{{"iris", "AODELd"}, 0.973333}, {{"iris", "KDBLd"}, 0.973333}, {{"iris", "SPODELd"}, 0.973333}, {{"iris", "TANLd"}, 0.973333}, {{"iris", "BoostAODE"}, 0.973333}
|
||||
};
|
||||
|
||||
string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
|
||||
auto [XCont, yCont, featuresCont, classNameCont, statesCont] = loadDataset(file_name, true, false);
|
||||
auto [XDisc, yDisc, featuresDisc, className, statesDisc] = loadFile(file_name);
|
||||
|
||||
SECTION("Test TAN classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::TAN();
|
||||
clf.fit(XDisc, yDisc, featuresDisc, className, statesDisc);
|
||||
auto score = clf.score(XDisc, yDisc);
|
||||
//scores[{file_name, "TAN"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "TAN"}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test TANLd classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::TANLd();
|
||||
clf.fit(XCont, yCont, featuresCont, classNameCont, statesCont);
|
||||
auto score = clf.score(XCont, yCont);
|
||||
//scores[{file_name, "TANLd"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "TANLd"}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test KDB classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::KDB(2);
|
||||
clf.fit(XDisc, yDisc, featuresDisc, className, statesDisc);
|
||||
auto score = clf.score(XDisc, yDisc);
|
||||
//scores[{file_name, "KDB"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "KDB"
|
||||
}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test KDBLd classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::KDBLd(2);
|
||||
clf.fit(XCont, yCont, featuresCont, classNameCont, statesCont);
|
||||
auto score = clf.score(XCont, yCont);
|
||||
//scores[{file_name, "KDBLd"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "KDBLd"
|
||||
}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test SPODE classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::SPODE(1);
|
||||
clf.fit(XDisc, yDisc, featuresDisc, className, statesDisc);
|
||||
auto score = clf.score(XDisc, yDisc);
|
||||
// scores[{file_name, "SPODE"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "SPODE"}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test SPODELd classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::SPODELd(1);
|
||||
clf.fit(XCont, yCont, featuresCont, classNameCont, statesCont);
|
||||
auto score = clf.score(XCont, yCont);
|
||||
// scores[{file_name, "SPODELd"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "SPODELd"}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test AODE classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::AODE();
|
||||
clf.fit(XDisc, yDisc, featuresDisc, className, statesDisc);
|
||||
auto score = clf.score(XDisc, yDisc);
|
||||
// scores[{file_name, "AODE"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "AODE"}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test AODELd classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::AODE();
|
||||
clf.fit(XCont, yCont, featuresCont, classNameCont, statesCont);
|
||||
auto score = clf.score(XCont, yCont);
|
||||
// scores[{file_name, "AODELd"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "AODELd"}]).epsilon(1e-6));
|
||||
}
|
||||
SECTION("Test BoostAODE classifier (" + file_name + ")")
|
||||
{
|
||||
auto clf = bayesnet::BoostAODE();
|
||||
clf.fit(XDisc, yDisc, featuresDisc, className, statesDisc);
|
||||
auto score = clf.score(XDisc, yDisc);
|
||||
// scores[{file_name, "BoostAODE"}] = score;
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, "BoostAODE"}]).epsilon(1e-6));
|
||||
}
|
||||
// for (auto scores : scores) {
|
||||
// cout << "{{\"" << scores.first.first << "\", \"" << scores.first.second << "\"}, " << scores.second << "}, ";
|
||||
// }
|
||||
}
|
||||
TEST_CASE("Models featuresDisc")
|
||||
{
|
||||
auto graph = vector<string>({ "digraph BayesNet {\nlabel=<BayesNet Test>\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n",
|
||||
"class [shape=circle, fontcolor=red, fillcolor=lightblue, style=filled ] \n",
|
||||
"class -> sepallength", "class -> sepalwidth", "class -> petallength", "class -> petalwidth", "petallength [shape=circle] \n",
|
||||
"petallength -> sepallength", "petalwidth [shape=circle] \n", "sepallength [shape=circle] \n",
|
||||
"sepallength -> sepalwidth", "sepalwidth [shape=circle] \n", "sepalwidth -> petalwidth", "}\n"
|
||||
}
|
||||
);
|
||||
|
||||
auto clf = bayesnet::TAN();
|
||||
auto [XDisc, yDisc, featuresDisc, className, statesDisc] = loadFile("iris");
|
||||
clf.fit(XDisc, yDisc, featuresDisc, className, statesDisc);
|
||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||
REQUIRE(clf.getNumberOfEdges() == 7);
|
||||
REQUIRE(clf.show() == vector<string>{"class -> sepallength, sepalwidth, petallength, petalwidth, ", "petallength -> sepallength, ", "petalwidth -> ", "sepallength -> sepalwidth, ", "sepalwidth -> petalwidth, "});
|
||||
REQUIRE(clf.graph("Test") == graph);
|
||||
}
|
||||
TEST_CASE("Get num featuresDisc & num edges")
|
||||
{
|
||||
auto [XDisc, yDisc, featuresDisc, className, statesDisc] = loadFile("iris");
|
||||
auto clf = bayesnet::KDB(2);
|
||||
clf.fit(XDisc, yDisc, featuresDisc, className, statesDisc);
|
||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||
REQUIRE(clf.getNumberOfEdges() == 8);
|
||||
}
|
@ -2,6 +2,7 @@
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <string>
|
||||
#include "TestUtils.h"
|
||||
#include "KDB.h"
|
||||
|
||||
TEST_CASE("Test Bayesian Network")
|
106
tests/TestUtils.cc
Normal file
106
tests/TestUtils.cc
Normal file
@ -0,0 +1,106 @@
|
||||
#include "TestUtils.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
class Paths {
|
||||
public:
|
||||
static string datasets()
|
||||
{
|
||||
return "../data/";
|
||||
}
|
||||
};
|
||||
|
||||
pair<vector<mdlp::labels_t>, map<string, int>> discretize(vector<mdlp::samples_t>& X, mdlp::labels_t& y, vector<string> features)
|
||||
{
|
||||
vector<mdlp::labels_t> Xd;
|
||||
map<string, int> maxes;
|
||||
auto fimdlp = mdlp::CPPFImdlp();
|
||||
for (int i = 0; i < X.size(); i++) {
|
||||
fimdlp.fit(X[i], y);
|
||||
mdlp::labels_t& xd = fimdlp.transform(X[i]);
|
||||
maxes[features[i]] = *max_element(xd.begin(), xd.end()) + 1;
|
||||
Xd.push_back(xd);
|
||||
}
|
||||
return { Xd, maxes };
|
||||
}
|
||||
|
||||
vector<mdlp::labels_t> discretizeDataset(vector<mdlp::samples_t>& X, mdlp::labels_t& y)
|
||||
{
|
||||
vector<mdlp::labels_t> Xd;
|
||||
auto fimdlp = mdlp::CPPFImdlp();
|
||||
for (int i = 0; i < X.size(); i++) {
|
||||
fimdlp.fit(X[i], y);
|
||||
mdlp::labels_t& xd = fimdlp.transform(X[i]);
|
||||
Xd.push_back(xd);
|
||||
}
|
||||
return Xd;
|
||||
}
|
||||
|
||||
bool file_exists(const string& name)
|
||||
{
|
||||
if (FILE* file = fopen(name.c_str(), "r")) {
|
||||
fclose(file);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
tuple<Tensor, Tensor, vector<string>, string, map<string, vector<int>>> loadDataset(const string& name, bool class_last, bool discretize_dataset)
|
||||
{
|
||||
auto handler = ArffFiles();
|
||||
handler.load(Paths::datasets() + static_cast<string>(name) + ".arff", class_last);
|
||||
// Get Dataset X, y
|
||||
vector<mdlp::samples_t>& X = handler.getX();
|
||||
mdlp::labels_t& y = handler.getY();
|
||||
// Get className & Features
|
||||
auto className = handler.getClassName();
|
||||
vector<string> features;
|
||||
auto attributes = handler.getAttributes();
|
||||
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; });
|
||||
Tensor Xd;
|
||||
auto states = map<string, vector<int>>();
|
||||
if (discretize_dataset) {
|
||||
auto Xr = discretizeDataset(X, y);
|
||||
Xd = torch::zeros({ static_cast<int>(Xr[0].size()), static_cast<int>(Xr.size()) }, torch::kInt32);
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
states[features[i]] = vector<int>(*max_element(Xr[i].begin(), Xr[i].end()) + 1);
|
||||
auto item = states.at(features[i]);
|
||||
iota(begin(item), end(item), 0);
|
||||
Xd.index_put_({ "...", i }, torch::tensor(Xr[i], torch::kInt32));
|
||||
}
|
||||
states[className] = vector<int>(*max_element(y.begin(), y.end()) + 1);
|
||||
iota(begin(states.at(className)), end(states.at(className)), 0);
|
||||
} else {
|
||||
Xd = torch::zeros({ static_cast<int>(X[0].size()), static_cast<int>(X.size()) }, torch::kFloat32);
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
Xd.index_put_({ "...", i }, torch::tensor(X[i]));
|
||||
}
|
||||
}
|
||||
return { Xd, torch::tensor(y, torch::kInt32), features, className, states };
|
||||
}
|
||||
|
||||
tuple<vector<vector<int>>, vector<int>, vector<string>, string, map<string, vector<int>>> loadFile(const string& name)
|
||||
{
|
||||
auto handler = ArffFiles();
|
||||
handler.load(Paths::datasets() + static_cast<string>(name) + ".arff");
|
||||
// Get Dataset X, y
|
||||
vector<mdlp::samples_t>& X = handler.getX();
|
||||
mdlp::labels_t& y = handler.getY();
|
||||
// Get className & Features
|
||||
auto className = handler.getClassName();
|
||||
vector<string> features;
|
||||
auto attributes = handler.getAttributes();
|
||||
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; });
|
||||
// Discretize Dataset
|
||||
vector<mdlp::labels_t> Xd;
|
||||
map<string, int> maxes;
|
||||
tie(Xd, maxes) = discretize(X, y, features);
|
||||
maxes[className] = *max_element(y.begin(), y.end()) + 1;
|
||||
map<string, vector<int>> states;
|
||||
for (auto feature : features) {
|
||||
states[feature] = vector<int>(maxes[feature]);
|
||||
}
|
||||
states[className] = vector<int>(maxes[className]);
|
||||
return { Xd, y, features, className, states };
|
||||
}
|
19
tests/TestUtils.h
Normal file
19
tests/TestUtils.h
Normal file
@ -0,0 +1,19 @@
|
||||
#ifndef TEST_UTILS_H
|
||||
#define TEST_UTILS_H
|
||||
#include <torch/torch.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include "ArffFiles.h"
|
||||
#include "CPPFImdlp.h"
|
||||
using namespace std;
|
||||
|
||||
bool file_exists(const std::string& name);
|
||||
pair<vector<mdlp::labels_t>, map<string, int>> discretize(vector<mdlp::samples_t>& X, mdlp::labels_t& y, vector<string> features);
|
||||
vector<mdlp::labels_t> discretizeDataset(vector<mdlp::samples_t>& X, mdlp::labels_t& y);
|
||||
tuple<vector<vector<int>>, vector<int>, vector<string>, string, map<string, vector<int>>> loadFile(const string& name);
|
||||
tuple<torch::Tensor, torch::Tensor, vector<string>, string, map<string, vector<int>>> loadDataset(const string& name, bool class_last, bool discretize_dataset);
|
||||
#endif //TEST_UTILS_H
|
||||
|
||||
#
|
Loading…
Reference in New Issue
Block a user