Experiment working with smoothing and disc-algo

This commit is contained in:
2024-06-11 13:52:26 +02:00
parent 0f06f8971e
commit e628d80f4c
2 changed files with 4 additions and 4 deletions

View File

@@ -142,6 +142,7 @@ namespace platform {
auto states = dataset.getStates(); // Get the states of the features Once they are discretized auto states = dataset.getStates(); // Get the states of the features Once they are discretized
double best_fold_score = 0.0; double best_fold_score = 0.0;
int best_idx_combination = -1; int best_idx_combination = -1;
bayesnet::Smoothing_t smoothing = bayesnet::Smoothing_t::NONE;
json best_fold_hyper; json best_fold_hyper;
for (int idx_combination = 0; idx_combination < combinations.size(); ++idx_combination) { for (int idx_combination = 0; idx_combination < combinations.size(); ++idx_combination) {
auto hyperparam_line = combinations[idx_combination]; auto hyperparam_line = combinations[idx_combination];
@@ -167,7 +168,7 @@ namespace platform {
hyperparameters.check(valid, dataset_name); hyperparameters.check(valid, dataset_name);
clf->setHyperparameters(hyperparameters.get(dataset_name)); clf->setHyperparameters(hyperparameters.get(dataset_name));
// Train model // Train model
clf->fit(X_nested_train, y_nested_train, features, className, states); clf->fit(X_nested_train, y_nested_train, features, className, states, smoothing);
// Test model // Test model
score += clf->score(X_nested_test, y_nested_test); score += clf->score(X_nested_test, y_nested_test);
} }
@@ -186,7 +187,7 @@ namespace platform {
auto valid = clf->getValidHyperparameters(); auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, dataset_name); hyperparameters.check(valid, dataset_name);
clf->setHyperparameters(best_fold_hyper); clf->setHyperparameters(best_fold_hyper);
clf->fit(X_train, y_train, features, className, states); clf->fit(X_train, y_train, features, className, states, smoothing);
best_fold_score = clf->score(X_test, y_test); best_fold_score = clf->score(X_test, y_test);
// Return the result // Return the result
result->idx_dataset = task["idx_dataset"].get<int>(); result->idx_dataset = task["idx_dataset"].get<int>();

View File

@@ -194,8 +194,7 @@ namespace platform {
// //
// Train model // Train model
// //
clf->setSmoothing(smooth_type); clf->fit(X_train, y_train, features, className, states, smooth_type);
clf->fit(X_train, y_train, features, className, states);
if (!quiet) if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "b"); showProgress(nfold + 1, getColor(clf->getStatus()), "b");
auto clf_notes = clf->getNotes(); auto clf_notes = clf->getNotes();