diff --git a/analysis_mysql.py b/analysis_mysql.py index d3d1076..c8970d5 100644 --- a/analysis_mysql.py +++ b/analysis_mysql.py @@ -4,9 +4,10 @@ from experimentation.Sets import Datasets from experimentation.Utils import TextColor from experimentation.Database import MySQL -models = ["stree", "oc1", "cart", "odte", "adaBoost", "bagging"] +models_tree = ["stree", "oc1", "cart"] +models_ensemble = ["odte", "adaBoost", "bagging"] title = "Best model results" -lengths = (30, 9, 11, 11, 11, 11, 11, 11) +lengths = (30, 9, 11, 11, 11) def parse_arguments() -> Tuple[str, str, str, bool, bool]: @@ -15,23 +16,32 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]: "-e", "--experiment", type=str, - choices=["any", "gridsearch", "crossval"], + choices=["gridsearch", "crossval", "any"], required=False, - default="any", + default="gridsearch", + ) + ap.add_argument( + "-m", + "--model", + type=str, + choices=["tree", "ensemble"], + required=False, + default="tree", ) args = ap.parse_args() - return args.experiment + return (args.experiment, args.model) -def report_header_content(title, experiment): +def report_header_content(title, experiment, model_type): length = sum(lengths) + len(lengths) - 1 output = "\n" + "*" * length + "\n" - num = (length - len(title) - len(experiment) - 2) // 2 - num2 = length - len(title) - len(experiment) - 5 - 2 * num + titles_length = len(title) + len(experiment) + len(model_type) + 21 + num = (length - titles_length) // 2 - 3 + num2 = length - titles_length - 7 - 2 * num output += ( "*" + " " * num - + f"{title} - {experiment}" + + f"{title} - Experiment: {experiment} - Models: {model_type}" + " " * (num + num2) + "*\n" ) @@ -44,10 +54,10 @@ def report_header_content(title, experiment): return output -def report_header(title, experiment): +def report_header(title, experiment, model_type): print( TextColor.HEADER - + report_header_content(title, experiment) + + report_header_content(title, experiment, model_type) + TextColor.ENDC ) @@ -85,14 +95,15 @@ def report_footer(agg): ) -experiment = parse_arguments() +(experiment, model_type) = parse_arguments() dbh = MySQL() database = dbh.get_connection() dt = Datasets(False, False, "tanveer") fields = ("Dataset", "Reference") -for model in models: - fields += (f"{model}",) -report_header(title, experiment) +models = models_tree if model_type == "tree" else models_ensemble +for item in models: + fields += (f"{item}",) +report_header(title, experiment, model_type) color = TextColor.LINE1 agg = {} for item in [ @@ -107,7 +118,7 @@ for dataset in dt: find_one = False # Look for max accuracy for any given dataset line = {"dataset": color + dataset[0]} - record = dbh.find_best(dataset[0], "any", experiment) + record = dbh.find_best(dataset[0], models, experiment) max_accuracy = 0.0 if record is None else record[5] for model in models: record = dbh.find_best(dataset[0], model, experiment) diff --git a/experimentation/Database.py b/experimentation/Database.py index b4014ee..eb32e38 100644 --- a/experimentation/Database.py +++ b/experimentation/Database.py @@ -46,18 +46,21 @@ class MySQL: def find_best(self, dataset, classifier="any", experiment="any"): cursor = self._database.cursor(buffered=True) - if classifier == "any": - command = ( - f"select * from results r inner join reference e on " - f"r.dataset=e.dataset where r.dataset='{dataset}' and " - f"date>='2021-01-20'" - ) - else: - command = ( - f"select * from results r inner join reference e on " - f"r.dataset=e.dataset where r.dataset='{dataset}' and " - f"classifier='{classifier}' and date>='2021-01-20'" - ) + date_from = "2021-01-20" + command = ( + f"select * from results r inner join reference e on " + f"r.dataset=e.dataset where r.dataset='{dataset}'" + ) + if isinstance(classifier, list): + classifier_set = "(" + for i, item in enumerate(classifier): + comma = "" if i == 0 else "," + classifier_set += f"{comma}'{item}'" + classifier_set += ")" + command += f" and r.classifier in {classifier_set}" + elif classifier != "any": + command += f" and r.classifier='{classifier}'" + command += f" and date>='{date_from}'" command += "" if experiment == "any" else f" and type='{experiment}'" command += ( " order by r.dataset, accuracy desc, classifier desc, " @@ -182,7 +185,8 @@ class BD(ABC): database = dbh.get_connection() command_insert = ( "replace into results (date, time, type, accuracy, " - "dataset, classifier, norm, stand, parameters, accuracy_std, time_spent, time_spent_std) values (%s, %s, " + "dataset, classifier, norm, stand, parameters, accuracy_std, " + "time_spent, time_spent_std) values (%s, %s, " "%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)" ) now = datetime.now() diff --git a/report_mysql.py b/report_mysql.py index 384dcbd..85ed7cb 100644 --- a/report_mysql.py +++ b/report_mysql.py @@ -54,7 +54,7 @@ def report_header(exclude_params): def report_line(record, agg): accuracy = record[5] - expected = record[10] + expected = record[13] if accuracy < expected: agg["worse"] += 1 sign = "-" @@ -84,7 +84,7 @@ def report_footer(agg): TextColor.MAGENTA + f"we have equal results {agg['equal']:2d} times" ) color = TextColor.LINE1 - for item in ["stree", "bagging", "adaBoost", "odte"]: + for item in models: print(color + f"{item:10s} used {agg[item]:2d} times") color = ( TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1