Fix some more lint warnings

This commit is contained in:
Ricardo Montañana Gómez 2023-07-30 00:04:18 +02:00
parent 8b2ed26ab7
commit b882569169
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 18 additions and 28 deletions

View File

@ -7,7 +7,7 @@
namespace bayesnet { namespace bayesnet {
using namespace std; using namespace std;
Graph::Graph(int V) Graph::Graph(int V) : V(V)
{ {
parent = vector<int>(V); parent = vector<int>(V);
for (int i = 0; i < V; i++) for (int i = 0; i < V; i++)
@ -34,10 +34,10 @@ namespace bayesnet {
} }
void Graph::kruskal_algorithm() void Graph::kruskal_algorithm()
{ {
int i, uSt, vEd;
// sort the edges ordered on decreasing weight // sort the edges ordered on decreasing weight
sort(G.begin(), G.end(), [](auto& left, auto& right) {return left.first > right.first;}); sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;});
for (i = 0; i < G.size(); i++) { for (int i = 0; i < G.size(); i++) {
int uSt, vEd;
uSt = find_set(G[i].second.first); uSt = find_set(G[i].second.first);
vEd = find_set(G[i].second.second); vEd = find_set(G[i].second.second);
if (uSt != vEd) { if (uSt != vEd) {

View File

@ -88,18 +88,15 @@ namespace bayesnet {
{ {
// Get dimensions of the CPT // Get dimensions of the CPT
dimensions.push_back(numStates); dimensions.push_back(numStates);
for (auto father : getParents()) { transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
dimensions.push_back(father->getNumStates());
}
// Create a tensor of zeros with the dimensions of the CPT // Create a tensor of zeros with the dimensions of the CPT
cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing; cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
// Fill table with counts // Fill table with counts
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) { for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
torch::List<c10::optional<torch::Tensor>> coordinates; torch::List<c10::optional<torch::Tensor>> coordinates;
coordinates.push_back(torch::tensor(dataset[name][n_sample])); coordinates.push_back(torch::tensor(dataset[name][n_sample]));
for (auto father : getParents()) { transform(parents.begin(), parents.end(), back_inserter(coordinates), [&dataset, &n_sample](const auto& parent) { return torch::tensor(dataset[parent->getName()][n_sample]); });
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
}
// Increment the count of the corresponding coordinate // Increment the count of the corresponding coordinate
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1); cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1);
} }
@ -111,19 +108,15 @@ namespace bayesnet {
torch::List<c10::optional<torch::Tensor>> coordinates; torch::List<c10::optional<torch::Tensor>> coordinates;
// following predetermined order of indices in the cpTable (see Node.h) // following predetermined order of indices in the cpTable (see Node.h)
coordinates.push_back(torch::tensor(evidence[name])); coordinates.push_back(torch::tensor(evidence[name]));
for (auto parent : getParents()) { transform(parents.begin(), parents.end(), back_inserter(coordinates), [&evidence](const auto& parent) { return torch::tensor(evidence[parent->getName()]); });
coordinates.push_back(torch::tensor(evidence[parent->getName()]));
}
return cpTable.index({ coordinates }).item<float>(); return cpTable.index({ coordinates }).item<float>();
} }
vector<string> Node::graph(string className) vector<string> Node::graph(const string& className)
{ {
auto output = vector<string>(); auto output = vector<string>();
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : ""; auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
output.push_back(name + " [shape=circle" + suffix + "] \n"); output.push_back(name + " [shape=circle" + suffix + "] \n");
for (auto& child : children) { transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
output.push_back(name + " -> " + child->getName());
}
return output; return output;
} }
} }

View File

@ -16,7 +16,7 @@ namespace bayesnet {
vector<int64_t> dimensions; // dimensions of the cpTable vector<int64_t> dimensions; // dimensions of the cpTable
public: public:
vector<pair<string, string>> combinations(const vector<string>&); vector<pair<string, string>> combinations(const vector<string>&);
Node(const std::string&, int); Node(const string&, int);
void clear(); void clear();
void addParent(Node*); void addParent(Node*);
void addChild(Node*); void addChild(Node*);
@ -30,7 +30,7 @@ namespace bayesnet {
int getNumStates() const; int getNumStates() const;
void setNumStates(int); void setNumStates(int);
unsigned minFill(); unsigned minFill();
vector<string> graph(string clasName); // Returns a vector of strings representing the graph in graphviz format vector<string> graph(const string& clasName); // Returns a vector of strings representing the graph in graphviz format
float getFactorValue(map<string, int>&); float getFactorValue(map<string, int>&);
}; };
} }

View File

@ -21,9 +21,7 @@ namespace platform {
vector<string> Datasets::getNames() vector<string> Datasets::getNames()
{ {
vector<string> result; vector<string> result;
for (auto& d : datasets) { transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; });
result.push_back(d.first);
}
return result; return result;
} }
vector<string> Datasets::getFeatures(string name) vector<string> Datasets::getFeatures(string name)
@ -79,7 +77,7 @@ namespace platform {
} }
return datasets[name]->getTensors(); return datasets[name]->getTensors();
} }
bool Datasets::isDataset(string name) bool Datasets::isDataset(const string& name)
{ {
return datasets.find(name) != datasets.end(); return datasets.find(name) != datasets.end();
} }
@ -193,9 +191,8 @@ namespace platform {
yv = arff.getY(); yv = arff.getY();
// Get className & Features // Get className & Features
className = arff.getClassName(); className = arff.getClassName();
for (auto feature : arff.getAttributes()) { auto attributes = arff.getAttributes();
features.push_back(feature.first); transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; });
}
} }
void Dataset::load() void Dataset::load()
{ {

View File

@ -49,7 +49,7 @@ namespace platform {
bool discretize; bool discretize;
void load(); // Loads the list of datasets void load(); // Loads the list of datasets
public: public:
Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); }; explicit Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
vector<string> getNames(); vector<string> getNames();
vector<string> getFeatures(string name); vector<string> getFeatures(string name);
int getNSamples(string name); int getNSamples(string name);
@ -58,7 +58,7 @@ namespace platform {
pair<vector<vector<float>>&, vector<int>&> getVectors(string name); pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name); pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
pair<torch::Tensor&, torch::Tensor&> getTensors(string name); pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
bool isDataset(string name); bool isDataset(const string& name);
}; };
}; };