diff --git a/report_score.py b/report_score.py index 16e536b..7535c6e 100644 --- a/report_score.py +++ b/report_score.py @@ -4,6 +4,7 @@ import time from datetime import datetime import json import numpy as np +import xlsxwriter from sklearn.tree import DecisionTreeClassifier from stree import Stree from sklearn.model_selection import KFold, cross_validate @@ -11,6 +12,12 @@ from experimentation.Sets import Datasets from experimentation.Database import MySQL from wodt import TreeClassifier from experimentation.Utils import TextColor +from mdlp import MDLP + + +CHECK_MARK = "\N{heavy check mark}" +EXCLAMATION_MARK = "\N{heavy exclamation mark symbol}" +BLACK_STAR = "\N{black star}" def parse_arguments(): @@ -52,6 +59,22 @@ def parse_arguments(): type=int, required=True, ) + ap.add_argument( + "-x", + "--excel", + type=str, + default="", + required=False, + help="generate excel file", + ) + ap.add_argument( + "-di", + "--discretize", + type=bool, + default=False, + required=False, + help="Discretize datasets", + ) ap.add_argument( "-p", "--parameters", type=str, required=False, default="{}" ) @@ -63,6 +86,8 @@ def parse_arguments(): args.sql, bool(args.normalize), args.parameters, + args.excel, + args.discretize, ) @@ -79,6 +104,9 @@ def get_classifier(model, random_state, hyperparameters): def process_dataset(dataset, verbose, model, params): X, y = dt.load(dataset) + if discretize: + mdlp = MDLP(random_state=1) + X = mdlp.fit_transform(X, y) scores = [] times = [] nodes = [] @@ -179,16 +207,118 @@ def store_string( return result +def excel_write_line( + book, + sheet, + name, + samples, + features, + classes, + accuracy, + times, + hyperparameters, + complexity, + status, +): + try: + excel_write_line.row += 1 + except AttributeError: + excel_write_line.row = 4 + size_n = 14 + decimal = book.add_format({"num_format": "0.000000", "font_size": size_n}) + integer = book.add_format({"num_format": "#,###", "font_size": size_n}) + normal = book.add_format({"font_size": size_n}) + col = 0 + status, _ = excel_status(status) + sheet.write(excel_write_line.row, col, name, normal) + sheet.write(excel_write_line.row, col + 1, samples, integer) + sheet.write(excel_write_line.row, col + 2, features, normal) + sheet.write(excel_write_line.row, col + 3, classes, normal) + sheet.write(excel_write_line.row, col + 4, complexity["nodes"], normal) + sheet.write(excel_write_line.row, col + 5, complexity["leaves"], normal) + sheet.write(excel_write_line.row, col + 6, complexity["depth"], normal) + sheet.write(excel_write_line.row, col + 7, accuracy, decimal) + sheet.write(excel_write_line.row, col + 8, status, normal) + sheet.write(excel_write_line.row, col + 9, np.mean(times), decimal) + sheet.write(excel_write_line.row, col + 10, hyperparameters, normal) + + +def excel_write_header(book, sheet): + header = book.add_format() + header.set_font_size(18) + subheader = book.add_format() + subheader.set_font_size(16) + sheet.write( + 0, + 0, + f"Process all datasets set with {model}: {set_of_files} " + f"norm: {normalize} std: {standardize} discretize: {discretize} " + f"store in: {model}", + header, + ) + sheet.write( + 1, + 0, + "5 Fold Cross Validation with 10 random seeds", + subheader, + ) + sheet.write(1, 5, f"{random_seeds}", subheader) + header_cols = [ + ("Dataset", 30), + ("Samples", 10), + ("Variables", 7), + ("Classes", 7), + ("Nodes", 7), + ("Leaves", 7), + ("Depth", 7), + ("Accuracy", 10), + ("Stat", 3), + ("Time", 10), + ("Parameters", 50), + ] + bold = book.add_format({"bold": True, "font_size": 14}) + i = 0 + for item, length in header_cols: + sheet.write(3, i, item, bold) + sheet.set_column(i, i, length) + i += 1 + + +def excel_status(status): + if status == TextColor.GREEN + CHECK_MARK + TextColor.ENDC: + return EXCLAMATION_MARK, "Accuracy better than stree optimized" + elif status == TextColor.RED + BLACK_STAR + TextColor.ENDC: + return BLACK_STAR, "Best accuracy of al models" + elif status != " ": + return CHECK_MARK, "Accuracy better than original stree_default" + return " ", "" + + +def excel_write_totals(book, sheet, totals, start): + i = 2 + bold = book.add_format({"bold": True, "font_size": 16}) + for key, total in totals.items(): + status, text = excel_status(key) + sheet.write(excel_write_line.row + i, 1, status, bold) + sheet.write(excel_write_line.row + i, 2, total, bold) + sheet.write(excel_write_line.row + i, 3, text, bold) + i += 1 + time_spent = get_time(start, time.time()) + sheet.write(excel_write_line.row + i + 1, 0, time_spent, bold) + + def compute_status(dbh, name, model, accuracy): - better_default = "\N{heavy check mark}" - better_stree = TextColor.GREEN + "\N{heavy check mark}" + TextColor.ENDC - best = TextColor.RED + "\N{black star}" + TextColor.ENDC + n_dig = 6 + ac_round = round(accuracy, n_dig) + better_default = CHECK_MARK + better_stree = TextColor.GREEN + CHECK_MARK + TextColor.ENDC + best = TextColor.RED + BLACK_STAR + TextColor.ENDC best_default, _ = get_best_score(dbh, name, model) best_stree, _ = get_best_score(dbh, name, "stree") best_all, _ = get_best_score(dbh, name, models_tree) - status = better_default if accuracy >= best_default else " " - status = better_stree if accuracy >= best_stree else status - status = best if accuracy >= best_all else status + status = better_default if ac_round > round(best_default, n_dig) else " " + status = better_stree if ac_round > round(best_stree, n_dig) else status + status = best if ac_round > round(best_all, n_dig) else status return status @@ -199,6 +329,12 @@ def get_best_score(dbh, name, model): return accuracy, acc_std +def get_time(start, stop): + hours, rem = divmod(stop - start, 3600) + minutes, seconds = divmod(rem, 60) + return f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s" + + random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] models_tree = [ "stree", @@ -210,17 +346,32 @@ models_tree = [ "baseRaF", ] standardize = False -(set_of_files, model, dataset, sql, normalize, parameters) = parse_arguments() +( + set_of_files, + model, + dataset, + sql, + normalize, + parameters, + excel, + discretize, +) = parse_arguments() dbh = MySQL() if sql: sql_output = open(f"{model}.sql", "w") +if excel != "": + file_name = f"{excel}.xlsx" + excel_wb = xlsxwriter.Workbook(file_name) + excel_ws = excel_wb.add_worksheet(model) + excel_write_header(excel_wb, excel_ws) database = dbh.get_connection() dt = Datasets(normalize, standardize, set_of_files) start = time.time() if dataset == "all": print( f"* Process all datasets set with {model}: {set_of_files} " - f"norm: {normalize} std: {standardize} store in: {model}" + f"norm: {normalize} std: {standardize} discretize: {discretize}" + f" store in: {model}" ) print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n") header_cols = [ @@ -245,6 +396,8 @@ if dataset == "all": print(f"{field:{underscore}s} ", end="") line_col += "=" * underscore + " " print(f"\n{line_col}") + totals = {} + accuracy_total = 0.0 for dataset in dt: name = dataset[0] X, y = dt.load(name) # type: ignore @@ -268,11 +421,17 @@ if dataset == "all": end="", ) accuracy = np.mean(scores) + accuracy_total += accuracy status = ( compute_status(dbh, name, model, accuracy) if model == "stree_default" else " " ) + if status != " ": + if status not in totals: + totals[status] = 1 + else: + totals[status] += 1 print(f"{accuracy:8.6f}±{np.std(scores):6.4f}{status}", end="") print(f"{np.mean(times):8.6f}±{np.std(times):6.4f} {hyperparameters}") if sql: @@ -280,6 +439,29 @@ if dataset == "all": name, model, scores, times, hyperparameters, complexity ) print(command, file=sql_output) + if excel != "": + excel_write_line( + excel_wb, + excel_ws, + name, + samples, + features, + classes, + accuracy, + times, + hyperparameters, + complexity, + status, + ) + for key, value in totals.items(): + print(f"{key} .....: {value:2d}") + print( + f"** Accuracy compared to stree_default (liblinear-ovr) .: " + f"{accuracy_total/40.282203:7.4f}" + ) + if excel != "": + excel_write_totals(excel_wb, excel_ws, totals, start) + excel_wb.close() else: scores, times, hyperparameters, nodes, leaves, depth = process_dataset( dataset, verbose=True, model=model, params=parameters @@ -302,9 +484,8 @@ else: ) print(f"- Hyperparameters ...: {hyperparameters}") stop = time.time() -hours, rem = divmod(stop - start, 3600) -minutes, seconds = divmod(rem, 60) -print(f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s") +time_spent = get_time(start, time.time()) +print(f"{time_spent}") if sql: sql_output.close() dbh.close()