Add tunnel to mysql

add any kernel to script generator
This commit is contained in:
2020-12-10 17:23:07 +01:00
parent 74841d5047
commit f3afcd00ba
12 changed files with 215 additions and 74 deletions

1
.gitignore vendored
View File

@@ -132,3 +132,4 @@ dmypy.json
.vscode .vscode
.pre-commit-config.yaml .pre-commit-config.yaml
experimentation/.myconfig experimentation/.myconfig
experimentation/.tunnel

View File

@@ -1,5 +1,6 @@
from experimentation.Sets import Datasets from experimentation.Sets import Datasets
from experimentation.Utils import TextColor, MySQL from experimentation.Utils import TextColor
from experimentation.Database import MySQL
models = ["stree", "odte", "adaBoost", "bagging"] models = ["stree", "odte", "adaBoost", "bagging"]
title = "Best model results" title = "Best model results"
@@ -78,7 +79,8 @@ def report_footer(agg):
) )
database = MySQL.get_connection() dbh = MySQL()
database = dbh.get_connection()
dt = Datasets(False, False, "tanveer") dt = Datasets(False, False, "tanveer")
fields = ("Dataset", "Reference") fields = ("Dataset", "Reference")
for model in models: for model in models:
@@ -131,3 +133,4 @@ for dataset in dt:
) )
print(report_line(line)) print(report_line(line))
report_footer(agg) report_footer(agg)
dbh.close()

View File

@@ -32,22 +32,18 @@ def parse_arguments() -> Tuple[str, str, str, bool, bool]:
default="aaai", default="aaai",
) )
args = ap.parse_args() args = ap.parse_args()
return ( return (args.host, args.model, args.set_of_files)
args.host,
args.model,
args.set_of_files,
)
( (host, model, set_of_files) = parse_arguments()
host,
model,
set_of_files,
) = parse_arguments()
datasets = Datasets(False, False, set_of_files) datasets = Datasets(False, False, set_of_files)
clf = None clf = None
experiment = Experiment( experiment = Experiment(
random_state=1, model=model, host=host, set_of_files=set_of_files random_state=1,
model=model,
host=host,
set_of_files=set_of_files,
kernel="any",
) )
for dataset in datasets: for dataset in datasets:
print(f"-Cross validation on {dataset[0]}") print(f"-Cross validation on {dataset[0]}")

View File

@@ -1,5 +1,4 @@
host=<server> host=<server>
port=3306
user=stree user=stree
password=<password> password=<password>
database=stree_experiments database=stree_experiments

View File

@@ -0,0 +1,4 @@
ssh_address_or_host=(<host>, <port>)
ssh_username=<user>
ssh_private_key=<path_to>/id_rsa
remote_bind_address=('127.0.0.1', 3306)

View File

@@ -3,9 +3,43 @@ import sqlite3
from datetime import datetime from datetime import datetime
from abc import ABC from abc import ABC
from typing import List from typing import List
import mysql.connector
from ast import literal_eval as make_tuple
from sshtunnel import SSHTunnelForwarder
from .Models import ModelBase from .Models import ModelBase
from .Utils import TextColor, MySQL from .Utils import TextColor
class MySQL:
def __init__(self):
self.server = None
def get_connection(self):
config_db = dict()
dir_path = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(dir_path, ".myconfig")) as f:
for line in f.read().splitlines():
key, value = line.split("=")
config_db[key] = value
config_tunnel = dict()
with open(os.path.join(dir_path, ".tunnel")) as f:
for line in f.read().splitlines():
key, value = line.split("=")
config_tunnel[key] = value
config_tunnel["remote_bind_address"] = make_tuple(
config_tunnel["remote_bind_address"]
)
config_tunnel["ssh_address_or_host"] = make_tuple(
config_tunnel["ssh_address_or_host"]
)
self.server = SSHTunnelForwarder(**config_tunnel)
self.server.daemon_forward_servers = True
self.server.start()
config_db["port"] = self.server.local_bind_port
return mysql.connector.connect(**config_db)
def close(self):
self.server.close()
class BD(ABC): class BD(ABC):
@@ -108,7 +142,8 @@ class BD(ABC):
:param record: data to insert in database :param record: data to insert in database
:type record: dict :type record: dict
""" """
database = MySQL.get_connection() dbh = MySQL()
database = dbh.get_connection()
command_insert = ( command_insert = (
"replace into results (date, time, type, accuracy, " "replace into results (date, time, type, accuracy, "
"dataset, classifier, norm, stand, parameters) values (%s, %s, " "dataset, classifier, norm, stand, parameters) values (%s, %s, "
@@ -131,6 +166,7 @@ class BD(ABC):
cursor = database.cursor() cursor = database.cursor()
cursor.execute(command_insert, values) cursor.execute(command_insert, values)
database.commit() database.commit()
dbh.close()
def execute(self, command: str) -> None: def execute(self, command: str) -> None:
c = self._con.cursor() c = self._con.cursor()

View File

@@ -1,7 +1,3 @@
import os
import mysql.connector
class TextColor: class TextColor:
BLUE = "\033[94m" BLUE = "\033[94m"
CYAN = "\033[96m" CYAN = "\033[96m"
@@ -18,15 +14,3 @@ class TextColor:
ENDC = "\033[0m" ENDC = "\033[0m"
BOLD = "\033[1m" BOLD = "\033[1m"
UNDERLINE = "\033[4m" UNDERLINE = "\033[4m"
class MySQL:
@staticmethod
def get_connection():
config = dict()
dir_path = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(dir_path, ".myconfig")) as f:
for line in f.read().splitlines():
key, value = line.split("=")
config[key] = value
return mysql.connector.connect(**config)

View File

@@ -30,10 +30,10 @@
"import json\n", "import json\n",
"import sqlite3\n", "import sqlite3\n",
"import mysql.connector\n", "import mysql.connector\n",
"from experimentation.Utils import MySQL\n", "from experimentation.Database import MySQL\n",
"from experimentation.Sets import Datasets\n", "from experimentation.Sets import Datasets\n",
"\n", "dbh = MySQL()\n",
"database = MySQL.get_connection()" "database = dbh.get_connection()"
] ]
}, },
{ {
@@ -375,6 +375,15 @@
"find_values('max_features', 'linear')" "find_values('max_features', 'linear')"
] ]
}, },
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"dbh.close()"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,

View File

@@ -1,7 +1,8 @@
import argparse import argparse
from typing import Tuple from typing import Tuple
from experimentation.Sets import Datasets from experimentation.Sets import Datasets
from experimentation.Utils import TextColor, MySQL from experimentation.Utils import TextColor
from experimentation.Database import MySQL
models = ["stree", "adaBoost", "bagging", "odte"] models = ["stree", "adaBoost", "bagging", "odte"]
@@ -114,7 +115,8 @@ def report_footer(agg):
classifier, classifier,
exclude_parameters, exclude_parameters,
) = parse_arguments() ) = parse_arguments()
database = MySQL.get_connection() dbh = MySQL()
database = dbh.get_connection()
dt = Datasets(False, False, "tanveer") dt = Datasets(False, False, "tanveer")
title = "Best Hyperparameters found for datasets" title = "Best Hyperparameters found for datasets"
lengths = (10, 8, 10, 10, 30, 3, 3, 9, 11) lengths = (10, 8, 10, 10, 30, 3, 3, 9, 11)
@@ -151,3 +153,4 @@ for dataset in dt:
) )
print(color + report_line(record, agg)) print(color + report_line(record, agg))
report_footer(agg) report_footer(agg)
dbh.close()

View File

@@ -2,7 +2,7 @@
for i in gridsearch gridbest cross; do for i in gridsearch gridbest cross; do
echo "*** Building $i experiments" echo "*** Building $i experiments"
for j in stree odte bagging adaBoost; do for j in stree odte bagging adaBoost; do
for k in linear poly rbf; do for k in linear poly rbf any; do
./genjobs.sh $i $j $k ./genjobs.sh $i $j $k
done done
done done

View File

@@ -36,6 +36,8 @@ setuptools.setup(
"ipympl", "ipympl",
"stree", "stree",
"odte", "odte",
"sshtunnel",
"mysql-connector-python",
], ],
zip_safe=False, zip_safe=False,
) )

View File

@@ -21,11 +21,12 @@
"dataset_name = \"cylinder-bands\"\n", "dataset_name = \"cylinder-bands\"\n",
"dataset_name = \"pima\"\n", "dataset_name = \"pima\"\n",
"dataset_name = \"conn-bench-sonar-mines-rocks\"\n", "dataset_name = \"conn-bench-sonar-mines-rocks\"\n",
"dataset_name = \"libras\"\n",
"parameters = {\"C\": .15, \"degree\": 6, \"gamma\": .7, \"kernel\": \"poly\", \"max_features\": None, \"max_iter\": 100000.0, \"random_state\": 0}\n", "parameters = {\"C\": .15, \"degree\": 6, \"gamma\": .7, \"kernel\": \"poly\", \"max_features\": None, \"max_iter\": 100000.0, \"random_state\": 0}\n",
"parameters = {'C': 7, 'degree': 7, 'gamma': 0.1, 'kernel': 'poly', 'max_features': 'auto', 'max_iter': 10000.0, 'random_state': 1, 'split_criteria': 'impurity'}\n", "#parameters = {'C': .17, 'degree': 5, 'gamma': 0.1, 'kernel': 'poly', 'max_features': 'auto', 'max_iter': 10000.0, 'random_state': 1, 'split_criteria': 'impurity'}\n",
"parameters = {\"C\": 0.2, \"max_iter\": 10000.0, \"random_state\": 1}\n", "#parameters = {\"C\": 0.2, \"max_iter\": 10000.0, \"random_state\": 1}\n",
"parameters = {\"C\": 0.55, \"gamma\": 0.1, \"kernel\": \"rbf\", \"max_iter\": 10000.0, \"random_state\": 1}\n", "#parameters = {\"C\": 0.55, \"gamma\": 0.1, \"kernel\": \"rbf\", \"max_iter\": 10000.0, \"random_state\": 1}\n",
"parameters = {\"C\": 55, \"max_iter\": 10000.0, \"random_state\": 1}" "#parameters = {\"C\": 55, \"max_iter\": 10000.0, \"random_state\": 1}"
] ]
}, },
{ {
@@ -40,61 +41,102 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"clf = Stree(**parameters)\n",
"results = cross_validate(clf, X, y, n_jobs=1, return_estimator=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/Users/rmontanana/.virtualenvs/general/lib/python3.8/site-packages/sklearn/svm/_base.py:976: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n", "root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9999 counts=(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([19, 19, 19, 19, 20, 19, 19, 19, 19, 20, 19, 19, 19, 19, 20]))\n",
" warnings.warn(\"Liblinear failed to converge, increase \"\n" "root - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9613 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([19, 19, 18, 11, 10, 18, 4, 19, 20, 8, 19, 3, 18, 11]))\n",
"root - Down - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.8726 counts=(array([ 5, 13, 14]), array([ 3, 4, 10]))\n",
"root - Down - Down - Down, <pure> - Leaf class=5 belief= 1.000000 impurity=0.0000 counts=(array([5]), array([2]))\n",
"root - Down - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.7312 counts=(array([ 5, 13, 14]), array([ 1, 4, 10]))\n",
"root - Down - Down - Up - Down, <pure> - Leaf class=5 belief= 1.000000 impurity=0.0000 counts=(array([5]), array([1]))\n",
"root - Down - Down - Up - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.8631 counts=(array([13, 14]), array([ 4, 10]))\n",
"root - Down - Down - Up - Up - Down, <pure> - Leaf class=13 belief= 1.000000 impurity=0.0000 counts=(array([13]), array([4]))\n",
"root - Down - Down - Up - Up - Up, <pure> - Leaf class=14 belief= 1.000000 impurity=0.0000 counts=(array([14]), array([10]))\n",
"root - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9359 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([19, 19, 18, 11, 7, 18, 4, 19, 20, 8, 19, 3, 14, 1]))\n",
"root - Down - Up - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9548 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]), array([19, 19, 18, 11, 7, 18, 4, 19, 20, 8, 19, 3, 14]))\n",
"root - Down - Up - Down - Down, <pure> - Leaf class=12 belief= 1.000000 impurity=0.0000 counts=(array([12]), array([3]))\n",
"root - Down - Up - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.9675 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13]), array([19, 19, 18, 11, 7, 18, 4, 19, 20, 8, 19, 14]))\n",
"root - Down - Up - Down - Up - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.7248 counts=(array([5, 6, 7]), array([5, 1, 1]))\n",
"root - Down - Up - Down - Up - Down - Down, <pure> - Leaf class=6 belief= 1.000000 impurity=0.0000 counts=(array([6]), array([1]))\n",
"root - Down - Up - Down - Up - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89) impurity=0.6500 counts=(array([5, 7]), array([5, 1]))\n",
"root - Down - Up - Down - Up - Down - Up - Down, <pure> - Leaf class=5 belief= 1.000000 impurity=0.0000 counts=(array([5]), array([5]))\n",
"root - Down - Up - Down - Up - Down - Up - Up, <pure> - Leaf class=7 belief= 1.000000 impurity=0.0000 counts=(array([7]), array([1]))\n",
"root - Down - Up - Down - Up - Up, <cgaf> - Leaf class=9 belief= 0.118343 impurity=0.9488 counts=(array([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13]), array([19, 19, 18, 11, 2, 17, 3, 19, 20, 8, 19, 14]))\n",
"root - Down - Up - Up, <pure> - Leaf class=14 belief= 1.000000 impurity=0.0000 counts=(array([14]), array([1]))\n",
"root - Up, <cgaf> - Leaf class=3 belief= 0.208791 impurity=0.8775 counts=(array([ 2, 3, 4, 5, 6, 7, 10, 12, 13, 14]), array([ 1, 19, 9, 9, 1, 15, 11, 16, 1, 9]))\n",
"\n"
] ]
} }
], ],
"source": [ "source": [
"clf = Stree(**parameters)\n", "print(results['estimator'][0])"
"results = cross_validate(clf, X, y, n_jobs=1)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"dt = DecisionTreeClassifier(random_state=1)\n",
"resdt = cross_validate(dt, X, y, n_jobs=1, return_estimator=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "name": "stdout",
"text/plain": [ "output_type": "stream",
"{'fit_time': array([0.0078361 , 0.03171897, 0.01422501, 0.06850815, 0.05387974]),\n", "text": [
" 'score_time': array([0.0005939 , 0.00044203, 0.00043583, 0.00050902, 0.00044012]),\n", "0.5416666666666667\n"
" 'test_score': array([0.4047619 , 0.61904762, 0.66666667, 0.92682927, 0.58536585])}" ]
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
"results" "print(resdt['test_score'].mean()) ki"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "name": "stdout",
"text/plain": [ "output_type": "stream",
"0.640534262485482" "text": [
] "[0.09722222 0.25 0.54166667 0.31944444 0.13888889]\n"
}, ]
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [
"print(results['test_score'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"results['test_score'].mean()" "results['test_score'].mean()"
] ]
@@ -371,9 +413,9 @@
} }
], ],
"source": [ "source": [
"clf = Stree()\n", "#clf = Stree()\n",
"model = GridSearchCV(clf, n_jobs=1, verbose=10, param_grid=param_grid, cv=2)\n", "#model = GridSearchCV(clf, n_jobs=1, verbose=10, param_grid=param_grid, cv=2)\n",
"model.fit(X, y)" "#model.fit(X, y)"
] ]
}, },
{ {
@@ -677,6 +719,68 @@
"b._base_model.get_model_name()" "b._base_model.get_model_name()"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _get_subspaces_set(\n",
" self, dataset: np.array, labels: np.array, max_features: int\n",
" ) -> np.array:\n",
" features = range(dataset.shape[1])\n",
" features_sets = list(combinations(features, max_features))\n",
" if len(features_sets) > 1:\n",
" if self._splitter_type == \"random\":\n",
" index = random.randint(0, len(features_sets) - 1)\n",
" return features_sets[index]\n",
" else:\n",
" # get only 3 sets at most\n",
" if len(features_sets) > 3:\n",
" features_sets = random.sample(features_sets, 3)\n",
" return self._select_best_set(dataset, labels, features_sets)\n",
" else:\n",
" return features_sets[0]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"def generate_subspaces(features: int, max_features: int) -> list:\n",
" combs = set()\n",
" # take 3 combinations at most\n",
" combinations = 1 if max_features == features else 3\n",
" while len(combs) < combinations:\n",
" combs.add(tuple(sorted(random.sample(range(features), max_features))))\n",
" return list(combs)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(2, 17, 51, 66, 79, 88, 98, 105, 145, 150),\n",
" (1, 11, 27, 48, 86, 117, 124, 157, 180, 194),\n",
" (11, 31, 42, 43, 61, 79, 138, 149, 150, 189)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_subspaces(200, 10)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,