mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 15:36:01 +00:00
Add tunnel to mysql
add any kernel to script generator
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -132,3 +132,4 @@ dmypy.json
|
|||||||
.vscode
|
.vscode
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
experimentation/.myconfig
|
experimentation/.myconfig
|
||||||
|
experimentation/.tunnel
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from experimentation.Sets import Datasets
|
from experimentation.Sets import Datasets
|
||||||
from experimentation.Utils import TextColor, MySQL
|
from experimentation.Utils import TextColor
|
||||||
|
from experimentation.Database import MySQL
|
||||||
|
|
||||||
models = ["stree", "odte", "adaBoost", "bagging"]
|
models = ["stree", "odte", "adaBoost", "bagging"]
|
||||||
title = "Best model results"
|
title = "Best model results"
|
||||||
@@ -78,7 +79,8 @@ def report_footer(agg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
database = MySQL.get_connection()
|
dbh = MySQL()
|
||||||
|
database = dbh.get_connection()
|
||||||
dt = Datasets(False, False, "tanveer")
|
dt = Datasets(False, False, "tanveer")
|
||||||
fields = ("Dataset", "Reference")
|
fields = ("Dataset", "Reference")
|
||||||
for model in models:
|
for model in models:
|
||||||
@@ -131,3 +133,4 @@ for dataset in dt:
|
|||||||
)
|
)
|
||||||
print(report_line(line))
|
print(report_line(line))
|
||||||
report_footer(agg)
|
report_footer(agg)
|
||||||
|
dbh.close()
|
||||||
|
18
cross_all.py
18
cross_all.py
@@ -32,22 +32,18 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
|||||||
default="aaai",
|
default="aaai",
|
||||||
)
|
)
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
return (
|
return (args.host, args.model, args.set_of_files)
|
||||||
args.host,
|
|
||||||
args.model,
|
|
||||||
args.set_of_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
(
|
(host, model, set_of_files) = parse_arguments()
|
||||||
host,
|
|
||||||
model,
|
|
||||||
set_of_files,
|
|
||||||
) = parse_arguments()
|
|
||||||
datasets = Datasets(False, False, set_of_files)
|
datasets = Datasets(False, False, set_of_files)
|
||||||
clf = None
|
clf = None
|
||||||
experiment = Experiment(
|
experiment = Experiment(
|
||||||
random_state=1, model=model, host=host, set_of_files=set_of_files
|
random_state=1,
|
||||||
|
model=model,
|
||||||
|
host=host,
|
||||||
|
set_of_files=set_of_files,
|
||||||
|
kernel="any",
|
||||||
)
|
)
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
print(f"-Cross validation on {dataset[0]}")
|
print(f"-Cross validation on {dataset[0]}")
|
||||||
|
@@ -1,5 +1,4 @@
|
|||||||
host=<server>
|
host=<server>
|
||||||
port=3306
|
|
||||||
user=stree
|
user=stree
|
||||||
password=<password>
|
password=<password>
|
||||||
database=stree_experiments
|
database=stree_experiments
|
4
experimentation/.tunnel.dist
Normal file
4
experimentation/.tunnel.dist
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
ssh_address_or_host=(<host>, <port>)
|
||||||
|
ssh_username=<user>
|
||||||
|
ssh_private_key=<path_to>/id_rsa
|
||||||
|
remote_bind_address=('127.0.0.1', 3306)
|
@@ -3,9 +3,43 @@ import sqlite3
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import mysql.connector
|
||||||
|
from ast import literal_eval as make_tuple
|
||||||
|
from sshtunnel import SSHTunnelForwarder
|
||||||
from .Models import ModelBase
|
from .Models import ModelBase
|
||||||
from .Utils import TextColor, MySQL
|
from .Utils import TextColor
|
||||||
|
|
||||||
|
|
||||||
|
class MySQL:
|
||||||
|
def __init__(self):
|
||||||
|
self.server = None
|
||||||
|
|
||||||
|
def get_connection(self):
|
||||||
|
config_db = dict()
|
||||||
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
with open(os.path.join(dir_path, ".myconfig")) as f:
|
||||||
|
for line in f.read().splitlines():
|
||||||
|
key, value = line.split("=")
|
||||||
|
config_db[key] = value
|
||||||
|
config_tunnel = dict()
|
||||||
|
with open(os.path.join(dir_path, ".tunnel")) as f:
|
||||||
|
for line in f.read().splitlines():
|
||||||
|
key, value = line.split("=")
|
||||||
|
config_tunnel[key] = value
|
||||||
|
config_tunnel["remote_bind_address"] = make_tuple(
|
||||||
|
config_tunnel["remote_bind_address"]
|
||||||
|
)
|
||||||
|
config_tunnel["ssh_address_or_host"] = make_tuple(
|
||||||
|
config_tunnel["ssh_address_or_host"]
|
||||||
|
)
|
||||||
|
self.server = SSHTunnelForwarder(**config_tunnel)
|
||||||
|
self.server.daemon_forward_servers = True
|
||||||
|
self.server.start()
|
||||||
|
config_db["port"] = self.server.local_bind_port
|
||||||
|
return mysql.connector.connect(**config_db)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.server.close()
|
||||||
|
|
||||||
|
|
||||||
class BD(ABC):
|
class BD(ABC):
|
||||||
@@ -108,7 +142,8 @@ class BD(ABC):
|
|||||||
:param record: data to insert in database
|
:param record: data to insert in database
|
||||||
:type record: dict
|
:type record: dict
|
||||||
"""
|
"""
|
||||||
database = MySQL.get_connection()
|
dbh = MySQL()
|
||||||
|
database = dbh.get_connection()
|
||||||
command_insert = (
|
command_insert = (
|
||||||
"replace into results (date, time, type, accuracy, "
|
"replace into results (date, time, type, accuracy, "
|
||||||
"dataset, classifier, norm, stand, parameters) values (%s, %s, "
|
"dataset, classifier, norm, stand, parameters) values (%s, %s, "
|
||||||
@@ -131,6 +166,7 @@ class BD(ABC):
|
|||||||
cursor = database.cursor()
|
cursor = database.cursor()
|
||||||
cursor.execute(command_insert, values)
|
cursor.execute(command_insert, values)
|
||||||
database.commit()
|
database.commit()
|
||||||
|
dbh.close()
|
||||||
|
|
||||||
def execute(self, command: str) -> None:
|
def execute(self, command: str) -> None:
|
||||||
c = self._con.cursor()
|
c = self._con.cursor()
|
||||||
|
@@ -1,7 +1,3 @@
|
|||||||
import os
|
|
||||||
import mysql.connector
|
|
||||||
|
|
||||||
|
|
||||||
class TextColor:
|
class TextColor:
|
||||||
BLUE = "\033[94m"
|
BLUE = "\033[94m"
|
||||||
CYAN = "\033[96m"
|
CYAN = "\033[96m"
|
||||||
@@ -18,15 +14,3 @@ class TextColor:
|
|||||||
ENDC = "\033[0m"
|
ENDC = "\033[0m"
|
||||||
BOLD = "\033[1m"
|
BOLD = "\033[1m"
|
||||||
UNDERLINE = "\033[4m"
|
UNDERLINE = "\033[4m"
|
||||||
|
|
||||||
|
|
||||||
class MySQL:
|
|
||||||
@staticmethod
|
|
||||||
def get_connection():
|
|
||||||
config = dict()
|
|
||||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
with open(os.path.join(dir_path, ".myconfig")) as f:
|
|
||||||
for line in f.read().splitlines():
|
|
||||||
key, value = line.split("=")
|
|
||||||
config[key] = value
|
|
||||||
return mysql.connector.connect(**config)
|
|
||||||
|
@@ -30,10 +30,10 @@
|
|||||||
"import json\n",
|
"import json\n",
|
||||||
"import sqlite3\n",
|
"import sqlite3\n",
|
||||||
"import mysql.connector\n",
|
"import mysql.connector\n",
|
||||||
"from experimentation.Utils import MySQL\n",
|
"from experimentation.Database import MySQL\n",
|
||||||
"from experimentation.Sets import Datasets\n",
|
"from experimentation.Sets import Datasets\n",
|
||||||
"\n",
|
"dbh = MySQL()\n",
|
||||||
"database = MySQL.get_connection()"
|
"database = dbh.get_connection()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -375,6 +375,15 @@
|
|||||||
"find_values('max_features', 'linear')"
|
"find_values('max_features', 'linear')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dbh.close()"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
@@ -1,7 +1,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from experimentation.Sets import Datasets
|
from experimentation.Sets import Datasets
|
||||||
from experimentation.Utils import TextColor, MySQL
|
from experimentation.Utils import TextColor
|
||||||
|
from experimentation.Database import MySQL
|
||||||
|
|
||||||
models = ["stree", "adaBoost", "bagging", "odte"]
|
models = ["stree", "adaBoost", "bagging", "odte"]
|
||||||
|
|
||||||
@@ -114,7 +115,8 @@ def report_footer(agg):
|
|||||||
classifier,
|
classifier,
|
||||||
exclude_parameters,
|
exclude_parameters,
|
||||||
) = parse_arguments()
|
) = parse_arguments()
|
||||||
database = MySQL.get_connection()
|
dbh = MySQL()
|
||||||
|
database = dbh.get_connection()
|
||||||
dt = Datasets(False, False, "tanveer")
|
dt = Datasets(False, False, "tanveer")
|
||||||
title = "Best Hyperparameters found for datasets"
|
title = "Best Hyperparameters found for datasets"
|
||||||
lengths = (10, 8, 10, 10, 30, 3, 3, 9, 11)
|
lengths = (10, 8, 10, 10, 30, 3, 3, 9, 11)
|
||||||
@@ -151,3 +153,4 @@ for dataset in dt:
|
|||||||
)
|
)
|
||||||
print(color + report_line(record, agg))
|
print(color + report_line(record, agg))
|
||||||
report_footer(agg)
|
report_footer(agg)
|
||||||
|
dbh.close()
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
for i in gridsearch gridbest cross; do
|
for i in gridsearch gridbest cross; do
|
||||||
echo "*** Building $i experiments"
|
echo "*** Building $i experiments"
|
||||||
for j in stree odte bagging adaBoost; do
|
for j in stree odte bagging adaBoost; do
|
||||||
for k in linear poly rbf; do
|
for k in linear poly rbf any; do
|
||||||
./genjobs.sh $i $j $k
|
./genjobs.sh $i $j $k
|
||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
2
setup.py
2
setup.py
@@ -36,6 +36,8 @@ setuptools.setup(
|
|||||||
"ipympl",
|
"ipympl",
|
||||||
"stree",
|
"stree",
|
||||||
"odte",
|
"odte",
|
||||||
|
"sshtunnel",
|
||||||
|
"mysql-connector-python",
|
||||||
],
|
],
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
)
|
)
|
||||||
|
@@ -21,11 +21,12 @@
|
|||||||
"dataset_name = \"cylinder-bands\"\n",
|
"dataset_name = \"cylinder-bands\"\n",
|
||||||
"dataset_name = \"pima\"\n",
|
"dataset_name = \"pima\"\n",
|
||||||
"dataset_name = \"conn-bench-sonar-mines-rocks\"\n",
|
"dataset_name = \"conn-bench-sonar-mines-rocks\"\n",
|
||||||
|
"dataset_name = \"libras\"\n",
|
||||||
"parameters = {\"C\": .15, \"degree\": 6, \"gamma\": .7, \"kernel\": \"poly\", \"max_features\": None, \"max_iter\": 100000.0, \"random_state\": 0}\n",
|
"parameters = {\"C\": .15, \"degree\": 6, \"gamma\": .7, \"kernel\": \"poly\", \"max_features\": None, \"max_iter\": 100000.0, \"random_state\": 0}\n",
|
||||||
"parameters = {'C': 7, 'degree': 7, 'gamma': 0.1, 'kernel': 'poly', 'max_features': 'auto', 'max_iter': 10000.0, 'random_state': 1, 'split_criteria': 'impurity'}\n",
|
"#parameters = {'C': .17, 'degree': 5, 'gamma': 0.1, 'kernel': 'poly', 'max_features': 'auto', 'max_iter': 10000.0, 'random_state': 1, 'split_criteria': 'impurity'}\n",
|
||||||
"parameters = {\"C\": 0.2, \"max_iter\": 10000.0, \"random_state\": 1}\n",
|
"#parameters = {\"C\": 0.2, \"max_iter\": 10000.0, \"random_state\": 1}\n",
|
||||||
"parameters = {\"C\": 0.55, \"gamma\": 0.1, \"kernel\": \"rbf\", \"max_iter\": 10000.0, \"random_state\": 1}\n",
|
"#parameters = {\"C\": 0.55, \"gamma\": 0.1, \"kernel\": \"rbf\", \"max_iter\": 10000.0, \"random_state\": 1}\n",
|
||||||
"parameters = {\"C\": 55, \"max_iter\": 10000.0, \"random_state\": 1}"
|
"#parameters = {\"C\": 55, \"max_iter\": 10000.0, \"random_state\": 1}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -40,61 +41,102 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/rmontanana/.virtualenvs/general/lib/python3.8/site-packages/sklearn/svm/_base.py:976: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
|
|
||||||
" warnings.warn(\"Liblinear failed to converge, increase \"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"clf = Stree(**parameters)\n",
|
"clf = Stree(**parameters)\n",
|
||||||
"results = cross_validate(clf, X, y, n_jobs=1)"
|
"results = cross_validate(clf, X, y, n_jobs=1, return_estimator=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"name": "stdout",
|
||||||
"text/plain": [
|
"output_type": "stream",
|
||||||
"{'fit_time': array([0.0078361 , 0.03171897, 0.01422501, 0.06850815, 0.05387974]),\n",
|
"text": [
|
||||||
" 'score_time': array([0.0005939 , 0.00044203, 0.00043583, 0.00050902, 0.00044012]),\n",
|
"root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9999 counts=(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([19, 19, 19, 19, 20, 19, 19, 19, 19, 20, 19, 19, 19, 19, 20]))\n",
|
||||||
" 'test_score': array([0.4047619 , 0.61904762, 0.66666667, 0.92682927, 0.58536585])}"
|
"root - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9613 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([19, 19, 18, 11, 10, 18, 4, 19, 20, 8, 19, 3, 18, 11]))\n",
|
||||||
|
"root - Down - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.8726 counts=(array([ 5, 13, 14]), array([ 3, 4, 10]))\n",
|
||||||
|
"root - Down - Down - Down, <pure> - Leaf class=5 belief= 1.000000 impurity=0.0000 counts=(array([5]), array([2]))\n",
|
||||||
|
"root - Down - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.7312 counts=(array([ 5, 13, 14]), array([ 1, 4, 10]))\n",
|
||||||
|
"root - Down - Down - Up - Down, <pure> - Leaf class=5 belief= 1.000000 impurity=0.0000 counts=(array([5]), array([1]))\n",
|
||||||
|
"root - Down - Down - Up - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.8631 counts=(array([13, 14]), array([ 4, 10]))\n",
|
||||||
|
"root - Down - Down - Up - Up - Down, <pure> - Leaf class=13 belief= 1.000000 impurity=0.0000 counts=(array([13]), array([4]))\n",
|
||||||
|
"root - Down - Down - Up - Up - Up, <pure> - Leaf class=14 belief= 1.000000 impurity=0.0000 counts=(array([14]), array([10]))\n",
|
||||||
|
"root - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9359 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([19, 19, 18, 11, 7, 18, 4, 19, 20, 8, 19, 3, 14, 1]))\n",
|
||||||
|
"root - Down - Up - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9548 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]), array([19, 19, 18, 11, 7, 18, 4, 19, 20, 8, 19, 3, 14]))\n",
|
||||||
|
"root - Down - Up - Down - Down, <pure> - Leaf class=12 belief= 1.000000 impurity=0.0000 counts=(array([12]), array([3]))\n",
|
||||||
|
"root - Down - Up - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9675 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13]), array([19, 19, 18, 11, 7, 18, 4, 19, 20, 8, 19, 14]))\n",
|
||||||
|
"root - Down - Up - Down - Up - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.7248 counts=(array([5, 6, 7]), array([5, 1, 1]))\n",
|
||||||
|
"root - Down - Up - Down - Up - Down - Down, <pure> - Leaf class=6 belief= 1.000000 impurity=0.0000 counts=(array([6]), array([1]))\n",
|
||||||
|
"root - Down - Up - Down - Up - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.6500 counts=(array([5, 7]), array([5, 1]))\n",
|
||||||
|
"root - Down - Up - Down - Up - Down - Up - Down, <pure> - Leaf class=5 belief= 1.000000 impurity=0.0000 counts=(array([5]), array([5]))\n",
|
||||||
|
"root - Down - Up - Down - Up - Down - Up - Up, <pure> - Leaf class=7 belief= 1.000000 impurity=0.0000 counts=(array([7]), array([1]))\n",
|
||||||
|
"root - Down - Up - Down - Up - Up, <cgaf> - Leaf class=9 belief= 0.118343 impurity=0.9488 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13]), array([19, 19, 18, 11, 2, 17, 3, 19, 20, 8, 19, 14]))\n",
|
||||||
|
"root - Down - Up - Up, <pure> - Leaf class=14 belief= 1.000000 impurity=0.0000 counts=(array([14]), array([1]))\n",
|
||||||
|
"root - Up, <cgaf> - Leaf class=3 belief= 0.208791 impurity=0.8775 counts=(array([ 2, 3, 4, 5, 6, 7, 10, 12, 13, 14]), array([ 1, 19, 9, 9, 1, 15, 11, 16, 1, 9]))\n",
|
||||||
|
"\n"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"results"
|
"print(results['estimator'][0])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||||
|
"dt = DecisionTreeClassifier(random_state=1)\n",
|
||||||
|
"resdt = cross_validate(dt, X, y, n_jobs=1, return_estimator=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"name": "stdout",
|
||||||
"text/plain": [
|
"output_type": "stream",
|
||||||
"0.640534262485482"
|
"text": [
|
||||||
|
"0.5416666666666667\n"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"source": [
|
||||||
|
"print(resdt['test_score'].mean()) ki"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[0.09722222 0.25 0.54166667 0.31944444 0.13888889]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(results['test_score'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"results['test_score'].mean()"
|
"results['test_score'].mean()"
|
||||||
]
|
]
|
||||||
@@ -371,9 +413,9 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"clf = Stree()\n",
|
"#clf = Stree()\n",
|
||||||
"model = GridSearchCV(clf, n_jobs=1, verbose=10, param_grid=param_grid, cv=2)\n",
|
"#model = GridSearchCV(clf, n_jobs=1, verbose=10, param_grid=param_grid, cv=2)\n",
|
||||||
"model.fit(X, y)"
|
"#model.fit(X, y)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -677,6 +719,68 @@
|
|||||||
"b._base_model.get_model_name()"
|
"b._base_model.get_model_name()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def _get_subspaces_set(\n",
|
||||||
|
" self, dataset: np.array, labels: np.array, max_features: int\n",
|
||||||
|
" ) -> np.array:\n",
|
||||||
|
" features = range(dataset.shape[1])\n",
|
||||||
|
" features_sets = list(combinations(features, max_features))\n",
|
||||||
|
" if len(features_sets) > 1:\n",
|
||||||
|
" if self._splitter_type == \"random\":\n",
|
||||||
|
" index = random.randint(0, len(features_sets) - 1)\n",
|
||||||
|
" return features_sets[index]\n",
|
||||||
|
" else:\n",
|
||||||
|
" # get only 3 sets at most\n",
|
||||||
|
" if len(features_sets) > 3:\n",
|
||||||
|
" features_sets = random.sample(features_sets, 3)\n",
|
||||||
|
" return self._select_best_set(dataset, labels, features_sets)\n",
|
||||||
|
" else:\n",
|
||||||
|
" return features_sets[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import random\n",
|
||||||
|
"def generate_subspaces(features: int, max_features: int) -> list:\n",
|
||||||
|
" combs = set()\n",
|
||||||
|
" # take 3 combinations at most\n",
|
||||||
|
" combinations = 1 if max_features == features else 3\n",
|
||||||
|
" while len(combs) < combinations:\n",
|
||||||
|
" combs.add(tuple(sorted(random.sample(range(features), max_features))))\n",
|
||||||
|
" return list(combs)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[(2, 17, 51, 66, 79, 88, 98, 105, 145, 150),\n",
|
||||||
|
" (1, 11, 27, 48, 86, 117, 124, 157, 180, 194),\n",
|
||||||
|
" (11, 31, 42, 43, 61, 79, 138, 149, 150, 189)]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"generate_subspaces(200, 10)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
Reference in New Issue
Block a user