mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 00:46:02 +00:00
Some quality refactoring
This commit is contained in:
29
main.py
29
main.py
@@ -1,29 +0,0 @@
|
|||||||
import time
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
from stree import Stree
|
|
||||||
|
|
||||||
random_state = 1
|
|
||||||
|
|
||||||
X, y = load_iris(return_X_y=True)
|
|
||||||
|
|
||||||
Xtrain, Xtest, ytrain, ytest = train_test_split(
|
|
||||||
X, y, test_size=0.3, random_state=random_state
|
|
||||||
)
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
print("Predicting with max_features=sqrt(n_features)")
|
|
||||||
clf = Stree(C=0.01, random_state=random_state, max_features="auto")
|
|
||||||
clf.fit(Xtrain, ytrain)
|
|
||||||
print(f"Took {time.time() - now:.2f} seconds to train")
|
|
||||||
print(clf)
|
|
||||||
print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}")
|
|
||||||
print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")
|
|
||||||
print("=" * 40)
|
|
||||||
print("Predicting with max_features=n_features")
|
|
||||||
clf = Stree(C=0.01, random_state=random_state)
|
|
||||||
clf.fit(Xtrain, ytrain)
|
|
||||||
print(f"Took {time.time() - now:.2f} seconds to train")
|
|
||||||
print(clf)
|
|
||||||
print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}")
|
|
||||||
print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")
|
|
@@ -144,7 +144,6 @@ class Snode:
|
|||||||
f"{self._belief: .6f} impurity={self._impurity:.4f} "
|
f"{self._belief: .6f} impurity={self._impurity:.4f} "
|
||||||
f"counts={count_values}"
|
f"counts={count_values}"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return (
|
return (
|
||||||
f"{self._title} feaures={self._features} impurity="
|
f"{self._title} feaures={self._features} impurity="
|
||||||
f"{self._impurity:.4f} "
|
f"{self._impurity:.4f} "
|
||||||
@@ -384,9 +383,7 @@ class Splitter:
|
|||||||
if self._splitter_type == "random":
|
if self._splitter_type == "random":
|
||||||
index = random.randint(0, len(features_sets) - 1)
|
index = random.randint(0, len(features_sets) - 1)
|
||||||
return features_sets[index]
|
return features_sets[index]
|
||||||
else:
|
|
||||||
return self._select_best_set(dataset, labels, features_sets)
|
return self._select_best_set(dataset, labels, features_sets)
|
||||||
else:
|
|
||||||
return features_sets[0]
|
return features_sets[0]
|
||||||
|
|
||||||
def get_subspace(
|
def get_subspace(
|
||||||
|
@@ -484,13 +484,13 @@ class Stree_test(unittest.TestCase):
|
|||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
nodes, leaves = clf.nodes_leaves()
|
nodes, leaves = clf.nodes_leaves()
|
||||||
self.assertEqual(25, nodes)
|
self.assertEqual(25, nodes)
|
||||||
self.assertEquals(13, leaves)
|
self.assertEqual(13, leaves)
|
||||||
X, y = load_wine(return_X_y=True)
|
X, y = load_wine(return_X_y=True)
|
||||||
clf = Stree(random_state=self._random_state)
|
clf = Stree(random_state=self._random_state)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
nodes, leaves = clf.nodes_leaves()
|
nodes, leaves = clf.nodes_leaves()
|
||||||
self.assertEqual(9, nodes)
|
self.assertEqual(9, nodes)
|
||||||
self.assertEquals(5, leaves)
|
self.assertEqual(5, leaves)
|
||||||
|
|
||||||
def test_nodes_leaves_artificial(self):
|
def test_nodes_leaves_artificial(self):
|
||||||
n1 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test1")
|
n1 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test1")
|
||||||
|
Reference in New Issue
Block a user