Add ticks to report_score

This commit is contained in:
2021-04-27 09:50:02 +02:00
parent b061d40355
commit d9f5bfee6c
2 changed files with 446 additions and 200 deletions

View File

@@ -10,6 +10,7 @@ from sklearn.model_selection import KFold, cross_validate
from experimentation.Sets import Datasets
from experimentation.Database import MySQL
from wodt import TreeClassifier
from experimentation.Utils import TextColor
def parse_arguments():
@@ -178,7 +179,36 @@ def store_string(
return result
def compute_status(dbh, name, model, accuracy):
better_default = "\N{heavy check mark}"
better_stree = TextColor.GREEN + "\N{heavy check mark}" + TextColor.ENDC
best = TextColor.RED + "\N{black star}" + TextColor.ENDC
best_default, _ = get_best_score(dbh, name, model)
best_stree, _ = get_best_score(dbh, name, "stree")
best_all, _ = get_best_score(dbh, name, models_tree)
status = better_default if accuracy >= best_default else " "
status = better_stree if accuracy >= best_stree else status
status = best if accuracy >= best_all else status
return status
def get_best_score(dbh, name, model):
record = dbh.find_best(name, model, "crossval")
accuracy = record[5] if record is not None else 0.0
acc_std = record[11] if record is not None else 0.0
return accuracy, acc_std
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
models_tree = [
"stree",
"stree_default",
"wodt",
"j48svm",
"oc1",
"cart",
"baseRaF",
]
standardize = False
(set_of_files, model, dataset, sql, normalize, parameters) = parse_arguments()
dbh = MySQL()
@@ -206,17 +236,22 @@ if dataset == "all":
"Parameters",
]
header_lengths = [30, 5, 3, 3, 7, 7, 7, 15, 15, 10]
parameters = json.dumps(json.loads(parameters))
if parameters != "{}" and len(parameters) > 10:
header_lengths.pop()
header_lengths.append(len(parameters))
line_col = ""
for field, underscore in zip(header_cols, header_lengths):
print(f"{field:{underscore}s} ", end="")
line_col += "=" * underscore + " "
print(f"\n{line_col}")
for dataset in dt:
X, y = dt.load(dataset[0]) # type: ignore
name = dataset[0]
X, y = dt.load(name) # type: ignore
samples, features = X.shape
classes = len(np.unique(y))
print(
f"{dataset[0]:30s} {samples:5d} {features:3d} {classes:3d} ",
f"{name:30s} {samples:5d} {features:3d} {classes:3d} ",
end="",
)
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
@@ -232,28 +267,35 @@ if dataset == "all":
f"{nodes_item:7.2f} {leaves_item:7.2f} {depth_item:7.2f} ",
end="",
)
print(f"{np.mean(scores):8.6f}±{np.std(scores):6.4f} ", end="")
accuracy = np.mean(scores)
status = (
compute_status(dbh, name, model, accuracy)
if model == "stree_default"
else " "
)
print(f"{accuracy:8.6f}±{np.std(scores):6.4f}{status}", end="")
print(f"{np.mean(times):8.6f}±{np.std(times):6.4f} {hyperparameters}")
if sql:
command = store_string(
dataset[0], model, scores, times, hyperparameters, complexity
name, model, scores, times, hyperparameters, complexity
)
print(command, file=sql_output)
else:
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
dataset, verbose=True, model=model, params=parameters
)
record = dbh.find_best(dataset, model, "crossval")
best_accuracy, acc_best_std = get_best_score(dbh, dataset, model)
accuracy = np.mean(scores)
accuracy_best = record[5] if record is not None else 0.0
acc_best_std = record[11] if record is not None else 0.0
print(f"* Normalize/Standard.: {normalize} / {standardize}")
print(
f"* Accuracy Computed .: {accuracy:6.4f}±{np.std(scores):6.4f} "
f"{np.mean(times):5.3f}s"
)
print(f"* Accuracy Best .....: {accuracy_best:6.4f}±{acc_best_std:6.4f}")
print(f"* Difference ........: {accuracy_best - accuracy:6.4f}")
print(f"* Best Accuracy model: {best_accuracy:6.4f}±{acc_best_std:6.4f}")
print(f"* Difference ........: {best_accuracy - accuracy:6.4f}")
best_accuracy, acc_best_std = get_best_score(dbh, dataset, models_tree)
print(f"* Best Accuracy .....: {best_accuracy:6.4f}±{acc_best_std:6.4f}")
print(f"* Difference ........: {best_accuracy - accuracy:6.4f}")
print(
f"* Nodes/Leaves/Depth : {np.mean(nodes):.2f} {np.mean(leaves):.2f} "
f"{np.mean(depth):.2f} "