Add Models class
This commit is contained in:
@@ -4,16 +4,11 @@
|
||||
#include <thread>
|
||||
#include <map>
|
||||
#include <argparse/argparse.hpp>
|
||||
#include "BaseClassifier.h"
|
||||
#include "ArffFiles.h"
|
||||
#include "Network.h"
|
||||
#include "BayesMetrics.h"
|
||||
#include "CPPFImdlp.h"
|
||||
#include "KDB.h"
|
||||
#include "SPODE.h"
|
||||
#include "AODE.h"
|
||||
#include "TAN.h"
|
||||
#include "Folding.h"
|
||||
#include "Models.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
@@ -91,13 +86,13 @@ int main(int argc, char** argv)
|
||||
.default_value(string{ PATH }
|
||||
);
|
||||
program.add_argument("-m", "--model")
|
||||
.help("Model to use {AODE, KDB, SPODE, TAN}")
|
||||
.help("Model to use " + platform::Models::toString())
|
||||
.action([](const std::string& value) {
|
||||
static const vector<string> choices = { "AODE", "KDB", "SPODE", "TAN" };
|
||||
static const vector<string> choices = platform::Models::getNames();
|
||||
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
||||
return value;
|
||||
}
|
||||
throw runtime_error("Model must be one of {AODE, KDB, SPODE, TAN}");
|
||||
throw runtime_error("Model must be one of " + platform::Models::toString());
|
||||
}
|
||||
);
|
||||
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);
|
||||
@@ -164,12 +159,8 @@ int main(int argc, char** argv)
|
||||
states[feature] = vector<int>(maxes[feature]);
|
||||
}
|
||||
states[className] = vector<int>(maxes[className]);
|
||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||
}
|
||||
);
|
||||
bayesnet::BaseClassifier* clf = classifiers[model_name];
|
||||
|
||||
bayesnet::BaseClassifier* clf = platform::Models::get(model_name);
|
||||
clf->fit(Xd, y, features, className, states);
|
||||
auto score = clf->score(Xd, y);
|
||||
auto lines = clf->show();
|
||||
|
Reference in New Issue
Block a user