Fix some lint warnings
This commit is contained in:
parent
7222119dfb
commit
9a0449c12d
@ -148,9 +148,9 @@ int main(int argc, char** argv)
|
||||
// Get className & Features
|
||||
auto className = handler.getClassName();
|
||||
vector<string> features;
|
||||
for (auto feature : handler.getAttributes()) {
|
||||
features.push_back(feature.first);
|
||||
}
|
||||
auto attributes = handler.getAttributes();
|
||||
transform(attributes.begin(), attributes.end(), back_inserter(features),
|
||||
[](const pair<string, string>& item) { return item.first; });
|
||||
// Discretize Dataset
|
||||
auto [Xd, maxes] = discretize(X, y, features);
|
||||
maxes[className] = *max_element(y.begin(), y.end()) + 1;
|
||||
@ -159,12 +159,7 @@ int main(int argc, char** argv)
|
||||
states[feature] = vector<int>(maxes[feature]);
|
||||
}
|
||||
states[className] = vector<int>(maxes[className]);
|
||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||
}
|
||||
);
|
||||
bayesnet::BaseClassifier* clf = classifiers[model_name];
|
||||
auto clf = platform::Models::instance()->create(model_name);
|
||||
clf->fit(Xd, y, features, className, states);
|
||||
auto score = clf->score(Xd, y);
|
||||
auto lines = clf->show();
|
||||
|
@ -12,8 +12,8 @@ namespace bayesnet {
|
||||
: features(features)
|
||||
, className(className)
|
||||
, classNumStates(classNumStates)
|
||||
, samples(torch::zeros({ static_cast<int>(vsamples[0].size()), static_cast<int>(vsamples.size() + 1) }, torch::kInt32))
|
||||
{
|
||||
samples = torch::zeros({ static_cast<int>(vsamples[0].size()), static_cast<int>(vsamples.size() + 1) }, torch::kInt32);
|
||||
for (int i = 0; i < vsamples.size(); ++i) {
|
||||
samples.index_put_({ "...", i }, torch::tensor(vsamples[i], torch::kInt32));
|
||||
}
|
||||
@ -123,7 +123,6 @@ namespace bayesnet {
|
||||
*/
|
||||
vector<pair<int, int>> Metrics::maximumSpanningTree(vector<string> features, Tensor& weights, int root)
|
||||
{
|
||||
auto result = vector<pair<int, int>>();
|
||||
auto mst = MST(features, weights, root);
|
||||
return mst.maximumSpanningTree();
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ namespace bayesnet {
|
||||
Tensor samples;
|
||||
vector<string> features;
|
||||
string className;
|
||||
int classNumStates;
|
||||
int classNumStates = 0;
|
||||
public:
|
||||
Metrics() = default;
|
||||
Metrics(Tensor&, vector<string>&, string&, int);
|
||||
|
@ -13,7 +13,7 @@ namespace bayesnet {
|
||||
protected:
|
||||
void train() override;
|
||||
public:
|
||||
KDB(int k, float theta = 0.03);
|
||||
explicit KDB(int k, float theta = 0.03);
|
||||
virtual ~KDB() {};
|
||||
vector<string> graph(string name = "KDB") override;
|
||||
};
|
||||
|
@ -10,7 +10,7 @@ namespace bayesnet {
|
||||
private:
|
||||
Tensor weights;
|
||||
vector<string> features;
|
||||
int root;
|
||||
int root = 0;
|
||||
public:
|
||||
MST() = default;
|
||||
MST(vector<string>& features, Tensor& weights, int root);
|
||||
@ -23,7 +23,7 @@ namespace bayesnet {
|
||||
vector <pair<float, pair<int, int>>> T; // vector for mst
|
||||
vector<int> parent;
|
||||
public:
|
||||
Graph(int V);
|
||||
explicit Graph(int V);
|
||||
void addEdge(int u, int v, float wt);
|
||||
int find_set(int i);
|
||||
void union_set(int u, int v);
|
||||
|
@ -27,9 +27,9 @@ namespace bayesnet {
|
||||
void completeFit();
|
||||
public:
|
||||
Network();
|
||||
Network(float, int);
|
||||
Network(float);
|
||||
Network(Network&);
|
||||
explicit Network(float, int);
|
||||
explicit Network(float);
|
||||
explicit Network(Network&);
|
||||
torch::Tensor& getSamples();
|
||||
float getmaxThreads();
|
||||
void addNode(string, int);
|
||||
|
@ -9,7 +9,7 @@ namespace bayesnet {
|
||||
protected:
|
||||
void train() override;
|
||||
public:
|
||||
SPODE(int root);
|
||||
explicit SPODE(int root);
|
||||
virtual ~SPODE() {};
|
||||
vector<string> graph(string name = "SPODE") override;
|
||||
};
|
||||
|
@ -2,25 +2,6 @@
|
||||
namespace platform {
|
||||
using namespace std;
|
||||
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
||||
// shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
|
||||
// {
|
||||
// bayesnet::BaseClassifier* instance = nullptr;
|
||||
// if (name == "AODE") {
|
||||
// instance = new bayesnet::AODE();
|
||||
// } else if (name == "KDB") {
|
||||
// instance = new bayesnet::KDB(2);
|
||||
// } else if (name == "SPODE") {
|
||||
// instance = new bayesnet::SPODE(2);
|
||||
// } else if (name == "TAN") {
|
||||
// instance = new bayesnet::TAN();
|
||||
// } else {
|
||||
// throw runtime_error("Model " + name + " not found");
|
||||
// }
|
||||
// if (instance != nullptr)
|
||||
// return shared_ptr<bayesnet::BaseClassifier>(instance);
|
||||
// else
|
||||
// return nullptr;
|
||||
// }
|
||||
Models* Models::factory = nullptr;;
|
||||
Models* Models::instance()
|
||||
{
|
||||
|
@ -76,10 +76,6 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
}
|
||||
return program;
|
||||
}
|
||||
void registerModels()
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
@ -92,7 +88,7 @@ int main(int argc, char** argv)
|
||||
auto stratified = program.get<bool>("stratified");
|
||||
auto n_folds = program.get<int>("folds");
|
||||
auto seeds = program.get<vector<int>>("seeds");
|
||||
vector<string> filesToProcess;
|
||||
vector<string> filesToTest;
|
||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||
auto title = program.get<string>("title");
|
||||
if (file_name != "") {
|
||||
@ -103,9 +99,9 @@ int main(int argc, char** argv)
|
||||
if (title == "") {
|
||||
title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds";
|
||||
}
|
||||
filesToProcess.push_back(file_name);
|
||||
filesToTest.push_back(file_name);
|
||||
} else {
|
||||
filesToProcess = platform::Datasets(path, true, platform::ARFF).getNames();
|
||||
filesToTest = platform::Datasets(path, true, platform::ARFF).getNames();
|
||||
saveResults = true;
|
||||
}
|
||||
|
||||
@ -121,7 +117,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
platform::Timer timer;
|
||||
timer.start();
|
||||
experiment.go(filesToProcess, path);
|
||||
experiment.go(filesToTest, path);
|
||||
experiment.setDuration(timer.getDuration());
|
||||
if (saveResults)
|
||||
experiment.save(PATH_RESULTS);
|
||||
|
Loading…
Reference in New Issue
Block a user