Add example of usage

This commit is contained in:
2022-11-12 13:16:11 +01:00
parent 180365d727
commit 8ca3646f8c
12 changed files with 88 additions and 1293 deletions

2
.gitignore vendored
View File

@@ -65,3 +65,5 @@ doc/generated/
# PyBuilder # PyBuilder
target/ target/
.ipynb_checkpoints

View File

@@ -20,6 +20,7 @@ push: ## Push code with tags
git push && git push --tags git push && git push --tags
test: ## Run tests test: ## Run tests
python -m doctest bayesclass/bayesclass.py
pytest pytest
doc: ## Update documentation doc: ## Update documentation

View File

@@ -1,215 +0,0 @@
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
35,11,11,16,41,29,18,0,0,0
22,3,11,23,40,25,13,0,0,1
35,21,11,26,23,26,6,0,0,0
3,41,5,28,48,0,3,0,0,2
68,4,0,13,0,10,31,0,6,3
22,9,6,27,42,25,19,1,5,3
34,26,11,5,41,2,20,0,0,1
46,20,6,23,38,24,20,0,0,0
8,34,0,43,45,5,20,2,2,4
35,20,12,23,20,24,8,0,6,3
22,24,11,27,28,16,3,0,0,3
32,4,7,17,46,27,20,0,6,3
63,21,11,9,10,10,28,0,0,0
57,32,11,5,7,9,24,0,0,1
23,20,11,32,22,26,4,0,5,1
27,26,11,30,22,27,3,0,0,3
28,41,0,38,41,0,13,3,4,4
22,9,7,27,43,30,3,0,0,3
50,23,0,32,41,17,30,0,0,5
40,16,7,27,40,26,19,1,0,3
58,19,11,12,18,13,27,0,5,0
66,0,0,37,17,33,31,0,6,3
48,16,11,15,20,29,20,0,5,3
33,25,11,19,34,25,3,0,5,0
52,12,4,42,15,32,25,1,8,5
10,22,11,27,42,15,3,0,0,3
14,11,11,37,39,30,3,0,0,3
24,41,0,38,43,0,10,3,3,4
25,22,11,27,36,25,3,0,0,3
16,24,11,23,25,25,4,0,0,1
23,16,11,30,42,30,3,0,0,3
45,24,8,28,14,25,20,0,0,1
0,41,0,1,48,0,1,0,0,2
22,26,11,28,22,30,3,0,0,3
33,17,11,23,41,25,5,0,5,0
11,9,11,29,42,30,3,0,6,0
14,11,11,30,40,29,3,0,6,0
30,4,6,30,40,30,20,0,0,3
23,12,11,27,41,30,3,1,6,3
5,37,7,39,20,33,3,0,0,3
37,9,11,23,40,29,18,0,0,0
38,17,11,13,42,25,4,0,5,3
22,17,11,28,35,26,3,0,0,3
14,22,8,27,42,16,3,0,0,3
49,26,11,8,24,11,20,1,6,1
33,9,11,20,42,26,17,0,0,0
6,31,5,44,0,34,0,4,0,5
33,21,11,23,22,25,3,0,0,0
35,19,11,23,39,26,10,0,0,0
60,20,11,17,31,23,10,0,0,3
33,7,11,23,45,26,14,0,3,0
48,20,11,13,35,25,6,1,5,3
32,23,11,16,42,26,3,0,0,0
14,20,11,28,42,30,3,0,0,3
22,41,0,43,38,0,23,2,0,4
31,17,11,30,30,23,8,0,3,0
64,41,5,38,1,32,26,0,0,4
55,27,11,16,10,6,22,0,0,0
34,26,11,15,33,9,16,0,0,1
48,25,12,23,22,21,4,0,0,3
48,26,12,23,10,23,4,0,6,3
49,26,11,15,23,10,18,0,0,0
9,23,11,20,31,25,15,0,0,0
63,35,11,2,7,9,24,0,0,0
64,27,11,2,6,6,28,0,5,0
8,28,0,43,42,10,23,3,2,4
49,20,7,23,21,26,20,0,5,0
62,35,11,13,5,13,20,0,7,1
68,0,0,41,0,26,31,4,6,3
58,19,11,12,20,13,27,0,5,0
42,41,5,30,22,0,21,0,0,2
48,26,11,23,22,25,3,0,5,3
48,41,2,30,22,0,27,0,0,2
42,22,12,26,20,24,4,0,5,3
64,23,11,9,10,10,28,0,2,0
22,26,11,27,22,29,3,0,0,3
33,7,11,27,42,25,14,0,0,0
2,17,11,16,41,27,4,0,6,0
22,18,10,23,41,21,15,0,0,1
29,16,11,23,41,25,6,0,0,0
33,11,11,23,41,26,15,0,0,0
31,23,11,21,42,24,3,0,0,0
57,39,12,11,5,0,24,0,0,1
34,21,9,23,31,26,15,0,0,0
58,1,5,29,39,17,30,0,0,5
38,27,12,28,8,23,3,0,5,3
68,8,0,6,10,2,31,0,0,3
33,11,11,27,31,23,10,0,5,0
33,17,11,20,41,30,13,0,0,0
60,27,3,23,17,15,29,0,0,3
22,41,0,36,42,0,17,3,0,4
35,9,11,19,40,27,18,0,6,0
58,20,11,12,18,13,27,0,5,0
49,27,5,19,31,0,27,0,0,2
6,41,0,43,46,0,5,2,0,4
59,26,11,12,10,11,24,0,3,0
30,41,0,33,41,0,17,3,0,4
51,29,3,30,6,16,29,0,5,3
15,16,11,27,42,16,3,0,0,3
48,20,12,19,22,26,6,0,0,3
33,29,11,23,30,18,3,0,0,0
23,23,11,28,22,30,3,0,6,3
64,41,5,23,1,14,17,3,0,4
24,41,0,38,42,0,5,3,0,4
22,41,0,38,42,0,4,3,0,4
4,17,0,44,2,35,2,0,0,5
27,17,11,33,28,30,3,0,0,3
30,41,0,43,43,0,20,2,0,4
49,26,8,20,13,26,20,0,0,0
49,8,0,30,47,15,30,0,0,5
40,8,6,11,47,15,23,0,5,3
18,41,0,43,43,0,18,2,0,4
49,29,11,19,13,2,20,0,0,0
22,41,0,38,46,0,9,3,0,4
26,15,11,23,22,26,19,0,0,1
64,26,8,20,22,26,20,0,0,4
54,26,5,30,15,22,24,1,5,3
47,39,7,43,4,34,0,3,0,4
40,27,0,3,47,0,29,0,0,3
34,5,6,23,46,25,20,0,6,0
23,17,7,20,40,26,19,0,6,3
13,16,11,24,43,30,3,0,6,0
66,27,7,18,3,5,29,0,0,3
68,27,7,6,2,5,30,0,0,3
56,17,1,27,45,10,30,0,6,5
33,15,11,23,42,23,4,0,5,0
22,2,0,19,48,34,20,0,0,4
21,34,0,43,22,5,20,3,0,4
55,26,13,14,7,2,18,0,0,0
33,7,11,23,43,26,10,0,0,0
14,17,11,28,42,30,3,0,0,3
23,11,11,28,44,30,3,0,0,3
53,41,0,38,45,0,8,3,0,4
33,9,11,23,42,26,18,0,5,0
65,26,0,29,18,15,30,0,0,5
33,20,11,13,42,25,3,0,0,0
33,26,11,18,41,26,3,0,0,0
28,16,11,29,40,26,3,0,0,3
61,27,11,7,6,11,25,0,0,0
14,20,11,28,40,30,3,0,4,3
35,9,11,17,42,26,18,0,0,0
49,29,11,23,8,21,18,1,0,0
49,27,11,23,6,10,17,2,0,0
23,15,0,35,47,33,28,0,0,5
22,24,11,29,40,26,3,0,0,3
48,16,11,29,22,26,14,0,5,3
27,27,11,34,12,29,3,0,0,3
54,27,5,27,10,19,27,0,5,3
12,41,11,30,8,11,3,0,5,3
40,25,12,19,22,26,3,0,0,3
1,27,7,34,35,34,0,3,0,4
63,35,11,9,5,0,25,0,0,0
66,27,0,23,3,13,31,0,5,3
40,24,11,22,34,21,3,0,0,3
22,25,9,23,23,21,17,0,0,1
33,11,11,23,41,27,15,0,0,0
6,41,0,43,46,0,4,2,0,4
49,9,5,35,26,26,28,0,0,5
49,41,11,0,10,1,20,0,0,0
48,22,11,23,22,26,6,0,0,3
66,0,0,8,42,0,31,0,0,3
59,26,11,12,7,13,24,0,5,0
15,41,0,43,43,0,18,2,4,4
4,17,0,44,2,35,2,0,0,5
68,0,0,8,42,0,31,0,0,3
63,35,11,2,7,9,24,0,0,0
33,11,11,16,42,25,14,0,0,0
48,12,11,21,22,27,18,0,6,3
22,25,11,22,35,30,3,0,0,3
15,41,0,43,42,1,20,2,0,4
23,16,11,23,31,25,16,0,0,3
12,20,11,29,42,3,5,0,5,3
67,29,11,7,6,1,27,0,5,0
44,41,0,34,39,34,0,4,0,4
49,33,11,23,13,25,4,0,0,0
18,28,5,33,42,0,17,3,0,4
61,41,11,12,5,11,20,0,0,0
41,16,11,23,40,26,6,0,0,0
58,0,4,29,45,26,30,0,0,5
49,41,0,3,46,0,29,0,0,2
19,17,11,27,40,26,3,0,0,3
22,25,11,28,24,30,3,0,5,3
36,26,8,30,9,25,19,0,4,1
63,41,0,13,25,7,30,0,4,3
35,9,11,23,40,25,18,0,0,0
28,36,0,39,44,0,17,3,0,4
31,9,11,23,34,26,18,0,0,0
39,25,6,19,36,24,20,0,0,0
23,22,11,23,27,25,8,0,5,1
52,25,0,24,19,15,30,0,0,5
49,26,11,23,10,24,20,0,0,0
34,21,6,23,41,21,20,0,4,0
49,30,5,29,22,0,24,0,0,2
8,41,0,43,42,1,20,2,0,4
49,34,0,40,31,0,29,0,0,2
48,17,11,13,20,29,20,0,5,3
14,17,11,27,42,30,3,0,0,3
14,22,11,27,42,26,3,0,0,3
22,6,11,36,42,28,3,0,4,3
23,16,11,30,40,29,3,0,5,3
25,24,11,30,22,30,3,0,0,3
48,20,9,19,28,25,20,0,5,0
34,26,11,28,11,26,19,0,0,1
34,12,11,20,40,26,15,1,5,0
48,24,11,27,20,21,16,0,0,3
29,25,11,17,38,20,6,0,0,0
21,35,0,43,46,1,20,2,4,4
19,26,11,28,41,16,3,0,0,0
33,11,11,20,42,26,5,0,0,0
16,25,10,20,26,26,4,0,0,1
14,15,11,41,24,30,3,0,0,3
18,29,11,22,40,15,3,0,5,3
25,9,7,30,42,30,14,0,0,3
48,34,5,30,25,0,21,0,0,2
1 RI Na Mg Al Si 'K' Ca Ba Fe Type
2 35 11 11 16 41 29 18 0 0 0
3 22 3 11 23 40 25 13 0 0 1
4 35 21 11 26 23 26 6 0 0 0
5 3 41 5 28 48 0 3 0 0 2
6 68 4 0 13 0 10 31 0 6 3
7 22 9 6 27 42 25 19 1 5 3
8 34 26 11 5 41 2 20 0 0 1
9 46 20 6 23 38 24 20 0 0 0
10 8 34 0 43 45 5 20 2 2 4
11 35 20 12 23 20 24 8 0 6 3
12 22 24 11 27 28 16 3 0 0 3
13 32 4 7 17 46 27 20 0 6 3
14 63 21 11 9 10 10 28 0 0 0
15 57 32 11 5 7 9 24 0 0 1
16 23 20 11 32 22 26 4 0 5 1
17 27 26 11 30 22 27 3 0 0 3
18 28 41 0 38 41 0 13 3 4 4
19 22 9 7 27 43 30 3 0 0 3
20 50 23 0 32 41 17 30 0 0 5
21 40 16 7 27 40 26 19 1 0 3
22 58 19 11 12 18 13 27 0 5 0
23 66 0 0 37 17 33 31 0 6 3
24 48 16 11 15 20 29 20 0 5 3
25 33 25 11 19 34 25 3 0 5 0
26 52 12 4 42 15 32 25 1 8 5
27 10 22 11 27 42 15 3 0 0 3
28 14 11 11 37 39 30 3 0 0 3
29 24 41 0 38 43 0 10 3 3 4
30 25 22 11 27 36 25 3 0 0 3
31 16 24 11 23 25 25 4 0 0 1
32 23 16 11 30 42 30 3 0 0 3
33 45 24 8 28 14 25 20 0 0 1
34 0 41 0 1 48 0 1 0 0 2
35 22 26 11 28 22 30 3 0 0 3
36 33 17 11 23 41 25 5 0 5 0
37 11 9 11 29 42 30 3 0 6 0
38 14 11 11 30 40 29 3 0 6 0
39 30 4 6 30 40 30 20 0 0 3
40 23 12 11 27 41 30 3 1 6 3
41 5 37 7 39 20 33 3 0 0 3
42 37 9 11 23 40 29 18 0 0 0
43 38 17 11 13 42 25 4 0 5 3
44 22 17 11 28 35 26 3 0 0 3
45 14 22 8 27 42 16 3 0 0 3
46 49 26 11 8 24 11 20 1 6 1
47 33 9 11 20 42 26 17 0 0 0
48 6 31 5 44 0 34 0 4 0 5
49 33 21 11 23 22 25 3 0 0 0
50 35 19 11 23 39 26 10 0 0 0
51 60 20 11 17 31 23 10 0 0 3
52 33 7 11 23 45 26 14 0 3 0
53 48 20 11 13 35 25 6 1 5 3
54 32 23 11 16 42 26 3 0 0 0
55 14 20 11 28 42 30 3 0 0 3
56 22 41 0 43 38 0 23 2 0 4
57 31 17 11 30 30 23 8 0 3 0
58 64 41 5 38 1 32 26 0 0 4
59 55 27 11 16 10 6 22 0 0 0
60 34 26 11 15 33 9 16 0 0 1
61 48 25 12 23 22 21 4 0 0 3
62 48 26 12 23 10 23 4 0 6 3
63 49 26 11 15 23 10 18 0 0 0
64 9 23 11 20 31 25 15 0 0 0
65 63 35 11 2 7 9 24 0 0 0
66 64 27 11 2 6 6 28 0 5 0
67 8 28 0 43 42 10 23 3 2 4
68 49 20 7 23 21 26 20 0 5 0
69 62 35 11 13 5 13 20 0 7 1
70 68 0 0 41 0 26 31 4 6 3
71 58 19 11 12 20 13 27 0 5 0
72 42 41 5 30 22 0 21 0 0 2
73 48 26 11 23 22 25 3 0 5 3
74 48 41 2 30 22 0 27 0 0 2
75 42 22 12 26 20 24 4 0 5 3
76 64 23 11 9 10 10 28 0 2 0
77 22 26 11 27 22 29 3 0 0 3
78 33 7 11 27 42 25 14 0 0 0
79 2 17 11 16 41 27 4 0 6 0
80 22 18 10 23 41 21 15 0 0 1
81 29 16 11 23 41 25 6 0 0 0
82 33 11 11 23 41 26 15 0 0 0
83 31 23 11 21 42 24 3 0 0 0
84 57 39 12 11 5 0 24 0 0 1
85 34 21 9 23 31 26 15 0 0 0
86 58 1 5 29 39 17 30 0 0 5
87 38 27 12 28 8 23 3 0 5 3
88 68 8 0 6 10 2 31 0 0 3
89 33 11 11 27 31 23 10 0 5 0
90 33 17 11 20 41 30 13 0 0 0
91 60 27 3 23 17 15 29 0 0 3
92 22 41 0 36 42 0 17 3 0 4
93 35 9 11 19 40 27 18 0 6 0
94 58 20 11 12 18 13 27 0 5 0
95 49 27 5 19 31 0 27 0 0 2
96 6 41 0 43 46 0 5 2 0 4
97 59 26 11 12 10 11 24 0 3 0
98 30 41 0 33 41 0 17 3 0 4
99 51 29 3 30 6 16 29 0 5 3
100 15 16 11 27 42 16 3 0 0 3
101 48 20 12 19 22 26 6 0 0 3
102 33 29 11 23 30 18 3 0 0 0
103 23 23 11 28 22 30 3 0 6 3
104 64 41 5 23 1 14 17 3 0 4
105 24 41 0 38 42 0 5 3 0 4
106 22 41 0 38 42 0 4 3 0 4
107 4 17 0 44 2 35 2 0 0 5
108 27 17 11 33 28 30 3 0 0 3
109 30 41 0 43 43 0 20 2 0 4
110 49 26 8 20 13 26 20 0 0 0
111 49 8 0 30 47 15 30 0 0 5
112 40 8 6 11 47 15 23 0 5 3
113 18 41 0 43 43 0 18 2 0 4
114 49 29 11 19 13 2 20 0 0 0
115 22 41 0 38 46 0 9 3 0 4
116 26 15 11 23 22 26 19 0 0 1
117 64 26 8 20 22 26 20 0 0 4
118 54 26 5 30 15 22 24 1 5 3
119 47 39 7 43 4 34 0 3 0 4
120 40 27 0 3 47 0 29 0 0 3
121 34 5 6 23 46 25 20 0 6 0
122 23 17 7 20 40 26 19 0 6 3
123 13 16 11 24 43 30 3 0 6 0
124 66 27 7 18 3 5 29 0 0 3
125 68 27 7 6 2 5 30 0 0 3
126 56 17 1 27 45 10 30 0 6 5
127 33 15 11 23 42 23 4 0 5 0
128 22 2 0 19 48 34 20 0 0 4
129 21 34 0 43 22 5 20 3 0 4
130 55 26 13 14 7 2 18 0 0 0
131 33 7 11 23 43 26 10 0 0 0
132 14 17 11 28 42 30 3 0 0 3
133 23 11 11 28 44 30 3 0 0 3
134 53 41 0 38 45 0 8 3 0 4
135 33 9 11 23 42 26 18 0 5 0
136 65 26 0 29 18 15 30 0 0 5
137 33 20 11 13 42 25 3 0 0 0
138 33 26 11 18 41 26 3 0 0 0
139 28 16 11 29 40 26 3 0 0 3
140 61 27 11 7 6 11 25 0 0 0
141 14 20 11 28 40 30 3 0 4 3
142 35 9 11 17 42 26 18 0 0 0
143 49 29 11 23 8 21 18 1 0 0
144 49 27 11 23 6 10 17 2 0 0
145 23 15 0 35 47 33 28 0 0 5
146 22 24 11 29 40 26 3 0 0 3
147 48 16 11 29 22 26 14 0 5 3
148 27 27 11 34 12 29 3 0 0 3
149 54 27 5 27 10 19 27 0 5 3
150 12 41 11 30 8 11 3 0 5 3
151 40 25 12 19 22 26 3 0 0 3
152 1 27 7 34 35 34 0 3 0 4
153 63 35 11 9 5 0 25 0 0 0
154 66 27 0 23 3 13 31 0 5 3
155 40 24 11 22 34 21 3 0 0 3
156 22 25 9 23 23 21 17 0 0 1
157 33 11 11 23 41 27 15 0 0 0
158 6 41 0 43 46 0 4 2 0 4
159 49 9 5 35 26 26 28 0 0 5
160 49 41 11 0 10 1 20 0 0 0
161 48 22 11 23 22 26 6 0 0 3
162 66 0 0 8 42 0 31 0 0 3
163 59 26 11 12 7 13 24 0 5 0
164 15 41 0 43 43 0 18 2 4 4
165 4 17 0 44 2 35 2 0 0 5
166 68 0 0 8 42 0 31 0 0 3
167 63 35 11 2 7 9 24 0 0 0
168 33 11 11 16 42 25 14 0 0 0
169 48 12 11 21 22 27 18 0 6 3
170 22 25 11 22 35 30 3 0 0 3
171 15 41 0 43 42 1 20 2 0 4
172 23 16 11 23 31 25 16 0 0 3
173 12 20 11 29 42 3 5 0 5 3
174 67 29 11 7 6 1 27 0 5 0
175 44 41 0 34 39 34 0 4 0 4
176 49 33 11 23 13 25 4 0 0 0
177 18 28 5 33 42 0 17 3 0 4
178 61 41 11 12 5 11 20 0 0 0
179 41 16 11 23 40 26 6 0 0 0
180 58 0 4 29 45 26 30 0 0 5
181 49 41 0 3 46 0 29 0 0 2
182 19 17 11 27 40 26 3 0 0 3
183 22 25 11 28 24 30 3 0 5 3
184 36 26 8 30 9 25 19 0 4 1
185 63 41 0 13 25 7 30 0 4 3
186 35 9 11 23 40 25 18 0 0 0
187 28 36 0 39 44 0 17 3 0 4
188 31 9 11 23 34 26 18 0 0 0
189 39 25 6 19 36 24 20 0 0 0
190 23 22 11 23 27 25 8 0 5 1
191 52 25 0 24 19 15 30 0 0 5
192 49 26 11 23 10 24 20 0 0 0
193 34 21 6 23 41 21 20 0 4 0
194 49 30 5 29 22 0 24 0 0 2
195 8 41 0 43 42 1 20 2 0 4
196 49 34 0 40 31 0 29 0 0 2
197 48 17 11 13 20 29 20 0 5 3
198 14 17 11 27 42 30 3 0 0 3
199 14 22 11 27 42 26 3 0 0 3
200 22 6 11 36 42 28 3 0 4 3
201 23 16 11 30 40 29 3 0 5 3
202 25 24 11 30 22 30 3 0 0 3
203 48 20 9 19 28 25 20 0 5 0
204 34 26 11 28 11 26 19 0 0 1
205 34 12 11 20 40 26 15 1 5 0
206 48 24 11 27 20 21 16 0 0 3
207 29 25 11 17 38 20 6 0 0 0
208 21 35 0 43 46 1 20 2 4 4
209 19 26 11 28 41 16 3 0 0 0
210 33 11 11 20 42 26 5 0 0 0
211 16 25 10 20 26 26 4 0 0 1
212 14 15 11 41 24 30 3 0 0 3
213 18 29 11 22 40 15 3 0 5 3
214 25 9 7 30 42 30 14 0 0 3
215 48 34 5 30 25 0 21 0 0 2

View File

@@ -7,13 +7,8 @@ import pandas as pd
from sklearn.base import ClassifierMixin, BaseEstimator from sklearn.base import ClassifierMixin, BaseEstimator
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.utils.multiclass import unique_labels from sklearn.utils.multiclass import unique_labels
from sklearn.exceptions import NotFittedError
import networkx as nx import networkx as nx
from pgmpy.estimators import ( from pgmpy.estimators import TreeSearch, BayesianEstimator
TreeSearch,
BayesianEstimator,
# MaximumLikelihoodEstimator,
)
from pgmpy.models import BayesianNetwork from pgmpy.models import BayesianNetwork
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -39,32 +34,13 @@ class TAN(ClassifierMixin, BaseEstimator):
""" """
def __init__( def __init__(
self, simple_init=False, show_progress=False, random_state=None self, simple_init=True, show_progress=False, random_state=None
): ):
self.simple_init = simple_init self.simple_init = simple_init
self.show_progress = show_progress self.show_progress = show_progress
self.random_state = random_state self.random_state = random_state
def fit(self, X, y, **kwargs): def __check_params_fit(self, X, y, kwargs):
"""A reference implementation of a fitting function for a classifier.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The training input samples.
y : array-like, shape (n_samples,)
The target values. An array of int.
**kwargs : dict
class_name : str (default='class') Name of the class column
features: list (default=None) List of features
head: int (default=None) Index of the head node. Default value
gets the node with the highest sum of weights (mutual_info)
Returns
-------
self : object
Returns self.
"""
# Check that X and y have correct shape # Check that X and y have correct shape
X, y = check_X_y(X, y) X, y = check_X_y(X, y)
# Store the classes seen during fit # Store the classes seen during fit
@@ -90,16 +66,55 @@ class TAN(ClassifierMixin, BaseEstimator):
if self.head_ is not None and self.head_ >= len(self.features_): if self.head_ is not None and self.head_ >= len(self.features_):
raise ValueError("Head index out of range") raise ValueError("Head index out of range")
def fit(self, X, y, **kwargs):
"""A reference implementation of a fitting function for a classifier.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The training input samples.
y : array-like, shape (n_samples,)
The target values. An array of int.
**kwargs : dict
class_name : str (default='class') Name of the class column
features: list (default=None) List of features
head: int (default=None) Index of the head node. Default value
gets the node with the highest sum of weights (mutual_info)
Returns
-------
self : object
Returns self.
Examples
--------
>>> import numpy as np
>>> import pandas as pd
>>> from bayesclass import TAN
>>> features = ['A', 'B', 'C', 'D', 'E']
>>> np.random.seed(17)
>>> values = pd.DataFrame(np.random.randint(low=0, high=2,
... size=(1000, 5)), columns=features)
>>> train_data = values[:800]
>>> train_y = train_data['E']
>>> predict_data = values[800:]
>>> train_data = train_data.drop('E', axis=1)
>>> model = TAN(random_state=17)
>>> features.remove('E')
>>> model.fit(train_data, train_y, features=features, class_name='E')
TAN(random_state=17)
"""
self.__check_params_fit(X, y, kwargs)
# Store the information needed to build the model
self.X_ = X self.X_ = X
self.y_ = y.astype(int) self.y_ = y.astype(int)
self.dataset_ = pd.DataFrame( self.dataset_ = pd.DataFrame(
self.X_, columns=self.features_, dtype="int16" self.X_, columns=self.features_, dtype="int16"
) )
self.dataset_[self.class_name_] = self.y_ self.dataset_[self.class_name_] = self.y_
try: # Build the DAG
check_is_fitted(self, ["X_", "y_", "fitted_"])
except NotFittedError:
self.__build() self.__build()
# Train the model
self.__train() self.__train()
self.fitted_ = True self.fitted_ = True
# Return the classifier # Return the classifier
@@ -145,24 +160,23 @@ class TAN(ClassifierMixin, BaseEstimator):
# learn graph structure # learn graph structure
root_node = None if self.head_ is None else self.features_[self.head_] root_node = None if self.head_ is None else self.features_[self.head_]
est = TreeSearch(self.dataset_, root_node=root_node) est = TreeSearch(self.dataset_, root_node=root_node)
dag = est.estimate( self.dag_ = est.estimate(
estimator_type="tan", estimator_type="tan",
class_node=self.class_name_, class_node=self.class_name_,
show_progress=self.show_progress, show_progress=self.show_progress,
) )
if self.head_ is None: if self.head_ is None:
self.head_ = est.root_node self.head_ = est.root_node
self.model_ = BayesianNetwork(
dag.edges(), show_progress=self.show_progress
)
def __train(self): def __train(self):
self.model_ = BayesianNetwork(
self.dag_.edges(), show_progress=self.show_progress
)
self.model_.fit( self.model_.fit(
self.dataset_, self.dataset_,
# estimator=MaximumLikelihoodEstimator, # estimator=MaximumLikelihoodEstimator,
estimator=BayesianEstimator, estimator=BayesianEstimator,
prior_type="K2", prior_type="K2",
n_jobs=1,
) )
def plot(self, title=""): def plot(self, title=""):
@@ -203,7 +217,7 @@ class TAN(ClassifierMixin, BaseEstimator):
>>> train_data = values[:800] >>> train_data = values[:800]
>>> train_y = train_data['E'] >>> train_y = train_data['E']
>>> predict_data = values[800:] >>> predict_data = values[800:]
>>> train_data.drop('E', axis=1, inplace=True) >>> train_data = train_data.drop('E', axis=1)
>>> model = TAN(random_state=17) >>> model = TAN(random_state=17)
>>> features.remove('E') >>> features.remove('E')
>>> model.fit(train_data, train_y, features=features, class_name='E') >>> model.fit(train_data, train_y, features=features, class_name='E')

View File

@@ -1 +0,0 @@
m0 <- ulam(alist(height ~ dnorm(mu, sigma), mu <- a, a ~ dnorm(186, 10), sigma ~ dexp(1)), data = d, chains = 4, iter = 2000, cores = 4, log_lik=TRUE)

View File

@@ -19,7 +19,7 @@ def data():
def test_TAN_constructor(): def test_TAN_constructor():
clf = TAN() clf = TAN()
# Test default values of hyperparameters # Test default values of hyperparameters
assert not clf.simple_init assert clf.simple_init
assert not clf.show_progress assert not clf.show_progress
assert clf.random_state is None assert clf.random_state is None
clf = TAN(simple_init=True, show_progress=True, random_state=17) clf = TAN(simple_init=True, show_progress=True, random_state=17)
@@ -34,6 +34,14 @@ def test_TAN_random_head(data):
assert clf.head_ == 3 assert clf.head_ == 3
def test_TAN_dag_initializer(data):
clf_not_simple = TAN(simple_init=False)
clf_simple = TAN(simple_init=True)
clf_not_simple.fit(*data, head=0)
clf_simple.fit(*data, head=0)
assert clf_simple.dag_.edges == clf_not_simple.dag_.edges
def test_TAN_classifier(data): def test_TAN_classifier(data):
clf = TAN() clf = TAN()

24
example.py Normal file
View File

@@ -0,0 +1,24 @@
from benchmark import Discretizer
from bayesclass import TAN
import sys
from sklearn.model_selection import cross_val_score, StratifiedKFold
if len(sys.argv) < 2:
print("Usage: python3 example.py <dataset> [n_folds]")
exit(1)
random_state = 17
name = sys.argv[1]
n_folds = int(sys.argv[2]) if len(sys.argv) == 3 else 5
dt = Discretizer()
X, y = dt.load(name)
clf = TAN(random_state=random_state)
fit_params = dict(
features=dt.get_features(), class_name=dt.get_class_name(), head=0
)
kfold = StratifiedKFold(
n_splits=n_folds, shuffle=True, random_state=random_state
)
score = cross_val_score(clf, X, y, cv=kfold, fit_params=fit_params)
print(f"Accuracy in {n_folds} folds stratified crossvalidation")
print(f"{name}{'.' * 10}{score.mean():9.7f}")

215
glass.csv
View File

@@ -1,215 +0,0 @@
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
35,11,11,16,41,29,18,0,0,0
22,3,11,23,40,25,13,0,0,1
35,21,11,26,23,26,6,0,0,0
3,41,5,28,48,0,3,0,0,2
68,4,0,13,0,10,31,0,6,3
22,9,6,27,42,25,19,1,5,3
34,26,11,5,41,2,20,0,0,1
46,20,6,23,38,24,20,0,0,0
8,34,0,43,45,5,20,2,2,4
35,20,12,23,20,24,8,0,6,3
22,24,11,27,28,16,3,0,0,3
32,4,7,17,46,27,20,0,6,3
63,21,11,9,10,10,28,0,0,0
57,32,11,5,7,9,24,0,0,1
23,20,11,32,22,26,4,0,5,1
27,26,11,30,22,27,3,0,0,3
28,41,0,38,41,0,13,3,4,4
22,9,7,27,43,30,3,0,0,3
50,23,0,32,41,17,30,0,0,5
40,16,7,27,40,26,19,1,0,3
58,19,11,12,18,13,27,0,5,0
66,0,0,37,17,33,31,0,6,3
48,16,11,15,20,29,20,0,5,3
33,25,11,19,34,25,3,0,5,0
52,12,4,42,15,32,25,1,8,5
10,22,11,27,42,15,3,0,0,3
14,11,11,37,39,30,3,0,0,3
24,41,0,38,43,0,10,3,3,4
25,22,11,27,36,25,3,0,0,3
16,24,11,23,25,25,4,0,0,1
23,16,11,30,42,30,3,0,0,3
45,24,8,28,14,25,20,0,0,1
0,41,0,1,48,0,1,0,0,2
22,26,11,28,22,30,3,0,0,3
33,17,11,23,41,25,5,0,5,0
11,9,11,29,42,30,3,0,6,0
14,11,11,30,40,29,3,0,6,0
30,4,6,30,40,30,20,0,0,3
23,12,11,27,41,30,3,1,6,3
5,37,7,39,20,33,3,0,0,3
37,9,11,23,40,29,18,0,0,0
38,17,11,13,42,25,4,0,5,3
22,17,11,28,35,26,3,0,0,3
14,22,8,27,42,16,3,0,0,3
49,26,11,8,24,11,20,1,6,1
33,9,11,20,42,26,17,0,0,0
6,31,5,44,0,34,0,4,0,5
33,21,11,23,22,25,3,0,0,0
35,19,11,23,39,26,10,0,0,0
60,20,11,17,31,23,10,0,0,3
33,7,11,23,45,26,14,0,3,0
48,20,11,13,35,25,6,1,5,3
32,23,11,16,42,26,3,0,0,0
14,20,11,28,42,30,3,0,0,3
22,41,0,43,38,0,23,2,0,4
31,17,11,30,30,23,8,0,3,0
64,41,5,38,1,32,26,0,0,4
55,27,11,16,10,6,22,0,0,0
34,26,11,15,33,9,16,0,0,1
48,25,12,23,22,21,4,0,0,3
48,26,12,23,10,23,4,0,6,3
49,26,11,15,23,10,18,0,0,0
9,23,11,20,31,25,15,0,0,0
63,35,11,2,7,9,24,0,0,0
64,27,11,2,6,6,28,0,5,0
8,28,0,43,42,10,23,3,2,4
49,20,7,23,21,26,20,0,5,0
62,35,11,13,5,13,20,0,7,1
68,0,0,41,0,26,31,4,6,3
58,19,11,12,20,13,27,0,5,0
42,41,5,30,22,0,21,0,0,2
48,26,11,23,22,25,3,0,5,3
48,41,2,30,22,0,27,0,0,2
42,22,12,26,20,24,4,0,5,3
64,23,11,9,10,10,28,0,2,0
22,26,11,27,22,29,3,0,0,3
33,7,11,27,42,25,14,0,0,0
2,17,11,16,41,27,4,0,6,0
22,18,10,23,41,21,15,0,0,1
29,16,11,23,41,25,6,0,0,0
33,11,11,23,41,26,15,0,0,0
31,23,11,21,42,24,3,0,0,0
57,39,12,11,5,0,24,0,0,1
34,21,9,23,31,26,15,0,0,0
58,1,5,29,39,17,30,0,0,5
38,27,12,28,8,23,3,0,5,3
68,8,0,6,10,2,31,0,0,3
33,11,11,27,31,23,10,0,5,0
33,17,11,20,41,30,13,0,0,0
60,27,3,23,17,15,29,0,0,3
22,41,0,36,42,0,17,3,0,4
35,9,11,19,40,27,18,0,6,0
58,20,11,12,18,13,27,0,5,0
49,27,5,19,31,0,27,0,0,2
6,41,0,43,46,0,5,2,0,4
59,26,11,12,10,11,24,0,3,0
30,41,0,33,41,0,17,3,0,4
51,29,3,30,6,16,29,0,5,3
15,16,11,27,42,16,3,0,0,3
48,20,12,19,22,26,6,0,0,3
33,29,11,23,30,18,3,0,0,0
23,23,11,28,22,30,3,0,6,3
64,41,5,23,1,14,17,3,0,4
24,41,0,38,42,0,5,3,0,4
22,41,0,38,42,0,4,3,0,4
4,17,0,44,2,35,2,0,0,5
27,17,11,33,28,30,3,0,0,3
30,41,0,43,43,0,20,2,0,4
49,26,8,20,13,26,20,0,0,0
49,8,0,30,47,15,30,0,0,5
40,8,6,11,47,15,23,0,5,3
18,41,0,43,43,0,18,2,0,4
49,29,11,19,13,2,20,0,0,0
22,41,0,38,46,0,9,3,0,4
26,15,11,23,22,26,19,0,0,1
64,26,8,20,22,26,20,0,0,4
54,26,5,30,15,22,24,1,5,3
47,39,7,43,4,34,0,3,0,4
40,27,0,3,47,0,29,0,0,3
34,5,6,23,46,25,20,0,6,0
23,17,7,20,40,26,19,0,6,3
13,16,11,24,43,30,3,0,6,0
66,27,7,18,3,5,29,0,0,3
68,27,7,6,2,5,30,0,0,3
56,17,1,27,45,10,30,0,6,5
33,15,11,23,42,23,4,0,5,0
22,2,0,19,48,34,20,0,0,4
21,34,0,43,22,5,20,3,0,4
55,26,13,14,7,2,18,0,0,0
33,7,11,23,43,26,10,0,0,0
14,17,11,28,42,30,3,0,0,3
23,11,11,28,44,30,3,0,0,3
53,41,0,38,45,0,8,3,0,4
33,9,11,23,42,26,18,0,5,0
65,26,0,29,18,15,30,0,0,5
33,20,11,13,42,25,3,0,0,0
33,26,11,18,41,26,3,0,0,0
28,16,11,29,40,26,3,0,0,3
61,27,11,7,6,11,25,0,0,0
14,20,11,28,40,30,3,0,4,3
35,9,11,17,42,26,18,0,0,0
49,29,11,23,8,21,18,1,0,0
49,27,11,23,6,10,17,2,0,0
23,15,0,35,47,33,28,0,0,5
22,24,11,29,40,26,3,0,0,3
48,16,11,29,22,26,14,0,5,3
27,27,11,34,12,29,3,0,0,3
54,27,5,27,10,19,27,0,5,3
12,41,11,30,8,11,3,0,5,3
40,25,12,19,22,26,3,0,0,3
1,27,7,34,35,34,0,3,0,4
63,35,11,9,5,0,25,0,0,0
66,27,0,23,3,13,31,0,5,3
40,24,11,22,34,21,3,0,0,3
22,25,9,23,23,21,17,0,0,1
33,11,11,23,41,27,15,0,0,0
6,41,0,43,46,0,4,2,0,4
49,9,5,35,26,26,28,0,0,5
49,41,11,0,10,1,20,0,0,0
48,22,11,23,22,26,6,0,0,3
66,0,0,8,42,0,31,0,0,3
59,26,11,12,7,13,24,0,5,0
15,41,0,43,43,0,18,2,4,4
4,17,0,44,2,35,2,0,0,5
68,0,0,8,42,0,31,0,0,3
63,35,11,2,7,9,24,0,0,0
33,11,11,16,42,25,14,0,0,0
48,12,11,21,22,27,18,0,6,3
22,25,11,22,35,30,3,0,0,3
15,41,0,43,42,1,20,2,0,4
23,16,11,23,31,25,16,0,0,3
12,20,11,29,42,3,5,0,5,3
67,29,11,7,6,1,27,0,5,0
44,41,0,34,39,34,0,4,0,4
49,33,11,23,13,25,4,0,0,0
18,28,5,33,42,0,17,3,0,4
61,41,11,12,5,11,20,0,0,0
41,16,11,23,40,26,6,0,0,0
58,0,4,29,45,26,30,0,0,5
49,41,0,3,46,0,29,0,0,2
19,17,11,27,40,26,3,0,0,3
22,25,11,28,24,30,3,0,5,3
36,26,8,30,9,25,19,0,4,1
63,41,0,13,25,7,30,0,4,3
35,9,11,23,40,25,18,0,0,0
28,36,0,39,44,0,17,3,0,4
31,9,11,23,34,26,18,0,0,0
39,25,6,19,36,24,20,0,0,0
23,22,11,23,27,25,8,0,5,1
52,25,0,24,19,15,30,0,0,5
49,26,11,23,10,24,20,0,0,0
34,21,6,23,41,21,20,0,4,0
49,30,5,29,22,0,24,0,0,2
8,41,0,43,42,1,20,2,0,4
49,34,0,40,31,0,29,0,0,2
48,17,11,13,20,29,20,0,5,3
14,17,11,27,42,30,3,0,0,3
14,22,11,27,42,26,3,0,0,3
22,6,11,36,42,28,3,0,4,3
23,16,11,30,40,29,3,0,5,3
25,24,11,30,22,30,3,0,0,3
48,20,9,19,28,25,20,0,5,0
34,26,11,28,11,26,19,0,0,1
34,12,11,20,40,26,15,1,5,0
48,24,11,27,20,21,16,0,0,3
29,25,11,17,38,20,6,0,0,0
21,35,0,43,46,1,20,2,4,4
19,26,11,28,41,16,3,0,0,0
33,11,11,20,42,26,5,0,0,0
16,25,10,20,26,26,4,0,0,1
14,15,11,41,24,30,3,0,0,3
18,29,11,22,40,15,3,0,5,3
25,9,7,30,42,30,14,0,0,3
48,34,5,30,25,0,21,0,0,2
1 RI Na Mg Al Si 'K' Ca Ba Fe Type
2 35 11 11 16 41 29 18 0 0 0
3 22 3 11 23 40 25 13 0 0 1
4 35 21 11 26 23 26 6 0 0 0
5 3 41 5 28 48 0 3 0 0 2
6 68 4 0 13 0 10 31 0 6 3
7 22 9 6 27 42 25 19 1 5 3
8 34 26 11 5 41 2 20 0 0 1
9 46 20 6 23 38 24 20 0 0 0
10 8 34 0 43 45 5 20 2 2 4
11 35 20 12 23 20 24 8 0 6 3
12 22 24 11 27 28 16 3 0 0 3
13 32 4 7 17 46 27 20 0 6 3
14 63 21 11 9 10 10 28 0 0 0
15 57 32 11 5 7 9 24 0 0 1
16 23 20 11 32 22 26 4 0 5 1
17 27 26 11 30 22 27 3 0 0 3
18 28 41 0 38 41 0 13 3 4 4
19 22 9 7 27 43 30 3 0 0 3
20 50 23 0 32 41 17 30 0 0 5
21 40 16 7 27 40 26 19 1 0 3
22 58 19 11 12 18 13 27 0 5 0
23 66 0 0 37 17 33 31 0 6 3
24 48 16 11 15 20 29 20 0 5 3
25 33 25 11 19 34 25 3 0 5 0
26 52 12 4 42 15 32 25 1 8 5
27 10 22 11 27 42 15 3 0 0 3
28 14 11 11 37 39 30 3 0 0 3
29 24 41 0 38 43 0 10 3 3 4
30 25 22 11 27 36 25 3 0 0 3
31 16 24 11 23 25 25 4 0 0 1
32 23 16 11 30 42 30 3 0 0 3
33 45 24 8 28 14 25 20 0 0 1
34 0 41 0 1 48 0 1 0 0 2
35 22 26 11 28 22 30 3 0 0 3
36 33 17 11 23 41 25 5 0 5 0
37 11 9 11 29 42 30 3 0 6 0
38 14 11 11 30 40 29 3 0 6 0
39 30 4 6 30 40 30 20 0 0 3
40 23 12 11 27 41 30 3 1 6 3
41 5 37 7 39 20 33 3 0 0 3
42 37 9 11 23 40 29 18 0 0 0
43 38 17 11 13 42 25 4 0 5 3
44 22 17 11 28 35 26 3 0 0 3
45 14 22 8 27 42 16 3 0 0 3
46 49 26 11 8 24 11 20 1 6 1
47 33 9 11 20 42 26 17 0 0 0
48 6 31 5 44 0 34 0 4 0 5
49 33 21 11 23 22 25 3 0 0 0
50 35 19 11 23 39 26 10 0 0 0
51 60 20 11 17 31 23 10 0 0 3
52 33 7 11 23 45 26 14 0 3 0
53 48 20 11 13 35 25 6 1 5 3
54 32 23 11 16 42 26 3 0 0 0
55 14 20 11 28 42 30 3 0 0 3
56 22 41 0 43 38 0 23 2 0 4
57 31 17 11 30 30 23 8 0 3 0
58 64 41 5 38 1 32 26 0 0 4
59 55 27 11 16 10 6 22 0 0 0
60 34 26 11 15 33 9 16 0 0 1
61 48 25 12 23 22 21 4 0 0 3
62 48 26 12 23 10 23 4 0 6 3
63 49 26 11 15 23 10 18 0 0 0
64 9 23 11 20 31 25 15 0 0 0
65 63 35 11 2 7 9 24 0 0 0
66 64 27 11 2 6 6 28 0 5 0
67 8 28 0 43 42 10 23 3 2 4
68 49 20 7 23 21 26 20 0 5 0
69 62 35 11 13 5 13 20 0 7 1
70 68 0 0 41 0 26 31 4 6 3
71 58 19 11 12 20 13 27 0 5 0
72 42 41 5 30 22 0 21 0 0 2
73 48 26 11 23 22 25 3 0 5 3
74 48 41 2 30 22 0 27 0 0 2
75 42 22 12 26 20 24 4 0 5 3
76 64 23 11 9 10 10 28 0 2 0
77 22 26 11 27 22 29 3 0 0 3
78 33 7 11 27 42 25 14 0 0 0
79 2 17 11 16 41 27 4 0 6 0
80 22 18 10 23 41 21 15 0 0 1
81 29 16 11 23 41 25 6 0 0 0
82 33 11 11 23 41 26 15 0 0 0
83 31 23 11 21 42 24 3 0 0 0
84 57 39 12 11 5 0 24 0 0 1
85 34 21 9 23 31 26 15 0 0 0
86 58 1 5 29 39 17 30 0 0 5
87 38 27 12 28 8 23 3 0 5 3
88 68 8 0 6 10 2 31 0 0 3
89 33 11 11 27 31 23 10 0 5 0
90 33 17 11 20 41 30 13 0 0 0
91 60 27 3 23 17 15 29 0 0 3
92 22 41 0 36 42 0 17 3 0 4
93 35 9 11 19 40 27 18 0 6 0
94 58 20 11 12 18 13 27 0 5 0
95 49 27 5 19 31 0 27 0 0 2
96 6 41 0 43 46 0 5 2 0 4
97 59 26 11 12 10 11 24 0 3 0
98 30 41 0 33 41 0 17 3 0 4
99 51 29 3 30 6 16 29 0 5 3
100 15 16 11 27 42 16 3 0 0 3
101 48 20 12 19 22 26 6 0 0 3
102 33 29 11 23 30 18 3 0 0 0
103 23 23 11 28 22 30 3 0 6 3
104 64 41 5 23 1 14 17 3 0 4
105 24 41 0 38 42 0 5 3 0 4
106 22 41 0 38 42 0 4 3 0 4
107 4 17 0 44 2 35 2 0 0 5
108 27 17 11 33 28 30 3 0 0 3
109 30 41 0 43 43 0 20 2 0 4
110 49 26 8 20 13 26 20 0 0 0
111 49 8 0 30 47 15 30 0 0 5
112 40 8 6 11 47 15 23 0 5 3
113 18 41 0 43 43 0 18 2 0 4
114 49 29 11 19 13 2 20 0 0 0
115 22 41 0 38 46 0 9 3 0 4
116 26 15 11 23 22 26 19 0 0 1
117 64 26 8 20 22 26 20 0 0 4
118 54 26 5 30 15 22 24 1 5 3
119 47 39 7 43 4 34 0 3 0 4
120 40 27 0 3 47 0 29 0 0 3
121 34 5 6 23 46 25 20 0 6 0
122 23 17 7 20 40 26 19 0 6 3
123 13 16 11 24 43 30 3 0 6 0
124 66 27 7 18 3 5 29 0 0 3
125 68 27 7 6 2 5 30 0 0 3
126 56 17 1 27 45 10 30 0 6 5
127 33 15 11 23 42 23 4 0 5 0
128 22 2 0 19 48 34 20 0 0 4
129 21 34 0 43 22 5 20 3 0 4
130 55 26 13 14 7 2 18 0 0 0
131 33 7 11 23 43 26 10 0 0 0
132 14 17 11 28 42 30 3 0 0 3
133 23 11 11 28 44 30 3 0 0 3
134 53 41 0 38 45 0 8 3 0 4
135 33 9 11 23 42 26 18 0 5 0
136 65 26 0 29 18 15 30 0 0 5
137 33 20 11 13 42 25 3 0 0 0
138 33 26 11 18 41 26 3 0 0 0
139 28 16 11 29 40 26 3 0 0 3
140 61 27 11 7 6 11 25 0 0 0
141 14 20 11 28 40 30 3 0 4 3
142 35 9 11 17 42 26 18 0 0 0
143 49 29 11 23 8 21 18 1 0 0
144 49 27 11 23 6 10 17 2 0 0
145 23 15 0 35 47 33 28 0 0 5
146 22 24 11 29 40 26 3 0 0 3
147 48 16 11 29 22 26 14 0 5 3
148 27 27 11 34 12 29 3 0 0 3
149 54 27 5 27 10 19 27 0 5 3
150 12 41 11 30 8 11 3 0 5 3
151 40 25 12 19 22 26 3 0 0 3
152 1 27 7 34 35 34 0 3 0 4
153 63 35 11 9 5 0 25 0 0 0
154 66 27 0 23 3 13 31 0 5 3
155 40 24 11 22 34 21 3 0 0 3
156 22 25 9 23 23 21 17 0 0 1
157 33 11 11 23 41 27 15 0 0 0
158 6 41 0 43 46 0 4 2 0 4
159 49 9 5 35 26 26 28 0 0 5
160 49 41 11 0 10 1 20 0 0 0
161 48 22 11 23 22 26 6 0 0 3
162 66 0 0 8 42 0 31 0 0 3
163 59 26 11 12 7 13 24 0 5 0
164 15 41 0 43 43 0 18 2 4 4
165 4 17 0 44 2 35 2 0 0 5
166 68 0 0 8 42 0 31 0 0 3
167 63 35 11 2 7 9 24 0 0 0
168 33 11 11 16 42 25 14 0 0 0
169 48 12 11 21 22 27 18 0 6 3
170 22 25 11 22 35 30 3 0 0 3
171 15 41 0 43 42 1 20 2 0 4
172 23 16 11 23 31 25 16 0 0 3
173 12 20 11 29 42 3 5 0 5 3
174 67 29 11 7 6 1 27 0 5 0
175 44 41 0 34 39 34 0 4 0 4
176 49 33 11 23 13 25 4 0 0 0
177 18 28 5 33 42 0 17 3 0 4
178 61 41 11 12 5 11 20 0 0 0
179 41 16 11 23 40 26 6 0 0 0
180 58 0 4 29 45 26 30 0 0 5
181 49 41 0 3 46 0 29 0 0 2
182 19 17 11 27 40 26 3 0 0 3
183 22 25 11 28 24 30 3 0 5 3
184 36 26 8 30 9 25 19 0 4 1
185 63 41 0 13 25 7 30 0 4 3
186 35 9 11 23 40 25 18 0 0 0
187 28 36 0 39 44 0 17 3 0 4
188 31 9 11 23 34 26 18 0 0 0
189 39 25 6 19 36 24 20 0 0 0
190 23 22 11 23 27 25 8 0 5 1
191 52 25 0 24 19 15 30 0 0 5
192 49 26 11 23 10 24 20 0 0 0
193 34 21 6 23 41 21 20 0 4 0
194 49 30 5 29 22 0 24 0 0 2
195 8 41 0 43 42 1 20 2 0 4
196 49 34 0 40 31 0 29 0 0 2
197 48 17 11 13 20 29 20 0 5 3
198 14 17 11 27 42 30 3 0 0 3
199 14 22 11 27 42 26 3 0 0 3
200 22 6 11 36 42 28 3 0 4 3
201 23 16 11 30 40 29 3 0 5 3
202 25 24 11 30 22 30 3 0 0 3
203 48 20 9 19 28 25 20 0 5 0
204 34 26 11 28 11 26 19 0 0 1
205 34 12 11 20 40 26 15 1 5 0
206 48 24 11 27 20 21 16 0 0 3
207 29 25 11 17 38 20 6 0 0 0
208 21 35 0 43 46 1 20 2 4 4
209 19 26 11 28 41 16 3 0 0 0
210 33 11 11 20 42 26 5 0 0 0
211 16 25 10 20 26 26 4 0 0 1
212 14 15 11 41 24 30 3 0 0 3
213 18 29 11 22 40 15 3 0 5 3
214 25 9 7 30 42 30 14 0 0 3
215 48 34 5 30 25 0 21 0 0 2

Binary file not shown.

Before

Width:  |  Height:  |  Size: 45 KiB

After

Width:  |  Height:  |  Size: 45 KiB

File diff suppressed because one or more lines are too long

111
test.py
View File

@@ -1,111 +0,0 @@
#!/usr/bin/env python
# coding: utf-8
# In[1]:
from mdlp import MDLP
import pandas as pd
from benchmark import Datasets
from bayesclass import TAN
from sklearn.model_selection import (
cross_validate,
StratifiedKFold,
KFold,
cross_val_score,
train_test_split,
)
import numpy as np
import warnings
from stree import Stree
# In[2]:
# Get data as a dataset
dt = Datasets()
data = dt.load("glass", dataframe=True)
features = dt.dataset.features
class_name = dt.dataset.class_name
factorization, class_factors = pd.factorize(data[class_name])
data[class_name] = factorization
data.head()
# In[3]:
# Fayyad Irani
discretiz = MDLP()
Xdisc = discretiz.fit_transform(
data[features].to_numpy(), data[class_name].to_numpy()
)
features_discretized = pd.DataFrame(Xdisc, columns=features)
dataset_discretized = features_discretized.copy()
dataset_discretized[class_name] = data[class_name]
X = dataset_discretized[features]
y = dataset_discretized[class_name]
dataset_discretized
# In[4]:
n_folds = 5
score_name = "accuracy"
random_state = 17
test_size = 0.3
def validate_classifier(model, X, y, stratified, fit_params):
stratified_class = StratifiedKFold if stratified else KFold
kfold = stratified_class(
shuffle=True, random_state=random_state, n_splits=n_folds
)
# return cross_validate(model, X, y, cv=kfold, return_estimator=True,
# scoring=score_name)
return cross_val_score(model, X, y, fit_params=fit_params)
def split_data(X, y, stratified):
if stratified:
return train_test_split(
X,
y,
test_size=test_size,
random_state=random_state,
stratify=y,
shuffle=True,
)
else:
return train_test_split(
X, y, test_size=test_size, random_state=random_state, shuffle=True
)
# In[5]:
warnings.filterwarnings("ignore")
for simple_init in [False, True]:
model = TAN(simple_init=simple_init)
for head in range(4):
X_train, X_test, y_train, y_test = split_data(X, y, stratified=False)
model.fit(
X_train,
y_train,
head=head,
features=features,
class_name=class_name,
)
y = model.predict(X_test)
model.plot()
# In[ ]:
model = TAN(simple_init=simple_init)
model.fit(X, y, features=features, class_name=class_name)
model.plot(
f"**simple_init={simple_init} head={head} score={model.score(X, y)}"
)