mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 23:46:03 +00:00
289 lines
8.1 KiB
Python
289 lines
8.1 KiB
Python
import argparse
|
|
from typing import Tuple
|
|
import numpy as np
|
|
from experimentation.Sets import Datasets
|
|
from experimentation.Utils import TextColor
|
|
from experimentation.Database import MySQL
|
|
|
|
report_csv = "report.csv"
|
|
table_tex = "table.tex"
|
|
models_tree = [
|
|
"stree",
|
|
"stree_default",
|
|
"wodt",
|
|
"j48svm",
|
|
"oc1",
|
|
"cart",
|
|
"baseRaF",
|
|
]
|
|
models_ensemble = ["odte", "adaBoost", "bagging", "TBRaF", "TBRoF", "TBRRoF"]
|
|
description = ["samp", "var", "cls"]
|
|
complexity = ["nodes", "leaves", "depth"]
|
|
title = "Best model results"
|
|
lengths = [30, 4, 3, 3, 3, 3, 3, 12, 12, 12, 12, 12, 12, 12]
|
|
|
|
|
|
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument(
|
|
"-e",
|
|
"--experiment",
|
|
type=str,
|
|
choices=["gridsearch", "crossval", "any"],
|
|
required=False,
|
|
default="gridsearch",
|
|
)
|
|
ap.add_argument(
|
|
"-m",
|
|
"--model",
|
|
type=str,
|
|
choices=["tree", "ensemble"],
|
|
required=False,
|
|
default="tree",
|
|
)
|
|
ap.add_argument(
|
|
"-c",
|
|
"--csv-output",
|
|
type=bool,
|
|
required=False,
|
|
default=False,
|
|
)
|
|
ap.add_argument(
|
|
"-t",
|
|
"--tex-output",
|
|
type=bool,
|
|
required=False,
|
|
default=False,
|
|
)
|
|
ap.add_argument(
|
|
"-o",
|
|
"--compare",
|
|
type=bool,
|
|
required=False,
|
|
default=False,
|
|
)
|
|
args = ap.parse_args()
|
|
return (
|
|
args.experiment,
|
|
args.model,
|
|
args.csv_output,
|
|
args.tex_output,
|
|
args.compare,
|
|
)
|
|
|
|
|
|
def print_header_tex(file_tex, second=False):
|
|
# old_header = (
|
|
# "\\begin{table}[ht]\n"
|
|
# "\\centering"
|
|
# "\\resizebox{\\textwidth}{!}{\\begin{tabular}{|r|l|r|r|r|c|c|c|c|c|c|
|
|
# c"
|
|
# "|}"
|
|
# "\\hline\n"
|
|
# "\\# & Dataset & Samples & Features & Classes & stree & stree def. &
|
|
# "
|
|
# "wodt & j48svm & oc1 & cart & baseRaF\\\\\n"
|
|
# "\\hline"
|
|
# )
|
|
cont = ""
|
|
num = ""
|
|
if second:
|
|
cont = " (cont.)"
|
|
num = "2"
|
|
header = (
|
|
"\\begin{sidewaystable}[ht]\n"
|
|
"\\centering\n"
|
|
"\\renewcommand{\\arraystretch}{1.2}\n"
|
|
"\\renewcommand{\\tabcolsep}{0.07cm}\n"
|
|
"\\caption{Datasets used during the experimentation" + cont + "}\n"
|
|
"\\label{table:datasets" + num + "}\n"
|
|
"\\resizebox{0.95\\textwidth}{!}{\n"
|
|
"\\begin{tabular}{rlrrrccccccc}\\hline\n"
|
|
"\\# & Dataset & \\#S & \\#F & \\#L & stree & stree default & wodt & "
|
|
"j48svm & oc1 & cart & baseRaF\\\\\n"
|
|
"\\hline\n"
|
|
)
|
|
print(header, file=file_tex)
|
|
|
|
|
|
def print_line_tex(number, dataset, line, file_tex):
|
|
dataset_name = dataset.replace("_", "\\_")
|
|
print_line = (
|
|
f"{number} & {dataset_name} & {line['samp']} & {line['var']} "
|
|
f"& {line['cls']}"
|
|
)
|
|
for model in models:
|
|
item = line[model]
|
|
print_line += f" & {item}"
|
|
print_line += "\\\\"
|
|
print(f"{print_line}", file=file_tex)
|
|
|
|
|
|
def print_footer_tex(file_tex):
|
|
# old_footer = (
|
|
# "\\hline\n"
|
|
# "\\csname @@input\\endcsname wintieloss\n"
|
|
# "\\hline\n"
|
|
# "\\end{tabular}}\n"
|
|
# "\\caption{Datasets used during the experimentation}\n"
|
|
# "\\label{table:datasets}\n"
|
|
# "\\end{table}"
|
|
# )
|
|
footer = "\\hline\n\\end{tabular}}\n\\end{sidewaystable}\n"
|
|
print(f"{footer}", file=file_tex)
|
|
|
|
|
|
def report_header_content(title, experiment, model_type):
|
|
length = sum(lengths) + len(lengths) - 1
|
|
output = "\n" + "*" * length + "\n"
|
|
titles_length = len(title) + len(experiment) + len(model_type) + 21
|
|
num = (length - titles_length) // 2 - 3
|
|
num2 = length - titles_length - 7 - 2 * num
|
|
output += (
|
|
"*"
|
|
+ " " * num
|
|
+ f"{title} - Experiment: {experiment} - Models: {model_type}"
|
|
+ " " * (num + num2)
|
|
+ "*\n"
|
|
)
|
|
output += "*" * length + "\n\n"
|
|
lines = ""
|
|
for item, data in enumerate(fields):
|
|
output += f"{fields[item]:^{lengths[item]}} "
|
|
lines += "=" * lengths[item] + " "
|
|
output += f"\n{lines}"
|
|
return output
|
|
|
|
|
|
def report_header(title, experiment, model_type):
|
|
print(
|
|
TextColor.HEADER
|
|
+ report_header_content(title, experiment, model_type)
|
|
+ TextColor.ENDC
|
|
)
|
|
|
|
|
|
def report_line(line):
|
|
output = f"{line['dataset']:{lengths[0] + 5}s} "
|
|
for key, item in enumerate(description + complexity):
|
|
output += f"{line[item]:{lengths[key + 1]}d} "
|
|
data = models.copy()
|
|
for key, model in enumerate(data):
|
|
output += f"{line[model]:{lengths[key + 7]}s} "
|
|
return output
|
|
|
|
|
|
def report_footer(agg):
|
|
length = sum(lengths) + len(lengths) - 1
|
|
print("-" * length)
|
|
color = TextColor.LINE2
|
|
print(color + "|{0:15s}|{1:6s}|".format("Classifier", "# Best"))
|
|
print(color + "=" * 24)
|
|
for item in models:
|
|
print(color + f"|{item:15s}", end="|")
|
|
print(color + f"{agg[item]['best']:6d}|")
|
|
color = (
|
|
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
|
|
)
|
|
print("-" * 24)
|
|
|
|
|
|
(experiment, model_type, csv_output, tex_output, compare) = parse_arguments()
|
|
dbh = MySQL()
|
|
database = dbh.get_connection()
|
|
dt = Datasets(False, False, "tanveer")
|
|
fields = (
|
|
"Dataset",
|
|
"Samp",
|
|
"Var",
|
|
"Cls",
|
|
"Nod",
|
|
"Lea",
|
|
"Dep",
|
|
)
|
|
if tex_output:
|
|
# We need the stree_std column for the tex output
|
|
compare = True
|
|
if not compare:
|
|
# remove stree_default from fields list and lengths
|
|
models_tree.pop(1)
|
|
lengths.pop(7)
|
|
models = models_tree if model_type == "tree" else models_ensemble
|
|
for item in models:
|
|
fields += (f"{item}",)
|
|
report_header(title, experiment, model_type)
|
|
color = TextColor.LINE1
|
|
agg = {}
|
|
for item in [
|
|
"better",
|
|
"worse",
|
|
] + models:
|
|
agg[item] = {}
|
|
agg[item]["best"] = 0
|
|
if csv_output:
|
|
file_csv = open(report_csv, "w")
|
|
print("dataset, classifier, accuracy", file=file_csv)
|
|
if tex_output:
|
|
file_tex = open(table_tex, "w")
|
|
print_header_tex(file_tex, second=False)
|
|
for number, dataset in enumerate(dt):
|
|
find_one = False
|
|
# Look for max accuracy for any given dataset
|
|
line = {"dataset": color + dataset[0]}
|
|
X, y = dt.load(dataset[0]) # type: ignore
|
|
line["samp"], line["var"] = X.shape
|
|
line["cls"] = len(np.unique(y))
|
|
record = dbh.find_best(dataset[0], models, experiment)
|
|
max_accuracy = 0.0 if record is None else record[5]
|
|
line["nodes"] = 0
|
|
line["leaves"] = 0
|
|
line["depth"] = 0
|
|
line_tex = line.copy()
|
|
for model in models:
|
|
record = dbh.find_best(dataset[0], model, experiment)
|
|
if record is None:
|
|
line[model] = color + "-" * 12
|
|
else:
|
|
if model == "stree":
|
|
line["nodes"] = record[12]
|
|
line["leaves"] = record[13]
|
|
line["depth"] = record[14]
|
|
reference = record[13]
|
|
accuracy = record[5]
|
|
acc_std = record[11]
|
|
find_one = True
|
|
item = f"{accuracy:.4f}±{acc_std:.3f}"
|
|
line_tex[model] = item
|
|
if round(accuracy, 4) == round(max_accuracy, 4):
|
|
line_tex[model] = "\\textbf{" + item + "}"
|
|
if accuracy == max_accuracy:
|
|
line[model] = (
|
|
TextColor.GREEN + TextColor.BOLD + item + TextColor.ENDC
|
|
)
|
|
agg[model]["best"] += 1
|
|
else:
|
|
line[model] = color + item
|
|
if csv_output:
|
|
print(f"{dataset[0]}, {model}, {accuracy}", file=file_csv)
|
|
if tex_output:
|
|
print_line_tex(number + 1, dataset[0], line_tex, file_tex)
|
|
if number == 24:
|
|
print_footer_tex(file_tex)
|
|
print_header_tex(file_tex, second=True)
|
|
if not find_one:
|
|
print(TextColor.FAIL + f"*No results found for {dataset[0]}")
|
|
else:
|
|
color = (
|
|
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
|
|
)
|
|
print(report_line(line))
|
|
report_footer(agg)
|
|
if csv_output:
|
|
file_csv.close()
|
|
print(f"{report_csv} file generated")
|
|
if tex_output:
|
|
print_footer_tex(file_tex)
|
|
file_tex.close()
|
|
print(f"{table_tex} file generated")
|
|
dbh.close()
|