Files
stree/main.py
Ricardo Montañana Gómez 3bdac9bd60 Complete source comments (#22)
* Add Hyperparameters description to README
Comment get_subspace method
Add environment info for binder (runtime.txt)

* Complete source comments
Change docstring type to numpy
update hyperameters table and explanation

* Update Jupyter notebooks
2021-01-19 10:44:59 +01:00

30 lines
1008 B
Python

import time
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from stree import Stree
random_state = 1
X, y = load_iris(return_X_y=True)
Xtrain, Xtest, ytrain, ytest = train_test_split(
X, y, test_size=0.3, random_state=random_state
)
now = time.time()
print("Predicting with max_features=sqrt(n_features)")
clf = Stree(C=0.01, random_state=random_state, max_features="auto")
clf.fit(Xtrain, ytrain)
print(f"Took {time.time() - now:.2f} seconds to train")
print(clf)
print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}")
print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")
print("=" * 40)
print("Predicting with max_features=n_features")
clf = Stree(C=0.01, random_state=random_state)
clf.fit(Xtrain, ytrain)
print(f"Took {time.time() - now:.2f} seconds to train")
print(clf)
print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}")
print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")