Add getNumberOfNodes & getNumberOfEdges to Models
Add some more tests
This commit is contained in:
@@ -8,7 +8,6 @@ namespace bayesnet {
|
||||
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
|
||||
BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
{
|
||||
|
||||
dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1);
|
||||
this->features = features;
|
||||
this->className = className;
|
||||
@@ -116,4 +115,13 @@ namespace bayesnet {
|
||||
}
|
||||
model.addNode(className, states[className].size());
|
||||
}
|
||||
int BaseClassifier::getNumberOfNodes()
|
||||
{
|
||||
// Features does not include class
|
||||
return fitted ? model.getFeatures().size() + 1 : 0;
|
||||
}
|
||||
int BaseClassifier::getNumberOfEdges()
|
||||
{
|
||||
return fitted ? model.getEdges().size() : 0;
|
||||
}
|
||||
}
|
@@ -30,6 +30,8 @@ namespace bayesnet {
|
||||
virtual ~BaseClassifier() = default;
|
||||
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
|
||||
void addNodes();
|
||||
int getNumberOfNodes();
|
||||
int getNumberOfEdges();
|
||||
Tensor predict(Tensor& X);
|
||||
vector<int> predict(vector<vector<int>>& X);
|
||||
float score(Tensor& X, Tensor& y);
|
||||
|
@@ -275,5 +275,17 @@ namespace bayesnet {
|
||||
output.push_back("}\n");
|
||||
return output;
|
||||
}
|
||||
vector<pair<string, string>> Network::getEdges()
|
||||
{
|
||||
auto edges = vector<pair<string, string>>();
|
||||
for (const auto& node : nodes) {
|
||||
auto head = node.first;
|
||||
for (const auto& child : node.second->getChildren()) {
|
||||
auto tail = child->getName();
|
||||
edges.push_back({ head, tail });
|
||||
}
|
||||
}
|
||||
return edges;
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -36,6 +36,7 @@ namespace bayesnet {
|
||||
map<string, std::unique_ptr<Node>>& getNodes();
|
||||
vector<string> getFeatures();
|
||||
int getStates();
|
||||
vector<pair<string, string>> getEdges();
|
||||
int getClassNumStates();
|
||||
string getClassName();
|
||||
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
|
||||
|
Reference in New Issue
Block a user