Add glass kdb net
This commit is contained in:
parent
ba08b8dd3d
commit
a0114da70c
25
data/glass.net
Normal file
25
data/glass.net
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
Type Si
|
||||||
|
Type Fe
|
||||||
|
Type RI
|
||||||
|
Type Na
|
||||||
|
Type Ba
|
||||||
|
Type Ca
|
||||||
|
Type Al
|
||||||
|
Type K
|
||||||
|
Type Mg
|
||||||
|
Fe RI
|
||||||
|
Fe Ba
|
||||||
|
Fe Ca
|
||||||
|
RI Na
|
||||||
|
RI Ba
|
||||||
|
RI Ca
|
||||||
|
RI Al
|
||||||
|
RI K
|
||||||
|
RI Mg
|
||||||
|
Ba Ca
|
||||||
|
Ba Al
|
||||||
|
Ca Al
|
||||||
|
Ca K
|
||||||
|
Ca Mg
|
||||||
|
Al K
|
||||||
|
K Mg
|
@ -1,54 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <string>
|
|
||||||
#include <torch/torch.h>
|
|
||||||
#include "ArffFiles.h"
|
|
||||||
#include "Network.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
int main()
|
|
||||||
{
|
|
||||||
auto handler = ArffFiles();
|
|
||||||
handler.load("iris.arff");
|
|
||||||
auto X = handler.getX();
|
|
||||||
auto y = handler.getY();
|
|
||||||
auto className = handler.getClassName();
|
|
||||||
vector<pair<string, string>> edges = { {className, "sepallength"}, {className, "sepalwidth"}, {className, "petallength"}, {className, "petalwidth"} };
|
|
||||||
auto network = bayesnet::Network();
|
|
||||||
// Add nodes to the network
|
|
||||||
for (auto feature : handler.getAttributes()) {
|
|
||||||
cout << "Adding feature: " << feature.first << endl;
|
|
||||||
network.addNode(feature.first, 7);
|
|
||||||
}
|
|
||||||
network.addNode(className, 3);
|
|
||||||
for (auto item : edges) {
|
|
||||||
network.addEdge(item.first, item.second);
|
|
||||||
}
|
|
||||||
cout << "Hello, Bayesian Networks!" << endl;
|
|
||||||
torch::Tensor tensor = torch::eye(3);
|
|
||||||
cout << "Now I'll add a cycle" << endl;
|
|
||||||
try {
|
|
||||||
network.addEdge("petallength", className);
|
|
||||||
}
|
|
||||||
catch (invalid_argument& e) {
|
|
||||||
cout << e.what() << endl;
|
|
||||||
}
|
|
||||||
cout << tensor << std::endl;
|
|
||||||
cout << "Nodes:" << endl;
|
|
||||||
for (auto [name, item] : network.getNodes()) {
|
|
||||||
cout << "*" << item->getName() << endl;
|
|
||||||
cout << "-Parents:" << endl;
|
|
||||||
for (auto parent : item->getParents()) {
|
|
||||||
cout << " " << parent->getName() << endl;
|
|
||||||
}
|
|
||||||
cout << "-Children:" << endl;
|
|
||||||
for (auto child : item->getChildren()) {
|
|
||||||
cout << " " << child->getName() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cout << "Root: " << network.getRoot()->getName() << endl;
|
|
||||||
network.setRoot(className);
|
|
||||||
cout << "Now Root should be class: " << network.getRoot()->getName() << endl;
|
|
||||||
cout << "PyTorch version: " << TORCH_VERSION << endl;
|
|
||||||
return 0;
|
|
||||||
}
|
|
@ -94,15 +94,17 @@ void showNodesInfo(bayesnet::Network& network, string className)
|
|||||||
{
|
{
|
||||||
cout << "Nodes:" << endl;
|
cout << "Nodes:" << endl;
|
||||||
for (auto [name, item] : network.getNodes()) {
|
for (auto [name, item] : network.getNodes()) {
|
||||||
cout << "*" << item->getName() << " -> " << item->getNumStates() << endl;
|
cout << "*" << item->getName() << " States -> " << item->getNumStates() << endl;
|
||||||
cout << "-Parents:" << endl;
|
cout << "-Parents:";
|
||||||
for (auto parent : item->getParents()) {
|
for (auto parent : item->getParents()) {
|
||||||
cout << " " << parent->getName() << endl;
|
cout << " " << parent->getName();
|
||||||
}
|
}
|
||||||
cout << "-Children:" << endl;
|
cout << endl;
|
||||||
|
cout << "-Children:";
|
||||||
for (auto child : item->getChildren()) {
|
for (auto child : item->getChildren()) {
|
||||||
cout << " " << child->getName() << endl;
|
cout << " " << child->getName();
|
||||||
}
|
}
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
cout << "Root: " << network.getRoot()->getName() << endl;
|
cout << "Root: " << network.getRoot()->getName() << endl;
|
||||||
network.setRoot(className);
|
network.setRoot(className);
|
||||||
@ -149,7 +151,7 @@ pair<string, string> get_options(int argc, char** argv)
|
|||||||
string path;
|
string path;
|
||||||
string network_name;
|
string network_name;
|
||||||
tie(file_name, path, network_name) = parse_arguments(argc, argv);
|
tie(file_name, path, network_name) = parse_arguments(argc, argv);
|
||||||
if (datasets.find(file_name) == datasets.end() && file_name != "all") {
|
if (datasets.find(file_name) == datasets.end()) {
|
||||||
cout << "Invalid file name: " << file_name << endl;
|
cout << "Invalid file name: " << file_name << endl;
|
||||||
usage(argv[0]);
|
usage(argv[0]);
|
||||||
exit(1);
|
exit(1);
|
||||||
|
Loading…
Reference in New Issue
Block a user