// *************************************************************** // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #pragma once #include #include #include namespace bayesnet { enum status_t { NORMAL, WARNING, ERROR }; class BaseClassifier { public: // X is nxm std::vector, y is nx1 std::vector virtual BaseClassifier& fit(std::vector>& X, std::vector& y, const std::vector& features, const std::string& className, std::map>& states) = 0; // X is nxm tensor, y is nx1 tensor virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states) = 0; virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states) = 0; virtual BaseClassifier& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states, const torch::Tensor& weights) = 0; virtual ~BaseClassifier() = default; torch::Tensor virtual predict(torch::Tensor& X) = 0; std::vector virtual predict(std::vector>& X) = 0; torch::Tensor virtual predict_proba(torch::Tensor& X) = 0; std::vector> virtual predict_proba(std::vector>& X) = 0; status_t virtual getStatus() const = 0; float virtual score(std::vector>& X, std::vector& y) = 0; float virtual score(torch::Tensor& X, torch::Tensor& y) = 0; int virtual getNumberOfNodes()const = 0; int virtual getNumberOfEdges()const = 0; int virtual getNumberOfStates() const = 0; int virtual getClassNumStates() const = 0; std::vector virtual show() const = 0; std::vector virtual graph(const std::string& title = "") const = 0; virtual std::string getVersion() = 0; std::vector virtual topological_order() = 0; std::vector virtual getNotes() const = 0; std::string virtual dump_cpt()const = 0; virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0; std::vector& getValidHyperparameters() { return validHyperparameters; } protected: virtual void trainModel(const torch::Tensor& weights) = 0; std::vector validHyperparameters; }; }