Optimize BoostAODE -> XBAODE #33

Merged
rmontanana merged 27 commits from WA2DE into main 2025-03-16 17:58:10 +00:00
6 changed files with 362 additions and 5 deletions
Showing only changes of commit f658149977 - Show all commits

View File

@@ -28,8 +28,8 @@ namespace bayesnet {
status_t virtual getStatus() const = 0;
float virtual score(std::vector<std::vector<int>>& X, std::vector<int>& y) = 0;
float virtual score(torch::Tensor& X, torch::Tensor& y) = 0;
int virtual getNumberOfNodes()const = 0;
int virtual getNumberOfEdges()const = 0;
int virtual getNumberOfNodes() const = 0;
int virtual getNumberOfEdges() const = 0;
int virtual getNumberOfStates() const = 0;
int virtual getClassNumStates() const = 0;
std::vector<std::string> virtual show() const = 0;
@@ -37,7 +37,7 @@ namespace bayesnet {
virtual std::string getVersion() = 0;
std::vector<std::string> virtual topological_order() = 0;
std::vector<std::string> virtual getNotes() const = 0;
std::string virtual dump_cpt()const = 0;
std::string virtual dump_cpt() const = 0;
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
protected:

View File

@@ -33,7 +33,12 @@ namespace bayesnet {
}
std::string dump_cpt() const override
{
return "";
std::string output;
for (auto& model : models) {
output += model->dump_cpt();
output += std::string(80, '-') + "\n";
}
return output;
}
protected:
torch::Tensor predict_average_voting(torch::Tensor& X);