9.7 KiB
9.7 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | |||||||||||||||||||||||||
![]() | |||||||||||||||||||||||||
|
|||||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #pragma once 8 : #include <vector> 9 : #include <torch/torch.h> 10 : #include <nlohmann/json.hpp> 11 : namespace bayesnet { 12 : enum status_t { NORMAL, WARNING, ERROR }; 13 : class BaseClassifier { 14 : public: 15 : // X is nxm std::vector, y is nx1 std::vector 16 : virtual BaseClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0; 17 : // X is nxm tensor, y is nx1 tensor 18 : virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0; 19 : virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0; 20 : virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) = 0; 21 1680 : virtual ~BaseClassifier() = default; 22 : torch::Tensor virtual predict(torch::Tensor& X) = 0; 23 : std::vector<int> virtual predict(std::vector<std::vector<int >>& X) = 0; 24 : torch::Tensor virtual predict_proba(torch::Tensor& X) = 0; 25 : std::vector<std::vector<double>> virtual predict_proba(std::vector<std::vector<int >>& X) = 0; 26 : status_t virtual getStatus() const = 0; 27 : float virtual score(std::vector<std::vector<int>>& X, std::vector<int>& y) = 0; 28 : float virtual score(torch::Tensor& X, torch::Tensor& y) = 0; 29 : int virtual getNumberOfNodes()const = 0; 30 : int virtual getNumberOfEdges()const = 0; 31 : int virtual getNumberOfStates() const = 0; 32 : int virtual getClassNumStates() const = 0; 33 : std::vector<std::string> virtual show() const = 0; 34 : std::vector<std::string> virtual graph(const std::string& title = "") const = 0; 35 : virtual std::string getVersion() = 0; 36 : std::vector<std::string> virtual topological_order() = 0; 37 : std::vector<std::string> virtual getNotes() const = 0; 38 : std::string virtual dump_cpt()const = 0; 39 : virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0; 40 : std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; } 41 : protected: 42 : virtual void trainModel(const torch::Tensor& weights) = 0; 43 : std::vector<std::string> validHyperparameters; 44 : }; 45 : } |
![]() |
Generated by: LCOV version 2.0-1 |
</html>