Refactor Library renaming Base classes
This commit is contained in:
@@ -2,7 +2,9 @@
|
||||
#include <string>
|
||||
#include <torch/torch.h>
|
||||
#include <thread>
|
||||
#include <map>
|
||||
#include <argparse/argparse.hpp>
|
||||
#include "BaseClassifier.h"
|
||||
#include "ArffFiles.h"
|
||||
#include "Network.h"
|
||||
#include "BayesMetrics.h"
|
||||
@@ -143,38 +145,12 @@ int main(int argc, char** argv)
|
||||
states[className] = vector<int>(
|
||||
maxes[className]);
|
||||
double score;
|
||||
vector<string> lines;
|
||||
vector<string> graph;
|
||||
auto kdb = bayesnet::KDB(2);
|
||||
auto aode = bayesnet::AODE();
|
||||
auto spode = bayesnet::SPODE(2);
|
||||
auto tan = bayesnet::TAN();
|
||||
switch (hash_conv(model_name)) {
|
||||
case "AODE"_sh:
|
||||
aode.fit(Xd, y, features, className, states);
|
||||
lines = aode.show();
|
||||
score = aode.score(Xd, y);
|
||||
graph = aode.graph();
|
||||
break;
|
||||
case "KDB"_sh:
|
||||
kdb.fit(Xd, y, features, className, states);
|
||||
lines = kdb.show();
|
||||
score = kdb.score(Xd, y);
|
||||
graph = kdb.graph();
|
||||
break;
|
||||
case "SPODE"_sh:
|
||||
spode.fit(Xd, y, features, className, states);
|
||||
lines = spode.show();
|
||||
score = spode.score(Xd, y);
|
||||
graph = spode.graph();
|
||||
break;
|
||||
case "TAN"_sh:
|
||||
tan.fit(Xd, y, features, className, states);
|
||||
lines = tan.show();
|
||||
score = tan.score(Xd, y);
|
||||
graph = tan.graph();
|
||||
break;
|
||||
}
|
||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({ { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } });
|
||||
bayesnet::BaseClassifier* clf = classifiers[model_name];
|
||||
clf->fit(Xd, y, features, className, states);
|
||||
score = clf->score(Xd, y);
|
||||
auto lines = clf->show();
|
||||
auto graph = clf->graph();
|
||||
for (auto line : lines) {
|
||||
cout << line << endl;
|
||||
}
|
||||
|
Reference in New Issue
Block a user