LCOV - code coverage report
Current view: top level - bayesnet/classifiers - Classifier.h (source / functions) Coverage Total Hit
Test: BayesNet Coverage Report Lines: 100.0 % 4 4
Test Date: 2024-05-06 17:54:04 Functions: 100.0 % 4 4
Legend: Lines: hit not hit

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #ifndef CLASSIFIER_H
       8              : #define CLASSIFIER_H
       9              : #include <torch/torch.h>
      10              : #include "bayesnet/utils/BayesMetrics.h"
      11              : #include "bayesnet/network/Network.h"
      12              : #include "bayesnet/BaseClassifier.h"
      13              : 
      14              : namespace bayesnet {
      15              :     class Classifier : public BaseClassifier {
      16              :     public:
      17              :         Classifier(Network model);
      18         1680 :         virtual ~Classifier() = default;
      19              :         Classifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
      20              :         Classifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
      21              :         Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
      22              :         Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) override;
      23              :         void addNodes();
      24              :         int getNumberOfNodes() const override;
      25              :         int getNumberOfEdges() const override;
      26              :         int getNumberOfStates() const override;
      27              :         int getClassNumStates() const override;
      28              :         torch::Tensor predict(torch::Tensor& X) override;
      29              :         std::vector<int> predict(std::vector<std::vector<int>>& X) override;
      30              :         torch::Tensor predict_proba(torch::Tensor& X) override;
      31              :         std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
      32          128 :         status_t getStatus() const override { return status; }
      33           96 :         std::string getVersion() override { return { project_version.begin(), project_version.end() }; };
      34              :         float score(torch::Tensor& X, torch::Tensor& y) override;
      35              :         float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
      36              :         std::vector<std::string> show() const override;
      37              :         std::vector<std::string> topological_order()  override;
      38           80 :         std::vector<std::string> getNotes() const override { return notes; }
      39              :         std::string dump_cpt() const override;
      40              :         void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
      41              :     protected:
      42              :         bool fitted;
      43              :         unsigned int m, n; // m: number of samples, n: number of features
      44              :         Network model;
      45              :         Metrics metrics;
      46              :         std::vector<std::string> features;
      47              :         std::string className;
      48              :         std::map<std::string, std::vector<int>> states;
      49              :         torch::Tensor dataset; // (n+1)xm tensor
      50              :         status_t status = NORMAL;
      51              :         std::vector<std::string> notes; // Used to store messages occurred during the fit process
      52              :         void checkFitParameters();
      53              :         virtual void buildModel(const torch::Tensor& weights) = 0;
      54              :         void trainModel(const torch::Tensor& weights) override;
      55              :         void buildDataset(torch::Tensor& y);
      56              :     private:
      57              :         Classifier& build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
      58              :     };
      59              : }
      60              : #endif
      61              : 
      62              : 
      63              : 
      64              : 
      65              : 
        

Generated by: LCOV version 2.0-1