Begin output nested grid
This commit is contained in:
parent
03e4437fea
commit
fb9b395748
@ -174,7 +174,6 @@ namespace platform {
|
||||
auto className = datasets.getClassName(fileName);
|
||||
double bestScore = 0.0;
|
||||
json bestHyperparameters;
|
||||
int numItems = 0;
|
||||
// for dataset // for seed // for fold // for hyperparameters // for nested fold
|
||||
for (const auto& seed : config.seeds) {
|
||||
Fold* fold;
|
||||
@ -183,7 +182,9 @@ namespace platform {
|
||||
else
|
||||
fold = new KFold(config.n_folds, y.size(0), seed);
|
||||
double bestScore = 0.0;
|
||||
std::cout << "(" << seed << ") doing Fold: " << flush;
|
||||
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
||||
std::cout << nfold + 1 << flush;
|
||||
// First level fold
|
||||
auto [train, test] = fold->getFold(nfold);
|
||||
auto train_t = torch::tensor(train);
|
||||
@ -192,16 +193,19 @@ namespace platform {
|
||||
auto y_train = y.index({ train_t });
|
||||
auto X_test = X.index({ "...", test_t });
|
||||
auto y_test = y.index({ test_t });
|
||||
auto num = 0;
|
||||
json result_fold;
|
||||
double score = 0.0;
|
||||
for (const auto& hyperparam_line : combinations) {
|
||||
std::cout << "[" << ++num << "/" << combinations.size() << "] " << std::flush;
|
||||
Fold* nested_fold;
|
||||
if (config.stratified)
|
||||
nested_fold = new StratifiedKFold(config.nested, y_train, seed);
|
||||
else
|
||||
nested_fold = new KFold(config.nested, y_train.size(0), seed);
|
||||
|
||||
for (int n_nested_fold = 0; n_nested_fold < config.nested; n_nested_fold++) {
|
||||
// Nested level fold
|
||||
auto [train_nested, test_nested] = fold->getFold(n_nested_fold);
|
||||
auto [train_nested, test_nested] = nested_fold->getFold(n_nested_fold);
|
||||
auto train_nested_t = torch::tensor(train_nested);
|
||||
auto test_nested_t = torch::tensor(test_nested);
|
||||
auto X_nexted_train = X_train.index({ "...", train_nested_t });
|
||||
@ -216,14 +220,17 @@ namespace platform {
|
||||
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||
// Train model
|
||||
if (!config.quiet)
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||
clf->fit(X_nexted_train, y_nested_train, features, className, states);
|
||||
showProgressFold(n_nested_fold + 1, getColor(clf->getStatus()), "a");
|
||||
// clf->fit(X_nexted_train, y_nested_train, features, className, states);
|
||||
// Test model
|
||||
if (!config.quiet)
|
||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||
bestScore += clf->score(X_nested_test, y_nested_test);
|
||||
showProgressFold(n_nested_fold + 1, getColor(clf->getStatus()), "b");
|
||||
// score += clf->score(X_nested_test, y_nested_test);
|
||||
score = 0.0;
|
||||
}
|
||||
delete nested_fold;
|
||||
score = score / config.nested;
|
||||
|
||||
}
|
||||
}
|
||||
delete fold;
|
||||
|
Loading…
Reference in New Issue
Block a user