mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-16 07:56:07 +00:00
Add excel gener. Fayyad discretiz. to report_score
Add stree_default comparison at the end of the report
This commit is contained in:
203
report_score.py
203
report_score.py
@@ -4,6 +4,7 @@ import time
|
||||
from datetime import datetime
|
||||
import json
|
||||
import numpy as np
|
||||
import xlsxwriter
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from stree import Stree
|
||||
from sklearn.model_selection import KFold, cross_validate
|
||||
@@ -11,6 +12,12 @@ from experimentation.Sets import Datasets
|
||||
from experimentation.Database import MySQL
|
||||
from wodt import TreeClassifier
|
||||
from experimentation.Utils import TextColor
|
||||
from mdlp import MDLP
|
||||
|
||||
|
||||
CHECK_MARK = "\N{heavy check mark}"
|
||||
EXCLAMATION_MARK = "\N{heavy exclamation mark symbol}"
|
||||
BLACK_STAR = "\N{black star}"
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
@@ -52,6 +59,22 @@ def parse_arguments():
|
||||
type=int,
|
||||
required=True,
|
||||
)
|
||||
ap.add_argument(
|
||||
"-x",
|
||||
"--excel",
|
||||
type=str,
|
||||
default="",
|
||||
required=False,
|
||||
help="generate excel file",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-di",
|
||||
"--discretize",
|
||||
type=bool,
|
||||
default=False,
|
||||
required=False,
|
||||
help="Discretize datasets",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-p", "--parameters", type=str, required=False, default="{}"
|
||||
)
|
||||
@@ -63,6 +86,8 @@ def parse_arguments():
|
||||
args.sql,
|
||||
bool(args.normalize),
|
||||
args.parameters,
|
||||
args.excel,
|
||||
args.discretize,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,6 +104,9 @@ def get_classifier(model, random_state, hyperparameters):
|
||||
|
||||
def process_dataset(dataset, verbose, model, params):
|
||||
X, y = dt.load(dataset)
|
||||
if discretize:
|
||||
mdlp = MDLP(random_state=1)
|
||||
X = mdlp.fit_transform(X, y)
|
||||
scores = []
|
||||
times = []
|
||||
nodes = []
|
||||
@@ -179,16 +207,118 @@ def store_string(
|
||||
return result
|
||||
|
||||
|
||||
def excel_write_line(
|
||||
book,
|
||||
sheet,
|
||||
name,
|
||||
samples,
|
||||
features,
|
||||
classes,
|
||||
accuracy,
|
||||
times,
|
||||
hyperparameters,
|
||||
complexity,
|
||||
status,
|
||||
):
|
||||
try:
|
||||
excel_write_line.row += 1
|
||||
except AttributeError:
|
||||
excel_write_line.row = 4
|
||||
size_n = 14
|
||||
decimal = book.add_format({"num_format": "0.000000", "font_size": size_n})
|
||||
integer = book.add_format({"num_format": "#,###", "font_size": size_n})
|
||||
normal = book.add_format({"font_size": size_n})
|
||||
col = 0
|
||||
status, _ = excel_status(status)
|
||||
sheet.write(excel_write_line.row, col, name, normal)
|
||||
sheet.write(excel_write_line.row, col + 1, samples, integer)
|
||||
sheet.write(excel_write_line.row, col + 2, features, normal)
|
||||
sheet.write(excel_write_line.row, col + 3, classes, normal)
|
||||
sheet.write(excel_write_line.row, col + 4, complexity["nodes"], normal)
|
||||
sheet.write(excel_write_line.row, col + 5, complexity["leaves"], normal)
|
||||
sheet.write(excel_write_line.row, col + 6, complexity["depth"], normal)
|
||||
sheet.write(excel_write_line.row, col + 7, accuracy, decimal)
|
||||
sheet.write(excel_write_line.row, col + 8, status, normal)
|
||||
sheet.write(excel_write_line.row, col + 9, np.mean(times), decimal)
|
||||
sheet.write(excel_write_line.row, col + 10, hyperparameters, normal)
|
||||
|
||||
|
||||
def excel_write_header(book, sheet):
|
||||
header = book.add_format()
|
||||
header.set_font_size(18)
|
||||
subheader = book.add_format()
|
||||
subheader.set_font_size(16)
|
||||
sheet.write(
|
||||
0,
|
||||
0,
|
||||
f"Process all datasets set with {model}: {set_of_files} "
|
||||
f"norm: {normalize} std: {standardize} discretize: {discretize} "
|
||||
f"store in: {model}",
|
||||
header,
|
||||
)
|
||||
sheet.write(
|
||||
1,
|
||||
0,
|
||||
"5 Fold Cross Validation with 10 random seeds",
|
||||
subheader,
|
||||
)
|
||||
sheet.write(1, 5, f"{random_seeds}", subheader)
|
||||
header_cols = [
|
||||
("Dataset", 30),
|
||||
("Samples", 10),
|
||||
("Variables", 7),
|
||||
("Classes", 7),
|
||||
("Nodes", 7),
|
||||
("Leaves", 7),
|
||||
("Depth", 7),
|
||||
("Accuracy", 10),
|
||||
("Stat", 3),
|
||||
("Time", 10),
|
||||
("Parameters", 50),
|
||||
]
|
||||
bold = book.add_format({"bold": True, "font_size": 14})
|
||||
i = 0
|
||||
for item, length in header_cols:
|
||||
sheet.write(3, i, item, bold)
|
||||
sheet.set_column(i, i, length)
|
||||
i += 1
|
||||
|
||||
|
||||
def excel_status(status):
|
||||
if status == TextColor.GREEN + CHECK_MARK + TextColor.ENDC:
|
||||
return EXCLAMATION_MARK, "Accuracy better than stree optimized"
|
||||
elif status == TextColor.RED + BLACK_STAR + TextColor.ENDC:
|
||||
return BLACK_STAR, "Best accuracy of al models"
|
||||
elif status != " ":
|
||||
return CHECK_MARK, "Accuracy better than original stree_default"
|
||||
return " ", ""
|
||||
|
||||
|
||||
def excel_write_totals(book, sheet, totals, start):
|
||||
i = 2
|
||||
bold = book.add_format({"bold": True, "font_size": 16})
|
||||
for key, total in totals.items():
|
||||
status, text = excel_status(key)
|
||||
sheet.write(excel_write_line.row + i, 1, status, bold)
|
||||
sheet.write(excel_write_line.row + i, 2, total, bold)
|
||||
sheet.write(excel_write_line.row + i, 3, text, bold)
|
||||
i += 1
|
||||
time_spent = get_time(start, time.time())
|
||||
sheet.write(excel_write_line.row + i + 1, 0, time_spent, bold)
|
||||
|
||||
|
||||
def compute_status(dbh, name, model, accuracy):
|
||||
better_default = "\N{heavy check mark}"
|
||||
better_stree = TextColor.GREEN + "\N{heavy check mark}" + TextColor.ENDC
|
||||
best = TextColor.RED + "\N{black star}" + TextColor.ENDC
|
||||
n_dig = 6
|
||||
ac_round = round(accuracy, n_dig)
|
||||
better_default = CHECK_MARK
|
||||
better_stree = TextColor.GREEN + CHECK_MARK + TextColor.ENDC
|
||||
best = TextColor.RED + BLACK_STAR + TextColor.ENDC
|
||||
best_default, _ = get_best_score(dbh, name, model)
|
||||
best_stree, _ = get_best_score(dbh, name, "stree")
|
||||
best_all, _ = get_best_score(dbh, name, models_tree)
|
||||
status = better_default if accuracy >= best_default else " "
|
||||
status = better_stree if accuracy >= best_stree else status
|
||||
status = best if accuracy >= best_all else status
|
||||
status = better_default if ac_round > round(best_default, n_dig) else " "
|
||||
status = better_stree if ac_round > round(best_stree, n_dig) else status
|
||||
status = best if ac_round > round(best_all, n_dig) else status
|
||||
return status
|
||||
|
||||
|
||||
@@ -199,6 +329,12 @@ def get_best_score(dbh, name, model):
|
||||
return accuracy, acc_std
|
||||
|
||||
|
||||
def get_time(start, stop):
|
||||
hours, rem = divmod(stop - start, 3600)
|
||||
minutes, seconds = divmod(rem, 60)
|
||||
return f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s"
|
||||
|
||||
|
||||
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||
models_tree = [
|
||||
"stree",
|
||||
@@ -210,17 +346,32 @@ models_tree = [
|
||||
"baseRaF",
|
||||
]
|
||||
standardize = False
|
||||
(set_of_files, model, dataset, sql, normalize, parameters) = parse_arguments()
|
||||
(
|
||||
set_of_files,
|
||||
model,
|
||||
dataset,
|
||||
sql,
|
||||
normalize,
|
||||
parameters,
|
||||
excel,
|
||||
discretize,
|
||||
) = parse_arguments()
|
||||
dbh = MySQL()
|
||||
if sql:
|
||||
sql_output = open(f"{model}.sql", "w")
|
||||
if excel != "":
|
||||
file_name = f"{excel}.xlsx"
|
||||
excel_wb = xlsxwriter.Workbook(file_name)
|
||||
excel_ws = excel_wb.add_worksheet(model)
|
||||
excel_write_header(excel_wb, excel_ws)
|
||||
database = dbh.get_connection()
|
||||
dt = Datasets(normalize, standardize, set_of_files)
|
||||
start = time.time()
|
||||
if dataset == "all":
|
||||
print(
|
||||
f"* Process all datasets set with {model}: {set_of_files} "
|
||||
f"norm: {normalize} std: {standardize} store in: {model}"
|
||||
f"norm: {normalize} std: {standardize} discretize: {discretize}"
|
||||
f" store in: {model}"
|
||||
)
|
||||
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n")
|
||||
header_cols = [
|
||||
@@ -245,6 +396,8 @@ if dataset == "all":
|
||||
print(f"{field:{underscore}s} ", end="")
|
||||
line_col += "=" * underscore + " "
|
||||
print(f"\n{line_col}")
|
||||
totals = {}
|
||||
accuracy_total = 0.0
|
||||
for dataset in dt:
|
||||
name = dataset[0]
|
||||
X, y = dt.load(name) # type: ignore
|
||||
@@ -268,11 +421,17 @@ if dataset == "all":
|
||||
end="",
|
||||
)
|
||||
accuracy = np.mean(scores)
|
||||
accuracy_total += accuracy
|
||||
status = (
|
||||
compute_status(dbh, name, model, accuracy)
|
||||
if model == "stree_default"
|
||||
else " "
|
||||
)
|
||||
if status != " ":
|
||||
if status not in totals:
|
||||
totals[status] = 1
|
||||
else:
|
||||
totals[status] += 1
|
||||
print(f"{accuracy:8.6f}±{np.std(scores):6.4f}{status}", end="")
|
||||
print(f"{np.mean(times):8.6f}±{np.std(times):6.4f} {hyperparameters}")
|
||||
if sql:
|
||||
@@ -280,6 +439,29 @@ if dataset == "all":
|
||||
name, model, scores, times, hyperparameters, complexity
|
||||
)
|
||||
print(command, file=sql_output)
|
||||
if excel != "":
|
||||
excel_write_line(
|
||||
excel_wb,
|
||||
excel_ws,
|
||||
name,
|
||||
samples,
|
||||
features,
|
||||
classes,
|
||||
accuracy,
|
||||
times,
|
||||
hyperparameters,
|
||||
complexity,
|
||||
status,
|
||||
)
|
||||
for key, value in totals.items():
|
||||
print(f"{key} .....: {value:2d}")
|
||||
print(
|
||||
f"** Accuracy compared to stree_default (liblinear-ovr) .: "
|
||||
f"{accuracy_total/40.282203:7.4f}"
|
||||
)
|
||||
if excel != "":
|
||||
excel_write_totals(excel_wb, excel_ws, totals, start)
|
||||
excel_wb.close()
|
||||
else:
|
||||
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
|
||||
dataset, verbose=True, model=model, params=parameters
|
||||
@@ -302,9 +484,8 @@ else:
|
||||
)
|
||||
print(f"- Hyperparameters ...: {hyperparameters}")
|
||||
stop = time.time()
|
||||
hours, rem = divmod(stop - start, 3600)
|
||||
minutes, seconds = divmod(rem, 60)
|
||||
print(f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s")
|
||||
time_spent = get_time(start, time.time())
|
||||
print(f"{time_spent}")
|
||||
if sql:
|
||||
sql_output.close()
|
||||
dbh.close()
|
||||
|
Reference in New Issue
Block a user