Update analysis and report mysql

This commit is contained in:
2021-03-05 10:41:18 +01:00
parent 116db3f528
commit a075e5e95a
3 changed files with 46 additions and 31 deletions

View File

@@ -4,9 +4,10 @@ from experimentation.Sets import Datasets
from experimentation.Utils import TextColor
from experimentation.Database import MySQL
models = ["stree", "oc1", "cart", "odte", "adaBoost", "bagging"]
models_tree = ["stree", "oc1", "cart"]
models_ensemble = ["odte", "adaBoost", "bagging"]
title = "Best model results"
lengths = (30, 9, 11, 11, 11, 11, 11, 11)
lengths = (30, 9, 11, 11, 11)
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
@@ -15,23 +16,32 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
"-e",
"--experiment",
type=str,
choices=["any", "gridsearch", "crossval"],
choices=["gridsearch", "crossval", "any"],
required=False,
default="any",
default="gridsearch",
)
ap.add_argument(
"-m",
"--model",
type=str,
choices=["tree", "ensemble"],
required=False,
default="tree",
)
args = ap.parse_args()
return args.experiment
return (args.experiment, args.model)
def report_header_content(title, experiment):
def report_header_content(title, experiment, model_type):
length = sum(lengths) + len(lengths) - 1
output = "\n" + "*" * length + "\n"
num = (length - len(title) - len(experiment) - 2) // 2
num2 = length - len(title) - len(experiment) - 5 - 2 * num
titles_length = len(title) + len(experiment) + len(model_type) + 21
num = (length - titles_length) // 2 - 3
num2 = length - titles_length - 7 - 2 * num
output += (
"*"
+ " " * num
+ f"{title} - {experiment}"
+ f"{title} - Experiment: {experiment} - Models: {model_type}"
+ " " * (num + num2)
+ "*\n"
)
@@ -44,10 +54,10 @@ def report_header_content(title, experiment):
return output
def report_header(title, experiment):
def report_header(title, experiment, model_type):
print(
TextColor.HEADER
+ report_header_content(title, experiment)
+ report_header_content(title, experiment, model_type)
+ TextColor.ENDC
)
@@ -85,14 +95,15 @@ def report_footer(agg):
)
experiment = parse_arguments()
(experiment, model_type) = parse_arguments()
dbh = MySQL()
database = dbh.get_connection()
dt = Datasets(False, False, "tanveer")
fields = ("Dataset", "Reference")
for model in models:
fields += (f"{model}",)
report_header(title, experiment)
models = models_tree if model_type == "tree" else models_ensemble
for item in models:
fields += (f"{item}",)
report_header(title, experiment, model_type)
color = TextColor.LINE1
agg = {}
for item in [
@@ -107,7 +118,7 @@ for dataset in dt:
find_one = False
# Look for max accuracy for any given dataset
line = {"dataset": color + dataset[0]}
record = dbh.find_best(dataset[0], "any", experiment)
record = dbh.find_best(dataset[0], models, experiment)
max_accuracy = 0.0 if record is None else record[5]
for model in models:
record = dbh.find_best(dataset[0], model, experiment)