Files
stree_datasets/report_score.py

254 lines
7.6 KiB
Python

import argparse
import random
import time
from datetime import datetime
import json
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from stree import Stree
from sklearn.model_selection import KFold, cross_validate
from experimentation.Sets import Datasets
from experimentation.Database import MySQL
from wodt import TreeClassifier
def parse_arguments():
ap = argparse.ArgumentParser()
ap.add_argument(
"-S",
"--set-of-files",
type=str,
choices=["aaai", "tanveer"],
required=False,
default="tanveer",
)
ap.add_argument(
"-m",
"--model",
type=str,
required=False,
default="stree_default",
help="model name, default stree_default",
)
ap.add_argument(
"-d",
"--dataset",
type=str,
required=True,
help="dataset to process, all for everyone",
)
ap.add_argument(
"-s",
"--sql",
default=False,
type=bool,
required=False,
help="generate report_score.sql",
)
args = ap.parse_args()
return (args.set_of_files, args.model, args.dataset, args.sql)
def get_classifier(model, random_state, hyperparameters):
if model == "stree" or model == "stree_default":
clf = Stree(random_state=random_state)
clf.set_params(**hyperparameters)
if model == "wodt":
clf = TreeClassifier(random_state=random_state)
if model == "cart":
clf = DecisionTreeClassifier(random_state=random_state)
return clf
def process_dataset(dataset, verbose, model):
X, y = dt.load(dataset)
scores = []
times = []
nodes = []
leaves = []
depths = []
if verbose:
print(
f"* Processing dataset [{dataset}] from Set: {set_of_files} with "
f"{model}"
)
print(f"X.shape: {X.shape}")
print(f"{X[:4]}")
print(f"Random seeds: {random_seeds}")
hyperparameters = json.loads("{}")
if model == "stree":
# Get the optimized parameters
record = dbh.find_best(dataset, model, "gridsearch")
hyperparameters = json.loads(record[8] if record[8] != "" else "{}")
hyperparameters.pop("random_state", None)
for random_state in random_seeds:
random.seed(random_state)
np.random.seed(random_state)
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
clf = get_classifier(model, random_state, hyperparameters)
res = cross_validate(clf, X, y, cv=kfold, return_estimator=True)
scores.append(res["test_score"])
times.append(res["fit_time"])
for result_item in res["estimator"]:
if model == "cart":
nodes_item = result_item.tree_.node_count
depth_item = result_item.tree_.max_depth
leaves_item = result_item.get_n_leaves()
else:
nodes_item, leaves_item = result_item.nodes_leaves()
depth_item = result_item.depth_
nodes.append(nodes_item)
leaves.append(leaves_item)
depths.append(depth_item)
if verbose:
print(
f"Random seed: {random_state:5d} Accuracy: "
f"{res['test_score'].mean():6.4f}±"
f"{res['test_score'].std():6.4f} "
f"{res['fit_time'].mean():5.3f}s"
)
return scores, times, json.dumps(hyperparameters), nodes, leaves, depths
def store_string(
dataset, model, accuracy, time_spent, hyperparameters, complexity
):
attributes = [
"date",
"time",
"type",
"accuracy",
"accuracy_std",
"dataset",
"classifier",
"norm",
"stand",
"time_spent",
"time_spent_std",
"parameters",
"nodes",
"leaves",
"depth",
]
command_insert = (
"replace into results ("
+ ",".join(attributes)
+ ") values("
+ ("'%s'," * len(attributes))[:-1]
+ ");"
)
now = datetime.now()
date = now.strftime("%Y-%m-%d")
time = now.strftime("%H:%M:%S")
nodes, leaves, depth = complexity.values()
values = (
date,
time,
"crossval",
np.mean(accuracy),
np.std(accuracy),
dataset,
model,
1,
0,
np.mean(time_spent),
np.std(time_spent),
hyperparameters,
nodes,
leaves,
depth,
)
result = command_insert % values
return result
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
normalize = True
standardize = False
(set_of_files, model, dataset, sql) = parse_arguments()
dbh = MySQL()
if sql:
sql_output = open(f"{model}.sql", "w")
database = dbh.get_connection()
dt = Datasets(normalize, standardize, set_of_files)
start = time.time()
if dataset == "all":
print(
f"* Process all datasets set with {model}: {set_of_files} "
f"norm: {normalize} std: {standardize} store in: {model}"
)
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n")
header_cols = [
"Dataset",
"Samp",
"Var",
"Cls",
"Nodes",
"Leaves",
"Depth",
"Accuracy",
"Time",
"Parameters",
]
header_lengths = [30, 5, 3, 3, 7, 7, 7, 15, 15, 10]
line_col = ""
for field, underscore in zip(header_cols, header_lengths):
print(f"{field:{underscore}s} ", end="")
line_col += "=" * underscore + " "
print(f"\n{line_col}")
for dataset in dt:
X, y = dt.load(dataset[0]) # type: ignore
samples, features = X.shape
classes = len(np.unique(y))
print(
f"{dataset[0]:30s} {samples:5d} {features:3d} {classes:3d} ",
end="",
)
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
dataset[0], verbose=False, model=model
)
complexity = dict(
nodes=float(np.mean(nodes)),
leaves=float(np.mean(leaves)),
depth=float(np.mean(depth)),
)
nodes_item, leaves_item, depth_item = complexity.values()
print(
f"{nodes_item:7.2f} {leaves_item:7.2f} {depth_item:7.2f} ",
end="",
)
print(f"{np.mean(scores):8.6f}±{np.std(scores):6.4f} ", end="")
print(f"{np.mean(times):8.6f}±{np.std(times):6.4f} {hyperparameters}")
if sql:
command = store_string(
dataset[0], model, scores, times, hyperparameters, complexity
)
print(command, file=sql_output)
else:
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
dataset, verbose=True, model=model
)
record = dbh.find_best(dataset, model, "crossval")
accuracy = np.mean(scores)
accuracy_best = record[5] if record is not None else 0.0
acc_best_std = record[11] if record is not None else 0.0
print(f"* Normalize/Standard.: {normalize} / {standardize}")
print(
f"* Accuracy Computed .: {accuracy:6.4f}±{np.std(scores):6.4f} "
f"{np.mean(times):5.3f}s"
)
print(f"* Accuracy Best .....: {accuracy_best:6.4f}±{acc_best_std:6.4f}")
print(f"* Difference ........: {accuracy_best - accuracy:6.4f}")
print(
f"* Nodes/Leaves/Depth : {np.mean(nodes):.2f} {np.mean(leaves):.2f} "
f"{np.mean(depth):.2f} "
)
print(f"- Hyperparameters ...: {hyperparameters}")
stop = time.time()
hours, rem = divmod(stop - start, 3600)
minutes, seconds = divmod(rem, 60)
print(f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s")
if sql:
sql_output.close()
dbh.close()