Refactor BaseClassifier and begin TAN impl.

This commit is contained in:
2023-07-14 00:10:55 +02:00
parent e52fdc718f
commit 3f09d474f9
9 changed files with 87 additions and 49 deletions

25
src/TAN.cc Normal file
View File

@@ -0,0 +1,25 @@
#include "TAN.h"
namespace bayesnet {
using namespace std;
using namespace torch;
TAN::TAN() : BaseClassifier(Network()) {}
void TAN::train()
{
// 0. Add all nodes to the model
addNodes();
// 1. Compute mutual information between each feature and the class
auto weights = metrics.conditionalEdge();
// 2. Compute the maximum spanning tree
auto mst = metrics.maximumSpanningTree(weights);
// 3. Add edges from the maximum spanning tree to the model
for (auto i = 0; i < mst.size(); ++i) {
auto [from, to] = mst[i];
model.addEdge(features[from], features[to]);
}
}
}