Files
stree_datasets/report_mysql.py

157 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
from typing import Tuple
from experimentation.Sets import Datasets
from experimentation.Utils import TextColor
from experimentation.Database import MySQL
models = [
"stree",
"stree_default",
"adaBoost",
"bagging",
"odte",
"cart",
"oc1",
"j48svm",
"wodt",
]
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
ap = argparse.ArgumentParser()
ap.add_argument(
"-m",
"--model",
type=str,
choices=["any"] + models,
required=False,
default="any",
)
ap.add_argument(
"-e",
"--experiment",
type=str,
choices=["gridsearch", "crossval"],
required=False,
default="crossval",
)
ap.add_argument(
"-x",
"--excludeparams",
default=False,
required=False,
action="store_true",
help="Exclude parameters in reports",
)
args = ap.parse_args()
return (
args.model,
args.excludeparams,
args.experiment,
)
def report_header_content(title):
length = sum(lengths) + len(lengths) - 1
output = "\n" + "*" * length + "\n"
title = title + f" -- {classifier} classifier --"
num = (length - len(title) - 2) // 2
num2 = length - len(title) - 2 - 2 * num
output += "*" + " " * num + title + " " * (num + num2) + "*\n"
output += "*" * length + "\n\n"
lines = ""
for item, data in enumerate(fields):
output += f"{fields[item]:{lengths[item]}} "
lines += "=" * lengths[item] + " "
output += f"\n{lines}"
return output
def report_header(exclude_params):
print(TextColor.HEADER + report_header_content(title) + TextColor.ENDC)
def report_line(record, agg):
accuracy = record[5]
expected = record[16]
if accuracy < expected:
agg["worse"] += 1
sign = "-"
elif accuracy > expected:
agg["better"] += 1
sign = "+"
else:
agg["equal"] += 1
sign = "="
model = record[3]
agg[model] += 1
output = (
f"{record[0]:%Y-%m-%d} {str(record[1]):>8s} {record[2]:10s} "
f"{model:10s} {record[4]:30s} "
f"{record[6]:3d} {record[7]:3d} {accuracy:8.7f} {expected:8.7f} "
f"{sign}"
)
if not exclude_parameters:
output += f" {record[8]}"
return output
def report_footer(agg):
print(TextColor.GREEN + f"we have better results {agg['better']:2d} times")
print(TextColor.RED + f"we have worse results {agg['worse']:2d} times")
print(
TextColor.MAGENTA + f"we have equal results {agg['equal']:2d} times"
)
color = TextColor.LINE1
for item in models:
print(color + f"{item:10s} used {agg[item]:2d} times")
color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
)
(
classifier,
exclude_parameters,
experiment,
) = parse_arguments()
dbh = MySQL()
database = dbh.get_connection()
dt = Datasets(False, False, "tanveer")
title = "Best Hyperparameters found for datasets"
lengths = (10, 8, 10, 10, 30, 3, 3, 9, 11)
fields = (
"Date",
"Time",
"Type",
"Classifier",
"Dataset",
"Nor",
"Std",
"Accuracy",
"Reference",
)
if not exclude_parameters:
fields += ("Parameters",)
lengths += (30,)
report_header(title)
color = TextColor.LINE1
agg = {}
for item in [
"equal",
"better",
"worse",
] + models:
agg[item] = 0
for dataset in dt:
record = dbh.find_best(dataset[0], classifier, experiment=experiment)
if record is None:
print(TextColor.FAIL + f"*No results found for {dataset[0]}")
else:
color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
)
print(color + report_line(record, agg))
report_footer(agg)
dbh.close()