mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-16 07:56:07 +00:00
Add 10 random seeds run in crossval
Add testwodt comparison
This commit is contained in:
@@ -3,6 +3,7 @@ import sqlite3
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import numpy as np
|
||||||
import mysql.connector
|
import mysql.connector
|
||||||
from ast import literal_eval as make_tuple
|
from ast import literal_eval as make_tuple
|
||||||
from sshtunnel import SSHTunnelForwarder
|
from sshtunnel import SSHTunnelForwarder
|
||||||
@@ -322,8 +323,8 @@ class Outcomes(BD):
|
|||||||
outcomes = ["fit_time", "score_time", "train_score", "test_score"]
|
outcomes = ["fit_time", "score_time", "train_score", "test_score"]
|
||||||
data = ""
|
data = ""
|
||||||
for index in outcomes:
|
for index in outcomes:
|
||||||
data += ", " + str(results[index].mean()) + ", "
|
data += ", " + str(np.mean(results[index])) + ", "
|
||||||
data += str(results[index].std())
|
data += str(np.std(results[index]))
|
||||||
command = (
|
command = (
|
||||||
f"insert or replace into {self._table} ('dataset', 'parameters', "
|
f"insert or replace into {self._table} ('dataset', 'parameters', "
|
||||||
"'date', 'normalize', 'standardize'"
|
"'date', 'normalize', 'standardize'"
|
||||||
@@ -341,12 +342,12 @@ class Outcomes(BD):
|
|||||||
normalize,
|
normalize,
|
||||||
standardize,
|
standardize,
|
||||||
[
|
[
|
||||||
float(results["test_score"].mean()),
|
float(np.mean(results["test_score"])),
|
||||||
float(results["test_score"].std()),
|
float(np.std(results["test_score"])),
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
float(results["fit_time"].mean()),
|
float(np.mean(results["fit_time"])),
|
||||||
float(results["fit_time"].std()),
|
float(np.std(results["fit_time"])),
|
||||||
],
|
],
|
||||||
parameters,
|
parameters,
|
||||||
)
|
)
|
||||||
|
@@ -75,16 +75,25 @@ class Experiment:
|
|||||||
model.set_params(**parameters)
|
model.set_params(**parameters)
|
||||||
self._num_warnings = 0
|
self._num_warnings = 0
|
||||||
warnings.warn = self._warn
|
warnings.warn = self._warn
|
||||||
with warnings.catch_warnings():
|
# Execute with 10 different seeds to ease the random effect
|
||||||
warnings.filterwarnings("ignore")
|
total = {}
|
||||||
# Also affect subprocesses
|
outcomes = ["fit_time", "score_time", "train_score", "test_score"]
|
||||||
os.environ["PYTHONWARNINGS"] = "ignore"
|
for item in outcomes:
|
||||||
results = cross_validate(
|
total[item] = []
|
||||||
model, X, y, return_train_score=True, n_jobs=self._threads
|
for random_state in [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]:
|
||||||
)
|
model.set_params(**{"random_state": random_state})
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
# Also affect subprocesses
|
||||||
|
os.environ["PYTHONWARNINGS"] = "ignore"
|
||||||
|
results = cross_validate(
|
||||||
|
model, X, y, return_train_score=True, n_jobs=self._threads
|
||||||
|
)
|
||||||
|
for item in outcomes:
|
||||||
|
total[item].append(results[item])
|
||||||
outcomes = Outcomes(host=self._host, model=self._model_name)
|
outcomes = Outcomes(host=self._host, model=self._model_name)
|
||||||
parameters = json.dumps(parameters, sort_keys=True)
|
parameters = json.dumps(parameters, sort_keys=True)
|
||||||
outcomes.store(dataset, normalize, standardize, parameters, results)
|
outcomes.store(dataset, normalize, standardize, parameters, total)
|
||||||
if self._num_warnings > 0:
|
if self._num_warnings > 0:
|
||||||
print(f"{self._num_warnings} warnings have happend")
|
print(f"{self._num_warnings} warnings have happend")
|
||||||
|
|
||||||
|
@@ -117,7 +117,7 @@ if dataset == "all":
|
|||||||
)
|
)
|
||||||
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}")
|
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}")
|
||||||
for dataset in dt:
|
for dataset in dt:
|
||||||
print(f"- {dataset[0]:20s} ", end="")
|
print(f"- {dataset[0]:30s} ", end="")
|
||||||
scores = process_dataset(dataset[0], verbose=False)
|
scores = process_dataset(dataset[0], verbose=False)
|
||||||
print(f"{np.mean(scores):6.4f}±{np.std(scores):6.4f}")
|
print(f"{np.mean(scores):6.4f}±{np.std(scores):6.4f}")
|
||||||
else:
|
else:
|
||||||
|
51
testwodt_output.txt
Normal file
51
testwodt_output.txt
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
* Process all datasets set: tanveer [-1, 1]: True norm: False std: False
|
||||||
|
5 Fold Cross Validation with 10 random seeds [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||||
|
- balance-scale 0.8462±0.0853
|
||||||
|
- balloons 0.6133±0.2012
|
||||||
|
- breast-cancer-wisc-diag 0.9399±0.0194
|
||||||
|
- breast-cancer-wisc-prog 0.6898±0.0833
|
||||||
|
- breast-cancer-wisc 0.9414±0.0307
|
||||||
|
- breast-cancer 0.6336±0.0707
|
||||||
|
- cardiotocography-10clases 0.6022±0.0335
|
||||||
|
- cardiotocography-3clases 0.8217±0.0565
|
||||||
|
- conn-bench-sonar-mines-rocks 0.6235±0.1017
|
||||||
|
- cylinder-bands 0.5593±0.0437
|
||||||
|
- dermatology 0.9528±0.0252
|
||||||
|
- echocardiogram 0.7475±0.0872
|
||||||
|
- fertility 0.7550±0.1404
|
||||||
|
- haberman-survival 0.6699±0.0626
|
||||||
|
- heart-hungarian 0.7654±0.0531
|
||||||
|
- hepatitis 0.7806±0.0793
|
||||||
|
- ilpd-indian-liver 0.6450±0.0445
|
||||||
|
- ionosphere 0.8696±0.0439
|
||||||
|
- iris 0.9460±0.0339
|
||||||
|
- led-display 0.7059±0.0168
|
||||||
|
- libras 0.6803±0.1177
|
||||||
|
- low-res-spect 0.8563±0.0310
|
||||||
|
- lymphography 0.8151±0.0600
|
||||||
|
- mammographic 0.7715±0.0378
|
||||||
|
- molec-biol-promoter 0.7339±0.1046
|
||||||
|
- musk-1 0.7231±0.0721
|
||||||
|
- oocytes_merluccius_nucleus_4d 0.7013±0.0376
|
||||||
|
- oocytes_merluccius_states_2f 0.8839±0.0314
|
||||||
|
- oocytes_trisopterus_nucleus_2f 0.6516±0.0534
|
||||||
|
- oocytes_trisopterus_states_5b 0.7711±0.1160
|
||||||
|
- parkinsons 0.8262±0.0892
|
||||||
|
- pima 0.7007±0.0341
|
||||||
|
- pittsburg-bridges-MATERIAL 0.7817±0.1466
|
||||||
|
- pittsburg-bridges-REL-L 0.5570±0.1307
|
||||||
|
- pittsburg-bridges-SPAN 0.5491±0.0985
|
||||||
|
- pittsburg-bridges-T-OR-D 0.7930±0.1684
|
||||||
|
- planning 0.5406±0.0762
|
||||||
|
- post-operative 0.5744±0.0690
|
||||||
|
- seeds 0.9071±0.0784
|
||||||
|
- statlog-australian-credit 0.5654±0.0353
|
||||||
|
- statlog-german-credit 0.6963±0.0243
|
||||||
|
- statlog-heart 0.7811±0.0462
|
||||||
|
- statlog-image 0.9472±0.0119
|
||||||
|
- statlog-vehicle 0.7084±0.0310
|
||||||
|
- synthetic-control 0.9825±0.0149
|
||||||
|
- tic-tac-toe 0.8356±0.1050
|
||||||
|
- vertebral-column-2clases 0.7865±0.1203
|
||||||
|
- wine 0.9659±0.0210
|
||||||
|
- zoo 0.9402±0.0448
|
51
wodt_comparecomputedpaper.txt
Normal file
51
wodt_comparecomputedpaper.txt
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
In DB Computed as in paper
|
||||||
|
========= ====================
|
||||||
|
balance-scale 0.8528 0.8462±0.0853
|
||||||
|
balloons 0.633333 0.6133±0.2012
|
||||||
|
breast-cancer-wisc-diag 0.963065 0.9399±0.0194
|
||||||
|
breast-cancer-wisc-prog 0.666282 0.6898±0.0833
|
||||||
|
breast-cancer-wisc 0.941367 0.9414±0.0307
|
||||||
|
breast-cancer 0.573382 0.6336±0.0707
|
||||||
|
cardiotocography-10clases 0.627007 0.6022±0.0335
|
||||||
|
cardiotocography-3clases 0.800097 0.8217±0.0565
|
||||||
|
conn-bench-sonar-mines-rocks 0.63043 0.6235±0.1017
|
||||||
|
cylinder-bands 0.541043 0.5593±0.0437
|
||||||
|
dermatology 0.950907 0.9528±0.0252
|
||||||
|
echocardiogram 0.640456 0.7475±0.0872
|
||||||
|
fertility 0.69 0.7550±0.1404
|
||||||
|
haberman-survival 0.659757 0.6699±0.0626
|
||||||
|
heart-hungarian 0.758212 0.7654±0.0531
|
||||||
|
hepatitis 0.76129 0.7806±0.0793
|
||||||
|
ilpd-indian-liver 0.6673 0.6450±0.0445
|
||||||
|
ionosphere 0.83497 0.8696±0.0439
|
||||||
|
iris 0.966667 0.9460±0.0339
|
||||||
|
led-display 0.704 0.7059±0.0168
|
||||||
|
libras 0.658333 0.6803±0.1177
|
||||||
|
low-res-spect 0.870111 0.8563±0.0310
|
||||||
|
lymphography 0.736092 0.8151±0.0600
|
||||||
|
mammographic 0.764848 0.7715±0.0378
|
||||||
|
molec-biol-promoter 0.772294 0.7339±0.1046
|
||||||
|
musk-1 0.722566 0.7231±0.0721
|
||||||
|
oocytes_merluccius_nucleus_4d 0.718125 0.7013±0.0376
|
||||||
|
oocytes_merluccius_states_2f 0.888431 0.8839±0.0314
|
||||||
|
oocytes_trisopterus_nucleus_2f 0.674263 0.6516±0.0534
|
||||||
|
oocytes_trisopterus_states_5b 0.798199 0.7711±0.1160
|
||||||
|
parkinsons 0.835897 0.8262±0.0892
|
||||||
|
pima 0.692793 0.7007±0.0341
|
||||||
|
pittsburg-bridges-MATERIAL 0.785714 0.7817±0.1466
|
||||||
|
pittsburg-bridges-REL-L 0.590476 0.5570±0.1307
|
||||||
|
pittsburg-bridges-SPAN 0.545029 0.5491±0.0985
|
||||||
|
pittsburg-bridges-T-OR-D 0.742857 0.7930±0.1684
|
||||||
|
planning 0.478228 0.5406±0.0762
|
||||||
|
post-operative 0.588889 0.5744±0.0690
|
||||||
|
seeds 0.928571 0.9071±0.0784
|
||||||
|
statlog-australian-credit 0.55942 0.5654±0.0353
|
||||||
|
statlog-german-credit 0.696 0.6963±0.0243
|
||||||
|
statlog-heart 0.788889 0.7811±0.0462
|
||||||
|
statlog-image 0.955844 0.9472±0.0119
|
||||||
|
statlog-vehicle 0.719875 0.7084±0.0310
|
||||||
|
synthetic-control 0.97 0.9825±0.0149
|
||||||
|
tic-tac-toe 0.866416 0.8356±0.1050
|
||||||
|
vertebral-column-2clases 0.790323 0.7865±0.1203
|
||||||
|
wine 0.972063 0.9659±0.0210
|
||||||
|
zoo 0.94 0.9402±0.0448
|
Reference in New Issue
Block a user