Add SPODE
This commit is contained in:
parent
002aa30672
commit
e8b8fa29c8
8
.vscode/settings.json
vendored
8
.vscode/settings.json
vendored
@ -90,7 +90,13 @@
|
|||||||
"format": "cpp",
|
"format": "cpp",
|
||||||
"valarray": "cpp",
|
"valarray": "cpp",
|
||||||
"regex": "cpp",
|
"regex": "cpp",
|
||||||
"span": "cpp"
|
"span": "cpp",
|
||||||
|
"cfenv": "cpp",
|
||||||
|
"cinttypes": "cpp",
|
||||||
|
"csetjmp": "cpp",
|
||||||
|
"future": "cpp",
|
||||||
|
"queue": "cpp",
|
||||||
|
"typeindex": "cpp"
|
||||||
},
|
},
|
||||||
"cmake.configureOnOpen": false,
|
"cmake.configureOnOpen": false,
|
||||||
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
|
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc)
|
add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc)
|
||||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
@ -121,7 +121,7 @@ namespace bayesnet {
|
|||||||
and the indices of the weights as nodes of this square matrix using
|
and the indices of the weights as nodes of this square matrix using
|
||||||
Kruskal algorithm
|
Kruskal algorithm
|
||||||
*/
|
*/
|
||||||
vector<pair<int, int>> Metrics::maximumSpanningTree(Tensor& weights)
|
vector<pair<int, int>> Metrics::maximumSpanningTree(int root, Tensor& weights)
|
||||||
{
|
{
|
||||||
auto result = vector<pair<int, int>>();
|
auto result = vector<pair<int, int>>();
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ namespace bayesnet {
|
|||||||
vector<float> conditionalEdgeWeights();
|
vector<float> conditionalEdgeWeights();
|
||||||
Tensor conditionalEdge();
|
Tensor conditionalEdge();
|
||||||
vector<pair<string, string>> doCombinations(const vector<string>&);
|
vector<pair<string, string>> doCombinations(const vector<string>&);
|
||||||
vector<pair<int, int>> maximumSpanningTree(Tensor& weights);
|
vector<pair<int, int>> maximumSpanningTree(int root, Tensor& weights);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
20
src/SPODE.cc
Normal file
20
src/SPODE.cc
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
#include "SPODE.h"
|
||||||
|
|
||||||
|
namespace bayesnet {
|
||||||
|
|
||||||
|
SPODE::SPODE(int root) : BaseClassifier(Network()), root(root) {}
|
||||||
|
|
||||||
|
void SPODE::train()
|
||||||
|
{
|
||||||
|
// 0. Add all nodes to the model
|
||||||
|
addNodes();
|
||||||
|
// 1. Add edges from the class node to all other nodes
|
||||||
|
// 2. Add edges from the root node to all other nodes
|
||||||
|
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
|
||||||
|
model.addEdge(className, features[i]);
|
||||||
|
if (i != root) {
|
||||||
|
model.addEdge(features[root], features[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
14
src/SPODE.h
Normal file
14
src/SPODE.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#ifndef SPODE_H
|
||||||
|
#define SPODE_H
|
||||||
|
#include "BaseClassifier.h"
|
||||||
|
namespace bayesnet {
|
||||||
|
class SPODE : public BaseClassifier {
|
||||||
|
private:
|
||||||
|
int root;
|
||||||
|
protected:
|
||||||
|
void train() override;
|
||||||
|
public:
|
||||||
|
SPODE(int root);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
21
src/TAN.cc
21
src/TAN.cc
@ -10,16 +10,25 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
// 0. Add all nodes to the model
|
// 0. Add all nodes to the model
|
||||||
addNodes();
|
addNodes();
|
||||||
// 1. Compute mutual information between each feature and the class
|
// 1. Compute mutual information between each feature and the class and set the root node
|
||||||
|
// as the highest mutual information with the class
|
||||||
|
auto mi = vector <pair<int, float >>();
|
||||||
|
Tensor class_dataset = dataset.index({ "...", -1 });
|
||||||
|
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
|
||||||
|
Tensor feature_dataset = dataset.index({ "...", i });
|
||||||
|
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset);
|
||||||
|
mi.push_back({ i, mi_value });
|
||||||
|
}
|
||||||
|
sort(mi.begin(), mi.end());
|
||||||
|
auto root = mi[mi.size() - 1].first;
|
||||||
|
// 2. Compute mutual information between each feature and the class
|
||||||
auto weights = metrics.conditionalEdge();
|
auto weights = metrics.conditionalEdge();
|
||||||
// 2. Compute the maximum spanning tree
|
// 3. Compute the maximum spanning tree
|
||||||
auto mst = metrics.maximumSpanningTree(weights);
|
auto mst = metrics.maximumSpanningTree(root, weights);
|
||||||
// 3. Add edges from the maximum spanning tree to the model
|
// 4. Add edges from the maximum spanning tree to the model
|
||||||
for (auto i = 0; i < mst.size(); ++i) {
|
for (auto i = 0; i < mst.size(); ++i) {
|
||||||
auto [from, to] = mst[i];
|
auto [from, to] = mst[i];
|
||||||
model.addEdge(features[from], features[to]);
|
model.addEdge(features[from], features[to]);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user