mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-16 07:56:07 +00:00
181 lines
5.0 KiB
Python
181 lines
5.0 KiB
Python
import argparse
|
|
from typing import Tuple
|
|
from experimentation.Sets import Datasets
|
|
from experimentation.Utils import TextColor
|
|
from experimentation.Database import MySQL
|
|
|
|
report_csv = "report.csv"
|
|
models_tree = [
|
|
"stree",
|
|
"wodt",
|
|
"j48svm",
|
|
"oc1",
|
|
"cart",
|
|
"baseRaF",
|
|
"baseRoF",
|
|
"baseRRoF",
|
|
]
|
|
models_ensemble = ["odte", "adaBoost", "bagging", "TBRaF", "TBRoF", "TBRRoF"]
|
|
title = "Best model results"
|
|
lengths = (30, 9, 11, 11, 11, 11, 11, 11, 11, 11)
|
|
|
|
|
|
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument(
|
|
"-e",
|
|
"--experiment",
|
|
type=str,
|
|
choices=["gridsearch", "crossval", "any"],
|
|
required=False,
|
|
default="gridsearch",
|
|
)
|
|
ap.add_argument(
|
|
"-m",
|
|
"--model",
|
|
type=str,
|
|
choices=["tree", "ensemble"],
|
|
required=False,
|
|
default="tree",
|
|
)
|
|
ap.add_argument(
|
|
"-c",
|
|
"--csv-output",
|
|
type=bool,
|
|
required=False,
|
|
default=False,
|
|
)
|
|
args = ap.parse_args()
|
|
return (args.experiment, args.model, args.csv_output)
|
|
|
|
|
|
def report_header_content(title, experiment, model_type):
|
|
length = sum(lengths) + len(lengths) - 1
|
|
output = "\n" + "*" * length + "\n"
|
|
titles_length = len(title) + len(experiment) + len(model_type) + 21
|
|
num = (length - titles_length) // 2 - 3
|
|
num2 = length - titles_length - 7 - 2 * num
|
|
output += (
|
|
"*"
|
|
+ " " * num
|
|
+ f"{title} - Experiment: {experiment} - Models: {model_type}"
|
|
+ " " * (num + num2)
|
|
+ "*\n"
|
|
)
|
|
output += "*" * length + "\n\n"
|
|
lines = ""
|
|
for item, data in enumerate(fields):
|
|
output += f"{fields[item]:{lengths[item]}} "
|
|
lines += "=" * lengths[item] + " "
|
|
output += f"\n{lines}"
|
|
return output
|
|
|
|
|
|
def report_header(title, experiment, model_type):
|
|
print(
|
|
TextColor.HEADER
|
|
+ report_header_content(title, experiment, model_type)
|
|
+ TextColor.ENDC
|
|
)
|
|
|
|
|
|
def report_line(line):
|
|
output = f"{line['dataset']:{lengths[0] + 5}s} "
|
|
data = models.copy()
|
|
data.insert(0, "reference")
|
|
for key, model in enumerate(data):
|
|
output += f"{line[model]:{lengths[key + 1]}s} "
|
|
return output
|
|
|
|
|
|
def report_footer(agg):
|
|
print(
|
|
TextColor.GREEN
|
|
+ f"we have better results {agg['better']['items']:2d} times"
|
|
)
|
|
print(
|
|
TextColor.RED
|
|
+ f"we have worse results {agg['worse']['items']:2d} times"
|
|
)
|
|
color = TextColor.LINE1
|
|
for item in models:
|
|
print(
|
|
color + f"{item:10s} used {agg[item]['items']:2d} times ", end=""
|
|
)
|
|
print(
|
|
color + f"better {agg[item]['better']:2d} times ",
|
|
end="",
|
|
)
|
|
print(color + f"worse {agg[item]['worse']:2d} times ")
|
|
color = (
|
|
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
|
|
)
|
|
|
|
|
|
(experiment, model_type, csv_output) = parse_arguments()
|
|
dbh = MySQL()
|
|
database = dbh.get_connection()
|
|
dt = Datasets(False, False, "tanveer")
|
|
fields = ("Dataset", "Reference")
|
|
models = models_tree if model_type == "tree" else models_ensemble
|
|
for item in models:
|
|
fields += (f"{item}",)
|
|
report_header(title, experiment, model_type)
|
|
color = TextColor.LINE1
|
|
agg = {}
|
|
for item in [
|
|
"better",
|
|
"worse",
|
|
] + models:
|
|
agg[item] = {}
|
|
agg[item]["items"] = 0
|
|
agg[item]["better"] = 0
|
|
agg[item]["worse"] = 0
|
|
if csv_output:
|
|
f = open(report_csv, "w")
|
|
print("dataset, classifier, accuracy", file=f)
|
|
for dataset in dt:
|
|
find_one = False
|
|
# Look for max accuracy for any given dataset
|
|
line = {"dataset": color + dataset[0]}
|
|
record = dbh.find_best(dataset[0], models, experiment)
|
|
max_accuracy = 0.0 if record is None else record[5]
|
|
for model in models:
|
|
record = dbh.find_best(dataset[0], model, experiment)
|
|
if record is None:
|
|
line[model] = color + "-" * 9 + " "
|
|
else:
|
|
reference = record[13]
|
|
accuracy = record[5]
|
|
find_one = True
|
|
agg[model]["items"] += 1
|
|
if accuracy > reference:
|
|
sign = "+"
|
|
agg["better"]["items"] += 1
|
|
agg[model]["better"] += 1
|
|
else:
|
|
sign = "-"
|
|
agg["worse"]["items"] += 1
|
|
agg[model]["worse"] += 1
|
|
item = f"{accuracy:9.7} {sign}"
|
|
line["reference"] = f"{reference:9.7}"
|
|
line[model] = (
|
|
TextColor.GREEN + TextColor.BOLD + item + TextColor.ENDC
|
|
if accuracy == max_accuracy
|
|
else color + item
|
|
)
|
|
if csv_output:
|
|
print(f"{dataset[0]}, {model}, {accuracy}", file=f)
|
|
if not find_one:
|
|
print(TextColor.FAIL + f"*No results found for {dataset[0]}")
|
|
else:
|
|
color = (
|
|
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
|
|
)
|
|
print(report_line(line))
|
|
report_footer(agg)
|
|
if csv_output:
|
|
f.close()
|
|
print(f"{report_csv} file generated")
|
|
dbh.close()
|