Add hyperparameters management in experiments
This commit is contained in:
@@ -25,6 +25,7 @@ namespace platform {
|
||||
oss << std::put_time(timeinfo, "%H:%M:%S");
|
||||
return oss.str();
|
||||
}
|
||||
Experiment::Experiment() : hyperparameters(json::parse("{}")) {}
|
||||
string Experiment::get_file_name()
|
||||
{
|
||||
string result = "results_" + score_name + "_" + model + "_" + platform + "_" + get_date() + "_" + get_time() + "_" + (stratified ? "1" : "0") + ".json";
|
||||
@@ -124,6 +125,8 @@ namespace platform {
|
||||
auto result = Result();
|
||||
auto [values, counts] = at::_unique(y);
|
||||
result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
||||
result.setHyperparameters(hyperparameters);
|
||||
// Initialize results vectors
|
||||
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
||||
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||
@@ -144,6 +147,10 @@ namespace platform {
|
||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||
auto clf = Models::instance()->create(model);
|
||||
setModelVersion(clf->getVersion());
|
||||
if (hyperparameters.size() != 0) {
|
||||
clf->setHyperparameters(hyperparameters);
|
||||
}
|
||||
// Split train - test dataset
|
||||
train_timer.start();
|
||||
auto [train, test] = fold->getFold(nfold);
|
||||
auto train_t = torch::tensor(train);
|
||||
@@ -153,12 +160,14 @@ namespace platform {
|
||||
auto X_test = X.index({ "...", test_t });
|
||||
auto y_test = y.index({ test_t });
|
||||
cout << nfold + 1 << ", " << flush;
|
||||
// Train model
|
||||
clf->fit(X_train, y_train, features, className, states);
|
||||
nodes[item] = clf->getNumberOfNodes();
|
||||
edges[item] = clf->getNumberOfEdges();
|
||||
num_states[item] = clf->getNumberOfStates();
|
||||
train_time[item] = train_timer.getDuration();
|
||||
auto accuracy_train_value = clf->score(X_train, y_train);
|
||||
// Test model
|
||||
test_timer.start();
|
||||
auto accuracy_test_value = clf->score(X_test, y_test);
|
||||
test_time[item] = test_timer.getDuration();
|
||||
|
Reference in New Issue
Block a user