Fix stratified folding mistake in remainders

This commit is contained in:
Ricardo Montañana Gómez 2023-07-27 16:51:27 +02:00
parent 3d8fea7a37
commit 7f7ddad36a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
8 changed files with 18 additions and 8 deletions

9
.vscode/launch.json vendored
View File

@ -26,12 +26,13 @@
"-m", "-m",
"TAN", "TAN",
"-p", "-p",
"datasets", "/Users/rmontanana/Code/discretizbench/datasets",
"--discretize", "--discretize",
"-f", "--stratified",
"5",
"--title", "--title",
"Debug test" "Debug test",
"-d",
"ionosphere"
], ],
"cwd": "${workspaceFolder}/build/src/Platform", "cwd": "${workspaceFolder}/build/src/Platform",
}, },

View File

@ -94,7 +94,7 @@ namespace bayesnet {
totalWeight += 1; totalWeight += 1;
} }
if (totalWeight == 0) if (totalWeight == 0)
throw invalid_argument("Total weight should not be zero"); return 0;
double entropyValue = 0; double entropyValue = 0;
for (int value = 0; value < featureCounts.sizes()[0]; ++value) { for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
double p_f = featureCounts[value].item<double>() / totalWeight; double p_f = featureCounts[value].item<double>() / totalWeight;

View File

@ -103,7 +103,7 @@ namespace bayesnet {
// Make a complete graph // Make a complete graph
for (int i = 0; i < num_features - 1; ++i) { for (int i = 0; i < num_features - 1; ++i) {
for (int j = i; j < num_features; ++j) { for (int j = i + 1; j < num_features; ++j) {
g.addEdge(i, j, weights[i][j].item<float>()); g.addEdge(i, j, weights[i][j].item<float>());
} }
} }

View File

@ -5,4 +5,4 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc) add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc)
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES} ") target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@ -89,6 +89,10 @@ namespace platform {
} }
return datasets[name]->getTensors(); return datasets[name]->getTensors();
} }
bool Datasets::isDataset(string name)
{
return datasets.find(name) != datasets.end();
}
Dataset::Dataset(Dataset& dataset) Dataset::Dataset(Dataset& dataset)
{ {
path = dataset.path; path = dataset.path;

View File

@ -58,6 +58,7 @@ namespace platform {
pair<vector<vector<float>>&, vector<int>&> getVectors(string name); pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name); pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
pair<torch::Tensor&, torch::Tensor&> getTensors(string name); pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
bool isDataset(string name);
}; };
}; };

View File

@ -72,7 +72,7 @@ void StratifiedKFold::build()
} }
while (remainder_samples_to_take > 0) { while (remainder_samples_to_take > 0) {
int fold = (rand() % static_cast<int>(k)); int fold = (rand() % static_cast<int>(k));
if (stratified_indices[fold].size() == fold_size) { if (stratified_indices[fold].size() == fold_size + 1) {
continue; continue;
} }
auto it = next(class_indices[label].begin(), 1); auto it = next(class_indices[label].begin(), 1);

View File

@ -81,6 +81,10 @@ int main(int argc, char** argv)
vector<string> filesToProcess; vector<string> filesToProcess;
auto datasets = platform::Datasets(path, true, platform::ARFF); auto datasets = platform::Datasets(path, true, platform::ARFF);
if (file_name != "") { if (file_name != "") {
if (!datasets.isDataset(file_name)) {
cerr << "Dataset " << file_name << " not found" << endl;
exit(1);
}
filesToProcess.push_back(file_name); filesToProcess.push_back(file_name);
} else { } else {
filesToProcess = platform::Datasets(path, true, platform::ARFF).getNames(); filesToProcess = platform::Datasets(path, true, platform::ARFF).getNames();