diff --git a/sample/main.cc b/sample/main.cc index de0e928..826b272 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -259,30 +259,40 @@ int main(int argc, char** argv) } states[className] = vector( maxes[className]); - // cout << "****************** KDB ******************" << endl; - // auto kdb = bayesnet::KDB(2); - // kdb.fit(Xd, y, features, className, states); - // for (auto line : kdb.show()) { - // cout << line << endl; - // } - // cout << "Score: " << kdb.score(Xd, y) << endl; - // cout << "****************** KDB ******************" << endl; - // cout << "****************** SPODE ******************" << endl; - // auto spode = bayesnet::SPODE(2); - // spode.fit(Xd, y, features, className, states); - // for (auto line : spode.show()) { - // cout << line << endl; - // } - // cout << "Score: " << spode.score(Xd, y) << endl; - // cout << "****************** SPODE ******************" << endl; - // cout << "****************** AODE ******************" << endl; - // auto aode = bayesnet::AODE(); - // aode.fit(Xd, y, features, className, states); - // for (auto line : aode.show()) { - // cout << line << endl; - // } - // cout << "Score: " << aode.score(Xd, y) << endl; - // cout << "****************** AODE ******************" << endl; + cout << "****************** KDB ******************" << endl; + auto kdb = bayesnet::KDB(2); + kdb.fit(Xd, y, features, className, states); + for (auto line : kdb.show()) { + cout << line << endl; + } + cout << "Score: " << kdb.score(Xd, y) << endl; + ofstream file("kdb.dot"); + file << kdb.graph(); + file.close(); + cout << "****************** KDB ******************" << endl; + cout << "****************** SPODE ******************" << endl; + auto spode = bayesnet::SPODE(2); + spode.fit(Xd, y, features, className, states); + for (auto line : spode.show()) { + cout << line << endl; + } + cout << "Score: " << spode.score(Xd, y) << endl; + file.open("spode.dot"); + file << spode.graph(); + file.close(); + cout << "****************** SPODE ******************" << endl; + cout << "****************** AODE ******************" << endl; + auto aode = bayesnet::AODE(); + aode.fit(Xd, y, features, className, states); + for (auto line : aode.show()) { + cout << line << endl; + } + cout << "Score: " << aode.score(Xd, y) << endl; + file.open("aode.dot"); + for (auto line : aode.graph()) + file << line; + file.close(); + cout << "****************** AODE ******************" << endl; cout << "****************** TAN ******************" << endl; auto tan = bayesnet::TAN(); tan.fit(Xd, y, features, className, states); @@ -290,6 +300,9 @@ int main(int argc, char** argv) cout << line << endl; } cout << "Score: " << tan.score(Xd, y) << endl; + file.open("tan.dot"); + file << tan.graph(); + file.close(); cout << "****************** TAN ******************" << endl; return 0; } \ No newline at end of file diff --git a/src/AODE.cc b/src/AODE.cc index ef17c77..f81b191 100644 --- a/src/AODE.cc +++ b/src/AODE.cc @@ -9,4 +9,8 @@ namespace bayesnet { models.push_back(std::make_unique(i)); } } + vector AODE::graph(string title) + { + return Ensemble::graph(title); + } } \ No newline at end of file diff --git a/src/AODE.h b/src/AODE.h index 061230a..a53ec8a 100644 --- a/src/AODE.h +++ b/src/AODE.h @@ -8,6 +8,7 @@ namespace bayesnet { void train() override; public: AODE(); + vector graph(string title = "AODE"); }; } #endif \ No newline at end of file diff --git a/src/BaseClassifier.h b/src/BaseClassifier.h index 301f847..730f3cd 100644 --- a/src/BaseClassifier.h +++ b/src/BaseClassifier.h @@ -35,6 +35,7 @@ namespace bayesnet { float score(Tensor& X, Tensor& y); float score(vector>& X, vector& y); vector show(); + virtual vector graph(string title) = 0; }; } #endif diff --git a/src/Ensemble.cc b/src/Ensemble.cc index c37d2a0..8a971c3 100644 --- a/src/Ensemble.cc +++ b/src/Ensemble.cc @@ -93,11 +93,20 @@ namespace bayesnet { } vector Ensemble::show() { - vector result; + auto result = vector(); for (auto i = 0; i < n_models; ++i) { auto res = models[i]->show(); result.insert(result.end(), res.begin(), res.end()); } return result; } + vector Ensemble::graph(string title) + { + auto result = vector(); + for (auto i = 0; i < n_models; ++i) { + auto res = models[i]->graph(title + "_" + to_string(i)); + result.insert(result.end(), res.begin(), res.end()); + } + return result; + } } \ No newline at end of file diff --git a/src/Ensemble.h b/src/Ensemble.h index 118f5e0..8db299d 100644 --- a/src/Ensemble.h +++ b/src/Ensemble.h @@ -36,6 +36,7 @@ namespace bayesnet { float score(Tensor& X, Tensor& y); float score(vector>& X, vector& y); vector show(); + vector graph(string title); }; } #endif diff --git a/src/KDB.cc b/src/KDB.cc index d28d805..f1023f6 100644 --- a/src/KDB.cc +++ b/src/KDB.cc @@ -80,4 +80,11 @@ namespace bayesnet { exit_cond = num == n_edges || candidates.size(0) == 0; } } + vector KDB::graph(string title) + { + if (title == "KDB") { + title += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")"; + } + return model.graph(title); + } } \ No newline at end of file diff --git a/src/KDB.h b/src/KDB.h index a0a2825..f58a7a5 100644 --- a/src/KDB.h +++ b/src/KDB.h @@ -14,6 +14,7 @@ namespace bayesnet { void train() override; public: KDB(int k, float theta = 0.03); + vector graph(string name = "KDB") override; }; } #endif \ No newline at end of file diff --git a/src/Network.cc b/src/Network.cc index d2ed9f3..e3246d4 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -258,5 +258,19 @@ namespace bayesnet { } return result; } + vector Network::graph(string title) + { + auto output = vector(); + auto prefix = "digraph BayesNet {\nlabel=graph(className); + output.insert(output.end(), result.begin(), result.end()); + } + output.push_back("}\n"); + return output; + } } diff --git a/src/Network.h b/src/Network.h index 9d3b3f4..e3b3b19 100644 --- a/src/Network.h +++ b/src/Network.h @@ -45,6 +45,7 @@ namespace bayesnet { vector> predict_proba(const vector>&); double score(const vector>&, const vector&); vector show(); + vector graph(string title); // Returns a vector of strings representing the graph in graphviz format inline string version() { return "0.1.0"; } }; } diff --git a/src/Node.cc b/src/Node.cc index a24bec7..6f1ba75 100644 --- a/src/Node.cc +++ b/src/Node.cc @@ -109,4 +109,14 @@ namespace bayesnet { } return cpTable.index({ coordinates }).item(); } + vector Node::graph(string className) + { + auto output = vector(); + auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : ""; + output.push_back(name + " [shape=circle" + suffix + "] \n"); + for (auto& child : children) { + output.push_back(name + " -> " + child->getName()); + } + return output; + } } \ No newline at end of file diff --git a/src/Node.h b/src/Node.h index c7961aa..45c9c02 100644 --- a/src/Node.h +++ b/src/Node.h @@ -29,6 +29,7 @@ namespace bayesnet { int getNumStates() const; void setNumStates(int); unsigned minFill(); + vector graph(string clasName); // Returns a vector of strings representing the graph in graphviz format float getFactorValue(map&); }; } diff --git a/src/SPODE.cc b/src/SPODE.cc index 4b0431b..dc661e7 100644 --- a/src/SPODE.cc +++ b/src/SPODE.cc @@ -17,4 +17,9 @@ namespace bayesnet { } } } + vector SPODE::graph(string name ) + { + return model.graph(name); + } + } \ No newline at end of file diff --git a/src/SPODE.h b/src/SPODE.h index f796d19..dae600b 100644 --- a/src/SPODE.h +++ b/src/SPODE.h @@ -9,6 +9,7 @@ namespace bayesnet { void train() override; public: SPODE(int root); + vector graph(string name = "SPODE") override; }; } #endif \ No newline at end of file diff --git a/src/TAN.cc b/src/TAN.cc index bb0d561..fb0e533 100644 --- a/src/TAN.cc +++ b/src/TAN.cc @@ -35,4 +35,8 @@ namespace bayesnet { model.addEdge(className, feature); } } + vector TAN::graph(string title) + { + return model.graph(title); + } } \ No newline at end of file diff --git a/src/TAN.h b/src/TAN.h index d1477b6..f438d91 100644 --- a/src/TAN.h +++ b/src/TAN.h @@ -10,6 +10,7 @@ namespace bayesnet { void train() override; public: TAN(); + vector graph(string name = "TAN") override; }; } #endif \ No newline at end of file