From a60b06e2f26b647239064d0c76d85f7ceafd55d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Wed, 12 Jul 2023 01:05:24 +0200 Subject: [PATCH] refactor to use in python --- sample/main.cc | 7 ++++++- sample/test.cc | 3 +++ src/Metrics.cc | 23 +++++++++++++++++++++-- src/Metrics.hpp | 5 +++-- 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/sample/main.cc b/sample/main.cc index c1b1e1f..69b8a5c 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -232,6 +232,11 @@ int main(int argc, char** argv) unsigned int nthreads = std::thread::hardware_concurrency(); cout << "Computer has " << nthreads << " cores." << endl; auto metrics = bayesnet::Metrics(network.getSamples(), features, className, network.getClassNumStates()); - cout << "conditionalEdgeWeight " << endl << metrics.conditionalEdgeWeight() << endl; + cout << "conditionalEdgeWeight " << endl; + auto conditional = metrics.conditionalEdgeWeights(); + cout << conditional << endl; + long m = features.size() + 1; + auto matrix = torch::from_blob(conditional.data(), { m, m }); + cout << matrix << endl; return 0; } \ No newline at end of file diff --git a/sample/test.cc b/sample/test.cc index e6b946f..1757026 100644 --- a/sample/test.cc +++ b/sample/test.cc @@ -167,6 +167,9 @@ int main() } // Print the resulting 3x3 tensor std::cout << tensor_3x3 << std::endl; + vector v = { 1,2,3,4,5 }; + torch::Tensor t = torch::tensor(v); + cout << t << endl; diff --git a/src/Metrics.cc b/src/Metrics.cc index 40e485f..7448f46 100644 --- a/src/Metrics.cc +++ b/src/Metrics.cc @@ -1,12 +1,30 @@ #include "Metrics.hpp" using namespace std; namespace bayesnet { + vector linearize(const vector>& vec_vec) + { + vector vec; + for (const auto& v : vec_vec) { + for (auto d : v) { + vec.push_back(d); + } + } + return vec; + } Metrics::Metrics(torch::Tensor& samples, vector& features, string& className, int classNumStates) : samples(samples) , features(features) , className(className) , classNumStates(classNumStates) { + + } + Metrics::Metrics(vector>& vsamples, int m, int n, vector& features, string& className, int classNumStates) + : features(features) + , className(className) + , classNumStates(classNumStates) + { + samples = torch::from_blob(linearize(vsamples).data(), { m, n }); } vector> Metrics::doCombinations(const vector& source) { @@ -19,7 +37,7 @@ namespace bayesnet { } return result; } - torch::Tensor Metrics::conditionalEdgeWeight() + vector Metrics::conditionalEdgeWeights() { auto result = vector(); auto source = vector(features); @@ -54,7 +72,8 @@ namespace bayesnet { matrix[x][y] = result[i]; matrix[y][x] = result[i]; } - return matrix; + std::vector v(matrix.data_ptr(), matrix.data_ptr() + matrix.numel()); + return v; } double Metrics::entropy(torch::Tensor& feature) { diff --git a/src/Metrics.hpp b/src/Metrics.hpp index c37e951..58ace76 100644 --- a/src/Metrics.hpp +++ b/src/Metrics.hpp @@ -7,7 +7,7 @@ using namespace std; namespace bayesnet { class Metrics { private: - torch::Tensor& samples; + torch::Tensor samples; vector& features; string& className; int classNumStates; @@ -17,7 +17,8 @@ namespace bayesnet { double mutualInformation(torch::Tensor&, torch::Tensor&); public: Metrics(torch::Tensor&, vector&, string&, int); - torch::Tensor conditionalEdgeWeight(); + Metrics(vector>&, int, int, vector&, string&, int); + vector conditionalEdgeWeights(); }; } #endif \ No newline at end of file