mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 15:36:01 +00:00
Add big test and feature selection notebook
This commit is contained in:
@@ -51,6 +51,9 @@ def parse_arguments():
|
||||
type=int,
|
||||
required=True,
|
||||
)
|
||||
ap.add_argument(
|
||||
"-p", "--parameters", type=str, required=False, default="{}"
|
||||
)
|
||||
args = ap.parse_args()
|
||||
return (
|
||||
args.set_of_files,
|
||||
@@ -58,6 +61,7 @@ def parse_arguments():
|
||||
args.dataset,
|
||||
args.sql,
|
||||
bool(args.normalize),
|
||||
args.parameters,
|
||||
)
|
||||
|
||||
|
||||
@@ -72,7 +76,7 @@ def get_classifier(model, random_state, hyperparameters):
|
||||
return clf
|
||||
|
||||
|
||||
def process_dataset(dataset, verbose, model):
|
||||
def process_dataset(dataset, verbose, model, params):
|
||||
X, y = dt.load(dataset)
|
||||
scores = []
|
||||
times = []
|
||||
@@ -87,7 +91,7 @@ def process_dataset(dataset, verbose, model):
|
||||
print(f"X.shape: {X.shape}")
|
||||
print(f"{X[:4]}")
|
||||
print(f"Random seeds: {random_seeds}")
|
||||
hyperparameters = json.loads("{}")
|
||||
hyperparameters = json.loads(params)
|
||||
if model == "stree":
|
||||
# Get the optimized parameters
|
||||
record = dbh.find_best(dataset, model, "gridsearch")
|
||||
@@ -176,7 +180,7 @@ def store_string(
|
||||
|
||||
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||
standardize = False
|
||||
(set_of_files, model, dataset, sql, normalize) = parse_arguments()
|
||||
(set_of_files, model, dataset, sql, normalize, parameters) = parse_arguments()
|
||||
dbh = MySQL()
|
||||
if sql:
|
||||
sql_output = open(f"{model}.sql", "w")
|
||||
@@ -216,7 +220,7 @@ if dataset == "all":
|
||||
end="",
|
||||
)
|
||||
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
|
||||
dataset[0], verbose=False, model=model
|
||||
dataset[0], verbose=False, model=model, params=parameters
|
||||
)
|
||||
complexity = dict(
|
||||
nodes=float(np.mean(nodes)),
|
||||
@@ -237,7 +241,7 @@ if dataset == "all":
|
||||
print(command, file=sql_output)
|
||||
else:
|
||||
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
|
||||
dataset, verbose=True, model=model
|
||||
dataset, verbose=True, model=model, params=parameters
|
||||
)
|
||||
record = dbh.find_best(dataset, model, "crossval")
|
||||
accuracy = np.mean(scores)
|
||||
|
Reference in New Issue
Block a user