Fix some more lint warnings
This commit is contained in:
parent
8b2ed26ab7
commit
b882569169
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
Graph::Graph(int V)
|
Graph::Graph(int V) : V(V)
|
||||||
{
|
{
|
||||||
parent = vector<int>(V);
|
parent = vector<int>(V);
|
||||||
for (int i = 0; i < V; i++)
|
for (int i = 0; i < V; i++)
|
||||||
@ -34,10 +34,10 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
void Graph::kruskal_algorithm()
|
void Graph::kruskal_algorithm()
|
||||||
{
|
{
|
||||||
int i, uSt, vEd;
|
|
||||||
// sort the edges ordered on decreasing weight
|
// sort the edges ordered on decreasing weight
|
||||||
sort(G.begin(), G.end(), [](auto& left, auto& right) {return left.first > right.first;});
|
sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;});
|
||||||
for (i = 0; i < G.size(); i++) {
|
for (int i = 0; i < G.size(); i++) {
|
||||||
|
int uSt, vEd;
|
||||||
uSt = find_set(G[i].second.first);
|
uSt = find_set(G[i].second.first);
|
||||||
vEd = find_set(G[i].second.second);
|
vEd = find_set(G[i].second.second);
|
||||||
if (uSt != vEd) {
|
if (uSt != vEd) {
|
||||||
|
@ -88,18 +88,15 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
// Get dimensions of the CPT
|
// Get dimensions of the CPT
|
||||||
dimensions.push_back(numStates);
|
dimensions.push_back(numStates);
|
||||||
for (auto father : getParents()) {
|
transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
|
||||||
dimensions.push_back(father->getNumStates());
|
|
||||||
}
|
|
||||||
// Create a tensor of zeros with the dimensions of the CPT
|
// Create a tensor of zeros with the dimensions of the CPT
|
||||||
cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
|
cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
|
||||||
// Fill table with counts
|
// Fill table with counts
|
||||||
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
|
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
|
||||||
torch::List<c10::optional<torch::Tensor>> coordinates;
|
torch::List<c10::optional<torch::Tensor>> coordinates;
|
||||||
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
|
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
|
||||||
for (auto father : getParents()) {
|
transform(parents.begin(), parents.end(), back_inserter(coordinates), [&dataset, &n_sample](const auto& parent) { return torch::tensor(dataset[parent->getName()][n_sample]); });
|
||||||
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
|
|
||||||
}
|
|
||||||
// Increment the count of the corresponding coordinate
|
// Increment the count of the corresponding coordinate
|
||||||
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1);
|
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1);
|
||||||
}
|
}
|
||||||
@ -111,19 +108,15 @@ namespace bayesnet {
|
|||||||
torch::List<c10::optional<torch::Tensor>> coordinates;
|
torch::List<c10::optional<torch::Tensor>> coordinates;
|
||||||
// following predetermined order of indices in the cpTable (see Node.h)
|
// following predetermined order of indices in the cpTable (see Node.h)
|
||||||
coordinates.push_back(torch::tensor(evidence[name]));
|
coordinates.push_back(torch::tensor(evidence[name]));
|
||||||
for (auto parent : getParents()) {
|
transform(parents.begin(), parents.end(), back_inserter(coordinates), [&evidence](const auto& parent) { return torch::tensor(evidence[parent->getName()]); });
|
||||||
coordinates.push_back(torch::tensor(evidence[parent->getName()]));
|
|
||||||
}
|
|
||||||
return cpTable.index({ coordinates }).item<float>();
|
return cpTable.index({ coordinates }).item<float>();
|
||||||
}
|
}
|
||||||
vector<string> Node::graph(string className)
|
vector<string> Node::graph(const string& className)
|
||||||
{
|
{
|
||||||
auto output = vector<string>();
|
auto output = vector<string>();
|
||||||
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
|
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
|
||||||
output.push_back(name + " [shape=circle" + suffix + "] \n");
|
output.push_back(name + " [shape=circle" + suffix + "] \n");
|
||||||
for (auto& child : children) {
|
transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
|
||||||
output.push_back(name + " -> " + child->getName());
|
|
||||||
}
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -16,7 +16,7 @@ namespace bayesnet {
|
|||||||
vector<int64_t> dimensions; // dimensions of the cpTable
|
vector<int64_t> dimensions; // dimensions of the cpTable
|
||||||
public:
|
public:
|
||||||
vector<pair<string, string>> combinations(const vector<string>&);
|
vector<pair<string, string>> combinations(const vector<string>&);
|
||||||
Node(const std::string&, int);
|
Node(const string&, int);
|
||||||
void clear();
|
void clear();
|
||||||
void addParent(Node*);
|
void addParent(Node*);
|
||||||
void addChild(Node*);
|
void addChild(Node*);
|
||||||
@ -30,7 +30,7 @@ namespace bayesnet {
|
|||||||
int getNumStates() const;
|
int getNumStates() const;
|
||||||
void setNumStates(int);
|
void setNumStates(int);
|
||||||
unsigned minFill();
|
unsigned minFill();
|
||||||
vector<string> graph(string clasName); // Returns a vector of strings representing the graph in graphviz format
|
vector<string> graph(const string& clasName); // Returns a vector of strings representing the graph in graphviz format
|
||||||
float getFactorValue(map<string, int>&);
|
float getFactorValue(map<string, int>&);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -21,9 +21,7 @@ namespace platform {
|
|||||||
vector<string> Datasets::getNames()
|
vector<string> Datasets::getNames()
|
||||||
{
|
{
|
||||||
vector<string> result;
|
vector<string> result;
|
||||||
for (auto& d : datasets) {
|
transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; });
|
||||||
result.push_back(d.first);
|
|
||||||
}
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
vector<string> Datasets::getFeatures(string name)
|
vector<string> Datasets::getFeatures(string name)
|
||||||
@ -79,7 +77,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return datasets[name]->getTensors();
|
return datasets[name]->getTensors();
|
||||||
}
|
}
|
||||||
bool Datasets::isDataset(string name)
|
bool Datasets::isDataset(const string& name)
|
||||||
{
|
{
|
||||||
return datasets.find(name) != datasets.end();
|
return datasets.find(name) != datasets.end();
|
||||||
}
|
}
|
||||||
@ -193,9 +191,8 @@ namespace platform {
|
|||||||
yv = arff.getY();
|
yv = arff.getY();
|
||||||
// Get className & Features
|
// Get className & Features
|
||||||
className = arff.getClassName();
|
className = arff.getClassName();
|
||||||
for (auto feature : arff.getAttributes()) {
|
auto attributes = arff.getAttributes();
|
||||||
features.push_back(feature.first);
|
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; });
|
||||||
}
|
|
||||||
}
|
}
|
||||||
void Dataset::load()
|
void Dataset::load()
|
||||||
{
|
{
|
||||||
|
@ -49,7 +49,7 @@ namespace platform {
|
|||||||
bool discretize;
|
bool discretize;
|
||||||
void load(); // Loads the list of datasets
|
void load(); // Loads the list of datasets
|
||||||
public:
|
public:
|
||||||
Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
|
explicit Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
|
||||||
vector<string> getNames();
|
vector<string> getNames();
|
||||||
vector<string> getFeatures(string name);
|
vector<string> getFeatures(string name);
|
||||||
int getNSamples(string name);
|
int getNSamples(string name);
|
||||||
@ -58,7 +58,7 @@ namespace platform {
|
|||||||
pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
|
pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
|
||||||
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
|
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
|
||||||
pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
|
pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
|
||||||
bool isDataset(string name);
|
bool isDataset(const string& name);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user