mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 23:46:03 +00:00
Update analysis and report mysql
This commit is contained in:
@@ -4,9 +4,10 @@ from experimentation.Sets import Datasets
|
|||||||
from experimentation.Utils import TextColor
|
from experimentation.Utils import TextColor
|
||||||
from experimentation.Database import MySQL
|
from experimentation.Database import MySQL
|
||||||
|
|
||||||
models = ["stree", "oc1", "cart", "odte", "adaBoost", "bagging"]
|
models_tree = ["stree", "oc1", "cart"]
|
||||||
|
models_ensemble = ["odte", "adaBoost", "bagging"]
|
||||||
title = "Best model results"
|
title = "Best model results"
|
||||||
lengths = (30, 9, 11, 11, 11, 11, 11, 11)
|
lengths = (30, 9, 11, 11, 11)
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
||||||
@@ -15,23 +16,32 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
|||||||
"-e",
|
"-e",
|
||||||
"--experiment",
|
"--experiment",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["any", "gridsearch", "crossval"],
|
choices=["gridsearch", "crossval", "any"],
|
||||||
required=False,
|
required=False,
|
||||||
default="any",
|
default="gridsearch",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
choices=["tree", "ensemble"],
|
||||||
|
required=False,
|
||||||
|
default="tree",
|
||||||
)
|
)
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
return args.experiment
|
return (args.experiment, args.model)
|
||||||
|
|
||||||
|
|
||||||
def report_header_content(title, experiment):
|
def report_header_content(title, experiment, model_type):
|
||||||
length = sum(lengths) + len(lengths) - 1
|
length = sum(lengths) + len(lengths) - 1
|
||||||
output = "\n" + "*" * length + "\n"
|
output = "\n" + "*" * length + "\n"
|
||||||
num = (length - len(title) - len(experiment) - 2) // 2
|
titles_length = len(title) + len(experiment) + len(model_type) + 21
|
||||||
num2 = length - len(title) - len(experiment) - 5 - 2 * num
|
num = (length - titles_length) // 2 - 3
|
||||||
|
num2 = length - titles_length - 7 - 2 * num
|
||||||
output += (
|
output += (
|
||||||
"*"
|
"*"
|
||||||
+ " " * num
|
+ " " * num
|
||||||
+ f"{title} - {experiment}"
|
+ f"{title} - Experiment: {experiment} - Models: {model_type}"
|
||||||
+ " " * (num + num2)
|
+ " " * (num + num2)
|
||||||
+ "*\n"
|
+ "*\n"
|
||||||
)
|
)
|
||||||
@@ -44,10 +54,10 @@ def report_header_content(title, experiment):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def report_header(title, experiment):
|
def report_header(title, experiment, model_type):
|
||||||
print(
|
print(
|
||||||
TextColor.HEADER
|
TextColor.HEADER
|
||||||
+ report_header_content(title, experiment)
|
+ report_header_content(title, experiment, model_type)
|
||||||
+ TextColor.ENDC
|
+ TextColor.ENDC
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,14 +95,15 @@ def report_footer(agg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
experiment = parse_arguments()
|
(experiment, model_type) = parse_arguments()
|
||||||
dbh = MySQL()
|
dbh = MySQL()
|
||||||
database = dbh.get_connection()
|
database = dbh.get_connection()
|
||||||
dt = Datasets(False, False, "tanveer")
|
dt = Datasets(False, False, "tanveer")
|
||||||
fields = ("Dataset", "Reference")
|
fields = ("Dataset", "Reference")
|
||||||
for model in models:
|
models = models_tree if model_type == "tree" else models_ensemble
|
||||||
fields += (f"{model}",)
|
for item in models:
|
||||||
report_header(title, experiment)
|
fields += (f"{item}",)
|
||||||
|
report_header(title, experiment, model_type)
|
||||||
color = TextColor.LINE1
|
color = TextColor.LINE1
|
||||||
agg = {}
|
agg = {}
|
||||||
for item in [
|
for item in [
|
||||||
@@ -107,7 +118,7 @@ for dataset in dt:
|
|||||||
find_one = False
|
find_one = False
|
||||||
# Look for max accuracy for any given dataset
|
# Look for max accuracy for any given dataset
|
||||||
line = {"dataset": color + dataset[0]}
|
line = {"dataset": color + dataset[0]}
|
||||||
record = dbh.find_best(dataset[0], "any", experiment)
|
record = dbh.find_best(dataset[0], models, experiment)
|
||||||
max_accuracy = 0.0 if record is None else record[5]
|
max_accuracy = 0.0 if record is None else record[5]
|
||||||
for model in models:
|
for model in models:
|
||||||
record = dbh.find_best(dataset[0], model, experiment)
|
record = dbh.find_best(dataset[0], model, experiment)
|
||||||
|
@@ -46,18 +46,21 @@ class MySQL:
|
|||||||
|
|
||||||
def find_best(self, dataset, classifier="any", experiment="any"):
|
def find_best(self, dataset, classifier="any", experiment="any"):
|
||||||
cursor = self._database.cursor(buffered=True)
|
cursor = self._database.cursor(buffered=True)
|
||||||
if classifier == "any":
|
date_from = "2021-01-20"
|
||||||
command = (
|
command = (
|
||||||
f"select * from results r inner join reference e on "
|
f"select * from results r inner join reference e on "
|
||||||
f"r.dataset=e.dataset where r.dataset='{dataset}' and "
|
f"r.dataset=e.dataset where r.dataset='{dataset}'"
|
||||||
f"date>='2021-01-20'"
|
)
|
||||||
)
|
if isinstance(classifier, list):
|
||||||
else:
|
classifier_set = "("
|
||||||
command = (
|
for i, item in enumerate(classifier):
|
||||||
f"select * from results r inner join reference e on "
|
comma = "" if i == 0 else ","
|
||||||
f"r.dataset=e.dataset where r.dataset='{dataset}' and "
|
classifier_set += f"{comma}'{item}'"
|
||||||
f"classifier='{classifier}' and date>='2021-01-20'"
|
classifier_set += ")"
|
||||||
)
|
command += f" and r.classifier in {classifier_set}"
|
||||||
|
elif classifier != "any":
|
||||||
|
command += f" and r.classifier='{classifier}'"
|
||||||
|
command += f" and date>='{date_from}'"
|
||||||
command += "" if experiment == "any" else f" and type='{experiment}'"
|
command += "" if experiment == "any" else f" and type='{experiment}'"
|
||||||
command += (
|
command += (
|
||||||
" order by r.dataset, accuracy desc, classifier desc, "
|
" order by r.dataset, accuracy desc, classifier desc, "
|
||||||
@@ -182,7 +185,8 @@ class BD(ABC):
|
|||||||
database = dbh.get_connection()
|
database = dbh.get_connection()
|
||||||
command_insert = (
|
command_insert = (
|
||||||
"replace into results (date, time, type, accuracy, "
|
"replace into results (date, time, type, accuracy, "
|
||||||
"dataset, classifier, norm, stand, parameters, accuracy_std, time_spent, time_spent_std) values (%s, %s, "
|
"dataset, classifier, norm, stand, parameters, accuracy_std, "
|
||||||
|
"time_spent, time_spent_std) values (%s, %s, "
|
||||||
"%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"
|
"%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"
|
||||||
)
|
)
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
@@ -54,7 +54,7 @@ def report_header(exclude_params):
|
|||||||
|
|
||||||
def report_line(record, agg):
|
def report_line(record, agg):
|
||||||
accuracy = record[5]
|
accuracy = record[5]
|
||||||
expected = record[10]
|
expected = record[13]
|
||||||
if accuracy < expected:
|
if accuracy < expected:
|
||||||
agg["worse"] += 1
|
agg["worse"] += 1
|
||||||
sign = "-"
|
sign = "-"
|
||||||
@@ -84,7 +84,7 @@ def report_footer(agg):
|
|||||||
TextColor.MAGENTA + f"we have equal results {agg['equal']:2d} times"
|
TextColor.MAGENTA + f"we have equal results {agg['equal']:2d} times"
|
||||||
)
|
)
|
||||||
color = TextColor.LINE1
|
color = TextColor.LINE1
|
||||||
for item in ["stree", "bagging", "adaBoost", "odte"]:
|
for item in models:
|
||||||
print(color + f"{item:10s} used {agg[item]:2d} times")
|
print(color + f"{item:10s} used {agg[item]:2d} times")
|
||||||
color = (
|
color = (
|
||||||
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
|
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
|
||||||
|
Reference in New Issue
Block a user