Add graphviz output to models

This commit is contained in:
Ricardo Montañana Gómez 2023-07-16 01:20:47 +02:00
parent 29aca0b35f
commit f530e69dae
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
16 changed files with 99 additions and 25 deletions

View File

@ -259,30 +259,40 @@ int main(int argc, char** argv)
}
states[className] = vector<int>(
maxes[className]);
// cout << "****************** KDB ******************" << endl;
// auto kdb = bayesnet::KDB(2);
// kdb.fit(Xd, y, features, className, states);
// for (auto line : kdb.show()) {
// cout << line << endl;
// }
// cout << "Score: " << kdb.score(Xd, y) << endl;
// cout << "****************** KDB ******************" << endl;
// cout << "****************** SPODE ******************" << endl;
// auto spode = bayesnet::SPODE(2);
// spode.fit(Xd, y, features, className, states);
// for (auto line : spode.show()) {
// cout << line << endl;
// }
// cout << "Score: " << spode.score(Xd, y) << endl;
// cout << "****************** SPODE ******************" << endl;
// cout << "****************** AODE ******************" << endl;
// auto aode = bayesnet::AODE();
// aode.fit(Xd, y, features, className, states);
// for (auto line : aode.show()) {
// cout << line << endl;
// }
// cout << "Score: " << aode.score(Xd, y) << endl;
// cout << "****************** AODE ******************" << endl;
cout << "****************** KDB ******************" << endl;
auto kdb = bayesnet::KDB(2);
kdb.fit(Xd, y, features, className, states);
for (auto line : kdb.show()) {
cout << line << endl;
}
cout << "Score: " << kdb.score(Xd, y) << endl;
ofstream file("kdb.dot");
file << kdb.graph();
file.close();
cout << "****************** KDB ******************" << endl;
cout << "****************** SPODE ******************" << endl;
auto spode = bayesnet::SPODE(2);
spode.fit(Xd, y, features, className, states);
for (auto line : spode.show()) {
cout << line << endl;
}
cout << "Score: " << spode.score(Xd, y) << endl;
file.open("spode.dot");
file << spode.graph();
file.close();
cout << "****************** SPODE ******************" << endl;
cout << "****************** AODE ******************" << endl;
auto aode = bayesnet::AODE();
aode.fit(Xd, y, features, className, states);
for (auto line : aode.show()) {
cout << line << endl;
}
cout << "Score: " << aode.score(Xd, y) << endl;
file.open("aode.dot");
for (auto line : aode.graph())
file << line;
file.close();
cout << "****************** AODE ******************" << endl;
cout << "****************** TAN ******************" << endl;
auto tan = bayesnet::TAN();
tan.fit(Xd, y, features, className, states);
@ -290,6 +300,9 @@ int main(int argc, char** argv)
cout << line << endl;
}
cout << "Score: " << tan.score(Xd, y) << endl;
file.open("tan.dot");
file << tan.graph();
file.close();
cout << "****************** TAN ******************" << endl;
return 0;
}

View File

@ -9,4 +9,8 @@ namespace bayesnet {
models.push_back(std::make_unique<SPODE>(i));
}
}
vector<string> AODE::graph(string title)
{
return Ensemble::graph(title);
}
}

View File

@ -8,6 +8,7 @@ namespace bayesnet {
void train() override;
public:
AODE();
vector<string> graph(string title = "AODE");
};
}
#endif

View File

@ -35,6 +35,7 @@ namespace bayesnet {
float score(Tensor& X, Tensor& y);
float score(vector<vector<int>>& X, vector<int>& y);
vector<string> show();
virtual vector<string> graph(string title) = 0;
};
}
#endif

View File

@ -93,11 +93,20 @@ namespace bayesnet {
}
vector<string> Ensemble::show()
{
vector<string> result;
auto result = vector<string>();
for (auto i = 0; i < n_models; ++i) {
auto res = models[i]->show();
result.insert(result.end(), res.begin(), res.end());
}
return result;
}
vector<string> Ensemble::graph(string title)
{
auto result = vector<string>();
for (auto i = 0; i < n_models; ++i) {
auto res = models[i]->graph(title + "_" + to_string(i));
result.insert(result.end(), res.begin(), res.end());
}
return result;
}
}

View File

@ -36,6 +36,7 @@ namespace bayesnet {
float score(Tensor& X, Tensor& y);
float score(vector<vector<int>>& X, vector<int>& y);
vector<string> show();
vector<string> graph(string title);
};
}
#endif

View File

@ -80,4 +80,11 @@ namespace bayesnet {
exit_cond = num == n_edges || candidates.size(0) == 0;
}
}
vector<string> KDB::graph(string title)
{
if (title == "KDB") {
title += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")";
}
return model.graph(title);
}
}

View File

@ -14,6 +14,7 @@ namespace bayesnet {
void train() override;
public:
KDB(int k, float theta = 0.03);
vector<string> graph(string name = "KDB") override;
};
}
#endif

View File

@ -258,5 +258,19 @@ namespace bayesnet {
}
return result;
}
vector<string> Network::graph(string title)
{
auto output = vector<string>();
auto prefix = "digraph BayesNet {\nlabel=<BayesNet ";
auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
string header = prefix + title + suffix;
output.push_back(header);
for (auto& node : nodes) {
auto result = node.second->graph(className);
output.insert(output.end(), result.begin(), result.end());
}
output.push_back("}\n");
return output;
}
}

View File

@ -45,6 +45,7 @@ namespace bayesnet {
vector<vector<double>> predict_proba(const vector<vector<int>>&);
double score(const vector<vector<int>>&, const vector<int>&);
vector<string> show();
vector<string> graph(string title); // Returns a vector of strings representing the graph in graphviz format
inline string version() { return "0.1.0"; }
};
}

View File

@ -109,4 +109,14 @@ namespace bayesnet {
}
return cpTable.index({ coordinates }).item<float>();
}
vector<string> Node::graph(string className)
{
auto output = vector<string>();
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
output.push_back(name + " [shape=circle" + suffix + "] \n");
for (auto& child : children) {
output.push_back(name + " -> " + child->getName());
}
return output;
}
}

View File

@ -29,6 +29,7 @@ namespace bayesnet {
int getNumStates() const;
void setNumStates(int);
unsigned minFill();
vector<string> graph(string clasName); // Returns a vector of strings representing the graph in graphviz format
float getFactorValue(map<string, int>&);
};
}

View File

@ -17,4 +17,9 @@ namespace bayesnet {
}
}
}
vector<string> SPODE::graph(string name )
{
return model.graph(name);
}
}

View File

@ -9,6 +9,7 @@ namespace bayesnet {
void train() override;
public:
SPODE(int root);
vector<string> graph(string name = "SPODE") override;
};
}
#endif

View File

@ -35,4 +35,8 @@ namespace bayesnet {
model.addEdge(className, feature);
}
}
vector<string> TAN::graph(string title)
{
return model.graph(title);
}
}

View File

@ -10,6 +10,7 @@ namespace bayesnet {
void train() override;
public:
TAN();
vector<string> graph(string name = "TAN") override;
};
}
#endif