mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-16 08:05:57 +00:00
Remove unneeded files from template
This commit is contained in:
33
examples/example.py
Normal file
33
examples/example.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import time
|
||||
from sklearn.model_selection import cross_val_score, StratifiedKFold
|
||||
from sklearn.preprocessing import KBinsDiscretizer
|
||||
from sklearn.datasets import load_wine
|
||||
from bayesclass.clfs import TAN
|
||||
import warnings
|
||||
|
||||
|
||||
# Warnings are not errors
|
||||
warnings.simplefilter("ignore")
|
||||
start = time.time()
|
||||
random_state = 17
|
||||
n_folds = 5
|
||||
print(f"Accuracy in {n_folds} folds stratified crossvalidation")
|
||||
dataset_start = time.time()
|
||||
dataset = load_wine()
|
||||
Xc = dataset.data
|
||||
enc = KBinsDiscretizer(encode="ordinal")
|
||||
X = enc.fit_transform(Xc)
|
||||
y = dataset.target
|
||||
clf = TAN(random_state=random_state)
|
||||
fit_params = dict(features=dataset.feature_names, class_name="class", head=0)
|
||||
kfold = StratifiedKFold(
|
||||
n_splits=n_folds, shuffle=True, random_state=random_state
|
||||
)
|
||||
score = cross_val_score(clf, X, y, cv=kfold, fit_params=fit_params)
|
||||
print(
|
||||
f"wine {'.' * 10}{score.mean():9.7f} "
|
||||
f"({time.time()-dataset_start:7.2f} seconds)"
|
||||
)
|
||||
clf.fit(X, y, **fit_params)
|
||||
clf.plot("TAN wine")
|
||||
print(f"Took {time.time()-start:.2f} seconds")
|
@@ -1,44 +0,0 @@
|
||||
"""
|
||||
============================
|
||||
Plotting Template Classifier
|
||||
============================
|
||||
|
||||
An example plot of :class:`bayesclass.TAN`
|
||||
"""
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from bayesclass.bayesclass import TAN
|
||||
|
||||
X = [[0, 0], [1, 1]]
|
||||
y = [0, 1]
|
||||
clf = TAN()
|
||||
clf.fit(X, y)
|
||||
|
||||
rng = np.random.RandomState(13)
|
||||
X_test = rng.randint(2, size=(500, 2))
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
X_0 = X_test[y_pred == 0]
|
||||
X_1 = X_test[y_pred == 1]
|
||||
|
||||
|
||||
p0 = plt.scatter(0, 0, c="red", s=100)
|
||||
p1 = plt.scatter(1, 1, c="blue", s=100)
|
||||
|
||||
ax0 = plt.scatter(X_0[:, 0], X_0[:, 1], c="crimson", s=50)
|
||||
ax1 = plt.scatter(X_1[:, 0], X_1[:, 1], c="deepskyblue", s=50)
|
||||
|
||||
leg = plt.legend(
|
||||
[p0, p1, ax0, ax1],
|
||||
["Point 0", "Point 1", "Class 0", "Class 1"],
|
||||
loc="upper left",
|
||||
fancybox=True,
|
||||
scatterpoints=1,
|
||||
)
|
||||
leg.get_frame().set_alpha(0.5)
|
||||
|
||||
plt.xlabel("Feature 1")
|
||||
plt.ylabel("Feature 2")
|
||||
plt.xlim([-0.5, 1.5])
|
||||
|
||||
plt.show()
|
Reference in New Issue
Block a user