mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 15:36:01 +00:00
Add sqlite config
Add sqlite to report_score
This commit is contained in:
BIN
data/stree.sqlite
Normal file
BIN
data/stree.sqlite
Normal file
Binary file not shown.
@@ -2,4 +2,5 @@ host=<server>
|
|||||||
port=tunnel
|
port=tunnel
|
||||||
user=stree
|
user=stree
|
||||||
password=<password>
|
password=<password>
|
||||||
database=stree_experiments
|
database=stree_experiments
|
||||||
|
sqlite=None
|
||||||
|
@@ -42,13 +42,21 @@ class MySQL:
|
|||||||
if self._tunnel:
|
if self._tunnel:
|
||||||
self._server.start()
|
self._server.start()
|
||||||
self._config_db["port"] = self._server.local_bind_port
|
self._config_db["port"] = self._server.local_bind_port
|
||||||
self._database = mysql.connector.connect(**self._config_db)
|
if self._config_db["sqlite"] == "None":
|
||||||
|
del self._config_db["sqlite"]
|
||||||
|
self._config_db["buffered"] = True
|
||||||
|
self._database = mysql.connector.connect(**self._config_db)
|
||||||
|
else:
|
||||||
|
self._database = sqlite3.connect(self._config_db["sqlite"])
|
||||||
|
# return dict as a result of select
|
||||||
|
self._database.row_factory = sqlite3.Row
|
||||||
|
|
||||||
return self._database
|
return self._database
|
||||||
|
|
||||||
def find_best(
|
def find_best(
|
||||||
self, dataset, classifier="any", experiment="any", time_info=False
|
self, dataset, classifier="any", experiment="any", time_info=False
|
||||||
):
|
):
|
||||||
cursor = self._database.cursor(buffered=True)
|
cursor = self._database.cursor()
|
||||||
date_from = "2021-01-20"
|
date_from = "2021-01-20"
|
||||||
# date_to = "2021-04-07"
|
# date_to = "2021-04-07"
|
||||||
command = (
|
command = (
|
||||||
|
6
experimentation/mysql/.myconfig
Normal file
6
experimentation/mysql/.myconfig
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
host=127.0.0.1
|
||||||
|
port=3306
|
||||||
|
user=stree
|
||||||
|
password=xtree
|
||||||
|
database=stree
|
||||||
|
sqlite=None
|
5
experimentation/mysql/.tunnel
Normal file
5
experimentation/mysql/.tunnel
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
ssh_address_or_host=('atenea.rmontanana.es', 31427)
|
||||||
|
ssh_username=rmontanana
|
||||||
|
ssh_private_key=/home/rmontanana/.ssh/id_rsa
|
||||||
|
remote_bind_address=('127.0.0.1', 3306)
|
||||||
|
enabled=1
|
6
experimentation/sqlite/.myconfig
Normal file
6
experimentation/sqlite/.myconfig
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
host=127.0.0.1
|
||||||
|
port=3306
|
||||||
|
user=stree
|
||||||
|
password=xtree
|
||||||
|
database=stree
|
||||||
|
sqlite=./data/stree.sqlite
|
5
experimentation/sqlite/.tunnel
Normal file
5
experimentation/sqlite/.tunnel
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
ssh_address_or_host=('atenea.rmontanana.es', 31427)
|
||||||
|
ssh_username=rmontanana
|
||||||
|
ssh_private_key=/home/rmontanana/.ssh/id_rsa
|
||||||
|
remote_bind_address=('127.0.0.1', 3306)
|
||||||
|
enabled=0
|
@@ -126,11 +126,13 @@ def process_dataset(dataset, verbose, model, params):
|
|||||||
record = dbh.find_best(dataset, model, "gridsearch")
|
record = dbh.find_best(dataset, model, "gridsearch")
|
||||||
hyperparameters = json.loads(record[8] if record[8] != "" else "{}")
|
hyperparameters = json.loads(record[8] if record[8] != "" else "{}")
|
||||||
hyperparameters.pop("random_state", None)
|
hyperparameters.pop("random_state", None)
|
||||||
|
print("*" * 100)
|
||||||
for random_state in random_seeds:
|
for random_state in random_seeds:
|
||||||
random.seed(random_state)
|
random.seed(random_state)
|
||||||
np.random.seed(random_state)
|
np.random.seed(random_state)
|
||||||
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
|
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
|
||||||
clf = get_classifier(model, random_state, hyperparameters)
|
clf = get_classifier(model, random_state, hyperparameters)
|
||||||
|
print(hyperparameters)
|
||||||
res = cross_validate(clf, X, y, cv=kfold, return_estimator=True)
|
res = cross_validate(clf, X, y, cv=kfold, return_estimator=True)
|
||||||
scores.append(res["test_score"])
|
scores.append(res["test_score"])
|
||||||
times.append(res["fit_time"])
|
times.append(res["fit_time"])
|
||||||
@@ -361,7 +363,7 @@ standardize = False
|
|||||||
excel,
|
excel,
|
||||||
discretize,
|
discretize,
|
||||||
) = parse_arguments()
|
) = parse_arguments()
|
||||||
# parameters = '{"splitter":"cfs","max_features":"auto"}'
|
# parameters = '{"kernel":"rbf","max_features":"auto"}'
|
||||||
dbh = MySQL()
|
dbh = MySQL()
|
||||||
if sql:
|
if sql:
|
||||||
sql_output = open(f"{model}.sql", "w")
|
sql_output = open(f"{model}.sql", "w")
|
||||||
|
Reference in New Issue
Block a user