Files
stree_datasets/big_test.py

131 lines
5.6 KiB
Python

import time
import warnings
import numpy as np
from stree import Stree
from sklearn.model_selection import KFold, cross_validate
from experimentation.Sets import Datasets
from experimentation.Utils import TextColor
kernels = ["linear", "sigmoid", "poly", "rbf"]
results = [
("balance-scale", 0.91072, 0.8456, 0.55824, 0.76864),
("balloons", 0.653333, 0.696667, 0.595, 0.581667),
("breast-cancer-wisc-diag", 0.968898, 0.94798, 0.920394, 0.972762),
("breast-cancer-wisc-prog", 0.802051, 0.762679, 0.755103, 0.773295),
("breast-cancer-wisc", 0.966661, 0.964666, 0.955221, 0.967809),
("breast-cancer", 0.734211, 0.71043, 0.736624, 0.731754),
("cardiotocography-10clases", 0.791487, 0.403616, 0.373194, 0.420877),
("cardiotocography-3clases", 0.900613, 0.798167, 0.84365, 0.81289),
("conn-bench-sonar-mines-rocks", 0.755528, 0.752439, 0.781243, 0.832091),
("cylinder-bands", 0.715049, 0.690042, 0.697238, 0.747613),
("dermatology", 0.966087, 0.531725, 0.576912, 0.381129),
("echocardiogram", 0.808832, 0.844501, 0.79245, 0.825299),
("fertility", 0.866, 0.88, 0.852, 0.88),
("haberman-survival", 0.735637, 0.733718, 0.728477, 0.731713),
("heart-hungarian", 0.817674, 0.807832, 0.811198, 0.823448),
("hepatitis", 0.796129, 0.781935, 0.806452, 0.825161),
("ilpd-indian-liver", 0.723498, 0.70739, 0.707907, 0.709788),
("ionosphere", 0.866056, 0.85528, 0.77293, 0.940744),
("iris", 0.965333, 0.832667, 0.952667, 0.952667),
("led-display", 0.703, 0.4156, 0.2601, 0.3011),
("libras", 0.747778, 0.165278, 0.108333, 0.177222),
("low-res-spect", 0.853102, 0.522254, 0.529979, 0.527917),
("lymphography", 0.773793, 0.547057, 0.547057, 0.547057),
("mammographic", 0.81915, 0.796662, 0.817173, 0.826747),
("molec-biol-promoter", 0.764416, 0.781039, 0.696017, 0.827143),
("musk-1", 0.843463, 0.732531, 0.900004, 0.895811),
("oocytes_merluccius_nucleus_4d", 0.810657, 0.702055, 0.714768, 0.770059),
("oocytes_merluccius_states_2f", 0.915365, 0.74883, 0.710081, 0.718894),
("oocytes_trisopterus_nucleus_2f", 0.800986, 0.674258, 0.690322, 0.799127),
("oocytes_trisopterus_states_5b", 0.916655, 0.602868, 0.637082, 0.588284),
("parkinsons", 0.882051, 0.839487, 0.864615, 0.874359),
("pima", 0.766651, 0.745009, 0.741266, 0.756369),
("pittsburg-bridges-MATERIAL", 0.791255, 0.854372, 0.830693, 0.846797),
("pittsburg-bridges-REL-L", 0.632238, 0.472, 0.509429, 0.484476),
("pittsburg-bridges-SPAN", 0.630234, 0.578129, 0.588596, 0.593275),
("pittsburg-bridges-T-OR-D", 0.861619, 0.85881, 0.867762, 0.86481),
("planning", 0.70455, 0.712207, 0.690751, 0.713258),
("post-operative", 0.675556, 0.711111, 0.711111, 0.711111),
("seeds", 0.949048, 0.890952, 0.9, 0.933333),
("statlog-australian-credit", 0.667246, 0.668261, 0.664638, 0.672319),
("statlog-german-credit", 0.7625, 0.7363, 0.7344, 0.758),
("statlog-heart", 0.822963, 0.838148, 0.830741, 0.827037),
("statlog-image", 0.952641, 0.383896, 0.420346, 0.379134),
("statlog-vehicle", 0.793028, 0.445035, 0.415464, 0.57556),
("synthetic-control", 0.938833, 0.511333, 0.439667, 0.5675),
("tic-tac-toe", 0.983296, 0.752095, 0.984028, 0.986324),
("vertebral-column-2clases", 0.852903, 0.812581, 0.730323, 0.845161),
("wine", 0.97581, 0.571635, 0.562397, 0.926175),
("zoo", 0.947619, 0.590667, 0.664095, 0.523524),
]
def process_dataset(dataset, kernel):
X, y = dt.load(dataset)
scores = []
times = []
for random_state in random_seeds:
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
clf = Stree(kernel=kernel, random_state=random_state)
res = cross_validate(clf, X, y, cv=kfold, return_estimator=True)
scores.append(res["test_score"])
times.append(res["fit_time"])
return scores, times
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
dt = Datasets(normalize=False, standardize=False, set_of_files="tanveer")
start = time.time()
print(
TextColor.MAGENTA
+ "Testing all datasets accuracies with default hyperparameters and all "
"kernels"
)
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n")
header_cols = [
"Dataset",
"Linear",
"Sigmoid",
"Poly",
"RBF",
]
header_lengths = [30, 7, 7, 7, 7, 7]
line_col = ""
mistakes = correct = 0
check_mark = "\N{heavy check mark}"
cross_mark = "\N{heavy ballot x}"
for field, underscore in zip(header_cols, header_lengths):
print(f"{field:{underscore}s} ", end="", flush=True)
line_col += "=" * underscore + " "
print(f"\n{line_col}")
color = ""
warnings.filterwarnings("ignore", message="Solver terminated early")
for name, linear, sigmoid, poly, rbf in results:
color = TextColor.LINE1 if color == TextColor.LINE2 else TextColor.LINE2
results_dataset = dict(linear=linear, sigmoid=sigmoid, poly=poly, rbf=rbf)
X, y = dt.load(name)
print(
color + f"{name:30s} ",
end="",
)
for kernel in kernels:
scores, times = process_dataset(name, kernel)
if round(np.mean(scores), 6) != results_dataset[kernel]:
mistakes += 1
item = cross_mark
item_color = TextColor.FAIL
else:
correct += 1
item = check_mark
item_color = TextColor.SUCCESS
item = item.center(7)
print(item_color + f"{item} ", end="", flush=True)
print("")
print(TextColor.SUCCESS + f"Correct results : {correct:3d}")
print(TextColor.FAIL + f"Mistaken results: {mistakes:3d}")
stop = time.time()
hours, rem = divmod(stop - start, 3600)
minutes, seconds = divmod(rem, 60)
print(color + f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s")