Add excel gener. Fayyad discretiz. to report_score

Add stree_default comparison at the end of the report
This commit is contained in:
2021-05-12 19:19:43 +02:00
parent d9f5bfee6c
commit 44ba4f05b9

View File

@@ -4,6 +4,7 @@ import time
from datetime import datetime from datetime import datetime
import json import json
import numpy as np import numpy as np
import xlsxwriter
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from stree import Stree from stree import Stree
from sklearn.model_selection import KFold, cross_validate from sklearn.model_selection import KFold, cross_validate
@@ -11,6 +12,12 @@ from experimentation.Sets import Datasets
from experimentation.Database import MySQL from experimentation.Database import MySQL
from wodt import TreeClassifier from wodt import TreeClassifier
from experimentation.Utils import TextColor from experimentation.Utils import TextColor
from mdlp import MDLP
CHECK_MARK = "\N{heavy check mark}"
EXCLAMATION_MARK = "\N{heavy exclamation mark symbol}"
BLACK_STAR = "\N{black star}"
def parse_arguments(): def parse_arguments():
@@ -52,6 +59,22 @@ def parse_arguments():
type=int, type=int,
required=True, required=True,
) )
ap.add_argument(
"-x",
"--excel",
type=str,
default="",
required=False,
help="generate excel file",
)
ap.add_argument(
"-di",
"--discretize",
type=bool,
default=False,
required=False,
help="Discretize datasets",
)
ap.add_argument( ap.add_argument(
"-p", "--parameters", type=str, required=False, default="{}" "-p", "--parameters", type=str, required=False, default="{}"
) )
@@ -63,6 +86,8 @@ def parse_arguments():
args.sql, args.sql,
bool(args.normalize), bool(args.normalize),
args.parameters, args.parameters,
args.excel,
args.discretize,
) )
@@ -79,6 +104,9 @@ def get_classifier(model, random_state, hyperparameters):
def process_dataset(dataset, verbose, model, params): def process_dataset(dataset, verbose, model, params):
X, y = dt.load(dataset) X, y = dt.load(dataset)
if discretize:
mdlp = MDLP(random_state=1)
X = mdlp.fit_transform(X, y)
scores = [] scores = []
times = [] times = []
nodes = [] nodes = []
@@ -179,16 +207,118 @@ def store_string(
return result return result
def excel_write_line(
book,
sheet,
name,
samples,
features,
classes,
accuracy,
times,
hyperparameters,
complexity,
status,
):
try:
excel_write_line.row += 1
except AttributeError:
excel_write_line.row = 4
size_n = 14
decimal = book.add_format({"num_format": "0.000000", "font_size": size_n})
integer = book.add_format({"num_format": "#,###", "font_size": size_n})
normal = book.add_format({"font_size": size_n})
col = 0
status, _ = excel_status(status)
sheet.write(excel_write_line.row, col, name, normal)
sheet.write(excel_write_line.row, col + 1, samples, integer)
sheet.write(excel_write_line.row, col + 2, features, normal)
sheet.write(excel_write_line.row, col + 3, classes, normal)
sheet.write(excel_write_line.row, col + 4, complexity["nodes"], normal)
sheet.write(excel_write_line.row, col + 5, complexity["leaves"], normal)
sheet.write(excel_write_line.row, col + 6, complexity["depth"], normal)
sheet.write(excel_write_line.row, col + 7, accuracy, decimal)
sheet.write(excel_write_line.row, col + 8, status, normal)
sheet.write(excel_write_line.row, col + 9, np.mean(times), decimal)
sheet.write(excel_write_line.row, col + 10, hyperparameters, normal)
def excel_write_header(book, sheet):
header = book.add_format()
header.set_font_size(18)
subheader = book.add_format()
subheader.set_font_size(16)
sheet.write(
0,
0,
f"Process all datasets set with {model}: {set_of_files} "
f"norm: {normalize} std: {standardize} discretize: {discretize} "
f"store in: {model}",
header,
)
sheet.write(
1,
0,
"5 Fold Cross Validation with 10 random seeds",
subheader,
)
sheet.write(1, 5, f"{random_seeds}", subheader)
header_cols = [
("Dataset", 30),
("Samples", 10),
("Variables", 7),
("Classes", 7),
("Nodes", 7),
("Leaves", 7),
("Depth", 7),
("Accuracy", 10),
("Stat", 3),
("Time", 10),
("Parameters", 50),
]
bold = book.add_format({"bold": True, "font_size": 14})
i = 0
for item, length in header_cols:
sheet.write(3, i, item, bold)
sheet.set_column(i, i, length)
i += 1
def excel_status(status):
if status == TextColor.GREEN + CHECK_MARK + TextColor.ENDC:
return EXCLAMATION_MARK, "Accuracy better than stree optimized"
elif status == TextColor.RED + BLACK_STAR + TextColor.ENDC:
return BLACK_STAR, "Best accuracy of al models"
elif status != " ":
return CHECK_MARK, "Accuracy better than original stree_default"
return " ", ""
def excel_write_totals(book, sheet, totals, start):
i = 2
bold = book.add_format({"bold": True, "font_size": 16})
for key, total in totals.items():
status, text = excel_status(key)
sheet.write(excel_write_line.row + i, 1, status, bold)
sheet.write(excel_write_line.row + i, 2, total, bold)
sheet.write(excel_write_line.row + i, 3, text, bold)
i += 1
time_spent = get_time(start, time.time())
sheet.write(excel_write_line.row + i + 1, 0, time_spent, bold)
def compute_status(dbh, name, model, accuracy): def compute_status(dbh, name, model, accuracy):
better_default = "\N{heavy check mark}" n_dig = 6
better_stree = TextColor.GREEN + "\N{heavy check mark}" + TextColor.ENDC ac_round = round(accuracy, n_dig)
best = TextColor.RED + "\N{black star}" + TextColor.ENDC better_default = CHECK_MARK
better_stree = TextColor.GREEN + CHECK_MARK + TextColor.ENDC
best = TextColor.RED + BLACK_STAR + TextColor.ENDC
best_default, _ = get_best_score(dbh, name, model) best_default, _ = get_best_score(dbh, name, model)
best_stree, _ = get_best_score(dbh, name, "stree") best_stree, _ = get_best_score(dbh, name, "stree")
best_all, _ = get_best_score(dbh, name, models_tree) best_all, _ = get_best_score(dbh, name, models_tree)
status = better_default if accuracy >= best_default else " " status = better_default if ac_round > round(best_default, n_dig) else " "
status = better_stree if accuracy >= best_stree else status status = better_stree if ac_round > round(best_stree, n_dig) else status
status = best if accuracy >= best_all else status status = best if ac_round > round(best_all, n_dig) else status
return status return status
@@ -199,6 +329,12 @@ def get_best_score(dbh, name, model):
return accuracy, acc_std return accuracy, acc_std
def get_time(start, stop):
hours, rem = divmod(stop - start, 3600)
minutes, seconds = divmod(rem, 60)
return f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s"
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
models_tree = [ models_tree = [
"stree", "stree",
@@ -210,17 +346,32 @@ models_tree = [
"baseRaF", "baseRaF",
] ]
standardize = False standardize = False
(set_of_files, model, dataset, sql, normalize, parameters) = parse_arguments() (
set_of_files,
model,
dataset,
sql,
normalize,
parameters,
excel,
discretize,
) = parse_arguments()
dbh = MySQL() dbh = MySQL()
if sql: if sql:
sql_output = open(f"{model}.sql", "w") sql_output = open(f"{model}.sql", "w")
if excel != "":
file_name = f"{excel}.xlsx"
excel_wb = xlsxwriter.Workbook(file_name)
excel_ws = excel_wb.add_worksheet(model)
excel_write_header(excel_wb, excel_ws)
database = dbh.get_connection() database = dbh.get_connection()
dt = Datasets(normalize, standardize, set_of_files) dt = Datasets(normalize, standardize, set_of_files)
start = time.time() start = time.time()
if dataset == "all": if dataset == "all":
print( print(
f"* Process all datasets set with {model}: {set_of_files} " f"* Process all datasets set with {model}: {set_of_files} "
f"norm: {normalize} std: {standardize} store in: {model}" f"norm: {normalize} std: {standardize} discretize: {discretize}"
f" store in: {model}"
) )
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n") print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n")
header_cols = [ header_cols = [
@@ -245,6 +396,8 @@ if dataset == "all":
print(f"{field:{underscore}s} ", end="") print(f"{field:{underscore}s} ", end="")
line_col += "=" * underscore + " " line_col += "=" * underscore + " "
print(f"\n{line_col}") print(f"\n{line_col}")
totals = {}
accuracy_total = 0.0
for dataset in dt: for dataset in dt:
name = dataset[0] name = dataset[0]
X, y = dt.load(name) # type: ignore X, y = dt.load(name) # type: ignore
@@ -268,11 +421,17 @@ if dataset == "all":
end="", end="",
) )
accuracy = np.mean(scores) accuracy = np.mean(scores)
accuracy_total += accuracy
status = ( status = (
compute_status(dbh, name, model, accuracy) compute_status(dbh, name, model, accuracy)
if model == "stree_default" if model == "stree_default"
else " " else " "
) )
if status != " ":
if status not in totals:
totals[status] = 1
else:
totals[status] += 1
print(f"{accuracy:8.6f}±{np.std(scores):6.4f}{status}", end="") print(f"{accuracy:8.6f}±{np.std(scores):6.4f}{status}", end="")
print(f"{np.mean(times):8.6f}±{np.std(times):6.4f} {hyperparameters}") print(f"{np.mean(times):8.6f}±{np.std(times):6.4f} {hyperparameters}")
if sql: if sql:
@@ -280,6 +439,29 @@ if dataset == "all":
name, model, scores, times, hyperparameters, complexity name, model, scores, times, hyperparameters, complexity
) )
print(command, file=sql_output) print(command, file=sql_output)
if excel != "":
excel_write_line(
excel_wb,
excel_ws,
name,
samples,
features,
classes,
accuracy,
times,
hyperparameters,
complexity,
status,
)
for key, value in totals.items():
print(f"{key} .....: {value:2d}")
print(
f"** Accuracy compared to stree_default (liblinear-ovr) .: "
f"{accuracy_total/40.282203:7.4f}"
)
if excel != "":
excel_write_totals(excel_wb, excel_ws, totals, start)
excel_wb.close()
else: else:
scores, times, hyperparameters, nodes, leaves, depth = process_dataset( scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
dataset, verbose=True, model=model, params=parameters dataset, verbose=True, model=model, params=parameters
@@ -302,9 +484,8 @@ else:
) )
print(f"- Hyperparameters ...: {hyperparameters}") print(f"- Hyperparameters ...: {hyperparameters}")
stop = time.time() stop = time.time()
hours, rem = divmod(stop - start, 3600) time_spent = get_time(start, time.time())
minutes, seconds = divmod(rem, 60) print(f"{time_spent}")
print(f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s")
if sql: if sql:
sql_output.close() sql_output.close()
dbh.close() dbh.close()