Complete TAN with Maximum Spanning Tree

This commit is contained in:
Ricardo Montañana Gómez 2023-07-15 18:31:50 +02:00
parent e311c27d43
commit e3863387bb
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 172 additions and 10 deletions

View File

@ -10,6 +10,7 @@
#include "KDB.h"
#include "SPODE.h"
#include "AODE.h"
#include "TAN.h"
using namespace std;
@ -282,5 +283,13 @@ int main(int argc, char** argv)
}
cout << "Score: " << aode.score(Xd, y) << endl;
cout << "****************** AODE ******************" << endl;
cout << "****************** TAN ******************" << endl;
auto tan = bayesnet::TAN();
tan.fit(Xd, y, features, className, states);
for (auto line : tan.show()) {
cout << line << endl;
}
cout << "Score: " << tan.score(Xd, y) << endl;
cout << "****************** TAN ******************" << endl;
return 0;
}

View File

@ -1,2 +1,2 @@
add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc)
add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

View File

@ -1,4 +1,5 @@
#include "Metrics.hpp"
#include "Mst.h"
using namespace std;
namespace bayesnet {
Metrics::Metrics(torch::Tensor& samples, vector<string>& features, string& className, int classNumStates)
@ -121,14 +122,11 @@ namespace bayesnet {
and the indices of the weights as nodes of this square matrix using
Kruskal algorithm
*/
vector<pair<int, int>> Metrics::maximumSpanningTree(int root, Tensor& weights)
vector<pair<int, int>> Metrics::maximumSpanningTree(vector<string> features, Tensor& weights, int root)
{
auto result = vector<pair<int, int>>();
auto mst = MST(features, weights, root);
return mst.maximumSpanningTree();
return result;
}
}

View File

@ -22,7 +22,7 @@ namespace bayesnet {
vector<float> conditionalEdgeWeights();
Tensor conditionalEdge();
vector<pair<string, string>> doCombinations(const vector<string>&);
vector<pair<int, int>> maximumSpanningTree(int root, Tensor& weights);
vector<pair<int, int>> maximumSpanningTree(vector<string> features, Tensor& weights, int root);
};
}
#endif

116
src/Mst.cc Normal file
View File

@ -0,0 +1,116 @@
#include "Mst.h"
#include <vector>
/*
Based on the code from https://www.softwaretestinghelp.com/minimum-spanning-tree-tutorial/
*/
namespace bayesnet {
using namespace std;
Graph::Graph(int V)
{
parent = vector<int>(V);
for (int i = 0; i < V; i++)
parent[i] = i;
G.clear();
T.clear();
}
void Graph::addEdge(int u, int v, float wt)
{
G.push_back({ wt, { u, v } });
}
int Graph::find_set(int i)
{
// If i is the parent of itself
if (i == parent[i])
return i;
else
//else recursively find the parent of i
return find_set(parent[i]);
}
void Graph::union_set(int u, int v)
{
parent[u] = parent[v];
}
void Graph::kruskal_algorithm()
{
int i, uSt, vEd;
// sort the edges ordered on decreasing weight
sort(G.begin(), G.end(), [](auto& left, auto& right) {return left.first > right.first;});
for (i = 0; i < G.size(); i++) {
uSt = find_set(G[i].second.first);
vEd = find_set(G[i].second.second);
if (uSt != vEd) {
T.push_back(G[i]); // add to mst vector
union_set(uSt, vEd);
}
}
}
void Graph::display_mst()
{
cout << "Edge :" << " Weight" << endl;
for (int i = 0; i < T.size(); i++) {
cout << T[i].second.first << " - " << T[i].second.second << " : "
<< T[i].first;
cout << endl;
}
}
vector<pair<int, int>> reorder(vector<pair<float, pair<int, int>>> T, int root_original)
{
auto result = vector<pair<int, int>>();
auto visited = vector<int>();
auto nextVariables = unordered_set<int>();
nextVariables.emplace(root_original);
while (nextVariables.size() > 0) {
int root = *nextVariables.begin();
nextVariables.erase(nextVariables.begin());
for (int i = 0; i < T.size(); ++i) {
auto [weight, edge] = T[i];
auto [from, to] = edge;
if (from == root || to == root) {
visited.insert(visited.begin(), i);
if (from == root) {
result.push_back({ from, to });
nextVariables.emplace(to);
} else {
result.push_back({ to, from });
nextVariables.emplace(from);
}
}
}
// Remove visited
for (int i = 0; i < visited.size(); ++i) {
T.erase(T.begin() + visited[i]);
}
visited.clear();
}
if (T.size() > 0) {
for (int i = 0; i < T.size(); ++i) {
auto [weight, edge] = T[i];
auto [from, to] = edge;
result.push_back({ from, to });
}
}
return result;
}
MST::MST(vector<string>& features, Tensor& weights, int root) : features(features), weights(weights), root(root) {}
vector<pair<int, int>> MST::maximumSpanningTree()
{
auto num_features = features.size();
Graph g(num_features);
// Make a complete graph
for (int i = 0; i < num_features - 1; ++i) {
for (int j = i; j < num_features; ++j) {
g.addEdge(i, j, weights[i][j].item<float>());
}
}
g.kruskal_algorithm();
//g.display_mst();
auto mst = g.get_mst();
return reorder(mst, root);
}
}

35
src/Mst.h Normal file
View File

@ -0,0 +1,35 @@
#ifndef MST_H
#define MST_H
#include <torch/torch.h>
#include <vector>
#include <string>
namespace bayesnet {
using namespace std;
using namespace torch;
class MST {
private:
Tensor weights;
vector<string> features;
int root;
public:
MST() = default;
MST(vector<string>& features, Tensor& weights, int root);
vector<pair<int, int>> maximumSpanningTree();
};
class Graph {
private:
int V; // number of nodes in graph
vector <pair<float, pair<int, int>>> G; // vector for graph
vector <pair<float, pair<int, int>>> T; // vector for mst
vector<int> parent;
public:
Graph(int V);
void addEdge(int u, int v, float wt);
int find_set(int i);
void union_set(int u, int v);
void kruskal_algorithm();
void display_mst();
vector <pair<float, pair<int, int>>> get_mst() { return T; }
};
}
#endif

View File

@ -19,16 +19,20 @@ namespace bayesnet {
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset);
mi.push_back({ i, mi_value });
}
sort(mi.begin(), mi.end());
sort(mi.begin(), mi.end(), [](auto& left, auto& right) {return left.second < right.second;});
auto root = mi[mi.size() - 1].first;
// 2. Compute mutual information between each feature and the class
auto weights = metrics.conditionalEdge();
// 3. Compute the maximum spanning tree
auto mst = metrics.maximumSpanningTree(root, weights);
auto mst = metrics.maximumSpanningTree(features, weights, root);
// 4. Add edges from the maximum spanning tree to the model
for (auto i = 0; i < mst.size(); ++i) {
auto [from, to] = mst[i];
model.addEdge(features[from], features[to]);
}
// 5. Add edges from the class to all features
for (auto feature : features) {
model.addEdge(className, feature);
}
}
}