Refactor Library renaming Base classes

This commit is contained in:
2023-07-22 23:07:56 +02:00
parent 41cceece20
commit 9981ad1811
16 changed files with 102 additions and 134 deletions

View File

@@ -2,7 +2,9 @@
#include <string>
#include <torch/torch.h>
#include <thread>
#include <map>
#include <argparse/argparse.hpp>
#include "BaseClassifier.h"
#include "ArffFiles.h"
#include "Network.h"
#include "BayesMetrics.h"
@@ -143,38 +145,12 @@ int main(int argc, char** argv)
states[className] = vector<int>(
maxes[className]);
double score;
vector<string> lines;
vector<string> graph;
auto kdb = bayesnet::KDB(2);
auto aode = bayesnet::AODE();
auto spode = bayesnet::SPODE(2);
auto tan = bayesnet::TAN();
switch (hash_conv(model_name)) {
case "AODE"_sh:
aode.fit(Xd, y, features, className, states);
lines = aode.show();
score = aode.score(Xd, y);
graph = aode.graph();
break;
case "KDB"_sh:
kdb.fit(Xd, y, features, className, states);
lines = kdb.show();
score = kdb.score(Xd, y);
graph = kdb.graph();
break;
case "SPODE"_sh:
spode.fit(Xd, y, features, className, states);
lines = spode.show();
score = spode.score(Xd, y);
graph = spode.graph();
break;
case "TAN"_sh:
tan.fit(Xd, y, features, className, states);
lines = tan.show();
score = tan.score(Xd, y);
graph = tan.graph();
break;
}
auto classifiers = map<string, bayesnet::BaseClassifier*>({ { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } });
bayesnet::BaseClassifier* clf = classifiers[model_name];
clf->fit(Xd, y, features, className, states);
score = clf->score(Xd, y);
auto lines = clf->show();
auto graph = clf->graph();
for (auto line : lines) {
cout << line << endl;
}