Fix some more lint warnings

This commit is contained in:
Ricardo Montañana Gómez 2023-07-30 00:04:18 +02:00
parent 8b2ed26ab7
commit b882569169
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 18 additions and 28 deletions

View File

@ -7,7 +7,7 @@
namespace bayesnet {
using namespace std;
Graph::Graph(int V)
Graph::Graph(int V) : V(V)
{
parent = vector<int>(V);
for (int i = 0; i < V; i++)
@ -34,10 +34,10 @@ namespace bayesnet {
}
void Graph::kruskal_algorithm()
{
int i, uSt, vEd;
// sort the edges ordered on decreasing weight
sort(G.begin(), G.end(), [](auto& left, auto& right) {return left.first > right.first;});
for (i = 0; i < G.size(); i++) {
sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;});
for (int i = 0; i < G.size(); i++) {
int uSt, vEd;
uSt = find_set(G[i].second.first);
vEd = find_set(G[i].second.second);
if (uSt != vEd) {

View File

@ -88,18 +88,15 @@ namespace bayesnet {
{
// Get dimensions of the CPT
dimensions.push_back(numStates);
for (auto father : getParents()) {
dimensions.push_back(father->getNumStates());
}
transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
// Create a tensor of zeros with the dimensions of the CPT
cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
// Fill table with counts
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
torch::List<c10::optional<torch::Tensor>> coordinates;
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
for (auto father : getParents()) {
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
}
transform(parents.begin(), parents.end(), back_inserter(coordinates), [&dataset, &n_sample](const auto& parent) { return torch::tensor(dataset[parent->getName()][n_sample]); });
// Increment the count of the corresponding coordinate
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1);
}
@ -111,19 +108,15 @@ namespace bayesnet {
torch::List<c10::optional<torch::Tensor>> coordinates;
// following predetermined order of indices in the cpTable (see Node.h)
coordinates.push_back(torch::tensor(evidence[name]));
for (auto parent : getParents()) {
coordinates.push_back(torch::tensor(evidence[parent->getName()]));
}
transform(parents.begin(), parents.end(), back_inserter(coordinates), [&evidence](const auto& parent) { return torch::tensor(evidence[parent->getName()]); });
return cpTable.index({ coordinates }).item<float>();
}
vector<string> Node::graph(string className)
vector<string> Node::graph(const string& className)
{
auto output = vector<string>();
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
output.push_back(name + " [shape=circle" + suffix + "] \n");
for (auto& child : children) {
output.push_back(name + " -> " + child->getName());
}
transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
return output;
}
}

View File

@ -16,7 +16,7 @@ namespace bayesnet {
vector<int64_t> dimensions; // dimensions of the cpTable
public:
vector<pair<string, string>> combinations(const vector<string>&);
Node(const std::string&, int);
Node(const string&, int);
void clear();
void addParent(Node*);
void addChild(Node*);
@ -30,7 +30,7 @@ namespace bayesnet {
int getNumStates() const;
void setNumStates(int);
unsigned minFill();
vector<string> graph(string clasName); // Returns a vector of strings representing the graph in graphviz format
vector<string> graph(const string& clasName); // Returns a vector of strings representing the graph in graphviz format
float getFactorValue(map<string, int>&);
};
}

View File

@ -21,9 +21,7 @@ namespace platform {
vector<string> Datasets::getNames()
{
vector<string> result;
for (auto& d : datasets) {
result.push_back(d.first);
}
transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; });
return result;
}
vector<string> Datasets::getFeatures(string name)
@ -79,7 +77,7 @@ namespace platform {
}
return datasets[name]->getTensors();
}
bool Datasets::isDataset(string name)
bool Datasets::isDataset(const string& name)
{
return datasets.find(name) != datasets.end();
}
@ -193,9 +191,8 @@ namespace platform {
yv = arff.getY();
// Get className & Features
className = arff.getClassName();
for (auto feature : arff.getAttributes()) {
features.push_back(feature.first);
}
auto attributes = arff.getAttributes();
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; });
}
void Dataset::load()
{

View File

@ -49,7 +49,7 @@ namespace platform {
bool discretize;
void load(); // Loads the list of datasets
public:
Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
explicit Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
vector<string> getNames();
vector<string> getFeatures(string name);
int getNSamples(string name);
@ -58,7 +58,7 @@ namespace platform {
pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
bool isDataset(string name);
bool isDataset(const string& name);
};
};