From f65814997727fcb85b474ffb478a19edda409239 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 12 Feb 2025 20:55:35 +0100 Subject: [PATCH] Add dump_cpt to Ensemble --- bayesnet/BaseClassifier.h | 6 +++--- bayesnet/ensembles/Ensemble.h | 7 ++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/bayesnet/BaseClassifier.h b/bayesnet/BaseClassifier.h index 81fbe26..60a0c8e 100644 --- a/bayesnet/BaseClassifier.h +++ b/bayesnet/BaseClassifier.h @@ -28,8 +28,8 @@ namespace bayesnet { status_t virtual getStatus() const = 0; float virtual score(std::vector>& X, std::vector& y) = 0; float virtual score(torch::Tensor& X, torch::Tensor& y) = 0; - int virtual getNumberOfNodes()const = 0; - int virtual getNumberOfEdges()const = 0; + int virtual getNumberOfNodes() const = 0; + int virtual getNumberOfEdges() const = 0; int virtual getNumberOfStates() const = 0; int virtual getClassNumStates() const = 0; std::vector virtual show() const = 0; @@ -37,7 +37,7 @@ namespace bayesnet { virtual std::string getVersion() = 0; std::vector virtual topological_order() = 0; std::vector virtual getNotes() const = 0; - std::string virtual dump_cpt()const = 0; + std::string virtual dump_cpt() const = 0; virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0; std::vector& getValidHyperparameters() { return validHyperparameters; } protected: diff --git a/bayesnet/ensembles/Ensemble.h b/bayesnet/ensembles/Ensemble.h index 5172a40..c046f54 100644 --- a/bayesnet/ensembles/Ensemble.h +++ b/bayesnet/ensembles/Ensemble.h @@ -33,7 +33,12 @@ namespace bayesnet { } std::string dump_cpt() const override { - return ""; + std::string output; + for (auto& model : models) { + output += model->dump_cpt(); + output += std::string(80, '-') + "\n"; + } + return output; } protected: torch::Tensor predict_average_voting(torch::Tensor& X);