Fix XSPode

This commit is contained in:
2025-03-10 15:55:48 +01:00
parent 86cccb6c7b
commit a26522e62f
4 changed files with 23 additions and 3 deletions

View File

@@ -407,5 +407,24 @@ namespace bayesnet {
} }
return result; return result;
} }
float XSpode::score(torch::Tensor& X, torch::Tensor& y)
{
torch::Tensor y_pred = predict(X);
return (y_pred == y).sum().item<float>() / y.size(0);
}
float XSpode::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
{
if (!fitted) {
throw std::logic_error(CLASSIFIER_NOT_FITTED);
}
auto y_pred = this->predict(X);
int correct = 0;
for (int i = 0; i < y_pred.size(); ++i) {
if (y_pred[i] == y[i]) {
correct++;
}
}
return (double)correct / y_pred.size();
}
} }

View File

@@ -38,6 +38,8 @@ namespace bayesnet {
torch::Tensor predict(torch::Tensor& X) override; torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override; std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override; torch::Tensor predict_proba(torch::Tensor& X) override;
float score(torch::Tensor& X, torch::Tensor& y) override;
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
protected: protected:
void buildModel(const torch::Tensor& weights) override; void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override; void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;

View File

@@ -30,7 +30,7 @@ namespace bayesnet {
protected: protected:
// Model-building function // Model-building function
void buildModel(const torch::Tensor& weights) override; void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& data, const Smoothing_t smoothing) override;
private: private:
int num_classes_; // Number of classes int num_classes_; // Number of classes
int num_attributes_; // Number of attributes int num_attributes_; // Number of attributes
@@ -46,7 +46,6 @@ namespace bayesnet {
bool weighted_a2de_; // Whether to use weighted A2DE bool weighted_a2de_; // Whether to use weighted A2DE
double smoothing_factor_; // Smoothing parameter (default: Laplace) double smoothing_factor_; // Smoothing parameter (default: Laplace)
torch::Tensor AODEConditionalProb(const torch::Tensor& data); torch::Tensor AODEConditionalProb(const torch::Tensor& data);
void trainModel(const torch::Tensor& data, const Smoothing_t smoothing);
int toIntValue(int attributeIndex, float value) const; int toIntValue(int attributeIndex, float value) const;
}; };
} }

View File

@@ -22,4 +22,4 @@ include_directories(
) )
add_executable(bayesnet_sample sample.cc) add_executable(bayesnet_sample sample.cc)
target_link_libraries(bayesnet_sample fimdlp "${TORCH_LIBRARIES}" "${BayesNet}") target_link_libraries(bayesnet_sample ${FImdlp} "${TORCH_LIBRARIES}" "${BayesNet}")