Ensemble Experiment, Folding, Classifiers and Network together
This commit is contained in:
@@ -17,16 +17,19 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
pair<float, float> cross_validation(Fold* fold, bayesnet::BaseClassifier* model, Tensor& X, Tensor& y, int k)
|
||||
pair<float, float> cross_validation(Fold* fold, bayesnet::BaseClassifier* model, Tensor& X, Tensor& y, vector<string> features, string className, map<string, vector<int>> states)
|
||||
{
|
||||
auto k = fold->getNumberOfFolds();
|
||||
float accuracy = 0.0;
|
||||
for (int i = 0; i < k; i++) {
|
||||
auto [train, test] = fold->getFold(i);
|
||||
auto X_train = X.indices{ train };
|
||||
auto y_train = y.indices{ train };
|
||||
auto X_test = X.indices{ test };
|
||||
auto y_test = y.indices{ test };
|
||||
model->fit(X_train, y_train);
|
||||
auto train_t = torch::tensor(train);
|
||||
auto test_t = torch::tensor(test);
|
||||
auto X_train = X.index({ train_t });
|
||||
auto y_train = y.index({ train_t });
|
||||
auto X_test = X.index({ test_t });
|
||||
auto y_test = y.index({ test_t });
|
||||
model->fit(X_train, y_train, features, className, states);
|
||||
auto acc = model->score(X_test, y_test);
|
||||
accuracy += acc;
|
||||
}
|
||||
@@ -97,9 +100,12 @@ int main(int argc, char** argv)
|
||||
/*
|
||||
* Begin Processing
|
||||
*/
|
||||
auto [X, y, features] = loadDataset(file_name, discretize_dataset);
|
||||
auto [X, y, features, className] = loadDataset(file_name, discretize_dataset, class_last);
|
||||
auto states = map<string, vector<int>>();
|
||||
if (discretize_dataset) {
|
||||
auto [discretized, maxes] = discretize(X, y, features);
|
||||
auto [Xd, maxes] = discretizeTorch(X, y, features);
|
||||
states = get_states(Xd, y, features, className);
|
||||
X = Xd;
|
||||
}
|
||||
auto fold = StratifiedKFold(5, y, -1);
|
||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||
@@ -108,7 +114,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
);
|
||||
bayesnet::BaseClassifier* model = classifiers[model_name];
|
||||
auto results = cross_validation(model, X, y, fold, 5);
|
||||
auto results = cross_validation(&fold, model, X, y, features, className, states);
|
||||
cout << "Accuracy: " << results.first << endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user