Files
stree_datasets/analysis_mysql.py

373 lines
10 KiB
Python

import argparse
from typing import Tuple
import numpy as np
import xlsxwriter
from experimentation.Sets import Datasets
from experimentation.Utils import TextColor
from experimentation.Database import MySQL
report_csv = "report.csv"
table_tex = "table.tex"
models_tree = [
"stree",
"stree_default",
"wodt",
"j48svm",
"oc1",
"cart",
"baseRaF",
]
models_ensemble = ["odte", "adaBoost", "bagging", "TBRaF", "TBRoF", "TBRRoF"]
description = ["samp", "var", "cls"]
complexity = ["nodes", "leaves", "depth"]
title = "Best model results"
lengths_acc = [30, 4, 3, 3, 6, 6, 5, 12, 12, 12, 12, 12, 12, 12]
lengths_time = [30, 4, 3, 3, 6, 6, 5, 17, 17, 17, 17, 17, 17, 17]
def parse_arguments() -> Tuple[str, str, str, bool, bool]:
ap = argparse.ArgumentParser()
ap.add_argument(
"-e",
"--experiment",
type=str,
choices=["gridsearch", "crossval", "any"],
required=False,
default="crossval",
)
ap.add_argument(
"-m",
"--model",
type=str,
choices=["tree", "ensemble"],
required=False,
default="tree",
)
ap.add_argument(
"-c",
"--csv-output",
type=bool,
required=False,
default=False,
)
ap.add_argument(
"-t",
"--tex-output",
type=bool,
required=False,
default=False,
)
ap.add_argument(
"-o",
"--compare",
type=int,
required=False,
default=1,
help="1=stree optimized, 2=stree_default, 3=both",
)
ap.add_argument(
"-i",
"--time",
type=bool,
required=False,
default=False,
)
ap.add_argument(
"-x",
"--excel",
type=str,
default="",
required=False,
help="generate excel file",
)
args = ap.parse_args()
return (
args.experiment,
args.model,
args.csv_output,
args.tex_output,
args.compare,
args.time,
args.excel,
)
def excel_write_header(book, sheet):
header = book.add_format()
header.set_font_size(18)
subheader = book.add_format()
subheader.set_font_size(16)
bold = book.add_format({"bold": True, "font_size": 14})
sheet.write(0, 0, "Dataset", bold)
sheet.set_column(0, 0, 30)
i = 1
lengths = [10, 10, 10, 10, 10, 10, 10]
for item, length in zip(models_tree, lengths):
sheet.write(0, i, item, bold)
sheet.set_column(i, i, length)
i += 1
def excel_write_line(book, sheet, dataset, line):
try:
excel_write_line.row += 1
except AttributeError:
excel_write_line.row = 1
size_n = 14
decimal = book.add_format({"num_format": "0.000000", "font_size": size_n})
normal = book.add_format({"font_size": size_n})
col = 0
excel_ws.write(excel_write_line.row, col, dataset, normal)
for item in line.values():
sheet.write(excel_write_line.row, col + 1, item, decimal)
col += 1
def excel_write_footer(book):
book.close()
def print_header_tex(file_tex, second=False):
# old_header = (
# "\\begin{table}[ht]\n"
# "\\centering"
# "\\resizebox{\\textwidth}{!}{\\begin{tabular}{|r|l|r|r|r|c|c|c|c|c|c|
# c"
# "|}"
# "\\hline\n"
# "\\# & Dataset & Samples & Features & Classes & stree & stree def. &
# "
# "wodt & j48svm & oc1 & cart & baseRaF\\\\\n"
# "\\hline"
# )
cont = ""
num = ""
if second:
cont = " (cont.)"
num = "2"
header = (
"\\begin{sidewaystable}[ht]\n"
"\\centering\n"
"\\renewcommand{\\arraystretch}{1.2}\n"
"\\renewcommand{\\tabcolsep}{0.07cm}\n"
"\\caption{Datasets used during the experimentation" + cont + "}\n"
"\\label{table:datasets" + num + "}\n"
"\\resizebox{0.95\\textwidth}{!}{\n"
"\\begin{tabular}{rlrrrccccccc}\\hline\n"
"\\# & Dataset & \\#S & \\#F & \\#L & STree & STree default & WODT & "
"J48SVM-ODT & OC1 & CART & TBSVM-ODT\\\\\n"
"\\hline\n"
)
print(header, file=file_tex)
def print_line_tex(number, dataset, line, file_tex):
dataset_name = dataset.replace("_", "\\_")
print_line = (
f"{number} & {dataset_name} & {line['samp']} & {line['var']} "
f"& {line['cls']}"
)
for model in models:
item = line[model]
print_line += f" & {item}"
print_line += "\\\\"
print(f"{print_line}", file=file_tex)
def print_footer_tex(file_tex):
# old_footer = (
# "\\hline\n"
# "\\csname @@input\\endcsname wintieloss\n"
# "\\hline\n"
# "\\end{tabular}}\n"
# "\\caption{Datasets used during the experimentation}\n"
# "\\label{table:datasets}\n"
# "\\end{table}"
# )
footer = "\\hline\n\\end{tabular}}\n\\end{sidewaystable}\n"
print(f"{footer}", file=file_tex)
def report_header_content(title, experiment, model_type):
length = sum(lengths) + len(lengths) - 1
output = "\n" + "*" * length + "\n"
titles_length = len(title) + len(experiment) + len(model_type) + 21
num = (length - titles_length) // 2 - 10
num2 = length - titles_length - num - 20
report_type = " --Times--" if time_info else "Accuracies"
output += (
"*"
+ " " * num
+ f"{title} - Experiment: {experiment} - Models: {model_type} - "
+ f"{report_type}"
+ " " * num2
+ "*\n"
)
output += "*" * length + "\n\n"
lines = ""
for item, data in enumerate(fields):
output += f"{fields[item]:^{lengths[item]}} "
lines += "=" * lengths[item] + " "
output += f"\n{lines}"
return output
def report_header(title, experiment, model_type):
print(
TextColor.HEADER
+ report_header_content(title, experiment, model_type)
+ TextColor.ENDC
)
def report_line(line):
output = f"{line['dataset']:{lengths[0] + 5}s} "
for key, item in enumerate(description):
output += f"{line[item]:{lengths[key + 1]}d} "
for key, item in enumerate(complexity):
output += f"{line[item]:{lengths[key + len(description) + 1]}.2f} "
data = models.copy()
for key, model in enumerate(data):
output += f"{line[model]:{lengths[key + 7]}s} "
return output
def report_footer(agg):
length = sum(lengths) + len(lengths) - 1
print("-" * length)
color = TextColor.LINE2
print(color + "|{0:15s}|{1:6s}|".format("Classifier", "# Best"))
print(color + "=" * 24)
for item in models:
print(color + f"|{item:15s}", end="|")
print(color + f"{agg[item]['best']:6d}|")
color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
)
print("-" * 24)
(
experiment,
model_type,
csv_output,
tex_output,
compare,
time_info,
excel,
) = parse_arguments()
dbh = MySQL()
database = dbh.get_connection()
dt = Datasets(False, False, "tanveer")
fields = (
"Dataset",
"Samp",
"Var",
"Cls",
"Nodes",
"Leave",
"Depth",
)
lengths = lengths_time if time_info else lengths_acc
reference_model = "stree"
if tex_output:
# We need the stree & stree_std column for the tex output
compare = 3
if compare != 3:
# remove stree_default from fields list and lengths
if compare == 1:
models_tree.pop(1)
else:
models_tree.pop(0)
reference_model = "stree_default"
lengths.pop(7)
models = models_tree if model_type == "tree" else models_ensemble
for item in models:
fields += (f"{item}",)
report_header(title, experiment, model_type)
color = TextColor.LINE1
if excel != "":
file_name = f"{excel}.xlsx"
excel_wb = xlsxwriter.Workbook(file_name)
excel_ws = excel_wb.add_worksheet("exreport")
excel_write_header(excel_wb, excel_ws)
agg = {}
for item in [
"better",
"worse",
] + models:
agg[item] = {}
agg[item]["best"] = 0
if csv_output:
file_csv = open(report_csv, "w")
print("dataset, classifier, accuracy", file=file_csv)
if tex_output:
file_tex = open(table_tex, "w")
print_header_tex(file_tex, second=False)
for number, dataset in enumerate(dt):
find_one = False
# Look for max accuracy for any given dataset
line = {"dataset": color + dataset[0]}
X, y = dt.load(dataset[0]) # type: ignore
line["samp"], line["var"] = X.shape
line["cls"] = len(np.unique(y))
record = dbh.find_best(dataset[0], models, experiment)
max_accuracy = 0.0 if record is None else record[5]
line["nodes"] = 0.0
line["leaves"] = 0.0
line["depth"] = 0.0
line_tex = line.copy()
line_excel = {}
for column, model in enumerate(models):
record = dbh.find_best(dataset[0], model, experiment)
if record is None:
line[model] = color + "-" * 12
else:
if model == reference_model:
line["nodes"] = record[12]
line["leaves"] = record[13]
line["depth"] = record[14]
reference = record[13]
accuracy = record[9] if time_info else record[5]
acc_std = record[10] if time_info else record[11]
find_one = True
item = f"{accuracy:{lengths[column + 7]-6}.4f}±{acc_std:.3f}"
line_tex[model] = item
if round(record[5], 4) == round(max_accuracy, 4):
line_tex[model] = "\\textbf{" + item + "}"
line[model] = (
TextColor.GREEN + TextColor.BOLD + item + TextColor.ENDC
)
agg[model]["best"] += 1
else:
line[model] = color + item
if csv_output:
print(f"{dataset[0]}, {model}, {accuracy}", file=file_csv)
if excel != "":
line_excel[model] = accuracy
if excel != "":
excel_write_line(excel_wb, excel_ws, dataset[0], line_excel)
if tex_output:
print_line_tex(number + 1, dataset[0], line_tex, file_tex)
if number == 24:
print_footer_tex(file_tex)
print_header_tex(file_tex, second=True)
if not find_one:
print(TextColor.FAIL + f"*No results found for {dataset[0]}")
else:
color = (
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
)
print(report_line(line))
if excel != "":
excel_write_footer(excel_wb)
report_footer(agg)
if csv_output:
file_csv.close()
print(f"{report_csv} file generated")
if tex_output:
print_footer_tex(file_tex)
file_tex.close()
print(f"{table_tex} file generated")
dbh.close()