35 lines
2.1 KiB
C++
35 lines
2.1 KiB
C++
#ifndef BASE_H
|
|
#define BASE_H
|
|
#include <torch/torch.h>
|
|
#include <nlohmann/json.hpp>
|
|
#include <vector>
|
|
namespace bayesnet {
|
|
enum status_t { NORMAL, WARNING, ERROR };
|
|
class BaseClassifier {
|
|
protected:
|
|
virtual void trainModel(const torch::Tensor& weights) = 0;
|
|
public:
|
|
// X is nxm std::vector, y is nx1 std::vector
|
|
virtual BaseClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
|
|
// X is nxm tensor, y is nx1 tensor
|
|
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
|
|
virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
|
|
virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) = 0;
|
|
virtual ~BaseClassifier() = default;
|
|
torch::Tensor virtual predict(torch::Tensor& X) = 0;
|
|
std::vector<int> virtual predict(std::vector<std::vector<int >>& X) = 0;
|
|
status_t virtual getStatus() const = 0;
|
|
float virtual score(std::vector<std::vector<int>>& X, std::vector<int>& y) = 0;
|
|
float virtual score(torch::Tensor& X, torch::Tensor& y) = 0;
|
|
int virtual getNumberOfNodes()const = 0;
|
|
int virtual getNumberOfEdges()const = 0;
|
|
int virtual getNumberOfStates() const = 0;
|
|
std::vector<std::string> virtual show() const = 0;
|
|
std::vector<std::string> virtual graph(const std::string& title = "") const = 0;
|
|
const std::string inline getVersion() const { return "0.2.0"; };
|
|
std::vector<std::string> virtual topological_order() = 0;
|
|
void virtual dump_cpt()const = 0;
|
|
virtual void setHyperparameters(nlohmann::json& hyperparameters) = 0;
|
|
};
|
|
}
|
|
#endif |