From 84795b4c43d951154500fbf765b1972687554836 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 15 Apr 2021 00:30:16 +0200 Subject: [PATCH] Fix normalization & standard. columnwise --- checkdatasets.py | 10 +++- checknormalize.py | 40 ++++++++------ experimentation/Sets.py | 26 +++++++-- .../stree_default_translated_scale.txt | 55 +++++++++++++++++++ 4 files changed, 109 insertions(+), 22 deletions(-) create mode 100644 results/normalizados/stree_default_translated_scale.txt diff --git a/checkdatasets.py b/checkdatasets.py index c2d2bd9..13d6864 100755 --- a/checkdatasets.py +++ b/checkdatasets.py @@ -1,4 +1,5 @@ from experimentation.Sets import Datasets +import numpy as np dt = Datasets(normalize=False, standardize=False, set_of_files="tanveer") print("Checking normalized datasets: ") @@ -7,6 +8,8 @@ for data in dt: X, y = dt.load(name) min_value = X.min() max_value = X.max() + media = np.mean(X) + desv = np.std(X) resultado = ( "Normalizado" if min_value <= 1 @@ -15,4 +18,9 @@ for data in dt: and max_value >= 0 else "No Normalizado" ) - print(f"{name:30s}: {resultado} ") + resultado2 = ( + "Standardizado" + if round(media) == 0 and round(desv) == 1 + else "No Standardizado" + ) + print(f"{name:30s}: {resultado} {resultado2}") diff --git a/checknormalize.py b/checknormalize.py index 01e9ab9..92ca9b7 100755 --- a/checknormalize.py +++ b/checknormalize.py @@ -24,21 +24,20 @@ def header(): print("Processing Datasets with stree default.\n") print( f"{'Dataset':30s} {'No Norm.':9s} {'Normaliz.':9s} " - f"{'Col.Norm.':9s} {'Best score in crossval':25s}" + f"{'Col.Norm.':9s} {'Context B':9s} {'Best score in crossval':25s}" ) - print("=" * 30 + " " + ("=" * 9 + " ") * 3 + "=" * 25) + print("=" * 30 + " " + ("=" * 9 + " ") * 4 + "=" * 25) -def process_dataset(X, y): +def process_dataset(X, y, normalize): scores = [] # return random.uniform(0, 1) - # Get the optimized parameters for random_state in random_seeds: random.seed(random_state) + clf_test = Stree(random_state=random_state, normalize=normalize) np.random.seed(random_state) kfold = KFold(shuffle=True, random_state=random_state, n_splits=5) - clf = Stree(random_state=random_state) - res = cross_validate(clf, X, y, cv=kfold, return_estimator=True) + res = cross_validate(clf_test, X, y, cv=kfold, return_estimator=True) scores.append(res["test_score"]) return np.mean(scores) @@ -58,7 +57,7 @@ database = dbh.get_connection() random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] dt = Datasets(normalize=False, standardize=False, set_of_files="tanveer") header() -total = [0, 0, 0] +total = [0, 0, 0, 0] line = TextColor.LINE1 for data in dt: name = data[0] @@ -66,22 +65,31 @@ for data in dt: record = dbh.find_best(name, models_tree, "crossval") X2 = normalize(X) X3 = normalize_rows(X) - clf = Stree(random_state=1) - ac1 = process_dataset(X, y) - ac2 = process_dataset(X2, y) - ac3 = process_dataset(X3, y) - max_value = max(ac1, ac2, ac3) + ac1 = process_dataset(X, y, False) + ac2 = process_dataset(X2, y, False) + ac3 = process_dataset(X3, y, False) + ac4 = process_dataset(X, y, True) + max_value = round(max(ac1, ac2, ac3, ac4), 6) line = TextColor.LINE2 if line == TextColor.LINE1 else TextColor.LINE1 print(line + f"{name:30s} ", end="", flush=True) - total[np.argmax([ac1, ac2, ac3])] += 1 + total[np.argmax([ac1, ac2, ac3, ac4])] += 1 color1 = TextColor.SUCCESS if ac1 == max_value else line color2 = TextColor.SUCCESS if ac2 == max_value else line color3 = TextColor.SUCCESS if ac3 == max_value else line + color4 = TextColor.SUCCESS if ac4 == max_value else line print(color1 + f"{ac1:9.6f} " + TextColor.ENDC, end="", flush=True) print(color2 + f"{ac2:9.6f} " + TextColor.ENDC, end="", flush=True) - print(color3 + f"{ac3:9.6f}" + TextColor.ENDC, end="", flush=True) - print(line + f"{record[5]:9.6f} {record[3]}" + TextColor.ENDC) -print(f"{'Total':30s} {total[0]:9d} {total[1]:9d} {total[2]:9d}") + print(color3 + f"{ac3:9.6f} " + TextColor.ENDC, end="", flush=True) + print(color4 + f"{ac4:9.6f}" + TextColor.ENDC, end="", flush=True) + best_accuracy = round(record[5], 6) + best_color = TextColor.UNDERLINE if best_accuracy >= max_value else "" + print( + line + + best_color + + f"{best_accuracy:9.6f} {record[3]}" + + TextColor.ENDC + ) +print(f"{'Total':30s} {total[0]:9d} {total[1]:9d} {total[2]:9d} {total[3]:9d}") stop = time.time() hours, rem = divmod(stop - start, 3600) minutes, seconds = divmod(rem, 60) diff --git a/experimentation/Sets.py b/experimentation/Sets.py index 9b7d382..553808c 100644 --- a/experimentation/Sets.py +++ b/experimentation/Sets.py @@ -35,23 +35,39 @@ class Dataset_Base: """ pass - def normalize(self, data: np.array) -> np.array: + @staticmethod + def normalize(data: np.array) -> np.array: min_data = data.min() return (data - min_data) / (data.max() - min_data) - def standardize(self, data: np.array) -> np.array: + @staticmethod + def normalize_rows(data: np.array) -> np.array: + res = data.copy() + for col in range(res.shape[1]): + res[:, col] = Dataset_Base.normalize(res[:, col]) + return res + + @staticmethod + def standardize(data: np.array) -> np.array: return (data - data.mean()) / data.std() + @staticmethod + def standardize_rows(data: np.array) -> np.array: + res = data.copy() + for col in range(res.shape[1]): + res[:, col] = Dataset_Base.standardize(res[:, col]) + return res + def get_params(self) -> str: return f"normalize={self._normalize}, standardize={self._standardize}" def post_process(self, X: np.array, y: np.array) -> tdataset: if self._standardize and self._normalize: - X = self.standardize(self.normalize(X)) + X = self.standardize_rows(self.normalize_rows(X)) elif self._standardize: - X = self.standardize(X) + X = self.standardize_rows(X) elif self._normalize: - X = self.normalize(X) + X = self.normalize_rows(X) return X, y def __iter__(self) -> Diterator: diff --git a/results/normalizados/stree_default_translated_scale.txt b/results/normalizados/stree_default_translated_scale.txt new file mode 100644 index 0000000..15f6d0a --- /dev/null +++ b/results/normalizados/stree_default_translated_scale.txt @@ -0,0 +1,55 @@ +* Process all datasets set with stree_default: tanveer norm: False std: False store in: stree_default +5 Fold Cross Validation with 10 random seeds [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] + +Dataset Samp Var Cls Nodes Leaves Depth Accuracy Time Parameters +============================== ===== === === ======= ======= ======= =============== =============== ========== +balance-scale 625 4 3 27.52 14.26 8.06 0.909600±0.0265 0.028588±0.0084 {} +balloons 16 4 2 3.32 2.16 2.08 0.666667±0.2544 0.001797±0.0006 {} +breast-cancer-wisc-diag 569 30 2 5.40 3.20 2.90 0.968364±0.0176 0.011921±0.0026 {} +breast-cancer-wisc-prog 198 33 2 7.20 4.10 3.22 0.807038±0.0519 0.030704±0.0080 {} +breast-cancer-wisc 699 9 2 8.64 4.82 4.08 0.966805±0.0136 0.012557±0.0025 {} +breast-cancer 286 9 2 21.52 11.26 5.92 0.731428±0.0463 0.037908±0.0124 {} +cardiotocography-10clases 2126 21 10 176.20 88.60 23.48 0.795250±0.0203 3.384487±0.3374 {} +cardiotocography-3clases 2126 21 3 51.04 26.02 9.46 0.900566±0.0145 1.187683±0.1610 {} +conn-bench-sonar-mines-rocks 208 60 2 6.04 3.52 2.84 0.748351±0.0692 0.013475±0.0036 {} +cylinder-bands 512 35 2 26.32 13.66 6.84 0.708785±0.0412 0.323182±0.1296 {} +dermatology 366 34 6 11.08 6.04 6.04 0.964720±0.0219 0.016816±0.0010 {} +echocardiogram 131 10 2 9.60 5.30 4.04 0.802593±0.0769 0.009271±0.0038 {} +fertility 100 9 2 2.68 1.84 1.78 0.868000±0.0606 0.003031±0.0014 {} +haberman-survival 306 3 2 26.32 13.66 6.26 0.736610±0.0456 0.024594±0.0054 {} +heart-hungarian 294 12 2 13.80 7.40 4.48 0.819719±0.0534 0.023521±0.0042 {} +hepatitis 155 19 2 11.32 6.16 4.52 0.789677±0.0745 0.011099±0.0035 {} +ilpd-indian-liver 583 9 2 17.28 9.14 5.54 0.717137±0.0385 0.037171±0.0176 {} +ionosphere 351 33 2 7.24 4.12 3.60 0.872040±0.0411 0.023299±0.0057 {} +iris 150 4 3 5.32 3.16 3.06 0.964667±0.0323 0.004864±0.0008 {} +led-display 1000 7 10 47.36 24.18 17.74 0.703900±0.0294 0.345882±0.0343 {} +libras 360 90 15 54.64 27.82 23.90 0.744444±0.0485 5.653415±0.9524 {} +low-res-spect 531 100 9 23.80 12.40 10.14 0.854031±0.0356 0.786059±0.0930 {} +lymphography 148 18 4 14.80 7.90 5.42 0.778437±0.0793 0.016700±0.0037 {} +mammographic 961 5 2 8.76 4.88 4.42 0.819148±0.0227 0.048142±0.0106 {} +molec-biol-promoter 106 57 2 3.00 2.00 2.00 0.765238±0.0835 0.001846±0.0001 {} +musk-1 476 166 2 6.72 3.86 3.00 0.843890±0.0322 0.259591±0.0426 {} +oocytes_merluccius_nucleus_4d 1022 41 2 15.36 8.18 4.78 0.812123±0.0248 1.502560±0.2593 {} +oocytes_merluccius_states_2f 1022 25 3 19.96 10.48 5.50 0.915366±0.0192 0.244292±0.0545 {} +oocytes_trisopterus_nucleus_2f 912 25 2 30.20 15.60 7.54 0.800437±0.0237 0.775162±0.1727 {} +oocytes_trisopterus_states_5b 912 32 3 11.40 6.20 4.50 0.916876±0.0191 0.670656±0.2637 {} +parkinsons 195 22 2 8.24 4.62 3.64 0.883077±0.0456 0.010391±0.0025 {} +pima 768 8 2 17.20 9.10 5.66 0.765999±0.0311 0.076647±0.0243 {} +pittsburg-bridges-MATERIAL 106 7 3 10.68 5.84 4.46 0.800736±0.0635 0.009084±0.0019 {} +pittsburg-bridges-REL-L 103 7 3 17.20 9.10 6.28 0.628429±0.1009 0.017778±0.0043 {} +pittsburg-bridges-SPAN 92 7 3 13.72 7.36 5.12 0.629181±0.1008 0.011592±0.0026 {} +pittsburg-bridges-T-OR-D 102 7 2 5.16 3.08 2.90 0.860571±0.0679 0.004102±0.0009 {} +planning 182 12 2 4.00 2.50 2.28 0.703468±0.0751 0.011304±0.0055 {} +post-operative 90 8 3 3.68 2.34 2.26 0.675556±0.0919 0.006637±0.0026 {} +seeds 210 7 3 9.40 5.20 4.32 0.948571±0.0345 0.008227±0.0011 {} +statlog-australian-credit 690 14 2 10.88 5.94 4.48 0.667391±0.0371 0.257995±0.1398 {} +statlog-german-credit 1000 24 2 21.88 11.44 6.44 0.762200±0.0264 0.307171±0.0651 {} +statlog-heart 270 13 2 13.52 7.26 4.72 0.821481±0.0433 0.018700±0.0046 {} +statlog-image 2310 18 7 28.96 14.98 10.14 0.958225±0.0081 1.604884±0.4019 {} +statlog-vehicle 846 18 4 21.28 11.14 6.66 0.783314±0.0397 0.335944±0.0666 {} +synthetic-control 600 60 6 12.64 6.82 6.30 0.941167±0.0317 0.413927±0.0837 {} +tic-tac-toe 958 9 2 3.00 2.00 2.00 0.983296±0.0084 0.015215±0.0018 {} +vertebral-column-2clases 310 6 2 9.32 5.16 4.48 0.850323±0.0414 0.010402±0.0018 {} +wine 178 13 3 5.00 3.00 3.00 0.975254±0.0243 0.002708±0.0001 {} +zoo 101 16 7 13.00 7.00 7.00 0.946619±0.0453 0.009794±0.0004 {} +Time: 0h 15m 36s