From 7f7ddad36a22fb6e5c59c4edcd04980167b44c18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 27 Jul 2023 16:51:27 +0200 Subject: [PATCH] Fix stratified folding mistake in remainders --- .vscode/launch.json | 9 +++++---- src/BayesNet/BayesMetrics.cc | 2 +- src/BayesNet/Mst.cc | 2 +- src/Platform/CMakeLists.txt | 2 +- src/Platform/Datasets.cc | 4 ++++ src/Platform/Datasets.h | 1 + src/Platform/Folding.cc | 2 +- src/Platform/main.cc | 4 ++++ 8 files changed, 18 insertions(+), 8 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 1b983e6..14f9cc8 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -26,12 +26,13 @@ "-m", "TAN", "-p", - "datasets", + "/Users/rmontanana/Code/discretizbench/datasets", "--discretize", - "-f", - "5", + "--stratified", "--title", - "Debug test" + "Debug test", + "-d", + "ionosphere" ], "cwd": "${workspaceFolder}/build/src/Platform", }, diff --git a/src/BayesNet/BayesMetrics.cc b/src/BayesNet/BayesMetrics.cc index 8c46582..6671995 100644 --- a/src/BayesNet/BayesMetrics.cc +++ b/src/BayesNet/BayesMetrics.cc @@ -94,7 +94,7 @@ namespace bayesnet { totalWeight += 1; } if (totalWeight == 0) - throw invalid_argument("Total weight should not be zero"); + return 0; double entropyValue = 0; for (int value = 0; value < featureCounts.sizes()[0]; ++value) { double p_f = featureCounts[value].item() / totalWeight; diff --git a/src/BayesNet/Mst.cc b/src/BayesNet/Mst.cc index 56a0558..b86812b 100644 --- a/src/BayesNet/Mst.cc +++ b/src/BayesNet/Mst.cc @@ -103,7 +103,7 @@ namespace bayesnet { // Make a complete graph for (int i = 0; i < num_features - 1; ++i) { - for (int j = i; j < num_features; ++j) { + for (int j = i + 1; j < num_features; ++j) { g.addEdge(i, j, weights[i][j].item()); } } diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index b0e135a..f1fea17 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -5,4 +5,4 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc) -target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES} ") \ No newline at end of file +target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Platform/Datasets.cc b/src/Platform/Datasets.cc index 7943bad..0c09c59 100644 --- a/src/Platform/Datasets.cc +++ b/src/Platform/Datasets.cc @@ -89,6 +89,10 @@ namespace platform { } return datasets[name]->getTensors(); } + bool Datasets::isDataset(string name) + { + return datasets.find(name) != datasets.end(); + } Dataset::Dataset(Dataset& dataset) { path = dataset.path; diff --git a/src/Platform/Datasets.h b/src/Platform/Datasets.h index b593e24..f6a4c5b 100644 --- a/src/Platform/Datasets.h +++ b/src/Platform/Datasets.h @@ -58,6 +58,7 @@ namespace platform { pair>&, vector&> getVectors(string name); pair>&, vector&> getVectorsDiscretized(string name); pair getTensors(string name); + bool isDataset(string name); }; }; diff --git a/src/Platform/Folding.cc b/src/Platform/Folding.cc index 86aaaec..a6687f6 100644 --- a/src/Platform/Folding.cc +++ b/src/Platform/Folding.cc @@ -72,7 +72,7 @@ void StratifiedKFold::build() } while (remainder_samples_to_take > 0) { int fold = (rand() % static_cast(k)); - if (stratified_indices[fold].size() == fold_size) { + if (stratified_indices[fold].size() == fold_size + 1) { continue; } auto it = next(class_indices[label].begin(), 1); diff --git a/src/Platform/main.cc b/src/Platform/main.cc index f1f6ed8..327d02e 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -81,6 +81,10 @@ int main(int argc, char** argv) vector filesToProcess; auto datasets = platform::Datasets(path, true, platform::ARFF); if (file_name != "") { + if (!datasets.isDataset(file_name)) { + cerr << "Dataset " << file_name << " not found" << endl; + exit(1); + } filesToProcess.push_back(file_name); } else { filesToProcess = platform::Datasets(path, true, platform::ARFF).getNames();