diff --git a/src/BayesNet/Mst.cc b/src/BayesNet/Mst.cc index b86812b..5e94dc3 100644 --- a/src/BayesNet/Mst.cc +++ b/src/BayesNet/Mst.cc @@ -7,7 +7,7 @@ namespace bayesnet { using namespace std; - Graph::Graph(int V) + Graph::Graph(int V) : V(V) { parent = vector(V); for (int i = 0; i < V; i++) @@ -34,10 +34,10 @@ namespace bayesnet { } void Graph::kruskal_algorithm() { - int i, uSt, vEd; // sort the edges ordered on decreasing weight - sort(G.begin(), G.end(), [](auto& left, auto& right) {return left.first > right.first;}); - for (i = 0; i < G.size(); i++) { + sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;}); + for (int i = 0; i < G.size(); i++) { + int uSt, vEd; uSt = find_set(G[i].second.first); vEd = find_set(G[i].second.second); if (uSt != vEd) { diff --git a/src/BayesNet/Node.cc b/src/BayesNet/Node.cc index d33fecf..095cff7 100644 --- a/src/BayesNet/Node.cc +++ b/src/BayesNet/Node.cc @@ -88,18 +88,15 @@ namespace bayesnet { { // Get dimensions of the CPT dimensions.push_back(numStates); - for (auto father : getParents()) { - dimensions.push_back(father->getNumStates()); - } + transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); }); + // Create a tensor of zeros with the dimensions of the CPT cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing; // Fill table with counts for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) { torch::List> coordinates; coordinates.push_back(torch::tensor(dataset[name][n_sample])); - for (auto father : getParents()) { - coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample])); - } + transform(parents.begin(), parents.end(), back_inserter(coordinates), [&dataset, &n_sample](const auto& parent) { return torch::tensor(dataset[parent->getName()][n_sample]); }); // Increment the count of the corresponding coordinate cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1); } @@ -111,19 +108,15 @@ namespace bayesnet { torch::List> coordinates; // following predetermined order of indices in the cpTable (see Node.h) coordinates.push_back(torch::tensor(evidence[name])); - for (auto parent : getParents()) { - coordinates.push_back(torch::tensor(evidence[parent->getName()])); - } + transform(parents.begin(), parents.end(), back_inserter(coordinates), [&evidence](const auto& parent) { return torch::tensor(evidence[parent->getName()]); }); return cpTable.index({ coordinates }).item(); } - vector Node::graph(string className) + vector Node::graph(const string& className) { auto output = vector(); auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : ""; output.push_back(name + " [shape=circle" + suffix + "] \n"); - for (auto& child : children) { - output.push_back(name + " -> " + child->getName()); - } + transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); }); return output; } } \ No newline at end of file diff --git a/src/BayesNet/Node.h b/src/BayesNet/Node.h index 5c5932a..3a5bbe6 100644 --- a/src/BayesNet/Node.h +++ b/src/BayesNet/Node.h @@ -16,7 +16,7 @@ namespace bayesnet { vector dimensions; // dimensions of the cpTable public: vector> combinations(const vector&); - Node(const std::string&, int); + Node(const string&, int); void clear(); void addParent(Node*); void addChild(Node*); @@ -30,7 +30,7 @@ namespace bayesnet { int getNumStates() const; void setNumStates(int); unsigned minFill(); - vector graph(string clasName); // Returns a vector of strings representing the graph in graphviz format + vector graph(const string& clasName); // Returns a vector of strings representing the graph in graphviz format float getFactorValue(map&); }; } diff --git a/src/Platform/Datasets.cc b/src/Platform/Datasets.cc index 1fef8d0..11b83ac 100644 --- a/src/Platform/Datasets.cc +++ b/src/Platform/Datasets.cc @@ -21,9 +21,7 @@ namespace platform { vector Datasets::getNames() { vector result; - for (auto& d : datasets) { - result.push_back(d.first); - } + transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; }); return result; } vector Datasets::getFeatures(string name) @@ -79,7 +77,7 @@ namespace platform { } return datasets[name]->getTensors(); } - bool Datasets::isDataset(string name) + bool Datasets::isDataset(const string& name) { return datasets.find(name) != datasets.end(); } @@ -193,9 +191,8 @@ namespace platform { yv = arff.getY(); // Get className & Features className = arff.getClassName(); - for (auto feature : arff.getAttributes()) { - features.push_back(feature.first); - } + auto attributes = arff.getAttributes(); + transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; }); } void Dataset::load() { diff --git a/src/Platform/Datasets.h b/src/Platform/Datasets.h index 7b68c71..4ccd1f0 100644 --- a/src/Platform/Datasets.h +++ b/src/Platform/Datasets.h @@ -49,7 +49,7 @@ namespace platform { bool discretize; void load(); // Loads the list of datasets public: - Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); }; + explicit Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); }; vector getNames(); vector getFeatures(string name); int getNSamples(string name); @@ -58,7 +58,7 @@ namespace platform { pair>&, vector&> getVectors(string name); pair>&, vector&> getVectorsDiscretized(string name); pair getTensors(string name); - bool isDataset(string name); + bool isDataset(const string& name); }; };