mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-15 23:55:57 +00:00
Fix fit/build/train mistake
This commit is contained in:
215
balance-scale.csv
Normal file
215
balance-scale.csv
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
|
||||||
|
35,11,11,16,41,29,18,0,0,0
|
||||||
|
22,3,11,23,40,25,13,0,0,1
|
||||||
|
35,21,11,26,23,26,6,0,0,0
|
||||||
|
3,41,5,28,48,0,3,0,0,2
|
||||||
|
68,4,0,13,0,10,31,0,6,3
|
||||||
|
22,9,6,27,42,25,19,1,5,3
|
||||||
|
34,26,11,5,41,2,20,0,0,1
|
||||||
|
46,20,6,23,38,24,20,0,0,0
|
||||||
|
8,34,0,43,45,5,20,2,2,4
|
||||||
|
35,20,12,23,20,24,8,0,6,3
|
||||||
|
22,24,11,27,28,16,3,0,0,3
|
||||||
|
32,4,7,17,46,27,20,0,6,3
|
||||||
|
63,21,11,9,10,10,28,0,0,0
|
||||||
|
57,32,11,5,7,9,24,0,0,1
|
||||||
|
23,20,11,32,22,26,4,0,5,1
|
||||||
|
27,26,11,30,22,27,3,0,0,3
|
||||||
|
28,41,0,38,41,0,13,3,4,4
|
||||||
|
22,9,7,27,43,30,3,0,0,3
|
||||||
|
50,23,0,32,41,17,30,0,0,5
|
||||||
|
40,16,7,27,40,26,19,1,0,3
|
||||||
|
58,19,11,12,18,13,27,0,5,0
|
||||||
|
66,0,0,37,17,33,31,0,6,3
|
||||||
|
48,16,11,15,20,29,20,0,5,3
|
||||||
|
33,25,11,19,34,25,3,0,5,0
|
||||||
|
52,12,4,42,15,32,25,1,8,5
|
||||||
|
10,22,11,27,42,15,3,0,0,3
|
||||||
|
14,11,11,37,39,30,3,0,0,3
|
||||||
|
24,41,0,38,43,0,10,3,3,4
|
||||||
|
25,22,11,27,36,25,3,0,0,3
|
||||||
|
16,24,11,23,25,25,4,0,0,1
|
||||||
|
23,16,11,30,42,30,3,0,0,3
|
||||||
|
45,24,8,28,14,25,20,0,0,1
|
||||||
|
0,41,0,1,48,0,1,0,0,2
|
||||||
|
22,26,11,28,22,30,3,0,0,3
|
||||||
|
33,17,11,23,41,25,5,0,5,0
|
||||||
|
11,9,11,29,42,30,3,0,6,0
|
||||||
|
14,11,11,30,40,29,3,0,6,0
|
||||||
|
30,4,6,30,40,30,20,0,0,3
|
||||||
|
23,12,11,27,41,30,3,1,6,3
|
||||||
|
5,37,7,39,20,33,3,0,0,3
|
||||||
|
37,9,11,23,40,29,18,0,0,0
|
||||||
|
38,17,11,13,42,25,4,0,5,3
|
||||||
|
22,17,11,28,35,26,3,0,0,3
|
||||||
|
14,22,8,27,42,16,3,0,0,3
|
||||||
|
49,26,11,8,24,11,20,1,6,1
|
||||||
|
33,9,11,20,42,26,17,0,0,0
|
||||||
|
6,31,5,44,0,34,0,4,0,5
|
||||||
|
33,21,11,23,22,25,3,0,0,0
|
||||||
|
35,19,11,23,39,26,10,0,0,0
|
||||||
|
60,20,11,17,31,23,10,0,0,3
|
||||||
|
33,7,11,23,45,26,14,0,3,0
|
||||||
|
48,20,11,13,35,25,6,1,5,3
|
||||||
|
32,23,11,16,42,26,3,0,0,0
|
||||||
|
14,20,11,28,42,30,3,0,0,3
|
||||||
|
22,41,0,43,38,0,23,2,0,4
|
||||||
|
31,17,11,30,30,23,8,0,3,0
|
||||||
|
64,41,5,38,1,32,26,0,0,4
|
||||||
|
55,27,11,16,10,6,22,0,0,0
|
||||||
|
34,26,11,15,33,9,16,0,0,1
|
||||||
|
48,25,12,23,22,21,4,0,0,3
|
||||||
|
48,26,12,23,10,23,4,0,6,3
|
||||||
|
49,26,11,15,23,10,18,0,0,0
|
||||||
|
9,23,11,20,31,25,15,0,0,0
|
||||||
|
63,35,11,2,7,9,24,0,0,0
|
||||||
|
64,27,11,2,6,6,28,0,5,0
|
||||||
|
8,28,0,43,42,10,23,3,2,4
|
||||||
|
49,20,7,23,21,26,20,0,5,0
|
||||||
|
62,35,11,13,5,13,20,0,7,1
|
||||||
|
68,0,0,41,0,26,31,4,6,3
|
||||||
|
58,19,11,12,20,13,27,0,5,0
|
||||||
|
42,41,5,30,22,0,21,0,0,2
|
||||||
|
48,26,11,23,22,25,3,0,5,3
|
||||||
|
48,41,2,30,22,0,27,0,0,2
|
||||||
|
42,22,12,26,20,24,4,0,5,3
|
||||||
|
64,23,11,9,10,10,28,0,2,0
|
||||||
|
22,26,11,27,22,29,3,0,0,3
|
||||||
|
33,7,11,27,42,25,14,0,0,0
|
||||||
|
2,17,11,16,41,27,4,0,6,0
|
||||||
|
22,18,10,23,41,21,15,0,0,1
|
||||||
|
29,16,11,23,41,25,6,0,0,0
|
||||||
|
33,11,11,23,41,26,15,0,0,0
|
||||||
|
31,23,11,21,42,24,3,0,0,0
|
||||||
|
57,39,12,11,5,0,24,0,0,1
|
||||||
|
34,21,9,23,31,26,15,0,0,0
|
||||||
|
58,1,5,29,39,17,30,0,0,5
|
||||||
|
38,27,12,28,8,23,3,0,5,3
|
||||||
|
68,8,0,6,10,2,31,0,0,3
|
||||||
|
33,11,11,27,31,23,10,0,5,0
|
||||||
|
33,17,11,20,41,30,13,0,0,0
|
||||||
|
60,27,3,23,17,15,29,0,0,3
|
||||||
|
22,41,0,36,42,0,17,3,0,4
|
||||||
|
35,9,11,19,40,27,18,0,6,0
|
||||||
|
58,20,11,12,18,13,27,0,5,0
|
||||||
|
49,27,5,19,31,0,27,0,0,2
|
||||||
|
6,41,0,43,46,0,5,2,0,4
|
||||||
|
59,26,11,12,10,11,24,0,3,0
|
||||||
|
30,41,0,33,41,0,17,3,0,4
|
||||||
|
51,29,3,30,6,16,29,0,5,3
|
||||||
|
15,16,11,27,42,16,3,0,0,3
|
||||||
|
48,20,12,19,22,26,6,0,0,3
|
||||||
|
33,29,11,23,30,18,3,0,0,0
|
||||||
|
23,23,11,28,22,30,3,0,6,3
|
||||||
|
64,41,5,23,1,14,17,3,0,4
|
||||||
|
24,41,0,38,42,0,5,3,0,4
|
||||||
|
22,41,0,38,42,0,4,3,0,4
|
||||||
|
4,17,0,44,2,35,2,0,0,5
|
||||||
|
27,17,11,33,28,30,3,0,0,3
|
||||||
|
30,41,0,43,43,0,20,2,0,4
|
||||||
|
49,26,8,20,13,26,20,0,0,0
|
||||||
|
49,8,0,30,47,15,30,0,0,5
|
||||||
|
40,8,6,11,47,15,23,0,5,3
|
||||||
|
18,41,0,43,43,0,18,2,0,4
|
||||||
|
49,29,11,19,13,2,20,0,0,0
|
||||||
|
22,41,0,38,46,0,9,3,0,4
|
||||||
|
26,15,11,23,22,26,19,0,0,1
|
||||||
|
64,26,8,20,22,26,20,0,0,4
|
||||||
|
54,26,5,30,15,22,24,1,5,3
|
||||||
|
47,39,7,43,4,34,0,3,0,4
|
||||||
|
40,27,0,3,47,0,29,0,0,3
|
||||||
|
34,5,6,23,46,25,20,0,6,0
|
||||||
|
23,17,7,20,40,26,19,0,6,3
|
||||||
|
13,16,11,24,43,30,3,0,6,0
|
||||||
|
66,27,7,18,3,5,29,0,0,3
|
||||||
|
68,27,7,6,2,5,30,0,0,3
|
||||||
|
56,17,1,27,45,10,30,0,6,5
|
||||||
|
33,15,11,23,42,23,4,0,5,0
|
||||||
|
22,2,0,19,48,34,20,0,0,4
|
||||||
|
21,34,0,43,22,5,20,3,0,4
|
||||||
|
55,26,13,14,7,2,18,0,0,0
|
||||||
|
33,7,11,23,43,26,10,0,0,0
|
||||||
|
14,17,11,28,42,30,3,0,0,3
|
||||||
|
23,11,11,28,44,30,3,0,0,3
|
||||||
|
53,41,0,38,45,0,8,3,0,4
|
||||||
|
33,9,11,23,42,26,18,0,5,0
|
||||||
|
65,26,0,29,18,15,30,0,0,5
|
||||||
|
33,20,11,13,42,25,3,0,0,0
|
||||||
|
33,26,11,18,41,26,3,0,0,0
|
||||||
|
28,16,11,29,40,26,3,0,0,3
|
||||||
|
61,27,11,7,6,11,25,0,0,0
|
||||||
|
14,20,11,28,40,30,3,0,4,3
|
||||||
|
35,9,11,17,42,26,18,0,0,0
|
||||||
|
49,29,11,23,8,21,18,1,0,0
|
||||||
|
49,27,11,23,6,10,17,2,0,0
|
||||||
|
23,15,0,35,47,33,28,0,0,5
|
||||||
|
22,24,11,29,40,26,3,0,0,3
|
||||||
|
48,16,11,29,22,26,14,0,5,3
|
||||||
|
27,27,11,34,12,29,3,0,0,3
|
||||||
|
54,27,5,27,10,19,27,0,5,3
|
||||||
|
12,41,11,30,8,11,3,0,5,3
|
||||||
|
40,25,12,19,22,26,3,0,0,3
|
||||||
|
1,27,7,34,35,34,0,3,0,4
|
||||||
|
63,35,11,9,5,0,25,0,0,0
|
||||||
|
66,27,0,23,3,13,31,0,5,3
|
||||||
|
40,24,11,22,34,21,3,0,0,3
|
||||||
|
22,25,9,23,23,21,17,0,0,1
|
||||||
|
33,11,11,23,41,27,15,0,0,0
|
||||||
|
6,41,0,43,46,0,4,2,0,4
|
||||||
|
49,9,5,35,26,26,28,0,0,5
|
||||||
|
49,41,11,0,10,1,20,0,0,0
|
||||||
|
48,22,11,23,22,26,6,0,0,3
|
||||||
|
66,0,0,8,42,0,31,0,0,3
|
||||||
|
59,26,11,12,7,13,24,0,5,0
|
||||||
|
15,41,0,43,43,0,18,2,4,4
|
||||||
|
4,17,0,44,2,35,2,0,0,5
|
||||||
|
68,0,0,8,42,0,31,0,0,3
|
||||||
|
63,35,11,2,7,9,24,0,0,0
|
||||||
|
33,11,11,16,42,25,14,0,0,0
|
||||||
|
48,12,11,21,22,27,18,0,6,3
|
||||||
|
22,25,11,22,35,30,3,0,0,3
|
||||||
|
15,41,0,43,42,1,20,2,0,4
|
||||||
|
23,16,11,23,31,25,16,0,0,3
|
||||||
|
12,20,11,29,42,3,5,0,5,3
|
||||||
|
67,29,11,7,6,1,27,0,5,0
|
||||||
|
44,41,0,34,39,34,0,4,0,4
|
||||||
|
49,33,11,23,13,25,4,0,0,0
|
||||||
|
18,28,5,33,42,0,17,3,0,4
|
||||||
|
61,41,11,12,5,11,20,0,0,0
|
||||||
|
41,16,11,23,40,26,6,0,0,0
|
||||||
|
58,0,4,29,45,26,30,0,0,5
|
||||||
|
49,41,0,3,46,0,29,0,0,2
|
||||||
|
19,17,11,27,40,26,3,0,0,3
|
||||||
|
22,25,11,28,24,30,3,0,5,3
|
||||||
|
36,26,8,30,9,25,19,0,4,1
|
||||||
|
63,41,0,13,25,7,30,0,4,3
|
||||||
|
35,9,11,23,40,25,18,0,0,0
|
||||||
|
28,36,0,39,44,0,17,3,0,4
|
||||||
|
31,9,11,23,34,26,18,0,0,0
|
||||||
|
39,25,6,19,36,24,20,0,0,0
|
||||||
|
23,22,11,23,27,25,8,0,5,1
|
||||||
|
52,25,0,24,19,15,30,0,0,5
|
||||||
|
49,26,11,23,10,24,20,0,0,0
|
||||||
|
34,21,6,23,41,21,20,0,4,0
|
||||||
|
49,30,5,29,22,0,24,0,0,2
|
||||||
|
8,41,0,43,42,1,20,2,0,4
|
||||||
|
49,34,0,40,31,0,29,0,0,2
|
||||||
|
48,17,11,13,20,29,20,0,5,3
|
||||||
|
14,17,11,27,42,30,3,0,0,3
|
||||||
|
14,22,11,27,42,26,3,0,0,3
|
||||||
|
22,6,11,36,42,28,3,0,4,3
|
||||||
|
23,16,11,30,40,29,3,0,5,3
|
||||||
|
25,24,11,30,22,30,3,0,0,3
|
||||||
|
48,20,9,19,28,25,20,0,5,0
|
||||||
|
34,26,11,28,11,26,19,0,0,1
|
||||||
|
34,12,11,20,40,26,15,1,5,0
|
||||||
|
48,24,11,27,20,21,16,0,0,3
|
||||||
|
29,25,11,17,38,20,6,0,0,0
|
||||||
|
21,35,0,43,46,1,20,2,4,4
|
||||||
|
19,26,11,28,41,16,3,0,0,0
|
||||||
|
33,11,11,20,42,26,5,0,0,0
|
||||||
|
16,25,10,20,26,26,4,0,0,1
|
||||||
|
14,15,11,41,24,30,3,0,0,3
|
||||||
|
18,29,11,22,40,15,3,0,5,3
|
||||||
|
25,9,7,30,42,30,14,0,0,3
|
||||||
|
48,34,5,30,25,0,21,0,0,2
|
|
@@ -7,11 +7,12 @@ import pandas as pd
|
|||||||
from sklearn.base import ClassifierMixin, BaseEstimator
|
from sklearn.base import ClassifierMixin, BaseEstimator
|
||||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||||
from sklearn.utils.multiclass import unique_labels
|
from sklearn.utils.multiclass import unique_labels
|
||||||
|
from sklearn.exceptions import NotFittedError
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pgmpy.estimators import (
|
from pgmpy.estimators import (
|
||||||
TreeSearch,
|
TreeSearch,
|
||||||
BayesianEstimator,
|
BayesianEstimator,
|
||||||
MaximumLikelihoodEstimator,
|
# MaximumLikelihoodEstimator,
|
||||||
)
|
)
|
||||||
from pgmpy.models import BayesianNetwork
|
from pgmpy.models import BayesianNetwork
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -21,10 +22,12 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
"""An example classifier which implements a 1-NN algorithm.
|
"""An example classifier which implements a 1-NN algorithm.
|
||||||
For more information regarding how to build your own classifier, read more
|
For more information regarding how to build your own classifier, read more
|
||||||
in the :ref:`User Guide <user_guide>`.
|
in the :ref:`User Guide <user_guide>`.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
demo_param : str, default='demo'
|
demo_param : str, default='demo'
|
||||||
A parameter used for demonstation of how to pass and store paramters.
|
A parameter used for demonstation of how to pass and store paramters.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
X_ : ndarray, shape (n_samples, n_features)
|
X_ : ndarray, shape (n_samples, n_features)
|
||||||
@@ -44,6 +47,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
|
|
||||||
def fit(self, X, y, **kwargs):
|
def fit(self, X, y, **kwargs):
|
||||||
"""A reference implementation of a fitting function for a classifier.
|
"""A reference implementation of a fitting function for a classifier.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
X : array-like, shape (n_samples, n_features)
|
X : array-like, shape (n_samples, n_features)
|
||||||
@@ -55,6 +59,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
features: list (default=None) List of features
|
features: list (default=None) List of features
|
||||||
head: int (default=None) Index of the head node. Default value
|
head: int (default=None) Index of the head node. Default value
|
||||||
gets the node with the highest sum of weights (mutual_info)
|
gets the node with the highest sum of weights (mutual_info)
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
self : object
|
self : object
|
||||||
@@ -86,8 +91,17 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
raise ValueError("Head index out of range")
|
raise ValueError("Head index out of range")
|
||||||
|
|
||||||
self.X_ = X
|
self.X_ = X
|
||||||
self.y_ = y
|
self.y_ = y.astype(int)
|
||||||
|
self.dataset_ = pd.DataFrame(
|
||||||
|
self.X_, columns=self.features_, dtype="int16"
|
||||||
|
)
|
||||||
|
self.dataset_[self.class_name_] = self.y_
|
||||||
|
try:
|
||||||
|
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
||||||
|
except NotFittedError:
|
||||||
|
self.__build()
|
||||||
self.__train()
|
self.__train()
|
||||||
|
self.fitted_ = True
|
||||||
# Return the classifier
|
# Return the classifier
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -101,6 +115,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
Marco Zaffalon,
|
Marco Zaffalon,
|
||||||
Learning extended tree augmented naive structures,
|
Learning extended tree augmented naive structures,
|
||||||
International Journal of Approximate Reasoning,
|
International Journal of Approximate Reasoning,
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List
|
List
|
||||||
@@ -121,14 +136,12 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
]
|
]
|
||||||
return list(combinations(reordered, 2))
|
return list(combinations(reordered, 2))
|
||||||
|
|
||||||
def __train(self):
|
def __build(self):
|
||||||
# Initialize a Naive Bayes model
|
# Initialize a Naive Bayes model
|
||||||
net = [(self.class_name_, feature) for feature in self.features_]
|
net = [(self.class_name_, feature) for feature in self.features_]
|
||||||
self.model_ = BayesianNetwork(net)
|
self.model_ = BayesianNetwork(net)
|
||||||
# initialize a complete network with all edges
|
# initialize a complete network with all edges
|
||||||
self.model_.add_edges_from(self.__initial_edges())
|
self.model_.add_edges_from(self.__initial_edges())
|
||||||
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
|
|
||||||
self.dataset_[self.class_name_] = self.y_
|
|
||||||
# learn graph structure
|
# learn graph structure
|
||||||
root_node = None if self.head_ is None else self.features_[self.head_]
|
root_node = None if self.head_ is None else self.features_[self.head_]
|
||||||
est = TreeSearch(self.dataset_, root_node=root_node)
|
est = TreeSearch(self.dataset_, root_node=root_node)
|
||||||
@@ -139,12 +152,17 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
)
|
)
|
||||||
if self.head_ is None:
|
if self.head_ is None:
|
||||||
self.head_ = est.root_node
|
self.head_ = est.root_node
|
||||||
self.model_ = BayesianNetwork(dag.edges())
|
self.model_ = BayesianNetwork(
|
||||||
|
dag.edges(), show_progress=self.show_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
def __train(self):
|
||||||
self.model_.fit(
|
self.model_.fit(
|
||||||
self.dataset_,
|
self.dataset_,
|
||||||
# estimator=MaximumLikelihoodEstimator,
|
# estimator=MaximumLikelihoodEstimator,
|
||||||
estimator=BayesianEstimator,
|
estimator=BayesianEstimator,
|
||||||
prior_type="K2",
|
prior_type="K2",
|
||||||
|
n_jobs=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def plot(self, title=""):
|
def plot(self, title=""):
|
||||||
@@ -161,20 +179,54 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
"""A reference implementation of a prediction for a classifier.
|
"""A reference implementation of a prediction for a classifier.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
X : array-like, shape (n_samples, n_features)
|
X : array-like, shape (n_samples, n_features)
|
||||||
The input samples.
|
The input samples.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
y : ndarray, shape (n_samples,)
|
y : ndarray, shape (n_samples,)
|
||||||
The label for each sample is the label of the closest sample
|
The label for each sample is the label of the closest sample
|
||||||
seen during fit.
|
seen during fit.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> import pandas as pd
|
||||||
|
>>> from bayesclass import TAN
|
||||||
|
>>> features = ['A', 'B', 'C', 'D', 'E']
|
||||||
|
>>> np.random.seed(17)
|
||||||
|
>>> values = pd.DataFrame(np.random.randint(low=0, high=2,
|
||||||
|
... size=(1000, 5)), columns=features)
|
||||||
|
>>> train_data = values[:800]
|
||||||
|
>>> train_y = train_data['E']
|
||||||
|
>>> predict_data = values[800:]
|
||||||
|
>>> train_data.drop('E', axis=1, inplace=True)
|
||||||
|
>>> model = TAN(random_state=17)
|
||||||
|
>>> features.remove('E')
|
||||||
|
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
||||||
|
TAN(random_state=17)
|
||||||
|
>>> predict_data = predict_data.copy()
|
||||||
|
>>> predict_data.drop('E', axis=1, inplace=True)
|
||||||
|
>>> y_pred = model.predict(predict_data)
|
||||||
|
>>> y_pred[:10]
|
||||||
|
array([[0],
|
||||||
|
[0],
|
||||||
|
[1],
|
||||||
|
[1],
|
||||||
|
[0],
|
||||||
|
[1],
|
||||||
|
[1],
|
||||||
|
[1],
|
||||||
|
[0],
|
||||||
|
[1]])
|
||||||
"""
|
"""
|
||||||
# Check is fit had been called
|
# Check is fit had been called
|
||||||
check_is_fitted(self, ["X_", "y_"])
|
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
||||||
|
|
||||||
# Input validation
|
# Input validation
|
||||||
X = check_array(X)
|
X = check_array(X)
|
||||||
dataset = pd.DataFrame(X, columns=self.features_)
|
dataset = pd.DataFrame(X, columns=self.features_, dtype="int16")
|
||||||
return self.model_.predict(dataset).to_numpy()
|
return self.model_.predict(dataset, n_jobs=1).to_numpy()
|
||||||
|
1
bayesclass/test.r
Normal file
1
bayesclass/test.r
Normal file
@@ -0,0 +1 @@
|
|||||||
|
m0 <- ulam(alist(height ~ dnorm(mu, sigma), mu <- a, a ~ dnorm(186, 10), sigma ~ dexp(1)), data = d, chains = 4, iter = 2000, cores = 4, log_lik=TRUE)
|
215
glass.csv
Normal file
215
glass.csv
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
|
||||||
|
35,11,11,16,41,29,18,0,0,0
|
||||||
|
22,3,11,23,40,25,13,0,0,1
|
||||||
|
35,21,11,26,23,26,6,0,0,0
|
||||||
|
3,41,5,28,48,0,3,0,0,2
|
||||||
|
68,4,0,13,0,10,31,0,6,3
|
||||||
|
22,9,6,27,42,25,19,1,5,3
|
||||||
|
34,26,11,5,41,2,20,0,0,1
|
||||||
|
46,20,6,23,38,24,20,0,0,0
|
||||||
|
8,34,0,43,45,5,20,2,2,4
|
||||||
|
35,20,12,23,20,24,8,0,6,3
|
||||||
|
22,24,11,27,28,16,3,0,0,3
|
||||||
|
32,4,7,17,46,27,20,0,6,3
|
||||||
|
63,21,11,9,10,10,28,0,0,0
|
||||||
|
57,32,11,5,7,9,24,0,0,1
|
||||||
|
23,20,11,32,22,26,4,0,5,1
|
||||||
|
27,26,11,30,22,27,3,0,0,3
|
||||||
|
28,41,0,38,41,0,13,3,4,4
|
||||||
|
22,9,7,27,43,30,3,0,0,3
|
||||||
|
50,23,0,32,41,17,30,0,0,5
|
||||||
|
40,16,7,27,40,26,19,1,0,3
|
||||||
|
58,19,11,12,18,13,27,0,5,0
|
||||||
|
66,0,0,37,17,33,31,0,6,3
|
||||||
|
48,16,11,15,20,29,20,0,5,3
|
||||||
|
33,25,11,19,34,25,3,0,5,0
|
||||||
|
52,12,4,42,15,32,25,1,8,5
|
||||||
|
10,22,11,27,42,15,3,0,0,3
|
||||||
|
14,11,11,37,39,30,3,0,0,3
|
||||||
|
24,41,0,38,43,0,10,3,3,4
|
||||||
|
25,22,11,27,36,25,3,0,0,3
|
||||||
|
16,24,11,23,25,25,4,0,0,1
|
||||||
|
23,16,11,30,42,30,3,0,0,3
|
||||||
|
45,24,8,28,14,25,20,0,0,1
|
||||||
|
0,41,0,1,48,0,1,0,0,2
|
||||||
|
22,26,11,28,22,30,3,0,0,3
|
||||||
|
33,17,11,23,41,25,5,0,5,0
|
||||||
|
11,9,11,29,42,30,3,0,6,0
|
||||||
|
14,11,11,30,40,29,3,0,6,0
|
||||||
|
30,4,6,30,40,30,20,0,0,3
|
||||||
|
23,12,11,27,41,30,3,1,6,3
|
||||||
|
5,37,7,39,20,33,3,0,0,3
|
||||||
|
37,9,11,23,40,29,18,0,0,0
|
||||||
|
38,17,11,13,42,25,4,0,5,3
|
||||||
|
22,17,11,28,35,26,3,0,0,3
|
||||||
|
14,22,8,27,42,16,3,0,0,3
|
||||||
|
49,26,11,8,24,11,20,1,6,1
|
||||||
|
33,9,11,20,42,26,17,0,0,0
|
||||||
|
6,31,5,44,0,34,0,4,0,5
|
||||||
|
33,21,11,23,22,25,3,0,0,0
|
||||||
|
35,19,11,23,39,26,10,0,0,0
|
||||||
|
60,20,11,17,31,23,10,0,0,3
|
||||||
|
33,7,11,23,45,26,14,0,3,0
|
||||||
|
48,20,11,13,35,25,6,1,5,3
|
||||||
|
32,23,11,16,42,26,3,0,0,0
|
||||||
|
14,20,11,28,42,30,3,0,0,3
|
||||||
|
22,41,0,43,38,0,23,2,0,4
|
||||||
|
31,17,11,30,30,23,8,0,3,0
|
||||||
|
64,41,5,38,1,32,26,0,0,4
|
||||||
|
55,27,11,16,10,6,22,0,0,0
|
||||||
|
34,26,11,15,33,9,16,0,0,1
|
||||||
|
48,25,12,23,22,21,4,0,0,3
|
||||||
|
48,26,12,23,10,23,4,0,6,3
|
||||||
|
49,26,11,15,23,10,18,0,0,0
|
||||||
|
9,23,11,20,31,25,15,0,0,0
|
||||||
|
63,35,11,2,7,9,24,0,0,0
|
||||||
|
64,27,11,2,6,6,28,0,5,0
|
||||||
|
8,28,0,43,42,10,23,3,2,4
|
||||||
|
49,20,7,23,21,26,20,0,5,0
|
||||||
|
62,35,11,13,5,13,20,0,7,1
|
||||||
|
68,0,0,41,0,26,31,4,6,3
|
||||||
|
58,19,11,12,20,13,27,0,5,0
|
||||||
|
42,41,5,30,22,0,21,0,0,2
|
||||||
|
48,26,11,23,22,25,3,0,5,3
|
||||||
|
48,41,2,30,22,0,27,0,0,2
|
||||||
|
42,22,12,26,20,24,4,0,5,3
|
||||||
|
64,23,11,9,10,10,28,0,2,0
|
||||||
|
22,26,11,27,22,29,3,0,0,3
|
||||||
|
33,7,11,27,42,25,14,0,0,0
|
||||||
|
2,17,11,16,41,27,4,0,6,0
|
||||||
|
22,18,10,23,41,21,15,0,0,1
|
||||||
|
29,16,11,23,41,25,6,0,0,0
|
||||||
|
33,11,11,23,41,26,15,0,0,0
|
||||||
|
31,23,11,21,42,24,3,0,0,0
|
||||||
|
57,39,12,11,5,0,24,0,0,1
|
||||||
|
34,21,9,23,31,26,15,0,0,0
|
||||||
|
58,1,5,29,39,17,30,0,0,5
|
||||||
|
38,27,12,28,8,23,3,0,5,3
|
||||||
|
68,8,0,6,10,2,31,0,0,3
|
||||||
|
33,11,11,27,31,23,10,0,5,0
|
||||||
|
33,17,11,20,41,30,13,0,0,0
|
||||||
|
60,27,3,23,17,15,29,0,0,3
|
||||||
|
22,41,0,36,42,0,17,3,0,4
|
||||||
|
35,9,11,19,40,27,18,0,6,0
|
||||||
|
58,20,11,12,18,13,27,0,5,0
|
||||||
|
49,27,5,19,31,0,27,0,0,2
|
||||||
|
6,41,0,43,46,0,5,2,0,4
|
||||||
|
59,26,11,12,10,11,24,0,3,0
|
||||||
|
30,41,0,33,41,0,17,3,0,4
|
||||||
|
51,29,3,30,6,16,29,0,5,3
|
||||||
|
15,16,11,27,42,16,3,0,0,3
|
||||||
|
48,20,12,19,22,26,6,0,0,3
|
||||||
|
33,29,11,23,30,18,3,0,0,0
|
||||||
|
23,23,11,28,22,30,3,0,6,3
|
||||||
|
64,41,5,23,1,14,17,3,0,4
|
||||||
|
24,41,0,38,42,0,5,3,0,4
|
||||||
|
22,41,0,38,42,0,4,3,0,4
|
||||||
|
4,17,0,44,2,35,2,0,0,5
|
||||||
|
27,17,11,33,28,30,3,0,0,3
|
||||||
|
30,41,0,43,43,0,20,2,0,4
|
||||||
|
49,26,8,20,13,26,20,0,0,0
|
||||||
|
49,8,0,30,47,15,30,0,0,5
|
||||||
|
40,8,6,11,47,15,23,0,5,3
|
||||||
|
18,41,0,43,43,0,18,2,0,4
|
||||||
|
49,29,11,19,13,2,20,0,0,0
|
||||||
|
22,41,0,38,46,0,9,3,0,4
|
||||||
|
26,15,11,23,22,26,19,0,0,1
|
||||||
|
64,26,8,20,22,26,20,0,0,4
|
||||||
|
54,26,5,30,15,22,24,1,5,3
|
||||||
|
47,39,7,43,4,34,0,3,0,4
|
||||||
|
40,27,0,3,47,0,29,0,0,3
|
||||||
|
34,5,6,23,46,25,20,0,6,0
|
||||||
|
23,17,7,20,40,26,19,0,6,3
|
||||||
|
13,16,11,24,43,30,3,0,6,0
|
||||||
|
66,27,7,18,3,5,29,0,0,3
|
||||||
|
68,27,7,6,2,5,30,0,0,3
|
||||||
|
56,17,1,27,45,10,30,0,6,5
|
||||||
|
33,15,11,23,42,23,4,0,5,0
|
||||||
|
22,2,0,19,48,34,20,0,0,4
|
||||||
|
21,34,0,43,22,5,20,3,0,4
|
||||||
|
55,26,13,14,7,2,18,0,0,0
|
||||||
|
33,7,11,23,43,26,10,0,0,0
|
||||||
|
14,17,11,28,42,30,3,0,0,3
|
||||||
|
23,11,11,28,44,30,3,0,0,3
|
||||||
|
53,41,0,38,45,0,8,3,0,4
|
||||||
|
33,9,11,23,42,26,18,0,5,0
|
||||||
|
65,26,0,29,18,15,30,0,0,5
|
||||||
|
33,20,11,13,42,25,3,0,0,0
|
||||||
|
33,26,11,18,41,26,3,0,0,0
|
||||||
|
28,16,11,29,40,26,3,0,0,3
|
||||||
|
61,27,11,7,6,11,25,0,0,0
|
||||||
|
14,20,11,28,40,30,3,0,4,3
|
||||||
|
35,9,11,17,42,26,18,0,0,0
|
||||||
|
49,29,11,23,8,21,18,1,0,0
|
||||||
|
49,27,11,23,6,10,17,2,0,0
|
||||||
|
23,15,0,35,47,33,28,0,0,5
|
||||||
|
22,24,11,29,40,26,3,0,0,3
|
||||||
|
48,16,11,29,22,26,14,0,5,3
|
||||||
|
27,27,11,34,12,29,3,0,0,3
|
||||||
|
54,27,5,27,10,19,27,0,5,3
|
||||||
|
12,41,11,30,8,11,3,0,5,3
|
||||||
|
40,25,12,19,22,26,3,0,0,3
|
||||||
|
1,27,7,34,35,34,0,3,0,4
|
||||||
|
63,35,11,9,5,0,25,0,0,0
|
||||||
|
66,27,0,23,3,13,31,0,5,3
|
||||||
|
40,24,11,22,34,21,3,0,0,3
|
||||||
|
22,25,9,23,23,21,17,0,0,1
|
||||||
|
33,11,11,23,41,27,15,0,0,0
|
||||||
|
6,41,0,43,46,0,4,2,0,4
|
||||||
|
49,9,5,35,26,26,28,0,0,5
|
||||||
|
49,41,11,0,10,1,20,0,0,0
|
||||||
|
48,22,11,23,22,26,6,0,0,3
|
||||||
|
66,0,0,8,42,0,31,0,0,3
|
||||||
|
59,26,11,12,7,13,24,0,5,0
|
||||||
|
15,41,0,43,43,0,18,2,4,4
|
||||||
|
4,17,0,44,2,35,2,0,0,5
|
||||||
|
68,0,0,8,42,0,31,0,0,3
|
||||||
|
63,35,11,2,7,9,24,0,0,0
|
||||||
|
33,11,11,16,42,25,14,0,0,0
|
||||||
|
48,12,11,21,22,27,18,0,6,3
|
||||||
|
22,25,11,22,35,30,3,0,0,3
|
||||||
|
15,41,0,43,42,1,20,2,0,4
|
||||||
|
23,16,11,23,31,25,16,0,0,3
|
||||||
|
12,20,11,29,42,3,5,0,5,3
|
||||||
|
67,29,11,7,6,1,27,0,5,0
|
||||||
|
44,41,0,34,39,34,0,4,0,4
|
||||||
|
49,33,11,23,13,25,4,0,0,0
|
||||||
|
18,28,5,33,42,0,17,3,0,4
|
||||||
|
61,41,11,12,5,11,20,0,0,0
|
||||||
|
41,16,11,23,40,26,6,0,0,0
|
||||||
|
58,0,4,29,45,26,30,0,0,5
|
||||||
|
49,41,0,3,46,0,29,0,0,2
|
||||||
|
19,17,11,27,40,26,3,0,0,3
|
||||||
|
22,25,11,28,24,30,3,0,5,3
|
||||||
|
36,26,8,30,9,25,19,0,4,1
|
||||||
|
63,41,0,13,25,7,30,0,4,3
|
||||||
|
35,9,11,23,40,25,18,0,0,0
|
||||||
|
28,36,0,39,44,0,17,3,0,4
|
||||||
|
31,9,11,23,34,26,18,0,0,0
|
||||||
|
39,25,6,19,36,24,20,0,0,0
|
||||||
|
23,22,11,23,27,25,8,0,5,1
|
||||||
|
52,25,0,24,19,15,30,0,0,5
|
||||||
|
49,26,11,23,10,24,20,0,0,0
|
||||||
|
34,21,6,23,41,21,20,0,4,0
|
||||||
|
49,30,5,29,22,0,24,0,0,2
|
||||||
|
8,41,0,43,42,1,20,2,0,4
|
||||||
|
49,34,0,40,31,0,29,0,0,2
|
||||||
|
48,17,11,13,20,29,20,0,5,3
|
||||||
|
14,17,11,27,42,30,3,0,0,3
|
||||||
|
14,22,11,27,42,26,3,0,0,3
|
||||||
|
22,6,11,36,42,28,3,0,4,3
|
||||||
|
23,16,11,30,40,29,3,0,5,3
|
||||||
|
25,24,11,30,22,30,3,0,0,3
|
||||||
|
48,20,9,19,28,25,20,0,5,0
|
||||||
|
34,26,11,28,11,26,19,0,0,1
|
||||||
|
34,12,11,20,40,26,15,1,5,0
|
||||||
|
48,24,11,27,20,21,16,0,0,3
|
||||||
|
29,25,11,17,38,20,6,0,0,0
|
||||||
|
21,35,0,43,46,1,20,2,4,4
|
||||||
|
19,26,11,28,41,16,3,0,0,0
|
||||||
|
33,11,11,20,42,26,5,0,0,0
|
||||||
|
16,25,10,20,26,26,4,0,0,1
|
||||||
|
14,15,11,41,24,30,3,0,0,3
|
||||||
|
18,29,11,22,40,15,3,0,5,3
|
||||||
|
25,9,7,30,42,30,14,0,0,3
|
||||||
|
48,34,5,30,25,0,21,0,0,2
|
|
1193
test.ipynb
1193
test.ipynb
File diff suppressed because one or more lines are too long
111
test.py
Normal file
111
test.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
# In[1]:
|
||||||
|
|
||||||
|
|
||||||
|
from mdlp import MDLP
|
||||||
|
import pandas as pd
|
||||||
|
from benchmark import Datasets
|
||||||
|
from bayesclass import TAN
|
||||||
|
from sklearn.model_selection import (
|
||||||
|
cross_validate,
|
||||||
|
StratifiedKFold,
|
||||||
|
KFold,
|
||||||
|
cross_val_score,
|
||||||
|
train_test_split,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
from stree import Stree
|
||||||
|
|
||||||
|
# In[2]:
|
||||||
|
|
||||||
|
|
||||||
|
# Get data as a dataset
|
||||||
|
dt = Datasets()
|
||||||
|
data = dt.load("glass", dataframe=True)
|
||||||
|
features = dt.dataset.features
|
||||||
|
class_name = dt.dataset.class_name
|
||||||
|
factorization, class_factors = pd.factorize(data[class_name])
|
||||||
|
data[class_name] = factorization
|
||||||
|
data.head()
|
||||||
|
|
||||||
|
|
||||||
|
# In[3]:
|
||||||
|
|
||||||
|
|
||||||
|
# Fayyad Irani
|
||||||
|
discretiz = MDLP()
|
||||||
|
Xdisc = discretiz.fit_transform(
|
||||||
|
data[features].to_numpy(), data[class_name].to_numpy()
|
||||||
|
)
|
||||||
|
features_discretized = pd.DataFrame(Xdisc, columns=features)
|
||||||
|
dataset_discretized = features_discretized.copy()
|
||||||
|
dataset_discretized[class_name] = data[class_name]
|
||||||
|
X = dataset_discretized[features]
|
||||||
|
y = dataset_discretized[class_name]
|
||||||
|
dataset_discretized
|
||||||
|
|
||||||
|
|
||||||
|
# In[4]:
|
||||||
|
|
||||||
|
|
||||||
|
n_folds = 5
|
||||||
|
score_name = "accuracy"
|
||||||
|
random_state = 17
|
||||||
|
test_size = 0.3
|
||||||
|
|
||||||
|
|
||||||
|
def validate_classifier(model, X, y, stratified, fit_params):
|
||||||
|
stratified_class = StratifiedKFold if stratified else KFold
|
||||||
|
kfold = stratified_class(
|
||||||
|
shuffle=True, random_state=random_state, n_splits=n_folds
|
||||||
|
)
|
||||||
|
# return cross_validate(model, X, y, cv=kfold, return_estimator=True,
|
||||||
|
# scoring=score_name)
|
||||||
|
return cross_val_score(model, X, y, fit_params=fit_params)
|
||||||
|
|
||||||
|
|
||||||
|
def split_data(X, y, stratified):
|
||||||
|
if stratified:
|
||||||
|
return train_test_split(
|
||||||
|
X,
|
||||||
|
y,
|
||||||
|
test_size=test_size,
|
||||||
|
random_state=random_state,
|
||||||
|
stratify=y,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return train_test_split(
|
||||||
|
X, y, test_size=test_size, random_state=random_state, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# In[5]:
|
||||||
|
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
for simple_init in [False, True]:
|
||||||
|
model = TAN(simple_init=simple_init)
|
||||||
|
for head in range(4):
|
||||||
|
X_train, X_test, y_train, y_test = split_data(X, y, stratified=False)
|
||||||
|
model.fit(
|
||||||
|
X_train,
|
||||||
|
y_train,
|
||||||
|
head=head,
|
||||||
|
features=features,
|
||||||
|
class_name=class_name,
|
||||||
|
)
|
||||||
|
y = model.predict(X_test)
|
||||||
|
model.plot()
|
||||||
|
|
||||||
|
# In[ ]:
|
||||||
|
|
||||||
|
|
||||||
|
model = TAN(simple_init=simple_init)
|
||||||
|
model.fit(X, y, features=features, class_name=class_name)
|
||||||
|
model.plot(
|
||||||
|
f"**simple_init={simple_init} head={head} score={model.score(X, y)}"
|
||||||
|
)
|
Reference in New Issue
Block a user