From 47078208bc04f26a5425bde80ae3d92e556d1626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 23 Mar 2021 00:50:04 +0100 Subject: [PATCH] Add resport score for stree update param_analysis for stree only --- .gitignore | 2 + param_analysis.py | 6 +- report.csv | 96 +++++++++---------- report_score.py | 236 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 291 insertions(+), 49 deletions(-) create mode 100644 report_score.py diff --git a/.gitignore b/.gitignore index 39f1558..2fea4c8 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,5 @@ dmypy.json experimentation/.myconfig experimentation/.tunnel results +report_score.sql +datasets_types \ No newline at end of file diff --git a/param_analysis.py b/param_analysis.py index 5b909ea..e459f0a 100644 --- a/param_analysis.py +++ b/param_analysis.py @@ -14,6 +14,7 @@ class Aggregation: self._dbh = dbh self._report = {} self._model_names = ["stree", "adaBoost", "bagging", "odte"] + self._model_names = ["stree"] self._kernel_names = kernel_names def find_values(self, dataset, parameter): @@ -35,10 +36,13 @@ class Aggregation: print("Aggregating data of best results ...") for dataset in dt: if result := self._dbh.find_best(dataset[0]): + json_string = result[8] if result[8] != "" else "{}" accuracy = result[5] expected = result[10] model = result[3] - json_result = json.loads(result[8]) + if model != "stree": + continue + json_result = json.loads(json_string) if "kernel" in json_result.keys(): kernel = json_result["kernel"] elif "base_estimator__kernel" in json_result.keys(): diff --git a/report.csv b/report.csv index 772818b..c554b81 100644 --- a/report.csv +++ b/report.csv @@ -1,107 +1,107 @@ dataset, classifier, accuracy -balance-scale, stree, 0.97056 +balance-scale, stree, 0.91184 balance-scale, wodt, 0.912 balance-scale, j48svm, 0.94 balance-scale, oc1, 0.9192 balance-scale, cart, 0.78816 balance-scale, baseRaF, 0.706738 -balloons, stree, 0.86 +balloons, stree, 0.653333 balloons, wodt, 0.688333 balloons, j48svm, 0.595 balloons, oc1, 0.62 balloons, cart, 0.671667 balloons, baseRaF, 0.605 -breast-cancer-wisc-diag, stree, 0.972764 +breast-cancer-wisc-diag, stree, 0.968898 breast-cancer-wisc-diag, wodt, 0.967317 breast-cancer-wisc-diag, j48svm, 0.952878 breast-cancer-wisc-diag, oc1, 0.933477 breast-cancer-wisc-diag, cart, 0.93953 breast-cancer-wisc-diag, baseRaF, 0.965694 -breast-cancer-wisc-prog, stree, 0.811128 +breast-cancer-wisc-prog, stree, 0.802051 breast-cancer-wisc-prog, wodt, 0.710141 breast-cancer-wisc-prog, j48svm, 0.724038 breast-cancer-wisc-prog, oc1, 0.71 breast-cancer-wisc-prog, cart, 0.699833 breast-cancer-wisc-prog, baseRaF, 0.74485 -breast-cancer-wisc, stree, 0.965802 +breast-cancer-wisc, stree, 0.966661 breast-cancer-wisc, wodt, 0.946208 breast-cancer-wisc, j48svm, 0.967674 breast-cancer-wisc, oc1, 0.940194 breast-cancer-wisc, cart, 0.940629 breast-cancer-wisc, baseRaF, 0.942857 -breast-cancer, stree, 0.733158 +breast-cancer, stree, 0.734211 breast-cancer, wodt, 0.650236 breast-cancer, j48svm, 0.707719 breast-cancer, oc1, 0.649728 breast-cancer, cart, 0.65444 breast-cancer, baseRaF, 0.656438 -cardiotocography-10clases, stree, 0.712009 +cardiotocography-10clases, stree, 0.552558 cardiotocography-10clases, wodt, 0.773706 cardiotocography-10clases, j48svm, 0.830812 cardiotocography-10clases, oc1, 0.795528 cardiotocography-10clases, cart, 0.818864 cardiotocography-10clases, baseRaF, 0.774788 -cardiotocography-3clases, stree, 0.891956 +cardiotocography-3clases, stree, 0.35207 cardiotocography-3clases, wodt, 0.897509 cardiotocography-3clases, j48svm, 0.927327 cardiotocography-3clases, oc1, 0.899811 cardiotocography-3clases, cart, 0.929258 cardiotocography-3clases, baseRaF, 0.896715 -conn-bench-sonar-mines-rocks, stree, 0.71439 +conn-bench-sonar-mines-rocks, stree, 0.755528 conn-bench-sonar-mines-rocks, wodt, 0.824959 conn-bench-sonar-mines-rocks, j48svm, 0.73892 conn-bench-sonar-mines-rocks, oc1, 0.710798 conn-bench-sonar-mines-rocks, cart, 0.728711 conn-bench-sonar-mines-rocks, baseRaF, 0.772981 -cylinder-bands, stree, 0.687101 +cylinder-bands, stree, 0.715049 cylinder-bands, wodt, 0.704074 cylinder-bands, j48svm, 0.726351 cylinder-bands, oc1, 0.67106 cylinder-bands, cart, 0.712703 cylinder-bands, baseRaF, 0.675117 -dermatology, stree, 0.971833 +dermatology, stree, 0.966087 dermatology, wodt, 0.965557 dermatology, j48svm, 0.955735 dermatology, oc1, 0.916087 dermatology, cart, 0.932766 dermatology, baseRaF, 0.970723 -echocardiogram, stree, 0.814758 +echocardiogram, stree, 0.808832 echocardiogram, wodt, 0.733875 echocardiogram, j48svm, 0.805527 echocardiogram, oc1, 0.748291 echocardiogram, cart, 0.745043 echocardiogram, baseRaF, 0.753522 -fertility, stree, 0.88 +fertility, stree, 0.866 fertility, wodt, 0.785 fertility, j48svm, 0.857 fertility, oc1, 0.793 fertility, cart, 0.8 fertility, baseRaF, 0.798 -haberman-survival, stree, 0.727795 +haberman-survival, stree, 0.735637 haberman-survival, wodt, 0.664707 haberman-survival, j48svm, 0.714056 haberman-survival, oc1, 0.651634 haberman-survival, cart, 0.65 haberman-survival, baseRaF, 0.720133 -heart-hungarian, stree, 0.827522 +heart-hungarian, stree, 0.817674 heart-hungarian, wodt, 0.764909 heart-hungarian, j48svm, 0.785026 heart-hungarian, oc1, 0.758298 heart-hungarian, cart, 0.760508 heart-hungarian, baseRaF, 0.779804 -hepatitis, stree, 0.824516 +hepatitis, stree, 0.796129 hepatitis, wodt, 0.785806 hepatitis, j48svm, 0.761935 hepatitis, oc1, 0.756774 hepatitis, cart, 0.765161 hepatitis, baseRaF, 0.773671 -ilpd-indian-liver, stree, 0.719207 +ilpd-indian-liver, stree, 0.723498 ilpd-indian-liver, wodt, 0.676176 ilpd-indian-liver, j48svm, 0.690339 ilpd-indian-liver, oc1, 0.660139 ilpd-indian-liver, cart, 0.663423 ilpd-indian-liver, baseRaF, 0.696685 -ionosphere, stree, 0.953276 +ionosphere, stree, 0.866056 ionosphere, wodt, 0.88008 ionosphere, j48svm, 0.891984 ionosphere, oc1, 0.879742 @@ -113,181 +113,181 @@ iris, j48svm, 0.947333 iris, oc1, 0.948 iris, cart, 0.938667 iris, baseRaF, 0.953413 -led-display, stree, 0.703 +led-display, stree, 0.7007 led-display, wodt, 0.7049 led-display, j48svm, 0.7204 led-display, oc1, 0.6993 led-display, cart, 0.7037 led-display, baseRaF, 0.70178 -libras, stree, 0.788333 +libras, stree, 0.747778 libras, wodt, 0.764167 libras, j48svm, 0.66 libras, oc1, 0.645 libras, cart, 0.655 libras, baseRaF, 0.726722 -low-res-spect, stree, 0.865713 +low-res-spect, stree, 0.853102 low-res-spect, wodt, 0.856459 low-res-spect, j48svm, 0.83358 low-res-spect, oc1, 0.824671 low-res-spect, cart, 0.829206 low-res-spect, baseRaF, 0.790875 -lymphography, stree, 0.823425 +lymphography, stree, 0.77046 lymphography, wodt, 0.808782 lymphography, j48svm, 0.778552 lymphography, oc1, 0.734634 lymphography, cart, 0.766276 lymphography, baseRaF, 0.761622 -mammographic, stree, 0.817068 +mammographic, stree, 0.81915 mammographic, wodt, 0.759839 mammographic, j48svm, 0.821435 mammographic, oc1, 0.768805 mammographic, cart, 0.757131 mammographic, baseRaF, 0.780206 -molec-biol-promoter, stree, 0.767056 +molec-biol-promoter, stree, 0.764416 molec-biol-promoter, wodt, 0.798528 molec-biol-promoter, j48svm, 0.744935 molec-biol-promoter, oc1, 0.734805 molec-biol-promoter, cart, 0.748701 molec-biol-promoter, baseRaF, 0.667239 -musk-1, stree, 0.916388 +musk-1, stree, 0.843463 musk-1, wodt, 0.838914 musk-1, j48svm, 0.82693 musk-1, oc1, 0.776401 musk-1, cart, 0.780215 musk-1, baseRaF, 0.834034 -oocytes_merluccius_nucleus_4d, stree, 0.835125 +oocytes_merluccius_nucleus_4d, stree, 0.810657 oocytes_merluccius_nucleus_4d, wodt, 0.737673 oocytes_merluccius_nucleus_4d, j48svm, 0.741766 oocytes_merluccius_nucleus_4d, oc1, 0.743199 oocytes_merluccius_nucleus_4d, cart, 0.728265 oocytes_merluccius_nucleus_4d, baseRaF, 0.792313 -oocytes_merluccius_states_2f, stree, 0.87359 +oocytes_merluccius_states_2f, stree, 0.912434 oocytes_merluccius_states_2f, wodt, 0.895115 oocytes_merluccius_states_2f, j48svm, 0.901374 oocytes_merluccius_states_2f, oc1, 0.889223 oocytes_merluccius_states_2f, cart, 0.891193 oocytes_merluccius_states_2f, baseRaF, 0.910551 -oocytes_trisopterus_nucleus_2f, stree, 0.799995 +oocytes_trisopterus_nucleus_2f, stree, 0.800986 oocytes_trisopterus_nucleus_2f, wodt, 0.751431 oocytes_trisopterus_nucleus_2f, j48svm, 0.756587 oocytes_trisopterus_nucleus_2f, oc1, 0.747697 oocytes_trisopterus_nucleus_2f, cart, 0.734313 oocytes_trisopterus_nucleus_2f, baseRaF, 0.76193 -oocytes_trisopterus_states_5b, stree, 0.924441 +oocytes_trisopterus_states_5b, stree, 0.9023 oocytes_trisopterus_states_5b, wodt, 0.89165 oocytes_trisopterus_states_5b, j48svm, 0.887943 oocytes_trisopterus_states_5b, oc1, 0.86393 oocytes_trisopterus_states_5b, cart, 0.870263 oocytes_trisopterus_states_5b, baseRaF, 0.922149 -parkinsons, stree, 0.865641 +parkinsons, stree, 0.882051 parkinsons, wodt, 0.901538 parkinsons, j48svm, 0.844615 parkinsons, oc1, 0.865641 parkinsons, cart, 0.855897 parkinsons, baseRaF, 0.87924 -pima, stree, 0.764053 +pima, stree, 0.766651 pima, wodt, 0.681591 pima, j48svm, 0.749876 pima, oc1, 0.693027 pima, cart, 0.701172 pima, baseRaF, 0.697005 -pittsburg-bridges-MATERIAL, stree, 0.867749 +pittsburg-bridges-MATERIAL, stree, 0.787446 pittsburg-bridges-MATERIAL, wodt, 0.79961 pittsburg-bridges-MATERIAL, j48svm, 0.855844 pittsburg-bridges-MATERIAL, oc1, 0.81026 pittsburg-bridges-MATERIAL, cart, 0.783593 pittsburg-bridges-MATERIAL, baseRaF, 0.81136 -pittsburg-bridges-REL-L, stree, 0.564048 +pittsburg-bridges-REL-L, stree, 0.62519 pittsburg-bridges-REL-L, wodt, 0.617143 pittsburg-bridges-REL-L, j48svm, 0.645048 pittsburg-bridges-REL-L, oc1, 0.604957 pittsburg-bridges-REL-L, cart, 0.625333 pittsburg-bridges-REL-L, baseRaF, 0.622107 -pittsburg-bridges-SPAN, stree, 0.658713 +pittsburg-bridges-SPAN, stree, 0.630234 pittsburg-bridges-SPAN, wodt, 0.606959 pittsburg-bridges-SPAN, j48svm, 0.621579 pittsburg-bridges-SPAN, oc1, 0.579333 pittsburg-bridges-SPAN, cart, 0.557544 pittsburg-bridges-SPAN, baseRaF, 0.630217 -pittsburg-bridges-T-OR-D, stree, 0.849952 +pittsburg-bridges-T-OR-D, stree, 0.861619 pittsburg-bridges-T-OR-D, wodt, 0.818429 pittsburg-bridges-T-OR-D, j48svm, 0.838333 pittsburg-bridges-T-OR-D, oc1, 0.831545 pittsburg-bridges-T-OR-D, cart, 0.821619 pittsburg-bridges-T-OR-D, baseRaF, 0.821007 -planning, stree, 0.73527 +planning, stree, 0.70455 planning, wodt, 0.576847 planning, j48svm, 0.711381 planning, oc1, 0.566988 planning, cart, 0.586712 planning, baseRaF, 0.590586 -post-operative, stree, 0.703333 +post-operative, stree, 0.573333 post-operative, wodt, 0.535556 post-operative, j48svm, 0.701111 post-operative, oc1, 0.542222 post-operative, cart, 0.567778 post-operative, baseRaF, 0.539375 -seeds, stree, 0.952857 +seeds, stree, 0.949048 seeds, wodt, 0.940476 seeds, j48svm, 0.909524 seeds, oc1, 0.932381 seeds, cart, 0.900476 seeds, baseRaF, 0.942518 -statlog-australian-credit, stree, 0.678261 +statlog-australian-credit, stree, 0.667246 statlog-australian-credit, wodt, 0.561594 statlog-australian-credit, j48svm, 0.66029 statlog-australian-credit, oc1, 0.573913 statlog-australian-credit, cart, 0.595507 statlog-australian-credit, baseRaF, 0.678261 -statlog-german-credit, stree, 0.7569 +statlog-german-credit, stree, 0.7625 statlog-german-credit, wodt, 0.6929 statlog-german-credit, j48svm, 0.7244 statlog-german-credit, oc1, 0.6874 statlog-german-credit, cart, 0.6738 statlog-german-credit, baseRaF, 0.68762 -statlog-heart, stree, 0.822222 +statlog-heart, stree, 0.822963 statlog-heart, wodt, 0.777778 statlog-heart, j48svm, 0.795926 statlog-heart, oc1, 0.749259 statlog-heart, cart, 0.762222 statlog-heart, baseRaF, 0.747605 -statlog-image, stree, 0.956623 +statlog-image, stree, 0.850649 statlog-image, wodt, 0.954632 statlog-image, j48svm, 0.967403 statlog-image, oc1, 0.95013 statlog-image, cart, 0.964892 statlog-image, baseRaF, 0.953604 -statlog-vehicle, stree, 0.788537 +statlog-vehicle, stree, 0.695151 statlog-vehicle, wodt, 0.726492 statlog-vehicle, j48svm, 0.729651 statlog-vehicle, oc1, 0.708496 statlog-vehicle, cart, 0.728367 statlog-vehicle, baseRaF, 0.789572 -synthetic-control, stree, 0.95 +synthetic-control, stree, 0.938833 synthetic-control, wodt, 0.973167 synthetic-control, j48svm, 0.922333 synthetic-control, oc1, 0.863167 synthetic-control, cart, 0.908333 synthetic-control, baseRaF, 0.971567 -tic-tac-toe, stree, 0.984444 +tic-tac-toe, stree, 0.983296 tic-tac-toe, wodt, 0.93905 tic-tac-toe, j48svm, 0.983295 tic-tac-toe, oc1, 0.91849 tic-tac-toe, cart, 0.951558 tic-tac-toe, baseRaF, 0.974906 -vertebral-column-2clases, stree, 0.851936 +vertebral-column-2clases, stree, 0.852903 vertebral-column-2clases, wodt, 0.801935 vertebral-column-2clases, j48svm, 0.84871 vertebral-column-2clases, oc1, 0.815161 vertebral-column-2clases, cart, 0.784839 vertebral-column-2clases, baseRaF, 0.822601 -wine, stree, 0.949333 +wine, stree, 0.97581 wine, wodt, 0.973048 wine, j48svm, 0.979143 wine, oc1, 0.916165 wine, cart, 0.921937 wine, baseRaF, 0.97748 -zoo, stree, 0.955524 +zoo, stree, 0.947619 zoo, wodt, 0.954429 zoo, j48svm, 0.92381 zoo, oc1, 0.890952 diff --git a/report_score.py b/report_score.py new file mode 100644 index 0000000..714be80 --- /dev/null +++ b/report_score.py @@ -0,0 +1,236 @@ +import argparse +import random +import time +from datetime import datetime +import json +import numpy as np +from stree import Stree +from sklearn.model_selection import KFold, cross_validate +from experimentation.Sets import Datasets +from experimentation.Database import MySQL + +8 + + +def parse_arguments(): + ap = argparse.ArgumentParser() + ap.add_argument( + "-S", + "--set-of-files", + type=str, + choices=["aaai", "tanveer"], + required=False, + default="tanveer", + ) + ap.add_argument( + "-m", + "--model", + type=str, + required=False, + default="stree", + help="model name, default stree", + ) + ap.add_argument( + "-d", + "--dataset", + type=str, + required=True, + help="dataset to process, all for everyone", + ) + ap.add_argument( + "-s", + "--sql", + default=False, + type=bool, + required=False, + help="generate report_score.sql", + ) + ap.add_argument( + "-p", + "--param", + default=False, + type=bool, + required=False, + help="Auto generate params", + ) + args = ap.parse_args() + return (args.set_of_files, args.model, args.dataset, args.sql, args.param) + + +def compute_auto_hyperparams(X, y): + params = {"max_iter": 1e4, "C": 0.1} + classes = len(np.unique(y)) + if classes > 2: + params["split_criteria"] = "max_samples" + return params + + +def process_dataset(dataset, verbose, model, auto_params): + X, y = dt.load(dataset) + scores = [] + times = [] + if verbose: + print( + f"* Processing dataset [{dataset}] from Set: {set_of_files} with " + f"{model}" + ) + print(f"X.shape: {X.shape}") + print(f"{X[:4]}") + print(f"Random seeds: {random_seeds}") + if auto_params: + hyperparameters = compute_auto_hyperparams(X, y) + else: + hyperparameters = {} + for random_state in random_seeds: + random.seed(random_state) + np.random.seed(random_state) + kfold = KFold(shuffle=True, random_state=random_state, n_splits=5) + clf = Stree(random_state=random_state) + clf.set_params(**hyperparameters) + res = cross_validate(clf, X, y, cv=kfold) + scores.append(res["test_score"]) + times.append(res["fit_time"]) + if verbose: + print( + f"Random seed: {random_state:5d} Accuracy: " + f"{res['test_score'].mean():6.4f}±" + f"{res['test_score'].std():6.4f} " + f"{res['fit_time'].mean():5.3f}s" + ) + return scores, times, json.dumps(hyperparameters) + + +def store_string(dataset, model, accuracy, time_spent, hyperparameters): + attributes = [ + "date", + "time", + "type", + "accuracy", + "accuracy_std", + "dataset", + "classifier", + "norm", + "stand", + "time_spent", + "time_spent_std", + "parameters", + ] + command_insert = ( + "replace into results (" + + ",".join(attributes) + + ") values(" + + ("'%s'," * len(attributes))[:-1] + + ");" + ) + now = datetime.now() + date = now.strftime("%Y-%m-%d") + time = now.strftime("%H:%M:%S") + values = ( + date, + time, + "crossval", + np.mean(accuracy), + np.std(accuracy), + dataset, + model, + True, + False, + np.mean(time_spent), + np.std(time_spent), + hyperparameters, + ) + result = command_insert % values + return result + + +random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] +normalize = True +standardize = False +(set_of_files, model, dataset, sql, auto_params) = parse_arguments() +dbh = MySQL() +if sql: + sql_output = open("report_score.sql", "w") +database = dbh.get_connection() +dt = Datasets(normalize, standardize, set_of_files) +start = time.time() +if dataset == "all": + print( + f"* Process all datasets set with {model}: {set_of_files} " + f"norm: {normalize} std: {standardize}" + ) + print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}\n") + print( + "{0:30s} {5:4s} {6:3s} {7:2s} {1:13s} {2:13s} {3:8s} {4:90s}".format( + "Dataset", + "Acc. computed", + "Best Accuracy", + "Diff.", + "Best accuracy hyperparameters", + "Samp", + "Var", + "Cls", + ) + ) + print("=" * 30, end=" ") + print("=" * 4, end=" ") + print("=" * 3, end=" ") + print("=" * 3, end=" ") + print("=" * 13, end=" ") + print("=" * 13, end=" ") + print("=" * 8, end=" ") + print("=" * 90) + for dataset in dt: + X, y = dt.load(dataset[0]) # type: ignore + samples, features = X.shape + classes = len(np.unique(y)) + print( + f"{dataset[0]:30s} {samples:4d} {features:3d} " f"{classes:3d} ", + end="", + ) + scores, times, hyperparameters = process_dataset( + dataset[0], verbose=False, model=model, auto_params=auto_params + ) + record = dbh.find_best(dataset[0], model, "crossval") + if record is not None: + parameters = json.loads(record[8] if record[8] != "" else "{}") + parameters.pop("random_state", None) + accuracy_best = record[5] + acc_best_std = record[11] + else: + parameters = json.loads("{}") + accuracy_best = 0.0 + acc_best_std = 0.0 + accuracy_computed = np.mean(scores) + diff = accuracy_best - accuracy_computed + print( + f"{accuracy_computed:6.4f}±{np.std(scores):6.4f} " + f"{accuracy_best:6.4f}±{acc_best_std:6.4f} {diff:8.5f} " + f"{json.dumps(parameters):40s}" + ) + if sql: + command = store_string( + dataset[0], model, scores, times, hyperparameters + ) + print(command, file=sql_output) +else: + scores, times, hyperparameters = process_dataset( + dataset, verbose=True, model=model, auto_params=auto_params + ) + record = dbh.find_best(dataset, model, "crossval") + accuracy = np.mean(scores) + accuracy_best = record[5] if record is not None else 0.0 + acc_best_std = record[11] if record is not None else 0.0 + print( + f"* Accuracy Computed : {accuracy:6.4f}±{np.std(scores):6.4f} " + f"{np.mean(times):5.3f}s" + ) + print(f"* Accuracy Best ....: {accuracy_best:6.4f}±{acc_best_std:6.4f}") + print(f"* Difference .......: {accuracy_best - accuracy:6.4f}") +stop = time.time() +print(f"- Auto Hyperparams .: {hyperparameters}") +hours, rem = divmod(stop - start, 3600) +minutes, seconds = divmod(rem, 60) +print(f"Time: {int(hours):2d}h {int(minutes):2d}m {int(seconds):2d}s") +if sql: + sql_output.close() +dbh.close()