Complete TAN with Maximum Spanning Tree
This commit is contained in:
parent
e311c27d43
commit
e3863387bb
@ -10,6 +10,7 @@
|
|||||||
#include "KDB.h"
|
#include "KDB.h"
|
||||||
#include "SPODE.h"
|
#include "SPODE.h"
|
||||||
#include "AODE.h"
|
#include "AODE.h"
|
||||||
|
#include "TAN.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -282,5 +283,13 @@ int main(int argc, char** argv)
|
|||||||
}
|
}
|
||||||
cout << "Score: " << aode.score(Xd, y) << endl;
|
cout << "Score: " << aode.score(Xd, y) << endl;
|
||||||
cout << "****************** AODE ******************" << endl;
|
cout << "****************** AODE ******************" << endl;
|
||||||
|
cout << "****************** TAN ******************" << endl;
|
||||||
|
auto tan = bayesnet::TAN();
|
||||||
|
tan.fit(Xd, y, features, className, states);
|
||||||
|
for (auto line : tan.show()) {
|
||||||
|
cout << line << endl;
|
||||||
|
}
|
||||||
|
cout << "Score: " << tan.score(Xd, y) << endl;
|
||||||
|
cout << "****************** TAN ******************" << endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
@ -1,2 +1,2 @@
|
|||||||
add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc)
|
add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc)
|
||||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
@ -1,4 +1,5 @@
|
|||||||
#include "Metrics.hpp"
|
#include "Metrics.hpp"
|
||||||
|
#include "Mst.h"
|
||||||
using namespace std;
|
using namespace std;
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Metrics::Metrics(torch::Tensor& samples, vector<string>& features, string& className, int classNumStates)
|
Metrics::Metrics(torch::Tensor& samples, vector<string>& features, string& className, int classNumStates)
|
||||||
@ -121,14 +122,11 @@ namespace bayesnet {
|
|||||||
and the indices of the weights as nodes of this square matrix using
|
and the indices of the weights as nodes of this square matrix using
|
||||||
Kruskal algorithm
|
Kruskal algorithm
|
||||||
*/
|
*/
|
||||||
vector<pair<int, int>> Metrics::maximumSpanningTree(int root, Tensor& weights)
|
vector<pair<int, int>> Metrics::maximumSpanningTree(vector<string> features, Tensor& weights, int root)
|
||||||
{
|
{
|
||||||
auto result = vector<pair<int, int>>();
|
auto result = vector<pair<int, int>>();
|
||||||
|
auto mst = MST(features, weights, root);
|
||||||
|
return mst.maximumSpanningTree();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -22,7 +22,7 @@ namespace bayesnet {
|
|||||||
vector<float> conditionalEdgeWeights();
|
vector<float> conditionalEdgeWeights();
|
||||||
Tensor conditionalEdge();
|
Tensor conditionalEdge();
|
||||||
vector<pair<string, string>> doCombinations(const vector<string>&);
|
vector<pair<string, string>> doCombinations(const vector<string>&);
|
||||||
vector<pair<int, int>> maximumSpanningTree(int root, Tensor& weights);
|
vector<pair<int, int>> maximumSpanningTree(vector<string> features, Tensor& weights, int root);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
116
src/Mst.cc
Normal file
116
src/Mst.cc
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
#include "Mst.h"
|
||||||
|
#include <vector>
|
||||||
|
/*
|
||||||
|
Based on the code from https://www.softwaretestinghelp.com/minimum-spanning-tree-tutorial/
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace bayesnet {
|
||||||
|
using namespace std;
|
||||||
|
Graph::Graph(int V)
|
||||||
|
{
|
||||||
|
parent = vector<int>(V);
|
||||||
|
for (int i = 0; i < V; i++)
|
||||||
|
parent[i] = i;
|
||||||
|
G.clear();
|
||||||
|
T.clear();
|
||||||
|
}
|
||||||
|
void Graph::addEdge(int u, int v, float wt)
|
||||||
|
{
|
||||||
|
G.push_back({ wt, { u, v } });
|
||||||
|
}
|
||||||
|
int Graph::find_set(int i)
|
||||||
|
{
|
||||||
|
// If i is the parent of itself
|
||||||
|
if (i == parent[i])
|
||||||
|
return i;
|
||||||
|
else
|
||||||
|
//else recursively find the parent of i
|
||||||
|
return find_set(parent[i]);
|
||||||
|
}
|
||||||
|
void Graph::union_set(int u, int v)
|
||||||
|
{
|
||||||
|
parent[u] = parent[v];
|
||||||
|
}
|
||||||
|
void Graph::kruskal_algorithm()
|
||||||
|
{
|
||||||
|
int i, uSt, vEd;
|
||||||
|
// sort the edges ordered on decreasing weight
|
||||||
|
sort(G.begin(), G.end(), [](auto& left, auto& right) {return left.first > right.first;});
|
||||||
|
for (i = 0; i < G.size(); i++) {
|
||||||
|
uSt = find_set(G[i].second.first);
|
||||||
|
vEd = find_set(G[i].second.second);
|
||||||
|
if (uSt != vEd) {
|
||||||
|
T.push_back(G[i]); // add to mst vector
|
||||||
|
union_set(uSt, vEd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Graph::display_mst()
|
||||||
|
{
|
||||||
|
cout << "Edge :" << " Weight" << endl;
|
||||||
|
for (int i = 0; i < T.size(); i++) {
|
||||||
|
cout << T[i].second.first << " - " << T[i].second.second << " : "
|
||||||
|
<< T[i].first;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<pair<int, int>> reorder(vector<pair<float, pair<int, int>>> T, int root_original)
|
||||||
|
{
|
||||||
|
auto result = vector<pair<int, int>>();
|
||||||
|
auto visited = vector<int>();
|
||||||
|
auto nextVariables = unordered_set<int>();
|
||||||
|
nextVariables.emplace(root_original);
|
||||||
|
while (nextVariables.size() > 0) {
|
||||||
|
int root = *nextVariables.begin();
|
||||||
|
nextVariables.erase(nextVariables.begin());
|
||||||
|
for (int i = 0; i < T.size(); ++i) {
|
||||||
|
auto [weight, edge] = T[i];
|
||||||
|
auto [from, to] = edge;
|
||||||
|
if (from == root || to == root) {
|
||||||
|
visited.insert(visited.begin(), i);
|
||||||
|
if (from == root) {
|
||||||
|
result.push_back({ from, to });
|
||||||
|
nextVariables.emplace(to);
|
||||||
|
} else {
|
||||||
|
result.push_back({ to, from });
|
||||||
|
nextVariables.emplace(from);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Remove visited
|
||||||
|
for (int i = 0; i < visited.size(); ++i) {
|
||||||
|
T.erase(T.begin() + visited[i]);
|
||||||
|
}
|
||||||
|
visited.clear();
|
||||||
|
}
|
||||||
|
if (T.size() > 0) {
|
||||||
|
for (int i = 0; i < T.size(); ++i) {
|
||||||
|
auto [weight, edge] = T[i];
|
||||||
|
auto [from, to] = edge;
|
||||||
|
result.push_back({ from, to });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
MST::MST(vector<string>& features, Tensor& weights, int root) : features(features), weights(weights), root(root) {}
|
||||||
|
vector<pair<int, int>> MST::maximumSpanningTree()
|
||||||
|
{
|
||||||
|
auto num_features = features.size();
|
||||||
|
Graph g(num_features);
|
||||||
|
|
||||||
|
// Make a complete graph
|
||||||
|
for (int i = 0; i < num_features - 1; ++i) {
|
||||||
|
for (int j = i; j < num_features; ++j) {
|
||||||
|
g.addEdge(i, j, weights[i][j].item<float>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.kruskal_algorithm();
|
||||||
|
//g.display_mst();
|
||||||
|
auto mst = g.get_mst();
|
||||||
|
return reorder(mst, root);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
35
src/Mst.h
Normal file
35
src/Mst.h
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#ifndef MST_H
|
||||||
|
#define MST_H
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
namespace bayesnet {
|
||||||
|
using namespace std;
|
||||||
|
using namespace torch;
|
||||||
|
class MST {
|
||||||
|
private:
|
||||||
|
Tensor weights;
|
||||||
|
vector<string> features;
|
||||||
|
int root;
|
||||||
|
public:
|
||||||
|
MST() = default;
|
||||||
|
MST(vector<string>& features, Tensor& weights, int root);
|
||||||
|
vector<pair<int, int>> maximumSpanningTree();
|
||||||
|
};
|
||||||
|
class Graph {
|
||||||
|
private:
|
||||||
|
int V; // number of nodes in graph
|
||||||
|
vector <pair<float, pair<int, int>>> G; // vector for graph
|
||||||
|
vector <pair<float, pair<int, int>>> T; // vector for mst
|
||||||
|
vector<int> parent;
|
||||||
|
public:
|
||||||
|
Graph(int V);
|
||||||
|
void addEdge(int u, int v, float wt);
|
||||||
|
int find_set(int i);
|
||||||
|
void union_set(int u, int v);
|
||||||
|
void kruskal_algorithm();
|
||||||
|
void display_mst();
|
||||||
|
vector <pair<float, pair<int, int>>> get_mst() { return T; }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
@ -19,16 +19,20 @@ namespace bayesnet {
|
|||||||
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset);
|
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset);
|
||||||
mi.push_back({ i, mi_value });
|
mi.push_back({ i, mi_value });
|
||||||
}
|
}
|
||||||
sort(mi.begin(), mi.end());
|
sort(mi.begin(), mi.end(), [](auto& left, auto& right) {return left.second < right.second;});
|
||||||
auto root = mi[mi.size() - 1].first;
|
auto root = mi[mi.size() - 1].first;
|
||||||
// 2. Compute mutual information between each feature and the class
|
// 2. Compute mutual information between each feature and the class
|
||||||
auto weights = metrics.conditionalEdge();
|
auto weights = metrics.conditionalEdge();
|
||||||
// 3. Compute the maximum spanning tree
|
// 3. Compute the maximum spanning tree
|
||||||
auto mst = metrics.maximumSpanningTree(root, weights);
|
auto mst = metrics.maximumSpanningTree(features, weights, root);
|
||||||
// 4. Add edges from the maximum spanning tree to the model
|
// 4. Add edges from the maximum spanning tree to the model
|
||||||
for (auto i = 0; i < mst.size(); ++i) {
|
for (auto i = 0; i < mst.size(); ++i) {
|
||||||
auto [from, to] = mst[i];
|
auto [from, to] = mst[i];
|
||||||
model.addEdge(features[from], features[to]);
|
model.addEdge(features[from], features[to]);
|
||||||
}
|
}
|
||||||
|
// 5. Add edges from the class to all features
|
||||||
|
for (auto feature : features) {
|
||||||
|
model.addEdge(className, feature);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user