Add topological order to Network

This commit is contained in:
Ricardo Montañana Gómez 2023-08-02 00:56:52 +02:00
parent f63a9a64f9
commit cdfb45d2cb
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 58 additions and 0 deletions

View File

@ -158,6 +158,7 @@ int main(int argc, char** argv)
states[feature] = vector<int>(maxes[feature]);
}
states[className] = vector<int>(maxes[className]);
auto clf = platform::Models::instance()->create(model_name);
clf->fit(Xd, y, features, className, states);
auto score = clf->score(Xd, y);
@ -166,6 +167,11 @@ int main(int argc, char** argv)
for (auto line : lines) {
cout << line << endl;
}
cout << "--- Topological Order ---" << endl;
for (auto name : clf->topological_order()) {
cout << name << ", ";
}
cout << "end." << endl;
cout << "Score: " << score << endl;
auto dot_file = model_name + "_" + file_name;
ofstream file(dot_file + ".dot");

View File

@ -19,6 +19,7 @@ namespace bayesnet {
vector<string> virtual graph(const string& title = "") = 0;
virtual ~BaseClassifier() = default;
const string inline getVersion() const { return "0.1.0"; };
vector<string> virtual topological_order() = 0;
};
}
#endif

View File

@ -134,4 +134,8 @@ namespace bayesnet {
{
return fitted ? model.getStates() : 0;
}
vector<string> Classifier::topological_order()
{
return model.topological_sort();
}
}

View File

@ -40,6 +40,7 @@ namespace bayesnet {
float score(Tensor& X, Tensor& y) override;
float score(vector<vector<int>>& X, vector<int>& y) override;
vector<string> show() override;
vector<string> topological_order() override;
};
}
#endif

View File

@ -41,6 +41,10 @@ namespace bayesnet {
int getNumberOfStates() override;
vector<string> show() override;
vector<string> graph(const string& title) override;
vector<string> topological_order() override
{
return vector<string>();
}
};
}
#endif

View File

@ -341,4 +341,45 @@ namespace bayesnet {
}
return edges;
}
vector<string> Network::topological_sort()
{
/* Check if al the fathers of every node are before the node */
auto result = features;
bool ending{ false };
int idx = 0;
while (!ending) {
ending = true;
for (auto feature : features) {
if (feature == className) {
continue;
}
auto fathers = nodes[feature]->getParents();
for (const auto& father : fathers) {
auto fatherName = father->getName();
if (fatherName == className) {
continue;
}
auto it = find(result.begin(), result.end(), fatherName);
if (it != result.end()) {
auto it2 = find(result.begin(), result.end(), feature);
if (it2 != result.end()) {
if (distance(it, it2) < 0) {
result.erase(remove(result.begin(), result.end(), fatherName), result.end());
result.insert(it2, fatherName);
ending = false;
}
} else {
throw logic_error("Error in topological sort because of node " + feature + " is not in result");
}
} else {
throw logic_error("Error in topological sort because of node father " + fatherName + " is not in result");
}
}
}
}
return result;
}
}

View File

@ -50,6 +50,7 @@ namespace bayesnet {
vector<vector<double>> predict_proba(const vector<vector<int>>&);
torch::Tensor predict_proba(const torch::Tensor&);
double score(const vector<vector<int>>&, const vector<int>&);
vector<string> topological_sort();
vector<string> show();
vector<string> graph(const string& title); // Returns a vector of strings representing the graph in graphviz format
inline string version() { return "0.1.0"; }