Update analysis and report mysql

This commit is contained in:
2021-03-05 10:41:18 +01:00
parent 116db3f528
commit a075e5e95a
3 changed files with 46 additions and 31 deletions

View File

@@ -4,9 +4,10 @@ from experimentation.Sets import Datasets
from experimentation.Utils import TextColor
from experimentation.Database import MySQL
models = ["stree", "oc1", "cart", "odte", "adaBoost", "bagging"]
models_tree = ["stree", "oc1", "cart"]
models_ensemble = ["odte", "adaBoost", "bagging"]
title = "Best model results"
lengths = (30, 9, 11, 11, 11, 11, 11, 11)
lengths = (30, 9, 11, 11, 11)
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
@@ -15,23 +16,32 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
"-e",
"--experiment",
type=str,
choices=["any", "gridsearch", "crossval"],
choices=["gridsearch", "crossval", "any"],
required=False,
default="any",
default="gridsearch",
)
ap.add_argument(
"-m",
"--model",
type=str,
choices=["tree", "ensemble"],
required=False,
default="tree",
)
args = ap.parse_args()
return args.experiment
return (args.experiment, args.model)
def report_header_content(title, experiment):
def report_header_content(title, experiment, model_type):
length = sum(lengths) + len(lengths) - 1
output = "\n" + "*" * length + "\n"
num = (length - len(title) - len(experiment) - 2) // 2
num2 = length - len(title) - len(experiment) - 5 - 2 * num
titles_length = len(title) + len(experiment) + len(model_type) + 21
num = (length - titles_length) // 2 - 3
num2 = length - titles_length - 7 - 2 * num
output += (
"*"
+ " " * num
+ f"{title} - {experiment}"
+ f"{title} - Experiment: {experiment} - Models: {model_type}"
+ " " * (num + num2)
+ "*\n"
)
@@ -44,10 +54,10 @@ def report_header_content(title, experiment):
return output
def report_header(title, experiment):
def report_header(title, experiment, model_type):
print(
TextColor.HEADER
+ report_header_content(title, experiment)
+ report_header_content(title, experiment, model_type)
+ TextColor.ENDC
)
@@ -85,14 +95,15 @@ def report_footer(agg):
)
experiment = parse_arguments()
(experiment, model_type) = parse_arguments()
dbh = MySQL()
database = dbh.get_connection()
dt = Datasets(False, False, "tanveer")
fields = ("Dataset", "Reference")
for model in models:
fields += (f"{model}",)
report_header(title, experiment)
models = models_tree if model_type == "tree" else models_ensemble
for item in models:
fields += (f"{item}",)
report_header(title, experiment, model_type)
color = TextColor.LINE1
agg = {}
for item in [
@@ -107,7 +118,7 @@ for dataset in dt:
find_one = False
# Look for max accuracy for any given dataset
line = {"dataset": color + dataset[0]}
record = dbh.find_best(dataset[0], "any", experiment)
record = dbh.find_best(dataset[0], models, experiment)
max_accuracy = 0.0 if record is None else record[5]
for model in models:
record = dbh.find_best(dataset[0], model, experiment)

View File

@@ -46,18 +46,21 @@ class MySQL:
def find_best(self, dataset, classifier="any", experiment="any"):
cursor = self._database.cursor(buffered=True)
if classifier == "any":
command = (
f"select * from results r inner join reference e on "
f"r.dataset=e.dataset where r.dataset='{dataset}' and "
f"date>='2021-01-20'"
)
else:
command = (
f"select * from results r inner join reference e on "
f"r.dataset=e.dataset where r.dataset='{dataset}' and "
f"classifier='{classifier}' and date>='2021-01-20'"
)
date_from = "2021-01-20"
command = (
f"select * from results r inner join reference e on "
f"r.dataset=e.dataset where r.dataset='{dataset}'"
)
if isinstance(classifier, list):
classifier_set = "("
for i, item in enumerate(classifier):
comma = "" if i == 0 else ","
classifier_set += f"{comma}'{item}'"
classifier_set += ")"
command += f" and r.classifier in {classifier_set}"
elif classifier != "any":
command += f" and r.classifier='{classifier}'"
command += f" and date>='{date_from}'"
command += "" if experiment == "any" else f" and type='{experiment}'"
command += (
" order by r.dataset, accuracy desc, classifier desc, "
@@ -182,7 +185,8 @@ class BD(ABC):
database = dbh.get_connection()
command_insert = (
"replace into results (date, time, type, accuracy, "
"dataset, classifier, norm, stand, parameters, accuracy_std, time_spent, time_spent_std) values (%s, %s, "
"dataset, classifier, norm, stand, parameters, accuracy_std, "
"time_spent, time_spent_std) values (%s, %s, "
"%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"
)
now = datetime.now()

View File

@@ -54,7 +54,7 @@ def report_header(exclude_params):
def report_line(record, agg):
accuracy = record[5]
expected = record[10]
expected = record[13]
if accuracy < expected:
agg["worse"] += 1
sign = "-"
@@ -84,7 +84,7 @@ def report_footer(agg):
TextColor.MAGENTA + f"we have equal results {agg['equal']:2d} times"
)
color = TextColor.LINE1
for item in ["stree", "bagging", "adaBoost", "odte"]:
for item in models:
print(color + f"{item:10s} used {agg[item]:2d} times")
color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1