Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef CLASSIFIER_H
8 : #define CLASSIFIER_H
9 : #include <torch/torch.h>
10 : #include "bayesnet/utils/BayesMetrics.h"
11 : #include "bayesnet/network/Network.h"
12 : #include "bayesnet/BaseClassifier.h"
13 :
14 : namespace bayesnet {
15 : class Classifier : public BaseClassifier {
16 : public:
17 : Classifier(Network model);
18 1680 : virtual ~Classifier() = default;
19 : Classifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
20 : Classifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
21 : Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
22 : Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) override;
23 : void addNodes();
24 : int getNumberOfNodes() const override;
25 : int getNumberOfEdges() const override;
26 : int getNumberOfStates() const override;
27 : int getClassNumStates() const override;
28 : torch::Tensor predict(torch::Tensor& X) override;
29 : std::vector<int> predict(std::vector<std::vector<int>>& X) override;
30 : torch::Tensor predict_proba(torch::Tensor& X) override;
31 : std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
32 128 : status_t getStatus() const override { return status; }
33 96 : std::string getVersion() override { return { project_version.begin(), project_version.end() }; };
34 : float score(torch::Tensor& X, torch::Tensor& y) override;
35 : float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
36 : std::vector<std::string> show() const override;
37 : std::vector<std::string> topological_order() override;
38 80 : std::vector<std::string> getNotes() const override { return notes; }
39 : std::string dump_cpt() const override;
40 : void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
41 : protected:
42 : bool fitted;
43 : unsigned int m, n; // m: number of samples, n: number of features
44 : Network model;
45 : Metrics metrics;
46 : std::vector<std::string> features;
47 : std::string className;
48 : std::map<std::string, std::vector<int>> states;
49 : torch::Tensor dataset; // (n+1)xm tensor
50 : status_t status = NORMAL;
51 : std::vector<std::string> notes; // Used to store messages occurred during the fit process
52 : void checkFitParameters();
53 : virtual void buildModel(const torch::Tensor& weights) = 0;
54 : void trainModel(const torch::Tensor& weights) override;
55 : void buildDataset(torch::Tensor& y);
56 : private:
57 : Classifier& build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
58 : };
59 : }
60 : #endif
61 :
62 :
63 :
64 :
65 :
|