Refactor cross_validation
This commit is contained in:
parent
b9e76becce
commit
adc0ca238f
@ -88,7 +88,7 @@ namespace platform {
|
|||||||
json data = build_json();
|
json data = build_json();
|
||||||
cout << data.dump(4) << endl;
|
cout << data.dump(4) << endl;
|
||||||
}
|
}
|
||||||
Result cross_validation(Fold* fold, string model_name, torch::Tensor& Xt, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states)
|
Result Experiment::cross_validation(string model_name, torch::Tensor& Xt, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states)
|
||||||
{
|
{
|
||||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||||
@ -98,41 +98,57 @@ namespace platform {
|
|||||||
auto result = Result();
|
auto result = Result();
|
||||||
auto [values, counts] = at::_unique(y);
|
auto [values, counts] = at::_unique(y);
|
||||||
result.setSamples(Xt.size(1)).setFeatures(Xt.size(0)).setClasses(values.size(0));
|
result.setSamples(Xt.size(1)).setFeatures(Xt.size(0)).setClasses(values.size(0));
|
||||||
auto k = fold->getNumberOfFolds();
|
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
||||||
auto accuracy_test = torch::zeros({ k }, torch::kFloat64);
|
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto accuracy_train = torch::zeros({ k }, torch::kFloat64);
|
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto train_time = torch::zeros({ k }, torch::kFloat64);
|
auto train_time = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto test_time = torch::zeros({ k }, torch::kFloat64);
|
auto test_time = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto nodes = torch::zeros({ k }, torch::kFloat64);
|
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto edges = torch::zeros({ k }, torch::kFloat64);
|
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto num_states = torch::zeros({ k }, torch::kFloat64);
|
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
Timer train_timer, test_timer;
|
Timer train_timer, test_timer;
|
||||||
|
int item = 0;
|
||||||
|
for (auto seed : randomSeeds) {
|
||||||
|
cout << "(" << seed << ") " << flush;
|
||||||
|
Fold* fold;
|
||||||
|
if (stratified)
|
||||||
|
fold = new StratifiedKFold(nfolds, y, seed);
|
||||||
|
else
|
||||||
|
fold = new KFold(nfolds, y.size(0), seed);
|
||||||
cout << "doing Fold: " << flush;
|
cout << "doing Fold: " << flush;
|
||||||
for (int i = 0; i < k; i++) {
|
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||||
bayesnet::BaseClassifier* model = classifiers[model_name];
|
bayesnet::BaseClassifier* clf = classifiers[model];
|
||||||
|
setModelVersion(clf->getVersion());
|
||||||
train_timer.start();
|
train_timer.start();
|
||||||
auto [train, test] = fold->getFold(i);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
auto test_t = torch::tensor(test);
|
auto test_t = torch::tensor(test);
|
||||||
auto X_train = Xt.index({ "...", train_t });
|
auto X_train = Xt.index({ "...", train_t });
|
||||||
auto y_train = y.index({ train_t });
|
auto y_train = y.index({ train_t });
|
||||||
auto X_test = Xt.index({ "...", test_t });
|
auto X_test = Xt.index({ "...", test_t });
|
||||||
auto y_test = y.index({ test_t });
|
auto y_test = y.index({ test_t });
|
||||||
cout << i + 1 << ", " << flush;
|
cout << nfold + 1 << ", " << flush;
|
||||||
model->fit(X_train, y_train, features, className, states);
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
nodes[i] = model->getNumberOfNodes();
|
nodes[item] = clf->getNumberOfNodes();
|
||||||
edges[i] = model->getNumberOfEdges();
|
edges[item] = clf->getNumberOfEdges();
|
||||||
num_states[i] = model->getNumberOfStates();
|
num_states[item] = clf->getNumberOfStates();
|
||||||
train_time[i] = train_timer.getDuration();
|
train_time[item] = train_timer.getDuration();
|
||||||
auto accuracy_train_value = model->score(X_train, y_train);
|
auto accuracy_train_value = clf->score(X_train, y_train);
|
||||||
test_timer.start();
|
test_timer.start();
|
||||||
auto accuracy_test_value = model->score(X_test, y_test);
|
auto accuracy_test_value = clf->score(X_test, y_test);
|
||||||
test_time[i] = test_timer.getDuration();
|
test_time[item] = test_timer.getDuration();
|
||||||
accuracy_train[i] = accuracy_train_value;
|
accuracy_train[item] = accuracy_train_value;
|
||||||
accuracy_test[i] = accuracy_test_value;
|
accuracy_test[item] = accuracy_test_value;
|
||||||
|
// Store results and times in vector
|
||||||
|
result.addScoreTrain(accuracy_train_value);
|
||||||
|
result.addScoreTest(accuracy_test_value);
|
||||||
|
result.addTimeTrain(train_time[item].item<double>());
|
||||||
|
result.addTimeTest(test_time[item].item<double>());
|
||||||
|
item++;
|
||||||
}
|
}
|
||||||
cout << "end. " << flush;
|
cout << "end. " << flush;
|
||||||
|
delete fold;
|
||||||
|
}
|
||||||
result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
|
result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
|
||||||
result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
|
result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
|
||||||
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
||||||
|
@ -105,9 +105,9 @@ namespace platform {
|
|||||||
Experiment& setDuration(float duration) { this->duration = duration; return *this; }
|
Experiment& setDuration(float duration) { this->duration = duration; return *this; }
|
||||||
string get_file_name();
|
string get_file_name();
|
||||||
void save(string path);
|
void save(string path);
|
||||||
Result cross_validation(const string& path, const string& fileName);
|
//Result cross_validation(const string& path, const string& fileName);
|
||||||
|
Result cross_validation(string model_name, torch::Tensor& X, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states);
|
||||||
void show();
|
void show();
|
||||||
};
|
};
|
||||||
Result cross_validation(Fold* fold, string model_name, torch::Tensor& X, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -126,19 +126,9 @@ int main(int argc, char** argv)
|
|||||||
auto samples = datasets.getNSamples(fileName);
|
auto samples = datasets.getNSamples(fileName);
|
||||||
auto className = datasets.getClassName(fileName);
|
auto className = datasets.getClassName(fileName);
|
||||||
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
||||||
for (auto seed : seeds) {
|
auto result = experiment.cross_validation(model_name, X, y, features, className, states);
|
||||||
cout << "(" << seed << ") " << flush;
|
|
||||||
Fold* fold;
|
|
||||||
if (stratified)
|
|
||||||
fold = new StratifiedKFold(n_folds, y, seed);
|
|
||||||
else
|
|
||||||
fold = new KFold(n_folds, samples, seed);
|
|
||||||
auto result = platform::cross_validation(fold, model_name, X, y, features, className, states);
|
|
||||||
result.setDataset(fileName);
|
result.setDataset(fileName);
|
||||||
experiment.setModelVersion("-FIXME-");
|
|
||||||
experiment.addResult(result);
|
experiment.addResult(result);
|
||||||
delete fold;
|
|
||||||
}
|
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
experiment.setDuration(timer.getDuration());
|
experiment.setDuration(timer.getDuration());
|
||||||
|
Loading…
Reference in New Issue
Block a user