Files
BayesNet/docs/manual/_classifier_8h_source.html

18 KiB

<html xmlns="http://www.w3.org/1999/xhtml" lang="en-US"> <head> <script type="text/javascript" src="jquery.js"></script> <script type="text/javascript" src="dynsections.js"></script> <script type="text/javascript" src="clipboard.js"></script> <script type="text/javascript" src="navtreedata.js"></script> <script type="text/javascript" src="navtree.js"></script> <script type="text/javascript" src="resize.js"></script> <script type="text/javascript" src="cookie.js"></script> <script type="text/javascript" src="search/searchdata.js"></script> <script type="text/javascript" src="search/search.js"></script> </head>
BayesNet 1.0.5
Bayesian Network Classifiers using libtorch from scratch
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ var searchBox = new SearchBox("searchBox", "search/",'.html'); /* @license-end */ </script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { codefold.init(0); }); /* @license-end */ </script> <script type="text/javascript" src="menudata.js"></script> <script type="text/javascript" src="menu.js"></script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { initMenu('',true,false,'search.php','Search',true); $(function() { init_search(); }); }); /* @license-end */ </script>
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function(){initNavTree('_classifier_8h_source.html',''); initResizable(true); }); /* @license-end */ </script>
Loading...
Searching...
No Matches
Classifier.h
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#ifndef CLASSIFIER_H
8#define CLASSIFIER_H
9#include <torch/torch.h>
10#include "bayesnet/utils/BayesMetrics.h"
11#include "bayesnet/network/Network.h"
12#include "bayesnet/BaseClassifier.h"
13
14namespace bayesnet {
15 class Classifier : public BaseClassifier {
16 public:
17 Classifier(Network model);
18 virtual ~Classifier() = default;
19 Classifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
20 Classifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
21 Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
22 Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) override;
23 void addNodes();
24 int getNumberOfNodes() const override;
25 int getNumberOfEdges() const override;
26 int getNumberOfStates() const override;
27 int getClassNumStates() const override;
28 torch::Tensor predict(torch::Tensor& X) override;
29 std::vector<int> predict(std::vector<std::vector<int>>& X) override;
30 torch::Tensor predict_proba(torch::Tensor& X) override;
31 std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
32 status_t getStatus() const override { return status; }
33 std::string getVersion() override { return { project_version.begin(), project_version.end() }; };
34 float score(torch::Tensor& X, torch::Tensor& y) override;
35 float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
36 std::vector<std::string> show() const override;
37 std::vector<std::string> topological_order() override;
38 std::vector<std::string> getNotes() const override { return notes; }
39 std::string dump_cpt() const override;
40 void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
41 protected:
42 bool fitted;
43 unsigned int m, n; // m: number of samples, n: number of features
44 Network model;
45 Metrics metrics;
46 std::vector<std::string> features;
47 std::string className;
48 std::map<std::string, std::vector<int>> states;
49 torch::Tensor dataset; // (n+1)xm tensor
50 status_t status = NORMAL;
51 std::vector<std::string> notes; // Used to store messages occurred during the fit process
52 void checkFitParameters();
53 virtual void buildModel(const torch::Tensor& weights) = 0;
54 void trainModel(const torch::Tensor& weights) override;
55 void buildDataset(torch::Tensor& y);
56 private:
57 Classifier& build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
58 };
59}
60#endif
61
62
63
64
65
</html>