import argparse from typing import Tuple import numpy as np from experimentation.Sets import Datasets from experimentation.Utils import TextColor from experimentation.Database import MySQL report_csv = "report.csv" models_tree = [ "stree", "stree_default", "wodt", "j48svm", "oc1", "cart", "baseRaF", ] models_ensemble = ["odte", "adaBoost", "bagging", "TBRaF", "TBRoF", "TBRRoF"] description = ["samp", "var", "cls"] complexity = ["nodes", "leaves", "depth"] title = "Best model results" lengths = [30, 4, 3, 3, 3, 3, 3, 12, 12, 12, 12, 12, 12, 12] def parse_arguments() -> Tuple[str, str, str, bool, bool]: ap = argparse.ArgumentParser() ap.add_argument( "-e", "--experiment", type=str, choices=["gridsearch", "crossval", "any"], required=False, default="gridsearch", ) ap.add_argument( "-m", "--model", type=str, choices=["tree", "ensemble"], required=False, default="tree", ) ap.add_argument( "-c", "--csv-output", type=bool, required=False, default=False, ) ap.add_argument( "-o", "--compare", type=bool, required=False, default=False, ) args = ap.parse_args() return (args.experiment, args.model, args.csv_output, args.compare) def report_header_content(title, experiment, model_type): length = sum(lengths) + len(lengths) - 1 output = "\n" + "*" * length + "\n" titles_length = len(title) + len(experiment) + len(model_type) + 21 num = (length - titles_length) // 2 - 3 num2 = length - titles_length - 7 - 2 * num output += ( "*" + " " * num + f"{title} - Experiment: {experiment} - Models: {model_type}" + " " * (num + num2) + "*\n" ) output += "*" * length + "\n\n" lines = "" for item, data in enumerate(fields): output += f"{fields[item]:^{lengths[item]}} " lines += "=" * lengths[item] + " " output += f"\n{lines}" return output def report_header(title, experiment, model_type): print( TextColor.HEADER + report_header_content(title, experiment, model_type) + TextColor.ENDC ) def report_line(line): output = f"{line['dataset']:{lengths[0] + 5}s} " for key, item in enumerate(description + complexity): output += f"{line[item]:{lengths[key + 1]}d} " data = models.copy() for key, model in enumerate(data): output += f"{line[model]:{lengths[key + 7]}s} " return output def report_footer(agg): length = sum(lengths) + len(lengths) - 1 print("-" * length) color = TextColor.LINE1 for item in models: print(color + f"{item:10s} ", end="") print(color + f"best of models {agg[item]['best']:2d} times") color = ( TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1 ) (experiment, model_type, csv_output, compare) = parse_arguments() dbh = MySQL() database = dbh.get_connection() dt = Datasets(False, False, "tanveer") fields = ( "Dataset", "Samp", "Var", "Cls", "Nod", "Lea", "Dep", ) if not compare: # remove stree_default from fields list and lengths models_tree.pop(1) lengths.pop(7) models = models_tree if model_type == "tree" else models_ensemble for item in models: fields += (f"{item}",) report_header(title, experiment, model_type) color = TextColor.LINE1 agg = {} for item in [ "better", "worse", ] + models: agg[item] = {} agg[item]["best"] = 0 if csv_output: f = open(report_csv, "w") print("dataset, classifier, accuracy", file=f) for dataset in dt: find_one = False # Look for max accuracy for any given dataset line = {"dataset": color + dataset[0]} X, y = dt.load(dataset[0]) # type: ignore line["samp"], line["var"] = X.shape line["cls"] = len(np.unique(y)) record = dbh.find_best(dataset[0], models, experiment) max_accuracy = 0.0 if record is None else record[5] line["nodes"] = 0 line["leaves"] = 0 line["depth"] = 0 for model in models: record = dbh.find_best(dataset[0], model, experiment) if record is None: line[model] = color + "-" * 12 else: if model == "stree": line["nodes"] = record[12] line["leaves"] = record[13] line["depth"] = record[14] reference = record[13] accuracy = record[5] acc_std = record[11] find_one = True item = f"{accuracy:.4f}±{acc_std:.3f}" if accuracy == max_accuracy: line[model] = ( TextColor.GREEN + TextColor.BOLD + item + TextColor.ENDC ) agg[model]["best"] += 1 else: line[model] = color + item if csv_output: print(f"{dataset[0]}, {model}, {accuracy}", file=f) if not find_one: print(TextColor.FAIL + f"*No results found for {dataset[0]}") else: color = ( TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1 ) print(report_line(line)) report_footer(agg) if csv_output: f.close() print(f"{report_csv} file generated") dbh.close()