Begin Network build

This commit is contained in:
Ricardo Montañana Gómez 2023-06-29 22:00:41 +02:00
parent 00ed1e0be1
commit 8e9c3483aa
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
18 changed files with 874 additions and 0 deletions

2
.gitignore vendored
View File

@ -31,4 +31,6 @@
*.exe *.exe
*.out *.out
*.app *.app
build/
*.dSYM/**

25
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,25 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "bayesnet",
"program": "${workspaceFolder}/build/bayesnet",
"args": [],
"cwd": "${workspaceFolder}",
"preLaunchTask": "CMake: build"
},
{
"type": "lldb",
"request": "launch",
"name": "aout",
"program": "${workspaceFolder}/a.out",
"args": [],
"cwd": "${workspaceFolder}"
}
]
}

86
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,86 @@
{
"files.associations": {
"*.rmd": "markdown",
"*.py": "python",
"vector": "cpp",
"__bit_reference": "cpp",
"__bits": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__hash_table": "cpp",
"__locale": "cpp",
"__mutex_base": "cpp",
"__node_handle": "cpp",
"__nullptr": "cpp",
"__split_buffer": "cpp",
"__string": "cpp",
"__threading_support": "cpp",
"__tuple": "cpp",
"array": "cpp",
"atomic": "cpp",
"bitset": "cpp",
"cctype": "cpp",
"chrono": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"compare": "cpp",
"complex": "cpp",
"concepts": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"exception": "cpp",
"initializer_list": "cpp",
"ios": "cpp",
"iosfwd": "cpp",
"istream": "cpp",
"limits": "cpp",
"locale": "cpp",
"memory": "cpp",
"mutex": "cpp",
"new": "cpp",
"optional": "cpp",
"ostream": "cpp",
"ratio": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"typeinfo": "cpp",
"unordered_map": "cpp",
"variant": "cpp",
"algorithm": "cpp",
"iostream": "cpp",
"iomanip": "cpp",
"numeric": "cpp",
"set": "cpp",
"__tree": "cpp",
"deque": "cpp",
"list": "cpp",
"map": "cpp",
"unordered_set": "cpp",
"any": "cpp",
"condition_variable": "cpp",
"forward_list": "cpp",
"fstream": "cpp",
"stack": "cpp",
"thread": "cpp",
"__memory": "cpp",
"filesystem": "cpp",
"*.toml": "toml",
"utility": "cpp",
"__verbose_abort": "cpp",
"bit": "cpp"
}
}

16
.vscode/tasks.json vendored Normal file
View File

@ -0,0 +1,16 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "cmake",
"label": "CMake: build",
"command": "build",
"targets": [
"all"
],
"group": "build",
"problemMatcher": [],
"detail": "CMake template build task"
}
]
}

132
ArffFiles.cc Normal file
View File

@ -0,0 +1,132 @@
#include "ArffFiles.h"
#include <fstream>
#include <sstream>
#include <map>
using namespace std;
ArffFiles::ArffFiles() = default;
vector<string> ArffFiles::getLines() const
{
return lines;
}
unsigned long int ArffFiles::getSize() const
{
return lines.size();
}
vector<pair<string, string>> ArffFiles::getAttributes() const
{
return attributes;
}
string ArffFiles::getClassName() const
{
return className;
}
string ArffFiles::getClassType() const
{
return classType;
}
vector<vector<float>>& ArffFiles::getX()
{
return X;
}
vector<int>& ArffFiles::getY()
{
return y;
}
void ArffFiles::load(const string& fileName, bool classLast)
{
ifstream file(fileName);
if (!file.is_open()) {
throw invalid_argument("Unable to open file");
}
string line;
string keyword;
string attribute;
string type;
string type_w;
while (getline(file, line)) {
if (line.empty() || line[0] == '%' || line == "\r" || line == " ") {
continue;
}
if (line.find("@attribute") != string::npos || line.find("@ATTRIBUTE") != string::npos) {
stringstream ss(line);
ss >> keyword >> attribute;
type = "";
while (ss >> type_w)
type += type_w + " ";
attributes.emplace_back(attribute, trim(type));
continue;
}
if (line[0] == '@') {
continue;
}
lines.push_back(line);
}
file.close();
if (attributes.empty())
throw invalid_argument("No attributes found");
if (classLast) {
className = get<0>(attributes.back());
classType = get<1>(attributes.back());
attributes.pop_back();
} else {
className = get<0>(attributes.front());
classType = get<1>(attributes.front());
attributes.erase(attributes.begin());
}
generateDataset(classLast);
}
void ArffFiles::generateDataset(bool classLast)
{
X = vector<vector<float>>(attributes.size(), vector<float>(lines.size()));
auto yy = vector<string>(lines.size(), "");
int labelIndex = classLast ? static_cast<int>(attributes.size()) : 0;
for (size_t i = 0; i < lines.size(); i++) {
stringstream ss(lines[i]);
string value;
int pos = 0;
int xIndex = 0;
while (getline(ss, value, ',')) {
if (pos++ == labelIndex) {
yy[i] = value;
} else {
X[xIndex++][i] = stof(value);
}
}
}
y = factorize(yy);
}
string ArffFiles::trim(const string& source)
{
string s(source);
s.erase(0, s.find_first_not_of(" \n\r\t"));
s.erase(s.find_last_not_of(" \n\r\t") + 1);
return s;
}
vector<int> ArffFiles::factorize(const vector<string>& labels_t)
{
vector<int> yy;
yy.reserve(labels_t.size());
map<string, int> labelMap;
int i = 0;
for (const string& label : labels_t) {
if (labelMap.find(label) == labelMap.end()) {
labelMap[label] = i++;
}
yy.push_back(labelMap[label]);
}
return yy;
}

34
ArffFiles.h Normal file
View File

@ -0,0 +1,34 @@
#ifndef ARFFFILES_H
#define ARFFFILES_H
#include <string>
#include <vector>
using namespace std;
class ArffFiles {
private:
vector<string> lines;
vector<pair<string, string>> attributes;
string className;
string classType;
vector<vector<float>> X;
vector<int> y;
void generateDataset(bool);
public:
ArffFiles();
void load(const string&, bool = true);
vector<string> getLines() const;
unsigned long int getSize() const;
string getClassName() const;
string getClassType() const;
static string trim(const string&);
vector<vector<float>>& getX();
vector<int>& getY();
vector<pair<string, string>> getAttributes() const;
static vector<int> factorize(const vector<string>& labels_t);
};
#endif

16
CMakeLists.txt Normal file
View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.20)
project(bayesnet)
find_package(Torch REQUIRED)
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
endif ()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
# add_library(BayesNet Node.cc Network.cc)
# add_executable(BayesNet main.cc Node.cc Network.cc ArffFiles.cc)
add_executable(BayesNet main.cc ArffFiles.cc Node.cc Network.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

79
Network.cc Normal file
View File

@ -0,0 +1,79 @@
#include "Network.h"
namespace bayesnet {
Network::~Network()
{
for (auto& pair : nodes) {
delete pair.second;
}
}
void Network::addNode(std::string name, int numStates)
{
nodes[name] = new Node(name, numStates);
}
void Network::addEdge(const std::string parent, const std::string child)
{
if (nodes.find(parent) == nodes.end()) {
throw std::invalid_argument("Parent node " + parent + " does not exist");
}
if (nodes.find(child) == nodes.end()) {
throw std::invalid_argument("Child node " + child + " does not exist");
}
nodes[parent]->addChild(nodes[child]);
nodes[child]->addParent(nodes[parent]);
}
std::map<std::string, Node*>& Network::getNodes()
{
return nodes;
}
void Network::fit(const std::vector<std::vector<int>>& dataset, const int smoothing)
{
auto jointCounts = [](const std::vector<std::vector<int>>& data, const std::vector<int>& indices, int numStates) {
int size = indices.size();
std::vector<int64_t> sizes(size, numStates);
torch::Tensor counts = torch::zeros(sizes, torch::kLong);
for (const auto& row : data) {
int idx = 0;
for (int i = 0; i < size; ++i) {
idx = idx * numStates + row[indices[i]];
}
counts.view({ -1 }).add_(idx, 1);
}
return counts;
};
auto marginalCounts = [](const torch::Tensor& jointCounts) {
return jointCounts.sum(-1);
};
for (auto& pair : nodes) {
Node* node = pair.second;
std::vector<int> indices;
for (const auto& parent : node->getParents()) {
indices.push_back(nodes[parent->getName()]->getId());
}
indices.push_back(node->getId());
for (auto& child : node->getChildren()) {
torch::Tensor counts = jointCounts(dataset, indices, node->getNumStates()) + smoothing;
torch::Tensor parentCounts = marginalCounts(counts);
parentCounts = parentCounts.unsqueeze(-1);
torch::Tensor cpt = counts.to(torch::kDouble) / parentCounts.to(torch::kDouble);
setCPD(node->getCPDKey(child), cpt);
}
}
}
torch::Tensor& Network::getCPD(const std::string& key)
{
return cpds[key];
}
void Network::setCPD(const std::string& key, const torch::Tensor& cpt)
{
cpds[key] = cpt;
}
}

21
Network.h Normal file
View File

@ -0,0 +1,21 @@
#ifndef NETWORK_H
#define NETWORK_H
#include "Node.h"
#include <map>
#include <vector>
namespace bayesnet {
class Network {
private:
map<string, Node*> nodes;
map<string, torch::Tensor> cpds; // Map from CPD key to CPD tensor
public:
~Network();
void addNode(string, int);
void addEdge(const string, const string);
map<string, Node*>& getNodes();
void fit(const vector<vector<int>>&, const int);
torch::Tensor& getCPD(const string&);
void setCPD(const string&, const torch::Tensor&);
};
}
#endif

41
Node.cc Normal file
View File

@ -0,0 +1,41 @@
#include "Node.h"
namespace bayesnet {
int Node::next_id = 0;
Node::Node(const std::string& name, int numStates)
: id(next_id++), name(name), numStates(numStates), cpt(torch::Tensor()), parents(vector<Node*>()), children(vector<Node*>())
{
}
string Node::getName() const
{
return name;
}
void Node::addParent(Node* parent)
{
parents.push_back(parent);
}
void Node::addChild(Node* child)
{
children.push_back(child);
}
vector<Node*>& Node::getParents()
{
return parents;
}
vector<Node*>& Node::getChildren()
{
return children;
}
int Node::getNumStates() const
{
return numStates;
}
string Node::getCPDKey(const Node* child) const
{
return name + "-" + child->getName();
}
}

31
Node.h Normal file
View File

@ -0,0 +1,31 @@
#ifndef NODE_H
#define NODE_H
#include <torch/torch.h>
#include <vector>
#include <string>
namespace bayesnet {
using namespace std;
class Node {
private:
static int next_id;
const int id;
string name;
vector<Node*> parents;
vector<Node*> children;
int numStates;
torch::Tensor cpt;
public:
Node(const std::string& name, int numStates);
void addParent(Node* parent);
void addChild(Node* child);
string getName() const;
vector<Node*>& getParents();
vector<Node*>& getChildren();
torch::Tensor& getCPT();
void setCPT(const torch::Tensor& cpt);
int getNumStates() const;
int getId() const { return id; }
string getCPDKey(const Node*) const;
};
}
#endif

225
iris.arff Executable file
View File

@ -0,0 +1,225 @@
% 1. Title: Iris Plants Database
%
% 2. Sources:
% (a) Creator: R.A. Fisher
% (b) Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
% (c) Date: July, 1988
%
% 3. Past Usage:
% - Publications: too many to mention!!! Here are a few.
% 1. Fisher,R.A. "The use of multiple measurements in taxonomic problems"
% Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions
% to Mathematical Statistics" (John Wiley, NY, 1950).
% 2. Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
% (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.
% 3. Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
% Structure and Classification Rule for Recognition in Partially Exposed
% Environments". IEEE Transactions on Pattern Analysis and Machine
% Intelligence, Vol. PAMI-2, No. 1, 67-71.
% -- Results:
% -- very low misclassification rates (0% for the setosa class)
% 4. Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE
% Transactions on Information Theory, May 1972, 431-433.
% -- Results:
% -- very low misclassification rates again
% 5. See also: 1988 MLC Proceedings, 54-64. Cheeseman et al's AUTOCLASS II
% conceptual clustering system finds 3 classes in the data.
%
% 4. Relevant Information:
% --- This is perhaps the best known database to be found in the pattern
% recognition literature. Fisher's paper is a classic in the field
% and is referenced frequently to this day. (See Duda & Hart, for
% example.) The data set contains 3 classes of 50 instances each,
% where each class refers to a type of iris plant. One class is
% linearly separable from the other 2; the latter are NOT linearly
% separable from each other.
% --- Predicted attribute: class of iris plant.
% --- This is an exceedingly simple domain.
%
% 5. Number of Instances: 150 (50 in each of three classes)
%
% 6. Number of Attributes: 4 numeric, predictive attributes and the class
%
% 7. Attribute Information:
% 1. sepal length in cm
% 2. sepal width in cm
% 3. petal length in cm
% 4. petal width in cm
% 5. class:
% -- Iris Setosa
% -- Iris Versicolour
% -- Iris Virginica
%
% 8. Missing Attribute Values: None
%
% Summary Statistics:
% Min Max Mean SD Class Correlation
% sepal length: 4.3 7.9 5.84 0.83 0.7826
% sepal width: 2.0 4.4 3.05 0.43 -0.4194
% petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
% petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
%
% 9. Class Distribution: 33.3% for each of 3 classes.
@RELATION iris
@ATTRIBUTE sepallength REAL
@ATTRIBUTE sepalwidth REAL
@ATTRIBUTE petallength REAL
@ATTRIBUTE petalwidth REAL
@ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica}
@DATA
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica
%
%
%

43
main.cc Normal file
View File

@ -0,0 +1,43 @@
#include <iostream>
#include <string>
#include <torch/torch.h>
#include "ArffFiles.h"
#include "Network.h"
using namespace std;
int main()
{
auto handler = ArffFiles();
handler.load("iris.arff");
auto X = handler.getX();
auto y = handler.getY();
auto className = handler.getClassName();
vector<pair<string, string>> edges = { {className, "sepallength"}, {className, "sepalwidth"}, {className, "petallength"}, {className, "petalwidth"} };
auto network = bayesnet::Network();
// Add nodes to the network
for (auto feature : handler.getAttributes()) {
cout << "Adding feature: " << feature.first << endl;
network.addNode(feature.first, 7);
}
network.addNode(className, 3);
for (auto item : edges) {
network.addEdge(item.first, item.second);
}
cout << "Hello, Bayesian Networks!" << endl;
torch::Tensor tensor = torch::eye(3);
cout << tensor << std::endl;
cout << "Nodes:" << endl;
for (auto [name, item] : network.getNodes()) {
cout << "*" << item->getName() << endl;
cout << "-Parents:" << endl;
for (auto parent : item->getParents()) {
cout << " " << parent->getName() << endl;
}
cout << "-Children:" << endl;
for (auto child : item->getChildren()) {
cout << " " << child->getName() << endl;
}
}
return 0;
}

47
simple/Network.cc Normal file
View File

@ -0,0 +1,47 @@
#include <string>
#include <vector>
#include <map>
#include "Network.h"
namespace bayesnet {
Network::Network() {}
Network::~Network()
{
for (auto& pair : nodes) {
delete pair.second;
}
}
void Network::addNode(std::string name)
{
nodes[name] = new Node(name);
}
void Network::addEdge(std::string parentName, std::string childName)
{
Node* parent = nodes[parentName];
Node* child = nodes[childName];
if (parent == nullptr || child == nullptr) {
throw std::invalid_argument("Parent or child node not found.");
}
child->addParent(parent);
}
// to be implemented
void Network::fit(const std::vector<std::vector<double>>& dataset)
{
// ... learn parameters (i.e., CPTs) using the dataset
}
// to be implemented
std::vector<double> Network::predict(const std::vector<std::vector<double>>& testset)
{
std::vector<double> predictions;
// ... use the CPTs and network structure to predict values
return predictions;
}
}

21
simple/Network.h Normal file
View File

@ -0,0 +1,21 @@
#ifndef NETWORK_H
#define NETWORK_H
#include <string>
#include <vector>
#include <map>
#include "Node.h"
namespace bayesnet {
class Network {
private:
std::map<std::string, Node*> nodes;
public:
Network();
~Network();
void addNode(std::string);
void addEdge(std::string, std::string);
void fit(const std::vector<std::vector<double>>&);
std::vector<double> predict(const std::vector<std::vector<double>>&);
};
}
#endif

14
simple/Node.cc Normal file
View File

@ -0,0 +1,14 @@
#include <string>
#include <vector>
#include <map>
#include "Node.h"
namespace bayesnet {
Node::Node(std::string name) : name(name) {}
void Node::addParent(Node* parent)
{
parents.push_back(parent);
parent->children.push_back(this);
}
}

18
simple/Node.h Normal file
View File

@ -0,0 +1,18 @@
#ifndef NODE_H
#define NODE_H
#include <string>
#include <vector>
#include <map>
namespace bayesnet {
class Node {
private:
std::string name;
std::vector<Node*> parents;
std::vector<Node*> children;
std::map<std::vector<bool>, double> cpt; // Conditional Probability Table
public:
Node(std::string);
void addParent(Node*);
};
}
#endif

23
test.cc Normal file
View File

@ -0,0 +1,23 @@
#include <map>
#include <string>
#include <iostream>
using namespace std;
int main(int argc, char const* argv[])
{
map<string, int> m;
m["a"] = 1;
m["b"] = 2;
m["c"] = 3;
if (m.find("b") != m.end()) {
cout << "Found b" << endl;
} else {
cout << "Not found b" << endl;
}
// for (auto [key, value] : m) {
// cout << key << " " << value << endl;
// }
return 0;
}