Add SPODE

This commit is contained in:
Ricardo Montañana Gómez 2023-07-14 12:59:47 +02:00
parent 002aa30672
commit e8b8fa29c8
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 59 additions and 10 deletions

View File

@ -90,7 +90,13 @@
"format": "cpp", "format": "cpp",
"valarray": "cpp", "valarray": "cpp",
"regex": "cpp", "regex": "cpp",
"span": "cpp" "span": "cpp",
"cfenv": "cpp",
"cinttypes": "cpp",
"csetjmp": "cpp",
"future": "cpp",
"queue": "cpp",
"typeindex": "cpp"
}, },
"cmake.configureOnOpen": false, "cmake.configureOnOpen": false,
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools" "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"

View File

@ -1,2 +1,2 @@
add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc) add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}") target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

View File

@ -121,7 +121,7 @@ namespace bayesnet {
and the indices of the weights as nodes of this square matrix using and the indices of the weights as nodes of this square matrix using
Kruskal algorithm Kruskal algorithm
*/ */
vector<pair<int, int>> Metrics::maximumSpanningTree(Tensor& weights) vector<pair<int, int>> Metrics::maximumSpanningTree(int root, Tensor& weights)
{ {
auto result = vector<pair<int, int>>(); auto result = vector<pair<int, int>>();

View File

@ -22,7 +22,7 @@ namespace bayesnet {
vector<float> conditionalEdgeWeights(); vector<float> conditionalEdgeWeights();
Tensor conditionalEdge(); Tensor conditionalEdge();
vector<pair<string, string>> doCombinations(const vector<string>&); vector<pair<string, string>> doCombinations(const vector<string>&);
vector<pair<int, int>> maximumSpanningTree(Tensor& weights); vector<pair<int, int>> maximumSpanningTree(int root, Tensor& weights);
}; };
} }
#endif #endif

20
src/SPODE.cc Normal file
View File

@ -0,0 +1,20 @@
#include "SPODE.h"
namespace bayesnet {
SPODE::SPODE(int root) : BaseClassifier(Network()), root(root) {}
void SPODE::train()
{
// 0. Add all nodes to the model
addNodes();
// 1. Add edges from the class node to all other nodes
// 2. Add edges from the root node to all other nodes
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
model.addEdge(className, features[i]);
if (i != root) {
model.addEdge(features[root], features[i]);
}
}
}
}

14
src/SPODE.h Normal file
View File

@ -0,0 +1,14 @@
#ifndef SPODE_H
#define SPODE_H
#include "BaseClassifier.h"
namespace bayesnet {
class SPODE : public BaseClassifier {
private:
int root;
protected:
void train() override;
public:
SPODE(int root);
};
}
#endif

View File

@ -10,16 +10,25 @@ namespace bayesnet {
{ {
// 0. Add all nodes to the model // 0. Add all nodes to the model
addNodes(); addNodes();
// 1. Compute mutual information between each feature and the class // 1. Compute mutual information between each feature and the class and set the root node
// as the highest mutual information with the class
auto mi = vector <pair<int, float >>();
Tensor class_dataset = dataset.index({ "...", -1 });
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
Tensor feature_dataset = dataset.index({ "...", i });
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset);
mi.push_back({ i, mi_value });
}
sort(mi.begin(), mi.end());
auto root = mi[mi.size() - 1].first;
// 2. Compute mutual information between each feature and the class
auto weights = metrics.conditionalEdge(); auto weights = metrics.conditionalEdge();
// 2. Compute the maximum spanning tree // 3. Compute the maximum spanning tree
auto mst = metrics.maximumSpanningTree(weights); auto mst = metrics.maximumSpanningTree(root, weights);
// 3. Add edges from the maximum spanning tree to the model // 4. Add edges from the maximum spanning tree to the model
for (auto i = 0; i < mst.size(); ++i) { for (auto i = 0; i < mst.size(); ++i) {
auto [from, to] = mst[i]; auto [from, to] = mst[i];
model.addEdge(features[from], features[to]); model.addEdge(features[from], features[to]);
} }
} }
} }