From e628d80f4ca261d9372c426d06b38d48dad60cf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 11 Jun 2024 13:52:26 +0200 Subject: [PATCH] Experiment working with smoothing and disc-algo --- src/grid/GridSearch.cpp | 5 +++-- src/main/Experiment.cpp | 3 +-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/grid/GridSearch.cpp b/src/grid/GridSearch.cpp index 6d8699c..2f79938 100644 --- a/src/grid/GridSearch.cpp +++ b/src/grid/GridSearch.cpp @@ -142,6 +142,7 @@ namespace platform { auto states = dataset.getStates(); // Get the states of the features Once they are discretized double best_fold_score = 0.0; int best_idx_combination = -1; + bayesnet::Smoothing_t smoothing = bayesnet::Smoothing_t::NONE; json best_fold_hyper; for (int idx_combination = 0; idx_combination < combinations.size(); ++idx_combination) { auto hyperparam_line = combinations[idx_combination]; @@ -167,7 +168,7 @@ namespace platform { hyperparameters.check(valid, dataset_name); clf->setHyperparameters(hyperparameters.get(dataset_name)); // Train model - clf->fit(X_nested_train, y_nested_train, features, className, states); + clf->fit(X_nested_train, y_nested_train, features, className, states, smoothing); // Test model score += clf->score(X_nested_test, y_nested_test); } @@ -186,7 +187,7 @@ namespace platform { auto valid = clf->getValidHyperparameters(); hyperparameters.check(valid, dataset_name); clf->setHyperparameters(best_fold_hyper); - clf->fit(X_train, y_train, features, className, states); + clf->fit(X_train, y_train, features, className, states, smoothing); best_fold_score = clf->score(X_test, y_test); // Return the result result->idx_dataset = task["idx_dataset"].get(); diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 3866c97..5e9fab2 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -194,8 +194,7 @@ namespace platform { // // Train model // - clf->setSmoothing(smooth_type); - clf->fit(X_train, y_train, features, className, states); + clf->fit(X_train, y_train, features, className, states, smooth_type); if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "b"); auto clf_notes = clf->getNotes();