diff --git a/sample/sample.cc b/sample/sample.cc index bd74edc..869b910 100644 --- a/sample/sample.cc +++ b/sample/sample.cc @@ -158,6 +158,7 @@ int main(int argc, char** argv) states[feature] = vector(maxes[feature]); } states[className] = vector(maxes[className]); + auto clf = platform::Models::instance()->create(model_name); clf->fit(Xd, y, features, className, states); auto score = clf->score(Xd, y); @@ -166,6 +167,11 @@ int main(int argc, char** argv) for (auto line : lines) { cout << line << endl; } + cout << "--- Topological Order ---" << endl; + for (auto name : clf->topological_order()) { + cout << name << ", "; + } + cout << "end." << endl; cout << "Score: " << score << endl; auto dot_file = model_name + "_" + file_name; ofstream file(dot_file + ".dot"); diff --git a/src/BayesNet/BaseClassifier.h b/src/BayesNet/BaseClassifier.h index 6ade380..2aae0a4 100644 --- a/src/BayesNet/BaseClassifier.h +++ b/src/BayesNet/BaseClassifier.h @@ -19,6 +19,7 @@ namespace bayesnet { vector virtual graph(const string& title = "") = 0; virtual ~BaseClassifier() = default; const string inline getVersion() const { return "0.1.0"; }; + vector virtual topological_order() = 0; }; } #endif \ No newline at end of file diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index 77624fc..3099979 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -134,4 +134,8 @@ namespace bayesnet { { return fitted ? model.getStates() : 0; } + vector Classifier::topological_order() + { + return model.topological_sort(); + } } \ No newline at end of file diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index f5ee534..ddfa251 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -40,6 +40,7 @@ namespace bayesnet { float score(Tensor& X, Tensor& y) override; float score(vector>& X, vector& y) override; vector show() override; + vector topological_order() override; }; } #endif diff --git a/src/BayesNet/Ensemble.h b/src/BayesNet/Ensemble.h index 4f5c7c6..45f4d5a 100644 --- a/src/BayesNet/Ensemble.h +++ b/src/BayesNet/Ensemble.h @@ -41,6 +41,10 @@ namespace bayesnet { int getNumberOfStates() override; vector show() override; vector graph(const string& title) override; + vector topological_order() override + { + return vector(); + } }; } #endif diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index 5f9aa93..1c212be 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -341,4 +341,45 @@ namespace bayesnet { } return edges; } + vector Network::topological_sort() + { + /* Check if al the fathers of every node are before the node */ + auto result = features; + bool ending{ false }; + int idx = 0; + while (!ending) { + ending = true; + for (auto feature : features) { + if (feature == className) { + continue; + } + auto fathers = nodes[feature]->getParents(); + for (const auto& father : fathers) { + auto fatherName = father->getName(); + if (fatherName == className) { + continue; + } + auto it = find(result.begin(), result.end(), fatherName); + if (it != result.end()) { + auto it2 = find(result.begin(), result.end(), feature); + if (it2 != result.end()) { + if (distance(it, it2) < 0) { + result.erase(remove(result.begin(), result.end(), fatherName), result.end()); + result.insert(it2, fatherName); + ending = false; + } + } else { + throw logic_error("Error in topological sort because of node " + feature + " is not in result"); + } + } else { + throw logic_error("Error in topological sort because of node father " + fatherName + " is not in result"); + } + } + } + } + + + + return result; + } } diff --git a/src/BayesNet/Network.h b/src/BayesNet/Network.h index 203daa7..39ff5a6 100644 --- a/src/BayesNet/Network.h +++ b/src/BayesNet/Network.h @@ -50,6 +50,7 @@ namespace bayesnet { vector> predict_proba(const vector>&); torch::Tensor predict_proba(const torch::Tensor&); double score(const vector>&, const vector&); + vector topological_sort(); vector show(); vector graph(const string& title); // Returns a vector of strings representing the graph in graphviz format inline string version() { return "0.1.0"; }