From 8e9c3483aa2a68ca7a99eff609b7d9c80ca824f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 29 Jun 2023 22:00:41 +0200 Subject: [PATCH] Begin Network build --- .gitignore | 2 + .vscode/launch.json | 25 +++++ .vscode/settings.json | 86 ++++++++++++++++ .vscode/tasks.json | 16 +++ ArffFiles.cc | 132 +++++++++++++++++++++++++ ArffFiles.h | 34 +++++++ CMakeLists.txt | 16 +++ Network.cc | 79 +++++++++++++++ Network.h | 21 ++++ Node.cc | 41 ++++++++ Node.h | 31 ++++++ iris.arff | 225 ++++++++++++++++++++++++++++++++++++++++++ main.cc | 43 ++++++++ simple/Network.cc | 47 +++++++++ simple/Network.h | 21 ++++ simple/Node.cc | 14 +++ simple/Node.h | 18 ++++ test.cc | 23 +++++ 18 files changed, 874 insertions(+) create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json create mode 100644 .vscode/tasks.json create mode 100644 ArffFiles.cc create mode 100644 ArffFiles.h create mode 100644 CMakeLists.txt create mode 100644 Network.cc create mode 100644 Network.h create mode 100644 Node.cc create mode 100644 Node.h create mode 100755 iris.arff create mode 100644 main.cc create mode 100644 simple/Network.cc create mode 100644 simple/Network.h create mode 100644 simple/Node.cc create mode 100644 simple/Node.h create mode 100644 test.cc diff --git a/.gitignore b/.gitignore index e257658..9d84258 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,6 @@ *.exe *.out *.app +build/ +*.dSYM/** diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..ae7831f --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,25 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "bayesnet", + "program": "${workspaceFolder}/build/bayesnet", + "args": [], + "cwd": "${workspaceFolder}", + "preLaunchTask": "CMake: build" + }, + { + "type": "lldb", + "request": "launch", + "name": "aout", + "program": "${workspaceFolder}/a.out", + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b4be811 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,86 @@ +{ + "files.associations": { + "*.rmd": "markdown", + "*.py": "python", + "vector": "cpp", + "__bit_reference": "cpp", + "__bits": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__nullptr": "cpp", + "__split_buffer": "cpp", + "__string": "cpp", + "__threading_support": "cpp", + "__tuple": "cpp", + "array": "cpp", + "atomic": "cpp", + "bitset": "cpp", + "cctype": "cpp", + "chrono": "cpp", + "clocale": "cpp", + "cmath": "cpp", + "compare": "cpp", + "complex": "cpp", + "concepts": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "exception": "cpp", + "initializer_list": "cpp", + "ios": "cpp", + "iosfwd": "cpp", + "istream": "cpp", + "limits": "cpp", + "locale": "cpp", + "memory": "cpp", + "mutex": "cpp", + "new": "cpp", + "optional": "cpp", + "ostream": "cpp", + "ratio": "cpp", + "sstream": "cpp", + "stdexcept": "cpp", + "streambuf": "cpp", + "string": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "typeinfo": "cpp", + "unordered_map": "cpp", + "variant": "cpp", + "algorithm": "cpp", + "iostream": "cpp", + "iomanip": "cpp", + "numeric": "cpp", + "set": "cpp", + "__tree": "cpp", + "deque": "cpp", + "list": "cpp", + "map": "cpp", + "unordered_set": "cpp", + "any": "cpp", + "condition_variable": "cpp", + "forward_list": "cpp", + "fstream": "cpp", + "stack": "cpp", + "thread": "cpp", + "__memory": "cpp", + "filesystem": "cpp", + "*.toml": "toml", + "utility": "cpp", + "__verbose_abort": "cpp", + "bit": "cpp" + } +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..4edc48b --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,16 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "type": "cmake", + "label": "CMake: build", + "command": "build", + "targets": [ + "all" + ], + "group": "build", + "problemMatcher": [], + "detail": "CMake template build task" + } + ] +} \ No newline at end of file diff --git a/ArffFiles.cc b/ArffFiles.cc new file mode 100644 index 0000000..b576699 --- /dev/null +++ b/ArffFiles.cc @@ -0,0 +1,132 @@ +#include "ArffFiles.h" +#include +#include +#include + +using namespace std; + +ArffFiles::ArffFiles() = default; + +vector ArffFiles::getLines() const +{ + return lines; +} + +unsigned long int ArffFiles::getSize() const +{ + return lines.size(); +} + +vector> ArffFiles::getAttributes() const +{ + return attributes; +} + +string ArffFiles::getClassName() const +{ + return className; +} + +string ArffFiles::getClassType() const +{ + return classType; +} + +vector>& ArffFiles::getX() +{ + return X; +} + +vector& ArffFiles::getY() +{ + return y; +} + +void ArffFiles::load(const string& fileName, bool classLast) +{ + ifstream file(fileName); + if (!file.is_open()) { + throw invalid_argument("Unable to open file"); + } + string line; + string keyword; + string attribute; + string type; + string type_w; + while (getline(file, line)) { + if (line.empty() || line[0] == '%' || line == "\r" || line == " ") { + continue; + } + if (line.find("@attribute") != string::npos || line.find("@ATTRIBUTE") != string::npos) { + stringstream ss(line); + ss >> keyword >> attribute; + type = ""; + while (ss >> type_w) + type += type_w + " "; + attributes.emplace_back(attribute, trim(type)); + continue; + } + if (line[0] == '@') { + continue; + } + lines.push_back(line); + } + file.close(); + if (attributes.empty()) + throw invalid_argument("No attributes found"); + if (classLast) { + className = get<0>(attributes.back()); + classType = get<1>(attributes.back()); + attributes.pop_back(); + } else { + className = get<0>(attributes.front()); + classType = get<1>(attributes.front()); + attributes.erase(attributes.begin()); + } + generateDataset(classLast); + +} + +void ArffFiles::generateDataset(bool classLast) +{ + X = vector>(attributes.size(), vector(lines.size())); + auto yy = vector(lines.size(), ""); + int labelIndex = classLast ? static_cast(attributes.size()) : 0; + for (size_t i = 0; i < lines.size(); i++) { + stringstream ss(lines[i]); + string value; + int pos = 0; + int xIndex = 0; + while (getline(ss, value, ',')) { + if (pos++ == labelIndex) { + yy[i] = value; + } else { + X[xIndex++][i] = stof(value); + } + } + } + y = factorize(yy); +} + +string ArffFiles::trim(const string& source) +{ + string s(source); + s.erase(0, s.find_first_not_of(" \n\r\t")); + s.erase(s.find_last_not_of(" \n\r\t") + 1); + return s; +} + +vector ArffFiles::factorize(const vector& labels_t) +{ + vector yy; + yy.reserve(labels_t.size()); + map labelMap; + int i = 0; + for (const string& label : labels_t) { + if (labelMap.find(label) == labelMap.end()) { + labelMap[label] = i++; + } + yy.push_back(labelMap[label]); + } + return yy; +} \ No newline at end of file diff --git a/ArffFiles.h b/ArffFiles.h new file mode 100644 index 0000000..ff8bbc5 --- /dev/null +++ b/ArffFiles.h @@ -0,0 +1,34 @@ +#ifndef ARFFFILES_H +#define ARFFFILES_H + +#include +#include + +using namespace std; + +class ArffFiles { +private: + vector lines; + vector> attributes; + string className; + string classType; + vector> X; + vector y; + + void generateDataset(bool); + +public: + ArffFiles(); + void load(const string&, bool = true); + vector getLines() const; + unsigned long int getSize() const; + string getClassName() const; + string getClassType() const; + static string trim(const string&); + vector>& getX(); + vector& getY(); + vector> getAttributes() const; + static vector factorize(const vector& labels_t); +}; + +#endif \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..2f7ffd6 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.20) + +project(bayesnet) +find_package(Torch REQUIRED) + +if (POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) +endif () + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +# add_library(BayesNet Node.cc Network.cc) +# add_executable(BayesNet main.cc Node.cc Network.cc ArffFiles.cc) +add_executable(BayesNet main.cc ArffFiles.cc Node.cc Network.cc) +target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/Network.cc b/Network.cc new file mode 100644 index 0000000..a5c41c2 --- /dev/null +++ b/Network.cc @@ -0,0 +1,79 @@ +#include "Network.h" +namespace bayesnet { + Network::~Network() + { + for (auto& pair : nodes) { + delete pair.second; + } + } + void Network::addNode(std::string name, int numStates) + { + nodes[name] = new Node(name, numStates); + } + void Network::addEdge(const std::string parent, const std::string child) + { + if (nodes.find(parent) == nodes.end()) { + throw std::invalid_argument("Parent node " + parent + " does not exist"); + } + if (nodes.find(child) == nodes.end()) { + throw std::invalid_argument("Child node " + child + " does not exist"); + } + nodes[parent]->addChild(nodes[child]); + nodes[child]->addParent(nodes[parent]); + } + std::map& Network::getNodes() + { + return nodes; + } + void Network::fit(const std::vector>& dataset, const int smoothing) + { + auto jointCounts = [](const std::vector>& data, const std::vector& indices, int numStates) { + int size = indices.size(); + std::vector sizes(size, numStates); + torch::Tensor counts = torch::zeros(sizes, torch::kLong); + + for (const auto& row : data) { + int idx = 0; + for (int i = 0; i < size; ++i) { + idx = idx * numStates + row[indices[i]]; + } + counts.view({ -1 }).add_(idx, 1); + } + + return counts; + }; + + auto marginalCounts = [](const torch::Tensor& jointCounts) { + return jointCounts.sum(-1); + }; + + for (auto& pair : nodes) { + Node* node = pair.second; + + std::vector indices; + for (const auto& parent : node->getParents()) { + indices.push_back(nodes[parent->getName()]->getId()); + } + indices.push_back(node->getId()); + + for (auto& child : node->getChildren()) { + torch::Tensor counts = jointCounts(dataset, indices, node->getNumStates()) + smoothing; + torch::Tensor parentCounts = marginalCounts(counts); + parentCounts = parentCounts.unsqueeze(-1); + + torch::Tensor cpt = counts.to(torch::kDouble) / parentCounts.to(torch::kDouble); + setCPD(node->getCPDKey(child), cpt); + } + } + } + + torch::Tensor& Network::getCPD(const std::string& key) + { + return cpds[key]; + } + + void Network::setCPD(const std::string& key, const torch::Tensor& cpt) + { + cpds[key] = cpt; + } +} diff --git a/Network.h b/Network.h new file mode 100644 index 0000000..94bafdc --- /dev/null +++ b/Network.h @@ -0,0 +1,21 @@ +#ifndef NETWORK_H +#define NETWORK_H +#include "Node.h" +#include +#include +namespace bayesnet { + class Network { + private: + map nodes; + map cpds; // Map from CPD key to CPD tensor + public: + ~Network(); + void addNode(string, int); + void addEdge(const string, const string); + map& getNodes(); + void fit(const vector>&, const int); + torch::Tensor& getCPD(const string&); + void setCPD(const string&, const torch::Tensor&); + }; +} +#endif \ No newline at end of file diff --git a/Node.cc b/Node.cc new file mode 100644 index 0000000..85b1f1d --- /dev/null +++ b/Node.cc @@ -0,0 +1,41 @@ +#include "Node.h" + +namespace bayesnet { + int Node::next_id = 0; + + Node::Node(const std::string& name, int numStates) + : id(next_id++), name(name), numStates(numStates), cpt(torch::Tensor()), parents(vector()), children(vector()) + { + } + + string Node::getName() const + { + return name; + } + + void Node::addParent(Node* parent) + { + parents.push_back(parent); + } + + void Node::addChild(Node* child) + { + children.push_back(child); + } + vector& Node::getParents() + { + return parents; + } + vector& Node::getChildren() + { + return children; + } + int Node::getNumStates() const + { + return numStates; + } + string Node::getCPDKey(const Node* child) const + { + return name + "-" + child->getName(); + } +} \ No newline at end of file diff --git a/Node.h b/Node.h new file mode 100644 index 0000000..a3ab4f3 --- /dev/null +++ b/Node.h @@ -0,0 +1,31 @@ +#ifndef NODE_H +#define NODE_H +#include +#include +#include +namespace bayesnet { + using namespace std; + class Node { + private: + static int next_id; + const int id; + string name; + vector parents; + vector children; + int numStates; + torch::Tensor cpt; + public: + Node(const std::string& name, int numStates); + void addParent(Node* parent); + void addChild(Node* child); + string getName() const; + vector& getParents(); + vector& getChildren(); + torch::Tensor& getCPT(); + void setCPT(const torch::Tensor& cpt); + int getNumStates() const; + int getId() const { return id; } + string getCPDKey(const Node*) const; + }; +} +#endif \ No newline at end of file diff --git a/iris.arff b/iris.arff new file mode 100755 index 0000000..780480c --- /dev/null +++ b/iris.arff @@ -0,0 +1,225 @@ +% 1. Title: Iris Plants Database +% +% 2. Sources: +% (a) Creator: R.A. Fisher +% (b) Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov) +% (c) Date: July, 1988 +% +% 3. Past Usage: +% - Publications: too many to mention!!! Here are a few. +% 1. Fisher,R.A. "The use of multiple measurements in taxonomic problems" +% Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions +% to Mathematical Statistics" (John Wiley, NY, 1950). +% 2. Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis. +% (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218. +% 3. Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System +% Structure and Classification Rule for Recognition in Partially Exposed +% Environments". IEEE Transactions on Pattern Analysis and Machine +% Intelligence, Vol. PAMI-2, No. 1, 67-71. +% -- Results: +% -- very low misclassification rates (0% for the setosa class) +% 4. Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE +% Transactions on Information Theory, May 1972, 431-433. +% -- Results: +% -- very low misclassification rates again +% 5. See also: 1988 MLC Proceedings, 54-64. Cheeseman et al's AUTOCLASS II +% conceptual clustering system finds 3 classes in the data. +% +% 4. Relevant Information: +% --- This is perhaps the best known database to be found in the pattern +% recognition literature. Fisher's paper is a classic in the field +% and is referenced frequently to this day. (See Duda & Hart, for +% example.) The data set contains 3 classes of 50 instances each, +% where each class refers to a type of iris plant. One class is +% linearly separable from the other 2; the latter are NOT linearly +% separable from each other. +% --- Predicted attribute: class of iris plant. +% --- This is an exceedingly simple domain. +% +% 5. Number of Instances: 150 (50 in each of three classes) +% +% 6. Number of Attributes: 4 numeric, predictive attributes and the class +% +% 7. Attribute Information: +% 1. sepal length in cm +% 2. sepal width in cm +% 3. petal length in cm +% 4. petal width in cm +% 5. class: +% -- Iris Setosa +% -- Iris Versicolour +% -- Iris Virginica +% +% 8. Missing Attribute Values: None +% +% Summary Statistics: +% Min Max Mean SD Class Correlation +% sepal length: 4.3 7.9 5.84 0.83 0.7826 +% sepal width: 2.0 4.4 3.05 0.43 -0.4194 +% petal length: 1.0 6.9 3.76 1.76 0.9490 (high!) +% petal width: 0.1 2.5 1.20 0.76 0.9565 (high!) +% +% 9. Class Distribution: 33.3% for each of 3 classes. + +@RELATION iris + +@ATTRIBUTE sepallength REAL +@ATTRIBUTE sepalwidth REAL +@ATTRIBUTE petallength REAL +@ATTRIBUTE petalwidth REAL +@ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica} + +@DATA +5.1,3.5,1.4,0.2,Iris-setosa +4.9,3.0,1.4,0.2,Iris-setosa +4.7,3.2,1.3,0.2,Iris-setosa +4.6,3.1,1.5,0.2,Iris-setosa +5.0,3.6,1.4,0.2,Iris-setosa +5.4,3.9,1.7,0.4,Iris-setosa +4.6,3.4,1.4,0.3,Iris-setosa +5.0,3.4,1.5,0.2,Iris-setosa +4.4,2.9,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.4,3.7,1.5,0.2,Iris-setosa +4.8,3.4,1.6,0.2,Iris-setosa +4.8,3.0,1.4,0.1,Iris-setosa +4.3,3.0,1.1,0.1,Iris-setosa +5.8,4.0,1.2,0.2,Iris-setosa +5.7,4.4,1.5,0.4,Iris-setosa +5.4,3.9,1.3,0.4,Iris-setosa +5.1,3.5,1.4,0.3,Iris-setosa +5.7,3.8,1.7,0.3,Iris-setosa +5.1,3.8,1.5,0.3,Iris-setosa +5.4,3.4,1.7,0.2,Iris-setosa +5.1,3.7,1.5,0.4,Iris-setosa +4.6,3.6,1.0,0.2,Iris-setosa +5.1,3.3,1.7,0.5,Iris-setosa +4.8,3.4,1.9,0.2,Iris-setosa +5.0,3.0,1.6,0.2,Iris-setosa +5.0,3.4,1.6,0.4,Iris-setosa +5.2,3.5,1.5,0.2,Iris-setosa +5.2,3.4,1.4,0.2,Iris-setosa +4.7,3.2,1.6,0.2,Iris-setosa +4.8,3.1,1.6,0.2,Iris-setosa +5.4,3.4,1.5,0.4,Iris-setosa +5.2,4.1,1.5,0.1,Iris-setosa +5.5,4.2,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.0,3.2,1.2,0.2,Iris-setosa +5.5,3.5,1.3,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +4.4,3.0,1.3,0.2,Iris-setosa +5.1,3.4,1.5,0.2,Iris-setosa +5.0,3.5,1.3,0.3,Iris-setosa +4.5,2.3,1.3,0.3,Iris-setosa +4.4,3.2,1.3,0.2,Iris-setosa +5.0,3.5,1.6,0.6,Iris-setosa +5.1,3.8,1.9,0.4,Iris-setosa +4.8,3.0,1.4,0.3,Iris-setosa +5.1,3.8,1.6,0.2,Iris-setosa +4.6,3.2,1.4,0.2,Iris-setosa +5.3,3.7,1.5,0.2,Iris-setosa +5.0,3.3,1.4,0.2,Iris-setosa +7.0,3.2,4.7,1.4,Iris-versicolor +6.4,3.2,4.5,1.5,Iris-versicolor +6.9,3.1,4.9,1.5,Iris-versicolor +5.5,2.3,4.0,1.3,Iris-versicolor +6.5,2.8,4.6,1.5,Iris-versicolor +5.7,2.8,4.5,1.3,Iris-versicolor +6.3,3.3,4.7,1.6,Iris-versicolor +4.9,2.4,3.3,1.0,Iris-versicolor +6.6,2.9,4.6,1.3,Iris-versicolor +5.2,2.7,3.9,1.4,Iris-versicolor +5.0,2.0,3.5,1.0,Iris-versicolor +5.9,3.0,4.2,1.5,Iris-versicolor +6.0,2.2,4.0,1.0,Iris-versicolor +6.1,2.9,4.7,1.4,Iris-versicolor +5.6,2.9,3.6,1.3,Iris-versicolor +6.7,3.1,4.4,1.4,Iris-versicolor +5.6,3.0,4.5,1.5,Iris-versicolor +5.8,2.7,4.1,1.0,Iris-versicolor +6.2,2.2,4.5,1.5,Iris-versicolor +5.6,2.5,3.9,1.1,Iris-versicolor +5.9,3.2,4.8,1.8,Iris-versicolor +6.1,2.8,4.0,1.3,Iris-versicolor +6.3,2.5,4.9,1.5,Iris-versicolor +6.1,2.8,4.7,1.2,Iris-versicolor +6.4,2.9,4.3,1.3,Iris-versicolor +6.6,3.0,4.4,1.4,Iris-versicolor +6.8,2.8,4.8,1.4,Iris-versicolor +6.7,3.0,5.0,1.7,Iris-versicolor +6.0,2.9,4.5,1.5,Iris-versicolor +5.7,2.6,3.5,1.0,Iris-versicolor +5.5,2.4,3.8,1.1,Iris-versicolor +5.5,2.4,3.7,1.0,Iris-versicolor +5.8,2.7,3.9,1.2,Iris-versicolor +6.0,2.7,5.1,1.6,Iris-versicolor +5.4,3.0,4.5,1.5,Iris-versicolor +6.0,3.4,4.5,1.6,Iris-versicolor +6.7,3.1,4.7,1.5,Iris-versicolor +6.3,2.3,4.4,1.3,Iris-versicolor +5.6,3.0,4.1,1.3,Iris-versicolor +5.5,2.5,4.0,1.3,Iris-versicolor +5.5,2.6,4.4,1.2,Iris-versicolor +6.1,3.0,4.6,1.4,Iris-versicolor +5.8,2.6,4.0,1.2,Iris-versicolor +5.0,2.3,3.3,1.0,Iris-versicolor +5.6,2.7,4.2,1.3,Iris-versicolor +5.7,3.0,4.2,1.2,Iris-versicolor +5.7,2.9,4.2,1.3,Iris-versicolor +6.2,2.9,4.3,1.3,Iris-versicolor +5.1,2.5,3.0,1.1,Iris-versicolor +5.7,2.8,4.1,1.3,Iris-versicolor +6.3,3.3,6.0,2.5,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +7.1,3.0,5.9,2.1,Iris-virginica +6.3,2.9,5.6,1.8,Iris-virginica +6.5,3.0,5.8,2.2,Iris-virginica +7.6,3.0,6.6,2.1,Iris-virginica +4.9,2.5,4.5,1.7,Iris-virginica +7.3,2.9,6.3,1.8,Iris-virginica +6.7,2.5,5.8,1.8,Iris-virginica +7.2,3.6,6.1,2.5,Iris-virginica +6.5,3.2,5.1,2.0,Iris-virginica +6.4,2.7,5.3,1.9,Iris-virginica +6.8,3.0,5.5,2.1,Iris-virginica +5.7,2.5,5.0,2.0,Iris-virginica +5.8,2.8,5.1,2.4,Iris-virginica +6.4,3.2,5.3,2.3,Iris-virginica +6.5,3.0,5.5,1.8,Iris-virginica +7.7,3.8,6.7,2.2,Iris-virginica +7.7,2.6,6.9,2.3,Iris-virginica +6.0,2.2,5.0,1.5,Iris-virginica +6.9,3.2,5.7,2.3,Iris-virginica +5.6,2.8,4.9,2.0,Iris-virginica +7.7,2.8,6.7,2.0,Iris-virginica +6.3,2.7,4.9,1.8,Iris-virginica +6.7,3.3,5.7,2.1,Iris-virginica +7.2,3.2,6.0,1.8,Iris-virginica +6.2,2.8,4.8,1.8,Iris-virginica +6.1,3.0,4.9,1.8,Iris-virginica +6.4,2.8,5.6,2.1,Iris-virginica +7.2,3.0,5.8,1.6,Iris-virginica +7.4,2.8,6.1,1.9,Iris-virginica +7.9,3.8,6.4,2.0,Iris-virginica +6.4,2.8,5.6,2.2,Iris-virginica +6.3,2.8,5.1,1.5,Iris-virginica +6.1,2.6,5.6,1.4,Iris-virginica +7.7,3.0,6.1,2.3,Iris-virginica +6.3,3.4,5.6,2.4,Iris-virginica +6.4,3.1,5.5,1.8,Iris-virginica +6.0,3.0,4.8,1.8,Iris-virginica +6.9,3.1,5.4,2.1,Iris-virginica +6.7,3.1,5.6,2.4,Iris-virginica +6.9,3.1,5.1,2.3,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +6.8,3.2,5.9,2.3,Iris-virginica +6.7,3.3,5.7,2.5,Iris-virginica +6.7,3.0,5.2,2.3,Iris-virginica +6.3,2.5,5.0,1.9,Iris-virginica +6.5,3.0,5.2,2.0,Iris-virginica +6.2,3.4,5.4,2.3,Iris-virginica +5.9,3.0,5.1,1.8,Iris-virginica +% +% +% diff --git a/main.cc b/main.cc new file mode 100644 index 0000000..8a1d6d1 --- /dev/null +++ b/main.cc @@ -0,0 +1,43 @@ +#include +#include +#include +#include "ArffFiles.h" +#include "Network.h" + +using namespace std; + +int main() +{ + auto handler = ArffFiles(); + handler.load("iris.arff"); + auto X = handler.getX(); + auto y = handler.getY(); + auto className = handler.getClassName(); + vector> edges = { {className, "sepallength"}, {className, "sepalwidth"}, {className, "petallength"}, {className, "petalwidth"} }; + auto network = bayesnet::Network(); + // Add nodes to the network + for (auto feature : handler.getAttributes()) { + cout << "Adding feature: " << feature.first << endl; + network.addNode(feature.first, 7); + } + network.addNode(className, 3); + for (auto item : edges) { + network.addEdge(item.first, item.second); + } + cout << "Hello, Bayesian Networks!" << endl; + torch::Tensor tensor = torch::eye(3); + cout << tensor << std::endl; + cout << "Nodes:" << endl; + for (auto [name, item] : network.getNodes()) { + cout << "*" << item->getName() << endl; + cout << "-Parents:" << endl; + for (auto parent : item->getParents()) { + cout << " " << parent->getName() << endl; + } + cout << "-Children:" << endl; + for (auto child : item->getChildren()) { + cout << " " << child->getName() << endl; + } + } + return 0; +} \ No newline at end of file diff --git a/simple/Network.cc b/simple/Network.cc new file mode 100644 index 0000000..f16a1dd --- /dev/null +++ b/simple/Network.cc @@ -0,0 +1,47 @@ +#include +#include +#include +#include "Network.h" + +namespace bayesnet { + Network::Network() {} + + Network::~Network() + { + for (auto& pair : nodes) { + delete pair.second; + } + } + + void Network::addNode(std::string name) + { + nodes[name] = new Node(name); + } + + void Network::addEdge(std::string parentName, std::string childName) + { + Node* parent = nodes[parentName]; + Node* child = nodes[childName]; + + if (parent == nullptr || child == nullptr) { + throw std::invalid_argument("Parent or child node not found."); + } + + child->addParent(parent); + } + + // to be implemented + void Network::fit(const std::vector>& dataset) + { + // ... learn parameters (i.e., CPTs) using the dataset + } + + // to be implemented + std::vector Network::predict(const std::vector>& testset) + { + std::vector predictions; + // ... use the CPTs and network structure to predict values + return predictions; + } +} + diff --git a/simple/Network.h b/simple/Network.h new file mode 100644 index 0000000..d09ac72 --- /dev/null +++ b/simple/Network.h @@ -0,0 +1,21 @@ +#ifndef NETWORK_H +#define NETWORK_H +#include +#include +#include +#include "Node.h" + +namespace bayesnet { + class Network { + private: + std::map nodes; + public: + Network(); + ~Network(); + void addNode(std::string); + void addEdge(std::string, std::string); + void fit(const std::vector>&); + std::vector predict(const std::vector>&); + }; +} +#endif diff --git a/simple/Node.cc b/simple/Node.cc new file mode 100644 index 0000000..0d04c87 --- /dev/null +++ b/simple/Node.cc @@ -0,0 +1,14 @@ +#include +#include +#include +#include "Node.h" + +namespace bayesnet { + Node::Node(std::string name) : name(name) {} + + void Node::addParent(Node* parent) + { + parents.push_back(parent); + parent->children.push_back(this); + } +} diff --git a/simple/Node.h b/simple/Node.h new file mode 100644 index 0000000..7f017f9 --- /dev/null +++ b/simple/Node.h @@ -0,0 +1,18 @@ +#ifndef NODE_H +#define NODE_H +#include +#include +#include +namespace bayesnet { + class Node { + private: + std::string name; + std::vector parents; + std::vector children; + std::map, double> cpt; // Conditional Probability Table + public: + Node(std::string); + void addParent(Node*); + }; +} +#endif diff --git a/test.cc b/test.cc new file mode 100644 index 0000000..adbbb08 --- /dev/null +++ b/test.cc @@ -0,0 +1,23 @@ +#include +#include +#include + +using namespace std; + +int main(int argc, char const* argv[]) +{ + map m; + m["a"] = 1; + m["b"] = 2; + m["c"] = 3; + if (m.find("b") != m.end()) { + cout << "Found b" << endl; + } else { + cout << "Not found b" << endl; + } + // for (auto [key, value] : m) { + // cout << key << " " << value << endl; + // } + + return 0; +}