Complete predict and score of kdb

Change new/delete to make_unique
This commit is contained in:
Ricardo Montañana Gómez 2023-07-15 01:05:36 +02:00
parent 6a8aad5911
commit db6908acd0
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
16 changed files with 176 additions and 98 deletions

2
.vscode/launch.json vendored
View File

@ -8,7 +8,7 @@
"program": "${workspaceFolder}/build/sample/main", "program": "${workspaceFolder}/build/sample/main",
"args": [ "args": [
"-f", "-f",
"glass" "iris"
], ],
"cwd": "${workspaceFolder}", "cwd": "${workspaceFolder}",
"preLaunchTask": "CMake: build" "preLaunchTask": "CMake: build"

View File

@ -96,7 +96,8 @@
"csetjmp": "cpp", "csetjmp": "cpp",
"future": "cpp", "future": "cpp",
"queue": "cpp", "queue": "cpp",
"typeindex": "cpp" "typeindex": "cpp",
"shared_mutex": "cpp"
}, },
"cmake.configureOnOpen": false, "cmake.configureOnOpen": false,
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools" "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"

View File

@ -96,15 +96,16 @@ pair<vector<mdlp::labels_t>, map<string, int>> discretize(vector<mdlp::samples_t
void showNodesInfo(bayesnet::Network& network, string className) void showNodesInfo(bayesnet::Network& network, string className)
{ {
cout << "Nodes:" << endl; cout << "Nodes:" << endl;
for (auto [name, item] : network.getNodes()) { for (auto& node : network.getNodes()) {
cout << "*" << item->getName() << " States -> " << item->getNumStates() << endl; auto name = node.first;
cout << "*" << node.second->getName() << " States -> " << node.second->getNumStates() << endl;
cout << "-Parents:"; cout << "-Parents:";
for (auto parent : item->getParents()) { for (auto parent : node.second->getParents()) {
cout << " " << parent->getName(); cout << " " << parent->getName();
} }
cout << endl; cout << endl;
cout << "-Children:"; cout << "-Children:";
for (auto child : item->getChildren()) { for (auto child : node.second->getChildren()) {
cout << " " << child->getName(); cout << " " << child->getName();
} }
cout << endl; cout << endl;
@ -113,7 +114,7 @@ void showNodesInfo(bayesnet::Network& network, string className)
void showCPDS(bayesnet::Network& network) void showCPDS(bayesnet::Network& network)
{ {
cout << "CPDs:" << endl; cout << "CPDs:" << endl;
auto nodes = network.getNodes(); auto& nodes = network.getNodes();
for (auto it = nodes.begin(); it != nodes.end(); it++) { for (auto it = nodes.begin(); it != nodes.end(); it++) {
cout << "* Name: " << it->first << " " << it->second->getName() << " -> " << it->second->getNumStates() << endl; cout << "* Name: " << it->first << " " << it->second->getName() << " -> " << it->second->getNumStates() << endl;
cout << "Parents: "; cout << "Parents: ";
@ -253,12 +254,14 @@ int main(int argc, char** argv)
for (auto feature : features) { for (auto feature : features) {
states[feature] = vector<int>(maxes[feature]); states[feature] = vector<int>(maxes[feature]);
} }
states[className] = vector<int>(maxes[className]); states[className] = vector<int>(
maxes[className]);
auto kdb = bayesnet::KDB(2); auto kdb = bayesnet::KDB(2);
kdb.fit(Xd, y, features, className, states); kdb.fit(Xd, y, features, className, states);
for (auto line : kdb.show()) { for (auto line : kdb.show()) {
cout << line << endl; cout << line << endl;
} }
cout << "Score: " << kdb.score(Xd, y) << endl;
cout << "****************** KDB ******************" << endl; cout << "****************** KDB ******************" << endl;
return 0; return 0;
} }

16
src/AODE.cc Normal file
View File

@ -0,0 +1,16 @@
#include "AODE.h"
namespace bayesnet {
AODE::AODE() : Ensemble()
{
models = vector<SPODE>();
}
void AODE::train()
{
for (int i = 0; i < features.size(); ++i) {
SPODE model = SPODE(i);
models.push_back(model);
}
}
}

13
src/AODE.h Normal file
View File

@ -0,0 +1,13 @@
#ifndef AODE_H
#define AODE_H
#include "Ensemble.h"
#include "SPODE.h"
namespace bayesnet {
class AODE : public Ensemble {
protected:
void train() override;
public:
AODE();
};
}
#endif

View File

@ -1,10 +1,11 @@
#include "BaseClassifier.h" #include "BaseClassifier.h"
#include "utils.h"
namespace bayesnet { namespace bayesnet {
using namespace std; using namespace std;
using namespace torch; using namespace torch;
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()) {} BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states) BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
{ {
@ -16,21 +17,19 @@ namespace bayesnet {
auto n_classes = states[className].size(); auto n_classes = states[className].size();
metrics = Metrics(dataset, features, className, n_classes); metrics = Metrics(dataset, features, className, n_classes);
train(); train();
model.fit(Xv, yv, features, className);
fitted = true;
return *this; return *this;
} }
BaseClassifier& BaseClassifier::fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
{
this->X = X;
this->y = y;
return build(features, className, states);
}
BaseClassifier& BaseClassifier::fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) BaseClassifier& BaseClassifier::fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states)
{ {
this->X = torch::zeros({ static_cast<int64_t>(X[0].size()), static_cast<int64_t>(X.size()) }, kInt64); this->X = torch::zeros({ static_cast<int64_t>(X[0].size()), static_cast<int64_t>(X.size()) }, kInt64);
Xv = X;
for (int i = 0; i < X.size(); ++i) { for (int i = 0; i < X.size(); ++i) {
this->X.index_put_({ "...", i }, torch::tensor(X[i], kInt64)); this->X.index_put_({ "...", i }, torch::tensor(X[i], kInt64));
} }
this->y = torch::tensor(y, kInt64); this->y = torch::tensor(y, kInt64);
yv = y;
return build(features, className, states); return build(features, className, states);
} }
void BaseClassifier::checkFitParameters() void BaseClassifier::checkFitParameters()
@ -53,55 +52,44 @@ namespace bayesnet {
} }
} }
} }
vector<int> BaseClassifier::argsort(vector<float>& nums)
{
int n = nums.size();
vector<int> indices(n);
iota(indices.begin(), indices.end(), 0);
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices;
}
vector<vector<int>> tensorToVector(const torch::Tensor& tensor)
{
// convert mxn tensor to nxm vector
vector<vector<int>> result;
auto tensor_accessor = tensor.accessor<int, 2>();
// Iterate over columns and rows of the tensor
for (int j = 0; j < tensor.size(1); ++j) {
vector<int> column;
for (int i = 0; i < tensor.size(0); ++i) {
column.push_back(tensor_accessor[i][j]);
}
result.push_back(column);
}
return result;
}
Tensor BaseClassifier::predict(Tensor& X) Tensor BaseClassifier::predict(Tensor& X)
{ {
auto n_models = models.size(); if (!fitted) {
Tensor y_pred = torch::zeros({ X.size(0), n_models }, torch::kInt64); throw logic_error("Classifier has not been fitted");
for (auto i = 0; i < n_models; ++i) {
y_pred.index_put_({ "...", i }, models[i].predict(X));
} }
auto y_pred_ = y_pred.accessor<int64_t, 2>(); auto m_ = X.size(0);
vector<int> y_pred_final; auto n_ = X.size(1);
for (int i = 0; i < y_pred.size(0); ++i) { vector<vector<int>> Xd(n_, vector<int>(m_, 0));
vector<float> votes(states[className].size(), 0); for (auto i = 0; i < n_; i++) {
for (int j = 0; j < y_pred.size(1); ++j) { auto temp = X.index({ "...", i });
votes[y_pred_[i][j]] += 1; Xd[i] = vector<int>(temp.data_ptr<int>(), temp.data_ptr<int>() + m_);
}
auto indices = argsort(votes);
y_pred_final.push_back(indices[0]);
} }
return torch::tensor(y_pred_final, torch::kInt64); auto yp = model.predict(Xd);
auto ypred = torch::tensor(yp, torch::kInt64);
return ypred;
} }
float BaseClassifier::score(Tensor& X, Tensor& y) float BaseClassifier::score(Tensor& X, Tensor& y)
{ {
if (!fitted) {
throw logic_error("Classifier has not been fitted");
}
Tensor y_pred = predict(X); Tensor y_pred = predict(X);
return (y_pred == y).sum().item<float>() / y.size(0); return (y_pred == y).sum().item<float>() / y.size(0);
} }
float BaseClassifier::score(vector<vector<int>>& X, vector<int>& y)
{
if (!fitted) {
throw logic_error("Classifier has not been fitted");
}
auto m_ = X[0].size();
auto n_ = X.size();
vector<vector<int>> Xd(n_, vector<int>(m_, 0));
for (auto i = 0; i < n_; i++) {
Xd[i] = vector<int>(X[i].begin(), X[i].end());
}
return model.score(Xd, y);
}
vector<string> BaseClassifier::show() vector<string> BaseClassifier::show()
{ {
return model.show(); return model.show();

View File

@ -9,12 +9,15 @@ using namespace torch;
namespace bayesnet { namespace bayesnet {
class BaseClassifier { class BaseClassifier {
private: private:
bool fitted;
BaseClassifier& build(vector<string>& features, string className, map<string, vector<int>>& states); BaseClassifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
protected: protected:
Network model; Network model;
int m, n; // m: number of samples, n: number of features int m, n; // m: number of samples, n: number of features
Tensor X; Tensor X;
vector<vector<int>> Xv;
Tensor y; Tensor y;
vector<int> yv;
Tensor dataset; Tensor dataset;
Metrics metrics; Metrics metrics;
vector<string> features; vector<string> features;
@ -24,13 +27,13 @@ namespace bayesnet {
virtual void train() = 0; virtual void train() = 0;
public: public:
BaseClassifier(Network model); BaseClassifier(Network model);
BaseClassifier& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states); virtual ~BaseClassifier() = default;
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states); BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
void addNodes(); void addNodes();
Tensor predict(Tensor& X); Tensor predict(Tensor& X);
float score(Tensor& X, Tensor& y); float score(Tensor& X, Tensor& y);
float score(vector<vector<int>>& X, vector<int>& y);
vector<string> show(); vector<string> show();
vector<int> argsort(vector<float>& nums);
}; };
} }
#endif #endif

View File

@ -1,2 +1,2 @@
add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc) add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}") target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

View File

@ -4,17 +4,22 @@ namespace bayesnet {
using namespace std; using namespace std;
using namespace torch; using namespace torch;
Ensemble::Ensemble(BaseClassifier& model) : model(model), models(vector<BaseClassifier>()), m(0), n(0), metrics(Metrics()) {} Ensemble::Ensemble() : m(0), n(0), n_models(0), metrics(Metrics()) {}
Ensemble& Ensemble::build(vector<string>& features, string className, map<string, vector<int>>& states) Ensemble& Ensemble::build(vector<string>& features, string className, map<string, vector<int>>& states)
{ {
dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1); dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1);
this->features = features; this->features = features;
this->className = className; this->className = className;
this->states = states; this->states = states;
auto n_classes = states[className].size(); auto n_classes = states[className].size();
metrics = Metrics(dataset, features, className, n_classes); metrics = Metrics(dataset, features, className, n_classes);
// Build models
train(); train();
// Train models
n_models = models.size();
for (auto i = 0; i < n_models; ++i) {
models[i].fit(X, y, features, className, states);
}
return *this; return *this;
} }
Ensemble& Ensemble::fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) Ensemble& Ensemble::fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
@ -37,16 +42,21 @@ namespace bayesnet {
} }
Tensor Ensemble::predict(Tensor& X) Tensor Ensemble::predict(Tensor& X)
{ {
auto m_ = X.size(0); Tensor y_pred = torch::zeros({ X.size(0), n_models }, torch::kInt64);
auto n_ = X.size(1); for (auto i = 0; i < n_models; ++i) {
vector<vector<int>> Xd(n_, vector<int>(m_, 0)); y_pred.index_put_({ "...", i }, models[i].predict(X));
for (auto i = 0; i < n_; i++) {
auto temp = X.index({ "...", i });
Xd[i] = vector<int>(temp.data_ptr<int>(), temp.data_ptr<int>() + m_);
} }
auto yp = model.predict(Xd); auto y_pred_ = y_pred.accessor<int64_t, 2>();
auto ypred = torch::tensor(yp, torch::kInt64); vector<int> y_pred_final;
return ypred; for (int i = 0; i < y_pred.size(0); ++i) {
vector<float> votes(states[className].size(), 0);
for (int j = 0; j < y_pred.size(1); ++j) {
votes[y_pred_[i][j]] += 1;
}
auto indices = argsort(votes);
y_pred_final.push_back(indices[0]);
}
return torch::tensor(y_pred_final, torch::kInt64);
} }
float Ensemble::score(Tensor& X, Tensor& y) float Ensemble::score(Tensor& X, Tensor& y)
{ {
@ -55,7 +65,11 @@ namespace bayesnet {
} }
vector<string> Ensemble::show() vector<string> Ensemble::show()
{ {
return model.show(); vector<string> result;
for (auto i = 0; i < n_models; ++i) {
auto res = models[i].show();
result.insert(result.end(), res.begin(), res.end());
}
return result;
} }
} }

View File

@ -3,15 +3,16 @@
#include <torch/torch.h> #include <torch/torch.h>
#include "BaseClassifier.h" #include "BaseClassifier.h"
#include "Metrics.hpp" #include "Metrics.hpp"
#include "utils.h"
using namespace std; using namespace std;
using namespace torch; using namespace torch;
namespace bayesnet { namespace bayesnet {
class Ensemble { class Ensemble {
private: private:
long n_models;
Ensemble& build(vector<string>& features, string className, map<string, vector<int>>& states); Ensemble& build(vector<string>& features, string className, map<string, vector<int>>& states);
protected: protected:
BaseClassifier& model;
vector<BaseClassifier> models; vector<BaseClassifier> models;
int m, n; // m: number of samples, n: number of features int m, n; // m: number of samples, n: number of features
Tensor X; Tensor X;
@ -23,7 +24,8 @@ namespace bayesnet {
map<string, vector<int>> states; map<string, vector<int>> states;
void virtual train() = 0; void virtual train() = 0;
public: public:
Ensemble(BaseClassifier& model); Ensemble();
virtual ~Ensemble() = default;
Ensemble& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states); Ensemble& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states);
Ensemble& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states); Ensemble& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
Tensor predict(Tensor& X); Tensor predict(Tensor& X);

View File

@ -1,6 +1,7 @@
#ifndef KDB_H #ifndef KDB_H
#define KDB_H #define KDB_H
#include "BaseClassifier.h" #include "BaseClassifier.h"
#include "utils.h"
namespace bayesnet { namespace bayesnet {
using namespace std; using namespace std;
using namespace torch; using namespace torch;

View File

@ -2,19 +2,13 @@
#include <mutex> #include <mutex>
#include "Network.h" #include "Network.h"
namespace bayesnet { namespace bayesnet {
Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {} Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8), fitted(false) {}
Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {} Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {} Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()) Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()), fitted(other.fitted)
{ {
for (auto& pair : other.nodes) { for (auto& pair : other.nodes) {
nodes[pair.first] = new Node(*pair.second); nodes[pair.first] = make_unique<Node>(*pair.second);
}
}
Network::~Network()
{
for (auto& pair : nodes) {
delete pair.second;
} }
} }
float Network::getmaxThreads() float Network::getmaxThreads()
@ -32,7 +26,7 @@ namespace bayesnet {
nodes[name]->setNumStates(numStates); nodes[name]->setNumStates(numStates);
return; return;
} }
nodes[name] = new Node(name, numStates); nodes[name] = make_unique<Node>(name, numStates);
} }
vector<string> Network::getFeatures() vector<string> Network::getFeatures()
{ {
@ -45,7 +39,7 @@ namespace bayesnet {
int Network::getStates() int Network::getStates()
{ {
int result = 0; int result = 0;
for (auto node : nodes) { for (auto& node : nodes) {
result += node.second->getNumStates(); result += node.second->getNumStates();
} }
return result; return result;
@ -79,20 +73,20 @@ namespace bayesnet {
throw invalid_argument("Child node " + child + " does not exist"); throw invalid_argument("Child node " + child + " does not exist");
} }
// Temporarily add edge to check for cycles // Temporarily add edge to check for cycles
nodes[parent]->addChild(nodes[child]); nodes[parent]->addChild(nodes[child].get());
nodes[child]->addParent(nodes[parent]); nodes[child]->addParent(nodes[parent].get());
unordered_set<string> visited; unordered_set<string> visited;
unordered_set<string> recStack; unordered_set<string> recStack;
if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
{ {
// remove problematic edge // remove problematic edge
nodes[parent]->removeChild(nodes[child]); nodes[parent]->removeChild(nodes[child].get());
nodes[child]->removeParent(nodes[parent]); nodes[child]->removeParent(nodes[parent].get());
throw invalid_argument("Adding this edge forms a cycle in the graph."); throw invalid_argument("Adding this edge forms a cycle in the graph.");
} }
} }
map<string, Node*>& Network::getNodes() map<string, std::unique_ptr<Node>>& Network::getNodes()
{ {
return nodes; return nodes;
} }
@ -140,9 +134,8 @@ namespace bayesnet {
lock.unlock(); lock.unlock();
pair.second->computeCPT(dataset, laplaceSmoothing); pair.second->computeCPT(dataset, laplaceSmoothing);
lock.lock(); lock.lock();
nodes[pair.first] = pair.second; nodes[pair.first] = std::move(pair.second);
lock.unlock(); lock.unlock();
} }
lock_guard<mutex> lock(mtx); lock_guard<mutex> lock(mtx);
@ -155,10 +148,14 @@ namespace bayesnet {
for (auto& thread : threads) { for (auto& thread : threads) {
thread.join(); thread.join();
} }
fitted = true;
} }
vector<int> Network::predict(const vector<vector<int>>& tsamples) vector<int> Network::predict(const vector<vector<int>>& tsamples)
{ {
if (!fitted) {
throw logic_error("You must call fit() before calling predict()");
}
vector<int> predictions; vector<int> predictions;
vector<int> sample; vector<int> sample;
for (int row = 0; row < tsamples[0].size(); ++row) { for (int row = 0; row < tsamples[0].size(); ++row) {
@ -176,6 +173,9 @@ namespace bayesnet {
} }
vector<vector<double>> Network::predict_proba(const vector<vector<int>>& tsamples) vector<vector<double>> Network::predict_proba(const vector<vector<int>>& tsamples)
{ {
if (!fitted) {
throw logic_error("You must call fit() before calling predict_proba()");
}
vector<vector<double>> predictions; vector<vector<double>> predictions;
vector<int> sample; vector<int> sample;
for (int row = 0; row < tsamples[0].size(); ++row) { for (int row = 0; row < tsamples[0].size(); ++row) {
@ -215,7 +215,7 @@ namespace bayesnet {
double Network::computeFactor(map<string, int>& completeEvidence) double Network::computeFactor(map<string, int>& completeEvidence)
{ {
double result = 1.0; double result = 1.0;
for (auto node : getNodes()) { for (auto& node : getNodes()) {
result *= node.second->getFactorValue(completeEvidence); result *= node.second->getFactorValue(completeEvidence);
} }
return result; return result;
@ -249,7 +249,7 @@ namespace bayesnet {
{ {
vector<string> result; vector<string> result;
// Draw the network // Draw the network
for (auto node : nodes) { for (auto& node : nodes) {
string line = node.first + " -> "; string line = node.first + " -> ";
for (auto child : node.second->getChildren()) { for (auto child : node.second->getChildren()) {
line += child->getName() + ", "; line += child->getName() + ", ";

View File

@ -7,8 +7,9 @@
namespace bayesnet { namespace bayesnet {
class Network { class Network {
private: private:
map<string, Node*> nodes; map<string, std::unique_ptr<Node>> nodes;
map<string, vector<int>> dataset; map<string, vector<int>> dataset;
bool fitted;
float maxThreads; float maxThreads;
int classNumStates; int classNumStates;
vector<string> features; vector<string> features;
@ -28,12 +29,11 @@ namespace bayesnet {
Network(float, int); Network(float, int);
Network(float); Network(float);
Network(Network&); Network(Network&);
~Network();
torch::Tensor& getSamples(); torch::Tensor& getSamples();
float getmaxThreads(); float getmaxThreads();
void addNode(string, int); void addNode(string, int);
void addEdge(const string, const string); void addEdge(const string, const string);
map<string, Node*>& getNodes(); map<string, std::unique_ptr<Node>>& getNodes();
vector<string> getFeatures(); vector<string> getFeatures();
int getStates(); int getStates();
int getClassNumStates(); int getClassNumStates();

View File

@ -6,12 +6,10 @@ namespace bayesnet {
: name(name), numStates(numStates), cpTable(torch::Tensor()), parents(vector<Node*>()), children(vector<Node*>()) : name(name), numStates(numStates), cpTable(torch::Tensor()), parents(vector<Node*>()), children(vector<Node*>())
{ {
} }
string Node::getName() const string Node::getName() const
{ {
return name; return name;
} }
void Node::addParent(Node* parent) void Node::addParent(Node* parent)
{ {
parents.push_back(parent); parents.push_back(parent);

31
src/utils.cc Normal file
View File

@ -0,0 +1,31 @@
#include <torch/torch.h>
#include <vector>
namespace bayesnet {
using namespace std;
using namespace torch;
vector<int> argsort(vector<float>& nums)
{
int n = nums.size();
vector<int> indices(n);
iota(indices.begin(), indices.end(), 0);
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices;
}
vector<vector<int>> tensorToVector(const Tensor& tensor)
{
// convert mxn tensor to nxm vector
vector<vector<int>> result;
auto tensor_accessor = tensor.accessor<int, 2>();
// Iterate over columns and rows of the tensor
for (int j = 0; j < tensor.size(1); ++j) {
vector<int> column;
for (int i = 0; i < tensor.size(0); ++i) {
column.push_back(tensor_accessor[i][j]);
}
result.push_back(column);
}
return result;
}
}

8
src/utils.h Normal file
View File

@ -0,0 +1,8 @@
namespace bayesnet {
using namespace std;
using namespace torch;
vector<int> argsort(vector<float>& nums);
vector<vector<int>> tensorToVector(const Tensor& tensor);
}