diff --git a/analysis_mysql.py b/analysis_mysql.py index 9c6a8bf..4453da6 100644 --- a/analysis_mysql.py +++ b/analysis_mysql.py @@ -1,5 +1,6 @@ import argparse from typing import Tuple +import numpy as np from experimentation.Sets import Datasets from experimentation.Utils import TextColor from experimentation.Database import MySQL @@ -14,8 +15,10 @@ models_tree = [ "baseRaF", ] models_ensemble = ["odte", "adaBoost", "bagging", "TBRaF", "TBRoF", "TBRRoF"] +description = ["samp", "var", "cls"] +complexity = ["nodes", "leaves", "depth"] title = "Best model results" -lengths = (30, 12, 12, 12, 12, 12, 12) +lengths = (30, 4, 3, 3, 3, 3, 3, 12, 12, 12, 12, 12, 12) def parse_arguments() -> Tuple[str, str, str, bool, bool]: @@ -79,9 +82,11 @@ def report_header(title, experiment, model_type): def report_line(line): output = f"{line['dataset']:{lengths[0] + 5}s} " + for key, item in enumerate(description + complexity): + output += f"{line[item]:{lengths[key + 1]}d} " data = models.copy() for key, model in enumerate(data): - output += f"{line[model]:{lengths[key + 1]}s} " + output += f"{line[model]:{lengths[key + 7]}s} " return output @@ -101,7 +106,15 @@ def report_footer(agg): dbh = MySQL() database = dbh.get_connection() dt = Datasets(False, False, "tanveer") -fields = ("Dataset",) +fields = ( + "Dataset", + "Samp", + "Var", + "Cls", + "Nod", + "Lea", + "Dep", +) models = models_tree if model_type == "tree" else models_ensemble for item in models: fields += (f"{item}",) @@ -121,13 +134,23 @@ for dataset in dt: find_one = False # Look for max accuracy for any given dataset line = {"dataset": color + dataset[0]} + X, y = dt.load(dataset[0]) # type: ignore + line["samp"], line["var"] = X.shape + line["cls"] = len(np.unique(y)) record = dbh.find_best(dataset[0], models, experiment) max_accuracy = 0.0 if record is None else record[5] + line["nodes"] = 0 + line["leaves"] = 0 + line["depth"] = 0 for model in models: record = dbh.find_best(dataset[0], model, experiment) if record is None: line[model] = color + "-" * 12 else: + if model == "stree": + line["nodes"] = record[12] + line["leaves"] = record[13] + line["depth"] = record[14] reference = record[13] accuracy = record[5] acc_std = record[11] diff --git a/experimentation/Database.py b/experimentation/Database.py index ca1b90b..5a2cfe3 100644 --- a/experimentation/Database.py +++ b/experimentation/Database.py @@ -176,6 +176,7 @@ class BD(ABC): accuracy, time_spent, parameters, + complexity, ) -> None: """Create a record in MySQL database @@ -187,8 +188,8 @@ class BD(ABC): command_insert = ( "replace into results (date, time, type, accuracy, " "dataset, classifier, norm, stand, parameters, accuracy_std, " - "time_spent, time_spent_std) values (%s, %s, " - "%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)" + "time_spent, time_spent_std, nodes, leaves, depth) values (%s, %s," + " %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)" ) now = datetime.now() date = now.strftime("%Y-%m-%d") @@ -206,6 +207,9 @@ class BD(ABC): accuracy[1], time_spent[0], time_spent[1], + complexity["nodes"], + complexity["leaves"], + complexity["depth"], ) cursor = database.cursor() cursor.execute(command_insert, values) @@ -319,7 +323,9 @@ class Outcomes(BD): self._table = "outcomes" super().__init__(host=host, model=model) - def store(self, dataset, normalize, standardize, parameters, results): + def store( + self, dataset, normalize, standardize, parameters, results, complexity + ): outcomes = ["fit_time", "score_time", "train_score", "test_score"] data = "" for index in outcomes: @@ -350,6 +356,7 @@ class Outcomes(BD): float(np.std(results["fit_time"])), ], parameters, + complexity, ) def report(self, dataset, exclude_params): diff --git a/experimentation/Experiments.py b/experimentation/Experiments.py index 3227df3..732ee9c 100644 --- a/experimentation/Experiments.py +++ b/experimentation/Experiments.py @@ -2,9 +2,8 @@ import json import os import time import warnings - +import numpy as np from sklearn.model_selection import GridSearchCV, cross_validate - from . import Models from .Database import Hyperparameters, MySQL, Outcomes from .Sets import Datasets @@ -94,15 +93,25 @@ class Experiment: X, y, return_train_score=True, + return_estimator=True, n_jobs=self._threads, cv=kfold, ) for item in outcomes: total[item].append(results[item]) print("end") + if type(model).__name__ == "Stree": + best_model = results["estimator"][np.argmax(results["test_score"])] + nodes, leaves = best_model.nodes_leaves() + depth = best_model.depth_ + else: + nodes = leaves = depth = 0 + complexity = dict(nodes=nodes, leaves=leaves, depth=depth) outcomes = Outcomes(host=self._host, model=self._model_name) parameters = json.dumps(parameters, sort_keys=True) - outcomes.store(dataset, normalize, standardize, parameters, total) + outcomes.store( + dataset, normalize, standardize, parameters, total, complexity + ) if self._num_warnings > 0: print(f"{self._num_warnings} warnings have happend") diff --git a/report.csv b/report.csv index c554b81..15a523f 100644 --- a/report.csv +++ b/report.csv @@ -1,107 +1,107 @@ dataset, classifier, accuracy -balance-scale, stree, 0.91184 +balance-scale, stree, 0.97056 balance-scale, wodt, 0.912 balance-scale, j48svm, 0.94 balance-scale, oc1, 0.9192 balance-scale, cart, 0.78816 balance-scale, baseRaF, 0.706738 -balloons, stree, 0.653333 +balloons, stree, 0.86 balloons, wodt, 0.688333 balloons, j48svm, 0.595 balloons, oc1, 0.62 balloons, cart, 0.671667 balloons, baseRaF, 0.605 -breast-cancer-wisc-diag, stree, 0.968898 +breast-cancer-wisc-diag, stree, 0.972764 breast-cancer-wisc-diag, wodt, 0.967317 breast-cancer-wisc-diag, j48svm, 0.952878 breast-cancer-wisc-diag, oc1, 0.933477 breast-cancer-wisc-diag, cart, 0.93953 breast-cancer-wisc-diag, baseRaF, 0.965694 -breast-cancer-wisc-prog, stree, 0.802051 +breast-cancer-wisc-prog, stree, 0.811128 breast-cancer-wisc-prog, wodt, 0.710141 breast-cancer-wisc-prog, j48svm, 0.724038 breast-cancer-wisc-prog, oc1, 0.71 breast-cancer-wisc-prog, cart, 0.699833 breast-cancer-wisc-prog, baseRaF, 0.74485 -breast-cancer-wisc, stree, 0.966661 +breast-cancer-wisc, stree, 0.965802 breast-cancer-wisc, wodt, 0.946208 breast-cancer-wisc, j48svm, 0.967674 breast-cancer-wisc, oc1, 0.940194 breast-cancer-wisc, cart, 0.940629 breast-cancer-wisc, baseRaF, 0.942857 -breast-cancer, stree, 0.734211 +breast-cancer, stree, 0.733158 breast-cancer, wodt, 0.650236 breast-cancer, j48svm, 0.707719 breast-cancer, oc1, 0.649728 breast-cancer, cart, 0.65444 breast-cancer, baseRaF, 0.656438 -cardiotocography-10clases, stree, 0.552558 +cardiotocography-10clases, stree, 0.712009 cardiotocography-10clases, wodt, 0.773706 cardiotocography-10clases, j48svm, 0.830812 cardiotocography-10clases, oc1, 0.795528 cardiotocography-10clases, cart, 0.818864 cardiotocography-10clases, baseRaF, 0.774788 -cardiotocography-3clases, stree, 0.35207 +cardiotocography-3clases, stree, 0.891956 cardiotocography-3clases, wodt, 0.897509 cardiotocography-3clases, j48svm, 0.927327 cardiotocography-3clases, oc1, 0.899811 cardiotocography-3clases, cart, 0.929258 cardiotocography-3clases, baseRaF, 0.896715 -conn-bench-sonar-mines-rocks, stree, 0.755528 +conn-bench-sonar-mines-rocks, stree, 0.71439 conn-bench-sonar-mines-rocks, wodt, 0.824959 conn-bench-sonar-mines-rocks, j48svm, 0.73892 conn-bench-sonar-mines-rocks, oc1, 0.710798 conn-bench-sonar-mines-rocks, cart, 0.728711 conn-bench-sonar-mines-rocks, baseRaF, 0.772981 -cylinder-bands, stree, 0.715049 +cylinder-bands, stree, 0.687101 cylinder-bands, wodt, 0.704074 cylinder-bands, j48svm, 0.726351 cylinder-bands, oc1, 0.67106 cylinder-bands, cart, 0.712703 cylinder-bands, baseRaF, 0.675117 -dermatology, stree, 0.966087 +dermatology, stree, 0.971833 dermatology, wodt, 0.965557 dermatology, j48svm, 0.955735 dermatology, oc1, 0.916087 dermatology, cart, 0.932766 dermatology, baseRaF, 0.970723 -echocardiogram, stree, 0.808832 +echocardiogram, stree, 0.814758 echocardiogram, wodt, 0.733875 echocardiogram, j48svm, 0.805527 echocardiogram, oc1, 0.748291 echocardiogram, cart, 0.745043 echocardiogram, baseRaF, 0.753522 -fertility, stree, 0.866 +fertility, stree, 0.88 fertility, wodt, 0.785 fertility, j48svm, 0.857 fertility, oc1, 0.793 fertility, cart, 0.8 fertility, baseRaF, 0.798 -haberman-survival, stree, 0.735637 +haberman-survival, stree, 0.727795 haberman-survival, wodt, 0.664707 haberman-survival, j48svm, 0.714056 haberman-survival, oc1, 0.651634 haberman-survival, cart, 0.65 haberman-survival, baseRaF, 0.720133 -heart-hungarian, stree, 0.817674 +heart-hungarian, stree, 0.827522 heart-hungarian, wodt, 0.764909 heart-hungarian, j48svm, 0.785026 heart-hungarian, oc1, 0.758298 heart-hungarian, cart, 0.760508 heart-hungarian, baseRaF, 0.779804 -hepatitis, stree, 0.796129 +hepatitis, stree, 0.824516 hepatitis, wodt, 0.785806 hepatitis, j48svm, 0.761935 hepatitis, oc1, 0.756774 hepatitis, cart, 0.765161 hepatitis, baseRaF, 0.773671 -ilpd-indian-liver, stree, 0.723498 +ilpd-indian-liver, stree, 0.719207 ilpd-indian-liver, wodt, 0.676176 ilpd-indian-liver, j48svm, 0.690339 ilpd-indian-liver, oc1, 0.660139 ilpd-indian-liver, cart, 0.663423 ilpd-indian-liver, baseRaF, 0.696685 -ionosphere, stree, 0.866056 +ionosphere, stree, 0.953276 ionosphere, wodt, 0.88008 ionosphere, j48svm, 0.891984 ionosphere, oc1, 0.879742 @@ -113,49 +113,49 @@ iris, j48svm, 0.947333 iris, oc1, 0.948 iris, cart, 0.938667 iris, baseRaF, 0.953413 -led-display, stree, 0.7007 +led-display, stree, 0.703 led-display, wodt, 0.7049 led-display, j48svm, 0.7204 led-display, oc1, 0.6993 led-display, cart, 0.7037 led-display, baseRaF, 0.70178 -libras, stree, 0.747778 +libras, stree, 0.788333 libras, wodt, 0.764167 libras, j48svm, 0.66 libras, oc1, 0.645 libras, cart, 0.655 libras, baseRaF, 0.726722 -low-res-spect, stree, 0.853102 +low-res-spect, stree, 0.865713 low-res-spect, wodt, 0.856459 low-res-spect, j48svm, 0.83358 low-res-spect, oc1, 0.824671 low-res-spect, cart, 0.829206 low-res-spect, baseRaF, 0.790875 -lymphography, stree, 0.77046 +lymphography, stree, 0.823425 lymphography, wodt, 0.808782 lymphography, j48svm, 0.778552 lymphography, oc1, 0.734634 lymphography, cart, 0.766276 lymphography, baseRaF, 0.761622 -mammographic, stree, 0.81915 +mammographic, stree, 0.817068 mammographic, wodt, 0.759839 mammographic, j48svm, 0.821435 mammographic, oc1, 0.768805 mammographic, cart, 0.757131 mammographic, baseRaF, 0.780206 -molec-biol-promoter, stree, 0.764416 +molec-biol-promoter, stree, 0.767056 molec-biol-promoter, wodt, 0.798528 molec-biol-promoter, j48svm, 0.744935 molec-biol-promoter, oc1, 0.734805 molec-biol-promoter, cart, 0.748701 molec-biol-promoter, baseRaF, 0.667239 -musk-1, stree, 0.843463 +musk-1, stree, 0.916388 musk-1, wodt, 0.838914 musk-1, j48svm, 0.82693 musk-1, oc1, 0.776401 musk-1, cart, 0.780215 musk-1, baseRaF, 0.834034 -oocytes_merluccius_nucleus_4d, stree, 0.810657 +oocytes_merluccius_nucleus_4d, stree, 0.835125 oocytes_merluccius_nucleus_4d, wodt, 0.737673 oocytes_merluccius_nucleus_4d, j48svm, 0.741766 oocytes_merluccius_nucleus_4d, oc1, 0.743199 @@ -167,127 +167,127 @@ oocytes_merluccius_states_2f, j48svm, 0.901374 oocytes_merluccius_states_2f, oc1, 0.889223 oocytes_merluccius_states_2f, cart, 0.891193 oocytes_merluccius_states_2f, baseRaF, 0.910551 -oocytes_trisopterus_nucleus_2f, stree, 0.800986 +oocytes_trisopterus_nucleus_2f, stree, 0.799995 oocytes_trisopterus_nucleus_2f, wodt, 0.751431 oocytes_trisopterus_nucleus_2f, j48svm, 0.756587 oocytes_trisopterus_nucleus_2f, oc1, 0.747697 oocytes_trisopterus_nucleus_2f, cart, 0.734313 oocytes_trisopterus_nucleus_2f, baseRaF, 0.76193 -oocytes_trisopterus_states_5b, stree, 0.9023 +oocytes_trisopterus_states_5b, stree, 0.924441 oocytes_trisopterus_states_5b, wodt, 0.89165 oocytes_trisopterus_states_5b, j48svm, 0.887943 oocytes_trisopterus_states_5b, oc1, 0.86393 oocytes_trisopterus_states_5b, cart, 0.870263 oocytes_trisopterus_states_5b, baseRaF, 0.922149 -parkinsons, stree, 0.882051 +parkinsons, stree, 0.865641 parkinsons, wodt, 0.901538 parkinsons, j48svm, 0.844615 parkinsons, oc1, 0.865641 parkinsons, cart, 0.855897 parkinsons, baseRaF, 0.87924 -pima, stree, 0.766651 +pima, stree, 0.764053 pima, wodt, 0.681591 pima, j48svm, 0.749876 pima, oc1, 0.693027 pima, cart, 0.701172 pima, baseRaF, 0.697005 -pittsburg-bridges-MATERIAL, stree, 0.787446 +pittsburg-bridges-MATERIAL, stree, 0.867749 pittsburg-bridges-MATERIAL, wodt, 0.79961 pittsburg-bridges-MATERIAL, j48svm, 0.855844 pittsburg-bridges-MATERIAL, oc1, 0.81026 pittsburg-bridges-MATERIAL, cart, 0.783593 pittsburg-bridges-MATERIAL, baseRaF, 0.81136 -pittsburg-bridges-REL-L, stree, 0.62519 +pittsburg-bridges-REL-L, stree, 0.564048 pittsburg-bridges-REL-L, wodt, 0.617143 pittsburg-bridges-REL-L, j48svm, 0.645048 pittsburg-bridges-REL-L, oc1, 0.604957 pittsburg-bridges-REL-L, cart, 0.625333 pittsburg-bridges-REL-L, baseRaF, 0.622107 -pittsburg-bridges-SPAN, stree, 0.630234 +pittsburg-bridges-SPAN, stree, 0.658713 pittsburg-bridges-SPAN, wodt, 0.606959 pittsburg-bridges-SPAN, j48svm, 0.621579 pittsburg-bridges-SPAN, oc1, 0.579333 pittsburg-bridges-SPAN, cart, 0.557544 pittsburg-bridges-SPAN, baseRaF, 0.630217 -pittsburg-bridges-T-OR-D, stree, 0.861619 +pittsburg-bridges-T-OR-D, stree, 0.849952 pittsburg-bridges-T-OR-D, wodt, 0.818429 pittsburg-bridges-T-OR-D, j48svm, 0.838333 pittsburg-bridges-T-OR-D, oc1, 0.831545 pittsburg-bridges-T-OR-D, cart, 0.821619 pittsburg-bridges-T-OR-D, baseRaF, 0.821007 -planning, stree, 0.70455 +planning, stree, 0.73527 planning, wodt, 0.576847 planning, j48svm, 0.711381 planning, oc1, 0.566988 planning, cart, 0.586712 planning, baseRaF, 0.590586 -post-operative, stree, 0.573333 +post-operative, stree, 0.703333 post-operative, wodt, 0.535556 post-operative, j48svm, 0.701111 post-operative, oc1, 0.542222 post-operative, cart, 0.567778 post-operative, baseRaF, 0.539375 -seeds, stree, 0.949048 +seeds, stree, 0.952857 seeds, wodt, 0.940476 seeds, j48svm, 0.909524 seeds, oc1, 0.932381 seeds, cart, 0.900476 seeds, baseRaF, 0.942518 -statlog-australian-credit, stree, 0.667246 +statlog-australian-credit, stree, 0.678261 statlog-australian-credit, wodt, 0.561594 statlog-australian-credit, j48svm, 0.66029 statlog-australian-credit, oc1, 0.573913 statlog-australian-credit, cart, 0.595507 statlog-australian-credit, baseRaF, 0.678261 -statlog-german-credit, stree, 0.7625 +statlog-german-credit, stree, 0.7569 statlog-german-credit, wodt, 0.6929 statlog-german-credit, j48svm, 0.7244 statlog-german-credit, oc1, 0.6874 statlog-german-credit, cart, 0.6738 statlog-german-credit, baseRaF, 0.68762 -statlog-heart, stree, 0.822963 +statlog-heart, stree, 0.822222 statlog-heart, wodt, 0.777778 statlog-heart, j48svm, 0.795926 statlog-heart, oc1, 0.749259 statlog-heart, cart, 0.762222 statlog-heart, baseRaF, 0.747605 -statlog-image, stree, 0.850649 +statlog-image, stree, 0.956623 statlog-image, wodt, 0.954632 statlog-image, j48svm, 0.967403 statlog-image, oc1, 0.95013 statlog-image, cart, 0.964892 statlog-image, baseRaF, 0.953604 -statlog-vehicle, stree, 0.695151 +statlog-vehicle, stree, 0.788537 statlog-vehicle, wodt, 0.726492 statlog-vehicle, j48svm, 0.729651 statlog-vehicle, oc1, 0.708496 statlog-vehicle, cart, 0.728367 statlog-vehicle, baseRaF, 0.789572 -synthetic-control, stree, 0.938833 +synthetic-control, stree, 0.95 synthetic-control, wodt, 0.973167 synthetic-control, j48svm, 0.922333 synthetic-control, oc1, 0.863167 synthetic-control, cart, 0.908333 synthetic-control, baseRaF, 0.971567 -tic-tac-toe, stree, 0.983296 +tic-tac-toe, stree, 0.984444 tic-tac-toe, wodt, 0.93905 tic-tac-toe, j48svm, 0.983295 tic-tac-toe, oc1, 0.91849 tic-tac-toe, cart, 0.951558 tic-tac-toe, baseRaF, 0.974906 -vertebral-column-2clases, stree, 0.852903 +vertebral-column-2clases, stree, 0.851936 vertebral-column-2clases, wodt, 0.801935 vertebral-column-2clases, j48svm, 0.84871 vertebral-column-2clases, oc1, 0.815161 vertebral-column-2clases, cart, 0.784839 vertebral-column-2clases, baseRaF, 0.822601 -wine, stree, 0.97581 +wine, stree, 0.949333 wine, wodt, 0.973048 wine, j48svm, 0.979143 wine, oc1, 0.916165 wine, cart, 0.921937 wine, baseRaF, 0.97748 -zoo, stree, 0.947619 +zoo, stree, 0.955524 zoo, wodt, 0.954429 zoo, j48svm, 0.92381 zoo, oc1, 0.890952 diff --git a/report_score.py b/report_score.py index fb7ccb6..6b70be7 100644 --- a/report_score.py +++ b/report_score.py @@ -55,22 +55,30 @@ def parse_arguments(): return (args.set_of_files, args.model, args.dataset, args.sql, args.param) -def nodes_leaves(clf): - nodes = 0 - leaves = 0 - for node in clf: - if node.is_leaf(): - leaves += 1 - else: - nodes += 1 - return nodes, leaves - - def compute_auto_hyperparams(X, y): - params = {"max_iter": 1e4, "C": 0.1} - classes = len(np.unique(y)) - if classes > 2: - params["split_criteria"] = "max_samples" + """Propuesta de auto configuración de hiperparámetros + max_it = 10e4 + (1 valor) + split = impurity si clases==2 y split=max_samples si clases > 2 + (1 valor) + kernel=linear o polinómico + (2 valores) + C = 0.1, 0.5 y 1.0 + (3 valores) + Caso 1: C=1, max_iter=1e4 + condicional split_max kernel lineal + Caso 2: C=0.5, max_iter=1e4 + condicional split_max kernel lineal + Caso 3: C=0.1, max_iter=1e4 + condicional split_max kernel lineal + Caso 4: C=1, max_iter=1e4 + condicional split_max kernel poly + Caso 5: C=0.5, max_iter=1e4 + condicional split_max kernel poly + Caso 6: C=0.1, max_iter=1e4 + condicional split_max kernel poly + Caso 7: C=1, max_iter=1e4 + condicional + kernel rbf + Caso 8: kernel rbf + """ + # params = {"max_iter": 1e4, "kernel": "rbf"} + # classes = len(np.unique(y)) + # if classes > 2: + # params["split_criteria"] = "max_samples" + params = {"kernel": "rbf"} return params @@ -97,7 +105,7 @@ def process_dataset(dataset, verbose, model, auto_params): clf = Stree(random_state=random_state) clf.set_params(**hyperparameters) res = cross_validate(clf, X, y, cv=kfold, return_estimator=True) - nodes, leaves = nodes_leaves(res["estimator"][0]) + nodes, leaves = res["estimator"][0].nodes_leaves() depth = res["estimator"][0].depth_ scores.append(res["test_score"]) times.append(res["fit_time"]) @@ -222,6 +230,9 @@ if dataset == "all": parameters = json.loads("{}") accuracy_best = 0.0 acc_best_std = 0.0 + if auto_params: + # show parameters computed + parameters = json.loads(hyperparameters) accuracy_computed = np.mean(scores) diff = accuracy_best - accuracy_computed print( @@ -243,12 +254,12 @@ else: accuracy_best = record[5] if record is not None else 0.0 acc_best_std = record[11] if record is not None else 0.0 print( - f"* Accuracy Computed : {accuracy:6.4f}±{np.std(scores):6.4f} " + f"* Accuracy Computed .: {accuracy:6.4f}±{np.std(scores):6.4f} " f"{np.mean(times):5.3f}s" ) print(f"* Accuracy Best .....: {accuracy_best:6.4f}±{acc_best_std:6.4f}") print(f"* Difference ........: {accuracy_best - accuracy:6.4f}") - print(f"* Nodes/Leaves/Depth :{nodes:2d} {leaves:2d} " f"{depth:2d} ") + print(f"* Nodes/Leaves/Depth : {nodes:2d} {leaves:2d} " f"{depth:2d} ") stop = time.time() print(f"- Auto Hyperparams ..: {hyperparameters}") hours, rem = divmod(stop - start, 3600)