From 7f75115fa9cd995e47b9549dca14861945d738ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 26 Mar 2021 00:06:48 +0100 Subject: [PATCH] Add stree default to analysis add experiment to report_mysql fix crosval experiment to get the best "gridsearch" parameters --- analysis_mysql.py | 18 +++++++++++++++--- experimentation/Experiments.py | 4 +++- report.csv | 32 ++++++++++++++++---------------- report_mysql.py | 14 ++++++++++++-- 4 files changed, 46 insertions(+), 22 deletions(-) diff --git a/analysis_mysql.py b/analysis_mysql.py index 4453da6..2c0d692 100644 --- a/analysis_mysql.py +++ b/analysis_mysql.py @@ -8,6 +8,7 @@ from experimentation.Database import MySQL report_csv = "report.csv" models_tree = [ "stree", + "stree_default", "wodt", "j48svm", "oc1", @@ -18,7 +19,7 @@ models_ensemble = ["odte", "adaBoost", "bagging", "TBRaF", "TBRoF", "TBRRoF"] description = ["samp", "var", "cls"] complexity = ["nodes", "leaves", "depth"] title = "Best model results" -lengths = (30, 4, 3, 3, 3, 3, 3, 12, 12, 12, 12, 12, 12) +lengths = [30, 4, 3, 3, 3, 3, 3, 12, 12, 12, 12, 12, 12, 12] def parse_arguments() -> Tuple[str, str, str, bool, bool]: @@ -46,8 +47,15 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]: required=False, default=False, ) + ap.add_argument( + "-o", + "--compare", + type=bool, + required=False, + default=False, + ) args = ap.parse_args() - return (args.experiment, args.model, args.csv_output) + return (args.experiment, args.model, args.csv_output, args.compare) def report_header_content(title, experiment, model_type): @@ -102,7 +110,7 @@ def report_footer(agg): ) -(experiment, model_type, csv_output) = parse_arguments() +(experiment, model_type, csv_output, compare) = parse_arguments() dbh = MySQL() database = dbh.get_connection() dt = Datasets(False, False, "tanveer") @@ -115,6 +123,10 @@ fields = ( "Lea", "Dep", ) +if not compare: + # remove stree_default from fields list and lengths + models_tree.pop(1) + lengths.pop(7) models = models_tree if model_type == "tree" else models_ensemble for item in models: fields += (f"{item}",) diff --git a/experimentation/Experiments.py b/experimentation/Experiments.py index 732ee9c..22b2345 100644 --- a/experimentation/Experiments.py +++ b/experimentation/Experiments.py @@ -46,7 +46,9 @@ class Experiment: model = self._clf.get_model() hyperparams = MySQL() hyperparams.get_connection() - record = hyperparams.find_best(dataset, self._model_name) + record = hyperparams.find_best( + dataset, self._model_name, experiment="gridsearch" + ) hyperparams.close() if record is None: try: diff --git a/report.csv b/report.csv index 15a523f..48aef0a 100644 --- a/report.csv +++ b/report.csv @@ -23,13 +23,13 @@ breast-cancer-wisc-prog, j48svm, 0.724038 breast-cancer-wisc-prog, oc1, 0.71 breast-cancer-wisc-prog, cart, 0.699833 breast-cancer-wisc-prog, baseRaF, 0.74485 -breast-cancer-wisc, stree, 0.965802 +breast-cancer-wisc, stree, 0.966661 breast-cancer-wisc, wodt, 0.946208 breast-cancer-wisc, j48svm, 0.967674 breast-cancer-wisc, oc1, 0.940194 breast-cancer-wisc, cart, 0.940629 breast-cancer-wisc, baseRaF, 0.942857 -breast-cancer, stree, 0.733158 +breast-cancer, stree, 0.734211 breast-cancer, wodt, 0.650236 breast-cancer, j48svm, 0.707719 breast-cancer, oc1, 0.649728 @@ -47,13 +47,13 @@ cardiotocography-3clases, j48svm, 0.927327 cardiotocography-3clases, oc1, 0.899811 cardiotocography-3clases, cart, 0.929258 cardiotocography-3clases, baseRaF, 0.896715 -conn-bench-sonar-mines-rocks, stree, 0.71439 +conn-bench-sonar-mines-rocks, stree, 0.755528 conn-bench-sonar-mines-rocks, wodt, 0.824959 conn-bench-sonar-mines-rocks, j48svm, 0.73892 conn-bench-sonar-mines-rocks, oc1, 0.710798 conn-bench-sonar-mines-rocks, cart, 0.728711 conn-bench-sonar-mines-rocks, baseRaF, 0.772981 -cylinder-bands, stree, 0.687101 +cylinder-bands, stree, 0.715049 cylinder-bands, wodt, 0.704074 cylinder-bands, j48svm, 0.726351 cylinder-bands, oc1, 0.67106 @@ -77,7 +77,7 @@ fertility, j48svm, 0.857 fertility, oc1, 0.793 fertility, cart, 0.8 fertility, baseRaF, 0.798 -haberman-survival, stree, 0.727795 +haberman-survival, stree, 0.735637 haberman-survival, wodt, 0.664707 haberman-survival, j48svm, 0.714056 haberman-survival, oc1, 0.651634 @@ -95,7 +95,7 @@ hepatitis, j48svm, 0.761935 hepatitis, oc1, 0.756774 hepatitis, cart, 0.765161 hepatitis, baseRaF, 0.773671 -ilpd-indian-liver, stree, 0.719207 +ilpd-indian-liver, stree, 0.723498 ilpd-indian-liver, wodt, 0.676176 ilpd-indian-liver, j48svm, 0.690339 ilpd-indian-liver, oc1, 0.660139 @@ -137,7 +137,7 @@ lymphography, j48svm, 0.778552 lymphography, oc1, 0.734634 lymphography, cart, 0.766276 lymphography, baseRaF, 0.761622 -mammographic, stree, 0.817068 +mammographic, stree, 0.81915 mammographic, wodt, 0.759839 mammographic, j48svm, 0.821435 mammographic, oc1, 0.768805 @@ -167,7 +167,7 @@ oocytes_merluccius_states_2f, j48svm, 0.901374 oocytes_merluccius_states_2f, oc1, 0.889223 oocytes_merluccius_states_2f, cart, 0.891193 oocytes_merluccius_states_2f, baseRaF, 0.910551 -oocytes_trisopterus_nucleus_2f, stree, 0.799995 +oocytes_trisopterus_nucleus_2f, stree, 0.800986 oocytes_trisopterus_nucleus_2f, wodt, 0.751431 oocytes_trisopterus_nucleus_2f, j48svm, 0.756587 oocytes_trisopterus_nucleus_2f, oc1, 0.747697 @@ -179,13 +179,13 @@ oocytes_trisopterus_states_5b, j48svm, 0.887943 oocytes_trisopterus_states_5b, oc1, 0.86393 oocytes_trisopterus_states_5b, cart, 0.870263 oocytes_trisopterus_states_5b, baseRaF, 0.922149 -parkinsons, stree, 0.865641 +parkinsons, stree, 0.882051 parkinsons, wodt, 0.901538 parkinsons, j48svm, 0.844615 parkinsons, oc1, 0.865641 parkinsons, cart, 0.855897 parkinsons, baseRaF, 0.87924 -pima, stree, 0.764053 +pima, stree, 0.766651 pima, wodt, 0.681591 pima, j48svm, 0.749876 pima, oc1, 0.693027 @@ -197,7 +197,7 @@ pittsburg-bridges-MATERIAL, j48svm, 0.855844 pittsburg-bridges-MATERIAL, oc1, 0.81026 pittsburg-bridges-MATERIAL, cart, 0.783593 pittsburg-bridges-MATERIAL, baseRaF, 0.81136 -pittsburg-bridges-REL-L, stree, 0.564048 +pittsburg-bridges-REL-L, stree, 0.62519 pittsburg-bridges-REL-L, wodt, 0.617143 pittsburg-bridges-REL-L, j48svm, 0.645048 pittsburg-bridges-REL-L, oc1, 0.604957 @@ -209,7 +209,7 @@ pittsburg-bridges-SPAN, j48svm, 0.621579 pittsburg-bridges-SPAN, oc1, 0.579333 pittsburg-bridges-SPAN, cart, 0.557544 pittsburg-bridges-SPAN, baseRaF, 0.630217 -pittsburg-bridges-T-OR-D, stree, 0.849952 +pittsburg-bridges-T-OR-D, stree, 0.861619 pittsburg-bridges-T-OR-D, wodt, 0.818429 pittsburg-bridges-T-OR-D, j48svm, 0.838333 pittsburg-bridges-T-OR-D, oc1, 0.831545 @@ -239,13 +239,13 @@ statlog-australian-credit, j48svm, 0.66029 statlog-australian-credit, oc1, 0.573913 statlog-australian-credit, cart, 0.595507 statlog-australian-credit, baseRaF, 0.678261 -statlog-german-credit, stree, 0.7569 +statlog-german-credit, stree, 0.7625 statlog-german-credit, wodt, 0.6929 statlog-german-credit, j48svm, 0.7244 statlog-german-credit, oc1, 0.6874 statlog-german-credit, cart, 0.6738 statlog-german-credit, baseRaF, 0.68762 -statlog-heart, stree, 0.822222 +statlog-heart, stree, 0.822963 statlog-heart, wodt, 0.777778 statlog-heart, j48svm, 0.795926 statlog-heart, oc1, 0.749259 @@ -275,13 +275,13 @@ tic-tac-toe, j48svm, 0.983295 tic-tac-toe, oc1, 0.91849 tic-tac-toe, cart, 0.951558 tic-tac-toe, baseRaF, 0.974906 -vertebral-column-2clases, stree, 0.851936 +vertebral-column-2clases, stree, 0.852903 vertebral-column-2clases, wodt, 0.801935 vertebral-column-2clases, j48svm, 0.84871 vertebral-column-2clases, oc1, 0.815161 vertebral-column-2clases, cart, 0.784839 vertebral-column-2clases, baseRaF, 0.822601 -wine, stree, 0.949333 +wine, stree, 0.97581 wine, wodt, 0.973048 wine, j48svm, 0.979143 wine, oc1, 0.916165 diff --git a/report_mysql.py b/report_mysql.py index 85ed7cb..09e527d 100644 --- a/report_mysql.py +++ b/report_mysql.py @@ -17,6 +17,14 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]: required=False, default="any", ) + ap.add_argument( + "-e", + "--experiment", + type=str, + choices=["gridsearch", "crossval"], + required=False, + default="crossval", + ) ap.add_argument( "-x", "--excludeparams", @@ -29,6 +37,7 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]: return ( args.model, args.excludeparams, + args.experiment, ) @@ -54,7 +63,7 @@ def report_header(exclude_params): def report_line(record, agg): accuracy = record[5] - expected = record[13] + expected = record[16] if accuracy < expected: agg["worse"] += 1 sign = "-" @@ -94,6 +103,7 @@ def report_footer(agg): ( classifier, exclude_parameters, + experiment, ) = parse_arguments() dbh = MySQL() database = dbh.get_connection() @@ -124,7 +134,7 @@ for item in [ ] + models: agg[item] = 0 for dataset in dt: - record = dbh.find_best(dataset[0], classifier) + record = dbh.find_best(dataset[0], classifier, experiment=experiment) if record is None: print(TextColor.FAIL + f"*No results found for {dataset[0]}") else: