Add getNumberOfNodes & getNumberOfEdges to Models

Add some more tests
This commit is contained in:
2023-07-19 15:05:44 +02:00
parent 1a21015492
commit 2f5bd0ea7e
8 changed files with 66 additions and 6 deletions

View File

@@ -8,7 +8,6 @@ namespace bayesnet {
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
{
dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1);
this->features = features;
this->className = className;
@@ -116,4 +115,13 @@ namespace bayesnet {
}
model.addNode(className, states[className].size());
}
int BaseClassifier::getNumberOfNodes()
{
// Features does not include class
return fitted ? model.getFeatures().size() + 1 : 0;
}
int BaseClassifier::getNumberOfEdges()
{
return fitted ? model.getEdges().size() : 0;
}
}

View File

@@ -30,6 +30,8 @@ namespace bayesnet {
virtual ~BaseClassifier() = default;
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
void addNodes();
int getNumberOfNodes();
int getNumberOfEdges();
Tensor predict(Tensor& X);
vector<int> predict(vector<vector<int>>& X);
float score(Tensor& X, Tensor& y);

View File

@@ -275,5 +275,17 @@ namespace bayesnet {
output.push_back("}\n");
return output;
}
vector<pair<string, string>> Network::getEdges()
{
auto edges = vector<pair<string, string>>();
for (const auto& node : nodes) {
auto head = node.first;
for (const auto& child : node.second->getChildren()) {
auto tail = child->getName();
edges.push_back({ head, tail });
}
}
return edges;
}
}

View File

@@ -36,6 +36,7 @@ namespace bayesnet {
map<string, std::unique_ptr<Node>>& getNodes();
vector<string> getFeatures();
int getStates();
vector<pair<string, string>> getEdges();
int getClassNumStates();
string getClassName();
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);