Add big test and feature selection notebook

This commit is contained in:
2021-04-26 01:08:57 +02:00
parent 84795b4c43
commit b061d40355
4 changed files with 650 additions and 8 deletions

View File

@@ -51,6 +51,9 @@ def parse_arguments():
type=int,
required=True,
)
ap.add_argument(
"-p", "--parameters", type=str, required=False, default="{}"
)
args = ap.parse_args()
return (
args.set_of_files,
@@ -58,6 +61,7 @@ def parse_arguments():
args.dataset,
args.sql,
bool(args.normalize),
args.parameters,
)
@@ -72,7 +76,7 @@ def get_classifier(model, random_state, hyperparameters):
return clf
def process_dataset(dataset, verbose, model):
def process_dataset(dataset, verbose, model, params):
X, y = dt.load(dataset)
scores = []
times = []
@@ -87,7 +91,7 @@ def process_dataset(dataset, verbose, model):
print(f"X.shape: {X.shape}")
print(f"{X[:4]}")
print(f"Random seeds: {random_seeds}")
hyperparameters = json.loads("{}")
hyperparameters = json.loads(params)
if model == "stree":
# Get the optimized parameters
record = dbh.find_best(dataset, model, "gridsearch")
@@ -176,7 +180,7 @@ def store_string(
random_seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
standardize = False
(set_of_files, model, dataset, sql, normalize) = parse_arguments()
(set_of_files, model, dataset, sql, normalize, parameters) = parse_arguments()
dbh = MySQL()
if sql:
sql_output = open(f"{model}.sql", "w")
@@ -216,7 +220,7 @@ if dataset == "all":
end="",
)
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
dataset[0], verbose=False, model=model
dataset[0], verbose=False, model=model, params=parameters
)
complexity = dict(
nodes=float(np.mean(nodes)),
@@ -237,7 +241,7 @@ if dataset == "all":
print(command, file=sql_output)
else:
scores, times, hyperparameters, nodes, leaves, depth = process_dataset(
dataset, verbose=True, model=model
dataset, verbose=True, model=model, params=parameters
)
record = dbh.find_best(dataset, model, "crossval")
accuracy = np.mean(scores)