Some quality refactoring

This commit is contained in:
2021-04-19 23:15:17 +02:00
parent fec094a75f
commit a2df31628d
4 changed files with 12 additions and 44 deletions

View File

@@ -6,7 +6,7 @@ overage:
comment:
layout: "reach, diff, flags, files"
behavior: default
require_changes: false
require_changes: false
require_base: yes
require_head: yes
branches: null
require_head: yes
branches: null

29
main.py
View File

@@ -1,29 +0,0 @@
import time
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from stree import Stree
random_state = 1
X, y = load_iris(return_X_y=True)
Xtrain, Xtest, ytrain, ytest = train_test_split(
X, y, test_size=0.3, random_state=random_state
)
now = time.time()
print("Predicting with max_features=sqrt(n_features)")
clf = Stree(C=0.01, random_state=random_state, max_features="auto")
clf.fit(Xtrain, ytrain)
print(f"Took {time.time() - now:.2f} seconds to train")
print(clf)
print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}")
print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")
print("=" * 40)
print("Predicting with max_features=n_features")
clf = Stree(C=0.01, random_state=random_state)
clf.fit(Xtrain, ytrain)
print(f"Took {time.time() - now:.2f} seconds to train")
print(clf)
print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}")
print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")

View File

@@ -144,12 +144,11 @@ class Snode:
f"{self._belief: .6f} impurity={self._impurity:.4f} "
f"counts={count_values}"
)
else:
return (
f"{self._title} feaures={self._features} impurity="
f"{self._impurity:.4f} "
f"counts={count_values}"
)
return (
f"{self._title} feaures={self._features} impurity="
f"{self._impurity:.4f} "
f"counts={count_values}"
)
class Siterator:
@@ -384,10 +383,8 @@ class Splitter:
if self._splitter_type == "random":
index = random.randint(0, len(features_sets) - 1)
return features_sets[index]
else:
return self._select_best_set(dataset, labels, features_sets)
else:
return features_sets[0]
return self._select_best_set(dataset, labels, features_sets)
return features_sets[0]
def get_subspace(
self, dataset: np.array, labels: np.array, max_features: int

View File

@@ -484,13 +484,13 @@ class Stree_test(unittest.TestCase):
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
self.assertEqual(25, nodes)
self.assertEquals(13, leaves)
self.assertEqual(13, leaves)
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
self.assertEqual(9, nodes)
self.assertEquals(5, leaves)
self.assertEqual(5, leaves)
def test_nodes_leaves_artificial(self):
n1 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test1")