mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 23:46:03 +00:00
Add nodes, leaves and depth to report_score
This commit is contained in:
@@ -9,8 +9,6 @@ from sklearn.model_selection import KFold, cross_validate
|
||||
from experimentation.Sets import Datasets
|
||||
from experimentation.Database import MySQL
|
||||
|
||||
8
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
ap = argparse.ArgumentParser()
|
||||
@@ -57,6 +55,17 @@ def parse_arguments():
|
||||
return (args.set_of_files, args.model, args.dataset, args.sql, args.param)
|
||||
|
||||
|
||||
def nodes_leaves(clf):
|
||||
nodes = 0
|
||||
leaves = 0
|
||||
for node in clf:
|
||||
if node.is_leaf():
|
||||
leaves += 1
|
||||
else:
|
||||
nodes += 1
|
||||
return nodes, leaves
|
||||
|
||||
|
||||
def compute_auto_hyperparams(X, y):
|
||||
params = {"max_iter": 1e4, "C": 0.1}
|
||||
classes = len(np.unique(y))
|
||||
@@ -87,7 +96,9 @@ def process_dataset(dataset, verbose, model, auto_params):
|
||||
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
|
||||
clf = Stree(random_state=random_state)
|
||||
clf.set_params(**hyperparameters)
|
||||
res = cross_validate(clf, X, y, cv=kfold)
|
||||
res = cross_validate(clf, X, y, cv=kfold, return_estimator=True)
|
||||
nodes, leaves = nodes_leaves(res["estimator"][0])
|
||||
depth = res["estimator"][0].depth_
|
||||
scores.append(res["test_score"])
|
||||
times.append(res["fit_time"])
|
||||
if verbose:
|
||||
@@ -97,7 +108,7 @@ def process_dataset(dataset, verbose, model, auto_params):
|
||||
f"{res['test_score'].std():6.4f} "
|
||||
f"{res['fit_time'].mean():5.3f}s"
|
||||
)
|
||||
return scores, times, json.dumps(hyperparameters)
|
||||
return scores, times, json.dumps(hyperparameters), nodes, leaves, depth
|
||||
|
||||
|
||||
def store_string(dataset, model, accuracy, time_spent, hyperparameters):
|
||||
@@ -160,7 +171,8 @@ if dataset == "all":
|
||||
)
|
||||
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n")
|
||||
print(
|
||||
"{0:30s} {5:4s} {6:3s} {7:2s} {1:13s} {2:13s} {3:8s} {4:90s}".format(
|
||||
"{0:30s} {5:4s} {6:3s} {7:2s} {8:2s} {9:2s} {10:2s} {1:13s} {2:13s} "
|
||||
"{3:8s} {4:90s}".format(
|
||||
"Dataset",
|
||||
"Acc. computed",
|
||||
"Best Accuracy",
|
||||
@@ -169,12 +181,18 @@ if dataset == "all":
|
||||
"Samp",
|
||||
"Var",
|
||||
"Cls",
|
||||
"N",
|
||||
"L",
|
||||
"D",
|
||||
)
|
||||
)
|
||||
print("=" * 30, end=" ")
|
||||
print("=" * 4, end=" ")
|
||||
print("=" * 3, end=" ")
|
||||
print("=" * 3, end=" ")
|
||||
print("=" * 2, end=" ")
|
||||
print("=" * 2, end=" ")
|
||||
print("=" * 2, end=" ")
|
||||
print("=" * 13, end=" ")
|
||||
print("=" * 13, end=" ")
|
||||
print("=" * 8, end=" ")
|
||||
@@ -187,9 +205,13 @@ if dataset == "all":
|
||||
f"{dataset[0]:30s} {samples:4d} {features:3d} " f"{classes:3d} ",
|
||||
end="",
|
||||
)
|
||||
scores, times, hyperparameters = process_dataset(
|
||||
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
|
||||
dataset[0], verbose=False, model=model, auto_params=auto_params
|
||||
)
|
||||
print(
|
||||
f"{nodes:2d} {leaves:2d} " f"{depth:2d} ",
|
||||
end="",
|
||||
)
|
||||
record = dbh.find_best(dataset[0], model, "crossval")
|
||||
if record is not None:
|
||||
parameters = json.loads(record[8] if record[8] != "" else "{}")
|
||||
@@ -213,7 +235,7 @@ if dataset == "all":
|
||||
)
|
||||
print(command, file=sql_output)
|
||||
else:
|
||||
scores, times, hyperparameters = process_dataset(
|
||||
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
|
||||
dataset, verbose=True, model=model, auto_params=auto_params
|
||||
)
|
||||
record = dbh.find_best(dataset, model, "crossval")
|
||||
@@ -224,10 +246,11 @@ else:
|
||||
f"* Accuracy Computed : {accuracy:6.4f}±{np.std(scores):6.4f} "
|
||||
f"{np.mean(times):5.3f}s"
|
||||
)
|
||||
print(f"* Accuracy Best ....: {accuracy_best:6.4f}±{acc_best_std:6.4f}")
|
||||
print(f"* Difference .......: {accuracy_best - accuracy:6.4f}")
|
||||
print(f"* Accuracy Best .....: {accuracy_best:6.4f}±{acc_best_std:6.4f}")
|
||||
print(f"* Difference ........: {accuracy_best - accuracy:6.4f}")
|
||||
print(f"* Nodes/Leaves/Depth :{nodes:2d} {leaves:2d} " f"{depth:2d} ")
|
||||
stop = time.time()
|
||||
print(f"- Auto Hyperparams .: {hyperparameters}")
|
||||
print(f"- Auto Hyperparams ..: {hyperparameters}")
|
||||
hours, rem = divmod(stop - start, 3600)
|
||||
minutes, seconds = divmod(rem, 60)
|
||||
print(f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s")
|
||||
|
Reference in New Issue
Block a user