LCOV - code coverage report
Current view: top level - bayesnet - BaseClassifier.h (source / functions) Coverage Total Hit
Test: BayesNet Coverage Report Lines: 100.0 % 1 1
Test Date: 2024-05-06 17:54:04 Functions: 100.0 % 1 1
Legend: Lines: hit not hit

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #pragma once
       8              : #include <vector>
       9              : #include <torch/torch.h>
      10              : #include <nlohmann/json.hpp>
      11              : namespace bayesnet {
      12              :     enum status_t { NORMAL, WARNING, ERROR };
      13              :     class BaseClassifier {
      14              :     public:
      15              :         // X is nxm std::vector, y is nx1 std::vector
      16              :         virtual BaseClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
      17              :         // X is nxm tensor, y is nx1 tensor
      18              :         virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
      19              :         virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
      20              :         virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) = 0;
      21         1680 :         virtual ~BaseClassifier() = default;
      22              :         torch::Tensor virtual predict(torch::Tensor& X) = 0;
      23              :         std::vector<int> virtual predict(std::vector<std::vector<int >>& X) = 0;
      24              :         torch::Tensor virtual predict_proba(torch::Tensor& X) = 0;
      25              :         std::vector<std::vector<double>> virtual predict_proba(std::vector<std::vector<int >>& X) = 0;
      26              :         status_t virtual getStatus() const = 0;
      27              :         float virtual score(std::vector<std::vector<int>>& X, std::vector<int>& y) = 0;
      28              :         float virtual score(torch::Tensor& X, torch::Tensor& y) = 0;
      29              :         int virtual getNumberOfNodes()const = 0;
      30              :         int virtual getNumberOfEdges()const = 0;
      31              :         int virtual getNumberOfStates() const = 0;
      32              :         int virtual getClassNumStates() const = 0;
      33              :         std::vector<std::string> virtual show() const = 0;
      34              :         std::vector<std::string> virtual graph(const std::string& title = "") const = 0;
      35              :         virtual std::string getVersion() = 0;
      36              :         std::vector<std::string> virtual topological_order() = 0;
      37              :         std::vector<std::string> virtual getNotes() const = 0;
      38              :         std::string virtual dump_cpt()const = 0;
      39              :         virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
      40              :         std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
      41              :     protected:
      42              :         virtual void trainModel(const torch::Tensor& weights) = 0;
      43              :         std::vector<std::string> validHyperparameters;
      44              :     };
      45              : }
        

Generated by: LCOV version 2.0-1