Add parent hyperparameter to TAN & SPODE

This commit is contained in:
2024-12-17 10:14:14 +01:00
parent 56a2d3ead0
commit e2781ee525
7 changed files with 78 additions and 11 deletions

View File

@@ -10,14 +10,15 @@
namespace bayesnet {
class SPODE : public Classifier {
private:
int root;
protected:
void buildModel(const torch::Tensor& weights) override;
public:
explicit SPODE(int root);
virtual ~SPODE() = default;
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
std::vector<std::string> graph(const std::string& name = "SPODE") const override;
protected:
void buildModel(const torch::Tensor& weights) override;
private:
int root;
};
}
#endif