mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-15 23:55:57 +00:00
Fix fit/build/train mistake
This commit is contained in:
215
balance-scale.csv
Normal file
215
balance-scale.csv
Normal file
@@ -0,0 +1,215 @@
|
||||
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
|
||||
35,11,11,16,41,29,18,0,0,0
|
||||
22,3,11,23,40,25,13,0,0,1
|
||||
35,21,11,26,23,26,6,0,0,0
|
||||
3,41,5,28,48,0,3,0,0,2
|
||||
68,4,0,13,0,10,31,0,6,3
|
||||
22,9,6,27,42,25,19,1,5,3
|
||||
34,26,11,5,41,2,20,0,0,1
|
||||
46,20,6,23,38,24,20,0,0,0
|
||||
8,34,0,43,45,5,20,2,2,4
|
||||
35,20,12,23,20,24,8,0,6,3
|
||||
22,24,11,27,28,16,3,0,0,3
|
||||
32,4,7,17,46,27,20,0,6,3
|
||||
63,21,11,9,10,10,28,0,0,0
|
||||
57,32,11,5,7,9,24,0,0,1
|
||||
23,20,11,32,22,26,4,0,5,1
|
||||
27,26,11,30,22,27,3,0,0,3
|
||||
28,41,0,38,41,0,13,3,4,4
|
||||
22,9,7,27,43,30,3,0,0,3
|
||||
50,23,0,32,41,17,30,0,0,5
|
||||
40,16,7,27,40,26,19,1,0,3
|
||||
58,19,11,12,18,13,27,0,5,0
|
||||
66,0,0,37,17,33,31,0,6,3
|
||||
48,16,11,15,20,29,20,0,5,3
|
||||
33,25,11,19,34,25,3,0,5,0
|
||||
52,12,4,42,15,32,25,1,8,5
|
||||
10,22,11,27,42,15,3,0,0,3
|
||||
14,11,11,37,39,30,3,0,0,3
|
||||
24,41,0,38,43,0,10,3,3,4
|
||||
25,22,11,27,36,25,3,0,0,3
|
||||
16,24,11,23,25,25,4,0,0,1
|
||||
23,16,11,30,42,30,3,0,0,3
|
||||
45,24,8,28,14,25,20,0,0,1
|
||||
0,41,0,1,48,0,1,0,0,2
|
||||
22,26,11,28,22,30,3,0,0,3
|
||||
33,17,11,23,41,25,5,0,5,0
|
||||
11,9,11,29,42,30,3,0,6,0
|
||||
14,11,11,30,40,29,3,0,6,0
|
||||
30,4,6,30,40,30,20,0,0,3
|
||||
23,12,11,27,41,30,3,1,6,3
|
||||
5,37,7,39,20,33,3,0,0,3
|
||||
37,9,11,23,40,29,18,0,0,0
|
||||
38,17,11,13,42,25,4,0,5,3
|
||||
22,17,11,28,35,26,3,0,0,3
|
||||
14,22,8,27,42,16,3,0,0,3
|
||||
49,26,11,8,24,11,20,1,6,1
|
||||
33,9,11,20,42,26,17,0,0,0
|
||||
6,31,5,44,0,34,0,4,0,5
|
||||
33,21,11,23,22,25,3,0,0,0
|
||||
35,19,11,23,39,26,10,0,0,0
|
||||
60,20,11,17,31,23,10,0,0,3
|
||||
33,7,11,23,45,26,14,0,3,0
|
||||
48,20,11,13,35,25,6,1,5,3
|
||||
32,23,11,16,42,26,3,0,0,0
|
||||
14,20,11,28,42,30,3,0,0,3
|
||||
22,41,0,43,38,0,23,2,0,4
|
||||
31,17,11,30,30,23,8,0,3,0
|
||||
64,41,5,38,1,32,26,0,0,4
|
||||
55,27,11,16,10,6,22,0,0,0
|
||||
34,26,11,15,33,9,16,0,0,1
|
||||
48,25,12,23,22,21,4,0,0,3
|
||||
48,26,12,23,10,23,4,0,6,3
|
||||
49,26,11,15,23,10,18,0,0,0
|
||||
9,23,11,20,31,25,15,0,0,0
|
||||
63,35,11,2,7,9,24,0,0,0
|
||||
64,27,11,2,6,6,28,0,5,0
|
||||
8,28,0,43,42,10,23,3,2,4
|
||||
49,20,7,23,21,26,20,0,5,0
|
||||
62,35,11,13,5,13,20,0,7,1
|
||||
68,0,0,41,0,26,31,4,6,3
|
||||
58,19,11,12,20,13,27,0,5,0
|
||||
42,41,5,30,22,0,21,0,0,2
|
||||
48,26,11,23,22,25,3,0,5,3
|
||||
48,41,2,30,22,0,27,0,0,2
|
||||
42,22,12,26,20,24,4,0,5,3
|
||||
64,23,11,9,10,10,28,0,2,0
|
||||
22,26,11,27,22,29,3,0,0,3
|
||||
33,7,11,27,42,25,14,0,0,0
|
||||
2,17,11,16,41,27,4,0,6,0
|
||||
22,18,10,23,41,21,15,0,0,1
|
||||
29,16,11,23,41,25,6,0,0,0
|
||||
33,11,11,23,41,26,15,0,0,0
|
||||
31,23,11,21,42,24,3,0,0,0
|
||||
57,39,12,11,5,0,24,0,0,1
|
||||
34,21,9,23,31,26,15,0,0,0
|
||||
58,1,5,29,39,17,30,0,0,5
|
||||
38,27,12,28,8,23,3,0,5,3
|
||||
68,8,0,6,10,2,31,0,0,3
|
||||
33,11,11,27,31,23,10,0,5,0
|
||||
33,17,11,20,41,30,13,0,0,0
|
||||
60,27,3,23,17,15,29,0,0,3
|
||||
22,41,0,36,42,0,17,3,0,4
|
||||
35,9,11,19,40,27,18,0,6,0
|
||||
58,20,11,12,18,13,27,0,5,0
|
||||
49,27,5,19,31,0,27,0,0,2
|
||||
6,41,0,43,46,0,5,2,0,4
|
||||
59,26,11,12,10,11,24,0,3,0
|
||||
30,41,0,33,41,0,17,3,0,4
|
||||
51,29,3,30,6,16,29,0,5,3
|
||||
15,16,11,27,42,16,3,0,0,3
|
||||
48,20,12,19,22,26,6,0,0,3
|
||||
33,29,11,23,30,18,3,0,0,0
|
||||
23,23,11,28,22,30,3,0,6,3
|
||||
64,41,5,23,1,14,17,3,0,4
|
||||
24,41,0,38,42,0,5,3,0,4
|
||||
22,41,0,38,42,0,4,3,0,4
|
||||
4,17,0,44,2,35,2,0,0,5
|
||||
27,17,11,33,28,30,3,0,0,3
|
||||
30,41,0,43,43,0,20,2,0,4
|
||||
49,26,8,20,13,26,20,0,0,0
|
||||
49,8,0,30,47,15,30,0,0,5
|
||||
40,8,6,11,47,15,23,0,5,3
|
||||
18,41,0,43,43,0,18,2,0,4
|
||||
49,29,11,19,13,2,20,0,0,0
|
||||
22,41,0,38,46,0,9,3,0,4
|
||||
26,15,11,23,22,26,19,0,0,1
|
||||
64,26,8,20,22,26,20,0,0,4
|
||||
54,26,5,30,15,22,24,1,5,3
|
||||
47,39,7,43,4,34,0,3,0,4
|
||||
40,27,0,3,47,0,29,0,0,3
|
||||
34,5,6,23,46,25,20,0,6,0
|
||||
23,17,7,20,40,26,19,0,6,3
|
||||
13,16,11,24,43,30,3,0,6,0
|
||||
66,27,7,18,3,5,29,0,0,3
|
||||
68,27,7,6,2,5,30,0,0,3
|
||||
56,17,1,27,45,10,30,0,6,5
|
||||
33,15,11,23,42,23,4,0,5,0
|
||||
22,2,0,19,48,34,20,0,0,4
|
||||
21,34,0,43,22,5,20,3,0,4
|
||||
55,26,13,14,7,2,18,0,0,0
|
||||
33,7,11,23,43,26,10,0,0,0
|
||||
14,17,11,28,42,30,3,0,0,3
|
||||
23,11,11,28,44,30,3,0,0,3
|
||||
53,41,0,38,45,0,8,3,0,4
|
||||
33,9,11,23,42,26,18,0,5,0
|
||||
65,26,0,29,18,15,30,0,0,5
|
||||
33,20,11,13,42,25,3,0,0,0
|
||||
33,26,11,18,41,26,3,0,0,0
|
||||
28,16,11,29,40,26,3,0,0,3
|
||||
61,27,11,7,6,11,25,0,0,0
|
||||
14,20,11,28,40,30,3,0,4,3
|
||||
35,9,11,17,42,26,18,0,0,0
|
||||
49,29,11,23,8,21,18,1,0,0
|
||||
49,27,11,23,6,10,17,2,0,0
|
||||
23,15,0,35,47,33,28,0,0,5
|
||||
22,24,11,29,40,26,3,0,0,3
|
||||
48,16,11,29,22,26,14,0,5,3
|
||||
27,27,11,34,12,29,3,0,0,3
|
||||
54,27,5,27,10,19,27,0,5,3
|
||||
12,41,11,30,8,11,3,0,5,3
|
||||
40,25,12,19,22,26,3,0,0,3
|
||||
1,27,7,34,35,34,0,3,0,4
|
||||
63,35,11,9,5,0,25,0,0,0
|
||||
66,27,0,23,3,13,31,0,5,3
|
||||
40,24,11,22,34,21,3,0,0,3
|
||||
22,25,9,23,23,21,17,0,0,1
|
||||
33,11,11,23,41,27,15,0,0,0
|
||||
6,41,0,43,46,0,4,2,0,4
|
||||
49,9,5,35,26,26,28,0,0,5
|
||||
49,41,11,0,10,1,20,0,0,0
|
||||
48,22,11,23,22,26,6,0,0,3
|
||||
66,0,0,8,42,0,31,0,0,3
|
||||
59,26,11,12,7,13,24,0,5,0
|
||||
15,41,0,43,43,0,18,2,4,4
|
||||
4,17,0,44,2,35,2,0,0,5
|
||||
68,0,0,8,42,0,31,0,0,3
|
||||
63,35,11,2,7,9,24,0,0,0
|
||||
33,11,11,16,42,25,14,0,0,0
|
||||
48,12,11,21,22,27,18,0,6,3
|
||||
22,25,11,22,35,30,3,0,0,3
|
||||
15,41,0,43,42,1,20,2,0,4
|
||||
23,16,11,23,31,25,16,0,0,3
|
||||
12,20,11,29,42,3,5,0,5,3
|
||||
67,29,11,7,6,1,27,0,5,0
|
||||
44,41,0,34,39,34,0,4,0,4
|
||||
49,33,11,23,13,25,4,0,0,0
|
||||
18,28,5,33,42,0,17,3,0,4
|
||||
61,41,11,12,5,11,20,0,0,0
|
||||
41,16,11,23,40,26,6,0,0,0
|
||||
58,0,4,29,45,26,30,0,0,5
|
||||
49,41,0,3,46,0,29,0,0,2
|
||||
19,17,11,27,40,26,3,0,0,3
|
||||
22,25,11,28,24,30,3,0,5,3
|
||||
36,26,8,30,9,25,19,0,4,1
|
||||
63,41,0,13,25,7,30,0,4,3
|
||||
35,9,11,23,40,25,18,0,0,0
|
||||
28,36,0,39,44,0,17,3,0,4
|
||||
31,9,11,23,34,26,18,0,0,0
|
||||
39,25,6,19,36,24,20,0,0,0
|
||||
23,22,11,23,27,25,8,0,5,1
|
||||
52,25,0,24,19,15,30,0,0,5
|
||||
49,26,11,23,10,24,20,0,0,0
|
||||
34,21,6,23,41,21,20,0,4,0
|
||||
49,30,5,29,22,0,24,0,0,2
|
||||
8,41,0,43,42,1,20,2,0,4
|
||||
49,34,0,40,31,0,29,0,0,2
|
||||
48,17,11,13,20,29,20,0,5,3
|
||||
14,17,11,27,42,30,3,0,0,3
|
||||
14,22,11,27,42,26,3,0,0,3
|
||||
22,6,11,36,42,28,3,0,4,3
|
||||
23,16,11,30,40,29,3,0,5,3
|
||||
25,24,11,30,22,30,3,0,0,3
|
||||
48,20,9,19,28,25,20,0,5,0
|
||||
34,26,11,28,11,26,19,0,0,1
|
||||
34,12,11,20,40,26,15,1,5,0
|
||||
48,24,11,27,20,21,16,0,0,3
|
||||
29,25,11,17,38,20,6,0,0,0
|
||||
21,35,0,43,46,1,20,2,4,4
|
||||
19,26,11,28,41,16,3,0,0,0
|
||||
33,11,11,20,42,26,5,0,0,0
|
||||
16,25,10,20,26,26,4,0,0,1
|
||||
14,15,11,41,24,30,3,0,0,3
|
||||
18,29,11,22,40,15,3,0,5,3
|
||||
25,9,7,30,42,30,14,0,0,3
|
||||
48,34,5,30,25,0,21,0,0,2
|
|
@@ -7,11 +7,12 @@ import pandas as pd
|
||||
from sklearn.base import ClassifierMixin, BaseEstimator
|
||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||
from sklearn.utils.multiclass import unique_labels
|
||||
from sklearn.exceptions import NotFittedError
|
||||
import networkx as nx
|
||||
from pgmpy.estimators import (
|
||||
TreeSearch,
|
||||
BayesianEstimator,
|
||||
MaximumLikelihoodEstimator,
|
||||
# MaximumLikelihoodEstimator,
|
||||
)
|
||||
from pgmpy.models import BayesianNetwork
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -21,10 +22,12 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
"""An example classifier which implements a 1-NN algorithm.
|
||||
For more information regarding how to build your own classifier, read more
|
||||
in the :ref:`User Guide <user_guide>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
demo_param : str, default='demo'
|
||||
A parameter used for demonstation of how to pass and store paramters.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
X_ : ndarray, shape (n_samples, n_features)
|
||||
@@ -44,6 +47,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
|
||||
def fit(self, X, y, **kwargs):
|
||||
"""A reference implementation of a fitting function for a classifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like, shape (n_samples, n_features)
|
||||
@@ -55,6 +59,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
features: list (default=None) List of features
|
||||
head: int (default=None) Index of the head node. Default value
|
||||
gets the node with the highest sum of weights (mutual_info)
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : object
|
||||
@@ -86,8 +91,17 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
raise ValueError("Head index out of range")
|
||||
|
||||
self.X_ = X
|
||||
self.y_ = y
|
||||
self.y_ = y.astype(int)
|
||||
self.dataset_ = pd.DataFrame(
|
||||
self.X_, columns=self.features_, dtype="int16"
|
||||
)
|
||||
self.dataset_[self.class_name_] = self.y_
|
||||
try:
|
||||
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
||||
except NotFittedError:
|
||||
self.__build()
|
||||
self.__train()
|
||||
self.fitted_ = True
|
||||
# Return the classifier
|
||||
return self
|
||||
|
||||
@@ -101,6 +115,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
Marco Zaffalon,
|
||||
Learning extended tree augmented naive structures,
|
||||
International Journal of Approximate Reasoning,
|
||||
|
||||
Returns
|
||||
-------
|
||||
List
|
||||
@@ -121,14 +136,12 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
]
|
||||
return list(combinations(reordered, 2))
|
||||
|
||||
def __train(self):
|
||||
def __build(self):
|
||||
# Initialize a Naive Bayes model
|
||||
net = [(self.class_name_, feature) for feature in self.features_]
|
||||
self.model_ = BayesianNetwork(net)
|
||||
# initialize a complete network with all edges
|
||||
self.model_.add_edges_from(self.__initial_edges())
|
||||
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
|
||||
self.dataset_[self.class_name_] = self.y_
|
||||
# learn graph structure
|
||||
root_node = None if self.head_ is None else self.features_[self.head_]
|
||||
est = TreeSearch(self.dataset_, root_node=root_node)
|
||||
@@ -139,12 +152,17 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
)
|
||||
if self.head_ is None:
|
||||
self.head_ = est.root_node
|
||||
self.model_ = BayesianNetwork(dag.edges())
|
||||
self.model_ = BayesianNetwork(
|
||||
dag.edges(), show_progress=self.show_progress
|
||||
)
|
||||
|
||||
def __train(self):
|
||||
self.model_.fit(
|
||||
self.dataset_,
|
||||
# estimator=MaximumLikelihoodEstimator,
|
||||
estimator=BayesianEstimator,
|
||||
prior_type="K2",
|
||||
n_jobs=1,
|
||||
)
|
||||
|
||||
def plot(self, title=""):
|
||||
@@ -161,20 +179,54 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
|
||||
def predict(self, X):
|
||||
"""A reference implementation of a prediction for a classifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like, shape (n_samples, n_features)
|
||||
The input samples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y : ndarray, shape (n_samples,)
|
||||
The label for each sample is the label of the closest sample
|
||||
seen during fit.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> import pandas as pd
|
||||
>>> from bayesclass import TAN
|
||||
>>> features = ['A', 'B', 'C', 'D', 'E']
|
||||
>>> np.random.seed(17)
|
||||
>>> values = pd.DataFrame(np.random.randint(low=0, high=2,
|
||||
... size=(1000, 5)), columns=features)
|
||||
>>> train_data = values[:800]
|
||||
>>> train_y = train_data['E']
|
||||
>>> predict_data = values[800:]
|
||||
>>> train_data.drop('E', axis=1, inplace=True)
|
||||
>>> model = TAN(random_state=17)
|
||||
>>> features.remove('E')
|
||||
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
||||
TAN(random_state=17)
|
||||
>>> predict_data = predict_data.copy()
|
||||
>>> predict_data.drop('E', axis=1, inplace=True)
|
||||
>>> y_pred = model.predict(predict_data)
|
||||
>>> y_pred[:10]
|
||||
array([[0],
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[1],
|
||||
[0],
|
||||
[1]])
|
||||
"""
|
||||
# Check is fit had been called
|
||||
check_is_fitted(self, ["X_", "y_"])
|
||||
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
||||
|
||||
# Input validation
|
||||
X = check_array(X)
|
||||
dataset = pd.DataFrame(X, columns=self.features_)
|
||||
return self.model_.predict(dataset).to_numpy()
|
||||
dataset = pd.DataFrame(X, columns=self.features_, dtype="int16")
|
||||
return self.model_.predict(dataset, n_jobs=1).to_numpy()
|
||||
|
1
bayesclass/test.r
Normal file
1
bayesclass/test.r
Normal file
@@ -0,0 +1 @@
|
||||
m0 <- ulam(alist(height ~ dnorm(mu, sigma), mu <- a, a ~ dnorm(186, 10), sigma ~ dexp(1)), data = d, chains = 4, iter = 2000, cores = 4, log_lik=TRUE)
|
215
glass.csv
Normal file
215
glass.csv
Normal file
@@ -0,0 +1,215 @@
|
||||
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
|
||||
35,11,11,16,41,29,18,0,0,0
|
||||
22,3,11,23,40,25,13,0,0,1
|
||||
35,21,11,26,23,26,6,0,0,0
|
||||
3,41,5,28,48,0,3,0,0,2
|
||||
68,4,0,13,0,10,31,0,6,3
|
||||
22,9,6,27,42,25,19,1,5,3
|
||||
34,26,11,5,41,2,20,0,0,1
|
||||
46,20,6,23,38,24,20,0,0,0
|
||||
8,34,0,43,45,5,20,2,2,4
|
||||
35,20,12,23,20,24,8,0,6,3
|
||||
22,24,11,27,28,16,3,0,0,3
|
||||
32,4,7,17,46,27,20,0,6,3
|
||||
63,21,11,9,10,10,28,0,0,0
|
||||
57,32,11,5,7,9,24,0,0,1
|
||||
23,20,11,32,22,26,4,0,5,1
|
||||
27,26,11,30,22,27,3,0,0,3
|
||||
28,41,0,38,41,0,13,3,4,4
|
||||
22,9,7,27,43,30,3,0,0,3
|
||||
50,23,0,32,41,17,30,0,0,5
|
||||
40,16,7,27,40,26,19,1,0,3
|
||||
58,19,11,12,18,13,27,0,5,0
|
||||
66,0,0,37,17,33,31,0,6,3
|
||||
48,16,11,15,20,29,20,0,5,3
|
||||
33,25,11,19,34,25,3,0,5,0
|
||||
52,12,4,42,15,32,25,1,8,5
|
||||
10,22,11,27,42,15,3,0,0,3
|
||||
14,11,11,37,39,30,3,0,0,3
|
||||
24,41,0,38,43,0,10,3,3,4
|
||||
25,22,11,27,36,25,3,0,0,3
|
||||
16,24,11,23,25,25,4,0,0,1
|
||||
23,16,11,30,42,30,3,0,0,3
|
||||
45,24,8,28,14,25,20,0,0,1
|
||||
0,41,0,1,48,0,1,0,0,2
|
||||
22,26,11,28,22,30,3,0,0,3
|
||||
33,17,11,23,41,25,5,0,5,0
|
||||
11,9,11,29,42,30,3,0,6,0
|
||||
14,11,11,30,40,29,3,0,6,0
|
||||
30,4,6,30,40,30,20,0,0,3
|
||||
23,12,11,27,41,30,3,1,6,3
|
||||
5,37,7,39,20,33,3,0,0,3
|
||||
37,9,11,23,40,29,18,0,0,0
|
||||
38,17,11,13,42,25,4,0,5,3
|
||||
22,17,11,28,35,26,3,0,0,3
|
||||
14,22,8,27,42,16,3,0,0,3
|
||||
49,26,11,8,24,11,20,1,6,1
|
||||
33,9,11,20,42,26,17,0,0,0
|
||||
6,31,5,44,0,34,0,4,0,5
|
||||
33,21,11,23,22,25,3,0,0,0
|
||||
35,19,11,23,39,26,10,0,0,0
|
||||
60,20,11,17,31,23,10,0,0,3
|
||||
33,7,11,23,45,26,14,0,3,0
|
||||
48,20,11,13,35,25,6,1,5,3
|
||||
32,23,11,16,42,26,3,0,0,0
|
||||
14,20,11,28,42,30,3,0,0,3
|
||||
22,41,0,43,38,0,23,2,0,4
|
||||
31,17,11,30,30,23,8,0,3,0
|
||||
64,41,5,38,1,32,26,0,0,4
|
||||
55,27,11,16,10,6,22,0,0,0
|
||||
34,26,11,15,33,9,16,0,0,1
|
||||
48,25,12,23,22,21,4,0,0,3
|
||||
48,26,12,23,10,23,4,0,6,3
|
||||
49,26,11,15,23,10,18,0,0,0
|
||||
9,23,11,20,31,25,15,0,0,0
|
||||
63,35,11,2,7,9,24,0,0,0
|
||||
64,27,11,2,6,6,28,0,5,0
|
||||
8,28,0,43,42,10,23,3,2,4
|
||||
49,20,7,23,21,26,20,0,5,0
|
||||
62,35,11,13,5,13,20,0,7,1
|
||||
68,0,0,41,0,26,31,4,6,3
|
||||
58,19,11,12,20,13,27,0,5,0
|
||||
42,41,5,30,22,0,21,0,0,2
|
||||
48,26,11,23,22,25,3,0,5,3
|
||||
48,41,2,30,22,0,27,0,0,2
|
||||
42,22,12,26,20,24,4,0,5,3
|
||||
64,23,11,9,10,10,28,0,2,0
|
||||
22,26,11,27,22,29,3,0,0,3
|
||||
33,7,11,27,42,25,14,0,0,0
|
||||
2,17,11,16,41,27,4,0,6,0
|
||||
22,18,10,23,41,21,15,0,0,1
|
||||
29,16,11,23,41,25,6,0,0,0
|
||||
33,11,11,23,41,26,15,0,0,0
|
||||
31,23,11,21,42,24,3,0,0,0
|
||||
57,39,12,11,5,0,24,0,0,1
|
||||
34,21,9,23,31,26,15,0,0,0
|
||||
58,1,5,29,39,17,30,0,0,5
|
||||
38,27,12,28,8,23,3,0,5,3
|
||||
68,8,0,6,10,2,31,0,0,3
|
||||
33,11,11,27,31,23,10,0,5,0
|
||||
33,17,11,20,41,30,13,0,0,0
|
||||
60,27,3,23,17,15,29,0,0,3
|
||||
22,41,0,36,42,0,17,3,0,4
|
||||
35,9,11,19,40,27,18,0,6,0
|
||||
58,20,11,12,18,13,27,0,5,0
|
||||
49,27,5,19,31,0,27,0,0,2
|
||||
6,41,0,43,46,0,5,2,0,4
|
||||
59,26,11,12,10,11,24,0,3,0
|
||||
30,41,0,33,41,0,17,3,0,4
|
||||
51,29,3,30,6,16,29,0,5,3
|
||||
15,16,11,27,42,16,3,0,0,3
|
||||
48,20,12,19,22,26,6,0,0,3
|
||||
33,29,11,23,30,18,3,0,0,0
|
||||
23,23,11,28,22,30,3,0,6,3
|
||||
64,41,5,23,1,14,17,3,0,4
|
||||
24,41,0,38,42,0,5,3,0,4
|
||||
22,41,0,38,42,0,4,3,0,4
|
||||
4,17,0,44,2,35,2,0,0,5
|
||||
27,17,11,33,28,30,3,0,0,3
|
||||
30,41,0,43,43,0,20,2,0,4
|
||||
49,26,8,20,13,26,20,0,0,0
|
||||
49,8,0,30,47,15,30,0,0,5
|
||||
40,8,6,11,47,15,23,0,5,3
|
||||
18,41,0,43,43,0,18,2,0,4
|
||||
49,29,11,19,13,2,20,0,0,0
|
||||
22,41,0,38,46,0,9,3,0,4
|
||||
26,15,11,23,22,26,19,0,0,1
|
||||
64,26,8,20,22,26,20,0,0,4
|
||||
54,26,5,30,15,22,24,1,5,3
|
||||
47,39,7,43,4,34,0,3,0,4
|
||||
40,27,0,3,47,0,29,0,0,3
|
||||
34,5,6,23,46,25,20,0,6,0
|
||||
23,17,7,20,40,26,19,0,6,3
|
||||
13,16,11,24,43,30,3,0,6,0
|
||||
66,27,7,18,3,5,29,0,0,3
|
||||
68,27,7,6,2,5,30,0,0,3
|
||||
56,17,1,27,45,10,30,0,6,5
|
||||
33,15,11,23,42,23,4,0,5,0
|
||||
22,2,0,19,48,34,20,0,0,4
|
||||
21,34,0,43,22,5,20,3,0,4
|
||||
55,26,13,14,7,2,18,0,0,0
|
||||
33,7,11,23,43,26,10,0,0,0
|
||||
14,17,11,28,42,30,3,0,0,3
|
||||
23,11,11,28,44,30,3,0,0,3
|
||||
53,41,0,38,45,0,8,3,0,4
|
||||
33,9,11,23,42,26,18,0,5,0
|
||||
65,26,0,29,18,15,30,0,0,5
|
||||
33,20,11,13,42,25,3,0,0,0
|
||||
33,26,11,18,41,26,3,0,0,0
|
||||
28,16,11,29,40,26,3,0,0,3
|
||||
61,27,11,7,6,11,25,0,0,0
|
||||
14,20,11,28,40,30,3,0,4,3
|
||||
35,9,11,17,42,26,18,0,0,0
|
||||
49,29,11,23,8,21,18,1,0,0
|
||||
49,27,11,23,6,10,17,2,0,0
|
||||
23,15,0,35,47,33,28,0,0,5
|
||||
22,24,11,29,40,26,3,0,0,3
|
||||
48,16,11,29,22,26,14,0,5,3
|
||||
27,27,11,34,12,29,3,0,0,3
|
||||
54,27,5,27,10,19,27,0,5,3
|
||||
12,41,11,30,8,11,3,0,5,3
|
||||
40,25,12,19,22,26,3,0,0,3
|
||||
1,27,7,34,35,34,0,3,0,4
|
||||
63,35,11,9,5,0,25,0,0,0
|
||||
66,27,0,23,3,13,31,0,5,3
|
||||
40,24,11,22,34,21,3,0,0,3
|
||||
22,25,9,23,23,21,17,0,0,1
|
||||
33,11,11,23,41,27,15,0,0,0
|
||||
6,41,0,43,46,0,4,2,0,4
|
||||
49,9,5,35,26,26,28,0,0,5
|
||||
49,41,11,0,10,1,20,0,0,0
|
||||
48,22,11,23,22,26,6,0,0,3
|
||||
66,0,0,8,42,0,31,0,0,3
|
||||
59,26,11,12,7,13,24,0,5,0
|
||||
15,41,0,43,43,0,18,2,4,4
|
||||
4,17,0,44,2,35,2,0,0,5
|
||||
68,0,0,8,42,0,31,0,0,3
|
||||
63,35,11,2,7,9,24,0,0,0
|
||||
33,11,11,16,42,25,14,0,0,0
|
||||
48,12,11,21,22,27,18,0,6,3
|
||||
22,25,11,22,35,30,3,0,0,3
|
||||
15,41,0,43,42,1,20,2,0,4
|
||||
23,16,11,23,31,25,16,0,0,3
|
||||
12,20,11,29,42,3,5,0,5,3
|
||||
67,29,11,7,6,1,27,0,5,0
|
||||
44,41,0,34,39,34,0,4,0,4
|
||||
49,33,11,23,13,25,4,0,0,0
|
||||
18,28,5,33,42,0,17,3,0,4
|
||||
61,41,11,12,5,11,20,0,0,0
|
||||
41,16,11,23,40,26,6,0,0,0
|
||||
58,0,4,29,45,26,30,0,0,5
|
||||
49,41,0,3,46,0,29,0,0,2
|
||||
19,17,11,27,40,26,3,0,0,3
|
||||
22,25,11,28,24,30,3,0,5,3
|
||||
36,26,8,30,9,25,19,0,4,1
|
||||
63,41,0,13,25,7,30,0,4,3
|
||||
35,9,11,23,40,25,18,0,0,0
|
||||
28,36,0,39,44,0,17,3,0,4
|
||||
31,9,11,23,34,26,18,0,0,0
|
||||
39,25,6,19,36,24,20,0,0,0
|
||||
23,22,11,23,27,25,8,0,5,1
|
||||
52,25,0,24,19,15,30,0,0,5
|
||||
49,26,11,23,10,24,20,0,0,0
|
||||
34,21,6,23,41,21,20,0,4,0
|
||||
49,30,5,29,22,0,24,0,0,2
|
||||
8,41,0,43,42,1,20,2,0,4
|
||||
49,34,0,40,31,0,29,0,0,2
|
||||
48,17,11,13,20,29,20,0,5,3
|
||||
14,17,11,27,42,30,3,0,0,3
|
||||
14,22,11,27,42,26,3,0,0,3
|
||||
22,6,11,36,42,28,3,0,4,3
|
||||
23,16,11,30,40,29,3,0,5,3
|
||||
25,24,11,30,22,30,3,0,0,3
|
||||
48,20,9,19,28,25,20,0,5,0
|
||||
34,26,11,28,11,26,19,0,0,1
|
||||
34,12,11,20,40,26,15,1,5,0
|
||||
48,24,11,27,20,21,16,0,0,3
|
||||
29,25,11,17,38,20,6,0,0,0
|
||||
21,35,0,43,46,1,20,2,4,4
|
||||
19,26,11,28,41,16,3,0,0,0
|
||||
33,11,11,20,42,26,5,0,0,0
|
||||
16,25,10,20,26,26,4,0,0,1
|
||||
14,15,11,41,24,30,3,0,0,3
|
||||
18,29,11,22,40,15,3,0,5,3
|
||||
25,9,7,30,42,30,14,0,0,3
|
||||
48,34,5,30,25,0,21,0,0,2
|
|
1193
test.ipynb
1193
test.ipynb
File diff suppressed because one or more lines are too long
111
test.py
Normal file
111
test.py
Normal file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# In[1]:
|
||||
|
||||
|
||||
from mdlp import MDLP
|
||||
import pandas as pd
|
||||
from benchmark import Datasets
|
||||
from bayesclass import TAN
|
||||
from sklearn.model_selection import (
|
||||
cross_validate,
|
||||
StratifiedKFold,
|
||||
KFold,
|
||||
cross_val_score,
|
||||
train_test_split,
|
||||
)
|
||||
import numpy as np
|
||||
import warnings
|
||||
from stree import Stree
|
||||
|
||||
# In[2]:
|
||||
|
||||
|
||||
# Get data as a dataset
|
||||
dt = Datasets()
|
||||
data = dt.load("glass", dataframe=True)
|
||||
features = dt.dataset.features
|
||||
class_name = dt.dataset.class_name
|
||||
factorization, class_factors = pd.factorize(data[class_name])
|
||||
data[class_name] = factorization
|
||||
data.head()
|
||||
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
# Fayyad Irani
|
||||
discretiz = MDLP()
|
||||
Xdisc = discretiz.fit_transform(
|
||||
data[features].to_numpy(), data[class_name].to_numpy()
|
||||
)
|
||||
features_discretized = pd.DataFrame(Xdisc, columns=features)
|
||||
dataset_discretized = features_discretized.copy()
|
||||
dataset_discretized[class_name] = data[class_name]
|
||||
X = dataset_discretized[features]
|
||||
y = dataset_discretized[class_name]
|
||||
dataset_discretized
|
||||
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
n_folds = 5
|
||||
score_name = "accuracy"
|
||||
random_state = 17
|
||||
test_size = 0.3
|
||||
|
||||
|
||||
def validate_classifier(model, X, y, stratified, fit_params):
|
||||
stratified_class = StratifiedKFold if stratified else KFold
|
||||
kfold = stratified_class(
|
||||
shuffle=True, random_state=random_state, n_splits=n_folds
|
||||
)
|
||||
# return cross_validate(model, X, y, cv=kfold, return_estimator=True,
|
||||
# scoring=score_name)
|
||||
return cross_val_score(model, X, y, fit_params=fit_params)
|
||||
|
||||
|
||||
def split_data(X, y, stratified):
|
||||
if stratified:
|
||||
return train_test_split(
|
||||
X,
|
||||
y,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
stratify=y,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
return train_test_split(
|
||||
X, y, test_size=test_size, random_state=random_state, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
for simple_init in [False, True]:
|
||||
model = TAN(simple_init=simple_init)
|
||||
for head in range(4):
|
||||
X_train, X_test, y_train, y_test = split_data(X, y, stratified=False)
|
||||
model.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
head=head,
|
||||
features=features,
|
||||
class_name=class_name,
|
||||
)
|
||||
y = model.predict(X_test)
|
||||
model.plot()
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
model = TAN(simple_init=simple_init)
|
||||
model.fit(X, y, features=features, class_name=class_name)
|
||||
model.plot(
|
||||
f"**simple_init={simple_init} head={head} score={model.score(X, y)}"
|
||||
)
|
Reference in New Issue
Block a user