Fix predict_proba declaration

This commit is contained in:
2025-02-26 21:08:33 +01:00
parent 0d1e4b3c6f
commit b055065e59

View File

@@ -32,7 +32,7 @@ namespace platform {
torch::Tensor predict(torch::Tensor& X) override; torch::Tensor predict(torch::Tensor& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override; torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<int> predict_spode(std::vector<std::vector<int>>& test_data, int parent); std::vector<int> predict_spode(std::vector<std::vector<int>>& test_data, int parent);
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override; std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>& X);
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override; float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
float score(torch::Tensor& X, torch::Tensor& y) override; float score(torch::Tensor& X, torch::Tensor& y) override;
int getNumberOfNodes() const override; int getNumberOfNodes() const override;
@@ -47,7 +47,7 @@ namespace platform {
std::vector<std::string> getNotes() const override { return notes; } std::vector<std::string> getNotes() const override { return notes; }
std::vector<std::string> graph(const std::string& title = "") const override { return {}; } std::vector<std::string> graph(const std::string& title = "") const override { return {}; }
void setHyperparameters(const nlohmann::json& hyperparameters) override; void setHyperparameters(const nlohmann::json& hyperparameters) override;
void set_active_parents(std::vector<int> active_parents) { for (const auto& parent : active_parents) aode_.add_active_parent(parent); } void set_active_parents(const std::vector<int> active_parents) { for (const auto& parent : active_parents) aode_.add_active_parent(parent); }
void add_active_parent(int parent) { aode_.add_active_parent(parent); } void add_active_parent(int parent) { aode_.add_active_parent(parent); }
void remove_last_parent() { aode_.remove_last_parent(); } void remove_last_parent() { aode_.remove_last_parent(); }
protected: protected: