2024-04-11 16:02:49 +00:00
|
|
|
// ***************************************************************
|
|
|
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
|
|
|
// SPDX-FileType: SOURCE
|
|
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
// ***************************************************************
|
|
|
|
|
|
|
|
#pragma once
|
2024-03-08 21:20:54 +00:00
|
|
|
#include <vector>
|
2023-07-13 01:15:42 +00:00
|
|
|
#include <torch/torch.h>
|
2023-08-20 15:57:38 +00:00
|
|
|
#include <nlohmann/json.hpp>
|
2023-07-13 01:15:42 +00:00
|
|
|
namespace bayesnet {
|
2023-09-05 11:39:43 +00:00
|
|
|
enum status_t { NORMAL, WARNING, ERROR };
|
2023-07-13 01:15:42 +00:00
|
|
|
class BaseClassifier {
|
|
|
|
public:
|
2023-11-08 17:45:35 +00:00
|
|
|
// X is nxm std::vector, y is nx1 std::vector
|
|
|
|
virtual BaseClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
|
2023-08-03 18:22:33 +00:00
|
|
|
// X is nxm tensor, y is nx1 tensor
|
2023-11-08 17:45:35 +00:00
|
|
|
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
|
|
|
|
virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
|
|
|
|
virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) = 0;
|
2023-08-04 17:42:18 +00:00
|
|
|
virtual ~BaseClassifier() = default;
|
2023-07-30 17:00:02 +00:00
|
|
|
torch::Tensor virtual predict(torch::Tensor& X) = 0;
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<int> virtual predict(std::vector<std::vector<int >>& X) = 0;
|
2024-02-22 10:45:40 +00:00
|
|
|
torch::Tensor virtual predict_proba(torch::Tensor& X) = 0;
|
|
|
|
std::vector<std::vector<double>> virtual predict_proba(std::vector<std::vector<int >>& X) = 0;
|
2023-09-05 11:39:43 +00:00
|
|
|
status_t virtual getStatus() const = 0;
|
2023-11-08 17:45:35 +00:00
|
|
|
float virtual score(std::vector<std::vector<int>>& X, std::vector<int>& y) = 0;
|
2023-07-23 12:10:28 +00:00
|
|
|
float virtual score(torch::Tensor& X, torch::Tensor& y) = 0;
|
2023-08-07 23:53:41 +00:00
|
|
|
int virtual getNumberOfNodes()const = 0;
|
|
|
|
int virtual getNumberOfEdges()const = 0;
|
|
|
|
int virtual getNumberOfStates() const = 0;
|
2024-02-22 10:45:40 +00:00
|
|
|
int virtual getClassNumStates() const = 0;
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::string> virtual show() const = 0;
|
|
|
|
std::vector<std::string> virtual graph(const std::string& title = "") const = 0;
|
2023-11-13 10:13:32 +00:00
|
|
|
virtual std::string getVersion() = 0;
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::string> virtual topological_order() = 0;
|
2024-02-09 09:57:19 +00:00
|
|
|
std::vector<std::string> virtual getNotes() const = 0;
|
2024-04-07 23:25:14 +00:00
|
|
|
std::string virtual dump_cpt()const = 0;
|
2023-11-18 10:56:10 +00:00
|
|
|
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
|
2023-11-19 21:36:27 +00:00
|
|
|
std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
|
|
|
|
protected:
|
|
|
|
virtual void trainModel(const torch::Tensor& weights) = 0;
|
|
|
|
std::vector<std::string> validHyperparameters;
|
2023-07-13 01:15:42 +00:00
|
|
|
};
|
2024-04-11 16:02:49 +00:00
|
|
|
}
|