Files
stree_datasets/analysis_mysql.py
Ricardo Montañana 7f75115fa9 Add stree default to analysis
add experiment to report_mysql
fix crosval experiment to get the best "gridsearch" parameters
2021-03-26 00:06:48 +01:00

192 lines
5.3 KiB
Python

import argparse
from typing import Tuple
import numpy as np
from experimentation.Sets import Datasets
from experimentation.Utils import TextColor
from experimentation.Database import MySQL
report_csv = "report.csv"
models_tree = [
"stree",
"stree_default",
"wodt",
"j48svm",
"oc1",
"cart",
"baseRaF",
]
models_ensemble = ["odte", "adaBoost", "bagging", "TBRaF", "TBRoF", "TBRRoF"]
description = ["samp", "var", "cls"]
complexity = ["nodes", "leaves", "depth"]
title = "Best model results"
lengths = [30, 4, 3, 3, 3, 3, 3, 12, 12, 12, 12, 12, 12, 12]
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
ap = argparse.ArgumentParser()
ap.add_argument(
"-e",
"--experiment",
type=str,
choices=["gridsearch", "crossval", "any"],
required=False,
default="gridsearch",
)
ap.add_argument(
"-m",
"--model",
type=str,
choices=["tree", "ensemble"],
required=False,
default="tree",
)
ap.add_argument(
"-c",
"--csv-output",
type=bool,
required=False,
default=False,
)
ap.add_argument(
"-o",
"--compare",
type=bool,
required=False,
default=False,
)
args = ap.parse_args()
return (args.experiment, args.model, args.csv_output, args.compare)
def report_header_content(title, experiment, model_type):
length = sum(lengths) + len(lengths) - 1
output = "\n" + "*" * length + "\n"
titles_length = len(title) + len(experiment) + len(model_type) + 21
num = (length - titles_length) // 2 - 3
num2 = length - titles_length - 7 - 2 * num
output += (
"*"
+ " " * num
+ f"{title} - Experiment: {experiment} - Models: {model_type}"
+ " " * (num + num2)
+ "*\n"
)
output += "*" * length + "\n\n"
lines = ""
for item, data in enumerate(fields):
output += f"{fields[item]:^{lengths[item]}} "
lines += "=" * lengths[item] + " "
output += f"\n{lines}"
return output
def report_header(title, experiment, model_type):
print(
TextColor.HEADER
+ report_header_content(title, experiment, model_type)
+ TextColor.ENDC
)
def report_line(line):
output = f"{line['dataset']:{lengths[0] + 5}s} "
for key, item in enumerate(description + complexity):
output += f"{line[item]:{lengths[key + 1]}d} "
data = models.copy()
for key, model in enumerate(data):
output += f"{line[model]:{lengths[key + 7]}s} "
return output
def report_footer(agg):
length = sum(lengths) + len(lengths) - 1
print("-" * length)
color = TextColor.LINE1
for item in models:
print(color + f"{item:10s} ", end="")
print(color + f"best of models {agg[item]['best']:2d} times")
color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
)
(experiment, model_type, csv_output, compare) = parse_arguments()
dbh = MySQL()
database = dbh.get_connection()
dt = Datasets(False, False, "tanveer")
fields = (
"Dataset",
"Samp",
"Var",
"Cls",
"Nod",
"Lea",
"Dep",
)
if not compare:
# remove stree_default from fields list and lengths
models_tree.pop(1)
lengths.pop(7)
models = models_tree if model_type == "tree" else models_ensemble
for item in models:
fields += (f"{item}",)
report_header(title, experiment, model_type)
color = TextColor.LINE1
agg = {}
for item in [
"better",
"worse",
] + models:
agg[item] = {}
agg[item]["best"] = 0
if csv_output:
f = open(report_csv, "w")
print("dataset, classifier, accuracy", file=f)
for dataset in dt:
find_one = False
# Look for max accuracy for any given dataset
line = {"dataset": color + dataset[0]}
X, y = dt.load(dataset[0]) # type: ignore
line["samp"], line["var"] = X.shape
line["cls"] = len(np.unique(y))
record = dbh.find_best(dataset[0], models, experiment)
max_accuracy = 0.0 if record is None else record[5]
line["nodes"] = 0
line["leaves"] = 0
line["depth"] = 0
for model in models:
record = dbh.find_best(dataset[0], model, experiment)
if record is None:
line[model] = color + "-" * 12
else:
if model == "stree":
line["nodes"] = record[12]
line["leaves"] = record[13]
line["depth"] = record[14]
reference = record[13]
accuracy = record[5]
acc_std = record[11]
find_one = True
item = f"{accuracy:.4f}±{acc_std:.3f}"
if accuracy == max_accuracy:
line[model] = (
TextColor.GREEN + TextColor.BOLD + item + TextColor.ENDC
)
agg[model]["best"] += 1
else:
line[model] = color + item
if csv_output:
print(f"{dataset[0]}, {model}, {accuracy}", file=f)
if not find_one:
print(TextColor.FAIL + f"*No results found for {dataset[0]}")
else:
color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
)
print(report_line(line))
report_footer(agg)
if csv_output:
f.close()
print(f"{report_csv} file generated")
dbh.close()