mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 07:26:02 +00:00
Add 10 random seeds run in crossval
Add testwodt comparison
This commit is contained in:
@@ -3,6 +3,7 @@ import sqlite3
|
||||
from datetime import datetime
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import mysql.connector
|
||||
from ast import literal_eval as make_tuple
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
@@ -322,8 +323,8 @@ class Outcomes(BD):
|
||||
outcomes = ["fit_time", "score_time", "train_score", "test_score"]
|
||||
data = ""
|
||||
for index in outcomes:
|
||||
data += ", " + str(results[index].mean()) + ", "
|
||||
data += str(results[index].std())
|
||||
data += ", " + str(np.mean(results[index])) + ", "
|
||||
data += str(np.std(results[index]))
|
||||
command = (
|
||||
f"insert or replace into {self._table} ('dataset', 'parameters', "
|
||||
"'date', 'normalize', 'standardize'"
|
||||
@@ -341,12 +342,12 @@ class Outcomes(BD):
|
||||
normalize,
|
||||
standardize,
|
||||
[
|
||||
float(results["test_score"].mean()),
|
||||
float(results["test_score"].std()),
|
||||
float(np.mean(results["test_score"])),
|
||||
float(np.std(results["test_score"])),
|
||||
],
|
||||
[
|
||||
float(results["fit_time"].mean()),
|
||||
float(results["fit_time"].std()),
|
||||
float(np.mean(results["fit_time"])),
|
||||
float(np.std(results["fit_time"])),
|
||||
],
|
||||
parameters,
|
||||
)
|
||||
|
@@ -75,16 +75,25 @@ class Experiment:
|
||||
model.set_params(**parameters)
|
||||
self._num_warnings = 0
|
||||
warnings.warn = self._warn
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
# Also affect subprocesses
|
||||
os.environ["PYTHONWARNINGS"] = "ignore"
|
||||
results = cross_validate(
|
||||
model, X, y, return_train_score=True, n_jobs=self._threads
|
||||
)
|
||||
# Execute with 10 different seeds to ease the random effect
|
||||
total = {}
|
||||
outcomes = ["fit_time", "score_time", "train_score", "test_score"]
|
||||
for item in outcomes:
|
||||
total[item] = []
|
||||
for random_state in [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]:
|
||||
model.set_params(**{"random_state": random_state})
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
# Also affect subprocesses
|
||||
os.environ["PYTHONWARNINGS"] = "ignore"
|
||||
results = cross_validate(
|
||||
model, X, y, return_train_score=True, n_jobs=self._threads
|
||||
)
|
||||
for item in outcomes:
|
||||
total[item].append(results[item])
|
||||
outcomes = Outcomes(host=self._host, model=self._model_name)
|
||||
parameters = json.dumps(parameters, sort_keys=True)
|
||||
outcomes.store(dataset, normalize, standardize, parameters, results)
|
||||
outcomes.store(dataset, normalize, standardize, parameters, total)
|
||||
if self._num_warnings > 0:
|
||||
print(f"{self._num_warnings} warnings have happend")
|
||||
|
||||
|
@@ -117,7 +117,7 @@ if dataset == "all":
|
||||
)
|
||||
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}")
|
||||
for dataset in dt:
|
||||
print(f"- {dataset[0]:20s} ", end="")
|
||||
print(f"- {dataset[0]:30s} ", end="")
|
||||
scores = process_dataset(dataset[0], verbose=False)
|
||||
print(f"{np.mean(scores):6.4f}±{np.std(scores):6.4f}")
|
||||
else:
|
||||
|
51
testwodt_output.txt
Normal file
51
testwodt_output.txt
Normal file
@@ -0,0 +1,51 @@
|
||||
* Process all datasets set: tanveer [-1, 1]: True norm: False std: False
|
||||
5 Fold Cross Validation with 10 random seeds [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||
- balance-scale 0.8462±0.0853
|
||||
- balloons 0.6133±0.2012
|
||||
- breast-cancer-wisc-diag 0.9399±0.0194
|
||||
- breast-cancer-wisc-prog 0.6898±0.0833
|
||||
- breast-cancer-wisc 0.9414±0.0307
|
||||
- breast-cancer 0.6336±0.0707
|
||||
- cardiotocography-10clases 0.6022±0.0335
|
||||
- cardiotocography-3clases 0.8217±0.0565
|
||||
- conn-bench-sonar-mines-rocks 0.6235±0.1017
|
||||
- cylinder-bands 0.5593±0.0437
|
||||
- dermatology 0.9528±0.0252
|
||||
- echocardiogram 0.7475±0.0872
|
||||
- fertility 0.7550±0.1404
|
||||
- haberman-survival 0.6699±0.0626
|
||||
- heart-hungarian 0.7654±0.0531
|
||||
- hepatitis 0.7806±0.0793
|
||||
- ilpd-indian-liver 0.6450±0.0445
|
||||
- ionosphere 0.8696±0.0439
|
||||
- iris 0.9460±0.0339
|
||||
- led-display 0.7059±0.0168
|
||||
- libras 0.6803±0.1177
|
||||
- low-res-spect 0.8563±0.0310
|
||||
- lymphography 0.8151±0.0600
|
||||
- mammographic 0.7715±0.0378
|
||||
- molec-biol-promoter 0.7339±0.1046
|
||||
- musk-1 0.7231±0.0721
|
||||
- oocytes_merluccius_nucleus_4d 0.7013±0.0376
|
||||
- oocytes_merluccius_states_2f 0.8839±0.0314
|
||||
- oocytes_trisopterus_nucleus_2f 0.6516±0.0534
|
||||
- oocytes_trisopterus_states_5b 0.7711±0.1160
|
||||
- parkinsons 0.8262±0.0892
|
||||
- pima 0.7007±0.0341
|
||||
- pittsburg-bridges-MATERIAL 0.7817±0.1466
|
||||
- pittsburg-bridges-REL-L 0.5570±0.1307
|
||||
- pittsburg-bridges-SPAN 0.5491±0.0985
|
||||
- pittsburg-bridges-T-OR-D 0.7930±0.1684
|
||||
- planning 0.5406±0.0762
|
||||
- post-operative 0.5744±0.0690
|
||||
- seeds 0.9071±0.0784
|
||||
- statlog-australian-credit 0.5654±0.0353
|
||||
- statlog-german-credit 0.6963±0.0243
|
||||
- statlog-heart 0.7811±0.0462
|
||||
- statlog-image 0.9472±0.0119
|
||||
- statlog-vehicle 0.7084±0.0310
|
||||
- synthetic-control 0.9825±0.0149
|
||||
- tic-tac-toe 0.8356±0.1050
|
||||
- vertebral-column-2clases 0.7865±0.1203
|
||||
- wine 0.9659±0.0210
|
||||
- zoo 0.9402±0.0448
|
51
wodt_comparecomputedpaper.txt
Normal file
51
wodt_comparecomputedpaper.txt
Normal file
@@ -0,0 +1,51 @@
|
||||
In DB Computed as in paper
|
||||
========= ====================
|
||||
balance-scale 0.8528 0.8462±0.0853
|
||||
balloons 0.633333 0.6133±0.2012
|
||||
breast-cancer-wisc-diag 0.963065 0.9399±0.0194
|
||||
breast-cancer-wisc-prog 0.666282 0.6898±0.0833
|
||||
breast-cancer-wisc 0.941367 0.9414±0.0307
|
||||
breast-cancer 0.573382 0.6336±0.0707
|
||||
cardiotocography-10clases 0.627007 0.6022±0.0335
|
||||
cardiotocography-3clases 0.800097 0.8217±0.0565
|
||||
conn-bench-sonar-mines-rocks 0.63043 0.6235±0.1017
|
||||
cylinder-bands 0.541043 0.5593±0.0437
|
||||
dermatology 0.950907 0.9528±0.0252
|
||||
echocardiogram 0.640456 0.7475±0.0872
|
||||
fertility 0.69 0.7550±0.1404
|
||||
haberman-survival 0.659757 0.6699±0.0626
|
||||
heart-hungarian 0.758212 0.7654±0.0531
|
||||
hepatitis 0.76129 0.7806±0.0793
|
||||
ilpd-indian-liver 0.6673 0.6450±0.0445
|
||||
ionosphere 0.83497 0.8696±0.0439
|
||||
iris 0.966667 0.9460±0.0339
|
||||
led-display 0.704 0.7059±0.0168
|
||||
libras 0.658333 0.6803±0.1177
|
||||
low-res-spect 0.870111 0.8563±0.0310
|
||||
lymphography 0.736092 0.8151±0.0600
|
||||
mammographic 0.764848 0.7715±0.0378
|
||||
molec-biol-promoter 0.772294 0.7339±0.1046
|
||||
musk-1 0.722566 0.7231±0.0721
|
||||
oocytes_merluccius_nucleus_4d 0.718125 0.7013±0.0376
|
||||
oocytes_merluccius_states_2f 0.888431 0.8839±0.0314
|
||||
oocytes_trisopterus_nucleus_2f 0.674263 0.6516±0.0534
|
||||
oocytes_trisopterus_states_5b 0.798199 0.7711±0.1160
|
||||
parkinsons 0.835897 0.8262±0.0892
|
||||
pima 0.692793 0.7007±0.0341
|
||||
pittsburg-bridges-MATERIAL 0.785714 0.7817±0.1466
|
||||
pittsburg-bridges-REL-L 0.590476 0.5570±0.1307
|
||||
pittsburg-bridges-SPAN 0.545029 0.5491±0.0985
|
||||
pittsburg-bridges-T-OR-D 0.742857 0.7930±0.1684
|
||||
planning 0.478228 0.5406±0.0762
|
||||
post-operative 0.588889 0.5744±0.0690
|
||||
seeds 0.928571 0.9071±0.0784
|
||||
statlog-australian-credit 0.55942 0.5654±0.0353
|
||||
statlog-german-credit 0.696 0.6963±0.0243
|
||||
statlog-heart 0.788889 0.7811±0.0462
|
||||
statlog-image 0.955844 0.9472±0.0119
|
||||
statlog-vehicle 0.719875 0.7084±0.0310
|
||||
synthetic-control 0.97 0.9825±0.0149
|
||||
tic-tac-toe 0.866416 0.8356±0.1050
|
||||
vertebral-column-2clases 0.790323 0.7865±0.1203
|
||||
wine 0.972063 0.9659±0.0210
|
||||
zoo 0.94 0.9402±0.0448
|
Reference in New Issue
Block a user