Begin Network build
This commit is contained in:
parent
00ed1e0be1
commit
8e9c3483aa
2
.gitignore
vendored
2
.gitignore
vendored
@ -31,4 +31,6 @@
|
||||
*.exe
|
||||
*.out
|
||||
*.app
|
||||
build/
|
||||
*.dSYM/**
|
||||
|
||||
|
25
.vscode/launch.json
vendored
Normal file
25
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "bayesnet",
|
||||
"program": "${workspaceFolder}/build/bayesnet",
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"preLaunchTask": "CMake: build"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "aout",
|
||||
"program": "${workspaceFolder}/a.out",
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
]
|
||||
}
|
86
.vscode/settings.json
vendored
Normal file
86
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,86 @@
|
||||
{
|
||||
"files.associations": {
|
||||
"*.rmd": "markdown",
|
||||
"*.py": "python",
|
||||
"vector": "cpp",
|
||||
"__bit_reference": "cpp",
|
||||
"__bits": "cpp",
|
||||
"__config": "cpp",
|
||||
"__debug": "cpp",
|
||||
"__errc": "cpp",
|
||||
"__hash_table": "cpp",
|
||||
"__locale": "cpp",
|
||||
"__mutex_base": "cpp",
|
||||
"__node_handle": "cpp",
|
||||
"__nullptr": "cpp",
|
||||
"__split_buffer": "cpp",
|
||||
"__string": "cpp",
|
||||
"__threading_support": "cpp",
|
||||
"__tuple": "cpp",
|
||||
"array": "cpp",
|
||||
"atomic": "cpp",
|
||||
"bitset": "cpp",
|
||||
"cctype": "cpp",
|
||||
"chrono": "cpp",
|
||||
"clocale": "cpp",
|
||||
"cmath": "cpp",
|
||||
"compare": "cpp",
|
||||
"complex": "cpp",
|
||||
"concepts": "cpp",
|
||||
"cstdarg": "cpp",
|
||||
"cstddef": "cpp",
|
||||
"cstdint": "cpp",
|
||||
"cstdio": "cpp",
|
||||
"cstdlib": "cpp",
|
||||
"cstring": "cpp",
|
||||
"ctime": "cpp",
|
||||
"cwchar": "cpp",
|
||||
"cwctype": "cpp",
|
||||
"exception": "cpp",
|
||||
"initializer_list": "cpp",
|
||||
"ios": "cpp",
|
||||
"iosfwd": "cpp",
|
||||
"istream": "cpp",
|
||||
"limits": "cpp",
|
||||
"locale": "cpp",
|
||||
"memory": "cpp",
|
||||
"mutex": "cpp",
|
||||
"new": "cpp",
|
||||
"optional": "cpp",
|
||||
"ostream": "cpp",
|
||||
"ratio": "cpp",
|
||||
"sstream": "cpp",
|
||||
"stdexcept": "cpp",
|
||||
"streambuf": "cpp",
|
||||
"string": "cpp",
|
||||
"string_view": "cpp",
|
||||
"system_error": "cpp",
|
||||
"tuple": "cpp",
|
||||
"type_traits": "cpp",
|
||||
"typeinfo": "cpp",
|
||||
"unordered_map": "cpp",
|
||||
"variant": "cpp",
|
||||
"algorithm": "cpp",
|
||||
"iostream": "cpp",
|
||||
"iomanip": "cpp",
|
||||
"numeric": "cpp",
|
||||
"set": "cpp",
|
||||
"__tree": "cpp",
|
||||
"deque": "cpp",
|
||||
"list": "cpp",
|
||||
"map": "cpp",
|
||||
"unordered_set": "cpp",
|
||||
"any": "cpp",
|
||||
"condition_variable": "cpp",
|
||||
"forward_list": "cpp",
|
||||
"fstream": "cpp",
|
||||
"stack": "cpp",
|
||||
"thread": "cpp",
|
||||
"__memory": "cpp",
|
||||
"filesystem": "cpp",
|
||||
"*.toml": "toml",
|
||||
"utility": "cpp",
|
||||
"__verbose_abort": "cpp",
|
||||
"bit": "cpp"
|
||||
}
|
||||
}
|
16
.vscode/tasks.json
vendored
Normal file
16
.vscode/tasks.json
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "cmake",
|
||||
"label": "CMake: build",
|
||||
"command": "build",
|
||||
"targets": [
|
||||
"all"
|
||||
],
|
||||
"group": "build",
|
||||
"problemMatcher": [],
|
||||
"detail": "CMake template build task"
|
||||
}
|
||||
]
|
||||
}
|
132
ArffFiles.cc
Normal file
132
ArffFiles.cc
Normal file
@ -0,0 +1,132 @@
|
||||
#include "ArffFiles.h"
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
|
||||
using namespace std;
|
||||
|
||||
ArffFiles::ArffFiles() = default;
|
||||
|
||||
vector<string> ArffFiles::getLines() const
|
||||
{
|
||||
return lines;
|
||||
}
|
||||
|
||||
unsigned long int ArffFiles::getSize() const
|
||||
{
|
||||
return lines.size();
|
||||
}
|
||||
|
||||
vector<pair<string, string>> ArffFiles::getAttributes() const
|
||||
{
|
||||
return attributes;
|
||||
}
|
||||
|
||||
string ArffFiles::getClassName() const
|
||||
{
|
||||
return className;
|
||||
}
|
||||
|
||||
string ArffFiles::getClassType() const
|
||||
{
|
||||
return classType;
|
||||
}
|
||||
|
||||
vector<vector<float>>& ArffFiles::getX()
|
||||
{
|
||||
return X;
|
||||
}
|
||||
|
||||
vector<int>& ArffFiles::getY()
|
||||
{
|
||||
return y;
|
||||
}
|
||||
|
||||
void ArffFiles::load(const string& fileName, bool classLast)
|
||||
{
|
||||
ifstream file(fileName);
|
||||
if (!file.is_open()) {
|
||||
throw invalid_argument("Unable to open file");
|
||||
}
|
||||
string line;
|
||||
string keyword;
|
||||
string attribute;
|
||||
string type;
|
||||
string type_w;
|
||||
while (getline(file, line)) {
|
||||
if (line.empty() || line[0] == '%' || line == "\r" || line == " ") {
|
||||
continue;
|
||||
}
|
||||
if (line.find("@attribute") != string::npos || line.find("@ATTRIBUTE") != string::npos) {
|
||||
stringstream ss(line);
|
||||
ss >> keyword >> attribute;
|
||||
type = "";
|
||||
while (ss >> type_w)
|
||||
type += type_w + " ";
|
||||
attributes.emplace_back(attribute, trim(type));
|
||||
continue;
|
||||
}
|
||||
if (line[0] == '@') {
|
||||
continue;
|
||||
}
|
||||
lines.push_back(line);
|
||||
}
|
||||
file.close();
|
||||
if (attributes.empty())
|
||||
throw invalid_argument("No attributes found");
|
||||
if (classLast) {
|
||||
className = get<0>(attributes.back());
|
||||
classType = get<1>(attributes.back());
|
||||
attributes.pop_back();
|
||||
} else {
|
||||
className = get<0>(attributes.front());
|
||||
classType = get<1>(attributes.front());
|
||||
attributes.erase(attributes.begin());
|
||||
}
|
||||
generateDataset(classLast);
|
||||
|
||||
}
|
||||
|
||||
void ArffFiles::generateDataset(bool classLast)
|
||||
{
|
||||
X = vector<vector<float>>(attributes.size(), vector<float>(lines.size()));
|
||||
auto yy = vector<string>(lines.size(), "");
|
||||
int labelIndex = classLast ? static_cast<int>(attributes.size()) : 0;
|
||||
for (size_t i = 0; i < lines.size(); i++) {
|
||||
stringstream ss(lines[i]);
|
||||
string value;
|
||||
int pos = 0;
|
||||
int xIndex = 0;
|
||||
while (getline(ss, value, ',')) {
|
||||
if (pos++ == labelIndex) {
|
||||
yy[i] = value;
|
||||
} else {
|
||||
X[xIndex++][i] = stof(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
y = factorize(yy);
|
||||
}
|
||||
|
||||
string ArffFiles::trim(const string& source)
|
||||
{
|
||||
string s(source);
|
||||
s.erase(0, s.find_first_not_of(" \n\r\t"));
|
||||
s.erase(s.find_last_not_of(" \n\r\t") + 1);
|
||||
return s;
|
||||
}
|
||||
|
||||
vector<int> ArffFiles::factorize(const vector<string>& labels_t)
|
||||
{
|
||||
vector<int> yy;
|
||||
yy.reserve(labels_t.size());
|
||||
map<string, int> labelMap;
|
||||
int i = 0;
|
||||
for (const string& label : labels_t) {
|
||||
if (labelMap.find(label) == labelMap.end()) {
|
||||
labelMap[label] = i++;
|
||||
}
|
||||
yy.push_back(labelMap[label]);
|
||||
}
|
||||
return yy;
|
||||
}
|
34
ArffFiles.h
Normal file
34
ArffFiles.h
Normal file
@ -0,0 +1,34 @@
|
||||
#ifndef ARFFFILES_H
|
||||
#define ARFFFILES_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
|
||||
class ArffFiles {
|
||||
private:
|
||||
vector<string> lines;
|
||||
vector<pair<string, string>> attributes;
|
||||
string className;
|
||||
string classType;
|
||||
vector<vector<float>> X;
|
||||
vector<int> y;
|
||||
|
||||
void generateDataset(bool);
|
||||
|
||||
public:
|
||||
ArffFiles();
|
||||
void load(const string&, bool = true);
|
||||
vector<string> getLines() const;
|
||||
unsigned long int getSize() const;
|
||||
string getClassName() const;
|
||||
string getClassType() const;
|
||||
static string trim(const string&);
|
||||
vector<vector<float>>& getX();
|
||||
vector<int>& getY();
|
||||
vector<pair<string, string>> getAttributes() const;
|
||||
static vector<int> factorize(const vector<string>& labels_t);
|
||||
};
|
||||
|
||||
#endif
|
16
CMakeLists.txt
Normal file
16
CMakeLists.txt
Normal file
@ -0,0 +1,16 @@
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
|
||||
project(bayesnet)
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
endif ()
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||
|
||||
# add_library(BayesNet Node.cc Network.cc)
|
||||
# add_executable(BayesNet main.cc Node.cc Network.cc ArffFiles.cc)
|
||||
add_executable(BayesNet main.cc ArffFiles.cc Node.cc Network.cc)
|
||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
79
Network.cc
Normal file
79
Network.cc
Normal file
@ -0,0 +1,79 @@
|
||||
#include "Network.h"
|
||||
namespace bayesnet {
|
||||
Network::~Network()
|
||||
{
|
||||
for (auto& pair : nodes) {
|
||||
delete pair.second;
|
||||
}
|
||||
}
|
||||
void Network::addNode(std::string name, int numStates)
|
||||
{
|
||||
nodes[name] = new Node(name, numStates);
|
||||
}
|
||||
void Network::addEdge(const std::string parent, const std::string child)
|
||||
{
|
||||
if (nodes.find(parent) == nodes.end()) {
|
||||
throw std::invalid_argument("Parent node " + parent + " does not exist");
|
||||
}
|
||||
if (nodes.find(child) == nodes.end()) {
|
||||
throw std::invalid_argument("Child node " + child + " does not exist");
|
||||
}
|
||||
nodes[parent]->addChild(nodes[child]);
|
||||
nodes[child]->addParent(nodes[parent]);
|
||||
}
|
||||
std::map<std::string, Node*>& Network::getNodes()
|
||||
{
|
||||
return nodes;
|
||||
}
|
||||
void Network::fit(const std::vector<std::vector<int>>& dataset, const int smoothing)
|
||||
{
|
||||
auto jointCounts = [](const std::vector<std::vector<int>>& data, const std::vector<int>& indices, int numStates) {
|
||||
int size = indices.size();
|
||||
std::vector<int64_t> sizes(size, numStates);
|
||||
torch::Tensor counts = torch::zeros(sizes, torch::kLong);
|
||||
|
||||
for (const auto& row : data) {
|
||||
int idx = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
idx = idx * numStates + row[indices[i]];
|
||||
}
|
||||
counts.view({ -1 }).add_(idx, 1);
|
||||
}
|
||||
|
||||
return counts;
|
||||
};
|
||||
|
||||
auto marginalCounts = [](const torch::Tensor& jointCounts) {
|
||||
return jointCounts.sum(-1);
|
||||
};
|
||||
|
||||
for (auto& pair : nodes) {
|
||||
Node* node = pair.second;
|
||||
|
||||
std::vector<int> indices;
|
||||
for (const auto& parent : node->getParents()) {
|
||||
indices.push_back(nodes[parent->getName()]->getId());
|
||||
}
|
||||
indices.push_back(node->getId());
|
||||
|
||||
for (auto& child : node->getChildren()) {
|
||||
torch::Tensor counts = jointCounts(dataset, indices, node->getNumStates()) + smoothing;
|
||||
torch::Tensor parentCounts = marginalCounts(counts);
|
||||
parentCounts = parentCounts.unsqueeze(-1);
|
||||
|
||||
torch::Tensor cpt = counts.to(torch::kDouble) / parentCounts.to(torch::kDouble);
|
||||
setCPD(node->getCPDKey(child), cpt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor& Network::getCPD(const std::string& key)
|
||||
{
|
||||
return cpds[key];
|
||||
}
|
||||
|
||||
void Network::setCPD(const std::string& key, const torch::Tensor& cpt)
|
||||
{
|
||||
cpds[key] = cpt;
|
||||
}
|
||||
}
|
21
Network.h
Normal file
21
Network.h
Normal file
@ -0,0 +1,21 @@
|
||||
#ifndef NETWORK_H
|
||||
#define NETWORK_H
|
||||
#include "Node.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
namespace bayesnet {
|
||||
class Network {
|
||||
private:
|
||||
map<string, Node*> nodes;
|
||||
map<string, torch::Tensor> cpds; // Map from CPD key to CPD tensor
|
||||
public:
|
||||
~Network();
|
||||
void addNode(string, int);
|
||||
void addEdge(const string, const string);
|
||||
map<string, Node*>& getNodes();
|
||||
void fit(const vector<vector<int>>&, const int);
|
||||
torch::Tensor& getCPD(const string&);
|
||||
void setCPD(const string&, const torch::Tensor&);
|
||||
};
|
||||
}
|
||||
#endif
|
41
Node.cc
Normal file
41
Node.cc
Normal file
@ -0,0 +1,41 @@
|
||||
#include "Node.h"
|
||||
|
||||
namespace bayesnet {
|
||||
int Node::next_id = 0;
|
||||
|
||||
Node::Node(const std::string& name, int numStates)
|
||||
: id(next_id++), name(name), numStates(numStates), cpt(torch::Tensor()), parents(vector<Node*>()), children(vector<Node*>())
|
||||
{
|
||||
}
|
||||
|
||||
string Node::getName() const
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
void Node::addParent(Node* parent)
|
||||
{
|
||||
parents.push_back(parent);
|
||||
}
|
||||
|
||||
void Node::addChild(Node* child)
|
||||
{
|
||||
children.push_back(child);
|
||||
}
|
||||
vector<Node*>& Node::getParents()
|
||||
{
|
||||
return parents;
|
||||
}
|
||||
vector<Node*>& Node::getChildren()
|
||||
{
|
||||
return children;
|
||||
}
|
||||
int Node::getNumStates() const
|
||||
{
|
||||
return numStates;
|
||||
}
|
||||
string Node::getCPDKey(const Node* child) const
|
||||
{
|
||||
return name + "-" + child->getName();
|
||||
}
|
||||
}
|
31
Node.h
Normal file
31
Node.h
Normal file
@ -0,0 +1,31 @@
|
||||
#ifndef NODE_H
|
||||
#define NODE_H
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
class Node {
|
||||
private:
|
||||
static int next_id;
|
||||
const int id;
|
||||
string name;
|
||||
vector<Node*> parents;
|
||||
vector<Node*> children;
|
||||
int numStates;
|
||||
torch::Tensor cpt;
|
||||
public:
|
||||
Node(const std::string& name, int numStates);
|
||||
void addParent(Node* parent);
|
||||
void addChild(Node* child);
|
||||
string getName() const;
|
||||
vector<Node*>& getParents();
|
||||
vector<Node*>& getChildren();
|
||||
torch::Tensor& getCPT();
|
||||
void setCPT(const torch::Tensor& cpt);
|
||||
int getNumStates() const;
|
||||
int getId() const { return id; }
|
||||
string getCPDKey(const Node*) const;
|
||||
};
|
||||
}
|
||||
#endif
|
225
iris.arff
Executable file
225
iris.arff
Executable file
@ -0,0 +1,225 @@
|
||||
% 1. Title: Iris Plants Database
|
||||
%
|
||||
% 2. Sources:
|
||||
% (a) Creator: R.A. Fisher
|
||||
% (b) Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
|
||||
% (c) Date: July, 1988
|
||||
%
|
||||
% 3. Past Usage:
|
||||
% - Publications: too many to mention!!! Here are a few.
|
||||
% 1. Fisher,R.A. "The use of multiple measurements in taxonomic problems"
|
||||
% Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions
|
||||
% to Mathematical Statistics" (John Wiley, NY, 1950).
|
||||
% 2. Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
|
||||
% (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.
|
||||
% 3. Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
|
||||
% Structure and Classification Rule for Recognition in Partially Exposed
|
||||
% Environments". IEEE Transactions on Pattern Analysis and Machine
|
||||
% Intelligence, Vol. PAMI-2, No. 1, 67-71.
|
||||
% -- Results:
|
||||
% -- very low misclassification rates (0% for the setosa class)
|
||||
% 4. Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE
|
||||
% Transactions on Information Theory, May 1972, 431-433.
|
||||
% -- Results:
|
||||
% -- very low misclassification rates again
|
||||
% 5. See also: 1988 MLC Proceedings, 54-64. Cheeseman et al's AUTOCLASS II
|
||||
% conceptual clustering system finds 3 classes in the data.
|
||||
%
|
||||
% 4. Relevant Information:
|
||||
% --- This is perhaps the best known database to be found in the pattern
|
||||
% recognition literature. Fisher's paper is a classic in the field
|
||||
% and is referenced frequently to this day. (See Duda & Hart, for
|
||||
% example.) The data set contains 3 classes of 50 instances each,
|
||||
% where each class refers to a type of iris plant. One class is
|
||||
% linearly separable from the other 2; the latter are NOT linearly
|
||||
% separable from each other.
|
||||
% --- Predicted attribute: class of iris plant.
|
||||
% --- This is an exceedingly simple domain.
|
||||
%
|
||||
% 5. Number of Instances: 150 (50 in each of three classes)
|
||||
%
|
||||
% 6. Number of Attributes: 4 numeric, predictive attributes and the class
|
||||
%
|
||||
% 7. Attribute Information:
|
||||
% 1. sepal length in cm
|
||||
% 2. sepal width in cm
|
||||
% 3. petal length in cm
|
||||
% 4. petal width in cm
|
||||
% 5. class:
|
||||
% -- Iris Setosa
|
||||
% -- Iris Versicolour
|
||||
% -- Iris Virginica
|
||||
%
|
||||
% 8. Missing Attribute Values: None
|
||||
%
|
||||
% Summary Statistics:
|
||||
% Min Max Mean SD Class Correlation
|
||||
% sepal length: 4.3 7.9 5.84 0.83 0.7826
|
||||
% sepal width: 2.0 4.4 3.05 0.43 -0.4194
|
||||
% petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
|
||||
% petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
|
||||
%
|
||||
% 9. Class Distribution: 33.3% for each of 3 classes.
|
||||
|
||||
@RELATION iris
|
||||
|
||||
@ATTRIBUTE sepallength REAL
|
||||
@ATTRIBUTE sepalwidth REAL
|
||||
@ATTRIBUTE petallength REAL
|
||||
@ATTRIBUTE petalwidth REAL
|
||||
@ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica}
|
||||
|
||||
@DATA
|
||||
5.1,3.5,1.4,0.2,Iris-setosa
|
||||
4.9,3.0,1.4,0.2,Iris-setosa
|
||||
4.7,3.2,1.3,0.2,Iris-setosa
|
||||
4.6,3.1,1.5,0.2,Iris-setosa
|
||||
5.0,3.6,1.4,0.2,Iris-setosa
|
||||
5.4,3.9,1.7,0.4,Iris-setosa
|
||||
4.6,3.4,1.4,0.3,Iris-setosa
|
||||
5.0,3.4,1.5,0.2,Iris-setosa
|
||||
4.4,2.9,1.4,0.2,Iris-setosa
|
||||
4.9,3.1,1.5,0.1,Iris-setosa
|
||||
5.4,3.7,1.5,0.2,Iris-setosa
|
||||
4.8,3.4,1.6,0.2,Iris-setosa
|
||||
4.8,3.0,1.4,0.1,Iris-setosa
|
||||
4.3,3.0,1.1,0.1,Iris-setosa
|
||||
5.8,4.0,1.2,0.2,Iris-setosa
|
||||
5.7,4.4,1.5,0.4,Iris-setosa
|
||||
5.4,3.9,1.3,0.4,Iris-setosa
|
||||
5.1,3.5,1.4,0.3,Iris-setosa
|
||||
5.7,3.8,1.7,0.3,Iris-setosa
|
||||
5.1,3.8,1.5,0.3,Iris-setosa
|
||||
5.4,3.4,1.7,0.2,Iris-setosa
|
||||
5.1,3.7,1.5,0.4,Iris-setosa
|
||||
4.6,3.6,1.0,0.2,Iris-setosa
|
||||
5.1,3.3,1.7,0.5,Iris-setosa
|
||||
4.8,3.4,1.9,0.2,Iris-setosa
|
||||
5.0,3.0,1.6,0.2,Iris-setosa
|
||||
5.0,3.4,1.6,0.4,Iris-setosa
|
||||
5.2,3.5,1.5,0.2,Iris-setosa
|
||||
5.2,3.4,1.4,0.2,Iris-setosa
|
||||
4.7,3.2,1.6,0.2,Iris-setosa
|
||||
4.8,3.1,1.6,0.2,Iris-setosa
|
||||
5.4,3.4,1.5,0.4,Iris-setosa
|
||||
5.2,4.1,1.5,0.1,Iris-setosa
|
||||
5.5,4.2,1.4,0.2,Iris-setosa
|
||||
4.9,3.1,1.5,0.1,Iris-setosa
|
||||
5.0,3.2,1.2,0.2,Iris-setosa
|
||||
5.5,3.5,1.3,0.2,Iris-setosa
|
||||
4.9,3.1,1.5,0.1,Iris-setosa
|
||||
4.4,3.0,1.3,0.2,Iris-setosa
|
||||
5.1,3.4,1.5,0.2,Iris-setosa
|
||||
5.0,3.5,1.3,0.3,Iris-setosa
|
||||
4.5,2.3,1.3,0.3,Iris-setosa
|
||||
4.4,3.2,1.3,0.2,Iris-setosa
|
||||
5.0,3.5,1.6,0.6,Iris-setosa
|
||||
5.1,3.8,1.9,0.4,Iris-setosa
|
||||
4.8,3.0,1.4,0.3,Iris-setosa
|
||||
5.1,3.8,1.6,0.2,Iris-setosa
|
||||
4.6,3.2,1.4,0.2,Iris-setosa
|
||||
5.3,3.7,1.5,0.2,Iris-setosa
|
||||
5.0,3.3,1.4,0.2,Iris-setosa
|
||||
7.0,3.2,4.7,1.4,Iris-versicolor
|
||||
6.4,3.2,4.5,1.5,Iris-versicolor
|
||||
6.9,3.1,4.9,1.5,Iris-versicolor
|
||||
5.5,2.3,4.0,1.3,Iris-versicolor
|
||||
6.5,2.8,4.6,1.5,Iris-versicolor
|
||||
5.7,2.8,4.5,1.3,Iris-versicolor
|
||||
6.3,3.3,4.7,1.6,Iris-versicolor
|
||||
4.9,2.4,3.3,1.0,Iris-versicolor
|
||||
6.6,2.9,4.6,1.3,Iris-versicolor
|
||||
5.2,2.7,3.9,1.4,Iris-versicolor
|
||||
5.0,2.0,3.5,1.0,Iris-versicolor
|
||||
5.9,3.0,4.2,1.5,Iris-versicolor
|
||||
6.0,2.2,4.0,1.0,Iris-versicolor
|
||||
6.1,2.9,4.7,1.4,Iris-versicolor
|
||||
5.6,2.9,3.6,1.3,Iris-versicolor
|
||||
6.7,3.1,4.4,1.4,Iris-versicolor
|
||||
5.6,3.0,4.5,1.5,Iris-versicolor
|
||||
5.8,2.7,4.1,1.0,Iris-versicolor
|
||||
6.2,2.2,4.5,1.5,Iris-versicolor
|
||||
5.6,2.5,3.9,1.1,Iris-versicolor
|
||||
5.9,3.2,4.8,1.8,Iris-versicolor
|
||||
6.1,2.8,4.0,1.3,Iris-versicolor
|
||||
6.3,2.5,4.9,1.5,Iris-versicolor
|
||||
6.1,2.8,4.7,1.2,Iris-versicolor
|
||||
6.4,2.9,4.3,1.3,Iris-versicolor
|
||||
6.6,3.0,4.4,1.4,Iris-versicolor
|
||||
6.8,2.8,4.8,1.4,Iris-versicolor
|
||||
6.7,3.0,5.0,1.7,Iris-versicolor
|
||||
6.0,2.9,4.5,1.5,Iris-versicolor
|
||||
5.7,2.6,3.5,1.0,Iris-versicolor
|
||||
5.5,2.4,3.8,1.1,Iris-versicolor
|
||||
5.5,2.4,3.7,1.0,Iris-versicolor
|
||||
5.8,2.7,3.9,1.2,Iris-versicolor
|
||||
6.0,2.7,5.1,1.6,Iris-versicolor
|
||||
5.4,3.0,4.5,1.5,Iris-versicolor
|
||||
6.0,3.4,4.5,1.6,Iris-versicolor
|
||||
6.7,3.1,4.7,1.5,Iris-versicolor
|
||||
6.3,2.3,4.4,1.3,Iris-versicolor
|
||||
5.6,3.0,4.1,1.3,Iris-versicolor
|
||||
5.5,2.5,4.0,1.3,Iris-versicolor
|
||||
5.5,2.6,4.4,1.2,Iris-versicolor
|
||||
6.1,3.0,4.6,1.4,Iris-versicolor
|
||||
5.8,2.6,4.0,1.2,Iris-versicolor
|
||||
5.0,2.3,3.3,1.0,Iris-versicolor
|
||||
5.6,2.7,4.2,1.3,Iris-versicolor
|
||||
5.7,3.0,4.2,1.2,Iris-versicolor
|
||||
5.7,2.9,4.2,1.3,Iris-versicolor
|
||||
6.2,2.9,4.3,1.3,Iris-versicolor
|
||||
5.1,2.5,3.0,1.1,Iris-versicolor
|
||||
5.7,2.8,4.1,1.3,Iris-versicolor
|
||||
6.3,3.3,6.0,2.5,Iris-virginica
|
||||
5.8,2.7,5.1,1.9,Iris-virginica
|
||||
7.1,3.0,5.9,2.1,Iris-virginica
|
||||
6.3,2.9,5.6,1.8,Iris-virginica
|
||||
6.5,3.0,5.8,2.2,Iris-virginica
|
||||
7.6,3.0,6.6,2.1,Iris-virginica
|
||||
4.9,2.5,4.5,1.7,Iris-virginica
|
||||
7.3,2.9,6.3,1.8,Iris-virginica
|
||||
6.7,2.5,5.8,1.8,Iris-virginica
|
||||
7.2,3.6,6.1,2.5,Iris-virginica
|
||||
6.5,3.2,5.1,2.0,Iris-virginica
|
||||
6.4,2.7,5.3,1.9,Iris-virginica
|
||||
6.8,3.0,5.5,2.1,Iris-virginica
|
||||
5.7,2.5,5.0,2.0,Iris-virginica
|
||||
5.8,2.8,5.1,2.4,Iris-virginica
|
||||
6.4,3.2,5.3,2.3,Iris-virginica
|
||||
6.5,3.0,5.5,1.8,Iris-virginica
|
||||
7.7,3.8,6.7,2.2,Iris-virginica
|
||||
7.7,2.6,6.9,2.3,Iris-virginica
|
||||
6.0,2.2,5.0,1.5,Iris-virginica
|
||||
6.9,3.2,5.7,2.3,Iris-virginica
|
||||
5.6,2.8,4.9,2.0,Iris-virginica
|
||||
7.7,2.8,6.7,2.0,Iris-virginica
|
||||
6.3,2.7,4.9,1.8,Iris-virginica
|
||||
6.7,3.3,5.7,2.1,Iris-virginica
|
||||
7.2,3.2,6.0,1.8,Iris-virginica
|
||||
6.2,2.8,4.8,1.8,Iris-virginica
|
||||
6.1,3.0,4.9,1.8,Iris-virginica
|
||||
6.4,2.8,5.6,2.1,Iris-virginica
|
||||
7.2,3.0,5.8,1.6,Iris-virginica
|
||||
7.4,2.8,6.1,1.9,Iris-virginica
|
||||
7.9,3.8,6.4,2.0,Iris-virginica
|
||||
6.4,2.8,5.6,2.2,Iris-virginica
|
||||
6.3,2.8,5.1,1.5,Iris-virginica
|
||||
6.1,2.6,5.6,1.4,Iris-virginica
|
||||
7.7,3.0,6.1,2.3,Iris-virginica
|
||||
6.3,3.4,5.6,2.4,Iris-virginica
|
||||
6.4,3.1,5.5,1.8,Iris-virginica
|
||||
6.0,3.0,4.8,1.8,Iris-virginica
|
||||
6.9,3.1,5.4,2.1,Iris-virginica
|
||||
6.7,3.1,5.6,2.4,Iris-virginica
|
||||
6.9,3.1,5.1,2.3,Iris-virginica
|
||||
5.8,2.7,5.1,1.9,Iris-virginica
|
||||
6.8,3.2,5.9,2.3,Iris-virginica
|
||||
6.7,3.3,5.7,2.5,Iris-virginica
|
||||
6.7,3.0,5.2,2.3,Iris-virginica
|
||||
6.3,2.5,5.0,1.9,Iris-virginica
|
||||
6.5,3.0,5.2,2.0,Iris-virginica
|
||||
6.2,3.4,5.4,2.3,Iris-virginica
|
||||
5.9,3.0,5.1,1.8,Iris-virginica
|
||||
%
|
||||
%
|
||||
%
|
43
main.cc
Normal file
43
main.cc
Normal file
@ -0,0 +1,43 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <torch/torch.h>
|
||||
#include "ArffFiles.h"
|
||||
#include "Network.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
int main()
|
||||
{
|
||||
auto handler = ArffFiles();
|
||||
handler.load("iris.arff");
|
||||
auto X = handler.getX();
|
||||
auto y = handler.getY();
|
||||
auto className = handler.getClassName();
|
||||
vector<pair<string, string>> edges = { {className, "sepallength"}, {className, "sepalwidth"}, {className, "petallength"}, {className, "petalwidth"} };
|
||||
auto network = bayesnet::Network();
|
||||
// Add nodes to the network
|
||||
for (auto feature : handler.getAttributes()) {
|
||||
cout << "Adding feature: " << feature.first << endl;
|
||||
network.addNode(feature.first, 7);
|
||||
}
|
||||
network.addNode(className, 3);
|
||||
for (auto item : edges) {
|
||||
network.addEdge(item.first, item.second);
|
||||
}
|
||||
cout << "Hello, Bayesian Networks!" << endl;
|
||||
torch::Tensor tensor = torch::eye(3);
|
||||
cout << tensor << std::endl;
|
||||
cout << "Nodes:" << endl;
|
||||
for (auto [name, item] : network.getNodes()) {
|
||||
cout << "*" << item->getName() << endl;
|
||||
cout << "-Parents:" << endl;
|
||||
for (auto parent : item->getParents()) {
|
||||
cout << " " << parent->getName() << endl;
|
||||
}
|
||||
cout << "-Children:" << endl;
|
||||
for (auto child : item->getChildren()) {
|
||||
cout << " " << child->getName() << endl;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
47
simple/Network.cc
Normal file
47
simple/Network.cc
Normal file
@ -0,0 +1,47 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "Network.h"
|
||||
|
||||
namespace bayesnet {
|
||||
Network::Network() {}
|
||||
|
||||
Network::~Network()
|
||||
{
|
||||
for (auto& pair : nodes) {
|
||||
delete pair.second;
|
||||
}
|
||||
}
|
||||
|
||||
void Network::addNode(std::string name)
|
||||
{
|
||||
nodes[name] = new Node(name);
|
||||
}
|
||||
|
||||
void Network::addEdge(std::string parentName, std::string childName)
|
||||
{
|
||||
Node* parent = nodes[parentName];
|
||||
Node* child = nodes[childName];
|
||||
|
||||
if (parent == nullptr || child == nullptr) {
|
||||
throw std::invalid_argument("Parent or child node not found.");
|
||||
}
|
||||
|
||||
child->addParent(parent);
|
||||
}
|
||||
|
||||
// to be implemented
|
||||
void Network::fit(const std::vector<std::vector<double>>& dataset)
|
||||
{
|
||||
// ... learn parameters (i.e., CPTs) using the dataset
|
||||
}
|
||||
|
||||
// to be implemented
|
||||
std::vector<double> Network::predict(const std::vector<std::vector<double>>& testset)
|
||||
{
|
||||
std::vector<double> predictions;
|
||||
// ... use the CPTs and network structure to predict values
|
||||
return predictions;
|
||||
}
|
||||
}
|
||||
|
21
simple/Network.h
Normal file
21
simple/Network.h
Normal file
@ -0,0 +1,21 @@
|
||||
#ifndef NETWORK_H
|
||||
#define NETWORK_H
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "Node.h"
|
||||
|
||||
namespace bayesnet {
|
||||
class Network {
|
||||
private:
|
||||
std::map<std::string, Node*> nodes;
|
||||
public:
|
||||
Network();
|
||||
~Network();
|
||||
void addNode(std::string);
|
||||
void addEdge(std::string, std::string);
|
||||
void fit(const std::vector<std::vector<double>>&);
|
||||
std::vector<double> predict(const std::vector<std::vector<double>>&);
|
||||
};
|
||||
}
|
||||
#endif
|
14
simple/Node.cc
Normal file
14
simple/Node.cc
Normal file
@ -0,0 +1,14 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "Node.h"
|
||||
|
||||
namespace bayesnet {
|
||||
Node::Node(std::string name) : name(name) {}
|
||||
|
||||
void Node::addParent(Node* parent)
|
||||
{
|
||||
parents.push_back(parent);
|
||||
parent->children.push_back(this);
|
||||
}
|
||||
}
|
18
simple/Node.h
Normal file
18
simple/Node.h
Normal file
@ -0,0 +1,18 @@
|
||||
#ifndef NODE_H
|
||||
#define NODE_H
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
namespace bayesnet {
|
||||
class Node {
|
||||
private:
|
||||
std::string name;
|
||||
std::vector<Node*> parents;
|
||||
std::vector<Node*> children;
|
||||
std::map<std::vector<bool>, double> cpt; // Conditional Probability Table
|
||||
public:
|
||||
Node(std::string);
|
||||
void addParent(Node*);
|
||||
};
|
||||
}
|
||||
#endif
|
23
test.cc
Normal file
23
test.cc
Normal file
@ -0,0 +1,23 @@
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
int main(int argc, char const* argv[])
|
||||
{
|
||||
map<string, int> m;
|
||||
m["a"] = 1;
|
||||
m["b"] = 2;
|
||||
m["c"] = 3;
|
||||
if (m.find("b") != m.end()) {
|
||||
cout << "Found b" << endl;
|
||||
} else {
|
||||
cout << "Not found b" << endl;
|
||||
}
|
||||
// for (auto [key, value] : m) {
|
||||
// cout << key << " " << value << endl;
|
||||
// }
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user