Fix stratified folding mistake in remainders
This commit is contained in:
parent
3d8fea7a37
commit
7f7ddad36a
9
.vscode/launch.json
vendored
9
.vscode/launch.json
vendored
@ -26,12 +26,13 @@
|
|||||||
"-m",
|
"-m",
|
||||||
"TAN",
|
"TAN",
|
||||||
"-p",
|
"-p",
|
||||||
"datasets",
|
"/Users/rmontanana/Code/discretizbench/datasets",
|
||||||
"--discretize",
|
"--discretize",
|
||||||
"-f",
|
"--stratified",
|
||||||
"5",
|
|
||||||
"--title",
|
"--title",
|
||||||
"Debug test"
|
"Debug test",
|
||||||
|
"-d",
|
||||||
|
"ionosphere"
|
||||||
],
|
],
|
||||||
"cwd": "${workspaceFolder}/build/src/Platform",
|
"cwd": "${workspaceFolder}/build/src/Platform",
|
||||||
},
|
},
|
||||||
|
@ -94,7 +94,7 @@ namespace bayesnet {
|
|||||||
totalWeight += 1;
|
totalWeight += 1;
|
||||||
}
|
}
|
||||||
if (totalWeight == 0)
|
if (totalWeight == 0)
|
||||||
throw invalid_argument("Total weight should not be zero");
|
return 0;
|
||||||
double entropyValue = 0;
|
double entropyValue = 0;
|
||||||
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
|
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
|
||||||
double p_f = featureCounts[value].item<double>() / totalWeight;
|
double p_f = featureCounts[value].item<double>() / totalWeight;
|
||||||
|
@ -103,7 +103,7 @@ namespace bayesnet {
|
|||||||
|
|
||||||
// Make a complete graph
|
// Make a complete graph
|
||||||
for (int i = 0; i < num_features - 1; ++i) {
|
for (int i = 0; i < num_features - 1; ++i) {
|
||||||
for (int j = i; j < num_features; ++j) {
|
for (int j = i + 1; j < num_features; ++j) {
|
||||||
g.addEdge(i, j, weights[i][j].item<float>());
|
g.addEdge(i, j, weights[i][j].item<float>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -89,6 +89,10 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return datasets[name]->getTensors();
|
return datasets[name]->getTensors();
|
||||||
}
|
}
|
||||||
|
bool Datasets::isDataset(string name)
|
||||||
|
{
|
||||||
|
return datasets.find(name) != datasets.end();
|
||||||
|
}
|
||||||
Dataset::Dataset(Dataset& dataset)
|
Dataset::Dataset(Dataset& dataset)
|
||||||
{
|
{
|
||||||
path = dataset.path;
|
path = dataset.path;
|
||||||
|
@ -58,6 +58,7 @@ namespace platform {
|
|||||||
pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
|
pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
|
||||||
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
|
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
|
||||||
pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
|
pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
|
||||||
|
bool isDataset(string name);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ void StratifiedKFold::build()
|
|||||||
}
|
}
|
||||||
while (remainder_samples_to_take > 0) {
|
while (remainder_samples_to_take > 0) {
|
||||||
int fold = (rand() % static_cast<int>(k));
|
int fold = (rand() % static_cast<int>(k));
|
||||||
if (stratified_indices[fold].size() == fold_size) {
|
if (stratified_indices[fold].size() == fold_size + 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto it = next(class_indices[label].begin(), 1);
|
auto it = next(class_indices[label].begin(), 1);
|
||||||
|
@ -81,6 +81,10 @@ int main(int argc, char** argv)
|
|||||||
vector<string> filesToProcess;
|
vector<string> filesToProcess;
|
||||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||||
if (file_name != "") {
|
if (file_name != "") {
|
||||||
|
if (!datasets.isDataset(file_name)) {
|
||||||
|
cerr << "Dataset " << file_name << " not found" << endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
filesToProcess.push_back(file_name);
|
filesToProcess.push_back(file_name);
|
||||||
} else {
|
} else {
|
||||||
filesToProcess = platform::Datasets(path, true, platform::ARFF).getNames();
|
filesToProcess = platform::Datasets(path, true, platform::ARFF).getNames();
|
||||||
|
Loading…
Reference in New Issue
Block a user