Add 10 random seeds run in crossval

Add testwodt comparison
This commit is contained in:
2021-03-10 16:42:18 +01:00
parent d4cfe77b18
commit e791d2edf5
5 changed files with 127 additions and 15 deletions

View File

@@ -3,6 +3,7 @@ import sqlite3
from datetime import datetime from datetime import datetime
from abc import ABC from abc import ABC
from typing import List from typing import List
import numpy as np
import mysql.connector import mysql.connector
from ast import literal_eval as make_tuple from ast import literal_eval as make_tuple
from sshtunnel import SSHTunnelForwarder from sshtunnel import SSHTunnelForwarder
@@ -322,8 +323,8 @@ class Outcomes(BD):
outcomes = ["fit_time", "score_time", "train_score", "test_score"] outcomes = ["fit_time", "score_time", "train_score", "test_score"]
data = "" data = ""
for index in outcomes: for index in outcomes:
data += ", " + str(results[index].mean()) + ", " data += ", " + str(np.mean(results[index])) + ", "
data += str(results[index].std()) data += str(np.std(results[index]))
command = ( command = (
f"insert or replace into {self._table} ('dataset', 'parameters', " f"insert or replace into {self._table} ('dataset', 'parameters', "
"'date', 'normalize', 'standardize'" "'date', 'normalize', 'standardize'"
@@ -341,12 +342,12 @@ class Outcomes(BD):
normalize, normalize,
standardize, standardize,
[ [
float(results["test_score"].mean()), float(np.mean(results["test_score"])),
float(results["test_score"].std()), float(np.std(results["test_score"])),
], ],
[ [
float(results["fit_time"].mean()), float(np.mean(results["fit_time"])),
float(results["fit_time"].std()), float(np.std(results["fit_time"])),
], ],
parameters, parameters,
) )

View File

@@ -75,16 +75,25 @@ class Experiment:
model.set_params(**parameters) model.set_params(**parameters)
self._num_warnings = 0 self._num_warnings = 0
warnings.warn = self._warn warnings.warn = self._warn
with warnings.catch_warnings(): # Execute with 10 different seeds to ease the random effect
warnings.filterwarnings("ignore") total = {}
# Also affect subprocesses outcomes = ["fit_time", "score_time", "train_score", "test_score"]
os.environ["PYTHONWARNINGS"] = "ignore" for item in outcomes:
results = cross_validate( total[item] = []
model, X, y, return_train_score=True, n_jobs=self._threads for random_state in [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]:
) model.set_params(**{"random_state": random_state})
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# Also affect subprocesses
os.environ["PYTHONWARNINGS"] = "ignore"
results = cross_validate(
model, X, y, return_train_score=True, n_jobs=self._threads
)
for item in outcomes:
total[item].append(results[item])
outcomes = Outcomes(host=self._host, model=self._model_name) outcomes = Outcomes(host=self._host, model=self._model_name)
parameters = json.dumps(parameters, sort_keys=True) parameters = json.dumps(parameters, sort_keys=True)
outcomes.store(dataset, normalize, standardize, parameters, results) outcomes.store(dataset, normalize, standardize, parameters, total)
if self._num_warnings > 0: if self._num_warnings > 0:
print(f"{self._num_warnings} warnings have happend") print(f"{self._num_warnings} warnings have happend")

View File

@@ -117,7 +117,7 @@ if dataset == "all":
) )
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}") print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}")
for dataset in dt: for dataset in dt:
print(f"- {dataset[0]:20s} ", end="") print(f"- {dataset[0]:30s} ", end="")
scores = process_dataset(dataset[0], verbose=False) scores = process_dataset(dataset[0], verbose=False)
print(f"{np.mean(scores):6.4f}±{np.std(scores):6.4f}") print(f"{np.mean(scores):6.4f}±{np.std(scores):6.4f}")
else: else:

51
testwodt_output.txt Normal file
View File

@@ -0,0 +1,51 @@
* Process all datasets set: tanveer [-1, 1]: True norm: False std: False
5 Fold Cross Validation with 10 random seeds [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
- balance-scale 0.8462±0.0853
- balloons 0.6133±0.2012
- breast-cancer-wisc-diag 0.9399±0.0194
- breast-cancer-wisc-prog 0.6898±0.0833
- breast-cancer-wisc 0.9414±0.0307
- breast-cancer 0.6336±0.0707
- cardiotocography-10clases 0.6022±0.0335
- cardiotocography-3clases 0.8217±0.0565
- conn-bench-sonar-mines-rocks 0.6235±0.1017
- cylinder-bands 0.5593±0.0437
- dermatology 0.9528±0.0252
- echocardiogram 0.7475±0.0872
- fertility 0.7550±0.1404
- haberman-survival 0.6699±0.0626
- heart-hungarian 0.7654±0.0531
- hepatitis 0.7806±0.0793
- ilpd-indian-liver 0.6450±0.0445
- ionosphere 0.8696±0.0439
- iris 0.9460±0.0339
- led-display 0.7059±0.0168
- libras 0.6803±0.1177
- low-res-spect 0.8563±0.0310
- lymphography 0.8151±0.0600
- mammographic 0.7715±0.0378
- molec-biol-promoter 0.7339±0.1046
- musk-1 0.7231±0.0721
- oocytes_merluccius_nucleus_4d 0.7013±0.0376
- oocytes_merluccius_states_2f 0.8839±0.0314
- oocytes_trisopterus_nucleus_2f 0.6516±0.0534
- oocytes_trisopterus_states_5b 0.7711±0.1160
- parkinsons 0.8262±0.0892
- pima 0.7007±0.0341
- pittsburg-bridges-MATERIAL 0.7817±0.1466
- pittsburg-bridges-REL-L 0.5570±0.1307
- pittsburg-bridges-SPAN 0.5491±0.0985
- pittsburg-bridges-T-OR-D 0.7930±0.1684
- planning 0.5406±0.0762
- post-operative 0.5744±0.0690
- seeds 0.9071±0.0784
- statlog-australian-credit 0.5654±0.0353
- statlog-german-credit 0.6963±0.0243
- statlog-heart 0.7811±0.0462
- statlog-image 0.9472±0.0119
- statlog-vehicle 0.7084±0.0310
- synthetic-control 0.9825±0.0149
- tic-tac-toe 0.8356±0.1050
- vertebral-column-2clases 0.7865±0.1203
- wine 0.9659±0.0210
- zoo 0.9402±0.0448

View File

@@ -0,0 +1,51 @@
In DB Computed as in paper
========= ====================
balance-scale 0.8528 0.8462±0.0853
balloons 0.633333 0.6133±0.2012
breast-cancer-wisc-diag 0.963065 0.9399±0.0194
breast-cancer-wisc-prog 0.666282 0.6898±0.0833
breast-cancer-wisc 0.941367 0.9414±0.0307
breast-cancer 0.573382 0.6336±0.0707
cardiotocography-10clases 0.627007 0.6022±0.0335
cardiotocography-3clases 0.800097 0.8217±0.0565
conn-bench-sonar-mines-rocks 0.63043 0.6235±0.1017
cylinder-bands 0.541043 0.5593±0.0437
dermatology 0.950907 0.9528±0.0252
echocardiogram 0.640456 0.7475±0.0872
fertility 0.69 0.7550±0.1404
haberman-survival 0.659757 0.6699±0.0626
heart-hungarian 0.758212 0.7654±0.0531
hepatitis 0.76129 0.7806±0.0793
ilpd-indian-liver 0.6673 0.6450±0.0445
ionosphere 0.83497 0.8696±0.0439
iris 0.966667 0.9460±0.0339
led-display 0.704 0.7059±0.0168
libras 0.658333 0.6803±0.1177
low-res-spect 0.870111 0.8563±0.0310
lymphography 0.736092 0.8151±0.0600
mammographic 0.764848 0.7715±0.0378
molec-biol-promoter 0.772294 0.7339±0.1046
musk-1 0.722566 0.7231±0.0721
oocytes_merluccius_nucleus_4d 0.718125 0.7013±0.0376
oocytes_merluccius_states_2f 0.888431 0.8839±0.0314
oocytes_trisopterus_nucleus_2f 0.674263 0.6516±0.0534
oocytes_trisopterus_states_5b 0.798199 0.7711±0.1160
parkinsons 0.835897 0.8262±0.0892
pima 0.692793 0.7007±0.0341
pittsburg-bridges-MATERIAL 0.785714 0.7817±0.1466
pittsburg-bridges-REL-L 0.590476 0.5570±0.1307
pittsburg-bridges-SPAN 0.545029 0.5491±0.0985
pittsburg-bridges-T-OR-D 0.742857 0.7930±0.1684
planning 0.478228 0.5406±0.0762
post-operative 0.588889 0.5744±0.0690
seeds 0.928571 0.9071±0.0784
statlog-australian-credit 0.55942 0.5654±0.0353
statlog-german-credit 0.696 0.6963±0.0243
statlog-heart 0.788889 0.7811±0.0462
statlog-image 0.955844 0.9472±0.0119
statlog-vehicle 0.719875 0.7084±0.0310
synthetic-control 0.97 0.9825±0.0149
tic-tac-toe 0.866416 0.8356±0.1050
vertebral-column-2clases 0.790323 0.7865±0.1203
wine 0.972063 0.9659±0.0210
zoo 0.94 0.9402±0.0448