diff --git a/score_all_cfs.py b/score_all_cfs.py new file mode 100755 index 0000000..2a9b665 --- /dev/null +++ b/score_all_cfs.py @@ -0,0 +1,64 @@ +import sys +import time +import warnings + +from experimentation.Sets import Datasets +from stree import Stree +from mdlp import MDLP +from mfs import MFS + + +def header(filter_name): + print(f"Score files") + initial = f"{'Dataset':30s} T. Disc T.Selec " + sec_line = "=" * 30 + " ======= ======= " + for item in ["Normal", "Discret.", filter_name.upper()]: + initial += f"{item:10s} " + sec_line += "=" * 10 + " " + initial += "Reduction" + sec_line += "=========" + print(initial) + print(sec_line) + + +warnings.filterwarnings("ignore") +if len(sys.argv) > 1: + filter_name = sys.argv[1] +else: + filter_name = "cfs" +if filter_name not in ["cfs", "fcbs"]: + print("First parameter has to be one of: {cfs, fcbs}") +datasets = Datasets(False, False, "tanveer") +header(filter_name) +better = worse = equal = 0 +for dataset in datasets: + mdlp = MDLP(random_state=1) + X, y = datasets.load(dataset[0]) + mfs = MFS() + now_disc = time.time() + X_disc = mdlp.fit_transform(X, y) + time_disc = time.time() - now_disc + now_selec = time.time() + if filter_name == "cfs": + features_selected = mfs.cfs(X_disc, y).get_results() + else: + features_selected = mfs.fcbs(X_disc, y, 5e-2).get_results() + time_selec = time.time() - now_selec + output = "" + odte_score = stree_score = 0.0 + now = time.time() + clf = Stree(random_state=1, multiclass_strategy="ovo") + score_norm = clf.fit(X, y).score(X, y) + clf = Stree(random_state=1, multiclass_strategy="ovo") + score_disc = clf.fit(X_disc, y).score(X_disc, y) + if len(features_selected) > 0: + X_feat = X_disc[:, features_selected] + clf = Stree(random_state=1, multiclass_strategy="ovo") + score_fs = clf.fit(X_feat, y).score(X_feat, y) + else: + score_fs = 0.0 + output = f"{dataset[0]:30s} {time_disc:7.3f} {time_selec:7.3f} " + output += f"{score_norm:.8f} " + output += f"{score_disc:.8f} {score_fs:.8f} " + output += f"{X.shape[1]:3} - {len(features_selected):3}" + print(output)