mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 15:36:01 +00:00
Add wodt clf
Add execution results of RaF, RoF and RRoF Fix fit time in database records
This commit is contained in:
125
testwodt.py
Normal file
125
testwodt.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import argparse
|
||||
from wodt import TreeClassifier
|
||||
from sklearn.model_selection import cross_val_score
|
||||
import numpy as np
|
||||
import random
|
||||
from experimentation.Sets import Datasets
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument(
|
||||
"-S",
|
||||
"--set-of-files",
|
||||
type=str,
|
||||
choices=["aaai", "tanveer"],
|
||||
required=False,
|
||||
default="aaai",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Dataset name",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-n",
|
||||
"--normalize",
|
||||
default=False,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="Normalize dataset (True/False)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-s",
|
||||
"--standardize",
|
||||
default=False,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="Standardize dataset (True/False)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-p",
|
||||
"--paper-norm",
|
||||
default=False,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="[-1, 1] normalization like on paper (True/False)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-r",
|
||||
"--random-set",
|
||||
default=0,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Set of random seeds: {0, 1}",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
return (
|
||||
args.set_of_files,
|
||||
args.dataset,
|
||||
args.normalize,
|
||||
args.standardize,
|
||||
args.paper_norm,
|
||||
args.random_set,
|
||||
)
|
||||
|
||||
|
||||
def normalize_paper(data):
|
||||
min_data = data.min()
|
||||
return 2 * (data - min_data) / (data.max() - min_data) - 1
|
||||
|
||||
|
||||
def process_dataset(dataset, verbose):
|
||||
X, y = dt.load(dataset)
|
||||
if paper_norm:
|
||||
X = normalize_paper(X)
|
||||
scores = []
|
||||
if verbose:
|
||||
print(f"* Processing dataset [{dataset}] from Set: {set_of_files}")
|
||||
print(f"X.shape: {X.shape}")
|
||||
print(f"{X[:4]}")
|
||||
print(f"Random seeds: {random_seeds}")
|
||||
print(f"[-1, 1]: {paper_norm} norm: {normalize} std: {standardize}")
|
||||
for random_state in random_seeds:
|
||||
random.seed(random_state)
|
||||
np.random.seed(random_state)
|
||||
clf = TreeClassifier(random_state=random_state)
|
||||
res = cross_val_score(clf, X, y, cv=5)
|
||||
scores.append(res)
|
||||
if verbose:
|
||||
print(
|
||||
f"Random seed: {random_state:5d} Accuracy: {res.mean():6.4f}"
|
||||
f"±{res.std():6.4f}"
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
(
|
||||
set_of_files,
|
||||
dataset,
|
||||
normalize,
|
||||
standardize,
|
||||
paper_norm,
|
||||
random_set,
|
||||
) = parse_arguments()
|
||||
random_seeds = (
|
||||
[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||
if random_set == 0
|
||||
else [32, 24, 56, 18, 2, 94, 1256, 84, 156, 42]
|
||||
)
|
||||
dt = Datasets(normalize, standardize, set_of_files)
|
||||
if dataset == "all":
|
||||
print(
|
||||
f"* Process all datasets set: {set_of_files} [-1, 1]: {paper_norm} "
|
||||
f"norm: {normalize} std: {standardize}"
|
||||
)
|
||||
print(f"5 Fold Cross Validation with 10 random seeds {random_seeds}")
|
||||
for dataset in dt:
|
||||
print(f"- {dataset[0]:20s} ", end="")
|
||||
scores = process_dataset(dataset[0], verbose=False)
|
||||
print(f"{np.mean(scores):6.4f}±{np.std(scores):6.4f}")
|
||||
else:
|
||||
scores = process_dataset(dataset, verbose=True)
|
||||
print(f"* Accuracy: {np.mean(scores):6.4f}±{np.std(scores):6.4f}")
|
Reference in New Issue
Block a user