From 1aa3b609e51944b20af12f5edfd4d785db426772 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 21 Aug 2025 19:01:10 +0200 Subject: [PATCH] Fix adult numeric features mistake --- tests/TestBayesModels.cc | 11 +++++++++++ tests/TestUtils.cc | 2 ++ tests/data/all.txt | 2 +- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/TestBayesModels.cc b/tests/TestBayesModels.cc index 3a572fe..28c3df7 100644 --- a/tests/TestBayesModels.cc +++ b/tests/TestBayesModels.cc @@ -531,6 +531,11 @@ TEST_CASE("Test Dataset Loading", "[Datasets]") } std::cout << "| " << dataset.yt[sample].item() << std::endl; } + auto features = dataset.features; + std::cout << "States:" << std::endl; + for (int i = 0; i < 14; i++) { + std::cout << i << " has " << dataset.states.at(features[i]).size() << " states." << std::endl; + } dataset = RawDatasets("adult", false); std::cout << "Dataset adult raw " << std::endl; for (int sample = 0; sample < max_sample; sample++) { @@ -539,4 +544,10 @@ TEST_CASE("Test Dataset Loading", "[Datasets]") } std::cout << "| " << dataset.yt[sample].item() << std::endl; } + std::cout << "States:" << std::endl; + for (int i = 0; i < 14; i++) { + std::cout << i << " has " << dataset.states.at(features[i]).size() << " states." << std::endl; + } + auto clf = bayesnet::TANLd(); + clf.fit(dataset.Xt, dataset.yt, dataset.features, dataset.className, dataset.states, dataset.smoothing); } diff --git a/tests/TestUtils.cc b/tests/TestUtils.cc index 79ebad0..fd021c8 100644 --- a/tests/TestUtils.cc +++ b/tests/TestUtils.cc @@ -213,6 +213,8 @@ void RawDatasets::loadDataset(const std::string& name, bool class_last) if (!is_numeric.at(i)) { states[features[i]] = std::vector(maxValues[features[i]]); iota(begin(states.at(features[i])), end(states.at(features[i])), 0); + } else { + states[features[i]] = std::vector(); } } yt = torch::tensor(yv, torch::kInt32); diff --git a/tests/data/all.txt b/tests/data/all.txt index 8fc547e..74909fc 100644 --- a/tests/data/all.txt +++ b/tests/data/all.txt @@ -1,4 +1,4 @@ -adult;class;[0,2,4,11,12,13] +adult;class;[0,2,4,10,11,12] balance-scale;class; all breast-w;Class; all diabetes;class; all