Add KDBNew model and fit_feature hyperparameter

This commit is contained in:
2023-02-04 18:29:10 +01:00
parent d454a318fc
commit 75ed3e8f6e
4 changed files with 28 additions and 4 deletions

View File

@@ -112,6 +112,17 @@ class Arguments(argparse.ArgumentParser):
"help": "Generate Excel File",
},
],
"fit_features": [
("--fit_features",),
{
"action": EnvDefault,
"envvar": "fit_features",
"required": True,
"help": "Include features in fit call",
"const": "1",
"nargs": "?",
},
],
"grid_paramfile": [
("-g", "--grid_paramfile"),
{

View File

@@ -113,8 +113,10 @@ class Experiment:
title,
progress_bar=True,
ignore_nan=True,
fit_features=None,
folds=5,
):
env_data = EnvData.load()
today = datetime.now()
self.time = today.strftime("%H:%M:%S")
self.date = today.strftime("%Y-%m-%d")
@@ -134,6 +136,11 @@ class Experiment:
self.title = title
self.ignore_nan = ignore_nan
self.stratified = stratified == "1"
self.fit_features = (
env_data["fit_features"] == "1"
if fit_features is None
else fit_features == "1"
)
self.stratified_class = StratifiedKFold if self.stratified else KFold
self.datasets = datasets
dictionary = json.loads(hyperparams_dict)
@@ -187,11 +194,14 @@ class Experiment:
self.depths = []
def _build_fit_params(self, name):
if not self.fit_features:
return None
res = dict(features=self.datasets.get_features())
states = self.datasets.get_states(name)
if states is None:
return None
features = self.datasets.get_features()
return {"state_names": states, "features": features}
return res
res["state_names"] = states
return res
def _n_fold_crossval(self, name, X, y, hyperparameters):
if self.scores != []:

View File

@@ -8,7 +8,7 @@ from sklearn.ensemble import (
)
from sklearn.svm import SVC
from stree import Stree
from bayesclass.clfs import TAN, KDB, AODE
from bayesclass.clfs import TAN, KDB, AODE, KDBNew
from wodt import Wodt
from odte import Odte
from xgboost import XGBClassifier
@@ -41,6 +41,7 @@ class Models:
"STree": Stree(random_state=random_state),
"TAN": TAN(random_state=random_state),
"KDB": KDB(k=2),
"KDBNew": KDBNew(k=2),
"AODE": AODE(random_state=random_state),
"Cart": DecisionTreeClassifier(random_state=random_state),
"ExtraTree": ExtraTreeClassifier(random_state=random_state),

View File

@@ -14,6 +14,7 @@ def main(args_test=None):
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
arguments.xset("report").xset("ignore_nan").xset("discretize")
arguments.xset("fit_features")
arguments.add_exclusive(
["grid_paramfile", "best_paramfile", "hyperparameters"]
)
@@ -40,6 +41,7 @@ def main(args_test=None):
ignore_nan=args.ignore_nan,
title=args.title,
folds=args.n_folds,
fit_features=args.fit_features,
)
job.do_experiment()
except ValueError as e: