Files
stree_datasets/analysis_mysql.py
Ricardo Montañana 116db3f528 Add oc1 and cart models
Update analysis and report mysql
2021-03-03 12:47:47 +01:00

145 lines
4.1 KiB
Python

import argparse
from typing import Tuple
from experimentation.Sets import Datasets
from experimentation.Utils import TextColor
from experimentation.Database import MySQL
models = ["stree", "oc1", "cart", "odte", "adaBoost", "bagging"]
title = "Best model results"
lengths = (30, 9, 11, 11, 11, 11, 11, 11)
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
ap = argparse.ArgumentParser()
ap.add_argument(
"-e",
"--experiment",
type=str,
choices=["any", "gridsearch", "crossval"],
required=False,
default="any",
)
args = ap.parse_args()
return args.experiment
def report_header_content(title, experiment):
length = sum(lengths) + len(lengths) - 1
output = "\n" + "*" * length + "\n"
num = (length - len(title) - len(experiment) - 2) // 2
num2 = length - len(title) - len(experiment) - 5 - 2 * num
output += (
"*"
+ " " * num
+ f"{title} - {experiment}"
+ " " * (num + num2)
+ "*\n"
)
output += "*" * length + "\n\n"
lines = ""
for item, data in enumerate(fields):
output += f"{fields[item]:{lengths[item]}} "
lines += "=" * lengths[item] + " "
output += f"\n{lines}"
return output
def report_header(title, experiment):
print(
TextColor.HEADER
+ report_header_content(title, experiment)
+ TextColor.ENDC
)
def report_line(line):
output = f"{line['dataset']:{lengths[0] + 5}s} "
data = models.copy()
data.insert(0, "reference")
for key, model in enumerate(data):
output += f"{line[model]:{lengths[key + 1]}s} "
return output
def report_footer(agg):
print(
TextColor.GREEN
+ f"we have better results {agg['better']['items']:2d} times"
)
print(
TextColor.RED
+ f"we have worse results {agg['worse']['items']:2d} times"
)
color = TextColor.LINE1
for item in models:
print(
color + f"{item:10s} used {agg[item]['items']:2d} times ", end=""
)
print(
color + f"better {agg[item]['better']:2d} times ",
end="",
)
print(color + f"worse {agg[item]['worse']:2d} times ")
color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
)
experiment = parse_arguments()
dbh = MySQL()
database = dbh.get_connection()
dt = Datasets(False, False, "tanveer")
fields = ("Dataset", "Reference")
for model in models:
fields += (f"{model}",)
report_header(title, experiment)
color = TextColor.LINE1
agg = {}
for item in [
"better",
"worse",
] + models:
agg[item] = {}
agg[item]["items"] = 0
agg[item]["better"] = 0
agg[item]["worse"] = 0
for dataset in dt:
find_one = False
# Look for max accuracy for any given dataset
line = {"dataset": color + dataset[0]}
record = dbh.find_best(dataset[0], "any", experiment)
max_accuracy = 0.0 if record is None else record[5]
for model in models:
record = dbh.find_best(dataset[0], model, experiment)
if record is None:
line[model] = color + "-" * 9 + " "
else:
reference = record[13]
accuracy = record[5]
find_one = True
agg[model]["items"] += 1
if accuracy > reference:
sign = "+"
agg["better"]["items"] += 1
agg[model]["better"] += 1
else:
sign = "-"
agg["worse"]["items"] += 1
agg[model]["worse"] += 1
item = f"{accuracy:9.7} {sign}"
line["reference"] = f"{reference:9.7}"
line[model] = (
TextColor.GREEN + TextColor.BOLD + item + TextColor.ENDC
if accuracy == max_accuracy
else color + item
)
if not find_one:
print(TextColor.FAIL + f"*No results found for {dataset[0]}")
else:
color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
)
print(report_line(line))
report_footer(agg)
dbh.close()