diff --git a/src/grid/GridSearch.cpp b/src/grid/GridSearch.cpp index 6d8699c..2f79938 100644 --- a/src/grid/GridSearch.cpp +++ b/src/grid/GridSearch.cpp @@ -142,6 +142,7 @@ namespace platform { auto states = dataset.getStates(); // Get the states of the features Once they are discretized double best_fold_score = 0.0; int best_idx_combination = -1; + bayesnet::Smoothing_t smoothing = bayesnet::Smoothing_t::NONE; json best_fold_hyper; for (int idx_combination = 0; idx_combination < combinations.size(); ++idx_combination) { auto hyperparam_line = combinations[idx_combination]; @@ -167,7 +168,7 @@ namespace platform { hyperparameters.check(valid, dataset_name); clf->setHyperparameters(hyperparameters.get(dataset_name)); // Train model - clf->fit(X_nested_train, y_nested_train, features, className, states); + clf->fit(X_nested_train, y_nested_train, features, className, states, smoothing); // Test model score += clf->score(X_nested_test, y_nested_test); } @@ -186,7 +187,7 @@ namespace platform { auto valid = clf->getValidHyperparameters(); hyperparameters.check(valid, dataset_name); clf->setHyperparameters(best_fold_hyper); - clf->fit(X_train, y_train, features, className, states); + clf->fit(X_train, y_train, features, className, states, smoothing); best_fold_score = clf->score(X_test, y_test); // Return the result result->idx_dataset = task["idx_dataset"].get(); diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 3866c97..5e9fab2 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -194,8 +194,7 @@ namespace platform { // // Train model // - clf->setSmoothing(smooth_type); - clf->fit(X_train, y_train, features, className, states); + clf->fit(X_train, y_train, features, className, states, smooth_type); if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "b"); auto clf_notes = clf->getNotes();