Fix XSPode
This commit is contained in:
@@ -407,5 +407,24 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
float XSpode::score(torch::Tensor& X, torch::Tensor& y)
|
||||||
|
{
|
||||||
|
torch::Tensor y_pred = predict(X);
|
||||||
|
return (y_pred == y).sum().item<float>() / y.size(0);
|
||||||
|
}
|
||||||
|
float XSpode::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
|
||||||
|
{
|
||||||
|
if (!fitted) {
|
||||||
|
throw std::logic_error(CLASSIFIER_NOT_FITTED);
|
||||||
|
}
|
||||||
|
auto y_pred = this->predict(X);
|
||||||
|
int correct = 0;
|
||||||
|
for (int i = 0; i < y_pred.size(); ++i) {
|
||||||
|
if (y_pred[i] == y[i]) {
|
||||||
|
correct++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return (double)correct / y_pred.size();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -38,6 +38,8 @@ namespace bayesnet {
|
|||||||
torch::Tensor predict(torch::Tensor& X) override;
|
torch::Tensor predict(torch::Tensor& X) override;
|
||||||
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
||||||
torch::Tensor predict_proba(torch::Tensor& X) override;
|
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||||
|
float score(torch::Tensor& X, torch::Tensor& y) override;
|
||||||
|
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
|
||||||
protected:
|
protected:
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;
|
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;
|
||||||
|
@@ -30,7 +30,7 @@ namespace bayesnet {
|
|||||||
protected:
|
protected:
|
||||||
// Model-building function
|
// Model-building function
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
|
void trainModel(const torch::Tensor& data, const Smoothing_t smoothing) override;
|
||||||
private:
|
private:
|
||||||
int num_classes_; // Number of classes
|
int num_classes_; // Number of classes
|
||||||
int num_attributes_; // Number of attributes
|
int num_attributes_; // Number of attributes
|
||||||
@@ -46,7 +46,6 @@ namespace bayesnet {
|
|||||||
bool weighted_a2de_; // Whether to use weighted A2DE
|
bool weighted_a2de_; // Whether to use weighted A2DE
|
||||||
double smoothing_factor_; // Smoothing parameter (default: Laplace)
|
double smoothing_factor_; // Smoothing parameter (default: Laplace)
|
||||||
torch::Tensor AODEConditionalProb(const torch::Tensor& data);
|
torch::Tensor AODEConditionalProb(const torch::Tensor& data);
|
||||||
void trainModel(const torch::Tensor& data, const Smoothing_t smoothing);
|
|
||||||
int toIntValue(int attributeIndex, float value) const;
|
int toIntValue(int attributeIndex, float value) const;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@@ -22,4 +22,4 @@ include_directories(
|
|||||||
)
|
)
|
||||||
|
|
||||||
add_executable(bayesnet_sample sample.cc)
|
add_executable(bayesnet_sample sample.cc)
|
||||||
target_link_libraries(bayesnet_sample fimdlp "${TORCH_LIBRARIES}" "${BayesNet}")
|
target_link_libraries(bayesnet_sample ${FImdlp} "${TORCH_LIBRARIES}" "${BayesNet}")
|
Reference in New Issue
Block a user