mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 07:26:02 +00:00
131 lines
5.6 KiB
Python
131 lines
5.6 KiB
Python
import time
|
|
import warnings
|
|
import numpy as np
|
|
from stree import Stree
|
|
from sklearn.model_selection import KFold, cross_validate
|
|
from experimentation.Sets import Datasets
|
|
from experimentation.Utils import TextColor
|
|
|
|
kernels = ["linear", "sigmoid", "poly", "rbf"]
|
|
|
|
results = [
|
|
("balance-scale", 0.91072, 0.8456, 0.55824, 0.76864),
|
|
("balloons", 0.653333, 0.696667, 0.595, 0.581667),
|
|
("breast-cancer-wisc-diag", 0.968898, 0.94798, 0.920394, 0.972762),
|
|
("breast-cancer-wisc-prog", 0.802051, 0.762679, 0.755103, 0.773295),
|
|
("breast-cancer-wisc", 0.966661, 0.964666, 0.955221, 0.967809),
|
|
("breast-cancer", 0.734211, 0.71043, 0.736624, 0.731754),
|
|
("cardiotocography-10clases", 0.791487, 0.403616, 0.373194, 0.420877),
|
|
("cardiotocography-3clases", 0.900613, 0.798167, 0.84365, 0.81289),
|
|
("conn-bench-sonar-mines-rocks", 0.755528, 0.752439, 0.781243, 0.832091),
|
|
("cylinder-bands", 0.715049, 0.690042, 0.697238, 0.747613),
|
|
("dermatology", 0.966087, 0.531725, 0.576912, 0.381129),
|
|
("echocardiogram", 0.808832, 0.844501, 0.79245, 0.825299),
|
|
("fertility", 0.866, 0.88, 0.852, 0.88),
|
|
("haberman-survival", 0.735637, 0.733718, 0.728477, 0.731713),
|
|
("heart-hungarian", 0.817674, 0.807832, 0.811198, 0.823448),
|
|
("hepatitis", 0.796129, 0.781935, 0.806452, 0.825161),
|
|
("ilpd-indian-liver", 0.723498, 0.70739, 0.707907, 0.709788),
|
|
("ionosphere", 0.866056, 0.85528, 0.77293, 0.940744),
|
|
("iris", 0.965333, 0.832667, 0.952667, 0.952667),
|
|
("led-display", 0.703, 0.4156, 0.2601, 0.3011),
|
|
("libras", 0.747778, 0.165278, 0.108333, 0.177222),
|
|
("low-res-spect", 0.853102, 0.522254, 0.529979, 0.527917),
|
|
("lymphography", 0.773793, 0.547057, 0.547057, 0.547057),
|
|
("mammographic", 0.81915, 0.796662, 0.817173, 0.826747),
|
|
("molec-biol-promoter", 0.764416, 0.781039, 0.696017, 0.827143),
|
|
("musk-1", 0.843463, 0.732531, 0.900004, 0.895811),
|
|
("oocytes_merluccius_nucleus_4d", 0.810657, 0.702055, 0.714768, 0.770059),
|
|
("oocytes_merluccius_states_2f", 0.915365, 0.74883, 0.710081, 0.718894),
|
|
("oocytes_trisopterus_nucleus_2f", 0.800986, 0.674258, 0.690322, 0.799127),
|
|
("oocytes_trisopterus_states_5b", 0.916655, 0.602868, 0.637082, 0.588284),
|
|
("parkinsons", 0.882051, 0.839487, 0.864615, 0.874359),
|
|
("pima", 0.766651, 0.745009, 0.741266, 0.756369),
|
|
("pittsburg-bridges-MATERIAL", 0.791255, 0.854372, 0.830693, 0.846797),
|
|
("pittsburg-bridges-REL-L", 0.632238, 0.472, 0.509429, 0.484476),
|
|
("pittsburg-bridges-SPAN", 0.630234, 0.578129, 0.588596, 0.593275),
|
|
("pittsburg-bridges-T-OR-D", 0.861619, 0.85881, 0.867762, 0.86481),
|
|
("planning", 0.70455, 0.712207, 0.690751, 0.713258),
|
|
("post-operative", 0.675556, 0.711111, 0.711111, 0.711111),
|
|
("seeds", 0.949048, 0.890952, 0.9, 0.933333),
|
|
("statlog-australian-credit", 0.667246, 0.668261, 0.664638, 0.672319),
|
|
("statlog-german-credit", 0.7625, 0.7363, 0.7344, 0.758),
|
|
("statlog-heart", 0.822963, 0.838148, 0.830741, 0.827037),
|
|
("statlog-image", 0.952641, 0.383896, 0.420346, 0.379134),
|
|
("statlog-vehicle", 0.793028, 0.445035, 0.415464, 0.57556),
|
|
("synthetic-control", 0.938833, 0.511333, 0.439667, 0.5675),
|
|
("tic-tac-toe", 0.983296, 0.752095, 0.984028, 0.986324),
|
|
("vertebral-column-2clases", 0.852903, 0.812581, 0.730323, 0.845161),
|
|
("wine", 0.97581, 0.571635, 0.562397, 0.926175),
|
|
("zoo", 0.947619, 0.590667, 0.664095, 0.523524),
|
|
]
|
|
|
|
|
|
def process_dataset(dataset, kernel):
|
|
X, y = dt.load(dataset)
|
|
scores = []
|
|
times = []
|
|
for random_state in random_seeds:
|
|
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
|
|
clf = Stree(kernel=kernel, random_state=random_state)
|
|
res = cross_validate(clf, X, y, cv=kfold, return_estimator=True)
|
|
scores.append(res["test_score"])
|
|
times.append(res["fit_time"])
|
|
return scores, times
|
|
|
|
|
|
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
|
dt = Datasets(normalize=False, standardize=False, set_of_files="tanveer")
|
|
start = time.time()
|
|
print(
|
|
TextColor.MAGENTA
|
|
+ "Testing all datasets accuracies with default hyperparameters and all "
|
|
"kernels"
|
|
)
|
|
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n")
|
|
header_cols = [
|
|
"Dataset",
|
|
"Linear",
|
|
"Sigmoid",
|
|
"Poly",
|
|
"RBF",
|
|
]
|
|
header_lengths = [30, 7, 7, 7, 7, 7]
|
|
line_col = ""
|
|
mistakes = correct = 0
|
|
check_mark = "\N{heavy check mark}"
|
|
cross_mark = "\N{heavy ballot x}"
|
|
for field, underscore in zip(header_cols, header_lengths):
|
|
print(f"{field:{underscore}s} ", end="", flush=True)
|
|
line_col += "=" * underscore + " "
|
|
print(f"\n{line_col}")
|
|
color = ""
|
|
warnings.filterwarnings("ignore", message="Solver terminated early")
|
|
for name, linear, sigmoid, poly, rbf in results:
|
|
color = TextColor.LINE1 if color == TextColor.LINE2 else TextColor.LINE2
|
|
results_dataset = dict(linear=linear, sigmoid=sigmoid, poly=poly, rbf=rbf)
|
|
X, y = dt.load(name)
|
|
print(
|
|
color + f"{name:30s} ",
|
|
end="",
|
|
)
|
|
for kernel in kernels:
|
|
scores, times = process_dataset(name, kernel)
|
|
if round(np.mean(scores), 6) != results_dataset[kernel]:
|
|
mistakes += 1
|
|
item = cross_mark
|
|
item_color = TextColor.FAIL
|
|
else:
|
|
correct += 1
|
|
item = check_mark
|
|
item_color = TextColor.SUCCESS
|
|
item = item.center(7)
|
|
print(item_color + f"{item} ", end="", flush=True)
|
|
print("")
|
|
print(TextColor.SUCCESS + f"Correct results : {correct:3d}")
|
|
print(TextColor.FAIL + f"Mistaken results: {mistakes:3d}")
|
|
stop = time.time()
|
|
hours, rem = divmod(stop - start, 3600)
|
|
minutes, seconds = divmod(rem, 60)
|
|
print(color + f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s")
|