Refactor BaseClassifier and begin TAN impl.
This commit is contained in:
@@ -4,7 +4,7 @@ namespace bayesnet {
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
|
||||
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0) {}
|
||||
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()) {}
|
||||
BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
{
|
||||
|
||||
@@ -13,6 +13,8 @@ namespace bayesnet {
|
||||
this->className = className;
|
||||
this->states = states;
|
||||
checkFitParameters();
|
||||
auto n_classes = states[className].size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
train();
|
||||
return *this;
|
||||
}
|
||||
@@ -51,6 +53,14 @@ namespace bayesnet {
|
||||
}
|
||||
}
|
||||
}
|
||||
vector<int> BaseClassifier::argsort(vector<float>& nums)
|
||||
{
|
||||
int n = nums.size();
|
||||
vector<int> indices(n);
|
||||
iota(indices.begin(), indices.end(), 0);
|
||||
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
|
||||
return indices;
|
||||
}
|
||||
vector<vector<int>> tensorToVector(const torch::Tensor& tensor)
|
||||
{
|
||||
// convert mxn tensor to nxm vector
|
||||
@@ -86,8 +96,16 @@ namespace bayesnet {
|
||||
Tensor y_pred = predict(X);
|
||||
return (y_pred == y).sum().item<float>() / y.size(0);
|
||||
}
|
||||
void BaseClassifier::show()
|
||||
vector<string> BaseClassifier::show()
|
||||
{
|
||||
model.show();
|
||||
return model.show();
|
||||
}
|
||||
void BaseClassifier::addNodes()
|
||||
{
|
||||
// Add all nodes to the network
|
||||
for (auto feature : features) {
|
||||
model.addNode(feature, states[feature].size());
|
||||
}
|
||||
model.addNode(className, states[className].size());
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user