BayesNet 1.0.5
Bayesian Network Classifiers using libtorch from scratch
Loading...
Searching...
No Matches
Classifier.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include <sstream>
8#include "bayesnet/utils/bayesnetUtils.h"
9#include "Classifier.h"
10
11namespace bayesnet {
12 Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
13 const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
14 Classifier& Classifier::build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
15 {
16 this->features = features;
17 this->className = className;
18 this->states = states;
19 m = dataset.size(1);
20 n = features.size();
21 checkFitParameters();
22 auto n_classes = states.at(className).size();
23 metrics = Metrics(dataset, features, className, n_classes);
24 model.initialize();
25 buildModel(weights);
26 trainModel(weights);
27 fitted = true;
28 return *this;
29 }
30 void Classifier::buildDataset(torch::Tensor& ytmp)
31 {
32 try {
33 auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
34 dataset = torch::cat({ dataset, yresized }, 0);
35 }
36 catch (const std::exception& e) {
37 std::stringstream oss;
38 oss << "* Error in X and y dimensions *\n";
39 oss << "X dimensions: " << dataset.sizes() << "\n";
40 oss << "y dimensions: " << ytmp.sizes();
41 throw std::runtime_error(oss.str());
42 }
43 }
44 void Classifier::trainModel(const torch::Tensor& weights)
45 {
46 model.fit(dataset, weights, features, className, states);
47 }
48 // X is nxm where n is the number of features and m the number of samples
49 Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
50 {
51 dataset = X;
52 buildDataset(y);
53 const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
54 return build(features, className, states, weights);
55 }
56 // X is nxm where n is the number of features and m the number of samples
57 Classifier& Classifier::fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
58 {
59 dataset = torch::zeros({ static_cast<int>(X.size()), static_cast<int>(X[0].size()) }, torch::kInt32);
60 for (int i = 0; i < X.size(); ++i) {
61 dataset.index_put_({ i, "..." }, torch::tensor(X[i], torch::kInt32));
62 }
63 auto ytmp = torch::tensor(y, torch::kInt32);
64 buildDataset(ytmp);
65 const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
66 return build(features, className, states, weights);
67 }
68 Classifier& Classifier::fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
69 {
70 this->dataset = dataset;
71 const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
72 return build(features, className, states, weights);
73 }
74 Classifier& Classifier::fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
75 {
76 this->dataset = dataset;
77 return build(features, className, states, weights);
78 }
79 void Classifier::checkFitParameters()
80 {
81 if (torch::is_floating_point(dataset)) {
82 throw std::invalid_argument("dataset (X, y) must be of type Integer");
83 }
84 if (dataset.size(0) - 1 != features.size()) {
85 throw std::invalid_argument("Classifier: X " + std::to_string(dataset.size(0) - 1) + " and features " + std::to_string(features.size()) + " must have the same number of features");
86 }
87 if (states.find(className) == states.end()) {
88 throw std::invalid_argument("class name not found in states");
89 }
90 for (auto feature : features) {
91 if (states.find(feature) == states.end()) {
92 throw std::invalid_argument("feature [" + feature + "] not found in states");
93 }
94 }
95 }
96 torch::Tensor Classifier::predict(torch::Tensor& X)
97 {
98 if (!fitted) {
99 throw std::logic_error(CLASSIFIER_NOT_FITTED);
100 }
101 return model.predict(X);
102 }
103 std::vector<int> Classifier::predict(std::vector<std::vector<int>>& X)
104 {
105 if (!fitted) {
106 throw std::logic_error(CLASSIFIER_NOT_FITTED);
107 }
108 auto m_ = X[0].size();
109 auto n_ = X.size();
110 std::vector<std::vector<int>> Xd(n_, std::vector<int>(m_, 0));
111 for (auto i = 0; i < n_; i++) {
112 Xd[i] = std::vector<int>(X[i].begin(), X[i].end());
113 }
114 auto yp = model.predict(Xd);
115 return yp;
116 }
117 torch::Tensor Classifier::predict_proba(torch::Tensor& X)
118 {
119 if (!fitted) {
120 throw std::logic_error(CLASSIFIER_NOT_FITTED);
121 }
122 return model.predict_proba(X);
123 }
124 std::vector<std::vector<double>> Classifier::predict_proba(std::vector<std::vector<int>>& X)
125 {
126 if (!fitted) {
127 throw std::logic_error(CLASSIFIER_NOT_FITTED);
128 }
129 auto m_ = X[0].size();
130 auto n_ = X.size();
131 std::vector<std::vector<int>> Xd(n_, std::vector<int>(m_, 0));
132 // Convert to nxm vector
133 for (auto i = 0; i < n_; i++) {
134 Xd[i] = std::vector<int>(X[i].begin(), X[i].end());
135 }
136 auto yp = model.predict_proba(Xd);
137 return yp;
138 }
139 float Classifier::score(torch::Tensor& X, torch::Tensor& y)
140 {
141 torch::Tensor y_pred = predict(X);
142 return (y_pred == y).sum().item<float>() / y.size(0);
143 }
144 float Classifier::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
145 {
146 if (!fitted) {
147 throw std::logic_error(CLASSIFIER_NOT_FITTED);
148 }
149 return model.score(X, y);
150 }
151 std::vector<std::string> Classifier::show() const
152 {
153 return model.show();
154 }
155 void Classifier::addNodes()
156 {
157 // Add all nodes to the network
158 for (const auto& feature : features) {
159 model.addNode(feature);
160 }
161 model.addNode(className);
162 }
163 int Classifier::getNumberOfNodes() const
164 {
165 // Features does not include class
166 return fitted ? model.getFeatures().size() : 0;
167 }
168 int Classifier::getNumberOfEdges() const
169 {
170 return fitted ? model.getNumEdges() : 0;
171 }
172 int Classifier::getNumberOfStates() const
173 {
174 return fitted ? model.getStates() : 0;
175 }
176 int Classifier::getClassNumStates() const
177 {
178 return fitted ? model.getClassNumStates() : 0;
179 }
180 std::vector<std::string> Classifier::topological_order()
181 {
182 return model.topological_sort();
183 }
184 std::string Classifier::dump_cpt() const
185 {
186 return model.dump_cpt();
187 }
188 void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
189 {
190 if (!hyperparameters.empty()) {
191 throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
192 }
193 }
194}