Fix stratified folding mistake in remainders

This commit is contained in:
Ricardo Montañana Gómez 2023-07-27 16:51:27 +02:00
parent 3d8fea7a37
commit 7f7ddad36a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
8 changed files with 18 additions and 8 deletions

9
.vscode/launch.json vendored
View File

@ -26,12 +26,13 @@
"-m",
"TAN",
"-p",
"datasets",
"/Users/rmontanana/Code/discretizbench/datasets",
"--discretize",
"-f",
"5",
"--stratified",
"--title",
"Debug test"
"Debug test",
"-d",
"ionosphere"
],
"cwd": "${workspaceFolder}/build/src/Platform",
},

View File

@ -94,7 +94,7 @@ namespace bayesnet {
totalWeight += 1;
}
if (totalWeight == 0)
throw invalid_argument("Total weight should not be zero");
return 0;
double entropyValue = 0;
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
double p_f = featureCounts[value].item<double>() / totalWeight;

View File

@ -103,7 +103,7 @@ namespace bayesnet {
// Make a complete graph
for (int i = 0; i < num_features - 1; ++i) {
for (int j = i; j < num_features; ++j) {
for (int j = i + 1; j < num_features; ++j) {
g.addEdge(i, j, weights[i][j].item<float>());
}
}

View File

@ -5,4 +5,4 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc)
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES} ")
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@ -89,6 +89,10 @@ namespace platform {
}
return datasets[name]->getTensors();
}
bool Datasets::isDataset(string name)
{
return datasets.find(name) != datasets.end();
}
Dataset::Dataset(Dataset& dataset)
{
path = dataset.path;

View File

@ -58,6 +58,7 @@ namespace platform {
pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
bool isDataset(string name);
};
};

View File

@ -72,7 +72,7 @@ void StratifiedKFold::build()
}
while (remainder_samples_to_take > 0) {
int fold = (rand() % static_cast<int>(k));
if (stratified_indices[fold].size() == fold_size) {
if (stratified_indices[fold].size() == fold_size + 1) {
continue;
}
auto it = next(class_indices[label].begin(), 1);

View File

@ -81,6 +81,10 @@ int main(int argc, char** argv)
vector<string> filesToProcess;
auto datasets = platform::Datasets(path, true, platform::ARFF);
if (file_name != "") {
if (!datasets.isDataset(file_name)) {
cerr << "Dataset " << file_name << " not found" << endl;
exit(1);
}
filesToProcess.push_back(file_name);
} else {
filesToProcess = platform::Datasets(path, true, platform::ARFF).getNames();