diff --git a/README.md b/README.md
index bee1822..6844723 100644
--- a/README.md
+++ b/README.md
@@ -50,7 +50,8 @@ Can be found in [stree.readthedocs.io](https://stree.readthedocs.io/en/stable/)
| | criterion | {“gini”, “entropy”} | entropy | The function to measure the quality of a split (only used if max_features != num_features).
Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain. |
| | min_samples_split | \ | 0 | The minimum number of samples required to split an internal node. 0 (default) for any |
| | max_features | \, \
or {“auto”, “sqrt”, “log2”} | None | The number of features to consider when looking for the split:
If int, then consider max_features features at each split.
If float, then max_features is a fraction and int(max_features \* n_features) features are considered at each split.
If “auto”, then max_features=sqrt(n_features).
If “sqrt”, then max_features=sqrt(n_features).
If “log2”, then max_features=log2(n_features).
If None, then max_features=n_features. |
-| | splitter | {"best", "random", "trandom", "mutual", "cfs", "fcbf", "iwss"} | "random" | The strategy used to choose the feature set at each node (only used if max_features < num_features). Supported strategies are: **“best”**: sklearn SelectKBest algorithm is used in every node to choose the max_features best features. **“random”**: The algorithm generates 5 candidates and choose the best (max. info. gain) of them. **“trandom”**: The algorithm generates only one random combination. **"mutual"**: Chooses the best features w.r.t. their mutual info with the label. **"cfs"**: Apply Correlation-based Feature Selection. **"fcbf"**: Apply Fast Correlation-Based Filter. **"iwss"**: IWSS based algorithm |
+| | splitter | {"best", "random", "trandom", "mutual", "cfs", "fcbf", "iwss"} | "random" | The strategy used to choose the feature set at each node (only used if max_features < num_features).
+Supported strategies are: **“best”**: sklearn SelectKBest algorithm is used in every node to choose the max_features best features. **“random”**: The algorithm generates 5 candidates and choose the best (max. info. gain) of them. **“trandom”**: The algorithm generates only one random combination. **"mutual"**: Chooses the best features w.r.t. their mutual info with the label. **"cfs"**: Apply Correlation-based Feature Selection. **"fcbf"**: Apply Fast Correlation-Based Filter. **"iwss"**: IWSS based algorithm |
| | normalize | \ | False | If standardization of features should be applied on each node with the samples that reach it |
| \* | multiclass_strategy | {"ovo", "ovr"} | "ovo" | Strategy to use with multiclass datasets, **"ovo"**: one versus one. **"ovr"**: one versus rest |
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 2d08757..dd859b3 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -12,19 +12,18 @@
#
import os
import sys
-import stree
+from stree._version import __version__
sys.path.insert(0, os.path.abspath("../../stree/"))
-
# -- Project information -----------------------------------------------------
project = "STree"
-copyright = "2020 - 2021, Ricardo Montañana Gómez"
+copyright = "2020 - 2022, Ricardo Montañana Gómez"
author = "Ricardo Montañana Gómez"
# The full version, including alpha/beta/rc tags
-version = stree.__version__
+version = __version__
release = version
diff --git a/docs/source/hyperparameters.md b/docs/source/hyperparameters.md
index 1a5ed86..0264063 100644
--- a/docs/source/hyperparameters.md
+++ b/docs/source/hyperparameters.md
@@ -3,20 +3,20 @@
| | **Hyperparameter** | **Type/Values** | **Default** | **Meaning** |
| --- | ------------------- | -------------------------------------------------------------- | ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| \* | C | \ | 1.0 | Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. |
-| \* | kernel | {"liblinear", "linear", "poly", "rbf", "sigmoid"} | linear | Specifies the kernel type to be used in the algorithm. It must be one of ‘liblinear’, ‘linear’, ‘poly’ or ‘rbf’. liblinear uses [liblinear](https://www.csie.ntu.edu.tw/~cjlin/liblinear/) library and the rest uses [libsvm](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) library through scikit-learn library |
+| \* | kernel | {"liblinear", "linear", "poly", "rbf", "sigmoid"} | linear | Specifies the kernel type to be used in the algorithm. It must be one of ‘liblinear’, ‘linear’, ‘poly’ or ‘rbf’.
liblinear uses [liblinear](https://www.csie.ntu.edu.tw/~cjlin/liblinear/) library and the rest uses [libsvm](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) library through scikit-learn library |
| \* | max_iter | \ | 1e5 | Hard limit on iterations within solver, or -1 for no limit. |
| \* | random_state | \ | None | Controls the pseudo random number generation for shuffling the data for probability estimates. Ignored when probability is False.
Pass an int for reproducible output across multiple function calls |
| | max_depth | \ | None | Specifies the maximum depth of the tree |
| \* | tol | \ | 1e-4 | Tolerance for stopping criterion. |
| \* | degree | \ | 3 | Degree of the polynomial kernel function (‘poly’). Ignored by all other kernels. |
| \* | gamma | {"scale", "auto"} or \ | scale | Kernel coefficient for ‘rbf’, ‘poly’ and ‘sigmoid’.
if gamma='scale' (default) is passed then it uses 1 / (n_features \* X.var()) as value of gamma,
if ‘auto’, uses 1 / n_features. |
-| | split_criteria | {"impurity", "max_samples"} | impurity | Decides (just in case of a multi class classification) which column (class) use to split the dataset in a node\*\*. max_samples is incompatible with 'ovo' multiclass_strategy |
-| | criterion | {“gini”, “entropy”} | entropy | The function to measure the quality of a split (only used if max_features != num_features).
Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain. |
+| | split_criteria | {"impurity", "max_samples"} | impurity | Decides (just in case of a multi class classification) which column (class) use to split the dataset in a node\*\*.
max_samples is incompatible with 'ovo' multiclass_strategy |
+| | criterion | {“gini”, “entropy”} | entropy | The function to measure the quality of a split (only used if max_features != num_features).
Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain. |
| | min_samples_split | \ | 0 | The minimum number of samples required to split an internal node. 0 (default) for any |
| | max_features | \, \
or {“auto”, “sqrt”, “log2”} | None | The number of features to consider when looking for the split:
If int, then consider max_features features at each split.
If float, then max_features is a fraction and int(max_features \* n_features) features are considered at each split.
If “auto”, then max_features=sqrt(n_features).
If “sqrt”, then max_features=sqrt(n_features).
If “log2”, then max_features=log2(n_features).
If None, then max_features=n_features. |
-| | splitter | {"best", "random", "trandom", "mutual", "cfs", "fcbf", "iwss"} | "random" | The strategy used to choose the feature set at each node (only used if max_features < num_features). Supported strategies are: **“best”**: sklearn SelectKBest algorithm is used in every node to choose the max_features best features. **“random”**: The algorithm generates 5 candidates and choose the best (max. info. gain) of them. **“trandom”**: The algorithm generates only one random combination. **"mutual"**: Chooses the best features w.r.t. their mutual info with the label. **"cfs"**: Apply Correlation-based Feature Selection. **"fcbf"**: Apply Fast Correlation-Based Filter. **"iwss"**: IWSS based algorithm |
+| | splitter | {"best", "random", "trandom", "mutual", "cfs", "fcbf", "iwss"} | "random" | The strategy used to choose the feature set at each node (only used if max_features < num_features).
Supported strategies are:
**“best”**: sklearn SelectKBest algorithm is used in every node to choose the max_features best features.
**“random”**: The algorithm generates 5 candidates and choose the best (max. info. gain) of them.
**“trandom”**: The algorithm generates only one random combination.
**"mutual"**: Chooses the best features w.r.t. their mutual info with the label.
**"cfs"**: Apply Correlation-based Feature Selection.
**"fcbf"**: Apply Fast Correlation-Based Filter.
**"iwss"**: IWSS based algorithm |
| | normalize | \ | False | If standardization of features should be applied on each node with the samples that reach it |
-| \* | multiclass_strategy | {"ovo", "ovr"} | "ovo" | Strategy to use with multiclass datasets, **"ovo"**: one versus one. **"ovr"**: one versus rest |
+| \* | multiclass_strategy | {"ovo", "ovr"} | "ovo" | Strategy to use with multiclass datasets:
**"ovo"**: one versus one.
**"ovr"**: one versus rest |
\* Hyperparameter used by the support vector classifier of every node
diff --git a/setup.py b/setup.py
index b58d071..7333fb0 100644
--- a/setup.py
+++ b/setup.py
@@ -7,9 +7,8 @@ def readme():
return f.read()
-def get_data(field):
+def get_data(field, file_name="__init__.py"):
item = ""
- file_name = "_version.py" if field == "version" else "__init__.py"
with open(os.path.join("stree", file_name)) as f:
for line in f.readlines():
if line.startswith(f"__{field}__"):
@@ -21,9 +20,14 @@ def get_data(field):
return item
+def get_requirements():
+ with open("requirements.txt") as f:
+ return f.read().splitlines()
+
+
setuptools.setup(
name="STree",
- version=get_data("version"),
+ version=get_data("version", "_version.py"),
license=get_data("license"),
description="Oblique decision tree with svm nodes",
long_description=readme(),
@@ -46,7 +50,7 @@ setuptools.setup(
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Intended Audience :: Science/Research",
],
- install_requires=["scikit-learn", "mufs"],
+ install_requires=get_requirements(),
test_suite="stree.tests",
zip_safe=False,
)
diff --git a/stree/Strees.py b/stree/Strees.py
index 1f93252..23e0290 100644
--- a/stree/Strees.py
+++ b/stree/Strees.py
@@ -368,6 +368,21 @@ class Stree(BaseEstimator, ClassifierMixin):
)
def __predict_class(self, X: np.array) -> np.array:
+ """Compute the predicted class for the samples in X. Returns the number
+ of samples of each class in the corresponding leaf node.
+
+ Parameters
+ ----------
+ X : np.array
+ Array of samples
+
+ Returns
+ -------
+ np.array
+ Array of shape (n_samples, n_classes) with the number of samples
+ of each class in the corresponding leaf node
+ """
+
def compute_prediction(xp, indices, node):
if xp is None:
return
@@ -388,6 +403,25 @@ class Stree(BaseEstimator, ClassifierMixin):
return result
def check_predict(self, X) -> np.array:
+ """Checks predict and predict_proba preconditions. If input X is not an
+ np.array convert it to one.
+
+ Parameters
+ ----------
+ X : np.ndarray
+ Array of samples
+
+ Returns
+ -------
+ np.array
+ Array of samples
+
+ Raises
+ ------
+ ValueError
+ If number of features of X is different of the number of features
+ in training data
+ """
check_is_fitted(self, ["tree_"])
# Input validation
X = check_array(X)
diff --git a/stree/_version.py b/stree/_version.py
index b3f9ac7..67bc602 100644
--- a/stree/_version.py
+++ b/stree/_version.py
@@ -1 +1 @@
-__version__ = "1.2.4"
+__version__ = "1.3.0"