Fix some more lint warnings
This commit is contained in:
parent
8b2ed26ab7
commit
b882569169
@ -7,7 +7,7 @@
|
||||
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
Graph::Graph(int V)
|
||||
Graph::Graph(int V) : V(V)
|
||||
{
|
||||
parent = vector<int>(V);
|
||||
for (int i = 0; i < V; i++)
|
||||
@ -34,10 +34,10 @@ namespace bayesnet {
|
||||
}
|
||||
void Graph::kruskal_algorithm()
|
||||
{
|
||||
int i, uSt, vEd;
|
||||
// sort the edges ordered on decreasing weight
|
||||
sort(G.begin(), G.end(), [](auto& left, auto& right) {return left.first > right.first;});
|
||||
for (i = 0; i < G.size(); i++) {
|
||||
sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;});
|
||||
for (int i = 0; i < G.size(); i++) {
|
||||
int uSt, vEd;
|
||||
uSt = find_set(G[i].second.first);
|
||||
vEd = find_set(G[i].second.second);
|
||||
if (uSt != vEd) {
|
||||
|
@ -88,18 +88,15 @@ namespace bayesnet {
|
||||
{
|
||||
// Get dimensions of the CPT
|
||||
dimensions.push_back(numStates);
|
||||
for (auto father : getParents()) {
|
||||
dimensions.push_back(father->getNumStates());
|
||||
}
|
||||
transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
|
||||
|
||||
// Create a tensor of zeros with the dimensions of the CPT
|
||||
cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
|
||||
// Fill table with counts
|
||||
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
|
||||
torch::List<c10::optional<torch::Tensor>> coordinates;
|
||||
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
|
||||
for (auto father : getParents()) {
|
||||
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
|
||||
}
|
||||
transform(parents.begin(), parents.end(), back_inserter(coordinates), [&dataset, &n_sample](const auto& parent) { return torch::tensor(dataset[parent->getName()][n_sample]); });
|
||||
// Increment the count of the corresponding coordinate
|
||||
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1);
|
||||
}
|
||||
@ -111,19 +108,15 @@ namespace bayesnet {
|
||||
torch::List<c10::optional<torch::Tensor>> coordinates;
|
||||
// following predetermined order of indices in the cpTable (see Node.h)
|
||||
coordinates.push_back(torch::tensor(evidence[name]));
|
||||
for (auto parent : getParents()) {
|
||||
coordinates.push_back(torch::tensor(evidence[parent->getName()]));
|
||||
}
|
||||
transform(parents.begin(), parents.end(), back_inserter(coordinates), [&evidence](const auto& parent) { return torch::tensor(evidence[parent->getName()]); });
|
||||
return cpTable.index({ coordinates }).item<float>();
|
||||
}
|
||||
vector<string> Node::graph(string className)
|
||||
vector<string> Node::graph(const string& className)
|
||||
{
|
||||
auto output = vector<string>();
|
||||
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
|
||||
output.push_back(name + " [shape=circle" + suffix + "] \n");
|
||||
for (auto& child : children) {
|
||||
output.push_back(name + " -> " + child->getName());
|
||||
}
|
||||
transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
|
||||
return output;
|
||||
}
|
||||
}
|
@ -16,7 +16,7 @@ namespace bayesnet {
|
||||
vector<int64_t> dimensions; // dimensions of the cpTable
|
||||
public:
|
||||
vector<pair<string, string>> combinations(const vector<string>&);
|
||||
Node(const std::string&, int);
|
||||
Node(const string&, int);
|
||||
void clear();
|
||||
void addParent(Node*);
|
||||
void addChild(Node*);
|
||||
@ -30,7 +30,7 @@ namespace bayesnet {
|
||||
int getNumStates() const;
|
||||
void setNumStates(int);
|
||||
unsigned minFill();
|
||||
vector<string> graph(string clasName); // Returns a vector of strings representing the graph in graphviz format
|
||||
vector<string> graph(const string& clasName); // Returns a vector of strings representing the graph in graphviz format
|
||||
float getFactorValue(map<string, int>&);
|
||||
};
|
||||
}
|
||||
|
@ -21,9 +21,7 @@ namespace platform {
|
||||
vector<string> Datasets::getNames()
|
||||
{
|
||||
vector<string> result;
|
||||
for (auto& d : datasets) {
|
||||
result.push_back(d.first);
|
||||
}
|
||||
transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; });
|
||||
return result;
|
||||
}
|
||||
vector<string> Datasets::getFeatures(string name)
|
||||
@ -79,7 +77,7 @@ namespace platform {
|
||||
}
|
||||
return datasets[name]->getTensors();
|
||||
}
|
||||
bool Datasets::isDataset(string name)
|
||||
bool Datasets::isDataset(const string& name)
|
||||
{
|
||||
return datasets.find(name) != datasets.end();
|
||||
}
|
||||
@ -193,9 +191,8 @@ namespace platform {
|
||||
yv = arff.getY();
|
||||
// Get className & Features
|
||||
className = arff.getClassName();
|
||||
for (auto feature : arff.getAttributes()) {
|
||||
features.push_back(feature.first);
|
||||
}
|
||||
auto attributes = arff.getAttributes();
|
||||
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; });
|
||||
}
|
||||
void Dataset::load()
|
||||
{
|
||||
|
@ -49,7 +49,7 @@ namespace platform {
|
||||
bool discretize;
|
||||
void load(); // Loads the list of datasets
|
||||
public:
|
||||
Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
|
||||
explicit Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
|
||||
vector<string> getNames();
|
||||
vector<string> getFeatures(string name);
|
||||
int getNSamples(string name);
|
||||
@ -58,7 +58,7 @@ namespace platform {
|
||||
pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
|
||||
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
|
||||
pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
|
||||
bool isDataset(string name);
|
||||
bool isDataset(const string& name);
|
||||
};
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user