mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 07:26:02 +00:00
98 lines
3.2 KiB
Python
Executable File
98 lines
3.2 KiB
Python
Executable File
import random
|
|
import time
|
|
import numpy as np
|
|
from sklearn.model_selection import KFold, cross_validate
|
|
from experimentation.Sets import Datasets
|
|
from stree import Stree
|
|
from experimentation.Utils import TextColor
|
|
from experimentation.Database import MySQL
|
|
|
|
|
|
def normalize(data: np.array) -> np.array:
|
|
min_data = data.min()
|
|
return (data - min_data) / (data.max() - min_data)
|
|
|
|
|
|
def normalize_rows(data: np.array) -> np.array:
|
|
res = data.copy()
|
|
for col in range(res.shape[1]):
|
|
res[:, col] = normalize(res[:, col])
|
|
return res
|
|
|
|
|
|
def header():
|
|
print("Processing Datasets with stree default.\n")
|
|
print(
|
|
f"{'Dataset':30s} {'No Norm.':9s} {'Normaliz.':9s} "
|
|
f"{'Col.Norm.':9s} {'Context B':9s} {'Best score in crossval':25s}"
|
|
)
|
|
print("=" * 30 + " " + ("=" * 9 + " ") * 4 + "=" * 25)
|
|
|
|
|
|
def process_dataset(X, y, normalize):
|
|
scores = []
|
|
# return random.uniform(0, 1)
|
|
for random_state in random_seeds:
|
|
random.seed(random_state)
|
|
clf_test = Stree(random_state=random_state, normalize=normalize)
|
|
np.random.seed(random_state)
|
|
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
|
|
res = cross_validate(clf_test, X, y, cv=kfold, return_estimator=True)
|
|
scores.append(res["test_score"])
|
|
return np.mean(scores)
|
|
|
|
|
|
start = time.time()
|
|
models_tree = [
|
|
"stree",
|
|
"stree_default",
|
|
"wodt",
|
|
"j48svm",
|
|
"oc1",
|
|
"cart",
|
|
"baseRaF",
|
|
]
|
|
dbh = MySQL()
|
|
database = dbh.get_connection()
|
|
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
|
dt = Datasets(normalize=False, standardize=False, set_of_files="tanveer")
|
|
header()
|
|
total = [0, 0, 0, 0]
|
|
line = TextColor.LINE1
|
|
for data in dt:
|
|
name = data[0]
|
|
X, y = dt.load(name)
|
|
record = dbh.find_best(name, models_tree, "crossval")
|
|
X2 = normalize(X)
|
|
X3 = normalize_rows(X)
|
|
ac1 = process_dataset(X, y, False)
|
|
ac2 = process_dataset(X2, y, False)
|
|
ac3 = process_dataset(X3, y, False)
|
|
ac4 = process_dataset(X, y, True)
|
|
max_value = round(max(ac1, ac2, ac3, ac4), 6)
|
|
line = TextColor.LINE2 if line == TextColor.LINE1 else TextColor.LINE1
|
|
print(line + f"{name:30s} ", end="", flush=True)
|
|
total[np.argmax([ac1, ac2, ac3, ac4])] += 1
|
|
color1 = TextColor.SUCCESS if ac1 == max_value else line
|
|
color2 = TextColor.SUCCESS if ac2 == max_value else line
|
|
color3 = TextColor.SUCCESS if ac3 == max_value else line
|
|
color4 = TextColor.SUCCESS if ac4 == max_value else line
|
|
print(color1 + f"{ac1:9.6f} " + TextColor.ENDC, end="", flush=True)
|
|
print(color2 + f"{ac2:9.6f} " + TextColor.ENDC, end="", flush=True)
|
|
print(color3 + f"{ac3:9.6f} " + TextColor.ENDC, end="", flush=True)
|
|
print(color4 + f"{ac4:9.6f}" + TextColor.ENDC, end="", flush=True)
|
|
best_accuracy = round(record[5], 6)
|
|
best_color = TextColor.UNDERLINE if best_accuracy >= max_value else ""
|
|
print(
|
|
line
|
|
+ best_color
|
|
+ f"{best_accuracy:9.6f} {record[3]}"
|
|
+ TextColor.ENDC
|
|
)
|
|
print(f"{'Total':30s} {total[0]:9d} {total[1]:9d} {total[2]:9d} {total[3]:9d}")
|
|
stop = time.time()
|
|
hours, rem = divmod(stop - start, 3600)
|
|
minutes, seconds = divmod(rem, 60)
|
|
print(f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s")
|
|
dbh.close()
|