mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-16 07:56:07 +00:00
Add time analysis
This commit is contained in:
@@ -20,7 +20,8 @@ models_ensemble = ["odte", "adaBoost", "bagging", "TBRaF", "TBRoF", "TBRRoF"]
|
||||
description = ["samp", "var", "cls"]
|
||||
complexity = ["nodes", "leaves", "depth"]
|
||||
title = "Best model results"
|
||||
lengths = [30, 4, 3, 3, 3, 3, 3, 12, 12, 12, 12, 12, 12, 12]
|
||||
lengths_acc = [30, 4, 3, 3, 3, 3, 3, 12, 12, 12, 12, 12, 12, 12]
|
||||
lengths_time = [30, 4, 3, 3, 3, 3, 3, 17, 17, 17, 17, 17, 17, 17]
|
||||
|
||||
|
||||
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
||||
@@ -62,6 +63,13 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
||||
required=False,
|
||||
default=False,
|
||||
)
|
||||
ap.add_argument(
|
||||
"-i",
|
||||
"--time",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
)
|
||||
args = ap.parse_args()
|
||||
return (
|
||||
args.experiment,
|
||||
@@ -69,6 +77,7 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
||||
args.csv_output,
|
||||
args.tex_output,
|
||||
args.compare,
|
||||
args.time,
|
||||
)
|
||||
|
||||
|
||||
@@ -137,13 +146,15 @@ def report_header_content(title, experiment, model_type):
|
||||
length = sum(lengths) + len(lengths) - 1
|
||||
output = "\n" + "*" * length + "\n"
|
||||
titles_length = len(title) + len(experiment) + len(model_type) + 21
|
||||
num = (length - titles_length) // 2 - 3
|
||||
num2 = length - titles_length - 7 - 2 * num
|
||||
num = (length - titles_length) // 2 - 10
|
||||
num2 = length - titles_length - num - 20
|
||||
report_type = " --Times--" if time_info else "Accuracies"
|
||||
output += (
|
||||
"*"
|
||||
+ " " * num
|
||||
+ f"{title} - Experiment: {experiment} - Models: {model_type}"
|
||||
+ " " * (num + num2)
|
||||
+ f"{title} - Experiment: {experiment} - Models: {model_type} - "
|
||||
+ f"{report_type}"
|
||||
+ " " * num2
|
||||
+ "*\n"
|
||||
)
|
||||
output += "*" * length + "\n\n"
|
||||
@@ -188,7 +199,14 @@ def report_footer(agg):
|
||||
print("-" * 24)
|
||||
|
||||
|
||||
(experiment, model_type, csv_output, tex_output, compare) = parse_arguments()
|
||||
(
|
||||
experiment,
|
||||
model_type,
|
||||
csv_output,
|
||||
tex_output,
|
||||
compare,
|
||||
time_info,
|
||||
) = parse_arguments()
|
||||
dbh = MySQL()
|
||||
database = dbh.get_connection()
|
||||
dt = Datasets(False, False, "tanveer")
|
||||
@@ -201,6 +219,7 @@ fields = (
|
||||
"Lea",
|
||||
"Dep",
|
||||
)
|
||||
lengths = lengths_time if time_info else lengths_acc
|
||||
if tex_output:
|
||||
# We need the stree_std column for the tex output
|
||||
compare = True
|
||||
@@ -233,13 +252,15 @@ for number, dataset in enumerate(dt):
|
||||
X, y = dt.load(dataset[0]) # type: ignore
|
||||
line["samp"], line["var"] = X.shape
|
||||
line["cls"] = len(np.unique(y))
|
||||
record = dbh.find_best(dataset[0], models, experiment)
|
||||
max_accuracy = 0.0 if record is None else record[5]
|
||||
record = dbh.find_best(dataset[0], models, experiment, time_info)
|
||||
max_accuracy = (
|
||||
0.0 if record is None else record[9] if time_info else record[5]
|
||||
)
|
||||
line["nodes"] = 0
|
||||
line["leaves"] = 0
|
||||
line["depth"] = 0
|
||||
line_tex = line.copy()
|
||||
for model in models:
|
||||
for column, model in enumerate(models):
|
||||
record = dbh.find_best(dataset[0], model, experiment)
|
||||
if record is None:
|
||||
line[model] = color + "-" * 12
|
||||
@@ -249,10 +270,11 @@ for number, dataset in enumerate(dt):
|
||||
line["leaves"] = record[13]
|
||||
line["depth"] = record[14]
|
||||
reference = record[13]
|
||||
accuracy = record[5]
|
||||
acc_std = record[11]
|
||||
accuracy = record[9] if time_info else record[5]
|
||||
acc_std = record[10] if time_info else record[11]
|
||||
find_one = True
|
||||
item = f"{accuracy:.4f}±{acc_std:.3f}"
|
||||
item = f"{accuracy:{lengths[column + 7]-6}.4f}±{acc_std:.3f}"
|
||||
# item = f"{accuracy:.4f}±{acc_std:.3f}"
|
||||
line_tex[model] = item
|
||||
if round(accuracy, 4) == round(max_accuracy, 4):
|
||||
line_tex[model] = "\\textbf{" + item + "}"
|
||||
|
Reference in New Issue
Block a user