feat: Add discretize and fix stratified hyperparameters in be_main

This commit is contained in:
Ricardo Montañana Gómez
2023-01-21 22:17:25 +01:00
parent 520f8807e5
commit 5ff6265a08
3 changed files with 23 additions and 4 deletions

View File

@@ -92,6 +92,17 @@ class Arguments(argparse.ArgumentParser):
"help": "dataset to work with", "help": "dataset to work with",
}, },
], ],
"discretize": [
("--discretize",),
{
"action": EnvDefault,
"envvar": "discretize",
"required": True,
"help": "Discretize dataset",
"const": "1",
"nargs": "?",
},
],
"excel": [ "excel": [
("-x", "--excel"), ("-x", "--excel"),
{ {
@@ -260,6 +271,8 @@ class Arguments(argparse.ArgumentParser):
"envvar": "stratified", "envvar": "stratified",
"required": True, "required": True,
"help": "Stratified", "help": "Stratified",
"const": "1",
"nargs": "?",
}, },
], ],
"tex_output": [ "tex_output": [

View File

@@ -108,14 +108,18 @@ class DatasetsSurcov:
class Datasets: class Datasets:
def __init__(self, dataset_name=None): def __init__(self, dataset_name=None, discretize=None):
envData = EnvData.load() envData = EnvData.load()
# DatasetsSurcov, DatasetsTanveer, DatasetsArff,... # DatasetsSurcov, DatasetsTanveer, DatasetsArff,...
source_name = getattr( source_name = getattr(
__import__(__name__), __import__(__name__),
f"Datasets{envData['source_data']}", f"Datasets{envData['source_data']}",
) )
self.discretize = envData["discretize"] == "1" self.discretize = (
envData["discretize"] == "1"
if discretize is None
else discretize == "1"
)
self.dataset = source_name() self.dataset = source_name()
# initialize self.class_names & self.data_sets # initialize self.class_names & self.data_sets
class_names, sets = self._init_names(dataset_name) class_names, sets = self._init_names(dataset_name)

View File

@@ -13,7 +13,7 @@ def main(args_test=None):
arguments = Arguments(prog="be_main") arguments = Arguments(prog="be_main")
arguments.xset("stratified").xset("score").xset("model", mandatory=True) arguments.xset("stratified").xset("score").xset("model", mandatory=True)
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title") arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
arguments.xset("report").xset("ignore_nan") arguments.xset("report").xset("ignore_nan").xset("discretize")
arguments.add_exclusive( arguments.add_exclusive(
["grid_paramfile", "best_paramfile", "hyperparameters"] ["grid_paramfile", "best_paramfile", "hyperparameters"]
) )
@@ -29,7 +29,9 @@ def main(args_test=None):
score_name=args.score, score_name=args.score,
model_name=args.model, model_name=args.model,
stratified=args.stratified, stratified=args.stratified,
datasets=Datasets(dataset_name=args.dataset), datasets=Datasets(
dataset_name=args.dataset, discretize=args.discretize
),
hyperparams_dict=args.hyperparameters, hyperparams_dict=args.hyperparameters,
hyperparams_file=args.best_paramfile, hyperparams_file=args.best_paramfile,
grid_paramfile=args.grid_paramfile, grid_paramfile=args.grid_paramfile,