Refactor BaseClassifier and begin TAN impl.

This commit is contained in:
2023-07-14 00:10:55 +02:00
parent e52fdc718f
commit 3f09d474f9
9 changed files with 87 additions and 49 deletions

View File

@@ -4,7 +4,7 @@ namespace bayesnet {
using namespace std;
using namespace torch;
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0) {}
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()) {}
BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
{
@@ -13,6 +13,8 @@ namespace bayesnet {
this->className = className;
this->states = states;
checkFitParameters();
auto n_classes = states[className].size();
metrics = Metrics(dataset, features, className, n_classes);
train();
return *this;
}
@@ -51,6 +53,14 @@ namespace bayesnet {
}
}
}
vector<int> BaseClassifier::argsort(vector<float>& nums)
{
int n = nums.size();
vector<int> indices(n);
iota(indices.begin(), indices.end(), 0);
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices;
}
vector<vector<int>> tensorToVector(const torch::Tensor& tensor)
{
// convert mxn tensor to nxm vector
@@ -86,8 +96,16 @@ namespace bayesnet {
Tensor y_pred = predict(X);
return (y_pred == y).sum().item<float>() / y.size(0);
}
void BaseClassifier::show()
vector<string> BaseClassifier::show()
{
model.show();
return model.show();
}
void BaseClassifier::addNodes()
{
// Add all nodes to the network
for (auto feature : features) {
model.addNode(feature, states[feature].size());
}
model.addNode(className, states[className].size());
}
}