Complete predict and score of kdb
Change new/delete to make_unique
This commit is contained in:
parent
6a8aad5911
commit
db6908acd0
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
@ -8,7 +8,7 @@
|
||||
"program": "${workspaceFolder}/build/sample/main",
|
||||
"args": [
|
||||
"-f",
|
||||
"glass"
|
||||
"iris"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"preLaunchTask": "CMake: build"
|
||||
|
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@ -96,7 +96,8 @@
|
||||
"csetjmp": "cpp",
|
||||
"future": "cpp",
|
||||
"queue": "cpp",
|
||||
"typeindex": "cpp"
|
||||
"typeindex": "cpp",
|
||||
"shared_mutex": "cpp"
|
||||
},
|
||||
"cmake.configureOnOpen": false,
|
||||
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
|
||||
|
@ -96,15 +96,16 @@ pair<vector<mdlp::labels_t>, map<string, int>> discretize(vector<mdlp::samples_t
|
||||
void showNodesInfo(bayesnet::Network& network, string className)
|
||||
{
|
||||
cout << "Nodes:" << endl;
|
||||
for (auto [name, item] : network.getNodes()) {
|
||||
cout << "*" << item->getName() << " States -> " << item->getNumStates() << endl;
|
||||
for (auto& node : network.getNodes()) {
|
||||
auto name = node.first;
|
||||
cout << "*" << node.second->getName() << " States -> " << node.second->getNumStates() << endl;
|
||||
cout << "-Parents:";
|
||||
for (auto parent : item->getParents()) {
|
||||
for (auto parent : node.second->getParents()) {
|
||||
cout << " " << parent->getName();
|
||||
}
|
||||
cout << endl;
|
||||
cout << "-Children:";
|
||||
for (auto child : item->getChildren()) {
|
||||
for (auto child : node.second->getChildren()) {
|
||||
cout << " " << child->getName();
|
||||
}
|
||||
cout << endl;
|
||||
@ -113,7 +114,7 @@ void showNodesInfo(bayesnet::Network& network, string className)
|
||||
void showCPDS(bayesnet::Network& network)
|
||||
{
|
||||
cout << "CPDs:" << endl;
|
||||
auto nodes = network.getNodes();
|
||||
auto& nodes = network.getNodes();
|
||||
for (auto it = nodes.begin(); it != nodes.end(); it++) {
|
||||
cout << "* Name: " << it->first << " " << it->second->getName() << " -> " << it->second->getNumStates() << endl;
|
||||
cout << "Parents: ";
|
||||
@ -253,12 +254,14 @@ int main(int argc, char** argv)
|
||||
for (auto feature : features) {
|
||||
states[feature] = vector<int>(maxes[feature]);
|
||||
}
|
||||
states[className] = vector<int>(maxes[className]);
|
||||
states[className] = vector<int>(
|
||||
maxes[className]);
|
||||
auto kdb = bayesnet::KDB(2);
|
||||
kdb.fit(Xd, y, features, className, states);
|
||||
for (auto line : kdb.show()) {
|
||||
cout << line << endl;
|
||||
}
|
||||
cout << "Score: " << kdb.score(Xd, y) << endl;
|
||||
cout << "****************** KDB ******************" << endl;
|
||||
return 0;
|
||||
}
|
16
src/AODE.cc
Normal file
16
src/AODE.cc
Normal file
@ -0,0 +1,16 @@
|
||||
#include "AODE.h"
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
AODE::AODE() : Ensemble()
|
||||
{
|
||||
models = vector<SPODE>();
|
||||
}
|
||||
void AODE::train()
|
||||
{
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
SPODE model = SPODE(i);
|
||||
models.push_back(model);
|
||||
}
|
||||
}
|
||||
}
|
13
src/AODE.h
Normal file
13
src/AODE.h
Normal file
@ -0,0 +1,13 @@
|
||||
#ifndef AODE_H
|
||||
#define AODE_H
|
||||
#include "Ensemble.h"
|
||||
#include "SPODE.h"
|
||||
namespace bayesnet {
|
||||
class AODE : public Ensemble {
|
||||
protected:
|
||||
void train() override;
|
||||
public:
|
||||
AODE();
|
||||
};
|
||||
}
|
||||
#endif
|
@ -1,10 +1,11 @@
|
||||
#include "BaseClassifier.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
|
||||
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()) {}
|
||||
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
|
||||
BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
{
|
||||
|
||||
@ -16,21 +17,19 @@ namespace bayesnet {
|
||||
auto n_classes = states[className].size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
train();
|
||||
model.fit(Xv, yv, features, className);
|
||||
fitted = true;
|
||||
return *this;
|
||||
}
|
||||
BaseClassifier& BaseClassifier::fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
{
|
||||
this->X = X;
|
||||
this->y = y;
|
||||
return build(features, className, states);
|
||||
}
|
||||
BaseClassifier& BaseClassifier::fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
{
|
||||
this->X = torch::zeros({ static_cast<int64_t>(X[0].size()), static_cast<int64_t>(X.size()) }, kInt64);
|
||||
Xv = X;
|
||||
for (int i = 0; i < X.size(); ++i) {
|
||||
this->X.index_put_({ "...", i }, torch::tensor(X[i], kInt64));
|
||||
}
|
||||
this->y = torch::tensor(y, kInt64);
|
||||
yv = y;
|
||||
return build(features, className, states);
|
||||
}
|
||||
void BaseClassifier::checkFitParameters()
|
||||
@ -53,55 +52,44 @@ namespace bayesnet {
|
||||
}
|
||||
}
|
||||
}
|
||||
vector<int> BaseClassifier::argsort(vector<float>& nums)
|
||||
{
|
||||
int n = nums.size();
|
||||
vector<int> indices(n);
|
||||
iota(indices.begin(), indices.end(), 0);
|
||||
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
|
||||
return indices;
|
||||
}
|
||||
vector<vector<int>> tensorToVector(const torch::Tensor& tensor)
|
||||
{
|
||||
// convert mxn tensor to nxm vector
|
||||
vector<vector<int>> result;
|
||||
auto tensor_accessor = tensor.accessor<int, 2>();
|
||||
|
||||
// Iterate over columns and rows of the tensor
|
||||
for (int j = 0; j < tensor.size(1); ++j) {
|
||||
vector<int> column;
|
||||
for (int i = 0; i < tensor.size(0); ++i) {
|
||||
column.push_back(tensor_accessor[i][j]);
|
||||
}
|
||||
result.push_back(column);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
Tensor BaseClassifier::predict(Tensor& X)
|
||||
{
|
||||
auto n_models = models.size();
|
||||
Tensor y_pred = torch::zeros({ X.size(0), n_models }, torch::kInt64);
|
||||
for (auto i = 0; i < n_models; ++i) {
|
||||
y_pred.index_put_({ "...", i }, models[i].predict(X));
|
||||
if (!fitted) {
|
||||
throw logic_error("Classifier has not been fitted");
|
||||
}
|
||||
auto y_pred_ = y_pred.accessor<int64_t, 2>();
|
||||
vector<int> y_pred_final;
|
||||
for (int i = 0; i < y_pred.size(0); ++i) {
|
||||
vector<float> votes(states[className].size(), 0);
|
||||
for (int j = 0; j < y_pred.size(1); ++j) {
|
||||
votes[y_pred_[i][j]] += 1;
|
||||
}
|
||||
auto indices = argsort(votes);
|
||||
y_pred_final.push_back(indices[0]);
|
||||
auto m_ = X.size(0);
|
||||
auto n_ = X.size(1);
|
||||
vector<vector<int>> Xd(n_, vector<int>(m_, 0));
|
||||
for (auto i = 0; i < n_; i++) {
|
||||
auto temp = X.index({ "...", i });
|
||||
Xd[i] = vector<int>(temp.data_ptr<int>(), temp.data_ptr<int>() + m_);
|
||||
}
|
||||
return torch::tensor(y_pred_final, torch::kInt64);
|
||||
auto yp = model.predict(Xd);
|
||||
auto ypred = torch::tensor(yp, torch::kInt64);
|
||||
return ypred;
|
||||
}
|
||||
float BaseClassifier::score(Tensor& X, Tensor& y)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw logic_error("Classifier has not been fitted");
|
||||
}
|
||||
Tensor y_pred = predict(X);
|
||||
return (y_pred == y).sum().item<float>() / y.size(0);
|
||||
}
|
||||
float BaseClassifier::score(vector<vector<int>>& X, vector<int>& y)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw logic_error("Classifier has not been fitted");
|
||||
}
|
||||
auto m_ = X[0].size();
|
||||
auto n_ = X.size();
|
||||
vector<vector<int>> Xd(n_, vector<int>(m_, 0));
|
||||
for (auto i = 0; i < n_; i++) {
|
||||
Xd[i] = vector<int>(X[i].begin(), X[i].end());
|
||||
}
|
||||
return model.score(Xd, y);
|
||||
}
|
||||
vector<string> BaseClassifier::show()
|
||||
{
|
||||
return model.show();
|
||||
|
@ -9,12 +9,15 @@ using namespace torch;
|
||||
namespace bayesnet {
|
||||
class BaseClassifier {
|
||||
private:
|
||||
bool fitted;
|
||||
BaseClassifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
|
||||
protected:
|
||||
Network model;
|
||||
int m, n; // m: number of samples, n: number of features
|
||||
Tensor X;
|
||||
vector<vector<int>> Xv;
|
||||
Tensor y;
|
||||
vector<int> yv;
|
||||
Tensor dataset;
|
||||
Metrics metrics;
|
||||
vector<string> features;
|
||||
@ -24,13 +27,13 @@ namespace bayesnet {
|
||||
virtual void train() = 0;
|
||||
public:
|
||||
BaseClassifier(Network model);
|
||||
BaseClassifier& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states);
|
||||
virtual ~BaseClassifier() = default;
|
||||
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
|
||||
void addNodes();
|
||||
Tensor predict(Tensor& X);
|
||||
float score(Tensor& X, Tensor& y);
|
||||
float score(vector<vector<int>>& X, vector<int>& y);
|
||||
vector<string> show();
|
||||
vector<int> argsort(vector<float>& nums);
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
@ -1,2 +1,2 @@
|
||||
add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc)
|
||||
add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc)
|
||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
@ -4,17 +4,22 @@ namespace bayesnet {
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
|
||||
Ensemble::Ensemble(BaseClassifier& model) : model(model), models(vector<BaseClassifier>()), m(0), n(0), metrics(Metrics()) {}
|
||||
Ensemble::Ensemble() : m(0), n(0), n_models(0), metrics(Metrics()) {}
|
||||
Ensemble& Ensemble::build(vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
{
|
||||
|
||||
dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1);
|
||||
this->features = features;
|
||||
this->className = className;
|
||||
this->states = states;
|
||||
auto n_classes = states[className].size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
// Build models
|
||||
train();
|
||||
// Train models
|
||||
n_models = models.size();
|
||||
for (auto i = 0; i < n_models; ++i) {
|
||||
models[i].fit(X, y, features, className, states);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
Ensemble& Ensemble::fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
@ -37,16 +42,21 @@ namespace bayesnet {
|
||||
}
|
||||
Tensor Ensemble::predict(Tensor& X)
|
||||
{
|
||||
auto m_ = X.size(0);
|
||||
auto n_ = X.size(1);
|
||||
vector<vector<int>> Xd(n_, vector<int>(m_, 0));
|
||||
for (auto i = 0; i < n_; i++) {
|
||||
auto temp = X.index({ "...", i });
|
||||
Xd[i] = vector<int>(temp.data_ptr<int>(), temp.data_ptr<int>() + m_);
|
||||
Tensor y_pred = torch::zeros({ X.size(0), n_models }, torch::kInt64);
|
||||
for (auto i = 0; i < n_models; ++i) {
|
||||
y_pred.index_put_({ "...", i }, models[i].predict(X));
|
||||
}
|
||||
auto yp = model.predict(Xd);
|
||||
auto ypred = torch::tensor(yp, torch::kInt64);
|
||||
return ypred;
|
||||
auto y_pred_ = y_pred.accessor<int64_t, 2>();
|
||||
vector<int> y_pred_final;
|
||||
for (int i = 0; i < y_pred.size(0); ++i) {
|
||||
vector<float> votes(states[className].size(), 0);
|
||||
for (int j = 0; j < y_pred.size(1); ++j) {
|
||||
votes[y_pred_[i][j]] += 1;
|
||||
}
|
||||
auto indices = argsort(votes);
|
||||
y_pred_final.push_back(indices[0]);
|
||||
}
|
||||
return torch::tensor(y_pred_final, torch::kInt64);
|
||||
}
|
||||
float Ensemble::score(Tensor& X, Tensor& y)
|
||||
{
|
||||
@ -55,7 +65,11 @@ namespace bayesnet {
|
||||
}
|
||||
vector<string> Ensemble::show()
|
||||
{
|
||||
return model.show();
|
||||
vector<string> result;
|
||||
for (auto i = 0; i < n_models; ++i) {
|
||||
auto res = models[i].show();
|
||||
result.insert(result.end(), res.begin(), res.end());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
@ -3,15 +3,16 @@
|
||||
#include <torch/torch.h>
|
||||
#include "BaseClassifier.h"
|
||||
#include "Metrics.hpp"
|
||||
#include "utils.h"
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
|
||||
namespace bayesnet {
|
||||
class Ensemble {
|
||||
private:
|
||||
long n_models;
|
||||
Ensemble& build(vector<string>& features, string className, map<string, vector<int>>& states);
|
||||
protected:
|
||||
BaseClassifier& model;
|
||||
vector<BaseClassifier> models;
|
||||
int m, n; // m: number of samples, n: number of features
|
||||
Tensor X;
|
||||
@ -23,7 +24,8 @@ namespace bayesnet {
|
||||
map<string, vector<int>> states;
|
||||
void virtual train() = 0;
|
||||
public:
|
||||
Ensemble(BaseClassifier& model);
|
||||
Ensemble();
|
||||
virtual ~Ensemble() = default;
|
||||
Ensemble& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states);
|
||||
Ensemble& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
|
||||
Tensor predict(Tensor& X);
|
||||
|
@ -1,6 +1,7 @@
|
||||
#ifndef KDB_H
|
||||
#define KDB_H
|
||||
#include "BaseClassifier.h"
|
||||
#include "utils.h"
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
|
@ -2,19 +2,13 @@
|
||||
#include <mutex>
|
||||
#include "Network.h"
|
||||
namespace bayesnet {
|
||||
Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {}
|
||||
Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
||||
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads())
|
||||
Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8), fitted(false) {}
|
||||
Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
|
||||
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
|
||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()), fitted(other.fitted)
|
||||
{
|
||||
for (auto& pair : other.nodes) {
|
||||
nodes[pair.first] = new Node(*pair.second);
|
||||
}
|
||||
}
|
||||
Network::~Network()
|
||||
{
|
||||
for (auto& pair : nodes) {
|
||||
delete pair.second;
|
||||
nodes[pair.first] = make_unique<Node>(*pair.second);
|
||||
}
|
||||
}
|
||||
float Network::getmaxThreads()
|
||||
@ -32,7 +26,7 @@ namespace bayesnet {
|
||||
nodes[name]->setNumStates(numStates);
|
||||
return;
|
||||
}
|
||||
nodes[name] = new Node(name, numStates);
|
||||
nodes[name] = make_unique<Node>(name, numStates);
|
||||
}
|
||||
vector<string> Network::getFeatures()
|
||||
{
|
||||
@ -45,7 +39,7 @@ namespace bayesnet {
|
||||
int Network::getStates()
|
||||
{
|
||||
int result = 0;
|
||||
for (auto node : nodes) {
|
||||
for (auto& node : nodes) {
|
||||
result += node.second->getNumStates();
|
||||
}
|
||||
return result;
|
||||
@ -79,20 +73,20 @@ namespace bayesnet {
|
||||
throw invalid_argument("Child node " + child + " does not exist");
|
||||
}
|
||||
// Temporarily add edge to check for cycles
|
||||
nodes[parent]->addChild(nodes[child]);
|
||||
nodes[child]->addParent(nodes[parent]);
|
||||
nodes[parent]->addChild(nodes[child].get());
|
||||
nodes[child]->addParent(nodes[parent].get());
|
||||
unordered_set<string> visited;
|
||||
unordered_set<string> recStack;
|
||||
if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
|
||||
{
|
||||
// remove problematic edge
|
||||
nodes[parent]->removeChild(nodes[child]);
|
||||
nodes[child]->removeParent(nodes[parent]);
|
||||
nodes[parent]->removeChild(nodes[child].get());
|
||||
nodes[child]->removeParent(nodes[parent].get());
|
||||
throw invalid_argument("Adding this edge forms a cycle in the graph.");
|
||||
}
|
||||
|
||||
}
|
||||
map<string, Node*>& Network::getNodes()
|
||||
map<string, std::unique_ptr<Node>>& Network::getNodes()
|
||||
{
|
||||
return nodes;
|
||||
}
|
||||
@ -140,9 +134,8 @@ namespace bayesnet {
|
||||
lock.unlock();
|
||||
|
||||
pair.second->computeCPT(dataset, laplaceSmoothing);
|
||||
|
||||
lock.lock();
|
||||
nodes[pair.first] = pair.second;
|
||||
nodes[pair.first] = std::move(pair.second);
|
||||
lock.unlock();
|
||||
}
|
||||
lock_guard<mutex> lock(mtx);
|
||||
@ -155,10 +148,14 @@ namespace bayesnet {
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
fitted = true;
|
||||
}
|
||||
|
||||
vector<int> Network::predict(const vector<vector<int>>& tsamples)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw logic_error("You must call fit() before calling predict()");
|
||||
}
|
||||
vector<int> predictions;
|
||||
vector<int> sample;
|
||||
for (int row = 0; row < tsamples[0].size(); ++row) {
|
||||
@ -176,6 +173,9 @@ namespace bayesnet {
|
||||
}
|
||||
vector<vector<double>> Network::predict_proba(const vector<vector<int>>& tsamples)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw logic_error("You must call fit() before calling predict_proba()");
|
||||
}
|
||||
vector<vector<double>> predictions;
|
||||
vector<int> sample;
|
||||
for (int row = 0; row < tsamples[0].size(); ++row) {
|
||||
@ -215,7 +215,7 @@ namespace bayesnet {
|
||||
double Network::computeFactor(map<string, int>& completeEvidence)
|
||||
{
|
||||
double result = 1.0;
|
||||
for (auto node : getNodes()) {
|
||||
for (auto& node : getNodes()) {
|
||||
result *= node.second->getFactorValue(completeEvidence);
|
||||
}
|
||||
return result;
|
||||
@ -249,7 +249,7 @@ namespace bayesnet {
|
||||
{
|
||||
vector<string> result;
|
||||
// Draw the network
|
||||
for (auto node : nodes) {
|
||||
for (auto& node : nodes) {
|
||||
string line = node.first + " -> ";
|
||||
for (auto child : node.second->getChildren()) {
|
||||
line += child->getName() + ", ";
|
||||
|
@ -7,8 +7,9 @@
|
||||
namespace bayesnet {
|
||||
class Network {
|
||||
private:
|
||||
map<string, Node*> nodes;
|
||||
map<string, std::unique_ptr<Node>> nodes;
|
||||
map<string, vector<int>> dataset;
|
||||
bool fitted;
|
||||
float maxThreads;
|
||||
int classNumStates;
|
||||
vector<string> features;
|
||||
@ -28,12 +29,11 @@ namespace bayesnet {
|
||||
Network(float, int);
|
||||
Network(float);
|
||||
Network(Network&);
|
||||
~Network();
|
||||
torch::Tensor& getSamples();
|
||||
float getmaxThreads();
|
||||
void addNode(string, int);
|
||||
void addEdge(const string, const string);
|
||||
map<string, Node*>& getNodes();
|
||||
map<string, std::unique_ptr<Node>>& getNodes();
|
||||
vector<string> getFeatures();
|
||||
int getStates();
|
||||
int getClassNumStates();
|
||||
|
@ -6,12 +6,10 @@ namespace bayesnet {
|
||||
: name(name), numStates(numStates), cpTable(torch::Tensor()), parents(vector<Node*>()), children(vector<Node*>())
|
||||
{
|
||||
}
|
||||
|
||||
string Node::getName() const
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
void Node::addParent(Node* parent)
|
||||
{
|
||||
parents.push_back(parent);
|
||||
|
31
src/utils.cc
Normal file
31
src/utils.cc
Normal file
@ -0,0 +1,31 @@
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
vector<int> argsort(vector<float>& nums)
|
||||
{
|
||||
int n = nums.size();
|
||||
vector<int> indices(n);
|
||||
iota(indices.begin(), indices.end(), 0);
|
||||
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
|
||||
return indices;
|
||||
}
|
||||
vector<vector<int>> tensorToVector(const Tensor& tensor)
|
||||
{
|
||||
// convert mxn tensor to nxm vector
|
||||
vector<vector<int>> result;
|
||||
auto tensor_accessor = tensor.accessor<int, 2>();
|
||||
|
||||
// Iterate over columns and rows of the tensor
|
||||
for (int j = 0; j < tensor.size(1); ++j) {
|
||||
vector<int> column;
|
||||
for (int i = 0; i < tensor.size(0); ++i) {
|
||||
column.push_back(tensor_accessor[i][j]);
|
||||
}
|
||||
result.push_back(column);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
8
src/utils.h
Normal file
8
src/utils.h
Normal file
@ -0,0 +1,8 @@
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
vector<int> argsort(vector<float>& nums);
|
||||
|
||||
vector<vector<int>> tensorToVector(const Tensor& tensor);
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user