Add sigmoid kernel

This commit is contained in:
2021-04-22 18:09:27 +02:00
parent 3af7864278
commit 881777c38c
3 changed files with 19 additions and 2 deletions

View File

@@ -37,7 +37,7 @@ Can be found in
| | **Hyperparameter** | **Type/Values** | **Default** | **Meaning** |
| --- | ------------------ | ------------------------------------------------------ | ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| \* | C | \<float\> | 1.0 | Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. |
| \* | kernel | {"linear", "poly", "rbf"} | linear | Specifies the kernel type to be used in the algorithm. It must be one of linear, poly or rbf. |
| \* | kernel | {"linear", "poly", "rbf", "sigmoid"} | linear | Specifies the kernel type to be used in the algorithm. It must be one of linear, poly or rbf. |
| \* | max_iter | \<int\> | 1e5 | Hard limit on iterations within solver, or -1 for no limit. |
| \* | random_state | \<int\> | None | Controls the pseudo random number generation for shuffling the data for probability estimates. Ignored when probability is False.<br>Pass an int for reproducible output across multiple function calls |
| | max_depth | \<int\> | None | Specifies the maximum depth of the tree |

View File

@@ -619,7 +619,9 @@ class Stree(BaseEstimator, ClassifierMixin):
f"Maximum depth has to be greater than 1... got (max_depth=\
{self.max_depth})"
)
kernels = ["linear", "rbf", "poly", "sigmoid"]
if self.kernel not in kernels:
raise ValueError(f"Kernel {self.kernel} not in {kernels}")
check_classification_targets(y)
X, y = check_X_y(X, y)
sample_weight = _check_sample_weight(

View File

@@ -21,6 +21,21 @@ class Stree_test(unittest.TestCase):
def setUp(cls):
os.environ["TESTING"] = "1"
def test_valid_kernels(self):
valid_kernels = ["linear", "rbf", "poly", "sigmoid"]
X, y = load_dataset()
for kernel in valid_kernels:
clf = Stree(kernel=kernel)
clf.fit(X, y)
self.assertIsNotNone(clf.tree_)
def test_bogus_kernel(self):
kernel = "other"
X, y = load_dataset()
clf = Stree(kernel=kernel)
with self.assertRaises(ValueError):
clf.fit(X, y)
def _check_tree(self, node: Snode):
"""Check recursively that the nodes that are not leaves have the
correct number of labels and its sons have the right number of elements