Fix Xspode

This commit is contained in:
2025-03-10 21:29:47 +01:00
parent a26522e62f
commit 5919fbfd34
5 changed files with 5 additions and 12 deletions

View File

@@ -10,7 +10,6 @@
namespace bayesnet {
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
Classifier& Classifier::build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing)
{
this->features = features;
@@ -22,11 +21,8 @@ namespace bayesnet {
auto n_classes = states.at(className).size();
metrics = Metrics(dataset, features, className, n_classes);
model.initialize();
std::cout << "Ahora buildmodel"<< std::endl;
buildModel(weights);
std::cout << "Ahora trainmodel"<< std::endl;
trainModel(weights, smoothing);
std::cout << "Después de trainmodel"<< std::endl;
fitted = true;
return *this;
}

View File

@@ -50,6 +50,7 @@ namespace bayesnet {
virtual void buildModel(const torch::Tensor& weights) = 0;
void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override;
void buildDataset(torch::Tensor& y);
const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
private:
Classifier& build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing);
};

View File

@@ -110,7 +110,6 @@ namespace bayesnet {
instance[nFeatures_] = dataset[-1][i].item<int>();
addSample(instance, weights[i].item<double>());
}
switch (smoothing) {
case bayesnet::Smoothing_t::ORIGINAL:
alpha_ = 1.0 / m;
@@ -414,9 +413,6 @@ namespace bayesnet {
}
float XSpode::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
{
if (!fitted) {
throw std::logic_error(CLASSIFIER_NOT_FITTED);
}
auto y_pred = this->predict(X);
int correct = 0;
for (int i = 0; i < y_pred.size(); ++i) {

View File

@@ -51,8 +51,6 @@ namespace bayesnet {
int statesClass_;
std::vector<int> states_; // [states_feat0, ..., states_feat(N-1)] (class not included in this array)
const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
// Class counts
std::vector<double> classCounts_; // [c], accumulative
std::vector<double> classPriors_; // [c], after normalization

View File

@@ -58,10 +58,12 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
auto raw = RawDatasets(file_name, discretize);
if (name == "XSPODE") {
std::cout << "Fitting XSPODE" << std::endl;
} else {
std::cout << "Fitting something else [" << name << "]" << std::endl;
}
clf->fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
clf->fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
auto score = clf->score(raw.Xt, raw.yt);
std::cout << "Classifier: " << name << " File: " << file_name << " Score: " << score << " expected = " << scores[{file_name, name}] << std::endl;
std::cout << "Classifier: " << name << " File: " << file_name << " Score: " << score << " expected = " << scores[{file_name, name}] << std::endl;
INFO("Classifier: " << name << " File: " << file_name);
REQUIRE(score == Catch::Approx(scores[{file_name, name}]).epsilon(raw.epsilon));
REQUIRE(clf->getStatus() == bayesnet::NORMAL);