Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include <sstream>
8 : #include "bayesnet/utils/bayesnetUtils.h"
9 : #include "Classifier.h"
10 :
11 : namespace bayesnet {
12 4750 : Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
13 : const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
14 3413 : Classifier& Classifier::build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
15 : {
16 3413 : this->features = features;
17 3413 : this->className = className;
18 3413 : this->states = states;
19 3413 : m = dataset.size(1);
20 3413 : n = features.size();
21 3413 : checkFitParameters();
22 3325 : auto n_classes = states.at(className).size();
23 3325 : metrics = Metrics(dataset, features, className, n_classes);
24 3325 : model.initialize();
25 3325 : buildModel(weights);
26 3325 : trainModel(weights);
27 3277 : fitted = true;
28 3277 : return *this;
29 : }
30 888 : void Classifier::buildDataset(torch::Tensor& ytmp)
31 : {
32 : try {
33 888 : auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
34 2752 : dataset = torch::cat({ dataset, yresized }, 0);
35 888 : }
36 44 : catch (const std::exception& e) {
37 44 : std::stringstream oss;
38 44 : oss << "* Error in X and y dimensions *\n";
39 44 : oss << "X dimensions: " << dataset.sizes() << "\n";
40 44 : oss << "y dimensions: " << ytmp.sizes();
41 44 : throw std::runtime_error(oss.str());
42 88 : }
43 1776 : }
44 2951 : void Classifier::trainModel(const torch::Tensor& weights)
45 : {
46 2951 : model.fit(dataset, weights, features, className, states);
47 2951 : }
48 : // X is nxm where n is the number of features and m the number of samples
49 322 : Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
50 : {
51 322 : dataset = X;
52 322 : buildDataset(y);
53 300 : const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
54 512 : return build(features, className, states, weights);
55 300 : }
56 : // X is nxm where n is the number of features and m the number of samples
57 360 : Classifier& Classifier::fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
58 : {
59 360 : dataset = torch::zeros({ static_cast<int>(X.size()), static_cast<int>(X[0].size()) }, torch::kInt32);
60 5883 : for (int i = 0; i < X.size(); ++i) {
61 22092 : dataset.index_put_({ i, "..." }, torch::tensor(X[i], torch::kInt32));
62 : }
63 360 : auto ytmp = torch::tensor(y, torch::kInt32);
64 360 : buildDataset(ytmp);
65 338 : const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
66 628 : return build(features, className, states, weights);
67 5931 : }
68 1089 : Classifier& Classifier::fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
69 : {
70 1089 : this->dataset = dataset;
71 1089 : const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
72 2178 : return build(features, className, states, weights);
73 1089 : }
74 1686 : Classifier& Classifier::fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
75 : {
76 1686 : this->dataset = dataset;
77 1686 : return build(features, className, states, weights);
78 : }
79 3413 : void Classifier::checkFitParameters()
80 : {
81 3413 : if (torch::is_floating_point(dataset)) {
82 22 : throw std::invalid_argument("dataset (X, y) must be of type Integer");
83 : }
84 3391 : if (dataset.size(0) - 1 != features.size()) {
85 22 : throw std::invalid_argument("Classifier: X " + std::to_string(dataset.size(0) - 1) + " and features " + std::to_string(features.size()) + " must have the same number of features");
86 : }
87 3369 : if (states.find(className) == states.end()) {
88 22 : throw std::invalid_argument("class name not found in states");
89 : }
90 124581 : for (auto feature : features) {
91 121256 : if (states.find(feature) == states.end()) {
92 22 : throw std::invalid_argument("feature [" + feature + "] not found in states");
93 : }
94 121256 : }
95 3325 : }
96 3262 : torch::Tensor Classifier::predict(torch::Tensor& X)
97 : {
98 3262 : if (!fitted) {
99 44 : throw std::logic_error(CLASSIFIER_NOT_FITTED);
100 : }
101 3218 : return model.predict(X);
102 : }
103 44 : std::vector<int> Classifier::predict(std::vector<std::vector<int>>& X)
104 : {
105 44 : if (!fitted) {
106 22 : throw std::logic_error(CLASSIFIER_NOT_FITTED);
107 : }
108 22 : auto m_ = X[0].size();
109 22 : auto n_ = X.size();
110 22 : std::vector<std::vector<int>> Xd(n_, std::vector<int>(m_, 0));
111 110 : for (auto i = 0; i < n_; i++) {
112 176 : Xd[i] = std::vector<int>(X[i].begin(), X[i].end());
113 : }
114 22 : auto yp = model.predict(Xd);
115 44 : return yp;
116 22 : }
117 3562 : torch::Tensor Classifier::predict_proba(torch::Tensor& X)
118 : {
119 3562 : if (!fitted) {
120 22 : throw std::logic_error(CLASSIFIER_NOT_FITTED);
121 : }
122 3540 : return model.predict_proba(X);
123 : }
124 766 : std::vector<std::vector<double>> Classifier::predict_proba(std::vector<std::vector<int>>& X)
125 : {
126 766 : if (!fitted) {
127 22 : throw std::logic_error(CLASSIFIER_NOT_FITTED);
128 : }
129 744 : auto m_ = X[0].size();
130 744 : auto n_ = X.size();
131 744 : std::vector<std::vector<int>> Xd(n_, std::vector<int>(m_, 0));
132 : // Convert to nxm vector
133 9722 : for (auto i = 0; i < n_; i++) {
134 17956 : Xd[i] = std::vector<int>(X[i].begin(), X[i].end());
135 : }
136 744 : auto yp = model.predict_proba(Xd);
137 1488 : return yp;
138 744 : }
139 308 : float Classifier::score(torch::Tensor& X, torch::Tensor& y)
140 : {
141 308 : torch::Tensor y_pred = predict(X);
142 572 : return (y_pred == y).sum().item<float>() / y.size(0);
143 286 : }
144 44 : float Classifier::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
145 : {
146 44 : if (!fitted) {
147 22 : throw std::logic_error(CLASSIFIER_NOT_FITTED);
148 : }
149 22 : return model.score(X, y);
150 : }
151 66 : std::vector<std::string> Classifier::show() const
152 : {
153 66 : return model.show();
154 : }
155 2951 : void Classifier::addNodes()
156 : {
157 : // Add all nodes to the network
158 116009 : for (const auto& feature : features) {
159 113058 : model.addNode(feature);
160 : }
161 2951 : model.addNode(className);
162 2951 : }
163 475 : int Classifier::getNumberOfNodes() const
164 : {
165 : // Features does not include class
166 475 : return fitted ? model.getFeatures().size() : 0;
167 : }
168 475 : int Classifier::getNumberOfEdges() const
169 : {
170 475 : return fitted ? model.getNumEdges() : 0;
171 : }
172 66 : int Classifier::getNumberOfStates() const
173 : {
174 66 : return fitted ? model.getStates() : 0;
175 : }
176 877 : int Classifier::getClassNumStates() const
177 : {
178 877 : return fitted ? model.getClassNumStates() : 0;
179 : }
180 11 : std::vector<std::string> Classifier::topological_order()
181 : {
182 11 : return model.topological_sort();
183 : }
184 11 : std::string Classifier::dump_cpt() const
185 : {
186 11 : return model.dump_cpt();
187 : }
188 231 : void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
189 : {
190 231 : if (!hyperparameters.empty()) {
191 22 : throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
192 : }
193 209 : }
194 : }
|