diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index eb3ffeb..1c8abeb 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -20,7 +20,7 @@ namespace bayesnet { { return samples; } - void Network::addNode(string name, int numStates) + void Network::addNode(const string& name, int numStates) { if (find(features.begin(), features.end(), name) == features.end()) { features.push_back(name); @@ -69,7 +69,7 @@ namespace bayesnet { recStack.erase(nodeId); // remove node from recursion stack before function ends return false; } - void Network::addEdge(const string parent, const string child) + void Network::addEdge(const string& parent, const string& child) { if (nodes.find(parent) == nodes.end()) { throw invalid_argument("Parent node " + parent + " does not exist"); @@ -105,8 +105,8 @@ namespace bayesnet { for (int i = 0; i < featureNames.size(); ++i) { auto column = torch::flatten(X.index({ "...", i })); auto k = vector(); - for (auto i = 0; i < X.size(0); ++i) { - k.push_back(column[i].item()); + for (auto z = 0; z < X.size(0); ++z) { + k.push_back(column[z].item()); } dataset[featureNames[i]] = k; } @@ -280,7 +280,7 @@ namespace bayesnet { } return result; } - vector Network::graph(string title) + vector Network::graph(const string& title) { auto output = vector(); auto prefix = "digraph BayesNet {\nlabel=>& getNodes(); vector getFeatures(); int getStates(); @@ -48,7 +48,7 @@ namespace bayesnet { vector> predict_proba(const vector>&); double score(const vector>&, const vector&); vector show(); - vector graph(string title); // Returns a vector of strings representing the graph in graphviz format + vector graph(const string& title); // Returns a vector of strings representing the graph in graphviz format inline string version() { return "0.1.0"; } }; } diff --git a/src/Platform/DotEnv.h b/src/Platform/DotEnv.h index af6eda2..a7e3e36 100644 --- a/src/Platform/DotEnv.h +++ b/src/Platform/DotEnv.h @@ -52,9 +52,9 @@ namespace platform { seeds_str = trim(seeds_str); seeds_str = seeds_str.substr(1, seeds_str.size() - 2); auto seeds_str_split = split(seeds_str, ','); - for (auto seed_str : seeds_str_split) { - seeds.push_back(stoi(seed_str)); - } + transform(seeds_str_split.begin(), seeds_str_split.end(), back_inserter(seeds), [](const std::string& str) { + return stoi(str); + }); return seeds; } }; diff --git a/src/Platform/Models.cc b/src/Platform/Models.cc index df1b517..1a66156 100644 --- a/src/Platform/Models.cc +++ b/src/Platform/Models.cc @@ -40,7 +40,7 @@ namespace platform { string Models::toString() { string result = ""; - for (auto& pair : functionRegistry) { + for (const auto& pair : functionRegistry) { result += pair.first + ", "; } return "{" + result.substr(0, result.size() - 2) + "}"; diff --git a/src/Platform/main.cc b/src/Platform/main.cc index d9dfb40..55c0cfe 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -49,22 +49,17 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) }}); auto seed_values = env.getSeeds(); program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); - bool class_last, discretize_dataset, stratified; - int n_folds; - vector seeds; - string model_name, file_name, path, complete_file_name, title; try { program.parse_args(argc, argv); - file_name = program.get("dataset"); - path = program.get("path"); - model_name = program.get("model"); - discretize_dataset = program.get("discretize"); - stratified = program.get("stratified"); - n_folds = program.get("folds"); - seeds = program.get>("seeds"); - complete_file_name = path + file_name + ".arff"; - class_last = false;//datasets[file_name]; - title = program.get("title"); + auto file_name = program.get("dataset"); + auto path = program.get("path"); + auto model_name = program.get("model"); + auto discretize_dataset = program.get("discretize"); + auto stratified = program.get("stratified"); + auto n_folds = program.get("folds"); + auto seeds = program.get>("seeds"); + auto complete_file_name = path + file_name + ".arff"; + auto title = program.get("title"); if (title == "" && file_name == "") { throw runtime_error("title is mandatory if dataset is not provided"); } diff --git a/src/Platform/platformUtils.cc b/src/Platform/platformUtils.cc index f318831..6fca9d9 100644 --- a/src/Platform/platformUtils.cc +++ b/src/Platform/platformUtils.cc @@ -2,7 +2,7 @@ using namespace torch; -vector split(string text, char delimiter) +vector split(const string& text, char delimiter) { vector result; stringstream ss(text); @@ -39,7 +39,7 @@ vector discretizeDataset(vector& X, mdlp::label return Xd; } -bool file_exists(const std::string& name) +bool file_exists(const string& name) { if (FILE* file = fopen(name.c_str(), "r")) { fclose(file); @@ -49,7 +49,7 @@ bool file_exists(const std::string& name) } } -tuple, string, map>> loadDataset(string path, string name, bool class_last, bool discretize_dataset) +tuple, string, map>> loadDataset(const string& path, const string& name, bool class_last, bool discretize_dataset) { auto handler = ArffFiles(); handler.load(path + static_cast(name) + ".arff", class_last); @@ -59,9 +59,8 @@ tuple, string, map>> loadData // Get className & Features auto className = handler.getClassName(); vector features; - for (auto feature : handler.getAttributes()) { - features.push_back(feature.first); - } + auto attributes = handler.getAttributes(); + transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; }); Tensor Xd; auto states = map>(); if (discretize_dataset) { @@ -83,7 +82,7 @@ tuple, string, map>> loadData return { Xd, torch::tensor(y, torch::kInt32), features, className, states }; } -tuple>, vector, vector, string, map>> loadFile(string name) +tuple>, vector, vector, string, map>> loadFile(const string& name) { auto handler = ArffFiles(); handler.load(PATH + static_cast(name) + ".arff"); @@ -93,9 +92,8 @@ tuple>, vector, vector, string, map features; - for (auto feature : handler.getAttributes()) { - features.push_back(feature.first); - } + auto attributes = handler.getAttributes(); + transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; }); // Discretize Dataset vector Xd; map maxes; diff --git a/src/Platform/platformUtils.h b/src/Platform/platformUtils.h index 9515bbf..2b4ca54 100644 --- a/src/Platform/platformUtils.h +++ b/src/Platform/platformUtils.h @@ -11,11 +11,11 @@ using namespace std; const string PATH = "../../data/"; bool file_exists(const std::string& name); -vector split(string text, char delimiter); +vector split(const string& text, char delimiter); pair, map> discretize(vector& X, mdlp::labels_t& y, vector features); vector discretizeDataset(vector& X, mdlp::labels_t& y); -pair>> discretizeTorch(torch::Tensor& X, torch::Tensor& y, vector& features, string className); -tuple>, vector, vector, string, map>> loadFile(string name); -tuple, string, map>> loadDataset(string path, string name, bool class_last, bool discretize_dataset); +pair>> discretizeTorch(torch::Tensor& X, torch::Tensor& y, vector& features, const string& className); +tuple>, vector, vector, string, map>> loadFile(const string& name); +tuple, string, map>> loadDataset(const string& path, const string& name, bool class_last, bool discretize_dataset); map> get_states(vector& features, string className, map& maxes); #endif //PLATFORM_UTILS_H