import argparse import random import time import json import numpy as np import xlsxwriter from stree import Stree from sklearn.model_selection import KFold, cross_validate from experimentation.Sets import Datasets from experimentation.Database import MySQL from experimentation.Utils import TextColor CHECK_MARK = "\N{heavy check mark}" EXCLAMATION_MARK = "\N{heavy exclamation mark symbol}" BLACK_STAR = "\N{black star}" def parse_arguments(): ap = argparse.ArgumentParser() ap.add_argument( "-x", "--excel", type=str, default="", required=False, help="generate excel file", ) ap.add_argument( "-p", "--parameters", type=str, required=False, default="{}" ) args = ap.parse_args() return ( args.parameters, args.excel, ) def get_classifier(model, random_state, hyperparameters): clf = Stree(random_state=random_state) clf.set_params(**hyperparameters) return clf def process_dataset(dataset, verbose, model, params): X, y = dt.load(dataset) scores = [] times = [] nodes = [] leaves = [] depths = [] hyperparameters = json.loads(params) for random_state in random_seeds: random.seed(random_state) np.random.seed(random_state) kfold = KFold(shuffle=True, random_state=random_state, n_splits=5) clf = get_classifier(model, random_state, hyperparameters) res = cross_validate(clf, X, y, cv=kfold, return_estimator=True) scores.append(res["test_score"]) times.append(res["fit_time"]) for result_item in res["estimator"]: nodes_item, leaves_item = result_item.nodes_leaves() depth_item = result_item.depth_ nodes.append(nodes_item) leaves.append(leaves_item) depths.append(depth_item) return scores, times, json.dumps(hyperparameters), nodes, leaves, depths def excel_write_line( book, sheet, name, samples, features, classes, accuracy, times, hyperparameters, complexity, status, ): try: excel_write_line.row += 1 except AttributeError: excel_write_line.row = 4 size_n = 14 decimal = book.add_format({"num_format": "0.000000", "font_size": size_n}) integer = book.add_format({"num_format": "#,###", "font_size": size_n}) normal = book.add_format({"font_size": size_n}) col = 0 status, _ = excel_status(status) sheet.write(excel_write_line.row, col, name, normal) sheet.write(excel_write_line.row, col + 1, samples, integer) sheet.write(excel_write_line.row, col + 2, features, normal) sheet.write(excel_write_line.row, col + 3, classes, normal) sheet.write(excel_write_line.row, col + 4, complexity["nodes"], normal) sheet.write(excel_write_line.row, col + 5, complexity["leaves"], normal) sheet.write(excel_write_line.row, col + 6, complexity["depth"], normal) sheet.write(excel_write_line.row, col + 7, accuracy, decimal) sheet.write(excel_write_line.row, col + 8, status, normal) sheet.write(excel_write_line.row, col + 9, np.mean(times), decimal) sheet.write(excel_write_line.row, col + 10, hyperparameters, normal) def excel_write_header(book, sheet, parameters): header = book.add_format() header.set_font_size(18) subheader = book.add_format() subheader.set_font_size(16) sheet.write( 0, 0, f"Process all datasets set with STree: Parameters: {parameters} ", header, ) sheet.write( 1, 0, "5 Fold Cross Validation with 10 random seeds", subheader, ) sheet.write(1, 5, f"{random_seeds}", subheader) header_cols = [ ("Dataset", 30), ("Samples", 10), ("Variables", 7), ("Classes", 7), ("Nodes", 7), ("Leaves", 7), ("Depth", 7), ("Accuracy", 10), ("Stat", 3), ("Time", 10), ("Parameters", 50), ] bold = book.add_format({"bold": True, "font_size": 14}) i = 0 for item, length in header_cols: sheet.write(3, i, item, bold) sheet.set_column(i, i, length) i += 1 def excel_status(status): if status == TextColor.GREEN + CHECK_MARK + TextColor.ENDC: return EXCLAMATION_MARK, "Accuracy better than stree optimized" elif status == TextColor.RED + BLACK_STAR + TextColor.ENDC: return BLACK_STAR, "Best accuracy of al models" elif status != " ": return CHECK_MARK, "Accuracy better than original stree_default" return " ", "" def excel_write_totals(book, sheet, totals, start, accuracy_total): i = 2 bold = book.add_format({"bold": True, "font_size": 16}) for key, total in totals.items(): status, text = excel_status(key) sheet.write(excel_write_line.row + i, 1, status, bold) sheet.write(excel_write_line.row + i, 2, total, bold) sheet.write(excel_write_line.row + i, 3, text, bold) i += 1 message = ( f"** Accuracy compared to stree_default (liblinear-ovr) .: " f"{accuracy_total/40.282203:7.4f}" ) sheet.write(excel_write_line.row + i + 1, 0, message, bold) time_spent = get_time(start, time.time()) sheet.write(excel_write_line.row + i + 3, 0, time_spent, bold) def compute_status(dbh, name, model, accuracy): n_dig = 6 ac_round = round(accuracy, n_dig) better_default = CHECK_MARK better_stree = TextColor.GREEN + CHECK_MARK + TextColor.ENDC best = TextColor.RED + BLACK_STAR + TextColor.ENDC best_default, _ = get_best_score(dbh, name, model) best_stree, _ = get_best_score(dbh, name, "stree") best_all, _ = get_best_score(dbh, name, models_tree) status = better_default if ac_round > round(best_default, n_dig) else " " status = better_stree if ac_round > round(best_stree, n_dig) else status status = best if ac_round > round(best_all, n_dig) else status return status def get_best_score(dbh, name, model): record = dbh.find_best(name, model, "crossval") accuracy = record[5] if record is not None else 0.0 acc_std = record[11] if record is not None else 0.0 return accuracy, acc_std def get_time(start, stop): hours, rem = divmod(stop - start, 3600) minutes, seconds = divmod(rem, 60) return f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s" random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] models_tree = [ "stree", "stree_default", "wodt", "j48svm", "oc1", "cart", "baseRaF", ] ( parameters, excel, ) = parse_arguments() dbh = MySQL() if excel != "": file_name = f"{excel}.xlsx" excel_wb = xlsxwriter.Workbook(file_name) excel_ws = excel_wb.add_worksheet("STree") excel_write_header(excel_wb, excel_ws, parameters) database = dbh.get_connection() dt = Datasets(normalize=False, standardize=False, set_of_files="tanveer") start = time.time() print(f"* Process all datasets set with STree: {parameters}") print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n") header_cols = [ "Dataset", "Samp", "Var", "Cls", "Nodes", "Leaves", "Depth", "Accuracy", "Time", "Parameters", ] header_lengths = [30, 5, 3, 3, 7, 7, 7, 15, 15, 10] model = "stree_default" parameters = json.dumps(json.loads(parameters)) if parameters != "{}" and len(parameters) > 10: header_lengths.pop() header_lengths.append(len(parameters)) line_col = "" for field, underscore in zip(header_cols, header_lengths): print(f"{field:{underscore}s} ", end="") line_col += "=" * underscore + " " print(f"\n{line_col}") totals = {} accuracy_total = 0.0 for dataset in dt: name = dataset[0] X, y = dt.load(name) # type: ignore samples, features = X.shape classes = len(np.unique(y)) print( f"{name:30s} {samples:5d} {features:3d} {classes:3d} ", end="", ) scores, times, hyperparameters, nodes, leaves, depth = process_dataset( dataset[0], verbose=False, model=model, params=parameters ) complexity = dict( nodes=float(np.mean(nodes)), leaves=float(np.mean(leaves)), depth=float(np.mean(depth)), ) nodes_item, leaves_item, depth_item = complexity.values() print( f"{nodes_item:7.2f} {leaves_item:7.2f} {depth_item:7.2f} ", end="", ) accuracy = np.mean(scores) accuracy_total += accuracy status = ( compute_status(dbh, name, model, accuracy) if model == "stree_default" else " " ) if status != " ": if status not in totals: totals[status] = 1 else: totals[status] += 1 print(f"{accuracy:8.6f}±{np.std(scores):6.4f}{status}", end="") print(f"{np.mean(times):8.6f}±{np.std(times):6.4f} {hyperparameters}") if excel != "": excel_write_line( excel_wb, excel_ws, name, samples, features, classes, accuracy, times, hyperparameters, complexity, status, ) for key, value in totals.items(): print(f"{key} .....: {value:2d}") print( f"** Accuracy compared to stree_default (liblinear-ovr) .: " f"{accuracy_total/40.282203:7.4f}" ) if excel != "": excel_write_totals(excel_wb, excel_ws, totals, start, accuracy_total) excel_wb.close() stop = time.time() time_spent = get_time(start, time.time()) print(f"{time_spent}") dbh.close()