mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 08:26:00 +00:00
First approach
Added max_depth, tol, weighted samples
This commit is contained in:
@@ -104,23 +104,28 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
__folder = 'data/'
|
||||
|
||||
def __init__(self, C: float = 1.0, max_iter: int = 1000, random_state: int = 0, use_predictions: bool = False):
|
||||
def __init__(self, C: float = 1.0, max_iter: int = 1000, random_state: int = 0,
|
||||
max_depth: int=None, tol: float=1e-4, use_predictions: bool = False):
|
||||
self.max_iter = max_iter
|
||||
self.C = C
|
||||
self.random_state = random_state
|
||||
self.random_state = random_state
|
||||
self.use_predictions = use_predictions
|
||||
self.max_depth = max_depth
|
||||
self.tol = tol
|
||||
|
||||
def get_params(self, deep=True):
|
||||
def get_params(self, deep: bool=True) -> dict:
|
||||
"""Get dict with hyperparameters and its values to accomplish sklearn rules
|
||||
"""
|
||||
return {
|
||||
'C': self.C,
|
||||
'random_state': self.random_state,
|
||||
'max_iter': self.max_iter,
|
||||
'use_predictions': self.use_predictions
|
||||
'use_predictions': self.use_predictions,
|
||||
'max_depth': self.max_depth,
|
||||
'tol': self.tol
|
||||
}
|
||||
|
||||
def set_params(self, **parameters):
|
||||
def set_params(self, **parameters: dict):
|
||||
"""Set hyperparmeters as specified by sklearn, needed in Gridsearchs
|
||||
"""
|
||||
for parameter, value in parameters.items():
|
||||
@@ -128,13 +133,18 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
return self
|
||||
|
||||
# Added binary_only tag as required by sklearn check_estimator
|
||||
def _more_tags(self):
|
||||
def _more_tags(self) -> dict:
|
||||
return {'binary_only': True}
|
||||
|
||||
def _linear_function(self, data: np.array, node: Snode) -> np.array:
|
||||
coef = node._vector[0, :].reshape(-1, data.shape[1])
|
||||
return data.dot(coef.T) + node._interceptor[0]
|
||||
|
||||
def _split_array(self, origin: np.array, down: np.array) -> list:
|
||||
up = ~down
|
||||
return origin[up[:, 0]] if any(up) else None, \
|
||||
origin[down[:, 0]] if any(down) else None
|
||||
|
||||
def _split_data(self, node: Snode, data: np.ndarray, indices: np.ndarray) -> list:
|
||||
if self.use_predictions:
|
||||
yp = node._clf.predict(data)
|
||||
@@ -145,25 +155,30 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
# computes positition of every sample is w.r.t. the hyperplane
|
||||
res = self._linear_function(data, node)
|
||||
down = res > 0
|
||||
up = ~down
|
||||
data_down = data[down[:, 0]] if any(down) else None
|
||||
indices_down = indices[down[:, 0]] if any(down) else None
|
||||
res_down = res[down[:, 0]] if any(down) else None
|
||||
data_up = data[up[:, 0]] if any(up) else None
|
||||
indices_up = indices[up[:, 0]] if any(up) else None
|
||||
res_up = res[up[:, 0]] if any(up) else None
|
||||
data_up, data_down = self._split_array(data, down)
|
||||
indices_up, indices_down = self._split_array(indices, down)
|
||||
res_up, res_down = self._split_array(res, down)
|
||||
return [data_up, indices_up, data_down, indices_down, res_up, res_down]
|
||||
|
||||
def fit(self, X: np.ndarray, y: np.ndarray, title: str = 'root') -> 'Stree':
|
||||
def fit(self, X: np.ndarray, y: np.ndarray, weighted_samples: np.array=None, **fitparams: dict) -> 'Stree':
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
if fitparams is not None:
|
||||
self.set_params(**fitparams)
|
||||
if type(y).__name__ == 'np.ndarray':
|
||||
y = y.ravel()
|
||||
if self.C < 0:
|
||||
raise ValueError(f"Penalty term must be positive... got (C={self.C:f})")
|
||||
self.__max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth
|
||||
if self.__max_depth < 1:
|
||||
raise ValueError(f"Maximum depth has to be greater than 1... got (max_depth={self.max_depth})")
|
||||
check_classification_targets(y)
|
||||
X, y = check_X_y(X, y)
|
||||
self.classes_ = np.unique(y)
|
||||
self.n_iter_ = self.max_iter
|
||||
self.depth_ = 0
|
||||
check_classification_targets(y)
|
||||
self.n_features_in_ = X.shape[1]
|
||||
self.tree_ = self.train(X, y.ravel(), title)
|
||||
self.tree_ = self.train(X, y, 1, 'root')
|
||||
self._build_predictor()
|
||||
return self
|
||||
|
||||
@@ -180,8 +195,11 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
|
||||
run_tree(self.tree_)
|
||||
|
||||
def train(self, X: np.ndarray, y: np.ndarray, title: str = 'root') -> Snode:
|
||||
if np.unique(y).shape[0] == 1:
|
||||
def train(self, X: np.ndarray, y: np.ndarray, depth: int, title: str = 'root') -> Snode:
|
||||
|
||||
if depth > self.__max_depth:
|
||||
return None
|
||||
if np.unique(y).shape[0] == 1 :
|
||||
# only 1 class => pure dataset
|
||||
return Snode(None, X, y, title + ', <pure>')
|
||||
# Train the model
|
||||
@@ -189,12 +207,13 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
random_state=self.random_state)
|
||||
clf.fit(X, y)
|
||||
tree = Snode(clf, X, y, title)
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
X_U, y_u, X_D, y_d, _, _ = self._split_data(tree, X, y)
|
||||
if X_U is None or X_D is None:
|
||||
# didn't part anything
|
||||
return Snode(clf, X, y, title + ', <cgaf>')
|
||||
tree.set_up(self.train(X_U, y_u, title + ' - Up'))
|
||||
tree.set_down(self.train(X_D, y_d, title + ' - Down'))
|
||||
tree.set_up(self.train(X_U, y_u, depth + 1, title + ' - Up'))
|
||||
tree.set_down(self.train(X_D, y_d, depth + 1, title + ' - Down'))
|
||||
return tree
|
||||
|
||||
def _reorder_results(self, y: np.array, indices: np.array, proba=False) -> np.array:
|
||||
@@ -273,8 +292,9 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
result = result.reshape(X.shape[0], 2)
|
||||
# Turn distances to hyperplane into probabilities based on fitting distances
|
||||
# of samples to its hyperplane that classified them, to the sigmoid function
|
||||
result[:, 1] = 1 / (1 + np.exp(-result[:, 1])) # Probability of being 1
|
||||
result[:, 0] = 1 - result[:, 1] # Probability of being 0
|
||||
# Probability of being 1
|
||||
result[:, 1] = 1 / (1 + np.exp(-result[:, 1]))
|
||||
result[:, 0] = 1 - result[:, 1] # Probability of being 0
|
||||
return self._reorder_results(result, indices, proba=True)
|
||||
|
||||
def score(self, X: np.array, y: np.array) -> float:
|
||||
@@ -286,8 +306,12 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
right = (yp == y).astype(int)
|
||||
return np.sum(right) / len(y)
|
||||
|
||||
def __iter__(self):
|
||||
return Siterator(self.tree_)
|
||||
def __iter__(self) -> Siterator:
|
||||
try:
|
||||
tree = self.tree_
|
||||
except:
|
||||
tree = None
|
||||
return Siterator(tree)
|
||||
|
||||
def __str__(self) -> str:
|
||||
output = ''
|
||||
@@ -295,6 +319,9 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
output += str(i) + '\n'
|
||||
return output
|
||||
|
||||
def get_folder(self) -> str:
|
||||
return self.__folder
|
||||
|
||||
def _save_datasets(self, tree: Snode, catalog: typing.TextIO, number: int):
|
||||
"""Save the dataset of the node in a csv file
|
||||
|
||||
@@ -324,4 +351,3 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
os.mkdir(self.__folder)
|
||||
with open(self.get_catalog_name(), 'w', encoding='utf-8') as catalog:
|
||||
self._save_datasets(self.tree_, catalog, 1)
|
||||
|
||||
|
@@ -118,6 +118,12 @@ class Stree_test(unittest.TestCase):
|
||||
x_file, y_file = self._get_file_data(row[0])
|
||||
y_original = np.array(self._find_out(x_file, X, y), dtype=int)
|
||||
self.assertTrue(np.array_equal(y_file, y_original))
|
||||
if os.path.isdir(self._clf.get_folder()):
|
||||
try:
|
||||
os.remove(f"{self._clf.get_folder()}*")
|
||||
os.rmdir(self._clf.get_folder())
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_single_prediction(self):
|
||||
X, y = self._get_Xy()
|
||||
@@ -253,6 +259,30 @@ class Stree_test(unittest.TestCase):
|
||||
from sklearn.utils.estimator_checks import check_estimator
|
||||
check_estimator(Stree())
|
||||
|
||||
def test_exception_if_C_is_negative(self):
|
||||
tclf = Stree(C=-1)
|
||||
with self.assertRaises(ValueError):
|
||||
tclf.fit(*self._get_Xy())
|
||||
|
||||
def test_check_max_depth_is_positive_or_None(self):
|
||||
tcl = Stree()
|
||||
self.assertIsNone(tcl.max_depth)
|
||||
tcl = Stree(max_depth=1)
|
||||
self.assertGreaterEqual(1, tcl.max_depth)
|
||||
with self.assertRaises(ValueError):
|
||||
tcl = Stree(max_depth=-1)
|
||||
tcl.fit(*self._get_Xy())
|
||||
|
||||
def test_check_max_depth(self):
|
||||
depth = 3
|
||||
tcl = Stree(random_state=self._random_state, max_depth=depth)
|
||||
tcl.fit(*self._get_Xy())
|
||||
self.assertEqual(depth, tcl.depth_)
|
||||
|
||||
def test_unfitted_tree_is_iterable(self):
|
||||
tcl = Stree()
|
||||
self.assertEqual(0, len(list(tcl)))
|
||||
|
||||
class Snode_test(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
Reference in New Issue
Block a user