12 KiB
12 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #ifndef CLASSIFIER_H 8 : #define CLASSIFIER_H 9 : #include <torch/torch.h> 10 : #include "bayesnet/utils/BayesMetrics.h" 11 : #include "bayesnet/network/Network.h" 12 : #include "bayesnet/BaseClassifier.h" 13 : 14 : namespace bayesnet { 15 : class Classifier : public BaseClassifier { 16 : public: 17 : Classifier(Network model); 18 606 : virtual ~Classifier() = default; 19 : Classifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override; 20 : Classifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override; 21 : Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override; 22 : Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) override; 23 : void addNodes(); 24 : int getNumberOfNodes() const override; 25 : int getNumberOfEdges() const override; 26 : int getNumberOfStates() const override; 27 : int getClassNumStates() const override; 28 : torch::Tensor predict(torch::Tensor& X) override; 29 : std::vector<int> predict(std::vector<std::vector<int>>& X) override; 30 : torch::Tensor predict_proba(torch::Tensor& X) override; 31 : std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override; 32 64 : status_t getStatus() const override { return status; } 33 48 : std::string getVersion() override { return { project_version.begin(), project_version.end() }; }; 34 : float score(torch::Tensor& X, torch::Tensor& y) override; 35 : float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override; 36 : std::vector<std::string> show() const override; 37 : std::vector<std::string> topological_order() override; 38 38 : std::vector<std::string> getNotes() const override { return notes; } 39 : std::string dump_cpt() const override; 40 : void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters 41 : protected: 42 : bool fitted; 43 : unsigned int m, n; // m: number of samples, n: number of features 44 : Network model; 45 : Metrics metrics; 46 : std::vector<std::string> features; 47 : std::string className; 48 : std::map<std::string, std::vector<int>> states; 49 : torch::Tensor dataset; // (n+1)xm tensor 50 : status_t status = NORMAL; 51 : std::vector<std::string> notes; // Used to store messages occurred during the fit process 52 : void checkFitParameters(); 53 : virtual void buildModel(const torch::Tensor& weights) = 0; 54 : void trainModel(const torch::Tensor& weights) override; 55 : void buildDataset(torch::Tensor& y); 56 : private: 57 : Classifier& build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights); 58 : }; 59 : } 60 : #endif 61 : 62 : 63 : 64 : 65 : |
![]() |
Generated by: LCOV version 2.0-1 |
</html>