Ensemble Experiment, Folding, Classifiers and Network together

This commit is contained in:
2023-07-23 14:10:28 +02:00
parent 644b6c9be0
commit 0c226371cc
12 changed files with 116 additions and 31 deletions

View File

@@ -17,16 +17,19 @@
using namespace std;
pair<float, float> cross_validation(Fold* fold, bayesnet::BaseClassifier* model, Tensor& X, Tensor& y, int k)
pair<float, float> cross_validation(Fold* fold, bayesnet::BaseClassifier* model, Tensor& X, Tensor& y, vector<string> features, string className, map<string, vector<int>> states)
{
auto k = fold->getNumberOfFolds();
float accuracy = 0.0;
for (int i = 0; i < k; i++) {
auto [train, test] = fold->getFold(i);
auto X_train = X.indices{ train };
auto y_train = y.indices{ train };
auto X_test = X.indices{ test };
auto y_test = y.indices{ test };
model->fit(X_train, y_train);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);
auto X_train = X.index({ train_t });
auto y_train = y.index({ train_t });
auto X_test = X.index({ test_t });
auto y_test = y.index({ test_t });
model->fit(X_train, y_train, features, className, states);
auto acc = model->score(X_test, y_test);
accuracy += acc;
}
@@ -97,9 +100,12 @@ int main(int argc, char** argv)
/*
* Begin Processing
*/
auto [X, y, features] = loadDataset(file_name, discretize_dataset);
auto [X, y, features, className] = loadDataset(file_name, discretize_dataset, class_last);
auto states = map<string, vector<int>>();
if (discretize_dataset) {
auto [discretized, maxes] = discretize(X, y, features);
auto [Xd, maxes] = discretizeTorch(X, y, features);
states = get_states(Xd, y, features, className);
X = Xd;
}
auto fold = StratifiedKFold(5, y, -1);
auto classifiers = map<string, bayesnet::BaseClassifier*>({
@@ -108,7 +114,7 @@ int main(int argc, char** argv)
}
);
bayesnet::BaseClassifier* model = classifiers[model_name];
auto results = cross_validation(model, X, y, fold, 5);
auto results = cross_validation(&fold, model, X, y, features, className, states);
cout << "Accuracy: " << results.first << endl;
return 0;
}
}