Update analysis and report mysql

This commit is contained in:
2021-03-05 10:41:18 +01:00
parent 116db3f528
commit a075e5e95a
3 changed files with 46 additions and 31 deletions

View File

@@ -4,9 +4,10 @@ from experimentation.Sets import Datasets
from experimentation.Utils import TextColor from experimentation.Utils import TextColor
from experimentation.Database import MySQL from experimentation.Database import MySQL
models = ["stree", "oc1", "cart", "odte", "adaBoost", "bagging"] models_tree = ["stree", "oc1", "cart"]
models_ensemble = ["odte", "adaBoost", "bagging"]
title = "Best model results" title = "Best model results"
lengths = (30, 9, 11, 11, 11, 11, 11, 11) lengths = (30, 9, 11, 11, 11)
def parse_arguments() -> Tuple[str, str, str, bool, bool]: def parse_arguments() -> Tuple[str, str, str, bool, bool]:
@@ -15,23 +16,32 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
"-e", "-e",
"--experiment", "--experiment",
type=str, type=str,
choices=["any", "gridsearch", "crossval"], choices=["gridsearch", "crossval", "any"],
required=False, required=False,
default="any", default="gridsearch",
)
ap.add_argument(
"-m",
"--model",
type=str,
choices=["tree", "ensemble"],
required=False,
default="tree",
) )
args = ap.parse_args() args = ap.parse_args()
return args.experiment return (args.experiment, args.model)
def report_header_content(title, experiment): def report_header_content(title, experiment, model_type):
length = sum(lengths) + len(lengths) - 1 length = sum(lengths) + len(lengths) - 1
output = "\n" + "*" * length + "\n" output = "\n" + "*" * length + "\n"
num = (length - len(title) - len(experiment) - 2) // 2 titles_length = len(title) + len(experiment) + len(model_type) + 21
num2 = length - len(title) - len(experiment) - 5 - 2 * num num = (length - titles_length) // 2 - 3
num2 = length - titles_length - 7 - 2 * num
output += ( output += (
"*" "*"
+ " " * num + " " * num
+ f"{title} - {experiment}" + f"{title} - Experiment: {experiment} - Models: {model_type}"
+ " " * (num + num2) + " " * (num + num2)
+ "*\n" + "*\n"
) )
@@ -44,10 +54,10 @@ def report_header_content(title, experiment):
return output return output
def report_header(title, experiment): def report_header(title, experiment, model_type):
print( print(
TextColor.HEADER TextColor.HEADER
+ report_header_content(title, experiment) + report_header_content(title, experiment, model_type)
+ TextColor.ENDC + TextColor.ENDC
) )
@@ -85,14 +95,15 @@ def report_footer(agg):
) )
experiment = parse_arguments() (experiment, model_type) = parse_arguments()
dbh = MySQL() dbh = MySQL()
database = dbh.get_connection() database = dbh.get_connection()
dt = Datasets(False, False, "tanveer") dt = Datasets(False, False, "tanveer")
fields = ("Dataset", "Reference") fields = ("Dataset", "Reference")
for model in models: models = models_tree if model_type == "tree" else models_ensemble
fields += (f"{model}",) for item in models:
report_header(title, experiment) fields += (f"{item}",)
report_header(title, experiment, model_type)
color = TextColor.LINE1 color = TextColor.LINE1
agg = {} agg = {}
for item in [ for item in [
@@ -107,7 +118,7 @@ for dataset in dt:
find_one = False find_one = False
# Look for max accuracy for any given dataset # Look for max accuracy for any given dataset
line = {"dataset": color + dataset[0]} line = {"dataset": color + dataset[0]}
record = dbh.find_best(dataset[0], "any", experiment) record = dbh.find_best(dataset[0], models, experiment)
max_accuracy = 0.0 if record is None else record[5] max_accuracy = 0.0 if record is None else record[5]
for model in models: for model in models:
record = dbh.find_best(dataset[0], model, experiment) record = dbh.find_best(dataset[0], model, experiment)

View File

@@ -46,18 +46,21 @@ class MySQL:
def find_best(self, dataset, classifier="any", experiment="any"): def find_best(self, dataset, classifier="any", experiment="any"):
cursor = self._database.cursor(buffered=True) cursor = self._database.cursor(buffered=True)
if classifier == "any": date_from = "2021-01-20"
command = ( command = (
f"select * from results r inner join reference e on " f"select * from results r inner join reference e on "
f"r.dataset=e.dataset where r.dataset='{dataset}' and " f"r.dataset=e.dataset where r.dataset='{dataset}'"
f"date>='2021-01-20'" )
) if isinstance(classifier, list):
else: classifier_set = "("
command = ( for i, item in enumerate(classifier):
f"select * from results r inner join reference e on " comma = "" if i == 0 else ","
f"r.dataset=e.dataset where r.dataset='{dataset}' and " classifier_set += f"{comma}'{item}'"
f"classifier='{classifier}' and date>='2021-01-20'" classifier_set += ")"
) command += f" and r.classifier in {classifier_set}"
elif classifier != "any":
command += f" and r.classifier='{classifier}'"
command += f" and date>='{date_from}'"
command += "" if experiment == "any" else f" and type='{experiment}'" command += "" if experiment == "any" else f" and type='{experiment}'"
command += ( command += (
" order by r.dataset, accuracy desc, classifier desc, " " order by r.dataset, accuracy desc, classifier desc, "
@@ -182,7 +185,8 @@ class BD(ABC):
database = dbh.get_connection() database = dbh.get_connection()
command_insert = ( command_insert = (
"replace into results (date, time, type, accuracy, " "replace into results (date, time, type, accuracy, "
"dataset, classifier, norm, stand, parameters, accuracy_std, time_spent, time_spent_std) values (%s, %s, " "dataset, classifier, norm, stand, parameters, accuracy_std, "
"time_spent, time_spent_std) values (%s, %s, "
"%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)" "%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"
) )
now = datetime.now() now = datetime.now()

View File

@@ -54,7 +54,7 @@ def report_header(exclude_params):
def report_line(record, agg): def report_line(record, agg):
accuracy = record[5] accuracy = record[5]
expected = record[10] expected = record[13]
if accuracy < expected: if accuracy < expected:
agg["worse"] += 1 agg["worse"] += 1
sign = "-" sign = "-"
@@ -84,7 +84,7 @@ def report_footer(agg):
TextColor.MAGENTA + f"we have equal results {agg['equal']:2d} times" TextColor.MAGENTA + f"we have equal results {agg['equal']:2d} times"
) )
color = TextColor.LINE1 color = TextColor.LINE1
for item in ["stree", "bagging", "adaBoost", "odte"]: for item in models:
print(color + f"{item:10s} used {agg[item]:2d} times") print(color + f"{item:10s} used {agg[item]:2d} times")
color = ( color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1 TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1