Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 :
8 : #include "bayesnetUtils.h"
9 : namespace bayesnet {
10 : // Return the indices in descending order
11 203 : std::vector<int> argsort(std::vector<double>& nums)
12 : {
13 203 : int n = nums.size();
14 203 : std::vector<int> indices(n);
15 203 : iota(indices.begin(), indices.end(), 0);
16 4244 : sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
17 406 : return indices;
18 203 : }
19 32 : std::vector<std::vector<double>> tensorToVectorDouble(torch::Tensor& dtensor)
20 : {
21 : // convert mxn tensor to mxn std::vector
22 32 : std::vector<std::vector<double>> result;
23 : // Iterate over cols
24 8072 : for (int i = 0; i < dtensor.size(0); ++i) {
25 24120 : auto col_tensor = dtensor.index({ i, "..." });
26 8040 : auto col = std::vector<double>(col_tensor.data_ptr<float>(), col_tensor.data_ptr<float>() + dtensor.size(1));
27 8040 : result.push_back(col);
28 8040 : }
29 64 : return result;
30 8072 : }
31 40 : torch::Tensor vectorToTensor(std::vector<std::vector<int>>& vector, bool transpose)
32 : {
33 : // convert nxm std::vector to mxn tensor if transpose
34 40 : long int m = transpose ? vector[0].size() : vector.size();
35 40 : long int n = transpose ? vector.size() : vector[0].size();
36 40 : auto tensor = torch::zeros({ m, n }, torch::kInt32);
37 276 : for (int i = 0; i < m; ++i) {
38 57664 : for (int j = 0; j < n; ++j) {
39 57428 : tensor[i][j] = transpose ? vector[j][i] : vector[i][j];
40 : }
41 : }
42 80 : return tensor;
43 40 : }
44 : }
|