Reformat table.tex in analysis_mysql

This commit is contained in:
2021-04-01 01:52:08 +02:00
parent 7f75115fa9
commit 01e725a93e
3 changed files with 154 additions and 10 deletions

View File

@@ -6,6 +6,7 @@ from experimentation.Utils import TextColor
from experimentation.Database import MySQL
report_csv = "report.csv"
table_tex = "table.tex"
models_tree = [
"stree",
"stree_default",
@@ -47,6 +48,13 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
required=False,
default=False,
)
ap.add_argument(
"-t",
"--tex-output",
type=bool,
required=False,
default=False,
)
ap.add_argument(
"-o",
"--compare",
@@ -55,7 +63,74 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
default=False,
)
args = ap.parse_args()
return (args.experiment, args.model, args.csv_output, args.compare)
return (
args.experiment,
args.model,
args.csv_output,
args.tex_output,
args.compare,
)
def print_header_tex(file_tex, second=False):
# old_header = (
# "\\begin{table}[ht]\n"
# "\\centering"
# "\\resizebox{\\textwidth}{!}{\\begin{tabular}{|r|l|r|r|r|c|c|c|c|c|c|
# c"
# "|}"
# "\\hline\n"
# "\\# & Dataset & Samples & Features & Classes & stree & stree def. &
# "
# "wodt & j48svm & oc1 & cart & baseRaF\\\\\n"
# "\\hline"
# )
cont = ""
num = ""
if second:
cont = " (cont.)"
num = "2"
header = (
"\\begin{sidewaystable}[ht]\n"
"\\centering\n"
"\\renewcommand{\\arraystretch}{1.2}\n"
"\\renewcommand{\\tabcolsep}{0.07cm}\n"
"\\caption{Datasets used during the experimentation" + cont + "}\n"
"\\label{table:datasets" + num + "}\n"
"\\resizebox{0.95\\textwidth}{!}{\n"
"\\begin{tabular}{rlrrrccccccc}\\hline\n"
"\\# & Dataset & \\#S & \\#F & \\#L & stree & stree default & wodt & "
"j48svm & oc1 & cart & baseRaF\\\\\n"
"\\hline\n"
)
print(header, file=file_tex)
def print_line_tex(number, dataset, line, file_tex):
dataset_name = dataset.replace("_", "\\_")
print_line = (
f"{number} & {dataset_name} & {line['samp']} & {line['var']} "
f"& {line['cls']}"
)
for model in models:
item = line[model]
print_line += f" & {item}"
print_line += "\\\\"
print(f"{print_line}", file=file_tex)
def print_footer_tex(file_tex):
# old_footer = (
# "\\hline\n"
# "\\csname @@input\\endcsname wintieloss\n"
# "\\hline\n"
# "\\end{tabular}}\n"
# "\\caption{Datasets used during the experimentation}\n"
# "\\label{table:datasets}\n"
# "\\end{table}"
# )
footer = "\\hline\n\\end{tabular}}\n\\end{sidewaystable}\n"
print(f"{footer}", file=file_tex)
def report_header_content(title, experiment, model_type):
@@ -110,7 +185,7 @@ def report_footer(agg):
)
(experiment, model_type, csv_output, compare) = parse_arguments()
(experiment, model_type, csv_output, tex_output, compare) = parse_arguments()
dbh = MySQL()
database = dbh.get_connection()
dt = Datasets(False, False, "tanveer")
@@ -123,6 +198,9 @@ fields = (
"Lea",
"Dep",
)
if tex_output:
# We need the stree_std column for the tex output
compare = True
if not compare:
# remove stree_default from fields list and lengths
models_tree.pop(1)
@@ -140,9 +218,12 @@ for item in [
agg[item] = {}
agg[item]["best"] = 0
if csv_output:
f = open(report_csv, "w")
print("dataset, classifier, accuracy", file=f)
for dataset in dt:
file_csv = open(report_csv, "w")
print("dataset, classifier, accuracy", file=file_csv)
if tex_output:
file_tex = open(table_tex, "w")
print_header_tex(file_tex, second=False)
for number, dataset in enumerate(dt):
find_one = False
# Look for max accuracy for any given dataset
line = {"dataset": color + dataset[0]}
@@ -154,6 +235,7 @@ for dataset in dt:
line["nodes"] = 0
line["leaves"] = 0
line["depth"] = 0
line_tex = line.copy()
for model in models:
record = dbh.find_best(dataset[0], model, experiment)
if record is None:
@@ -168,6 +250,9 @@ for dataset in dt:
acc_std = record[11]
find_one = True
item = f"{accuracy:.4f}±{acc_std:.3f}"
line_tex[model] = item
if round(accuracy, 4) == round(max_accuracy, 4):
line_tex[model] = "\\textbf{" + item + "}"
if accuracy == max_accuracy:
line[model] = (
TextColor.GREEN + TextColor.BOLD + item + TextColor.ENDC
@@ -176,7 +261,12 @@ for dataset in dt:
else:
line[model] = color + item
if csv_output:
print(f"{dataset[0]}, {model}, {accuracy}", file=f)
print(f"{dataset[0]}, {model}, {accuracy}", file=file_csv)
if tex_output:
print_line_tex(number + 1, dataset[0], line_tex, file_tex)
if number == 24:
print_footer_tex(file_tex)
print_header_tex(file_tex, second=True)
if not find_one:
print(TextColor.FAIL + f"*No results found for {dataset[0]}")
else:
@@ -186,6 +276,10 @@ for dataset in dt:
print(report_line(line))
report_footer(agg)
if csv_output:
f.close()
file_csv.close()
print(f"{report_csv} file generated")
if tex_output:
print_footer_tex(file_tex)
file_tex.close()
print(f"{table_tex} file generated")
dbh.close()