From e791d2edf5dc423090e38575adb8c6f7dea502ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Wed, 10 Mar 2021 16:42:18 +0100 Subject: [PATCH] Add 10 random seeds run in crossval Add testwodt comparison --- experimentation/Database.py | 13 +++++---- experimentation/Experiments.py | 25 +++++++++++------ testwodt.py | 2 +- testwodt_output.txt | 51 ++++++++++++++++++++++++++++++++++ wodt_comparecomputedpaper.txt | 51 ++++++++++++++++++++++++++++++++++ 5 files changed, 127 insertions(+), 15 deletions(-) create mode 100644 testwodt_output.txt create mode 100644 wodt_comparecomputedpaper.txt diff --git a/experimentation/Database.py b/experimentation/Database.py index e753aaf..ca1b90b 100644 --- a/experimentation/Database.py +++ b/experimentation/Database.py @@ -3,6 +3,7 @@ import sqlite3 from datetime import datetime from abc import ABC from typing import List +import numpy as np import mysql.connector from ast import literal_eval as make_tuple from sshtunnel import SSHTunnelForwarder @@ -322,8 +323,8 @@ class Outcomes(BD): outcomes = ["fit_time", "score_time", "train_score", "test_score"] data = "" for index in outcomes: - data += ", " + str(results[index].mean()) + ", " - data += str(results[index].std()) + data += ", " + str(np.mean(results[index])) + ", " + data += str(np.std(results[index])) command = ( f"insert or replace into {self._table} ('dataset', 'parameters', " "'date', 'normalize', 'standardize'" @@ -341,12 +342,12 @@ class Outcomes(BD): normalize, standardize, [ - float(results["test_score"].mean()), - float(results["test_score"].std()), + float(np.mean(results["test_score"])), + float(np.std(results["test_score"])), ], [ - float(results["fit_time"].mean()), - float(results["fit_time"].std()), + float(np.mean(results["fit_time"])), + float(np.std(results["fit_time"])), ], parameters, ) diff --git a/experimentation/Experiments.py b/experimentation/Experiments.py index 8d1390b..4f63f72 100644 --- a/experimentation/Experiments.py +++ b/experimentation/Experiments.py @@ -75,16 +75,25 @@ class Experiment: model.set_params(**parameters) self._num_warnings = 0 warnings.warn = self._warn - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - # Also affect subprocesses - os.environ["PYTHONWARNINGS"] = "ignore" - results = cross_validate( - model, X, y, return_train_score=True, n_jobs=self._threads - ) + # Execute with 10 different seeds to ease the random effect + total = {} + outcomes = ["fit_time", "score_time", "train_score", "test_score"] + for item in outcomes: + total[item] = [] + for random_state in [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]: + model.set_params(**{"random_state": random_state}) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + # Also affect subprocesses + os.environ["PYTHONWARNINGS"] = "ignore" + results = cross_validate( + model, X, y, return_train_score=True, n_jobs=self._threads + ) + for item in outcomes: + total[item].append(results[item]) outcomes = Outcomes(host=self._host, model=self._model_name) parameters = json.dumps(parameters, sort_keys=True) - outcomes.store(dataset, normalize, standardize, parameters, results) + outcomes.store(dataset, normalize, standardize, parameters, total) if self._num_warnings > 0: print(f"{self._num_warnings} warnings have happend") diff --git a/testwodt.py b/testwodt.py index 07a5342..97115d4 100644 --- a/testwodt.py +++ b/testwodt.py @@ -117,7 +117,7 @@ if dataset == "all": ) print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}") for dataset in dt: - print(f"- {dataset[0]:20s} ", end="") + print(f"- {dataset[0]:30s} ", end="") scores = process_dataset(dataset[0], verbose=False) print(f"{np.mean(scores):6.4f}±{np.std(scores):6.4f}") else: diff --git a/testwodt_output.txt b/testwodt_output.txt new file mode 100644 index 0000000..9874235 --- /dev/null +++ b/testwodt_output.txt @@ -0,0 +1,51 @@ +* Process all datasets set: tanveer [-1, 1]: True norm: False std: False +5 Fold Cross Validation with 10 random seeds [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] +- balance-scale 0.8462±0.0853 +- balloons 0.6133±0.2012 +- breast-cancer-wisc-diag 0.9399±0.0194 +- breast-cancer-wisc-prog 0.6898±0.0833 +- breast-cancer-wisc 0.9414±0.0307 +- breast-cancer 0.6336±0.0707 +- cardiotocography-10clases 0.6022±0.0335 +- cardiotocography-3clases 0.8217±0.0565 +- conn-bench-sonar-mines-rocks 0.6235±0.1017 +- cylinder-bands 0.5593±0.0437 +- dermatology 0.9528±0.0252 +- echocardiogram 0.7475±0.0872 +- fertility 0.7550±0.1404 +- haberman-survival 0.6699±0.0626 +- heart-hungarian 0.7654±0.0531 +- hepatitis 0.7806±0.0793 +- ilpd-indian-liver 0.6450±0.0445 +- ionosphere 0.8696±0.0439 +- iris 0.9460±0.0339 +- led-display 0.7059±0.0168 +- libras 0.6803±0.1177 +- low-res-spect 0.8563±0.0310 +- lymphography 0.8151±0.0600 +- mammographic 0.7715±0.0378 +- molec-biol-promoter 0.7339±0.1046 +- musk-1 0.7231±0.0721 +- oocytes_merluccius_nucleus_4d 0.7013±0.0376 +- oocytes_merluccius_states_2f 0.8839±0.0314 +- oocytes_trisopterus_nucleus_2f 0.6516±0.0534 +- oocytes_trisopterus_states_5b 0.7711±0.1160 +- parkinsons 0.8262±0.0892 +- pima 0.7007±0.0341 +- pittsburg-bridges-MATERIAL 0.7817±0.1466 +- pittsburg-bridges-REL-L 0.5570±0.1307 +- pittsburg-bridges-SPAN 0.5491±0.0985 +- pittsburg-bridges-T-OR-D 0.7930±0.1684 +- planning 0.5406±0.0762 +- post-operative 0.5744±0.0690 +- seeds 0.9071±0.0784 +- statlog-australian-credit 0.5654±0.0353 +- statlog-german-credit 0.6963±0.0243 +- statlog-heart 0.7811±0.0462 +- statlog-image 0.9472±0.0119 +- statlog-vehicle 0.7084±0.0310 +- synthetic-control 0.9825±0.0149 +- tic-tac-toe 0.8356±0.1050 +- vertebral-column-2clases 0.7865±0.1203 +- wine 0.9659±0.0210 +- zoo 0.9402±0.0448 diff --git a/wodt_comparecomputedpaper.txt b/wodt_comparecomputedpaper.txt new file mode 100644 index 0000000..fbe0248 --- /dev/null +++ b/wodt_comparecomputedpaper.txt @@ -0,0 +1,51 @@ + In DB Computed as in paper + ========= ==================== +balance-scale 0.8528 0.8462±0.0853 +balloons 0.633333 0.6133±0.2012 +breast-cancer-wisc-diag 0.963065 0.9399±0.0194 +breast-cancer-wisc-prog 0.666282 0.6898±0.0833 +breast-cancer-wisc 0.941367 0.9414±0.0307 +breast-cancer 0.573382 0.6336±0.0707 +cardiotocography-10clases 0.627007 0.6022±0.0335 +cardiotocography-3clases 0.800097 0.8217±0.0565 +conn-bench-sonar-mines-rocks 0.63043 0.6235±0.1017 +cylinder-bands 0.541043 0.5593±0.0437 +dermatology 0.950907 0.9528±0.0252 +echocardiogram 0.640456 0.7475±0.0872 +fertility 0.69 0.7550±0.1404 +haberman-survival 0.659757 0.6699±0.0626 +heart-hungarian 0.758212 0.7654±0.0531 +hepatitis 0.76129 0.7806±0.0793 +ilpd-indian-liver 0.6673 0.6450±0.0445 +ionosphere 0.83497 0.8696±0.0439 +iris 0.966667 0.9460±0.0339 +led-display 0.704 0.7059±0.0168 +libras 0.658333 0.6803±0.1177 +low-res-spect 0.870111 0.8563±0.0310 +lymphography 0.736092 0.8151±0.0600 +mammographic 0.764848 0.7715±0.0378 +molec-biol-promoter 0.772294 0.7339±0.1046 +musk-1 0.722566 0.7231±0.0721 +oocytes_merluccius_nucleus_4d 0.718125 0.7013±0.0376 +oocytes_merluccius_states_2f 0.888431 0.8839±0.0314 +oocytes_trisopterus_nucleus_2f 0.674263 0.6516±0.0534 +oocytes_trisopterus_states_5b 0.798199 0.7711±0.1160 +parkinsons 0.835897 0.8262±0.0892 +pima 0.692793 0.7007±0.0341 +pittsburg-bridges-MATERIAL 0.785714 0.7817±0.1466 +pittsburg-bridges-REL-L 0.590476 0.5570±0.1307 +pittsburg-bridges-SPAN 0.545029 0.5491±0.0985 +pittsburg-bridges-T-OR-D 0.742857 0.7930±0.1684 +planning 0.478228 0.5406±0.0762 +post-operative 0.588889 0.5744±0.0690 +seeds 0.928571 0.9071±0.0784 +statlog-australian-credit 0.55942 0.5654±0.0353 +statlog-german-credit 0.696 0.6963±0.0243 +statlog-heart 0.788889 0.7811±0.0462 +statlog-image 0.955844 0.9472±0.0119 +statlog-vehicle 0.719875 0.7084±0.0310 +synthetic-control 0.97 0.9825±0.0149 +tic-tac-toe 0.866416 0.8356±0.1050 +vertebral-column-2clases 0.790323 0.7865±0.1203 +wine 0.972063 0.9659±0.0210 +zoo 0.94 0.9402±0.0448