Fix some lint warnings

This commit is contained in:
Ricardo Montañana Gómez 2023-07-29 19:38:42 +02:00
parent 7222119dfb
commit 9a0449c12d
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
9 changed files with 17 additions and 46 deletions

View File

@ -148,9 +148,9 @@ int main(int argc, char** argv)
// Get className & Features
auto className = handler.getClassName();
vector<string> features;
for (auto feature : handler.getAttributes()) {
features.push_back(feature.first);
}
auto attributes = handler.getAttributes();
transform(attributes.begin(), attributes.end(), back_inserter(features),
[](const pair<string, string>& item) { return item.first; });
// Discretize Dataset
auto [Xd, maxes] = discretize(X, y, features);
maxes[className] = *max_element(y.begin(), y.end()) + 1;
@ -159,12 +159,7 @@ int main(int argc, char** argv)
states[feature] = vector<int>(maxes[feature]);
}
states[className] = vector<int>(maxes[className]);
auto classifiers = map<string, bayesnet::BaseClassifier*>({
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
}
);
bayesnet::BaseClassifier* clf = classifiers[model_name];
auto clf = platform::Models::instance()->create(model_name);
clf->fit(Xd, y, features, className, states);
auto score = clf->score(Xd, y);
auto lines = clf->show();

View File

@ -12,8 +12,8 @@ namespace bayesnet {
: features(features)
, className(className)
, classNumStates(classNumStates)
, samples(torch::zeros({ static_cast<int>(vsamples[0].size()), static_cast<int>(vsamples.size() + 1) }, torch::kInt32))
{
samples = torch::zeros({ static_cast<int>(vsamples[0].size()), static_cast<int>(vsamples.size() + 1) }, torch::kInt32);
for (int i = 0; i < vsamples.size(); ++i) {
samples.index_put_({ "...", i }, torch::tensor(vsamples[i], torch::kInt32));
}
@ -123,7 +123,6 @@ namespace bayesnet {
*/
vector<pair<int, int>> Metrics::maximumSpanningTree(vector<string> features, Tensor& weights, int root)
{
auto result = vector<pair<int, int>>();
auto mst = MST(features, weights, root);
return mst.maximumSpanningTree();
}

View File

@ -11,7 +11,7 @@ namespace bayesnet {
Tensor samples;
vector<string> features;
string className;
int classNumStates;
int classNumStates = 0;
public:
Metrics() = default;
Metrics(Tensor&, vector<string>&, string&, int);

View File

@ -13,7 +13,7 @@ namespace bayesnet {
protected:
void train() override;
public:
KDB(int k, float theta = 0.03);
explicit KDB(int k, float theta = 0.03);
virtual ~KDB() {};
vector<string> graph(string name = "KDB") override;
};

View File

@ -10,7 +10,7 @@ namespace bayesnet {
private:
Tensor weights;
vector<string> features;
int root;
int root = 0;
public:
MST() = default;
MST(vector<string>& features, Tensor& weights, int root);
@ -23,7 +23,7 @@ namespace bayesnet {
vector <pair<float, pair<int, int>>> T; // vector for mst
vector<int> parent;
public:
Graph(int V);
explicit Graph(int V);
void addEdge(int u, int v, float wt);
int find_set(int i);
void union_set(int u, int v);

View File

@ -27,9 +27,9 @@ namespace bayesnet {
void completeFit();
public:
Network();
Network(float, int);
Network(float);
Network(Network&);
explicit Network(float, int);
explicit Network(float);
explicit Network(Network&);
torch::Tensor& getSamples();
float getmaxThreads();
void addNode(string, int);

View File

@ -9,7 +9,7 @@ namespace bayesnet {
protected:
void train() override;
public:
SPODE(int root);
explicit SPODE(int root);
virtual ~SPODE() {};
vector<string> graph(string name = "SPODE") override;
};

View File

@ -2,25 +2,6 @@
namespace platform {
using namespace std;
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
// shared_ptr<bayesnet::BaseClassifier> Models::createInstance(const string& name)
// {
// bayesnet::BaseClassifier* instance = nullptr;
// if (name == "AODE") {
// instance = new bayesnet::AODE();
// } else if (name == "KDB") {
// instance = new bayesnet::KDB(2);
// } else if (name == "SPODE") {
// instance = new bayesnet::SPODE(2);
// } else if (name == "TAN") {
// instance = new bayesnet::TAN();
// } else {
// throw runtime_error("Model " + name + " not found");
// }
// if (instance != nullptr)
// return shared_ptr<bayesnet::BaseClassifier>(instance);
// else
// return nullptr;
// }
Models* Models::factory = nullptr;;
Models* Models::instance()
{

View File

@ -76,10 +76,6 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
}
return program;
}
void registerModels()
{
}
int main(int argc, char** argv)
{
@ -92,7 +88,7 @@ int main(int argc, char** argv)
auto stratified = program.get<bool>("stratified");
auto n_folds = program.get<int>("folds");
auto seeds = program.get<vector<int>>("seeds");
vector<string> filesToProcess;
vector<string> filesToTest;
auto datasets = platform::Datasets(path, true, platform::ARFF);
auto title = program.get<string>("title");
if (file_name != "") {
@ -103,9 +99,9 @@ int main(int argc, char** argv)
if (title == "") {
title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds";
}
filesToProcess.push_back(file_name);
filesToTest.push_back(file_name);
} else {
filesToProcess = platform::Datasets(path, true, platform::ARFF).getNames();
filesToTest = platform::Datasets(path, true, platform::ARFF).getNames();
saveResults = true;
}
@ -121,7 +117,7 @@ int main(int argc, char** argv)
}
platform::Timer timer;
timer.start();
experiment.go(filesToProcess, path);
experiment.go(filesToTest, path);
experiment.setDuration(timer.getDuration());
if (saveResults)
experiment.save(PATH_RESULTS);