Complete TAN with Maximum Spanning Tree
This commit is contained in:
parent
e311c27d43
commit
e3863387bb
@ -10,6 +10,7 @@
|
||||
#include "KDB.h"
|
||||
#include "SPODE.h"
|
||||
#include "AODE.h"
|
||||
#include "TAN.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
@ -282,5 +283,13 @@ int main(int argc, char** argv)
|
||||
}
|
||||
cout << "Score: " << aode.score(Xd, y) << endl;
|
||||
cout << "****************** AODE ******************" << endl;
|
||||
cout << "****************** TAN ******************" << endl;
|
||||
auto tan = bayesnet::TAN();
|
||||
tan.fit(Xd, y, features, className, states);
|
||||
for (auto line : tan.show()) {
|
||||
cout << line << endl;
|
||||
}
|
||||
cout << "Score: " << tan.score(Xd, y) << endl;
|
||||
cout << "****************** TAN ******************" << endl;
|
||||
return 0;
|
||||
}
|
@ -1,2 +1,2 @@
|
||||
add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc)
|
||||
add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc)
|
||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
@ -1,4 +1,5 @@
|
||||
#include "Metrics.hpp"
|
||||
#include "Mst.h"
|
||||
using namespace std;
|
||||
namespace bayesnet {
|
||||
Metrics::Metrics(torch::Tensor& samples, vector<string>& features, string& className, int classNumStates)
|
||||
@ -121,14 +122,11 @@ namespace bayesnet {
|
||||
and the indices of the weights as nodes of this square matrix using
|
||||
Kruskal algorithm
|
||||
*/
|
||||
vector<pair<int, int>> Metrics::maximumSpanningTree(int root, Tensor& weights)
|
||||
vector<pair<int, int>> Metrics::maximumSpanningTree(vector<string> features, Tensor& weights, int root)
|
||||
{
|
||||
auto result = vector<pair<int, int>>();
|
||||
auto mst = MST(features, weights, root);
|
||||
return mst.maximumSpanningTree();
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
@ -22,7 +22,7 @@ namespace bayesnet {
|
||||
vector<float> conditionalEdgeWeights();
|
||||
Tensor conditionalEdge();
|
||||
vector<pair<string, string>> doCombinations(const vector<string>&);
|
||||
vector<pair<int, int>> maximumSpanningTree(int root, Tensor& weights);
|
||||
vector<pair<int, int>> maximumSpanningTree(vector<string> features, Tensor& weights, int root);
|
||||
};
|
||||
}
|
||||
#endif
|
116
src/Mst.cc
Normal file
116
src/Mst.cc
Normal file
@ -0,0 +1,116 @@
|
||||
#include "Mst.h"
|
||||
#include <vector>
|
||||
/*
|
||||
Based on the code from https://www.softwaretestinghelp.com/minimum-spanning-tree-tutorial/
|
||||
|
||||
*/
|
||||
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
Graph::Graph(int V)
|
||||
{
|
||||
parent = vector<int>(V);
|
||||
for (int i = 0; i < V; i++)
|
||||
parent[i] = i;
|
||||
G.clear();
|
||||
T.clear();
|
||||
}
|
||||
void Graph::addEdge(int u, int v, float wt)
|
||||
{
|
||||
G.push_back({ wt, { u, v } });
|
||||
}
|
||||
int Graph::find_set(int i)
|
||||
{
|
||||
// If i is the parent of itself
|
||||
if (i == parent[i])
|
||||
return i;
|
||||
else
|
||||
//else recursively find the parent of i
|
||||
return find_set(parent[i]);
|
||||
}
|
||||
void Graph::union_set(int u, int v)
|
||||
{
|
||||
parent[u] = parent[v];
|
||||
}
|
||||
void Graph::kruskal_algorithm()
|
||||
{
|
||||
int i, uSt, vEd;
|
||||
// sort the edges ordered on decreasing weight
|
||||
sort(G.begin(), G.end(), [](auto& left, auto& right) {return left.first > right.first;});
|
||||
for (i = 0; i < G.size(); i++) {
|
||||
uSt = find_set(G[i].second.first);
|
||||
vEd = find_set(G[i].second.second);
|
||||
if (uSt != vEd) {
|
||||
T.push_back(G[i]); // add to mst vector
|
||||
union_set(uSt, vEd);
|
||||
}
|
||||
}
|
||||
}
|
||||
void Graph::display_mst()
|
||||
{
|
||||
cout << "Edge :" << " Weight" << endl;
|
||||
for (int i = 0; i < T.size(); i++) {
|
||||
cout << T[i].second.first << " - " << T[i].second.second << " : "
|
||||
<< T[i].first;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
vector<pair<int, int>> reorder(vector<pair<float, pair<int, int>>> T, int root_original)
|
||||
{
|
||||
auto result = vector<pair<int, int>>();
|
||||
auto visited = vector<int>();
|
||||
auto nextVariables = unordered_set<int>();
|
||||
nextVariables.emplace(root_original);
|
||||
while (nextVariables.size() > 0) {
|
||||
int root = *nextVariables.begin();
|
||||
nextVariables.erase(nextVariables.begin());
|
||||
for (int i = 0; i < T.size(); ++i) {
|
||||
auto [weight, edge] = T[i];
|
||||
auto [from, to] = edge;
|
||||
if (from == root || to == root) {
|
||||
visited.insert(visited.begin(), i);
|
||||
if (from == root) {
|
||||
result.push_back({ from, to });
|
||||
nextVariables.emplace(to);
|
||||
} else {
|
||||
result.push_back({ to, from });
|
||||
nextVariables.emplace(from);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remove visited
|
||||
for (int i = 0; i < visited.size(); ++i) {
|
||||
T.erase(T.begin() + visited[i]);
|
||||
}
|
||||
visited.clear();
|
||||
}
|
||||
if (T.size() > 0) {
|
||||
for (int i = 0; i < T.size(); ++i) {
|
||||
auto [weight, edge] = T[i];
|
||||
auto [from, to] = edge;
|
||||
result.push_back({ from, to });
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
MST::MST(vector<string>& features, Tensor& weights, int root) : features(features), weights(weights), root(root) {}
|
||||
vector<pair<int, int>> MST::maximumSpanningTree()
|
||||
{
|
||||
auto num_features = features.size();
|
||||
Graph g(num_features);
|
||||
|
||||
// Make a complete graph
|
||||
for (int i = 0; i < num_features - 1; ++i) {
|
||||
for (int j = i; j < num_features; ++j) {
|
||||
g.addEdge(i, j, weights[i][j].item<float>());
|
||||
}
|
||||
}
|
||||
g.kruskal_algorithm();
|
||||
//g.display_mst();
|
||||
auto mst = g.get_mst();
|
||||
return reorder(mst, root);
|
||||
}
|
||||
|
||||
}
|
35
src/Mst.h
Normal file
35
src/Mst.h
Normal file
@ -0,0 +1,35 @@
|
||||
#ifndef MST_H
|
||||
#define MST_H
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
class MST {
|
||||
private:
|
||||
Tensor weights;
|
||||
vector<string> features;
|
||||
int root;
|
||||
public:
|
||||
MST() = default;
|
||||
MST(vector<string>& features, Tensor& weights, int root);
|
||||
vector<pair<int, int>> maximumSpanningTree();
|
||||
};
|
||||
class Graph {
|
||||
private:
|
||||
int V; // number of nodes in graph
|
||||
vector <pair<float, pair<int, int>>> G; // vector for graph
|
||||
vector <pair<float, pair<int, int>>> T; // vector for mst
|
||||
vector<int> parent;
|
||||
public:
|
||||
Graph(int V);
|
||||
void addEdge(int u, int v, float wt);
|
||||
int find_set(int i);
|
||||
void union_set(int u, int v);
|
||||
void kruskal_algorithm();
|
||||
void display_mst();
|
||||
vector <pair<float, pair<int, int>>> get_mst() { return T; }
|
||||
};
|
||||
}
|
||||
#endif
|
@ -19,16 +19,20 @@ namespace bayesnet {
|
||||
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset);
|
||||
mi.push_back({ i, mi_value });
|
||||
}
|
||||
sort(mi.begin(), mi.end());
|
||||
sort(mi.begin(), mi.end(), [](auto& left, auto& right) {return left.second < right.second;});
|
||||
auto root = mi[mi.size() - 1].first;
|
||||
// 2. Compute mutual information between each feature and the class
|
||||
auto weights = metrics.conditionalEdge();
|
||||
// 3. Compute the maximum spanning tree
|
||||
auto mst = metrics.maximumSpanningTree(root, weights);
|
||||
auto mst = metrics.maximumSpanningTree(features, weights, root);
|
||||
// 4. Add edges from the maximum spanning tree to the model
|
||||
for (auto i = 0; i < mst.size(); ++i) {
|
||||
auto [from, to] = mst[i];
|
||||
model.addEdge(features[from], features[to]);
|
||||
}
|
||||
// 5. Add edges from the class to all features
|
||||
for (auto feature : features) {
|
||||
model.addEdge(className, feature);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user