From 881777c38c708d79d631f915477e548133bad3a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 22 Apr 2021 18:09:27 +0200 Subject: [PATCH] Add sigmoid kernel --- README.md | 2 +- stree/Strees.py | 4 +++- stree/tests/Stree_test.py | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e13bcc9..860b901 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Can be found in | | **Hyperparameter** | **Type/Values** | **Default** | **Meaning** | | --- | ------------------ | ------------------------------------------------------ | ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | \* | C | \ | 1.0 | Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. | -| \* | kernel | {"linear", "poly", "rbf"} | linear | Specifies the kernel type to be used in the algorithm. It must be one of ‘linear’, ‘poly’ or ‘rbf’. | +| \* | kernel | {"linear", "poly", "rbf", "sigmoid"} | linear | Specifies the kernel type to be used in the algorithm. It must be one of ‘linear’, ‘poly’ or ‘rbf’. | | \* | max_iter | \ | 1e5 | Hard limit on iterations within solver, or -1 for no limit. | | \* | random_state | \ | None | Controls the pseudo random number generation for shuffling the data for probability estimates. Ignored when probability is False.
Pass an int for reproducible output across multiple function calls | | | max_depth | \ | None | Specifies the maximum depth of the tree | diff --git a/stree/Strees.py b/stree/Strees.py index 1554e7d..5bb7c1a 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -619,7 +619,9 @@ class Stree(BaseEstimator, ClassifierMixin): f"Maximum depth has to be greater than 1... got (max_depth=\ {self.max_depth})" ) - + kernels = ["linear", "rbf", "poly", "sigmoid"] + if self.kernel not in kernels: + raise ValueError(f"Kernel {self.kernel} not in {kernels}") check_classification_targets(y) X, y = check_X_y(X, y) sample_weight = _check_sample_weight( diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 379082b..51a54d6 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -21,6 +21,21 @@ class Stree_test(unittest.TestCase): def setUp(cls): os.environ["TESTING"] = "1" + def test_valid_kernels(self): + valid_kernels = ["linear", "rbf", "poly", "sigmoid"] + X, y = load_dataset() + for kernel in valid_kernels: + clf = Stree(kernel=kernel) + clf.fit(X, y) + self.assertIsNotNone(clf.tree_) + + def test_bogus_kernel(self): + kernel = "other" + X, y = load_dataset() + clf = Stree(kernel=kernel) + with self.assertRaises(ValueError): + clf.fit(X, y) + def _check_tree(self, node: Snode): """Check recursively that the nodes that are not leaves have the correct number of labels and its sons have the right number of elements