mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-18 17:15:53 +00:00
Add example of usage
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -65,3 +65,5 @@ doc/generated/
|
|||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
target/
|
target/
|
||||||
|
|
||||||
|
.ipynb_checkpoints
|
1
Makefile
1
Makefile
@@ -20,6 +20,7 @@ push: ## Push code with tags
|
|||||||
git push && git push --tags
|
git push && git push --tags
|
||||||
|
|
||||||
test: ## Run tests
|
test: ## Run tests
|
||||||
|
python -m doctest bayesclass/bayesclass.py
|
||||||
pytest
|
pytest
|
||||||
|
|
||||||
doc: ## Update documentation
|
doc: ## Update documentation
|
||||||
|
@@ -1,215 +0,0 @@
|
|||||||
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
|
|
||||||
35,11,11,16,41,29,18,0,0,0
|
|
||||||
22,3,11,23,40,25,13,0,0,1
|
|
||||||
35,21,11,26,23,26,6,0,0,0
|
|
||||||
3,41,5,28,48,0,3,0,0,2
|
|
||||||
68,4,0,13,0,10,31,0,6,3
|
|
||||||
22,9,6,27,42,25,19,1,5,3
|
|
||||||
34,26,11,5,41,2,20,0,0,1
|
|
||||||
46,20,6,23,38,24,20,0,0,0
|
|
||||||
8,34,0,43,45,5,20,2,2,4
|
|
||||||
35,20,12,23,20,24,8,0,6,3
|
|
||||||
22,24,11,27,28,16,3,0,0,3
|
|
||||||
32,4,7,17,46,27,20,0,6,3
|
|
||||||
63,21,11,9,10,10,28,0,0,0
|
|
||||||
57,32,11,5,7,9,24,0,0,1
|
|
||||||
23,20,11,32,22,26,4,0,5,1
|
|
||||||
27,26,11,30,22,27,3,0,0,3
|
|
||||||
28,41,0,38,41,0,13,3,4,4
|
|
||||||
22,9,7,27,43,30,3,0,0,3
|
|
||||||
50,23,0,32,41,17,30,0,0,5
|
|
||||||
40,16,7,27,40,26,19,1,0,3
|
|
||||||
58,19,11,12,18,13,27,0,5,0
|
|
||||||
66,0,0,37,17,33,31,0,6,3
|
|
||||||
48,16,11,15,20,29,20,0,5,3
|
|
||||||
33,25,11,19,34,25,3,0,5,0
|
|
||||||
52,12,4,42,15,32,25,1,8,5
|
|
||||||
10,22,11,27,42,15,3,0,0,3
|
|
||||||
14,11,11,37,39,30,3,0,0,3
|
|
||||||
24,41,0,38,43,0,10,3,3,4
|
|
||||||
25,22,11,27,36,25,3,0,0,3
|
|
||||||
16,24,11,23,25,25,4,0,0,1
|
|
||||||
23,16,11,30,42,30,3,0,0,3
|
|
||||||
45,24,8,28,14,25,20,0,0,1
|
|
||||||
0,41,0,1,48,0,1,0,0,2
|
|
||||||
22,26,11,28,22,30,3,0,0,3
|
|
||||||
33,17,11,23,41,25,5,0,5,0
|
|
||||||
11,9,11,29,42,30,3,0,6,0
|
|
||||||
14,11,11,30,40,29,3,0,6,0
|
|
||||||
30,4,6,30,40,30,20,0,0,3
|
|
||||||
23,12,11,27,41,30,3,1,6,3
|
|
||||||
5,37,7,39,20,33,3,0,0,3
|
|
||||||
37,9,11,23,40,29,18,0,0,0
|
|
||||||
38,17,11,13,42,25,4,0,5,3
|
|
||||||
22,17,11,28,35,26,3,0,0,3
|
|
||||||
14,22,8,27,42,16,3,0,0,3
|
|
||||||
49,26,11,8,24,11,20,1,6,1
|
|
||||||
33,9,11,20,42,26,17,0,0,0
|
|
||||||
6,31,5,44,0,34,0,4,0,5
|
|
||||||
33,21,11,23,22,25,3,0,0,0
|
|
||||||
35,19,11,23,39,26,10,0,0,0
|
|
||||||
60,20,11,17,31,23,10,0,0,3
|
|
||||||
33,7,11,23,45,26,14,0,3,0
|
|
||||||
48,20,11,13,35,25,6,1,5,3
|
|
||||||
32,23,11,16,42,26,3,0,0,0
|
|
||||||
14,20,11,28,42,30,3,0,0,3
|
|
||||||
22,41,0,43,38,0,23,2,0,4
|
|
||||||
31,17,11,30,30,23,8,0,3,0
|
|
||||||
64,41,5,38,1,32,26,0,0,4
|
|
||||||
55,27,11,16,10,6,22,0,0,0
|
|
||||||
34,26,11,15,33,9,16,0,0,1
|
|
||||||
48,25,12,23,22,21,4,0,0,3
|
|
||||||
48,26,12,23,10,23,4,0,6,3
|
|
||||||
49,26,11,15,23,10,18,0,0,0
|
|
||||||
9,23,11,20,31,25,15,0,0,0
|
|
||||||
63,35,11,2,7,9,24,0,0,0
|
|
||||||
64,27,11,2,6,6,28,0,5,0
|
|
||||||
8,28,0,43,42,10,23,3,2,4
|
|
||||||
49,20,7,23,21,26,20,0,5,0
|
|
||||||
62,35,11,13,5,13,20,0,7,1
|
|
||||||
68,0,0,41,0,26,31,4,6,3
|
|
||||||
58,19,11,12,20,13,27,0,5,0
|
|
||||||
42,41,5,30,22,0,21,0,0,2
|
|
||||||
48,26,11,23,22,25,3,0,5,3
|
|
||||||
48,41,2,30,22,0,27,0,0,2
|
|
||||||
42,22,12,26,20,24,4,0,5,3
|
|
||||||
64,23,11,9,10,10,28,0,2,0
|
|
||||||
22,26,11,27,22,29,3,0,0,3
|
|
||||||
33,7,11,27,42,25,14,0,0,0
|
|
||||||
2,17,11,16,41,27,4,0,6,0
|
|
||||||
22,18,10,23,41,21,15,0,0,1
|
|
||||||
29,16,11,23,41,25,6,0,0,0
|
|
||||||
33,11,11,23,41,26,15,0,0,0
|
|
||||||
31,23,11,21,42,24,3,0,0,0
|
|
||||||
57,39,12,11,5,0,24,0,0,1
|
|
||||||
34,21,9,23,31,26,15,0,0,0
|
|
||||||
58,1,5,29,39,17,30,0,0,5
|
|
||||||
38,27,12,28,8,23,3,0,5,3
|
|
||||||
68,8,0,6,10,2,31,0,0,3
|
|
||||||
33,11,11,27,31,23,10,0,5,0
|
|
||||||
33,17,11,20,41,30,13,0,0,0
|
|
||||||
60,27,3,23,17,15,29,0,0,3
|
|
||||||
22,41,0,36,42,0,17,3,0,4
|
|
||||||
35,9,11,19,40,27,18,0,6,0
|
|
||||||
58,20,11,12,18,13,27,0,5,0
|
|
||||||
49,27,5,19,31,0,27,0,0,2
|
|
||||||
6,41,0,43,46,0,5,2,0,4
|
|
||||||
59,26,11,12,10,11,24,0,3,0
|
|
||||||
30,41,0,33,41,0,17,3,0,4
|
|
||||||
51,29,3,30,6,16,29,0,5,3
|
|
||||||
15,16,11,27,42,16,3,0,0,3
|
|
||||||
48,20,12,19,22,26,6,0,0,3
|
|
||||||
33,29,11,23,30,18,3,0,0,0
|
|
||||||
23,23,11,28,22,30,3,0,6,3
|
|
||||||
64,41,5,23,1,14,17,3,0,4
|
|
||||||
24,41,0,38,42,0,5,3,0,4
|
|
||||||
22,41,0,38,42,0,4,3,0,4
|
|
||||||
4,17,0,44,2,35,2,0,0,5
|
|
||||||
27,17,11,33,28,30,3,0,0,3
|
|
||||||
30,41,0,43,43,0,20,2,0,4
|
|
||||||
49,26,8,20,13,26,20,0,0,0
|
|
||||||
49,8,0,30,47,15,30,0,0,5
|
|
||||||
40,8,6,11,47,15,23,0,5,3
|
|
||||||
18,41,0,43,43,0,18,2,0,4
|
|
||||||
49,29,11,19,13,2,20,0,0,0
|
|
||||||
22,41,0,38,46,0,9,3,0,4
|
|
||||||
26,15,11,23,22,26,19,0,0,1
|
|
||||||
64,26,8,20,22,26,20,0,0,4
|
|
||||||
54,26,5,30,15,22,24,1,5,3
|
|
||||||
47,39,7,43,4,34,0,3,0,4
|
|
||||||
40,27,0,3,47,0,29,0,0,3
|
|
||||||
34,5,6,23,46,25,20,0,6,0
|
|
||||||
23,17,7,20,40,26,19,0,6,3
|
|
||||||
13,16,11,24,43,30,3,0,6,0
|
|
||||||
66,27,7,18,3,5,29,0,0,3
|
|
||||||
68,27,7,6,2,5,30,0,0,3
|
|
||||||
56,17,1,27,45,10,30,0,6,5
|
|
||||||
33,15,11,23,42,23,4,0,5,0
|
|
||||||
22,2,0,19,48,34,20,0,0,4
|
|
||||||
21,34,0,43,22,5,20,3,0,4
|
|
||||||
55,26,13,14,7,2,18,0,0,0
|
|
||||||
33,7,11,23,43,26,10,0,0,0
|
|
||||||
14,17,11,28,42,30,3,0,0,3
|
|
||||||
23,11,11,28,44,30,3,0,0,3
|
|
||||||
53,41,0,38,45,0,8,3,0,4
|
|
||||||
33,9,11,23,42,26,18,0,5,0
|
|
||||||
65,26,0,29,18,15,30,0,0,5
|
|
||||||
33,20,11,13,42,25,3,0,0,0
|
|
||||||
33,26,11,18,41,26,3,0,0,0
|
|
||||||
28,16,11,29,40,26,3,0,0,3
|
|
||||||
61,27,11,7,6,11,25,0,0,0
|
|
||||||
14,20,11,28,40,30,3,0,4,3
|
|
||||||
35,9,11,17,42,26,18,0,0,0
|
|
||||||
49,29,11,23,8,21,18,1,0,0
|
|
||||||
49,27,11,23,6,10,17,2,0,0
|
|
||||||
23,15,0,35,47,33,28,0,0,5
|
|
||||||
22,24,11,29,40,26,3,0,0,3
|
|
||||||
48,16,11,29,22,26,14,0,5,3
|
|
||||||
27,27,11,34,12,29,3,0,0,3
|
|
||||||
54,27,5,27,10,19,27,0,5,3
|
|
||||||
12,41,11,30,8,11,3,0,5,3
|
|
||||||
40,25,12,19,22,26,3,0,0,3
|
|
||||||
1,27,7,34,35,34,0,3,0,4
|
|
||||||
63,35,11,9,5,0,25,0,0,0
|
|
||||||
66,27,0,23,3,13,31,0,5,3
|
|
||||||
40,24,11,22,34,21,3,0,0,3
|
|
||||||
22,25,9,23,23,21,17,0,0,1
|
|
||||||
33,11,11,23,41,27,15,0,0,0
|
|
||||||
6,41,0,43,46,0,4,2,0,4
|
|
||||||
49,9,5,35,26,26,28,0,0,5
|
|
||||||
49,41,11,0,10,1,20,0,0,0
|
|
||||||
48,22,11,23,22,26,6,0,0,3
|
|
||||||
66,0,0,8,42,0,31,0,0,3
|
|
||||||
59,26,11,12,7,13,24,0,5,0
|
|
||||||
15,41,0,43,43,0,18,2,4,4
|
|
||||||
4,17,0,44,2,35,2,0,0,5
|
|
||||||
68,0,0,8,42,0,31,0,0,3
|
|
||||||
63,35,11,2,7,9,24,0,0,0
|
|
||||||
33,11,11,16,42,25,14,0,0,0
|
|
||||||
48,12,11,21,22,27,18,0,6,3
|
|
||||||
22,25,11,22,35,30,3,0,0,3
|
|
||||||
15,41,0,43,42,1,20,2,0,4
|
|
||||||
23,16,11,23,31,25,16,0,0,3
|
|
||||||
12,20,11,29,42,3,5,0,5,3
|
|
||||||
67,29,11,7,6,1,27,0,5,0
|
|
||||||
44,41,0,34,39,34,0,4,0,4
|
|
||||||
49,33,11,23,13,25,4,0,0,0
|
|
||||||
18,28,5,33,42,0,17,3,0,4
|
|
||||||
61,41,11,12,5,11,20,0,0,0
|
|
||||||
41,16,11,23,40,26,6,0,0,0
|
|
||||||
58,0,4,29,45,26,30,0,0,5
|
|
||||||
49,41,0,3,46,0,29,0,0,2
|
|
||||||
19,17,11,27,40,26,3,0,0,3
|
|
||||||
22,25,11,28,24,30,3,0,5,3
|
|
||||||
36,26,8,30,9,25,19,0,4,1
|
|
||||||
63,41,0,13,25,7,30,0,4,3
|
|
||||||
35,9,11,23,40,25,18,0,0,0
|
|
||||||
28,36,0,39,44,0,17,3,0,4
|
|
||||||
31,9,11,23,34,26,18,0,0,0
|
|
||||||
39,25,6,19,36,24,20,0,0,0
|
|
||||||
23,22,11,23,27,25,8,0,5,1
|
|
||||||
52,25,0,24,19,15,30,0,0,5
|
|
||||||
49,26,11,23,10,24,20,0,0,0
|
|
||||||
34,21,6,23,41,21,20,0,4,0
|
|
||||||
49,30,5,29,22,0,24,0,0,2
|
|
||||||
8,41,0,43,42,1,20,2,0,4
|
|
||||||
49,34,0,40,31,0,29,0,0,2
|
|
||||||
48,17,11,13,20,29,20,0,5,3
|
|
||||||
14,17,11,27,42,30,3,0,0,3
|
|
||||||
14,22,11,27,42,26,3,0,0,3
|
|
||||||
22,6,11,36,42,28,3,0,4,3
|
|
||||||
23,16,11,30,40,29,3,0,5,3
|
|
||||||
25,24,11,30,22,30,3,0,0,3
|
|
||||||
48,20,9,19,28,25,20,0,5,0
|
|
||||||
34,26,11,28,11,26,19,0,0,1
|
|
||||||
34,12,11,20,40,26,15,1,5,0
|
|
||||||
48,24,11,27,20,21,16,0,0,3
|
|
||||||
29,25,11,17,38,20,6,0,0,0
|
|
||||||
21,35,0,43,46,1,20,2,4,4
|
|
||||||
19,26,11,28,41,16,3,0,0,0
|
|
||||||
33,11,11,20,42,26,5,0,0,0
|
|
||||||
16,25,10,20,26,26,4,0,0,1
|
|
||||||
14,15,11,41,24,30,3,0,0,3
|
|
||||||
18,29,11,22,40,15,3,0,5,3
|
|
||||||
25,9,7,30,42,30,14,0,0,3
|
|
||||||
48,34,5,30,25,0,21,0,0,2
|
|
|
@@ -7,13 +7,8 @@ import pandas as pd
|
|||||||
from sklearn.base import ClassifierMixin, BaseEstimator
|
from sklearn.base import ClassifierMixin, BaseEstimator
|
||||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||||
from sklearn.utils.multiclass import unique_labels
|
from sklearn.utils.multiclass import unique_labels
|
||||||
from sklearn.exceptions import NotFittedError
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pgmpy.estimators import (
|
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
||||||
TreeSearch,
|
|
||||||
BayesianEstimator,
|
|
||||||
# MaximumLikelihoodEstimator,
|
|
||||||
)
|
|
||||||
from pgmpy.models import BayesianNetwork
|
from pgmpy.models import BayesianNetwork
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
@@ -39,32 +34,13 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, simple_init=False, show_progress=False, random_state=None
|
self, simple_init=True, show_progress=False, random_state=None
|
||||||
):
|
):
|
||||||
self.simple_init = simple_init
|
self.simple_init = simple_init
|
||||||
self.show_progress = show_progress
|
self.show_progress = show_progress
|
||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
|
|
||||||
def fit(self, X, y, **kwargs):
|
def __check_params_fit(self, X, y, kwargs):
|
||||||
"""A reference implementation of a fitting function for a classifier.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
X : array-like, shape (n_samples, n_features)
|
|
||||||
The training input samples.
|
|
||||||
y : array-like, shape (n_samples,)
|
|
||||||
The target values. An array of int.
|
|
||||||
**kwargs : dict
|
|
||||||
class_name : str (default='class') Name of the class column
|
|
||||||
features: list (default=None) List of features
|
|
||||||
head: int (default=None) Index of the head node. Default value
|
|
||||||
gets the node with the highest sum of weights (mutual_info)
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
self : object
|
|
||||||
Returns self.
|
|
||||||
"""
|
|
||||||
# Check that X and y have correct shape
|
# Check that X and y have correct shape
|
||||||
X, y = check_X_y(X, y)
|
X, y = check_X_y(X, y)
|
||||||
# Store the classes seen during fit
|
# Store the classes seen during fit
|
||||||
@@ -90,16 +66,55 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
if self.head_ is not None and self.head_ >= len(self.features_):
|
if self.head_ is not None and self.head_ >= len(self.features_):
|
||||||
raise ValueError("Head index out of range")
|
raise ValueError("Head index out of range")
|
||||||
|
|
||||||
|
def fit(self, X, y, **kwargs):
|
||||||
|
"""A reference implementation of a fitting function for a classifier.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : array-like, shape (n_samples, n_features)
|
||||||
|
The training input samples.
|
||||||
|
y : array-like, shape (n_samples,)
|
||||||
|
The target values. An array of int.
|
||||||
|
**kwargs : dict
|
||||||
|
class_name : str (default='class') Name of the class column
|
||||||
|
features: list (default=None) List of features
|
||||||
|
head: int (default=None) Index of the head node. Default value
|
||||||
|
gets the node with the highest sum of weights (mutual_info)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
self : object
|
||||||
|
Returns self.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> import pandas as pd
|
||||||
|
>>> from bayesclass import TAN
|
||||||
|
>>> features = ['A', 'B', 'C', 'D', 'E']
|
||||||
|
>>> np.random.seed(17)
|
||||||
|
>>> values = pd.DataFrame(np.random.randint(low=0, high=2,
|
||||||
|
... size=(1000, 5)), columns=features)
|
||||||
|
>>> train_data = values[:800]
|
||||||
|
>>> train_y = train_data['E']
|
||||||
|
>>> predict_data = values[800:]
|
||||||
|
>>> train_data = train_data.drop('E', axis=1)
|
||||||
|
>>> model = TAN(random_state=17)
|
||||||
|
>>> features.remove('E')
|
||||||
|
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
||||||
|
TAN(random_state=17)
|
||||||
|
"""
|
||||||
|
self.__check_params_fit(X, y, kwargs)
|
||||||
|
# Store the information needed to build the model
|
||||||
self.X_ = X
|
self.X_ = X
|
||||||
self.y_ = y.astype(int)
|
self.y_ = y.astype(int)
|
||||||
self.dataset_ = pd.DataFrame(
|
self.dataset_ = pd.DataFrame(
|
||||||
self.X_, columns=self.features_, dtype="int16"
|
self.X_, columns=self.features_, dtype="int16"
|
||||||
)
|
)
|
||||||
self.dataset_[self.class_name_] = self.y_
|
self.dataset_[self.class_name_] = self.y_
|
||||||
try:
|
# Build the DAG
|
||||||
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
|
||||||
except NotFittedError:
|
|
||||||
self.__build()
|
self.__build()
|
||||||
|
# Train the model
|
||||||
self.__train()
|
self.__train()
|
||||||
self.fitted_ = True
|
self.fitted_ = True
|
||||||
# Return the classifier
|
# Return the classifier
|
||||||
@@ -145,24 +160,23 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
# learn graph structure
|
# learn graph structure
|
||||||
root_node = None if self.head_ is None else self.features_[self.head_]
|
root_node = None if self.head_ is None else self.features_[self.head_]
|
||||||
est = TreeSearch(self.dataset_, root_node=root_node)
|
est = TreeSearch(self.dataset_, root_node=root_node)
|
||||||
dag = est.estimate(
|
self.dag_ = est.estimate(
|
||||||
estimator_type="tan",
|
estimator_type="tan",
|
||||||
class_node=self.class_name_,
|
class_node=self.class_name_,
|
||||||
show_progress=self.show_progress,
|
show_progress=self.show_progress,
|
||||||
)
|
)
|
||||||
if self.head_ is None:
|
if self.head_ is None:
|
||||||
self.head_ = est.root_node
|
self.head_ = est.root_node
|
||||||
self.model_ = BayesianNetwork(
|
|
||||||
dag.edges(), show_progress=self.show_progress
|
|
||||||
)
|
|
||||||
|
|
||||||
def __train(self):
|
def __train(self):
|
||||||
|
self.model_ = BayesianNetwork(
|
||||||
|
self.dag_.edges(), show_progress=self.show_progress
|
||||||
|
)
|
||||||
self.model_.fit(
|
self.model_.fit(
|
||||||
self.dataset_,
|
self.dataset_,
|
||||||
# estimator=MaximumLikelihoodEstimator,
|
# estimator=MaximumLikelihoodEstimator,
|
||||||
estimator=BayesianEstimator,
|
estimator=BayesianEstimator,
|
||||||
prior_type="K2",
|
prior_type="K2",
|
||||||
n_jobs=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def plot(self, title=""):
|
def plot(self, title=""):
|
||||||
@@ -203,7 +217,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
>>> train_data = values[:800]
|
>>> train_data = values[:800]
|
||||||
>>> train_y = train_data['E']
|
>>> train_y = train_data['E']
|
||||||
>>> predict_data = values[800:]
|
>>> predict_data = values[800:]
|
||||||
>>> train_data.drop('E', axis=1, inplace=True)
|
>>> train_data = train_data.drop('E', axis=1)
|
||||||
>>> model = TAN(random_state=17)
|
>>> model = TAN(random_state=17)
|
||||||
>>> features.remove('E')
|
>>> features.remove('E')
|
||||||
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
||||||
|
@@ -1 +0,0 @@
|
|||||||
m0 <- ulam(alist(height ~ dnorm(mu, sigma), mu <- a, a ~ dnorm(186, 10), sigma ~ dexp(1)), data = d, chains = 4, iter = 2000, cores = 4, log_lik=TRUE)
|
|
@@ -19,7 +19,7 @@ def data():
|
|||||||
def test_TAN_constructor():
|
def test_TAN_constructor():
|
||||||
clf = TAN()
|
clf = TAN()
|
||||||
# Test default values of hyperparameters
|
# Test default values of hyperparameters
|
||||||
assert not clf.simple_init
|
assert clf.simple_init
|
||||||
assert not clf.show_progress
|
assert not clf.show_progress
|
||||||
assert clf.random_state is None
|
assert clf.random_state is None
|
||||||
clf = TAN(simple_init=True, show_progress=True, random_state=17)
|
clf = TAN(simple_init=True, show_progress=True, random_state=17)
|
||||||
@@ -34,6 +34,14 @@ def test_TAN_random_head(data):
|
|||||||
assert clf.head_ == 3
|
assert clf.head_ == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_TAN_dag_initializer(data):
|
||||||
|
clf_not_simple = TAN(simple_init=False)
|
||||||
|
clf_simple = TAN(simple_init=True)
|
||||||
|
clf_not_simple.fit(*data, head=0)
|
||||||
|
clf_simple.fit(*data, head=0)
|
||||||
|
assert clf_simple.dag_.edges == clf_not_simple.dag_.edges
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_classifier(data):
|
def test_TAN_classifier(data):
|
||||||
clf = TAN()
|
clf = TAN()
|
||||||
|
|
||||||
|
24
example.py
Normal file
24
example.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from benchmark import Discretizer
|
||||||
|
from bayesclass import TAN
|
||||||
|
import sys
|
||||||
|
from sklearn.model_selection import cross_val_score, StratifiedKFold
|
||||||
|
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python3 example.py <dataset> [n_folds]")
|
||||||
|
exit(1)
|
||||||
|
random_state = 17
|
||||||
|
name = sys.argv[1]
|
||||||
|
n_folds = int(sys.argv[2]) if len(sys.argv) == 3 else 5
|
||||||
|
dt = Discretizer()
|
||||||
|
X, y = dt.load(name)
|
||||||
|
clf = TAN(random_state=random_state)
|
||||||
|
fit_params = dict(
|
||||||
|
features=dt.get_features(), class_name=dt.get_class_name(), head=0
|
||||||
|
)
|
||||||
|
kfold = StratifiedKFold(
|
||||||
|
n_splits=n_folds, shuffle=True, random_state=random_state
|
||||||
|
)
|
||||||
|
score = cross_val_score(clf, X, y, cv=kfold, fit_params=fit_params)
|
||||||
|
print(f"Accuracy in {n_folds} folds stratified crossvalidation")
|
||||||
|
print(f"{name}{'.' * 10}{score.mean():9.7f}")
|
215
glass.csv
215
glass.csv
@@ -1,215 +0,0 @@
|
|||||||
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
|
|
||||||
35,11,11,16,41,29,18,0,0,0
|
|
||||||
22,3,11,23,40,25,13,0,0,1
|
|
||||||
35,21,11,26,23,26,6,0,0,0
|
|
||||||
3,41,5,28,48,0,3,0,0,2
|
|
||||||
68,4,0,13,0,10,31,0,6,3
|
|
||||||
22,9,6,27,42,25,19,1,5,3
|
|
||||||
34,26,11,5,41,2,20,0,0,1
|
|
||||||
46,20,6,23,38,24,20,0,0,0
|
|
||||||
8,34,0,43,45,5,20,2,2,4
|
|
||||||
35,20,12,23,20,24,8,0,6,3
|
|
||||||
22,24,11,27,28,16,3,0,0,3
|
|
||||||
32,4,7,17,46,27,20,0,6,3
|
|
||||||
63,21,11,9,10,10,28,0,0,0
|
|
||||||
57,32,11,5,7,9,24,0,0,1
|
|
||||||
23,20,11,32,22,26,4,0,5,1
|
|
||||||
27,26,11,30,22,27,3,0,0,3
|
|
||||||
28,41,0,38,41,0,13,3,4,4
|
|
||||||
22,9,7,27,43,30,3,0,0,3
|
|
||||||
50,23,0,32,41,17,30,0,0,5
|
|
||||||
40,16,7,27,40,26,19,1,0,3
|
|
||||||
58,19,11,12,18,13,27,0,5,0
|
|
||||||
66,0,0,37,17,33,31,0,6,3
|
|
||||||
48,16,11,15,20,29,20,0,5,3
|
|
||||||
33,25,11,19,34,25,3,0,5,0
|
|
||||||
52,12,4,42,15,32,25,1,8,5
|
|
||||||
10,22,11,27,42,15,3,0,0,3
|
|
||||||
14,11,11,37,39,30,3,0,0,3
|
|
||||||
24,41,0,38,43,0,10,3,3,4
|
|
||||||
25,22,11,27,36,25,3,0,0,3
|
|
||||||
16,24,11,23,25,25,4,0,0,1
|
|
||||||
23,16,11,30,42,30,3,0,0,3
|
|
||||||
45,24,8,28,14,25,20,0,0,1
|
|
||||||
0,41,0,1,48,0,1,0,0,2
|
|
||||||
22,26,11,28,22,30,3,0,0,3
|
|
||||||
33,17,11,23,41,25,5,0,5,0
|
|
||||||
11,9,11,29,42,30,3,0,6,0
|
|
||||||
14,11,11,30,40,29,3,0,6,0
|
|
||||||
30,4,6,30,40,30,20,0,0,3
|
|
||||||
23,12,11,27,41,30,3,1,6,3
|
|
||||||
5,37,7,39,20,33,3,0,0,3
|
|
||||||
37,9,11,23,40,29,18,0,0,0
|
|
||||||
38,17,11,13,42,25,4,0,5,3
|
|
||||||
22,17,11,28,35,26,3,0,0,3
|
|
||||||
14,22,8,27,42,16,3,0,0,3
|
|
||||||
49,26,11,8,24,11,20,1,6,1
|
|
||||||
33,9,11,20,42,26,17,0,0,0
|
|
||||||
6,31,5,44,0,34,0,4,0,5
|
|
||||||
33,21,11,23,22,25,3,0,0,0
|
|
||||||
35,19,11,23,39,26,10,0,0,0
|
|
||||||
60,20,11,17,31,23,10,0,0,3
|
|
||||||
33,7,11,23,45,26,14,0,3,0
|
|
||||||
48,20,11,13,35,25,6,1,5,3
|
|
||||||
32,23,11,16,42,26,3,0,0,0
|
|
||||||
14,20,11,28,42,30,3,0,0,3
|
|
||||||
22,41,0,43,38,0,23,2,0,4
|
|
||||||
31,17,11,30,30,23,8,0,3,0
|
|
||||||
64,41,5,38,1,32,26,0,0,4
|
|
||||||
55,27,11,16,10,6,22,0,0,0
|
|
||||||
34,26,11,15,33,9,16,0,0,1
|
|
||||||
48,25,12,23,22,21,4,0,0,3
|
|
||||||
48,26,12,23,10,23,4,0,6,3
|
|
||||||
49,26,11,15,23,10,18,0,0,0
|
|
||||||
9,23,11,20,31,25,15,0,0,0
|
|
||||||
63,35,11,2,7,9,24,0,0,0
|
|
||||||
64,27,11,2,6,6,28,0,5,0
|
|
||||||
8,28,0,43,42,10,23,3,2,4
|
|
||||||
49,20,7,23,21,26,20,0,5,0
|
|
||||||
62,35,11,13,5,13,20,0,7,1
|
|
||||||
68,0,0,41,0,26,31,4,6,3
|
|
||||||
58,19,11,12,20,13,27,0,5,0
|
|
||||||
42,41,5,30,22,0,21,0,0,2
|
|
||||||
48,26,11,23,22,25,3,0,5,3
|
|
||||||
48,41,2,30,22,0,27,0,0,2
|
|
||||||
42,22,12,26,20,24,4,0,5,3
|
|
||||||
64,23,11,9,10,10,28,0,2,0
|
|
||||||
22,26,11,27,22,29,3,0,0,3
|
|
||||||
33,7,11,27,42,25,14,0,0,0
|
|
||||||
2,17,11,16,41,27,4,0,6,0
|
|
||||||
22,18,10,23,41,21,15,0,0,1
|
|
||||||
29,16,11,23,41,25,6,0,0,0
|
|
||||||
33,11,11,23,41,26,15,0,0,0
|
|
||||||
31,23,11,21,42,24,3,0,0,0
|
|
||||||
57,39,12,11,5,0,24,0,0,1
|
|
||||||
34,21,9,23,31,26,15,0,0,0
|
|
||||||
58,1,5,29,39,17,30,0,0,5
|
|
||||||
38,27,12,28,8,23,3,0,5,3
|
|
||||||
68,8,0,6,10,2,31,0,0,3
|
|
||||||
33,11,11,27,31,23,10,0,5,0
|
|
||||||
33,17,11,20,41,30,13,0,0,0
|
|
||||||
60,27,3,23,17,15,29,0,0,3
|
|
||||||
22,41,0,36,42,0,17,3,0,4
|
|
||||||
35,9,11,19,40,27,18,0,6,0
|
|
||||||
58,20,11,12,18,13,27,0,5,0
|
|
||||||
49,27,5,19,31,0,27,0,0,2
|
|
||||||
6,41,0,43,46,0,5,2,0,4
|
|
||||||
59,26,11,12,10,11,24,0,3,0
|
|
||||||
30,41,0,33,41,0,17,3,0,4
|
|
||||||
51,29,3,30,6,16,29,0,5,3
|
|
||||||
15,16,11,27,42,16,3,0,0,3
|
|
||||||
48,20,12,19,22,26,6,0,0,3
|
|
||||||
33,29,11,23,30,18,3,0,0,0
|
|
||||||
23,23,11,28,22,30,3,0,6,3
|
|
||||||
64,41,5,23,1,14,17,3,0,4
|
|
||||||
24,41,0,38,42,0,5,3,0,4
|
|
||||||
22,41,0,38,42,0,4,3,0,4
|
|
||||||
4,17,0,44,2,35,2,0,0,5
|
|
||||||
27,17,11,33,28,30,3,0,0,3
|
|
||||||
30,41,0,43,43,0,20,2,0,4
|
|
||||||
49,26,8,20,13,26,20,0,0,0
|
|
||||||
49,8,0,30,47,15,30,0,0,5
|
|
||||||
40,8,6,11,47,15,23,0,5,3
|
|
||||||
18,41,0,43,43,0,18,2,0,4
|
|
||||||
49,29,11,19,13,2,20,0,0,0
|
|
||||||
22,41,0,38,46,0,9,3,0,4
|
|
||||||
26,15,11,23,22,26,19,0,0,1
|
|
||||||
64,26,8,20,22,26,20,0,0,4
|
|
||||||
54,26,5,30,15,22,24,1,5,3
|
|
||||||
47,39,7,43,4,34,0,3,0,4
|
|
||||||
40,27,0,3,47,0,29,0,0,3
|
|
||||||
34,5,6,23,46,25,20,0,6,0
|
|
||||||
23,17,7,20,40,26,19,0,6,3
|
|
||||||
13,16,11,24,43,30,3,0,6,0
|
|
||||||
66,27,7,18,3,5,29,0,0,3
|
|
||||||
68,27,7,6,2,5,30,0,0,3
|
|
||||||
56,17,1,27,45,10,30,0,6,5
|
|
||||||
33,15,11,23,42,23,4,0,5,0
|
|
||||||
22,2,0,19,48,34,20,0,0,4
|
|
||||||
21,34,0,43,22,5,20,3,0,4
|
|
||||||
55,26,13,14,7,2,18,0,0,0
|
|
||||||
33,7,11,23,43,26,10,0,0,0
|
|
||||||
14,17,11,28,42,30,3,0,0,3
|
|
||||||
23,11,11,28,44,30,3,0,0,3
|
|
||||||
53,41,0,38,45,0,8,3,0,4
|
|
||||||
33,9,11,23,42,26,18,0,5,0
|
|
||||||
65,26,0,29,18,15,30,0,0,5
|
|
||||||
33,20,11,13,42,25,3,0,0,0
|
|
||||||
33,26,11,18,41,26,3,0,0,0
|
|
||||||
28,16,11,29,40,26,3,0,0,3
|
|
||||||
61,27,11,7,6,11,25,0,0,0
|
|
||||||
14,20,11,28,40,30,3,0,4,3
|
|
||||||
35,9,11,17,42,26,18,0,0,0
|
|
||||||
49,29,11,23,8,21,18,1,0,0
|
|
||||||
49,27,11,23,6,10,17,2,0,0
|
|
||||||
23,15,0,35,47,33,28,0,0,5
|
|
||||||
22,24,11,29,40,26,3,0,0,3
|
|
||||||
48,16,11,29,22,26,14,0,5,3
|
|
||||||
27,27,11,34,12,29,3,0,0,3
|
|
||||||
54,27,5,27,10,19,27,0,5,3
|
|
||||||
12,41,11,30,8,11,3,0,5,3
|
|
||||||
40,25,12,19,22,26,3,0,0,3
|
|
||||||
1,27,7,34,35,34,0,3,0,4
|
|
||||||
63,35,11,9,5,0,25,0,0,0
|
|
||||||
66,27,0,23,3,13,31,0,5,3
|
|
||||||
40,24,11,22,34,21,3,0,0,3
|
|
||||||
22,25,9,23,23,21,17,0,0,1
|
|
||||||
33,11,11,23,41,27,15,0,0,0
|
|
||||||
6,41,0,43,46,0,4,2,0,4
|
|
||||||
49,9,5,35,26,26,28,0,0,5
|
|
||||||
49,41,11,0,10,1,20,0,0,0
|
|
||||||
48,22,11,23,22,26,6,0,0,3
|
|
||||||
66,0,0,8,42,0,31,0,0,3
|
|
||||||
59,26,11,12,7,13,24,0,5,0
|
|
||||||
15,41,0,43,43,0,18,2,4,4
|
|
||||||
4,17,0,44,2,35,2,0,0,5
|
|
||||||
68,0,0,8,42,0,31,0,0,3
|
|
||||||
63,35,11,2,7,9,24,0,0,0
|
|
||||||
33,11,11,16,42,25,14,0,0,0
|
|
||||||
48,12,11,21,22,27,18,0,6,3
|
|
||||||
22,25,11,22,35,30,3,0,0,3
|
|
||||||
15,41,0,43,42,1,20,2,0,4
|
|
||||||
23,16,11,23,31,25,16,0,0,3
|
|
||||||
12,20,11,29,42,3,5,0,5,3
|
|
||||||
67,29,11,7,6,1,27,0,5,0
|
|
||||||
44,41,0,34,39,34,0,4,0,4
|
|
||||||
49,33,11,23,13,25,4,0,0,0
|
|
||||||
18,28,5,33,42,0,17,3,0,4
|
|
||||||
61,41,11,12,5,11,20,0,0,0
|
|
||||||
41,16,11,23,40,26,6,0,0,0
|
|
||||||
58,0,4,29,45,26,30,0,0,5
|
|
||||||
49,41,0,3,46,0,29,0,0,2
|
|
||||||
19,17,11,27,40,26,3,0,0,3
|
|
||||||
22,25,11,28,24,30,3,0,5,3
|
|
||||||
36,26,8,30,9,25,19,0,4,1
|
|
||||||
63,41,0,13,25,7,30,0,4,3
|
|
||||||
35,9,11,23,40,25,18,0,0,0
|
|
||||||
28,36,0,39,44,0,17,3,0,4
|
|
||||||
31,9,11,23,34,26,18,0,0,0
|
|
||||||
39,25,6,19,36,24,20,0,0,0
|
|
||||||
23,22,11,23,27,25,8,0,5,1
|
|
||||||
52,25,0,24,19,15,30,0,0,5
|
|
||||||
49,26,11,23,10,24,20,0,0,0
|
|
||||||
34,21,6,23,41,21,20,0,4,0
|
|
||||||
49,30,5,29,22,0,24,0,0,2
|
|
||||||
8,41,0,43,42,1,20,2,0,4
|
|
||||||
49,34,0,40,31,0,29,0,0,2
|
|
||||||
48,17,11,13,20,29,20,0,5,3
|
|
||||||
14,17,11,27,42,30,3,0,0,3
|
|
||||||
14,22,11,27,42,26,3,0,0,3
|
|
||||||
22,6,11,36,42,28,3,0,4,3
|
|
||||||
23,16,11,30,40,29,3,0,5,3
|
|
||||||
25,24,11,30,22,30,3,0,0,3
|
|
||||||
48,20,9,19,28,25,20,0,5,0
|
|
||||||
34,26,11,28,11,26,19,0,0,1
|
|
||||||
34,12,11,20,40,26,15,1,5,0
|
|
||||||
48,24,11,27,20,21,16,0,0,3
|
|
||||||
29,25,11,17,38,20,6,0,0,0
|
|
||||||
21,35,0,43,46,1,20,2,4,4
|
|
||||||
19,26,11,28,41,16,3,0,0,0
|
|
||||||
33,11,11,20,42,26,5,0,0,0
|
|
||||||
16,25,10,20,26,26,4,0,0,1
|
|
||||||
14,15,11,41,24,30,3,0,0,3
|
|
||||||
18,29,11,22,40,15,3,0,5,3
|
|
||||||
25,9,7,30,42,30,14,0,0,3
|
|
||||||
48,34,5,30,25,0,21,0,0,2
|
|
|
Binary file not shown.
Before Width: | Height: | Size: 45 KiB After Width: | Height: | Size: 45 KiB |
712
test.ipynb
712
test.ipynb
File diff suppressed because one or more lines are too long
111
test.py
111
test.py
@@ -1,111 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf-8
|
|
||||||
|
|
||||||
# In[1]:
|
|
||||||
|
|
||||||
|
|
||||||
from mdlp import MDLP
|
|
||||||
import pandas as pd
|
|
||||||
from benchmark import Datasets
|
|
||||||
from bayesclass import TAN
|
|
||||||
from sklearn.model_selection import (
|
|
||||||
cross_validate,
|
|
||||||
StratifiedKFold,
|
|
||||||
KFold,
|
|
||||||
cross_val_score,
|
|
||||||
train_test_split,
|
|
||||||
)
|
|
||||||
import numpy as np
|
|
||||||
import warnings
|
|
||||||
from stree import Stree
|
|
||||||
|
|
||||||
# In[2]:
|
|
||||||
|
|
||||||
|
|
||||||
# Get data as a dataset
|
|
||||||
dt = Datasets()
|
|
||||||
data = dt.load("glass", dataframe=True)
|
|
||||||
features = dt.dataset.features
|
|
||||||
class_name = dt.dataset.class_name
|
|
||||||
factorization, class_factors = pd.factorize(data[class_name])
|
|
||||||
data[class_name] = factorization
|
|
||||||
data.head()
|
|
||||||
|
|
||||||
|
|
||||||
# In[3]:
|
|
||||||
|
|
||||||
|
|
||||||
# Fayyad Irani
|
|
||||||
discretiz = MDLP()
|
|
||||||
Xdisc = discretiz.fit_transform(
|
|
||||||
data[features].to_numpy(), data[class_name].to_numpy()
|
|
||||||
)
|
|
||||||
features_discretized = pd.DataFrame(Xdisc, columns=features)
|
|
||||||
dataset_discretized = features_discretized.copy()
|
|
||||||
dataset_discretized[class_name] = data[class_name]
|
|
||||||
X = dataset_discretized[features]
|
|
||||||
y = dataset_discretized[class_name]
|
|
||||||
dataset_discretized
|
|
||||||
|
|
||||||
|
|
||||||
# In[4]:
|
|
||||||
|
|
||||||
|
|
||||||
n_folds = 5
|
|
||||||
score_name = "accuracy"
|
|
||||||
random_state = 17
|
|
||||||
test_size = 0.3
|
|
||||||
|
|
||||||
|
|
||||||
def validate_classifier(model, X, y, stratified, fit_params):
|
|
||||||
stratified_class = StratifiedKFold if stratified else KFold
|
|
||||||
kfold = stratified_class(
|
|
||||||
shuffle=True, random_state=random_state, n_splits=n_folds
|
|
||||||
)
|
|
||||||
# return cross_validate(model, X, y, cv=kfold, return_estimator=True,
|
|
||||||
# scoring=score_name)
|
|
||||||
return cross_val_score(model, X, y, fit_params=fit_params)
|
|
||||||
|
|
||||||
|
|
||||||
def split_data(X, y, stratified):
|
|
||||||
if stratified:
|
|
||||||
return train_test_split(
|
|
||||||
X,
|
|
||||||
y,
|
|
||||||
test_size=test_size,
|
|
||||||
random_state=random_state,
|
|
||||||
stratify=y,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return train_test_split(
|
|
||||||
X, y, test_size=test_size, random_state=random_state, shuffle=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# In[5]:
|
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
for simple_init in [False, True]:
|
|
||||||
model = TAN(simple_init=simple_init)
|
|
||||||
for head in range(4):
|
|
||||||
X_train, X_test, y_train, y_test = split_data(X, y, stratified=False)
|
|
||||||
model.fit(
|
|
||||||
X_train,
|
|
||||||
y_train,
|
|
||||||
head=head,
|
|
||||||
features=features,
|
|
||||||
class_name=class_name,
|
|
||||||
)
|
|
||||||
y = model.predict(X_test)
|
|
||||||
model.plot()
|
|
||||||
|
|
||||||
# In[ ]:
|
|
||||||
|
|
||||||
|
|
||||||
model = TAN(simple_init=simple_init)
|
|
||||||
model.fit(X, y, features=features, class_name=class_name)
|
|
||||||
model.plot(
|
|
||||||
f"**simple_init={simple_init} head={head} score={model.score(X, y)}"
|
|
||||||
)
|
|
Reference in New Issue
Block a user