From 08fb237001f2a78e37da1e5e5c6d042056bd219c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 22 Mar 2021 11:02:53 +0100 Subject: [PATCH] Non stratified experiments Remove reference column in analysis --- analysis_mysql.py | 44 +-- comparewodt/compare.sh | 15 + experimentation/Experiments.py | 11 +- report.csv | 484 ++++++++++++++++----------------- stats_stree.py | 40 +++ testwodt.py | 5 +- 6 files changed, 318 insertions(+), 281 deletions(-) create mode 100755 comparewodt/compare.sh create mode 100644 stats_stree.py diff --git a/analysis_mysql.py b/analysis_mysql.py index 89ee21d..9c6a8bf 100644 --- a/analysis_mysql.py +++ b/analysis_mysql.py @@ -15,7 +15,7 @@ models_tree = [ ] models_ensemble = ["odte", "adaBoost", "bagging", "TBRaF", "TBRoF", "TBRRoF"] title = "Best model results" -lengths = (30, 9, 11, 11, 11, 11, 11, 11) +lengths = (30, 12, 12, 12, 12, 12, 12) def parse_arguments() -> Tuple[str, str, str, bool, bool]: @@ -63,7 +63,7 @@ def report_header_content(title, experiment, model_type): output += "*" * length + "\n\n" lines = "" for item, data in enumerate(fields): - output += f"{fields[item]:{lengths[item]}} " + output += f"{fields[item]:^{lengths[item]}} " lines += "=" * lengths[item] + " " output += f"\n{lines}" return output @@ -80,31 +80,17 @@ def report_header(title, experiment, model_type): def report_line(line): output = f"{line['dataset']:{lengths[0] + 5}s} " data = models.copy() - data.insert(0, "reference") for key, model in enumerate(data): output += f"{line[model]:{lengths[key + 1]}s} " return output def report_footer(agg): - print( - TextColor.GREEN - + f"we have better results {agg['better']['items']:2d} times" - ) - print( - TextColor.RED - + f"we have worse results {agg['worse']['items']:2d} times" - ) + length = sum(lengths) + len(lengths) - 1 + print("-" * length) color = TextColor.LINE1 for item in models: - print( - color + f"{item:10s} used {agg[item]['items']:2d} times ", end="" - ) - print( - color + f"better than reference {agg[item]['better']:2d} times ", - end="", - ) - print(color + f"worse {agg[item]['worse']:2d} times ", end="") + print(color + f"{item:10s} ", end="") print(color + f"best of models {agg[item]['best']:2d} times") color = ( TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1 @@ -115,7 +101,7 @@ def report_footer(agg): dbh = MySQL() database = dbh.get_connection() dt = Datasets(False, False, "tanveer") -fields = ("Dataset", "Reference") +fields = ("Dataset",) models = models_tree if model_type == "tree" else models_ensemble for item in models: fields += (f"{item}",) @@ -127,9 +113,6 @@ for item in [ "worse", ] + models: agg[item] = {} - agg[item]["items"] = 0 - agg[item]["better"] = 0 - agg[item]["worse"] = 0 agg[item]["best"] = 0 if csv_output: f = open(report_csv, "w") @@ -143,22 +126,13 @@ for dataset in dt: for model in models: record = dbh.find_best(dataset[0], model, experiment) if record is None: - line[model] = color + "-" * 9 + " " + line[model] = color + "-" * 12 else: reference = record[13] accuracy = record[5] + acc_std = record[11] find_one = True - agg[model]["items"] += 1 - if accuracy > reference: - sign = "+" - agg["better"]["items"] += 1 - agg[model]["better"] += 1 - else: - sign = "-" - agg["worse"]["items"] += 1 - agg[model]["worse"] += 1 - item = f"{accuracy:9.7} {sign}" - line["reference"] = f"{reference:9.7}" + item = f"{accuracy:.4f}±{acc_std:.3f}" if accuracy == max_accuracy: line[model] = ( TextColor.GREEN + TextColor.BOLD + item + TextColor.ENDC diff --git a/comparewodt/compare.sh b/comparewodt/compare.sh new file mode 100755 index 0000000..1ae4435 --- /dev/null +++ b/comparewodt/compare.sh @@ -0,0 +1,15 @@ +#!/bin/bash +function busca_resultado() { + res=`grep -w $2 $1|cut -d";" -f2` +} +estratificado="estratificado.txt" +no_estratificado="no_estratificado.txt" +busca_resultado $estratificado "wine" +echo $res +busca_resultado $no_estratificado "zoo" +echo $res +cat $estratificado|while read dataset accuracy +do + busca_resultado $no_estratificado $dataset + echo "$dataset E[$accuracy] NE[$res]" +done diff --git a/experimentation/Experiments.py b/experimentation/Experiments.py index f2b85a3..3227df3 100644 --- a/experimentation/Experiments.py +++ b/experimentation/Experiments.py @@ -6,8 +6,9 @@ import warnings from sklearn.model_selection import GridSearchCV, cross_validate from . import Models -from .Database import Hyperparameters, Outcomes, MySQL +from .Database import Hyperparameters, MySQL, Outcomes from .Sets import Datasets +from sklearn.model_selection._split import KFold class Experiment: @@ -81,6 +82,7 @@ class Experiment: for item in outcomes: total[item] = [] for random_state in [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]: + kfold = KFold(shuffle=True, random_state=random_state, n_splits=5) model.set_params(**{"random_state": random_state}) print(f"{random_state}, ", end="", flush=True) with warnings.catch_warnings(): @@ -88,7 +90,12 @@ class Experiment: # Also affect subprocesses os.environ["PYTHONWARNINGS"] = "ignore" results = cross_validate( - model, X, y, return_train_score=True, n_jobs=self._threads + model, + X, + y, + return_train_score=True, + n_jobs=self._threads, + cv=kfold, ) for item in outcomes: total[item].append(results[item]) diff --git a/report.csv b/report.csv index 8ead9f2..772818b 100644 --- a/report.csv +++ b/report.csv @@ -1,295 +1,295 @@ dataset, classifier, accuracy -balance-scale, stree, 0.9488 -balance-scale, wodt, 0.86016 -balance-scale, j48svm, 0.94128 +balance-scale, stree, 0.97056 +balance-scale, wodt, 0.912 +balance-scale, j48svm, 0.94 balance-scale, oc1, 0.9192 -balance-scale, cart, 0.57312 -balance-scale, baseRaF, 0.790543 -balloons, stree, 0.866667 -balloons, wodt, 0.626667 -balloons, j48svm, 0.511667 +balance-scale, cart, 0.78816 +balance-scale, baseRaF, 0.706738 +balloons, stree, 0.86 +balloons, wodt, 0.688333 +balloons, j48svm, 0.595 balloons, oc1, 0.62 -balloons, cart, 0.683333 -balloons, baseRaF, 0.5375 -breast-cancer-wisc-diag, stree, 0.978932 -breast-cancer-wisc-diag, wodt, 0.962546 -breast-cancer-wisc-diag, j48svm, 0.956397 +balloons, cart, 0.671667 +balloons, baseRaF, 0.605 +breast-cancer-wisc-diag, stree, 0.972764 +breast-cancer-wisc-diag, wodt, 0.967317 +breast-cancer-wisc-diag, j48svm, 0.952878 breast-cancer-wisc-diag, oc1, 0.933477 -breast-cancer-wisc-diag, cart, 0.933925 -breast-cancer-wisc-diag, baseRaF, 0.94808 -breast-cancer-wisc-prog, stree, 0.828462 -breast-cancer-wisc-prog, wodt, 0.689654 -breast-cancer-wisc-prog, j48svm, 0.697013 +breast-cancer-wisc-diag, cart, 0.93953 +breast-cancer-wisc-diag, baseRaF, 0.965694 +breast-cancer-wisc-prog, stree, 0.811128 +breast-cancer-wisc-prog, wodt, 0.710141 +breast-cancer-wisc-prog, j48svm, 0.724038 breast-cancer-wisc-prog, oc1, 0.71 -breast-cancer-wisc-prog, cart, 0.749462 -breast-cancer-wisc-prog, baseRaF, 0.70087 -breast-cancer-wisc, stree, 0.965694 -breast-cancer-wisc, wodt, 0.936199 -breast-cancer-wisc, j48svm, 0.967529 +breast-cancer-wisc-prog, cart, 0.699833 +breast-cancer-wisc-prog, baseRaF, 0.74485 +breast-cancer-wisc, stree, 0.965802 +breast-cancer-wisc, wodt, 0.946208 +breast-cancer-wisc, j48svm, 0.967674 breast-cancer-wisc, oc1, 0.940194 -breast-cancer-wisc, cart, 0.937074 -breast-cancer-wisc, baseRaF, 0.946961 -breast-cancer, stree, 0.730853 -breast-cancer, wodt, 0.619673 -breast-cancer, j48svm, 0.712976 +breast-cancer-wisc, cart, 0.940629 +breast-cancer-wisc, baseRaF, 0.942857 +breast-cancer, stree, 0.733158 +breast-cancer, wodt, 0.650236 +breast-cancer, j48svm, 0.707719 breast-cancer, oc1, 0.649728 -breast-cancer, cart, 0.637205 -breast-cancer, baseRaF, 0.685839 -cardiotocography-10clases, stree, 0.666522 -cardiotocography-10clases, wodt, 0.627577 -cardiotocography-10clases, j48svm, 0.832552 +breast-cancer, cart, 0.65444 +breast-cancer, baseRaF, 0.656438 +cardiotocography-10clases, stree, 0.712009 +cardiotocography-10clases, wodt, 0.773706 +cardiotocography-10clases, j48svm, 0.830812 cardiotocography-10clases, oc1, 0.795528 -cardiotocography-10clases, cart, 0.716373 -cardiotocography-10clases, baseRaF, 0.679912 -cardiotocography-3clases, stree, 0.848074 -cardiotocography-3clases, wodt, 0.803063 -cardiotocography-3clases, j48svm, 0.9278 +cardiotocography-10clases, cart, 0.818864 +cardiotocography-10clases, baseRaF, 0.774788 +cardiotocography-3clases, stree, 0.891956 +cardiotocography-3clases, wodt, 0.897509 +cardiotocography-3clases, j48svm, 0.927327 cardiotocography-3clases, oc1, 0.899811 -cardiotocography-3clases, cart, 0.844726 -cardiotocography-3clases, baseRaF, 0.880937 -conn-bench-sonar-mines-rocks, stree, 0.597433 -conn-bench-sonar-mines-rocks, wodt, 0.649501 -conn-bench-sonar-mines-rocks, j48svm, 0.728897 +cardiotocography-3clases, cart, 0.929258 +cardiotocography-3clases, baseRaF, 0.896715 +conn-bench-sonar-mines-rocks, stree, 0.71439 +conn-bench-sonar-mines-rocks, wodt, 0.824959 +conn-bench-sonar-mines-rocks, j48svm, 0.73892 conn-bench-sonar-mines-rocks, oc1, 0.710798 -conn-bench-sonar-mines-rocks, cart, 0.630511 -conn-bench-sonar-mines-rocks, baseRaF, 0.727885 -cylinder-bands, stree, 0.628081 -cylinder-bands, wodt, 0.570849 -cylinder-bands, j48svm, 0.736126 +conn-bench-sonar-mines-rocks, cart, 0.728711 +conn-bench-sonar-mines-rocks, baseRaF, 0.772981 +cylinder-bands, stree, 0.687101 +cylinder-bands, wodt, 0.704074 +cylinder-bands, j48svm, 0.726351 cylinder-bands, oc1, 0.67106 -cylinder-bands, cart, 0.583602 -cylinder-bands, baseRaF, 0.647188 -dermatology, stree, 0.975454 -dermatology, wodt, 0.951925 -dermatology, j48svm, 0.955472 +cylinder-bands, cart, 0.712703 +cylinder-bands, baseRaF, 0.675117 +dermatology, stree, 0.971833 +dermatology, wodt, 0.965557 +dermatology, j48svm, 0.955735 dermatology, oc1, 0.916087 -dermatology, cart, 0.918064 -dermatology, baseRaF, 0.886384 -echocardiogram, stree, 0.82094 -echocardiogram, wodt, 0.723077 -echocardiogram, j48svm, 0.835726 +dermatology, cart, 0.932766 +dermatology, baseRaF, 0.970723 +echocardiogram, stree, 0.814758 +echocardiogram, wodt, 0.733875 +echocardiogram, j48svm, 0.805527 echocardiogram, oc1, 0.748291 -echocardiogram, cart, 0.757493 -echocardiogram, baseRaF, 0.775732 +echocardiogram, cart, 0.745043 +echocardiogram, baseRaF, 0.753522 fertility, stree, 0.88 -fertility, wodt, 0.763 -fertility, j48svm, 0.864 +fertility, wodt, 0.785 +fertility, j48svm, 0.857 fertility, oc1, 0.793 -fertility, cart, 0.752 -fertility, baseRaF, 0.837 -haberman-survival, stree, 0.764675 -haberman-survival, wodt, 0.647827 -haberman-survival, j48svm, 0.708847 +fertility, cart, 0.8 +fertility, baseRaF, 0.798 +haberman-survival, stree, 0.727795 +haberman-survival, wodt, 0.664707 +haberman-survival, j48svm, 0.714056 haberman-survival, oc1, 0.651634 -haberman-survival, cart, 0.640899 -haberman-survival, baseRaF, 0.733443 -heart-hungarian, stree, 0.829924 -heart-hungarian, wodt, 0.758328 -heart-hungarian, j48svm, 0.785061 +haberman-survival, cart, 0.65 +haberman-survival, baseRaF, 0.720133 +heart-hungarian, stree, 0.827522 +heart-hungarian, wodt, 0.764909 +heart-hungarian, j48svm, 0.785026 heart-hungarian, oc1, 0.758298 -heart-hungarian, cart, 0.731297 -heart-hungarian, baseRaF, 0.778289 -hepatitis, stree, 0.83871 -hepatitis, wodt, 0.774839 +heart-hungarian, cart, 0.760508 +heart-hungarian, baseRaF, 0.779804 +hepatitis, stree, 0.824516 +hepatitis, wodt, 0.785806 hepatitis, j48svm, 0.761935 hepatitis, oc1, 0.756774 -hepatitis, cart, 0.766452 -hepatitis, baseRaF, 0.764477 -ilpd-indian-liver, stree, 0.742691 -ilpd-indian-liver, wodt, 0.650159 -ilpd-indian-liver, j48svm, 0.692116 +hepatitis, cart, 0.765161 +hepatitis, baseRaF, 0.773671 +ilpd-indian-liver, stree, 0.719207 +ilpd-indian-liver, wodt, 0.676176 +ilpd-indian-liver, j48svm, 0.690339 ilpd-indian-liver, oc1, 0.660139 -ilpd-indian-liver, cart, 0.653587 -ilpd-indian-liver, baseRaF, 0.698181 -ionosphere, stree, 0.948732 -ionosphere, wodt, 0.857328 -ionosphere, j48svm, 0.891445 +ilpd-indian-liver, cart, 0.663423 +ilpd-indian-liver, baseRaF, 0.696685 +ionosphere, stree, 0.953276 +ionosphere, wodt, 0.88008 +ionosphere, j48svm, 0.891984 ionosphere, oc1, 0.879742 -ionosphere, cart, 0.876125 -ionosphere, baseRaF, 0.87236 -iris, stree, 0.98 -iris, wodt, 0.96 -iris, j48svm, 0.941333 +ionosphere, cart, 0.895771 +ionosphere, baseRaF, 0.875389 +iris, stree, 0.965333 +iris, wodt, 0.946 +iris, j48svm, 0.947333 iris, oc1, 0.948 -iris, cart, 0.956667 -iris, baseRaF, 0.944726 -led-display, stree, 0.7071 -led-display, wodt, 0.7053 -led-display, j48svm, 0.7177 +iris, cart, 0.938667 +iris, baseRaF, 0.953413 +led-display, stree, 0.703 +led-display, wodt, 0.7049 +led-display, j48svm, 0.7204 led-display, oc1, 0.6993 -led-display, cart, 0.7073 -led-display, baseRaF, 0.56058 -libras, stree, 0.761111 -libras, wodt, 0.671111 -libras, j48svm, 0.664167 +led-display, cart, 0.7037 +led-display, baseRaF, 0.70178 +libras, stree, 0.788333 +libras, wodt, 0.764167 +libras, j48svm, 0.66 libras, oc1, 0.645 -libras, cart, 0.555556 -libras, baseRaF, 0.657278 -low-res-spect, stree, 0.879492 -low-res-spect, wodt, 0.845585 -low-res-spect, j48svm, 0.831852 +libras, cart, 0.655 +libras, baseRaF, 0.726722 +low-res-spect, stree, 0.865713 +low-res-spect, wodt, 0.856459 +low-res-spect, j48svm, 0.83358 low-res-spect, oc1, 0.824671 -low-res-spect, cart, 0.826327 -low-res-spect, baseRaF, 0.765601 -lymphography, stree, 0.864828 -lymphography, wodt, 0.784598 -lymphography, j48svm, 0.772552 +low-res-spect, cart, 0.829206 +low-res-spect, baseRaF, 0.790875 +lymphography, stree, 0.823425 +lymphography, wodt, 0.808782 +lymphography, j48svm, 0.778552 lymphography, oc1, 0.734634 -lymphography, cart, 0.79331 -lymphography, baseRaF, 0.718919 -mammographic, stree, 0.819062 -mammographic, wodt, 0.76379 -mammographic, j48svm, 0.816863 +lymphography, cart, 0.766276 +lymphography, baseRaF, 0.761622 +mammographic, stree, 0.817068 +mammographic, wodt, 0.759839 +mammographic, j48svm, 0.821435 mammographic, oc1, 0.768805 -mammographic, cart, 0.766706 -mammographic, baseRaF, 0.802937 -molec-biol-promoter, stree, 0.810822 -molec-biol-promoter, wodt, 0.741905 -molec-biol-promoter, j48svm, 0.785455 +mammographic, cart, 0.757131 +mammographic, baseRaF, 0.780206 +molec-biol-promoter, stree, 0.767056 +molec-biol-promoter, wodt, 0.798528 +molec-biol-promoter, j48svm, 0.744935 molec-biol-promoter, oc1, 0.734805 -molec-biol-promoter, cart, 0.739437 -molec-biol-promoter, baseRaF, 0.644409 -musk-1, stree, 0.75432 -musk-1, wodt, 0.734763 -musk-1, j48svm, 0.806143 +molec-biol-promoter, cart, 0.748701 +molec-biol-promoter, baseRaF, 0.667239 +musk-1, stree, 0.916388 +musk-1, wodt, 0.838914 +musk-1, j48svm, 0.82693 musk-1, oc1, 0.776401 -musk-1, cart, 0.683419 -musk-1, baseRaF, 0.764916 -oocytes_merluccius_nucleus_4d, stree, 0.812142 -oocytes_merluccius_nucleus_4d, wodt, 0.723538 -oocytes_merluccius_nucleus_4d, j48svm, 0.740807 +musk-1, cart, 0.780215 +musk-1, baseRaF, 0.834034 +oocytes_merluccius_nucleus_4d, stree, 0.835125 +oocytes_merluccius_nucleus_4d, wodt, 0.737673 +oocytes_merluccius_nucleus_4d, j48svm, 0.741766 oocytes_merluccius_nucleus_4d, oc1, 0.743199 -oocytes_merluccius_nucleus_4d, cart, 0.706999 -oocytes_merluccius_nucleus_4d, baseRaF, 0.743156 -oocytes_merluccius_states_2f, stree, 0.921688 -oocytes_merluccius_states_2f, wodt, 0.884993 -oocytes_merluccius_states_2f, j48svm, 0.900002 +oocytes_merluccius_nucleus_4d, cart, 0.728265 +oocytes_merluccius_nucleus_4d, baseRaF, 0.792313 +oocytes_merluccius_states_2f, stree, 0.87359 +oocytes_merluccius_states_2f, wodt, 0.895115 +oocytes_merluccius_states_2f, j48svm, 0.901374 oocytes_merluccius_states_2f, oc1, 0.889223 -oocytes_merluccius_states_2f, cart, 0.877563 -oocytes_merluccius_states_2f, baseRaF, 0.87948 -oocytes_trisopterus_nucleus_2f, stree, 0.747691 -oocytes_trisopterus_nucleus_2f, wodt, 0.654345 -oocytes_trisopterus_nucleus_2f, j48svm, 0.755697 +oocytes_merluccius_states_2f, cart, 0.891193 +oocytes_merluccius_states_2f, baseRaF, 0.910551 +oocytes_trisopterus_nucleus_2f, stree, 0.799995 +oocytes_trisopterus_nucleus_2f, wodt, 0.751431 +oocytes_trisopterus_nucleus_2f, j48svm, 0.756587 oocytes_trisopterus_nucleus_2f, oc1, 0.747697 -oocytes_trisopterus_nucleus_2f, cart, 0.704823 -oocytes_trisopterus_nucleus_2f, baseRaF, 0.721601 -oocytes_trisopterus_states_5b, stree, 0.845361 -oocytes_trisopterus_states_5b, wodt, 0.769139 -oocytes_trisopterus_states_5b, j48svm, 0.885075 +oocytes_trisopterus_nucleus_2f, cart, 0.734313 +oocytes_trisopterus_nucleus_2f, baseRaF, 0.76193 +oocytes_trisopterus_states_5b, stree, 0.924441 +oocytes_trisopterus_states_5b, wodt, 0.89165 +oocytes_trisopterus_states_5b, j48svm, 0.887943 oocytes_trisopterus_states_5b, oc1, 0.86393 -oocytes_trisopterus_states_5b, cart, 0.757974 -oocytes_trisopterus_states_5b, baseRaF, 0.862434 -parkinsons, stree, 0.835897 -parkinsons, wodt, 0.811795 -parkinsons, j48svm, 0.859487 +oocytes_trisopterus_states_5b, cart, 0.870263 +oocytes_trisopterus_states_5b, baseRaF, 0.922149 +parkinsons, stree, 0.865641 +parkinsons, wodt, 0.901538 +parkinsons, j48svm, 0.844615 parkinsons, oc1, 0.865641 -parkinsons, cart, 0.725128 -parkinsons, baseRaF, 0.847298 -pima, stree, 0.780002 -pima, wodt, 0.697832 -pima, j48svm, 0.748314 +parkinsons, cart, 0.855897 +parkinsons, baseRaF, 0.87924 +pima, stree, 0.764053 +pima, wodt, 0.681591 +pima, j48svm, 0.749876 pima, oc1, 0.693027 -pima, cart, 0.712883 -pima, baseRaF, 0.70849 -pittsburg-bridges-MATERIAL, stree, 0.886147 -pittsburg-bridges-MATERIAL, wodt, 0.762208 -pittsburg-bridges-MATERIAL, j48svm, 0.84645 +pima, cart, 0.701172 +pima, baseRaF, 0.697005 +pittsburg-bridges-MATERIAL, stree, 0.867749 +pittsburg-bridges-MATERIAL, wodt, 0.79961 +pittsburg-bridges-MATERIAL, j48svm, 0.855844 pittsburg-bridges-MATERIAL, oc1, 0.81026 -pittsburg-bridges-MATERIAL, cart, 0.730087 -pittsburg-bridges-MATERIAL, baseRaF, 0.800316 -pittsburg-bridges-REL-L, stree, 0.578143 -pittsburg-bridges-REL-L, wodt, 0.574429 -pittsburg-bridges-REL-L, j48svm, 0.653571 +pittsburg-bridges-MATERIAL, cart, 0.783593 +pittsburg-bridges-MATERIAL, baseRaF, 0.81136 +pittsburg-bridges-REL-L, stree, 0.564048 +pittsburg-bridges-REL-L, wodt, 0.617143 +pittsburg-bridges-REL-L, j48svm, 0.645048 pittsburg-bridges-REL-L, oc1, 0.604957 -pittsburg-bridges-REL-L, cart, 0.581762 -pittsburg-bridges-REL-L, baseRaF, 0.623964 -pittsburg-bridges-SPAN, stree, 0.677193 -pittsburg-bridges-SPAN, wodt, 0.529357 -pittsburg-bridges-SPAN, j48svm, 0.626784 +pittsburg-bridges-REL-L, cart, 0.625333 +pittsburg-bridges-REL-L, baseRaF, 0.622107 +pittsburg-bridges-SPAN, stree, 0.658713 +pittsburg-bridges-SPAN, wodt, 0.606959 +pittsburg-bridges-SPAN, j48svm, 0.621579 pittsburg-bridges-SPAN, oc1, 0.579333 -pittsburg-bridges-SPAN, cart, 0.536023 -pittsburg-bridges-SPAN, baseRaF, 0.593913 -pittsburg-bridges-T-OR-D, stree, 0.902381 -pittsburg-bridges-T-OR-D, wodt, 0.79 -pittsburg-bridges-T-OR-D, j48svm, 0.835619 +pittsburg-bridges-SPAN, cart, 0.557544 +pittsburg-bridges-SPAN, baseRaF, 0.630217 +pittsburg-bridges-T-OR-D, stree, 0.849952 +pittsburg-bridges-T-OR-D, wodt, 0.818429 +pittsburg-bridges-T-OR-D, j48svm, 0.838333 pittsburg-bridges-T-OR-D, oc1, 0.831545 -pittsburg-bridges-T-OR-D, cart, 0.721667 -pittsburg-bridges-T-OR-D, baseRaF, 0.841081 -planning, stree, 0.725525 -planning, wodt, 0.552192 -planning, j48svm, 0.711246 +pittsburg-bridges-T-OR-D, cart, 0.821619 +pittsburg-bridges-T-OR-D, baseRaF, 0.821007 +planning, stree, 0.73527 +planning, wodt, 0.576847 +planning, j48svm, 0.711381 planning, oc1, 0.566988 -planning, cart, 0.574384 -planning, baseRaF, 0.626404 -post-operative, stree, 0.722222 -post-operative, wodt, 0.56 -post-operative, j48svm, 0.692222 +planning, cart, 0.586712 +planning, baseRaF, 0.590586 +post-operative, stree, 0.703333 +post-operative, wodt, 0.535556 +post-operative, j48svm, 0.701111 post-operative, oc1, 0.542222 -post-operative, cart, 0.586667 -post-operative, baseRaF, 0.669413 -seeds, stree, 0.949048 -seeds, wodt, 0.925238 -seeds, j48svm, 0.912381 +post-operative, cart, 0.567778 +post-operative, baseRaF, 0.539375 +seeds, stree, 0.952857 +seeds, wodt, 0.940476 +seeds, j48svm, 0.909524 seeds, oc1, 0.932381 -seeds, cart, 0.879524 -seeds, baseRaF, 0.904209 -statlog-australian-credit, stree, 0.678116 -statlog-australian-credit, wodt, 0.571739 -statlog-australian-credit, j48svm, 0.655652 +seeds, cart, 0.900476 +seeds, baseRaF, 0.942518 +statlog-australian-credit, stree, 0.678261 +statlog-australian-credit, wodt, 0.561594 +statlog-australian-credit, j48svm, 0.66029 statlog-australian-credit, oc1, 0.573913 -statlog-australian-credit, cart, 0.606377 +statlog-australian-credit, cart, 0.595507 statlog-australian-credit, baseRaF, 0.678261 -statlog-german-credit, stree, 0.7472 -statlog-german-credit, wodt, 0.6878 -statlog-german-credit, j48svm, 0.7261 +statlog-german-credit, stree, 0.7569 +statlog-german-credit, wodt, 0.6929 +statlog-german-credit, j48svm, 0.7244 statlog-german-credit, oc1, 0.6874 -statlog-german-credit, cart, 0.6834 -statlog-german-credit, baseRaF, 0.69528 -statlog-heart, stree, 0.848148 -statlog-heart, wodt, 0.773333 -statlog-heart, j48svm, 0.815556 +statlog-german-credit, cart, 0.6738 +statlog-german-credit, baseRaF, 0.68762 +statlog-heart, stree, 0.822222 +statlog-heart, wodt, 0.777778 +statlog-heart, j48svm, 0.795926 statlog-heart, oc1, 0.749259 -statlog-heart, cart, 0.758519 -statlog-heart, baseRaF, 0.767883 -statlog-image, stree, 0.959307 -statlog-image, wodt, 0.955671 -statlog-image, j48svm, 0.966797 +statlog-heart, cart, 0.762222 +statlog-heart, baseRaF, 0.747605 +statlog-image, stree, 0.956623 +statlog-image, wodt, 0.954632 +statlog-image, j48svm, 0.967403 statlog-image, oc1, 0.95013 -statlog-image, cart, 0.963377 -statlog-image, baseRaF, 0.825938 -statlog-vehicle, stree, 0.801413 -statlog-vehicle, wodt, 0.731811 -statlog-vehicle, j48svm, 0.730389 +statlog-image, cart, 0.964892 +statlog-image, baseRaF, 0.953604 +statlog-vehicle, stree, 0.788537 +statlog-vehicle, wodt, 0.726492 +statlog-vehicle, j48svm, 0.729651 statlog-vehicle, oc1, 0.708496 -statlog-vehicle, cart, 0.728592 -statlog-vehicle, baseRaF, 0.683698 -synthetic-control, stree, 0.971667 -synthetic-control, wodt, 0.979 -synthetic-control, j48svm, 0.921667 +statlog-vehicle, cart, 0.728367 +statlog-vehicle, baseRaF, 0.789572 +synthetic-control, stree, 0.95 +synthetic-control, wodt, 0.973167 +synthetic-control, j48svm, 0.922333 synthetic-control, oc1, 0.863167 -synthetic-control, cart, 0.906333 -synthetic-control, baseRaF, 0.8999 -tic-tac-toe, stree, 0.987435 -tic-tac-toe, wodt, 0.849967 -tic-tac-toe, j48svm, 0.983301 +synthetic-control, cart, 0.908333 +synthetic-control, baseRaF, 0.971567 +tic-tac-toe, stree, 0.984444 +tic-tac-toe, wodt, 0.93905 +tic-tac-toe, j48svm, 0.983295 tic-tac-toe, oc1, 0.91849 -tic-tac-toe, cart, 0.836177 -tic-tac-toe, baseRaF, 0.836562 -vertebral-column-2clases, stree, 0.829032 -vertebral-column-2clases, wodt, 0.793548 -vertebral-column-2clases, j48svm, 0.850645 +tic-tac-toe, cart, 0.951558 +tic-tac-toe, baseRaF, 0.974906 +vertebral-column-2clases, stree, 0.851936 +vertebral-column-2clases, wodt, 0.801935 +vertebral-column-2clases, j48svm, 0.84871 vertebral-column-2clases, oc1, 0.815161 -vertebral-column-2clases, cart, 0.775161 -vertebral-column-2clases, baseRaF, 0.794591 -wine, stree, 0.977778 -wine, wodt, 0.968079 -wine, j48svm, 0.983778 +vertebral-column-2clases, cart, 0.784839 +vertebral-column-2clases, baseRaF, 0.822601 +wine, stree, 0.949333 +wine, wodt, 0.973048 +wine, j48svm, 0.979143 wine, oc1, 0.916165 -wine, cart, 0.897524 -wine, baseRaF, 0.923513 -zoo, stree, 0.96 -zoo, wodt, 0.945 -zoo, j48svm, 0.920857 +wine, cart, 0.921937 +wine, baseRaF, 0.97748 +zoo, stree, 0.955524 +zoo, wodt, 0.954429 +zoo, j48svm, 0.92381 zoo, oc1, 0.890952 -zoo, cart, 0.958 -zoo, baseRaF, 0.8861 +zoo, cart, 0.957476 +zoo, baseRaF, 0.936262 diff --git a/stats_stree.py b/stats_stree.py new file mode 100644 index 0000000..43ef876 --- /dev/null +++ b/stats_stree.py @@ -0,0 +1,40 @@ +from stree import Stree +from experimentation.Sets import Datasets + + +def nodes_leaves(clf): + nodes = 0 + leaves = 0 + for node in clf: + if node.is_leaf(): + leaves += 1 + else: + nodes += 1 + return nodes, leaves + + +def compute_depth(node, depth): + if node is None: + return depth + if node.is_leaf(): + return depth + 1 + + return max( + compute_depth(node.get_up(), depth + 1), + compute_depth(node.get_down(), depth + 1), + ) + + +dt = Datasets(True, False, "tanveer") +for dataset in dt: + dataset_name = dataset[0] + X, y = dt.load(dataset_name) + clf = Stree(random_state=1) + clf.fit(X, y) + accuracy = clf.score(X, y) + nodes, leaves = nodes_leaves(clf) + depth = compute_depth(clf.tree_, 0) + print( + f"{dataset_name:30s} {nodes:5d} {leaves:5d} {clf.depth_:5d} " + f"{depth:5d} {accuracy:7.5f}" + ) diff --git a/testwodt.py b/testwodt.py index 97115d4..3d4fa26 100644 --- a/testwodt.py +++ b/testwodt.py @@ -1,6 +1,6 @@ import argparse from wodt import TreeClassifier -from sklearn.model_selection import cross_val_score +from sklearn.model_selection import KFold, cross_val_score import numpy as np import random from experimentation.Sets import Datasets @@ -85,8 +85,9 @@ def process_dataset(dataset, verbose): for random_state in random_seeds: random.seed(random_state) np.random.seed(random_state) + kfold = KFold(shuffle=True, random_state=random_state, n_splits=5) clf = TreeClassifier(random_state=random_state) - res = cross_val_score(clf, X, y, cv=5) + res = cross_val_score(clf, X, y, cv=kfold) scores.append(res) if verbose: print(