29 KiB
29 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #include <sstream> 8 : #include "bayesnet/utils/bayesnetUtils.h" 9 : #include "Classifier.h" 10 : 11 : namespace bayesnet { 12 4750 : Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {} 13 : const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted"; 14 3413 : Classifier& Classifier::build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) 15 : { 16 3413 : this->features = features; 17 3413 : this->className = className; 18 3413 : this->states = states; 19 3413 : m = dataset.size(1); 20 3413 : n = features.size(); 21 3413 : checkFitParameters(); 22 3325 : auto n_classes = states.at(className).size(); 23 3325 : metrics = Metrics(dataset, features, className, n_classes); 24 3325 : model.initialize(); 25 3325 : buildModel(weights); 26 3325 : trainModel(weights); 27 3277 : fitted = true; 28 3277 : return *this; 29 : } 30 888 : void Classifier::buildDataset(torch::Tensor& ytmp) 31 : { 32 : try { 33 888 : auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1); 34 2752 : dataset = torch::cat({ dataset, yresized }, 0); 35 888 : } 36 44 : catch (const std::exception& e) { 37 44 : std::stringstream oss; 38 44 : oss << "* Error in X and y dimensions *\n"; 39 44 : oss << "X dimensions: " << dataset.sizes() << "\n"; 40 44 : oss << "y dimensions: " << ytmp.sizes(); 41 44 : throw std::runtime_error(oss.str()); 42 88 : } 43 1776 : } 44 2951 : void Classifier::trainModel(const torch::Tensor& weights) 45 : { 46 2951 : model.fit(dataset, weights, features, className, states); 47 2951 : } 48 : // X is nxm where n is the number of features and m the number of samples 49 322 : Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) 50 : { 51 322 : dataset = X; 52 322 : buildDataset(y); 53 300 : const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble); 54 512 : return build(features, className, states, weights); 55 300 : } 56 : // X is nxm where n is the number of features and m the number of samples 57 360 : Classifier& Classifier::fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) 58 : { 59 360 : dataset = torch::zeros({ static_cast<int>(X.size()), static_cast<int>(X[0].size()) }, torch::kInt32); 60 5883 : for (int i = 0; i < X.size(); ++i) { 61 22092 : dataset.index_put_({ i, "..." }, torch::tensor(X[i], torch::kInt32)); 62 : } 63 360 : auto ytmp = torch::tensor(y, torch::kInt32); 64 360 : buildDataset(ytmp); 65 338 : const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble); 66 628 : return build(features, className, states, weights); 67 5931 : } 68 1089 : Classifier& Classifier::fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) 69 : { 70 1089 : this->dataset = dataset; 71 1089 : const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble); 72 2178 : return build(features, className, states, weights); 73 1089 : } 74 1686 : Classifier& Classifier::fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) 75 : { 76 1686 : this->dataset = dataset; 77 1686 : return build(features, className, states, weights); 78 : } 79 3413 : void Classifier::checkFitParameters() 80 : { 81 3413 : if (torch::is_floating_point(dataset)) { 82 22 : throw std::invalid_argument("dataset (X, y) must be of type Integer"); 83 : } 84 3391 : if (dataset.size(0) - 1 != features.size()) { 85 22 : throw std::invalid_argument("Classifier: X " + std::to_string(dataset.size(0) - 1) + " and features " + std::to_string(features.size()) + " must have the same number of features"); 86 : } 87 3369 : if (states.find(className) == states.end()) { 88 22 : throw std::invalid_argument("class name not found in states"); 89 : } 90 124581 : for (auto feature : features) { 91 121256 : if (states.find(feature) == states.end()) { 92 22 : throw std::invalid_argument("feature [" + feature + "] not found in states"); 93 : } 94 121256 : } 95 3325 : } 96 3262 : torch::Tensor Classifier::predict(torch::Tensor& X) 97 : { 98 3262 : if (!fitted) { 99 44 : throw std::logic_error(CLASSIFIER_NOT_FITTED); 100 : } 101 3218 : return model.predict(X); 102 : } 103 44 : std::vector<int> Classifier::predict(std::vector<std::vector<int>>& X) 104 : { 105 44 : if (!fitted) { 106 22 : throw std::logic_error(CLASSIFIER_NOT_FITTED); 107 : } 108 22 : auto m_ = X[0].size(); 109 22 : auto n_ = X.size(); 110 22 : std::vector<std::vector<int>> Xd(n_, std::vector<int>(m_, 0)); 111 110 : for (auto i = 0; i < n_; i++) { 112 176 : Xd[i] = std::vector<int>(X[i].begin(), X[i].end()); 113 : } 114 22 : auto yp = model.predict(Xd); 115 44 : return yp; 116 22 : } 117 3562 : torch::Tensor Classifier::predict_proba(torch::Tensor& X) 118 : { 119 3562 : if (!fitted) { 120 22 : throw std::logic_error(CLASSIFIER_NOT_FITTED); 121 : } 122 3540 : return model.predict_proba(X); 123 : } 124 766 : std::vector<std::vector<double>> Classifier::predict_proba(std::vector<std::vector<int>>& X) 125 : { 126 766 : if (!fitted) { 127 22 : throw std::logic_error(CLASSIFIER_NOT_FITTED); 128 : } 129 744 : auto m_ = X[0].size(); 130 744 : auto n_ = X.size(); 131 744 : std::vector<std::vector<int>> Xd(n_, std::vector<int>(m_, 0)); 132 : // Convert to nxm vector 133 9722 : for (auto i = 0; i < n_; i++) { 134 17956 : Xd[i] = std::vector<int>(X[i].begin(), X[i].end()); 135 : } 136 744 : auto yp = model.predict_proba(Xd); 137 1488 : return yp; 138 744 : } 139 308 : float Classifier::score(torch::Tensor& X, torch::Tensor& y) 140 : { 141 308 : torch::Tensor y_pred = predict(X); 142 572 : return (y_pred == y).sum().item<float>() / y.size(0); 143 286 : } 144 44 : float Classifier::score(std::vector<std::vector<int>>& X, std::vector<int>& y) 145 : { 146 44 : if (!fitted) { 147 22 : throw std::logic_error(CLASSIFIER_NOT_FITTED); 148 : } 149 22 : return model.score(X, y); 150 : } 151 66 : std::vector<std::string> Classifier::show() const 152 : { 153 66 : return model.show(); 154 : } 155 2951 : void Classifier::addNodes() 156 : { 157 : // Add all nodes to the network 158 116009 : for (const auto& feature : features) { 159 113058 : model.addNode(feature); 160 : } 161 2951 : model.addNode(className); 162 2951 : } 163 475 : int Classifier::getNumberOfNodes() const 164 : { 165 : // Features does not include class 166 475 : return fitted ? model.getFeatures().size() : 0; 167 : } 168 475 : int Classifier::getNumberOfEdges() const 169 : { 170 475 : return fitted ? model.getNumEdges() : 0; 171 : } 172 66 : int Classifier::getNumberOfStates() const 173 : { 174 66 : return fitted ? model.getStates() : 0; 175 : } 176 877 : int Classifier::getClassNumStates() const 177 : { 178 877 : return fitted ? model.getClassNumStates() : 0; 179 : } 180 11 : std::vector<std::string> Classifier::topological_order() 181 : { 182 11 : return model.topological_sort(); 183 : } 184 11 : std::string Classifier::dump_cpt() const 185 : { 186 11 : return model.dump_cpt(); 187 : } 188 231 : void Classifier::setHyperparameters(const nlohmann::json& hyperparameters) 189 : { 190 231 : if (!hyperparameters.empty()) { 191 22 : throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump()); 192 : } 193 209 : } 194 : } |
![]() |
Generated by: LCOV version 2.0-1 |
</html>