diff --git a/data/stree.sqlite b/data/stree.sqlite new file mode 100644 index 0000000..ea91b79 Binary files /dev/null and b/data/stree.sqlite differ diff --git a/experimentation/.myconfig.dist b/experimentation/.myconfig.dist index adcb49d..604f98e 100644 --- a/experimentation/.myconfig.dist +++ b/experimentation/.myconfig.dist @@ -2,4 +2,5 @@ host= port=tunnel user=stree password= -database=stree_experiments \ No newline at end of file +database=stree_experiments +sqlite=None diff --git a/experimentation/Database.py b/experimentation/Database.py index 282b175..22d9955 100644 --- a/experimentation/Database.py +++ b/experimentation/Database.py @@ -42,13 +42,21 @@ class MySQL: if self._tunnel: self._server.start() self._config_db["port"] = self._server.local_bind_port - self._database = mysql.connector.connect(**self._config_db) + if self._config_db["sqlite"] == "None": + del self._config_db["sqlite"] + self._config_db["buffered"] = True + self._database = mysql.connector.connect(**self._config_db) + else: + self._database = sqlite3.connect(self._config_db["sqlite"]) + # return dict as a result of select + self._database.row_factory = sqlite3.Row + return self._database def find_best( self, dataset, classifier="any", experiment="any", time_info=False ): - cursor = self._database.cursor(buffered=True) + cursor = self._database.cursor() date_from = "2021-01-20" # date_to = "2021-04-07" command = ( diff --git a/experimentation/mysql/.myconfig b/experimentation/mysql/.myconfig new file mode 100644 index 0000000..7b79ceb --- /dev/null +++ b/experimentation/mysql/.myconfig @@ -0,0 +1,6 @@ +host=127.0.0.1 +port=3306 +user=stree +password=xtree +database=stree +sqlite=None diff --git a/experimentation/mysql/.tunnel b/experimentation/mysql/.tunnel new file mode 100644 index 0000000..3c52126 --- /dev/null +++ b/experimentation/mysql/.tunnel @@ -0,0 +1,5 @@ +ssh_address_or_host=('atenea.rmontanana.es', 31427) +ssh_username=rmontanana +ssh_private_key=/home/rmontanana/.ssh/id_rsa +remote_bind_address=('127.0.0.1', 3306) +enabled=1 diff --git a/experimentation/sqlite/.myconfig b/experimentation/sqlite/.myconfig new file mode 100644 index 0000000..757273c --- /dev/null +++ b/experimentation/sqlite/.myconfig @@ -0,0 +1,6 @@ +host=127.0.0.1 +port=3306 +user=stree +password=xtree +database=stree +sqlite=./data/stree.sqlite diff --git a/experimentation/sqlite/.tunnel b/experimentation/sqlite/.tunnel new file mode 100644 index 0000000..556b9f1 --- /dev/null +++ b/experimentation/sqlite/.tunnel @@ -0,0 +1,5 @@ +ssh_address_or_host=('atenea.rmontanana.es', 31427) +ssh_username=rmontanana +ssh_private_key=/home/rmontanana/.ssh/id_rsa +remote_bind_address=('127.0.0.1', 3306) +enabled=0 diff --git a/report_score.py b/report_score.py index cd51f33..fa8143b 100644 --- a/report_score.py +++ b/report_score.py @@ -126,11 +126,13 @@ def process_dataset(dataset, verbose, model, params): record = dbh.find_best(dataset, model, "gridsearch") hyperparameters = json.loads(record[8] if record[8] != "" else "{}") hyperparameters.pop("random_state", None) + print("*" * 100) for random_state in random_seeds: random.seed(random_state) np.random.seed(random_state) kfold = KFold(shuffle=True, random_state=random_state, n_splits=5) clf = get_classifier(model, random_state, hyperparameters) + print(hyperparameters) res = cross_validate(clf, X, y, cv=kfold, return_estimator=True) scores.append(res["test_score"]) times.append(res["fit_time"]) @@ -361,7 +363,7 @@ standardize = False excel, discretize, ) = parse_arguments() -# parameters = '{"splitter":"cfs","max_features":"auto"}' +# parameters = '{"kernel":"rbf","max_features":"auto"}' dbh = MySQL() if sql: sql_output = open(f"{model}.sql", "w")