mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 07:26:02 +00:00
Add tunnel to mysql
add any kernel to script generator
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -132,3 +132,4 @@ dmypy.json
|
||||
.vscode
|
||||
.pre-commit-config.yaml
|
||||
experimentation/.myconfig
|
||||
experimentation/.tunnel
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from experimentation.Sets import Datasets
|
||||
from experimentation.Utils import TextColor, MySQL
|
||||
from experimentation.Utils import TextColor
|
||||
from experimentation.Database import MySQL
|
||||
|
||||
models = ["stree", "odte", "adaBoost", "bagging"]
|
||||
title = "Best model results"
|
||||
@@ -78,7 +79,8 @@ def report_footer(agg):
|
||||
)
|
||||
|
||||
|
||||
database = MySQL.get_connection()
|
||||
dbh = MySQL()
|
||||
database = dbh.get_connection()
|
||||
dt = Datasets(False, False, "tanveer")
|
||||
fields = ("Dataset", "Reference")
|
||||
for model in models:
|
||||
@@ -131,3 +133,4 @@ for dataset in dt:
|
||||
)
|
||||
print(report_line(line))
|
||||
report_footer(agg)
|
||||
dbh.close()
|
||||
|
18
cross_all.py
18
cross_all.py
@@ -32,22 +32,18 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
|
||||
default="aaai",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
return (
|
||||
args.host,
|
||||
args.model,
|
||||
args.set_of_files,
|
||||
)
|
||||
return (args.host, args.model, args.set_of_files)
|
||||
|
||||
|
||||
(
|
||||
host,
|
||||
model,
|
||||
set_of_files,
|
||||
) = parse_arguments()
|
||||
(host, model, set_of_files) = parse_arguments()
|
||||
datasets = Datasets(False, False, set_of_files)
|
||||
clf = None
|
||||
experiment = Experiment(
|
||||
random_state=1, model=model, host=host, set_of_files=set_of_files
|
||||
random_state=1,
|
||||
model=model,
|
||||
host=host,
|
||||
set_of_files=set_of_files,
|
||||
kernel="any",
|
||||
)
|
||||
for dataset in datasets:
|
||||
print(f"-Cross validation on {dataset[0]}")
|
||||
|
@@ -1,5 +1,4 @@
|
||||
host=<server>
|
||||
port=3306
|
||||
user=stree
|
||||
password=<password>
|
||||
database=stree_experiments
|
4
experimentation/.tunnel.dist
Normal file
4
experimentation/.tunnel.dist
Normal file
@@ -0,0 +1,4 @@
|
||||
ssh_address_or_host=(<host>, <port>)
|
||||
ssh_username=<user>
|
||||
ssh_private_key=<path_to>/id_rsa
|
||||
remote_bind_address=('127.0.0.1', 3306)
|
@@ -3,9 +3,43 @@ import sqlite3
|
||||
from datetime import datetime
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
|
||||
import mysql.connector
|
||||
from ast import literal_eval as make_tuple
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
from .Models import ModelBase
|
||||
from .Utils import TextColor, MySQL
|
||||
from .Utils import TextColor
|
||||
|
||||
|
||||
class MySQL:
|
||||
def __init__(self):
|
||||
self.server = None
|
||||
|
||||
def get_connection(self):
|
||||
config_db = dict()
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
with open(os.path.join(dir_path, ".myconfig")) as f:
|
||||
for line in f.read().splitlines():
|
||||
key, value = line.split("=")
|
||||
config_db[key] = value
|
||||
config_tunnel = dict()
|
||||
with open(os.path.join(dir_path, ".tunnel")) as f:
|
||||
for line in f.read().splitlines():
|
||||
key, value = line.split("=")
|
||||
config_tunnel[key] = value
|
||||
config_tunnel["remote_bind_address"] = make_tuple(
|
||||
config_tunnel["remote_bind_address"]
|
||||
)
|
||||
config_tunnel["ssh_address_or_host"] = make_tuple(
|
||||
config_tunnel["ssh_address_or_host"]
|
||||
)
|
||||
self.server = SSHTunnelForwarder(**config_tunnel)
|
||||
self.server.daemon_forward_servers = True
|
||||
self.server.start()
|
||||
config_db["port"] = self.server.local_bind_port
|
||||
return mysql.connector.connect(**config_db)
|
||||
|
||||
def close(self):
|
||||
self.server.close()
|
||||
|
||||
|
||||
class BD(ABC):
|
||||
@@ -108,7 +142,8 @@ class BD(ABC):
|
||||
:param record: data to insert in database
|
||||
:type record: dict
|
||||
"""
|
||||
database = MySQL.get_connection()
|
||||
dbh = MySQL()
|
||||
database = dbh.get_connection()
|
||||
command_insert = (
|
||||
"replace into results (date, time, type, accuracy, "
|
||||
"dataset, classifier, norm, stand, parameters) values (%s, %s, "
|
||||
@@ -131,6 +166,7 @@ class BD(ABC):
|
||||
cursor = database.cursor()
|
||||
cursor.execute(command_insert, values)
|
||||
database.commit()
|
||||
dbh.close()
|
||||
|
||||
def execute(self, command: str) -> None:
|
||||
c = self._con.cursor()
|
||||
|
@@ -1,7 +1,3 @@
|
||||
import os
|
||||
import mysql.connector
|
||||
|
||||
|
||||
class TextColor:
|
||||
BLUE = "\033[94m"
|
||||
CYAN = "\033[96m"
|
||||
@@ -18,15 +14,3 @@ class TextColor:
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
UNDERLINE = "\033[4m"
|
||||
|
||||
|
||||
class MySQL:
|
||||
@staticmethod
|
||||
def get_connection():
|
||||
config = dict()
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
with open(os.path.join(dir_path, ".myconfig")) as f:
|
||||
for line in f.read().splitlines():
|
||||
key, value = line.split("=")
|
||||
config[key] = value
|
||||
return mysql.connector.connect(**config)
|
||||
|
@@ -30,10 +30,10 @@
|
||||
"import json\n",
|
||||
"import sqlite3\n",
|
||||
"import mysql.connector\n",
|
||||
"from experimentation.Utils import MySQL\n",
|
||||
"from experimentation.Database import MySQL\n",
|
||||
"from experimentation.Sets import Datasets\n",
|
||||
"\n",
|
||||
"database = MySQL.get_connection()"
|
||||
"dbh = MySQL()\n",
|
||||
"database = dbh.get_connection()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -375,6 +375,15 @@
|
||||
"find_values('max_features', 'linear')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dbh.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import argparse
|
||||
from typing import Tuple
|
||||
from experimentation.Sets import Datasets
|
||||
from experimentation.Utils import TextColor, MySQL
|
||||
from experimentation.Utils import TextColor
|
||||
from experimentation.Database import MySQL
|
||||
|
||||
models = ["stree", "adaBoost", "bagging", "odte"]
|
||||
|
||||
@@ -114,7 +115,8 @@ def report_footer(agg):
|
||||
classifier,
|
||||
exclude_parameters,
|
||||
) = parse_arguments()
|
||||
database = MySQL.get_connection()
|
||||
dbh = MySQL()
|
||||
database = dbh.get_connection()
|
||||
dt = Datasets(False, False, "tanveer")
|
||||
title = "Best Hyperparameters found for datasets"
|
||||
lengths = (10, 8, 10, 10, 30, 3, 3, 9, 11)
|
||||
@@ -151,3 +153,4 @@ for dataset in dt:
|
||||
)
|
||||
print(color + report_line(record, agg))
|
||||
report_footer(agg)
|
||||
dbh.close()
|
||||
|
@@ -2,7 +2,7 @@
|
||||
for i in gridsearch gridbest cross; do
|
||||
echo "*** Building $i experiments"
|
||||
for j in stree odte bagging adaBoost; do
|
||||
for k in linear poly rbf; do
|
||||
for k in linear poly rbf any; do
|
||||
./genjobs.sh $i $j $k
|
||||
done
|
||||
done
|
||||
|
2
setup.py
2
setup.py
@@ -36,6 +36,8 @@ setuptools.setup(
|
||||
"ipympl",
|
||||
"stree",
|
||||
"odte",
|
||||
"sshtunnel",
|
||||
"mysql-connector-python",
|
||||
],
|
||||
zip_safe=False,
|
||||
)
|
||||
|
@@ -21,11 +21,12 @@
|
||||
"dataset_name = \"cylinder-bands\"\n",
|
||||
"dataset_name = \"pima\"\n",
|
||||
"dataset_name = \"conn-bench-sonar-mines-rocks\"\n",
|
||||
"dataset_name = \"libras\"\n",
|
||||
"parameters = {\"C\": .15, \"degree\": 6, \"gamma\": .7, \"kernel\": \"poly\", \"max_features\": None, \"max_iter\": 100000.0, \"random_state\": 0}\n",
|
||||
"parameters = {'C': 7, 'degree': 7, 'gamma': 0.1, 'kernel': 'poly', 'max_features': 'auto', 'max_iter': 10000.0, 'random_state': 1, 'split_criteria': 'impurity'}\n",
|
||||
"parameters = {\"C\": 0.2, \"max_iter\": 10000.0, \"random_state\": 1}\n",
|
||||
"parameters = {\"C\": 0.55, \"gamma\": 0.1, \"kernel\": \"rbf\", \"max_iter\": 10000.0, \"random_state\": 1}\n",
|
||||
"parameters = {\"C\": 55, \"max_iter\": 10000.0, \"random_state\": 1}"
|
||||
"#parameters = {'C': .17, 'degree': 5, 'gamma': 0.1, 'kernel': 'poly', 'max_features': 'auto', 'max_iter': 10000.0, 'random_state': 1, 'split_criteria': 'impurity'}\n",
|
||||
"#parameters = {\"C\": 0.2, \"max_iter\": 10000.0, \"random_state\": 1}\n",
|
||||
"#parameters = {\"C\": 0.55, \"gamma\": 0.1, \"kernel\": \"rbf\", \"max_iter\": 10000.0, \"random_state\": 1}\n",
|
||||
"#parameters = {\"C\": 55, \"max_iter\": 10000.0, \"random_state\": 1}"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -40,61 +41,102 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"clf = Stree(**parameters)\n",
|
||||
"results = cross_validate(clf, X, y, n_jobs=1, return_estimator=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/rmontanana/.virtualenvs/general/lib/python3.8/site-packages/sklearn/svm/_base.py:976: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
|
||||
" warnings.warn(\"Liblinear failed to converge, increase \"\n"
|
||||
"root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9999 counts=(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([19, 19, 19, 19, 20, 19, 19, 19, 19, 20, 19, 19, 19, 19, 20]))\n",
|
||||
"root - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9613 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([19, 19, 18, 11, 10, 18, 4, 19, 20, 8, 19, 3, 18, 11]))\n",
|
||||
"root - Down - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.8726 counts=(array([ 5, 13, 14]), array([ 3, 4, 10]))\n",
|
||||
"root - Down - Down - Down, <pure> - Leaf class=5 belief= 1.000000 impurity=0.0000 counts=(array([5]), array([2]))\n",
|
||||
"root - Down - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.7312 counts=(array([ 5, 13, 14]), array([ 1, 4, 10]))\n",
|
||||
"root - Down - Down - Up - Down, <pure> - Leaf class=5 belief= 1.000000 impurity=0.0000 counts=(array([5]), array([1]))\n",
|
||||
"root - Down - Down - Up - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.8631 counts=(array([13, 14]), array([ 4, 10]))\n",
|
||||
"root - Down - Down - Up - Up - Down, <pure> - Leaf class=13 belief= 1.000000 impurity=0.0000 counts=(array([13]), array([4]))\n",
|
||||
"root - Down - Down - Up - Up - Up, <pure> - Leaf class=14 belief= 1.000000 impurity=0.0000 counts=(array([14]), array([10]))\n",
|
||||
"root - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9359 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([19, 19, 18, 11, 7, 18, 4, 19, 20, 8, 19, 3, 14, 1]))\n",
|
||||
"root - Down - Up - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9548 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]), array([19, 19, 18, 11, 7, 18, 4, 19, 20, 8, 19, 3, 14]))\n",
|
||||
"root - Down - Up - Down - Down, <pure> - Leaf class=12 belief= 1.000000 impurity=0.0000 counts=(array([12]), array([3]))\n",
|
||||
"root - Down - Up - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9675 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13]), array([19, 19, 18, 11, 7, 18, 4, 19, 20, 8, 19, 14]))\n",
|
||||
"root - Down - Up - Down - Up - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.7248 counts=(array([5, 6, 7]), array([5, 1, 1]))\n",
|
||||
"root - Down - Up - Down - Up - Down - Down, <pure> - Leaf class=6 belief= 1.000000 impurity=0.0000 counts=(array([6]), array([1]))\n",
|
||||
"root - Down - Up - Down - Up - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.6500 counts=(array([5, 7]), array([5, 1]))\n",
|
||||
"root - Down - Up - Down - Up - Down - Up - Down, <pure> - Leaf class=5 belief= 1.000000 impurity=0.0000 counts=(array([5]), array([5]))\n",
|
||||
"root - Down - Up - Down - Up - Down - Up - Up, <pure> - Leaf class=7 belief= 1.000000 impurity=0.0000 counts=(array([7]), array([1]))\n",
|
||||
"root - Down - Up - Down - Up - Up, <cgaf> - Leaf class=9 belief= 0.118343 impurity=0.9488 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13]), array([19, 19, 18, 11, 2, 17, 3, 19, 20, 8, 19, 14]))\n",
|
||||
"root - Down - Up - Up, <pure> - Leaf class=14 belief= 1.000000 impurity=0.0000 counts=(array([14]), array([1]))\n",
|
||||
"root - Up, <cgaf> - Leaf class=3 belief= 0.208791 impurity=0.8775 counts=(array([ 2, 3, 4, 5, 6, 7, 10, 12, 13, 14]), array([ 1, 19, 9, 9, 1, 15, 11, 16, 1, 9]))\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"clf = Stree(**parameters)\n",
|
||||
"results = cross_validate(clf, X, y, n_jobs=1)"
|
||||
"print(results['estimator'][0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||
"dt = DecisionTreeClassifier(random_state=1)\n",
|
||||
"resdt = cross_validate(dt, X, y, n_jobs=1, return_estimator=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'fit_time': array([0.0078361 , 0.03171897, 0.01422501, 0.06850815, 0.05387974]),\n",
|
||||
" 'score_time': array([0.0005939 , 0.00044203, 0.00043583, 0.00050902, 0.00044012]),\n",
|
||||
" 'test_score': array([0.4047619 , 0.61904762, 0.66666667, 0.92682927, 0.58536585])}"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0.5416666666666667\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"results"
|
||||
"print(resdt['test_score'].mean()) ki"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.640534262485482"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0.09722222 0.25 0.54166667 0.31944444 0.13888889]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(results['test_score'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"results['test_score'].mean()"
|
||||
]
|
||||
@@ -371,9 +413,9 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"clf = Stree()\n",
|
||||
"model = GridSearchCV(clf, n_jobs=1, verbose=10, param_grid=param_grid, cv=2)\n",
|
||||
"model.fit(X, y)"
|
||||
"#clf = Stree()\n",
|
||||
"#model = GridSearchCV(clf, n_jobs=1, verbose=10, param_grid=param_grid, cv=2)\n",
|
||||
"#model.fit(X, y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -677,6 +719,68 @@
|
||||
"b._base_model.get_model_name()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _get_subspaces_set(\n",
|
||||
" self, dataset: np.array, labels: np.array, max_features: int\n",
|
||||
" ) -> np.array:\n",
|
||||
" features = range(dataset.shape[1])\n",
|
||||
" features_sets = list(combinations(features, max_features))\n",
|
||||
" if len(features_sets) > 1:\n",
|
||||
" if self._splitter_type == \"random\":\n",
|
||||
" index = random.randint(0, len(features_sets) - 1)\n",
|
||||
" return features_sets[index]\n",
|
||||
" else:\n",
|
||||
" # get only 3 sets at most\n",
|
||||
" if len(features_sets) > 3:\n",
|
||||
" features_sets = random.sample(features_sets, 3)\n",
|
||||
" return self._select_best_set(dataset, labels, features_sets)\n",
|
||||
" else:\n",
|
||||
" return features_sets[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"def generate_subspaces(features: int, max_features: int) -> list:\n",
|
||||
" combs = set()\n",
|
||||
" # take 3 combinations at most\n",
|
||||
" combinations = 1 if max_features == features else 3\n",
|
||||
" while len(combs) < combinations:\n",
|
||||
" combs.add(tuple(sorted(random.sample(range(features), max_features))))\n",
|
||||
" return list(combs)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[(2, 17, 51, 66, 79, 88, 98, 105, 145, 150),\n",
|
||||
" (1, 11, 27, 48, 86, 117, 124, 157, 180, 194),\n",
|
||||
" (11, 31, 42, 43, 61, 79, 138, 149, 150, 189)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"generate_subspaces(200, 10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -706,4 +810,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user