diff --git a/.gitignore b/.gitignore index e257658..9d84258 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,6 @@ *.exe *.out *.app +build/ +*.dSYM/** diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..ae7831f --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,25 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "bayesnet", + "program": "${workspaceFolder}/build/bayesnet", + "args": [], + "cwd": "${workspaceFolder}", + "preLaunchTask": "CMake: build" + }, + { + "type": "lldb", + "request": "launch", + "name": "aout", + "program": "${workspaceFolder}/a.out", + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b4be811 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,86 @@ +{ + "files.associations": { + "*.rmd": "markdown", + "*.py": "python", + "vector": "cpp", + "__bit_reference": "cpp", + "__bits": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__nullptr": "cpp", + "__split_buffer": "cpp", + "__string": "cpp", + "__threading_support": "cpp", + "__tuple": "cpp", + "array": "cpp", + "atomic": "cpp", + "bitset": "cpp", + "cctype": "cpp", + "chrono": "cpp", + "clocale": "cpp", + "cmath": "cpp", + "compare": "cpp", + "complex": "cpp", + "concepts": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "exception": "cpp", + "initializer_list": "cpp", + "ios": "cpp", + "iosfwd": "cpp", + "istream": "cpp", + "limits": "cpp", + "locale": "cpp", + "memory": "cpp", + "mutex": "cpp", + "new": "cpp", + "optional": "cpp", + "ostream": "cpp", + "ratio": "cpp", + "sstream": "cpp", + "stdexcept": "cpp", + "streambuf": "cpp", + "string": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "typeinfo": "cpp", + "unordered_map": "cpp", + "variant": "cpp", + "algorithm": "cpp", + "iostream": "cpp", + "iomanip": "cpp", + "numeric": "cpp", + "set": "cpp", + "__tree": "cpp", + "deque": "cpp", + "list": "cpp", + "map": "cpp", + "unordered_set": "cpp", + "any": "cpp", + "condition_variable": "cpp", + "forward_list": "cpp", + "fstream": "cpp", + "stack": "cpp", + "thread": "cpp", + "__memory": "cpp", + "filesystem": "cpp", + "*.toml": "toml", + "utility": "cpp", + "__verbose_abort": "cpp", + "bit": "cpp" + } +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..4edc48b --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,16 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "type": "cmake", + "label": "CMake: build", + "command": "build", + "targets": [ + "all" + ], + "group": "build", + "problemMatcher": [], + "detail": "CMake template build task" + } + ] +} \ No newline at end of file diff --git a/ArffFiles.cc b/ArffFiles.cc new file mode 100644 index 0000000..b576699 --- /dev/null +++ b/ArffFiles.cc @@ -0,0 +1,132 @@ +#include "ArffFiles.h" +#include +#include +#include + +using namespace std; + +ArffFiles::ArffFiles() = default; + +vector ArffFiles::getLines() const +{ + return lines; +} + +unsigned long int ArffFiles::getSize() const +{ + return lines.size(); +} + +vector> ArffFiles::getAttributes() const +{ + return attributes; +} + +string ArffFiles::getClassName() const +{ + return className; +} + +string ArffFiles::getClassType() const +{ + return classType; +} + +vector>& ArffFiles::getX() +{ + return X; +} + +vector& ArffFiles::getY() +{ + return y; +} + +void ArffFiles::load(const string& fileName, bool classLast) +{ + ifstream file(fileName); + if (!file.is_open()) { + throw invalid_argument("Unable to open file"); + } + string line; + string keyword; + string attribute; + string type; + string type_w; + while (getline(file, line)) { + if (line.empty() || line[0] == '%' || line == "\r" || line == " ") { + continue; + } + if (line.find("@attribute") != string::npos || line.find("@ATTRIBUTE") != string::npos) { + stringstream ss(line); + ss >> keyword >> attribute; + type = ""; + while (ss >> type_w) + type += type_w + " "; + attributes.emplace_back(attribute, trim(type)); + continue; + } + if (line[0] == '@') { + continue; + } + lines.push_back(line); + } + file.close(); + if (attributes.empty()) + throw invalid_argument("No attributes found"); + if (classLast) { + className = get<0>(attributes.back()); + classType = get<1>(attributes.back()); + attributes.pop_back(); + } else { + className = get<0>(attributes.front()); + classType = get<1>(attributes.front()); + attributes.erase(attributes.begin()); + } + generateDataset(classLast); + +} + +void ArffFiles::generateDataset(bool classLast) +{ + X = vector>(attributes.size(), vector(lines.size())); + auto yy = vector(lines.size(), ""); + int labelIndex = classLast ? static_cast(attributes.size()) : 0; + for (size_t i = 0; i < lines.size(); i++) { + stringstream ss(lines[i]); + string value; + int pos = 0; + int xIndex = 0; + while (getline(ss, value, ',')) { + if (pos++ == labelIndex) { + yy[i] = value; + } else { + X[xIndex++][i] = stof(value); + } + } + } + y = factorize(yy); +} + +string ArffFiles::trim(const string& source) +{ + string s(source); + s.erase(0, s.find_first_not_of(" \n\r\t")); + s.erase(s.find_last_not_of(" \n\r\t") + 1); + return s; +} + +vector ArffFiles::factorize(const vector& labels_t) +{ + vector yy; + yy.reserve(labels_t.size()); + map labelMap; + int i = 0; + for (const string& label : labels_t) { + if (labelMap.find(label) == labelMap.end()) { + labelMap[label] = i++; + } + yy.push_back(labelMap[label]); + } + return yy; +} \ No newline at end of file diff --git a/ArffFiles.h b/ArffFiles.h new file mode 100644 index 0000000..ff8bbc5 --- /dev/null +++ b/ArffFiles.h @@ -0,0 +1,34 @@ +#ifndef ARFFFILES_H +#define ARFFFILES_H + +#include +#include + +using namespace std; + +class ArffFiles { +private: + vector lines; + vector> attributes; + string className; + string classType; + vector> X; + vector y; + + void generateDataset(bool); + +public: + ArffFiles(); + void load(const string&, bool = true); + vector getLines() const; + unsigned long int getSize() const; + string getClassName() const; + string getClassType() const; + static string trim(const string&); + vector>& getX(); + vector& getY(); + vector> getAttributes() const; + static vector factorize(const vector& labels_t); +}; + +#endif \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..2f7ffd6 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.20) + +project(bayesnet) +find_package(Torch REQUIRED) + +if (POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) +endif () + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +# add_library(BayesNet Node.cc Network.cc) +# add_executable(BayesNet main.cc Node.cc Network.cc ArffFiles.cc) +add_executable(BayesNet main.cc ArffFiles.cc Node.cc Network.cc) +target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/Network.cc b/Network.cc new file mode 100644 index 0000000..a5c41c2 --- /dev/null +++ b/Network.cc @@ -0,0 +1,79 @@ +#include "Network.h" +namespace bayesnet { + Network::~Network() + { + for (auto& pair : nodes) { + delete pair.second; + } + } + void Network::addNode(std::string name, int numStates) + { + nodes[name] = new Node(name, numStates); + } + void Network::addEdge(const std::string parent, const std::string child) + { + if (nodes.find(parent) == nodes.end()) { + throw std::invalid_argument("Parent node " + parent + " does not exist"); + } + if (nodes.find(child) == nodes.end()) { + throw std::invalid_argument("Child node " + child + " does not exist"); + } + nodes[parent]->addChild(nodes[child]); + nodes[child]->addParent(nodes[parent]); + } + std::map& Network::getNodes() + { + return nodes; + } + void Network::fit(const std::vector>& dataset, const int smoothing) + { + auto jointCounts = [](const std::vector>& data, const std::vector& indices, int numStates) { + int size = indices.size(); + std::vector sizes(size, numStates); + torch::Tensor counts = torch::zeros(sizes, torch::kLong); + + for (const auto& row : data) { + int idx = 0; + for (int i = 0; i < size; ++i) { + idx = idx * numStates + row[indices[i]]; + } + counts.view({ -1 }).add_(idx, 1); + } + + return counts; + }; + + auto marginalCounts = [](const torch::Tensor& jointCounts) { + return jointCounts.sum(-1); + }; + + for (auto& pair : nodes) { + Node* node = pair.second; + + std::vector indices; + for (const auto& parent : node->getParents()) { + indices.push_back(nodes[parent->getName()]->getId()); + } + indices.push_back(node->getId()); + + for (auto& child : node->getChildren()) { + torch::Tensor counts = jointCounts(dataset, indices, node->getNumStates()) + smoothing; + torch::Tensor parentCounts = marginalCounts(counts); + parentCounts = parentCounts.unsqueeze(-1); + + torch::Tensor cpt = counts.to(torch::kDouble) / parentCounts.to(torch::kDouble); + setCPD(node->getCPDKey(child), cpt); + } + } + } + + torch::Tensor& Network::getCPD(const std::string& key) + { + return cpds[key]; + } + + void Network::setCPD(const std::string& key, const torch::Tensor& cpt) + { + cpds[key] = cpt; + } +} diff --git a/Network.h b/Network.h new file mode 100644 index 0000000..94bafdc --- /dev/null +++ b/Network.h @@ -0,0 +1,21 @@ +#ifndef NETWORK_H +#define NETWORK_H +#include "Node.h" +#include +#include +namespace bayesnet { + class Network { + private: + map nodes; + map cpds; // Map from CPD key to CPD tensor + public: + ~Network(); + void addNode(string, int); + void addEdge(const string, const string); + map& getNodes(); + void fit(const vector>&, const int); + torch::Tensor& getCPD(const string&); + void setCPD(const string&, const torch::Tensor&); + }; +} +#endif \ No newline at end of file diff --git a/Node.cc b/Node.cc new file mode 100644 index 0000000..85b1f1d --- /dev/null +++ b/Node.cc @@ -0,0 +1,41 @@ +#include "Node.h" + +namespace bayesnet { + int Node::next_id = 0; + + Node::Node(const std::string& name, int numStates) + : id(next_id++), name(name), numStates(numStates), cpt(torch::Tensor()), parents(vector()), children(vector()) + { + } + + string Node::getName() const + { + return name; + } + + void Node::addParent(Node* parent) + { + parents.push_back(parent); + } + + void Node::addChild(Node* child) + { + children.push_back(child); + } + vector& Node::getParents() + { + return parents; + } + vector& Node::getChildren() + { + return children; + } + int Node::getNumStates() const + { + return numStates; + } + string Node::getCPDKey(const Node* child) const + { + return name + "-" + child->getName(); + } +} \ No newline at end of file diff --git a/Node.h b/Node.h new file mode 100644 index 0000000..a3ab4f3 --- /dev/null +++ b/Node.h @@ -0,0 +1,31 @@ +#ifndef NODE_H +#define NODE_H +#include +#include +#include +namespace bayesnet { + using namespace std; + class Node { + private: + static int next_id; + const int id; + string name; + vector parents; + vector children; + int numStates; + torch::Tensor cpt; + public: + Node(const std::string& name, int numStates); + void addParent(Node* parent); + void addChild(Node* child); + string getName() const; + vector& getParents(); + vector& getChildren(); + torch::Tensor& getCPT(); + void setCPT(const torch::Tensor& cpt); + int getNumStates() const; + int getId() const { return id; } + string getCPDKey(const Node*) const; + }; +} +#endif \ No newline at end of file diff --git a/iris.arff b/iris.arff new file mode 100755 index 0000000..780480c --- /dev/null +++ b/iris.arff @@ -0,0 +1,225 @@ +% 1. Title: Iris Plants Database +% +% 2. Sources: +% (a) Creator: R.A. Fisher +% (b) Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov) +% (c) Date: July, 1988 +% +% 3. Past Usage: +% - Publications: too many to mention!!! Here are a few. +% 1. Fisher,R.A. "The use of multiple measurements in taxonomic problems" +% Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions +% to Mathematical Statistics" (John Wiley, NY, 1950). +% 2. Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis. +% (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218. +% 3. Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System +% Structure and Classification Rule for Recognition in Partially Exposed +% Environments". IEEE Transactions on Pattern Analysis and Machine +% Intelligence, Vol. PAMI-2, No. 1, 67-71. +% -- Results: +% -- very low misclassification rates (0% for the setosa class) +% 4. Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE +% Transactions on Information Theory, May 1972, 431-433. +% -- Results: +% -- very low misclassification rates again +% 5. See also: 1988 MLC Proceedings, 54-64. Cheeseman et al's AUTOCLASS II +% conceptual clustering system finds 3 classes in the data. +% +% 4. Relevant Information: +% --- This is perhaps the best known database to be found in the pattern +% recognition literature. Fisher's paper is a classic in the field +% and is referenced frequently to this day. (See Duda & Hart, for +% example.) The data set contains 3 classes of 50 instances each, +% where each class refers to a type of iris plant. One class is +% linearly separable from the other 2; the latter are NOT linearly +% separable from each other. +% --- Predicted attribute: class of iris plant. +% --- This is an exceedingly simple domain. +% +% 5. Number of Instances: 150 (50 in each of three classes) +% +% 6. Number of Attributes: 4 numeric, predictive attributes and the class +% +% 7. Attribute Information: +% 1. sepal length in cm +% 2. sepal width in cm +% 3. petal length in cm +% 4. petal width in cm +% 5. class: +% -- Iris Setosa +% -- Iris Versicolour +% -- Iris Virginica +% +% 8. Missing Attribute Values: None +% +% Summary Statistics: +% Min Max Mean SD Class Correlation +% sepal length: 4.3 7.9 5.84 0.83 0.7826 +% sepal width: 2.0 4.4 3.05 0.43 -0.4194 +% petal length: 1.0 6.9 3.76 1.76 0.9490 (high!) +% petal width: 0.1 2.5 1.20 0.76 0.9565 (high!) +% +% 9. Class Distribution: 33.3% for each of 3 classes. + +@RELATION iris + +@ATTRIBUTE sepallength REAL +@ATTRIBUTE sepalwidth REAL +@ATTRIBUTE petallength REAL +@ATTRIBUTE petalwidth REAL +@ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica} + +@DATA +5.1,3.5,1.4,0.2,Iris-setosa +4.9,3.0,1.4,0.2,Iris-setosa +4.7,3.2,1.3,0.2,Iris-setosa +4.6,3.1,1.5,0.2,Iris-setosa +5.0,3.6,1.4,0.2,Iris-setosa +5.4,3.9,1.7,0.4,Iris-setosa +4.6,3.4,1.4,0.3,Iris-setosa +5.0,3.4,1.5,0.2,Iris-setosa +4.4,2.9,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.4,3.7,1.5,0.2,Iris-setosa +4.8,3.4,1.6,0.2,Iris-setosa +4.8,3.0,1.4,0.1,Iris-setosa +4.3,3.0,1.1,0.1,Iris-setosa +5.8,4.0,1.2,0.2,Iris-setosa +5.7,4.4,1.5,0.4,Iris-setosa +5.4,3.9,1.3,0.4,Iris-setosa +5.1,3.5,1.4,0.3,Iris-setosa +5.7,3.8,1.7,0.3,Iris-setosa +5.1,3.8,1.5,0.3,Iris-setosa +5.4,3.4,1.7,0.2,Iris-setosa +5.1,3.7,1.5,0.4,Iris-setosa +4.6,3.6,1.0,0.2,Iris-setosa +5.1,3.3,1.7,0.5,Iris-setosa +4.8,3.4,1.9,0.2,Iris-setosa +5.0,3.0,1.6,0.2,Iris-setosa +5.0,3.4,1.6,0.4,Iris-setosa +5.2,3.5,1.5,0.2,Iris-setosa +5.2,3.4,1.4,0.2,Iris-setosa +4.7,3.2,1.6,0.2,Iris-setosa +4.8,3.1,1.6,0.2,Iris-setosa +5.4,3.4,1.5,0.4,Iris-setosa +5.2,4.1,1.5,0.1,Iris-setosa +5.5,4.2,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.0,3.2,1.2,0.2,Iris-setosa +5.5,3.5,1.3,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +4.4,3.0,1.3,0.2,Iris-setosa +5.1,3.4,1.5,0.2,Iris-setosa +5.0,3.5,1.3,0.3,Iris-setosa +4.5,2.3,1.3,0.3,Iris-setosa +4.4,3.2,1.3,0.2,Iris-setosa +5.0,3.5,1.6,0.6,Iris-setosa +5.1,3.8,1.9,0.4,Iris-setosa +4.8,3.0,1.4,0.3,Iris-setosa +5.1,3.8,1.6,0.2,Iris-setosa +4.6,3.2,1.4,0.2,Iris-setosa +5.3,3.7,1.5,0.2,Iris-setosa +5.0,3.3,1.4,0.2,Iris-setosa +7.0,3.2,4.7,1.4,Iris-versicolor +6.4,3.2,4.5,1.5,Iris-versicolor +6.9,3.1,4.9,1.5,Iris-versicolor +5.5,2.3,4.0,1.3,Iris-versicolor +6.5,2.8,4.6,1.5,Iris-versicolor +5.7,2.8,4.5,1.3,Iris-versicolor +6.3,3.3,4.7,1.6,Iris-versicolor +4.9,2.4,3.3,1.0,Iris-versicolor +6.6,2.9,4.6,1.3,Iris-versicolor +5.2,2.7,3.9,1.4,Iris-versicolor +5.0,2.0,3.5,1.0,Iris-versicolor +5.9,3.0,4.2,1.5,Iris-versicolor +6.0,2.2,4.0,1.0,Iris-versicolor +6.1,2.9,4.7,1.4,Iris-versicolor +5.6,2.9,3.6,1.3,Iris-versicolor +6.7,3.1,4.4,1.4,Iris-versicolor +5.6,3.0,4.5,1.5,Iris-versicolor +5.8,2.7,4.1,1.0,Iris-versicolor +6.2,2.2,4.5,1.5,Iris-versicolor +5.6,2.5,3.9,1.1,Iris-versicolor +5.9,3.2,4.8,1.8,Iris-versicolor +6.1,2.8,4.0,1.3,Iris-versicolor +6.3,2.5,4.9,1.5,Iris-versicolor +6.1,2.8,4.7,1.2,Iris-versicolor +6.4,2.9,4.3,1.3,Iris-versicolor +6.6,3.0,4.4,1.4,Iris-versicolor +6.8,2.8,4.8,1.4,Iris-versicolor +6.7,3.0,5.0,1.7,Iris-versicolor +6.0,2.9,4.5,1.5,Iris-versicolor +5.7,2.6,3.5,1.0,Iris-versicolor +5.5,2.4,3.8,1.1,Iris-versicolor +5.5,2.4,3.7,1.0,Iris-versicolor +5.8,2.7,3.9,1.2,Iris-versicolor +6.0,2.7,5.1,1.6,Iris-versicolor +5.4,3.0,4.5,1.5,Iris-versicolor +6.0,3.4,4.5,1.6,Iris-versicolor +6.7,3.1,4.7,1.5,Iris-versicolor +6.3,2.3,4.4,1.3,Iris-versicolor +5.6,3.0,4.1,1.3,Iris-versicolor +5.5,2.5,4.0,1.3,Iris-versicolor +5.5,2.6,4.4,1.2,Iris-versicolor +6.1,3.0,4.6,1.4,Iris-versicolor +5.8,2.6,4.0,1.2,Iris-versicolor +5.0,2.3,3.3,1.0,Iris-versicolor +5.6,2.7,4.2,1.3,Iris-versicolor +5.7,3.0,4.2,1.2,Iris-versicolor +5.7,2.9,4.2,1.3,Iris-versicolor +6.2,2.9,4.3,1.3,Iris-versicolor +5.1,2.5,3.0,1.1,Iris-versicolor +5.7,2.8,4.1,1.3,Iris-versicolor +6.3,3.3,6.0,2.5,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +7.1,3.0,5.9,2.1,Iris-virginica +6.3,2.9,5.6,1.8,Iris-virginica +6.5,3.0,5.8,2.2,Iris-virginica +7.6,3.0,6.6,2.1,Iris-virginica +4.9,2.5,4.5,1.7,Iris-virginica +7.3,2.9,6.3,1.8,Iris-virginica +6.7,2.5,5.8,1.8,Iris-virginica +7.2,3.6,6.1,2.5,Iris-virginica +6.5,3.2,5.1,2.0,Iris-virginica +6.4,2.7,5.3,1.9,Iris-virginica +6.8,3.0,5.5,2.1,Iris-virginica +5.7,2.5,5.0,2.0,Iris-virginica +5.8,2.8,5.1,2.4,Iris-virginica +6.4,3.2,5.3,2.3,Iris-virginica +6.5,3.0,5.5,1.8,Iris-virginica +7.7,3.8,6.7,2.2,Iris-virginica +7.7,2.6,6.9,2.3,Iris-virginica +6.0,2.2,5.0,1.5,Iris-virginica +6.9,3.2,5.7,2.3,Iris-virginica +5.6,2.8,4.9,2.0,Iris-virginica +7.7,2.8,6.7,2.0,Iris-virginica +6.3,2.7,4.9,1.8,Iris-virginica +6.7,3.3,5.7,2.1,Iris-virginica +7.2,3.2,6.0,1.8,Iris-virginica +6.2,2.8,4.8,1.8,Iris-virginica +6.1,3.0,4.9,1.8,Iris-virginica +6.4,2.8,5.6,2.1,Iris-virginica +7.2,3.0,5.8,1.6,Iris-virginica +7.4,2.8,6.1,1.9,Iris-virginica +7.9,3.8,6.4,2.0,Iris-virginica +6.4,2.8,5.6,2.2,Iris-virginica +6.3,2.8,5.1,1.5,Iris-virginica +6.1,2.6,5.6,1.4,Iris-virginica +7.7,3.0,6.1,2.3,Iris-virginica +6.3,3.4,5.6,2.4,Iris-virginica +6.4,3.1,5.5,1.8,Iris-virginica +6.0,3.0,4.8,1.8,Iris-virginica +6.9,3.1,5.4,2.1,Iris-virginica +6.7,3.1,5.6,2.4,Iris-virginica +6.9,3.1,5.1,2.3,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +6.8,3.2,5.9,2.3,Iris-virginica +6.7,3.3,5.7,2.5,Iris-virginica +6.7,3.0,5.2,2.3,Iris-virginica +6.3,2.5,5.0,1.9,Iris-virginica +6.5,3.0,5.2,2.0,Iris-virginica +6.2,3.4,5.4,2.3,Iris-virginica +5.9,3.0,5.1,1.8,Iris-virginica +% +% +% diff --git a/main.cc b/main.cc new file mode 100644 index 0000000..8a1d6d1 --- /dev/null +++ b/main.cc @@ -0,0 +1,43 @@ +#include +#include +#include +#include "ArffFiles.h" +#include "Network.h" + +using namespace std; + +int main() +{ + auto handler = ArffFiles(); + handler.load("iris.arff"); + auto X = handler.getX(); + auto y = handler.getY(); + auto className = handler.getClassName(); + vector> edges = { {className, "sepallength"}, {className, "sepalwidth"}, {className, "petallength"}, {className, "petalwidth"} }; + auto network = bayesnet::Network(); + // Add nodes to the network + for (auto feature : handler.getAttributes()) { + cout << "Adding feature: " << feature.first << endl; + network.addNode(feature.first, 7); + } + network.addNode(className, 3); + for (auto item : edges) { + network.addEdge(item.first, item.second); + } + cout << "Hello, Bayesian Networks!" << endl; + torch::Tensor tensor = torch::eye(3); + cout << tensor << std::endl; + cout << "Nodes:" << endl; + for (auto [name, item] : network.getNodes()) { + cout << "*" << item->getName() << endl; + cout << "-Parents:" << endl; + for (auto parent : item->getParents()) { + cout << " " << parent->getName() << endl; + } + cout << "-Children:" << endl; + for (auto child : item->getChildren()) { + cout << " " << child->getName() << endl; + } + } + return 0; +} \ No newline at end of file diff --git a/simple/Network.cc b/simple/Network.cc new file mode 100644 index 0000000..f16a1dd --- /dev/null +++ b/simple/Network.cc @@ -0,0 +1,47 @@ +#include +#include +#include +#include "Network.h" + +namespace bayesnet { + Network::Network() {} + + Network::~Network() + { + for (auto& pair : nodes) { + delete pair.second; + } + } + + void Network::addNode(std::string name) + { + nodes[name] = new Node(name); + } + + void Network::addEdge(std::string parentName, std::string childName) + { + Node* parent = nodes[parentName]; + Node* child = nodes[childName]; + + if (parent == nullptr || child == nullptr) { + throw std::invalid_argument("Parent or child node not found."); + } + + child->addParent(parent); + } + + // to be implemented + void Network::fit(const std::vector>& dataset) + { + // ... learn parameters (i.e., CPTs) using the dataset + } + + // to be implemented + std::vector Network::predict(const std::vector>& testset) + { + std::vector predictions; + // ... use the CPTs and network structure to predict values + return predictions; + } +} + diff --git a/simple/Network.h b/simple/Network.h new file mode 100644 index 0000000..d09ac72 --- /dev/null +++ b/simple/Network.h @@ -0,0 +1,21 @@ +#ifndef NETWORK_H +#define NETWORK_H +#include +#include +#include +#include "Node.h" + +namespace bayesnet { + class Network { + private: + std::map nodes; + public: + Network(); + ~Network(); + void addNode(std::string); + void addEdge(std::string, std::string); + void fit(const std::vector>&); + std::vector predict(const std::vector>&); + }; +} +#endif diff --git a/simple/Node.cc b/simple/Node.cc new file mode 100644 index 0000000..0d04c87 --- /dev/null +++ b/simple/Node.cc @@ -0,0 +1,14 @@ +#include +#include +#include +#include "Node.h" + +namespace bayesnet { + Node::Node(std::string name) : name(name) {} + + void Node::addParent(Node* parent) + { + parents.push_back(parent); + parent->children.push_back(this); + } +} diff --git a/simple/Node.h b/simple/Node.h new file mode 100644 index 0000000..7f017f9 --- /dev/null +++ b/simple/Node.h @@ -0,0 +1,18 @@ +#ifndef NODE_H +#define NODE_H +#include +#include +#include +namespace bayesnet { + class Node { + private: + std::string name; + std::vector parents; + std::vector children; + std::map, double> cpt; // Conditional Probability Table + public: + Node(std::string); + void addParent(Node*); + }; +} +#endif diff --git a/test.cc b/test.cc new file mode 100644 index 0000000..adbbb08 --- /dev/null +++ b/test.cc @@ -0,0 +1,23 @@ +#include +#include +#include + +using namespace std; + +int main(int argc, char const* argv[]) +{ + map m; + m["a"] = 1; + m["b"] = 2; + m["c"] = 3; + if (m.find("b") != m.end()) { + cout << "Found b" << endl; + } else { + cout << "Not found b" << endl; + } + // for (auto [key, value] : m) { + // cout << key << " " << value << endl; + // } + + return 0; +}