mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 07:26:02 +00:00
Add sqlite config
Add sqlite to report_score
This commit is contained in:
BIN
data/stree.sqlite
Normal file
BIN
data/stree.sqlite
Normal file
Binary file not shown.
@@ -2,4 +2,5 @@ host=<server>
|
||||
port=tunnel
|
||||
user=stree
|
||||
password=<password>
|
||||
database=stree_experiments
|
||||
database=stree_experiments
|
||||
sqlite=None
|
||||
|
@@ -42,13 +42,21 @@ class MySQL:
|
||||
if self._tunnel:
|
||||
self._server.start()
|
||||
self._config_db["port"] = self._server.local_bind_port
|
||||
self._database = mysql.connector.connect(**self._config_db)
|
||||
if self._config_db["sqlite"] == "None":
|
||||
del self._config_db["sqlite"]
|
||||
self._config_db["buffered"] = True
|
||||
self._database = mysql.connector.connect(**self._config_db)
|
||||
else:
|
||||
self._database = sqlite3.connect(self._config_db["sqlite"])
|
||||
# return dict as a result of select
|
||||
self._database.row_factory = sqlite3.Row
|
||||
|
||||
return self._database
|
||||
|
||||
def find_best(
|
||||
self, dataset, classifier="any", experiment="any", time_info=False
|
||||
):
|
||||
cursor = self._database.cursor(buffered=True)
|
||||
cursor = self._database.cursor()
|
||||
date_from = "2021-01-20"
|
||||
# date_to = "2021-04-07"
|
||||
command = (
|
||||
|
6
experimentation/mysql/.myconfig
Normal file
6
experimentation/mysql/.myconfig
Normal file
@@ -0,0 +1,6 @@
|
||||
host=127.0.0.1
|
||||
port=3306
|
||||
user=stree
|
||||
password=xtree
|
||||
database=stree
|
||||
sqlite=None
|
5
experimentation/mysql/.tunnel
Normal file
5
experimentation/mysql/.tunnel
Normal file
@@ -0,0 +1,5 @@
|
||||
ssh_address_or_host=('atenea.rmontanana.es', 31427)
|
||||
ssh_username=rmontanana
|
||||
ssh_private_key=/home/rmontanana/.ssh/id_rsa
|
||||
remote_bind_address=('127.0.0.1', 3306)
|
||||
enabled=1
|
6
experimentation/sqlite/.myconfig
Normal file
6
experimentation/sqlite/.myconfig
Normal file
@@ -0,0 +1,6 @@
|
||||
host=127.0.0.1
|
||||
port=3306
|
||||
user=stree
|
||||
password=xtree
|
||||
database=stree
|
||||
sqlite=./data/stree.sqlite
|
5
experimentation/sqlite/.tunnel
Normal file
5
experimentation/sqlite/.tunnel
Normal file
@@ -0,0 +1,5 @@
|
||||
ssh_address_or_host=('atenea.rmontanana.es', 31427)
|
||||
ssh_username=rmontanana
|
||||
ssh_private_key=/home/rmontanana/.ssh/id_rsa
|
||||
remote_bind_address=('127.0.0.1', 3306)
|
||||
enabled=0
|
@@ -126,11 +126,13 @@ def process_dataset(dataset, verbose, model, params):
|
||||
record = dbh.find_best(dataset, model, "gridsearch")
|
||||
hyperparameters = json.loads(record[8] if record[8] != "" else "{}")
|
||||
hyperparameters.pop("random_state", None)
|
||||
print("*" * 100)
|
||||
for random_state in random_seeds:
|
||||
random.seed(random_state)
|
||||
np.random.seed(random_state)
|
||||
kfold = KFold(shuffle=True, random_state=random_state, n_splits=5)
|
||||
clf = get_classifier(model, random_state, hyperparameters)
|
||||
print(hyperparameters)
|
||||
res = cross_validate(clf, X, y, cv=kfold, return_estimator=True)
|
||||
scores.append(res["test_score"])
|
||||
times.append(res["fit_time"])
|
||||
@@ -361,7 +363,7 @@ standardize = False
|
||||
excel,
|
||||
discretize,
|
||||
) = parse_arguments()
|
||||
# parameters = '{"splitter":"cfs","max_features":"auto"}'
|
||||
# parameters = '{"kernel":"rbf","max_features":"auto"}'
|
||||
dbh = MySQL()
|
||||
if sql:
|
||||
sql_output = open(f"{model}.sql", "w")
|
||||
|
Reference in New Issue
Block a user