diff --git a/sample/main.cc b/sample/main.cc index 69b8a5c..1b412df 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -231,6 +231,7 @@ int main(int argc, char** argv) cout << "BayesNet version: " << network.version() << endl; unsigned int nthreads = std::thread::hardware_concurrency(); cout << "Computer has " << nthreads << " cores." << endl; + cout << "****************** First ******************" << endl; auto metrics = bayesnet::Metrics(network.getSamples(), features, className, network.getClassNumStates()); cout << "conditionalEdgeWeight " << endl; auto conditional = metrics.conditionalEdgeWeights(); @@ -238,5 +239,13 @@ int main(int argc, char** argv) long m = features.size() + 1; auto matrix = torch::from_blob(conditional.data(), { m, m }); cout << matrix << endl; + cout << "****************** Second ******************" << endl; + auto metrics2 = bayesnet::Metrics(Xd, y, features, className, network.getClassNumStates()); + cout << "conditionalEdgeWeight " << endl; + auto conditional2 = metrics2.conditionalEdgeWeights(); + cout << conditional2 << endl; + long m2 = features.size() + 1; + auto matrix2 = torch::from_blob(conditional2.data(), { m, m }); + cout << matrix2 << endl; return 0; } \ No newline at end of file diff --git a/src/Metrics.cc b/src/Metrics.cc index 7448f46..d782d56 100644 --- a/src/Metrics.cc +++ b/src/Metrics.cc @@ -1,30 +1,23 @@ #include "Metrics.hpp" using namespace std; namespace bayesnet { - vector linearize(const vector>& vec_vec) - { - vector vec; - for (const auto& v : vec_vec) { - for (auto d : v) { - vec.push_back(d); - } - } - return vec; - } Metrics::Metrics(torch::Tensor& samples, vector& features, string& className, int classNumStates) : samples(samples) , features(features) , className(className) , classNumStates(classNumStates) { - } - Metrics::Metrics(vector>& vsamples, int m, int n, vector& features, string& className, int classNumStates) + Metrics::Metrics(vector>& vsamples, vector& labels, vector& features, string& className, int classNumStates) : features(features) , className(className) , classNumStates(classNumStates) { - samples = torch::from_blob(linearize(vsamples).data(), { m, n }); + samples = torch::zeros({ static_cast(vsamples[0].size()), static_cast(vsamples.size() + 1) }, torch::kInt64); + for (int i = 0; i < vsamples.size(); ++i) { + samples.index_put_({ "...", i }, torch::tensor(vsamples[i], torch::kInt64)); + } + samples.index_put_({ "...", -1 }, torch::tensor(labels, torch::kInt64)); } vector> Metrics::doCombinations(const vector& source) { diff --git a/src/Metrics.hpp b/src/Metrics.hpp index 58ace76..79ba46a 100644 --- a/src/Metrics.hpp +++ b/src/Metrics.hpp @@ -17,7 +17,7 @@ namespace bayesnet { double mutualInformation(torch::Tensor&, torch::Tensor&); public: Metrics(torch::Tensor&, vector&, string&, int); - Metrics(vector>&, int, int, vector&, string&, int); + Metrics(vector>&, vector&, vector&, string&, int); vector conditionalEdgeWeights(); }; }