Add glass kdb net

This commit is contained in:
Ricardo Montañana Gómez 2023-07-05 19:09:59 +02:00
parent ba08b8dd3d
commit a0114da70c
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 33 additions and 60 deletions

25
data/glass.net Normal file
View File

@ -0,0 +1,25 @@
Type Si
Type Fe
Type RI
Type Na
Type Ba
Type Ca
Type Al
Type K
Type Mg
Fe RI
Fe Ba
Fe Ca
RI Na
RI Ba
RI Ca
RI Al
RI K
RI Mg
Ba Ca
Ba Al
Ca Al
Ca K
Ca Mg
Al K
K Mg

View File

@ -1,54 +0,0 @@
#include <iostream>
#include <string>
#include <torch/torch.h>
#include "ArffFiles.h"
#include "Network.h"
using namespace std;
int main()
{
auto handler = ArffFiles();
handler.load("iris.arff");
auto X = handler.getX();
auto y = handler.getY();
auto className = handler.getClassName();
vector<pair<string, string>> edges = { {className, "sepallength"}, {className, "sepalwidth"}, {className, "petallength"}, {className, "petalwidth"} };
auto network = bayesnet::Network();
// Add nodes to the network
for (auto feature : handler.getAttributes()) {
cout << "Adding feature: " << feature.first << endl;
network.addNode(feature.first, 7);
}
network.addNode(className, 3);
for (auto item : edges) {
network.addEdge(item.first, item.second);
}
cout << "Hello, Bayesian Networks!" << endl;
torch::Tensor tensor = torch::eye(3);
cout << "Now I'll add a cycle" << endl;
try {
network.addEdge("petallength", className);
}
catch (invalid_argument& e) {
cout << e.what() << endl;
}
cout << tensor << std::endl;
cout << "Nodes:" << endl;
for (auto [name, item] : network.getNodes()) {
cout << "*" << item->getName() << endl;
cout << "-Parents:" << endl;
for (auto parent : item->getParents()) {
cout << " " << parent->getName() << endl;
}
cout << "-Children:" << endl;
for (auto child : item->getChildren()) {
cout << " " << child->getName() << endl;
}
}
cout << "Root: " << network.getRoot()->getName() << endl;
network.setRoot(className);
cout << "Now Root should be class: " << network.getRoot()->getName() << endl;
cout << "PyTorch version: " << TORCH_VERSION << endl;
return 0;
}

View File

@ -94,15 +94,17 @@ void showNodesInfo(bayesnet::Network& network, string className)
{
cout << "Nodes:" << endl;
for (auto [name, item] : network.getNodes()) {
cout << "*" << item->getName() << " -> " << item->getNumStates() << endl;
cout << "-Parents:" << endl;
cout << "*" << item->getName() << " States -> " << item->getNumStates() << endl;
cout << "-Parents:";
for (auto parent : item->getParents()) {
cout << " " << parent->getName() << endl;
cout << " " << parent->getName();
}
cout << "-Children:" << endl;
cout << endl;
cout << "-Children:";
for (auto child : item->getChildren()) {
cout << " " << child->getName() << endl;
cout << " " << child->getName();
}
cout << endl;
}
cout << "Root: " << network.getRoot()->getName() << endl;
network.setRoot(className);
@ -149,7 +151,7 @@ pair<string, string> get_options(int argc, char** argv)
string path;
string network_name;
tie(file_name, path, network_name) = parse_arguments(argc, argv);
if (datasets.find(file_name) == datasets.end() && file_name != "all") {
if (datasets.find(file_name) == datasets.end()) {
cout << "Invalid file name: " << file_name << endl;
usage(argv[0]);
exit(1);