Fix XSPode

This commit is contained in:
2025-03-10 15:55:48 +01:00
parent 86cccb6c7b
commit a26522e62f
4 changed files with 23 additions and 3 deletions

View File

@@ -30,7 +30,7 @@ namespace bayesnet {
protected:
// Model-building function
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& data, const Smoothing_t smoothing) override;
private:
int num_classes_; // Number of classes
int num_attributes_; // Number of attributes
@@ -46,7 +46,6 @@ namespace bayesnet {
bool weighted_a2de_; // Whether to use weighted A2DE
double smoothing_factor_; // Smoothing parameter (default: Laplace)
torch::Tensor AODEConditionalProb(const torch::Tensor& data);
void trainModel(const torch::Tensor& data, const Smoothing_t smoothing);
int toIntValue(int attributeIndex, float value) const;
};
}