mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 07:26:02 +00:00
505 lines
15 KiB
Python
505 lines
15 KiB
Python
import argparse
|
|
import random
|
|
import time
|
|
import warnings
|
|
from datetime import datetime
|
|
import json
|
|
import numpy as np
|
|
import xlsxwriter
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from stree import Stree
|
|
from sklearn.model_selection import KFold, cross_validate
|
|
from experimentation.Sets import Datasets
|
|
from experimentation.Database import MySQL
|
|
from wodt import TreeClassifier
|
|
from experimentation.Utils import TextColor
|
|
from mdlp import MDLP
|
|
|
|
|
|
CHECK_MARK = "\N{heavy check mark}"
|
|
EXCLAMATION_MARK = "\N{heavy exclamation mark symbol}"
|
|
BLACK_STAR = "\N{black star}"
|
|
|
|
|
|
def parse_arguments():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument(
|
|
"-S",
|
|
"--set-of-files",
|
|
type=str,
|
|
choices=["aaai", "tanveer"],
|
|
required=False,
|
|
default="tanveer",
|
|
)
|
|
ap.add_argument(
|
|
"-m",
|
|
"--model",
|
|
type=str,
|
|
required=False,
|
|
default="stree_default",
|
|
help="model name, default stree_default",
|
|
)
|
|
ap.add_argument(
|
|
"-d",
|
|
"--dataset",
|
|
type=str,
|
|
required=True,
|
|
help="dataset to process, all for everyone",
|
|
)
|
|
ap.add_argument(
|
|
"-s",
|
|
"--sql",
|
|
default=False,
|
|
type=bool,
|
|
required=False,
|
|
help="generate report_score.sql",
|
|
)
|
|
ap.add_argument(
|
|
"-n",
|
|
"--normalize",
|
|
type=int,
|
|
required=True,
|
|
)
|
|
ap.add_argument(
|
|
"-x",
|
|
"--excel",
|
|
type=str,
|
|
default="",
|
|
required=False,
|
|
help="generate excel file",
|
|
)
|
|
ap.add_argument(
|
|
"-di",
|
|
"--discretize",
|
|
type=bool,
|
|
default=False,
|
|
required=False,
|
|
help="Discretize datasets",
|
|
)
|
|
ap.add_argument(
|
|
"-p", "--parameters", type=str, required=False, default="{}"
|
|
)
|
|
args = ap.parse_args()
|
|
return (
|
|
args.set_of_files,
|
|
args.model,
|
|
args.dataset,
|
|
args.sql,
|
|
bool(args.normalize),
|
|
args.parameters,
|
|
args.excel,
|
|
args.discretize,
|
|
)
|
|
|
|
|
|
def get_classifier(model, random_state, hyperparameters):
|
|
if model == "stree" or model == "stree_default":
|
|
clf = Stree(random_state=random_state)
|
|
clf.set_params(**hyperparameters)
|
|
if model == "wodt":
|
|
clf = TreeClassifier(random_state=random_state)
|
|
if model == "cart":
|
|
clf = DecisionTreeClassifier(random_state=random_state)
|
|
return clf
|
|
|
|
|
|
def process_dataset(dataset, verbose, model, params):
|
|
X, y = dt.load(dataset)
|
|
if discretize:
|
|
mdlp = MDLP(random_state=1)
|
|
X = mdlp.fit_transform(X, y)
|
|
scores = []
|
|
times = []
|
|
nodes = []
|
|
leaves = []
|
|
depths = []
|
|
if verbose:
|
|
print(
|
|
f"* Processing dataset [{dataset}] from Set: {set_of_files} with "
|
|
f"{model}"
|
|
)
|
|
print(f"X.shape: {X.shape}")
|
|
print(f"{X[:4]}")
|
|
print(f"Random seeds: {random_seeds}")
|
|
hyperparameters = json.loads(params)
|
|
if model == "stree":
|
|
# Get the optimized parameters
|
|
record = dbh.find_best(dataset, model, "gridsearch")
|
|
hyperparameters = json.loads(
|
|
record[8].replace('\\"', '"') if record[8] != "" else "{}"
|
|
)
|
|
hyperparameters.pop("random_state", None)
|
|
for random_state in random_seeds:
|
|
random.seed(random_state)
|
|
np.random.seed(random_state)
|
|
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
|
|
clf = get_classifier(model, random_state, hyperparameters)
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore")
|
|
res = cross_validate(clf, X, y, cv=kfold, return_estimator=True)
|
|
scores.append(res["test_score"])
|
|
times.append(res["fit_time"])
|
|
for result_item in res["estimator"]:
|
|
if model == "cart":
|
|
nodes_item = result_item.tree_.node_count
|
|
depth_item = result_item.tree_.max_depth
|
|
leaves_item = result_item.get_n_leaves()
|
|
else:
|
|
nodes_item, leaves_item = result_item.nodes_leaves()
|
|
depth_item = (
|
|
result_item.depth_ if hasattr(result_item, "depth_") else 0
|
|
)
|
|
nodes.append(nodes_item)
|
|
leaves.append(leaves_item)
|
|
depths.append(depth_item)
|
|
if verbose:
|
|
print(
|
|
f"Random seed: {random_state:5d} Accuracy: "
|
|
f"{res['test_score'].mean():6.4f}±"
|
|
f"{res['test_score'].std():6.4f} "
|
|
f"{res['fit_time'].mean():5.3f}s"
|
|
)
|
|
return scores, times, json.dumps(hyperparameters), nodes, leaves, depths
|
|
|
|
|
|
def store_string(
|
|
dataset, model, accuracy, time_spent, hyperparameters, complexity
|
|
):
|
|
attributes = [
|
|
"date",
|
|
"time",
|
|
"type",
|
|
"accuracy",
|
|
"accuracy_std",
|
|
"dataset",
|
|
"classifier",
|
|
"norm",
|
|
"stand",
|
|
"time_spent",
|
|
"time_spent_std",
|
|
"parameters",
|
|
"nodes",
|
|
"leaves",
|
|
"depth",
|
|
]
|
|
command_insert = (
|
|
"replace into results ("
|
|
+ ",".join(attributes)
|
|
+ ") values("
|
|
+ ("'%s'," * len(attributes))[:-1]
|
|
+ ");"
|
|
)
|
|
now = datetime.now()
|
|
date = now.strftime("%Y-%m-%d")
|
|
time = now.strftime("%H:%M:%S")
|
|
nodes, leaves, depth = complexity.values()
|
|
values = (
|
|
date,
|
|
time,
|
|
"crossval",
|
|
np.mean(accuracy),
|
|
np.std(accuracy),
|
|
dataset,
|
|
model,
|
|
1,
|
|
0,
|
|
np.mean(time_spent),
|
|
np.std(time_spent),
|
|
hyperparameters,
|
|
nodes,
|
|
leaves,
|
|
depth,
|
|
)
|
|
result = command_insert % values
|
|
return result
|
|
|
|
|
|
def excel_write_line(
|
|
book,
|
|
sheet,
|
|
name,
|
|
samples,
|
|
features,
|
|
classes,
|
|
accuracy,
|
|
times,
|
|
hyperparameters,
|
|
complexity,
|
|
status,
|
|
):
|
|
try:
|
|
excel_write_line.row += 1
|
|
except AttributeError:
|
|
excel_write_line.row = 4
|
|
size_n = 14
|
|
decimal = book.add_format({"num_format": "0.000000", "font_size": size_n})
|
|
integer = book.add_format({"num_format": "#,###", "font_size": size_n})
|
|
normal = book.add_format({"font_size": size_n})
|
|
col = 0
|
|
status, _ = excel_status(status)
|
|
sheet.write(excel_write_line.row, col, name, normal)
|
|
sheet.write(excel_write_line.row, col + 1, samples, integer)
|
|
sheet.write(excel_write_line.row, col + 2, features, normal)
|
|
sheet.write(excel_write_line.row, col + 3, classes, normal)
|
|
sheet.write(excel_write_line.row, col + 4, complexity["nodes"], normal)
|
|
sheet.write(excel_write_line.row, col + 5, complexity["leaves"], normal)
|
|
sheet.write(excel_write_line.row, col + 6, complexity["depth"], normal)
|
|
sheet.write(excel_write_line.row, col + 7, accuracy, decimal)
|
|
sheet.write(excel_write_line.row, col + 8, status, normal)
|
|
sheet.write(excel_write_line.row, col + 9, np.mean(times), decimal)
|
|
sheet.write(excel_write_line.row, col + 10, hyperparameters, normal)
|
|
|
|
|
|
def excel_write_header(book, sheet):
|
|
header = book.add_format()
|
|
header.set_font_size(18)
|
|
subheader = book.add_format()
|
|
subheader.set_font_size(16)
|
|
sheet.write(
|
|
0,
|
|
0,
|
|
f"Process all datasets set with {model}: {set_of_files} "
|
|
f"norm: {normalize} std: {standardize} discretize: {discretize} "
|
|
f"store in: {model}",
|
|
header,
|
|
)
|
|
sheet.write(
|
|
1,
|
|
0,
|
|
"5 Fold Cross Validation with 10 random seeds",
|
|
subheader,
|
|
)
|
|
sheet.write(1, 5, f"{random_seeds}", subheader)
|
|
header_cols = [
|
|
("Dataset", 30),
|
|
("Samples", 10),
|
|
("Variables", 7),
|
|
("Classes", 7),
|
|
("Nodes", 7),
|
|
("Leaves", 7),
|
|
("Depth", 7),
|
|
("Accuracy", 10),
|
|
("Stat", 3),
|
|
("Time", 10),
|
|
("Parameters", 50),
|
|
]
|
|
bold = book.add_format({"bold": True, "font_size": 14})
|
|
i = 0
|
|
for item, length in header_cols:
|
|
sheet.write(3, i, item, bold)
|
|
sheet.set_column(i, i, length)
|
|
i += 1
|
|
|
|
|
|
def excel_status(status):
|
|
if status == TextColor.GREEN + CHECK_MARK + TextColor.ENDC:
|
|
return EXCLAMATION_MARK, "Accuracy better than stree optimized"
|
|
elif status == TextColor.RED + BLACK_STAR + TextColor.ENDC:
|
|
return BLACK_STAR, "Best accuracy of al models"
|
|
elif status != " ":
|
|
return CHECK_MARK, "Accuracy better than original stree_default"
|
|
return " ", ""
|
|
|
|
|
|
def excel_write_totals(book, sheet, totals, start, accuracy_total):
|
|
i = 2
|
|
bold = book.add_format({"bold": True, "font_size": 16})
|
|
for key, total in totals.items():
|
|
status, text = excel_status(key)
|
|
sheet.write(excel_write_line.row + i, 1, status, bold)
|
|
sheet.write(excel_write_line.row + i, 2, total, bold)
|
|
sheet.write(excel_write_line.row + i, 3, text, bold)
|
|
i += 1
|
|
message = (
|
|
f"** Accuracy compared to stree_default (liblinear-ovr) .: "
|
|
f"{accuracy_total/40.282203:7.4f}"
|
|
)
|
|
sheet.write(excel_write_line.row + i + 1, 0, message, bold)
|
|
time_spent = get_time(start, time.time())
|
|
sheet.write(excel_write_line.row + i + 3, 0, time_spent, bold)
|
|
|
|
|
|
def compute_status(dbh, name, model, accuracy):
|
|
n_dig = 6
|
|
ac_round = round(accuracy, n_dig)
|
|
better_default = CHECK_MARK
|
|
better_stree = TextColor.GREEN + CHECK_MARK + TextColor.ENDC
|
|
best = TextColor.RED + BLACK_STAR + TextColor.ENDC
|
|
best_default, _ = get_best_score(dbh, name, model)
|
|
best_stree, _ = get_best_score(dbh, name, "stree")
|
|
best_all, _ = get_best_score(dbh, name, models_tree)
|
|
status = better_default if ac_round > round(best_default, n_dig) else " "
|
|
status = better_stree if ac_round > round(best_stree, n_dig) else status
|
|
status = best if ac_round > round(best_all, n_dig) else status
|
|
return status
|
|
|
|
|
|
def get_best_score(dbh, name, model):
|
|
record = dbh.find_best(name, model, "crossval")
|
|
accuracy = record[5] if record is not None else 0.0
|
|
acc_std = record[11] if record is not None else 0.0
|
|
return accuracy, acc_std
|
|
|
|
|
|
def get_time(start, stop):
|
|
hours, rem = divmod(stop - start, 3600)
|
|
minutes, seconds = divmod(rem, 60)
|
|
return f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s"
|
|
|
|
|
|
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
|
models_tree = [
|
|
"stree",
|
|
"stree_default",
|
|
"wodt",
|
|
"j48svm",
|
|
"oc1",
|
|
"cart",
|
|
"baseRaF",
|
|
]
|
|
standardize = False
|
|
(
|
|
set_of_files,
|
|
model,
|
|
dataset,
|
|
sql,
|
|
normalize,
|
|
parameters,
|
|
excel,
|
|
discretize,
|
|
) = parse_arguments()
|
|
# parameters = '{"kernel":"rbf","max_features":"auto"}'
|
|
dbh = MySQL()
|
|
if sql:
|
|
sql_output = open(f"{model}.sql", "w")
|
|
if excel != "":
|
|
file_name = f"{excel}.xlsx"
|
|
excel_wb = xlsxwriter.Workbook(file_name)
|
|
excel_ws = excel_wb.add_worksheet(model)
|
|
excel_write_header(excel_wb, excel_ws)
|
|
database = dbh.get_connection()
|
|
dt = Datasets(normalize, standardize, set_of_files)
|
|
start = time.time()
|
|
if dataset == "all":
|
|
print(
|
|
f"* Process all datasets set with {model}: {set_of_files} "
|
|
f"norm: {normalize} std: {standardize} discretize: {discretize}"
|
|
f" store in: {model}"
|
|
)
|
|
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n")
|
|
header_cols = [
|
|
"Dataset",
|
|
"Samp",
|
|
"Var",
|
|
"Cls",
|
|
"Nodes",
|
|
"Leaves",
|
|
"Depth",
|
|
"Accuracy",
|
|
"Time",
|
|
"Parameters",
|
|
]
|
|
header_lengths = [30, 5, 3, 3, 7, 7, 7, 15, 15, 10]
|
|
parameters = json.dumps(json.loads(parameters))
|
|
if parameters != "{}" and len(parameters) > 10:
|
|
header_lengths.pop()
|
|
header_lengths.append(len(parameters))
|
|
line_col = ""
|
|
for field, underscore in zip(header_cols, header_lengths):
|
|
print(f"{field:{underscore}s} ", end="")
|
|
line_col += "=" * underscore + " "
|
|
print(f"\n{line_col}")
|
|
totals = {}
|
|
accuracy_total = 0.0
|
|
for dataset in dt:
|
|
name = dataset[0]
|
|
X, y = dt.load(name) # type: ignore
|
|
samples, features = X.shape
|
|
classes = len(np.unique(y))
|
|
print(
|
|
f"{name:30s} {samples:5d} {features:3d} {classes:3d} ",
|
|
end="",
|
|
)
|
|
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
|
|
dataset[0], verbose=False, model=model, params=parameters
|
|
)
|
|
complexity = dict(
|
|
nodes=float(np.mean(nodes)),
|
|
leaves=float(np.mean(leaves)),
|
|
depth=float(np.mean(depth)),
|
|
)
|
|
nodes_item, leaves_item, depth_item = complexity.values()
|
|
print(
|
|
f"{nodes_item:7.2f} {leaves_item:7.2f} {depth_item:7.2f} ",
|
|
end="",
|
|
)
|
|
accuracy = np.mean(scores)
|
|
accuracy_total += accuracy
|
|
status = (
|
|
compute_status(dbh, name, model, accuracy)
|
|
if model == "stree_default"
|
|
else " "
|
|
)
|
|
if status != " ":
|
|
if status not in totals:
|
|
totals[status] = 1
|
|
else:
|
|
totals[status] += 1
|
|
print(f"{accuracy:8.6f}±{np.std(scores):6.4f}{status}", end="")
|
|
print(f"{np.mean(times):8.6f}±{np.std(times):6.4f} {hyperparameters}")
|
|
if sql:
|
|
command = store_string(
|
|
name, model, scores, times, hyperparameters, complexity
|
|
)
|
|
print(command, file=sql_output)
|
|
if excel != "":
|
|
excel_write_line(
|
|
excel_wb,
|
|
excel_ws,
|
|
name,
|
|
samples,
|
|
features,
|
|
classes,
|
|
accuracy,
|
|
times,
|
|
hyperparameters,
|
|
complexity,
|
|
status,
|
|
)
|
|
for key, value in totals.items():
|
|
print(f"{key} .....: {value:2d}")
|
|
print(
|
|
f"** Accuracy compared to stree_default (liblinear-ovr) .: "
|
|
f"{accuracy_total/40.282203:7.4f}"
|
|
)
|
|
if excel != "":
|
|
excel_write_totals(excel_wb, excel_ws, totals, start, accuracy_total)
|
|
excel_wb.close()
|
|
else:
|
|
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
|
|
dataset, verbose=True, model=model, params=parameters
|
|
)
|
|
best_accuracy, acc_best_std = get_best_score(dbh, dataset, model)
|
|
accuracy = np.mean(scores)
|
|
print(f"* Normalize/Standard.: {normalize} / {standardize}")
|
|
print(
|
|
f"* Accuracy Computed .: {accuracy:6.4f}±{np.std(scores):6.4f} "
|
|
f"{np.mean(times):5.3f}s"
|
|
)
|
|
print(f"* Best Accuracy model: {best_accuracy:6.4f}±{acc_best_std:6.4f}")
|
|
print(f"* Difference ........: {best_accuracy - accuracy:6.4f}")
|
|
best_accuracy, acc_best_std = get_best_score(dbh, dataset, models_tree)
|
|
print(f"* Best Accuracy .....: {best_accuracy:6.4f}±{acc_best_std:6.4f}")
|
|
print(f"* Difference ........: {best_accuracy - accuracy:6.4f}")
|
|
print(
|
|
f"* Nodes/Leaves/Depth : {np.mean(nodes):.2f} {np.mean(leaves):.2f} "
|
|
f"{np.mean(depth):.2f} "
|
|
)
|
|
print(f"- Hyperparameters ...: {hyperparameters}")
|
|
stop = time.time()
|
|
time_spent = get_time(start, time.time())
|
|
print(f"{time_spent}")
|
|
if sql:
|
|
sql_output.close()
|
|
dbh.close()
|