Add sqlite config

Add sqlite to report_score
This commit is contained in:
2021-06-30 09:23:58 +02:00
parent 641defd109
commit e36fed491b
8 changed files with 37 additions and 4 deletions

BIN
data/stree.sqlite Normal file

Binary file not shown.

View File

@@ -2,4 +2,5 @@ host=<server>
port=tunnel
user=stree
password=<password>
database=stree_experiments
database=stree_experiments
sqlite=None

View File

@@ -42,13 +42,21 @@ class MySQL:
if self._tunnel:
self._server.start()
self._config_db["port"] = self._server.local_bind_port
self._database = mysql.connector.connect(**self._config_db)
if self._config_db["sqlite"] == "None":
del self._config_db["sqlite"]
self._config_db["buffered"] = True
self._database = mysql.connector.connect(**self._config_db)
else:
self._database = sqlite3.connect(self._config_db["sqlite"])
# return dict as a result of select
self._database.row_factory = sqlite3.Row
return self._database
def find_best(
self, dataset, classifier="any", experiment="any", time_info=False
):
cursor = self._database.cursor(buffered=True)
cursor = self._database.cursor()
date_from = "2021-01-20"
# date_to = "2021-04-07"
command = (

View File

@@ -0,0 +1,6 @@
host=127.0.0.1
port=3306
user=stree
password=xtree
database=stree
sqlite=None

View File

@@ -0,0 +1,5 @@
ssh_address_or_host=('atenea.rmontanana.es', 31427)
ssh_username=rmontanana
ssh_private_key=/home/rmontanana/.ssh/id_rsa
remote_bind_address=('127.0.0.1', 3306)
enabled=1

View File

@@ -0,0 +1,6 @@
host=127.0.0.1
port=3306
user=stree
password=xtree
database=stree
sqlite=./data/stree.sqlite

View File

@@ -0,0 +1,5 @@
ssh_address_or_host=('atenea.rmontanana.es', 31427)
ssh_username=rmontanana
ssh_private_key=/home/rmontanana/.ssh/id_rsa
remote_bind_address=('127.0.0.1', 3306)
enabled=0

View File

@@ -126,11 +126,13 @@ def process_dataset(dataset, verbose, model, params):
record = dbh.find_best(dataset, model, "gridsearch")
hyperparameters = json.loads(record[8] if record[8] != "" else "{}")
hyperparameters.pop("random_state", None)
print("*" * 100)
for random_state in random_seeds:
random.seed(random_state)
np.random.seed(random_state)
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
clf = get_classifier(model, random_state, hyperparameters)
print(hyperparameters)
res = cross_validate(clf, X, y, cv=kfold, return_estimator=True)
scores.append(res["test_score"])
times.append(res["fit_time"])
@@ -361,7 +363,7 @@ standardize = False
excel,
discretize,
) = parse_arguments()
# parameters = '{"splitter":"cfs","max_features":"auto"}'
# parameters = '{"kernel":"rbf","max_features":"auto"}'
dbh = MySQL()
if sql:
sql_output = open(f"{model}.sql", "w")