Complete predict & predict_proba with voting & probabilities

This commit is contained in:
2024-02-23 23:11:14 +01:00
parent 52abd2d670
commit 8477698d8d
5 changed files with 161 additions and 283 deletions

View File

@@ -10,28 +10,39 @@ namespace bayesnet {
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices;
}
template<typename T>
std::vector<std::vector<T>> tensorToVector(torch::Tensor& dtensor)
std::vector<std::vector<int>> tensorToVector(torch::Tensor& dtensor)
{
// convert mxn tensor to nxm std::vector
std::vector<std::vector<T>> result;
std::vector<std::vector<int>> result;
// Iterate over cols
for (int i = 0; i < dtensor.size(1); ++i) {
auto col_tensor = dtensor.index({ "...", i });
auto col = std::vector<T>(col_tensor.data_ptr<T>(), col_tensor.data_ptr<T>() + dtensor.size(0));
auto col = std::vector<int>(col_tensor.data_ptr<int>(), col_tensor.data_ptr<int>() + dtensor.size(0));
result.push_back(col);
}
return result;
}
torch::Tensor vectorToTensor(std::vector<std::vector<int>>& vector)
std::vector<std::vector<double>> tensorToVectorDouble(torch::Tensor& dtensor)
{
// convert nxm std::vector to mxn tensor
long int m = vector[0].size();
long int n = vector.size();
// convert mxn tensor to mxn std::vector
std::vector<std::vector<double>> result;
// Iterate over cols
for (int i = 0; i < dtensor.size(0); ++i) {
auto col_tensor = dtensor.index({ i, "..." });
auto col = std::vector<double>(col_tensor.data_ptr<float>(), col_tensor.data_ptr<float>() + dtensor.size(1));
result.push_back(col);
}
return result;
}
torch::Tensor vectorToTensor(std::vector<std::vector<int>>& vector, bool transpose)
{
// convert nxm std::vector to mxn tensor if transpose
long int m = transpose ? vector[0].size() : vector.size();
long int n = transpose ? vector.size() : vector[0].size();
auto tensor = torch::zeros({ m, n }, torch::kInt32);
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
tensor[i][j] = vector[j][i];
tensor[i][j] = transpose ? vector[j][i] : vector[i][j];
}
}
return tensor;