diff --git a/data/glass.net b/data/glass.net new file mode 100644 index 0000000..39592f8 --- /dev/null +++ b/data/glass.net @@ -0,0 +1,25 @@ +Type Si +Type Fe +Type RI +Type Na +Type Ba +Type Ca +Type Al +Type K +Type Mg +Fe RI +Fe Ba +Fe Ca +RI Na +RI Ba +RI Ca +RI Al +RI K +RI Mg +Ba Ca +Ba Al +Ca Al +Ca K +Ca Mg +Al K +K Mg \ No newline at end of file diff --git a/sample/main copy.cc b/sample/main copy.cc deleted file mode 100644 index 5ec155d..0000000 --- a/sample/main copy.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include -#include -#include -#include "ArffFiles.h" -#include "Network.h" - -using namespace std; - -int main() -{ - auto handler = ArffFiles(); - handler.load("iris.arff"); - auto X = handler.getX(); - auto y = handler.getY(); - auto className = handler.getClassName(); - vector> edges = { {className, "sepallength"}, {className, "sepalwidth"}, {className, "petallength"}, {className, "petalwidth"} }; - auto network = bayesnet::Network(); - // Add nodes to the network - for (auto feature : handler.getAttributes()) { - cout << "Adding feature: " << feature.first << endl; - network.addNode(feature.first, 7); - } - network.addNode(className, 3); - for (auto item : edges) { - network.addEdge(item.first, item.second); - } - cout << "Hello, Bayesian Networks!" << endl; - torch::Tensor tensor = torch::eye(3); - cout << "Now I'll add a cycle" << endl; - try { - network.addEdge("petallength", className); - } - catch (invalid_argument& e) { - cout << e.what() << endl; - } - cout << tensor << std::endl; - cout << "Nodes:" << endl; - for (auto [name, item] : network.getNodes()) { - cout << "*" << item->getName() << endl; - cout << "-Parents:" << endl; - for (auto parent : item->getParents()) { - cout << " " << parent->getName() << endl; - } - cout << "-Children:" << endl; - for (auto child : item->getChildren()) { - cout << " " << child->getName() << endl; - } - } - cout << "Root: " << network.getRoot()->getName() << endl; - network.setRoot(className); - cout << "Now Root should be class: " << network.getRoot()->getName() << endl; - cout << "PyTorch version: " << TORCH_VERSION << endl; - return 0; -} \ No newline at end of file diff --git a/sample/main.cc b/sample/main.cc index 06a22bb..b5cec25 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -94,15 +94,17 @@ void showNodesInfo(bayesnet::Network& network, string className) { cout << "Nodes:" << endl; for (auto [name, item] : network.getNodes()) { - cout << "*" << item->getName() << " -> " << item->getNumStates() << endl; - cout << "-Parents:" << endl; + cout << "*" << item->getName() << " States -> " << item->getNumStates() << endl; + cout << "-Parents:"; for (auto parent : item->getParents()) { - cout << " " << parent->getName() << endl; + cout << " " << parent->getName(); } - cout << "-Children:" << endl; + cout << endl; + cout << "-Children:"; for (auto child : item->getChildren()) { - cout << " " << child->getName() << endl; + cout << " " << child->getName(); } + cout << endl; } cout << "Root: " << network.getRoot()->getName() << endl; network.setRoot(className); @@ -149,7 +151,7 @@ pair get_options(int argc, char** argv) string path; string network_name; tie(file_name, path, network_name) = parse_arguments(argc, argv); - if (datasets.find(file_name) == datasets.end() && file_name != "all") { + if (datasets.find(file_name) == datasets.end()) { cout << "Invalid file name: " << file_name << endl; usage(argv[0]); exit(1);