feat: Add discretize and fix stratified hyperparameters in be_main

This commit is contained in:
Ricardo Montañana Gómez
2023-01-21 22:17:25 +01:00
parent 520f8807e5
commit 5ff6265a08
3 changed files with 23 additions and 4 deletions

View File

@@ -92,6 +92,17 @@ class Arguments(argparse.ArgumentParser):
"help": "dataset to work with",
},
],
"discretize": [
("--discretize",),
{
"action": EnvDefault,
"envvar": "discretize",
"required": True,
"help": "Discretize dataset",
"const": "1",
"nargs": "?",
},
],
"excel": [
("-x", "--excel"),
{
@@ -260,6 +271,8 @@ class Arguments(argparse.ArgumentParser):
"envvar": "stratified",
"required": True,
"help": "Stratified",
"const": "1",
"nargs": "?",
},
],
"tex_output": [

View File

@@ -108,14 +108,18 @@ class DatasetsSurcov:
class Datasets:
def __init__(self, dataset_name=None):
def __init__(self, dataset_name=None, discretize=None):
envData = EnvData.load()
# DatasetsSurcov, DatasetsTanveer, DatasetsArff,...
source_name = getattr(
__import__(__name__),
f"Datasets{envData['source_data']}",
)
self.discretize = envData["discretize"] == "1"
self.discretize = (
envData["discretize"] == "1"
if discretize is None
else discretize == "1"
)
self.dataset = source_name()
# initialize self.class_names & self.data_sets
class_names, sets = self._init_names(dataset_name)

View File

@@ -13,7 +13,7 @@ def main(args_test=None):
arguments = Arguments(prog="be_main")
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
arguments.xset("report").xset("ignore_nan")
arguments.xset("report").xset("ignore_nan").xset("discretize")
arguments.add_exclusive(
["grid_paramfile", "best_paramfile", "hyperparameters"]
)
@@ -29,7 +29,9 @@ def main(args_test=None):
score_name=args.score,
model_name=args.model,
stratified=args.stratified,
datasets=Datasets(dataset_name=args.dataset),
datasets=Datasets(
dataset_name=args.dataset, discretize=args.discretize
),
hyperparams_dict=args.hyperparameters,
hyperparams_file=args.best_paramfile,
grid_paramfile=args.grid_paramfile,