mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 07:56:06 +00:00
compute predictor and store model in node
This commit is contained in:
45
tests/Snode_test.py
Normal file
45
tests/Snode_test.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import unittest
|
||||
|
||||
from sklearn.datasets import make_classification
|
||||
import numpy as np
|
||||
import csv
|
||||
|
||||
from trees.Stree import Stree, Snode
|
||||
|
||||
|
||||
class Snode_test(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._random_state = 1
|
||||
self._model = Stree(random_state=self._random_state,
|
||||
use_predictions=True)
|
||||
self._model.fit(*self._get_Xy())
|
||||
super(Snode_test, self).__init__(*args, **kwargs)
|
||||
|
||||
def _get_Xy(self):
|
||||
X, y = make_classification(n_samples=1500, n_features=3, n_informative=3,
|
||||
n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,
|
||||
class_sep=1.5, flip_y=0, weights=[0.5, 0.5], random_state=self._random_state)
|
||||
return X, y
|
||||
|
||||
def test_attributes_in_leaves(self):
|
||||
"""Check if the attributes in leaves have correct values so they form a predictor
|
||||
"""
|
||||
def check_leave(node: Snode):
|
||||
if node.is_leaf():
|
||||
# Check Belief
|
||||
classes, card = np.unique(node._y, return_counts=True)
|
||||
max_card = max(card)
|
||||
min_card = min(card)
|
||||
try:
|
||||
accuracy = max_card / min_card
|
||||
except:
|
||||
accuracy = 0
|
||||
self.assertEqual(accuracy, node._belief)
|
||||
# Check Class
|
||||
class_computed = classes[card == max_card]
|
||||
self.assertEqual(class_computed, node._class)
|
||||
return
|
||||
check_leave(node.get_down())
|
||||
check_leave(node.get_up())
|
||||
check_leave(self._model._tree)
|
@@ -1,35 +1,31 @@
|
||||
import unittest
|
||||
|
||||
from sklearn.svm import LinearSVC
|
||||
from sklearn.datasets import make_classification
|
||||
import numpy as np
|
||||
import csv
|
||||
|
||||
from trees.Stree import Stree, Snode
|
||||
|
||||
|
||||
class Stree_test(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._random_state = 1
|
||||
self._model_tree = Stree(random_state=self._random_state, use_predictions=True)
|
||||
self._model_tree.fit(*self._get_Xy())
|
||||
self._model_svm = LinearSVC(random_state=self._random_state, max_iter=self._model_tree._max_iter)
|
||||
self._model = Stree(random_state=self._random_state,
|
||||
use_predictions=True)
|
||||
self._model.fit(*self._get_Xy())
|
||||
super(Stree_test, self).__init__(*args, **kwargs)
|
||||
|
||||
def _get_Xy(self):
|
||||
X, y = make_classification(n_samples=1500, n_features=3, n_informative=3,
|
||||
n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,
|
||||
class_sep=1.5, flip_y=0,weights=[0.5,0.5], random_state=self._random_state)
|
||||
X, y = make_classification(n_samples=1500, n_features=3, n_informative=3,
|
||||
n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,
|
||||
class_sep=1.5, flip_y=0, weights=[0.5, 0.5], random_state=self._random_state)
|
||||
return X, y
|
||||
|
||||
def test_split_data(self):
|
||||
self.assertTrue(True)
|
||||
|
||||
def _check_tree(self, node: Snode):
|
||||
if node.is_leaf():
|
||||
return
|
||||
self._model_svm.fit(node._X, node._y)
|
||||
y_prediction = self._model_svm.predict(node._X)
|
||||
y_prediction = node._model.predict(node._X)
|
||||
y_down = node.get_down()._y
|
||||
y_up = node.get_up()._y
|
||||
# Is a correct partition in terms of cadinality?
|
||||
@@ -59,7 +55,7 @@ class Stree_test(unittest.TestCase):
|
||||
def test_build_tree(self):
|
||||
"""Check if the tree is built the same way as predictions of models
|
||||
"""
|
||||
self._check_tree(self._model_tree._tree)
|
||||
self._check_tree(self._model._tree)
|
||||
|
||||
def _get_file_data(self, file_name: str) -> tuple:
|
||||
"""Return X, y from data, y is the last column in array
|
||||
@@ -69,7 +65,7 @@ class Stree_test(unittest.TestCase):
|
||||
|
||||
Returns:
|
||||
tuple -- tuple with samples, categories
|
||||
"""
|
||||
"""
|
||||
data = np.genfromtxt(file_name, delimiter=',')
|
||||
data = np.array(data)
|
||||
column_y = data.shape[1] - 1
|
||||
@@ -87,22 +83,22 @@ class Stree_test(unittest.TestCase):
|
||||
|
||||
Returns:
|
||||
np.array -- classes of the given samples
|
||||
"""
|
||||
"""
|
||||
res = []
|
||||
for needle in px:
|
||||
for row in range(x_original.shape[0]):
|
||||
if all(x_original[row, :] == needle):
|
||||
res.append(y_original[row])
|
||||
return res
|
||||
|
||||
|
||||
def test_subdatasets(self):
|
||||
"""Check if the subdatasets files have the same predictions as the tree itself
|
||||
"""
|
||||
model = LinearSVC(random_state=self._random_state, max_iter=self._model_tree._max_iter)
|
||||
model = self._model._tree._model
|
||||
X, y = self._get_Xy()
|
||||
model.fit(X, y)
|
||||
self._model_tree.save_sub_datasets()
|
||||
with open(self._model_tree.get_catalog_name()) as cat_file:
|
||||
self._model.save_sub_datasets()
|
||||
with open(self._model.get_catalog_name()) as cat_file:
|
||||
catalog = csv.reader(cat_file, delimiter=',')
|
||||
for row in catalog:
|
||||
X, y = self._get_Xy()
|
||||
|
Reference in New Issue
Block a user