Fix fit/build/train mistake

This commit is contained in:
2022-11-10 09:37:12 +01:00
parent 02110f7608
commit 180365d727
6 changed files with 1198 additions and 607 deletions

215
balance-scale.csv Normal file
View File

@@ -0,0 +1,215 @@
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
35,11,11,16,41,29,18,0,0,0
22,3,11,23,40,25,13,0,0,1
35,21,11,26,23,26,6,0,0,0
3,41,5,28,48,0,3,0,0,2
68,4,0,13,0,10,31,0,6,3
22,9,6,27,42,25,19,1,5,3
34,26,11,5,41,2,20,0,0,1
46,20,6,23,38,24,20,0,0,0
8,34,0,43,45,5,20,2,2,4
35,20,12,23,20,24,8,0,6,3
22,24,11,27,28,16,3,0,0,3
32,4,7,17,46,27,20,0,6,3
63,21,11,9,10,10,28,0,0,0
57,32,11,5,7,9,24,0,0,1
23,20,11,32,22,26,4,0,5,1
27,26,11,30,22,27,3,0,0,3
28,41,0,38,41,0,13,3,4,4
22,9,7,27,43,30,3,0,0,3
50,23,0,32,41,17,30,0,0,5
40,16,7,27,40,26,19,1,0,3
58,19,11,12,18,13,27,0,5,0
66,0,0,37,17,33,31,0,6,3
48,16,11,15,20,29,20,0,5,3
33,25,11,19,34,25,3,0,5,0
52,12,4,42,15,32,25,1,8,5
10,22,11,27,42,15,3,0,0,3
14,11,11,37,39,30,3,0,0,3
24,41,0,38,43,0,10,3,3,4
25,22,11,27,36,25,3,0,0,3
16,24,11,23,25,25,4,0,0,1
23,16,11,30,42,30,3,0,0,3
45,24,8,28,14,25,20,0,0,1
0,41,0,1,48,0,1,0,0,2
22,26,11,28,22,30,3,0,0,3
33,17,11,23,41,25,5,0,5,0
11,9,11,29,42,30,3,0,6,0
14,11,11,30,40,29,3,0,6,0
30,4,6,30,40,30,20,0,0,3
23,12,11,27,41,30,3,1,6,3
5,37,7,39,20,33,3,0,0,3
37,9,11,23,40,29,18,0,0,0
38,17,11,13,42,25,4,0,5,3
22,17,11,28,35,26,3,0,0,3
14,22,8,27,42,16,3,0,0,3
49,26,11,8,24,11,20,1,6,1
33,9,11,20,42,26,17,0,0,0
6,31,5,44,0,34,0,4,0,5
33,21,11,23,22,25,3,0,0,0
35,19,11,23,39,26,10,0,0,0
60,20,11,17,31,23,10,0,0,3
33,7,11,23,45,26,14,0,3,0
48,20,11,13,35,25,6,1,5,3
32,23,11,16,42,26,3,0,0,0
14,20,11,28,42,30,3,0,0,3
22,41,0,43,38,0,23,2,0,4
31,17,11,30,30,23,8,0,3,0
64,41,5,38,1,32,26,0,0,4
55,27,11,16,10,6,22,0,0,0
34,26,11,15,33,9,16,0,0,1
48,25,12,23,22,21,4,0,0,3
48,26,12,23,10,23,4,0,6,3
49,26,11,15,23,10,18,0,0,0
9,23,11,20,31,25,15,0,0,0
63,35,11,2,7,9,24,0,0,0
64,27,11,2,6,6,28,0,5,0
8,28,0,43,42,10,23,3,2,4
49,20,7,23,21,26,20,0,5,0
62,35,11,13,5,13,20,0,7,1
68,0,0,41,0,26,31,4,6,3
58,19,11,12,20,13,27,0,5,0
42,41,5,30,22,0,21,0,0,2
48,26,11,23,22,25,3,0,5,3
48,41,2,30,22,0,27,0,0,2
42,22,12,26,20,24,4,0,5,3
64,23,11,9,10,10,28,0,2,0
22,26,11,27,22,29,3,0,0,3
33,7,11,27,42,25,14,0,0,0
2,17,11,16,41,27,4,0,6,0
22,18,10,23,41,21,15,0,0,1
29,16,11,23,41,25,6,0,0,0
33,11,11,23,41,26,15,0,0,0
31,23,11,21,42,24,3,0,0,0
57,39,12,11,5,0,24,0,0,1
34,21,9,23,31,26,15,0,0,0
58,1,5,29,39,17,30,0,0,5
38,27,12,28,8,23,3,0,5,3
68,8,0,6,10,2,31,0,0,3
33,11,11,27,31,23,10,0,5,0
33,17,11,20,41,30,13,0,0,0
60,27,3,23,17,15,29,0,0,3
22,41,0,36,42,0,17,3,0,4
35,9,11,19,40,27,18,0,6,0
58,20,11,12,18,13,27,0,5,0
49,27,5,19,31,0,27,0,0,2
6,41,0,43,46,0,5,2,0,4
59,26,11,12,10,11,24,0,3,0
30,41,0,33,41,0,17,3,0,4
51,29,3,30,6,16,29,0,5,3
15,16,11,27,42,16,3,0,0,3
48,20,12,19,22,26,6,0,0,3
33,29,11,23,30,18,3,0,0,0
23,23,11,28,22,30,3,0,6,3
64,41,5,23,1,14,17,3,0,4
24,41,0,38,42,0,5,3,0,4
22,41,0,38,42,0,4,3,0,4
4,17,0,44,2,35,2,0,0,5
27,17,11,33,28,30,3,0,0,3
30,41,0,43,43,0,20,2,0,4
49,26,8,20,13,26,20,0,0,0
49,8,0,30,47,15,30,0,0,5
40,8,6,11,47,15,23,0,5,3
18,41,0,43,43,0,18,2,0,4
49,29,11,19,13,2,20,0,0,0
22,41,0,38,46,0,9,3,0,4
26,15,11,23,22,26,19,0,0,1
64,26,8,20,22,26,20,0,0,4
54,26,5,30,15,22,24,1,5,3
47,39,7,43,4,34,0,3,0,4
40,27,0,3,47,0,29,0,0,3
34,5,6,23,46,25,20,0,6,0
23,17,7,20,40,26,19,0,6,3
13,16,11,24,43,30,3,0,6,0
66,27,7,18,3,5,29,0,0,3
68,27,7,6,2,5,30,0,0,3
56,17,1,27,45,10,30,0,6,5
33,15,11,23,42,23,4,0,5,0
22,2,0,19,48,34,20,0,0,4
21,34,0,43,22,5,20,3,0,4
55,26,13,14,7,2,18,0,0,0
33,7,11,23,43,26,10,0,0,0
14,17,11,28,42,30,3,0,0,3
23,11,11,28,44,30,3,0,0,3
53,41,0,38,45,0,8,3,0,4
33,9,11,23,42,26,18,0,5,0
65,26,0,29,18,15,30,0,0,5
33,20,11,13,42,25,3,0,0,0
33,26,11,18,41,26,3,0,0,0
28,16,11,29,40,26,3,0,0,3
61,27,11,7,6,11,25,0,0,0
14,20,11,28,40,30,3,0,4,3
35,9,11,17,42,26,18,0,0,0
49,29,11,23,8,21,18,1,0,0
49,27,11,23,6,10,17,2,0,0
23,15,0,35,47,33,28,0,0,5
22,24,11,29,40,26,3,0,0,3
48,16,11,29,22,26,14,0,5,3
27,27,11,34,12,29,3,0,0,3
54,27,5,27,10,19,27,0,5,3
12,41,11,30,8,11,3,0,5,3
40,25,12,19,22,26,3,0,0,3
1,27,7,34,35,34,0,3,0,4
63,35,11,9,5,0,25,0,0,0
66,27,0,23,3,13,31,0,5,3
40,24,11,22,34,21,3,0,0,3
22,25,9,23,23,21,17,0,0,1
33,11,11,23,41,27,15,0,0,0
6,41,0,43,46,0,4,2,0,4
49,9,5,35,26,26,28,0,0,5
49,41,11,0,10,1,20,0,0,0
48,22,11,23,22,26,6,0,0,3
66,0,0,8,42,0,31,0,0,3
59,26,11,12,7,13,24,0,5,0
15,41,0,43,43,0,18,2,4,4
4,17,0,44,2,35,2,0,0,5
68,0,0,8,42,0,31,0,0,3
63,35,11,2,7,9,24,0,0,0
33,11,11,16,42,25,14,0,0,0
48,12,11,21,22,27,18,0,6,3
22,25,11,22,35,30,3,0,0,3
15,41,0,43,42,1,20,2,0,4
23,16,11,23,31,25,16,0,0,3
12,20,11,29,42,3,5,0,5,3
67,29,11,7,6,1,27,0,5,0
44,41,0,34,39,34,0,4,0,4
49,33,11,23,13,25,4,0,0,0
18,28,5,33,42,0,17,3,0,4
61,41,11,12,5,11,20,0,0,0
41,16,11,23,40,26,6,0,0,0
58,0,4,29,45,26,30,0,0,5
49,41,0,3,46,0,29,0,0,2
19,17,11,27,40,26,3,0,0,3
22,25,11,28,24,30,3,0,5,3
36,26,8,30,9,25,19,0,4,1
63,41,0,13,25,7,30,0,4,3
35,9,11,23,40,25,18,0,0,0
28,36,0,39,44,0,17,3,0,4
31,9,11,23,34,26,18,0,0,0
39,25,6,19,36,24,20,0,0,0
23,22,11,23,27,25,8,0,5,1
52,25,0,24,19,15,30,0,0,5
49,26,11,23,10,24,20,0,0,0
34,21,6,23,41,21,20,0,4,0
49,30,5,29,22,0,24,0,0,2
8,41,0,43,42,1,20,2,0,4
49,34,0,40,31,0,29,0,0,2
48,17,11,13,20,29,20,0,5,3
14,17,11,27,42,30,3,0,0,3
14,22,11,27,42,26,3,0,0,3
22,6,11,36,42,28,3,0,4,3
23,16,11,30,40,29,3,0,5,3
25,24,11,30,22,30,3,0,0,3
48,20,9,19,28,25,20,0,5,0
34,26,11,28,11,26,19,0,0,1
34,12,11,20,40,26,15,1,5,0
48,24,11,27,20,21,16,0,0,3
29,25,11,17,38,20,6,0,0,0
21,35,0,43,46,1,20,2,4,4
19,26,11,28,41,16,3,0,0,0
33,11,11,20,42,26,5,0,0,0
16,25,10,20,26,26,4,0,0,1
14,15,11,41,24,30,3,0,0,3
18,29,11,22,40,15,3,0,5,3
25,9,7,30,42,30,14,0,0,3
48,34,5,30,25,0,21,0,0,2
1 RI Na Mg Al Si 'K' Ca Ba Fe Type
2 35 11 11 16 41 29 18 0 0 0
3 22 3 11 23 40 25 13 0 0 1
4 35 21 11 26 23 26 6 0 0 0
5 3 41 5 28 48 0 3 0 0 2
6 68 4 0 13 0 10 31 0 6 3
7 22 9 6 27 42 25 19 1 5 3
8 34 26 11 5 41 2 20 0 0 1
9 46 20 6 23 38 24 20 0 0 0
10 8 34 0 43 45 5 20 2 2 4
11 35 20 12 23 20 24 8 0 6 3
12 22 24 11 27 28 16 3 0 0 3
13 32 4 7 17 46 27 20 0 6 3
14 63 21 11 9 10 10 28 0 0 0
15 57 32 11 5 7 9 24 0 0 1
16 23 20 11 32 22 26 4 0 5 1
17 27 26 11 30 22 27 3 0 0 3
18 28 41 0 38 41 0 13 3 4 4
19 22 9 7 27 43 30 3 0 0 3
20 50 23 0 32 41 17 30 0 0 5
21 40 16 7 27 40 26 19 1 0 3
22 58 19 11 12 18 13 27 0 5 0
23 66 0 0 37 17 33 31 0 6 3
24 48 16 11 15 20 29 20 0 5 3
25 33 25 11 19 34 25 3 0 5 0
26 52 12 4 42 15 32 25 1 8 5
27 10 22 11 27 42 15 3 0 0 3
28 14 11 11 37 39 30 3 0 0 3
29 24 41 0 38 43 0 10 3 3 4
30 25 22 11 27 36 25 3 0 0 3
31 16 24 11 23 25 25 4 0 0 1
32 23 16 11 30 42 30 3 0 0 3
33 45 24 8 28 14 25 20 0 0 1
34 0 41 0 1 48 0 1 0 0 2
35 22 26 11 28 22 30 3 0 0 3
36 33 17 11 23 41 25 5 0 5 0
37 11 9 11 29 42 30 3 0 6 0
38 14 11 11 30 40 29 3 0 6 0
39 30 4 6 30 40 30 20 0 0 3
40 23 12 11 27 41 30 3 1 6 3
41 5 37 7 39 20 33 3 0 0 3
42 37 9 11 23 40 29 18 0 0 0
43 38 17 11 13 42 25 4 0 5 3
44 22 17 11 28 35 26 3 0 0 3
45 14 22 8 27 42 16 3 0 0 3
46 49 26 11 8 24 11 20 1 6 1
47 33 9 11 20 42 26 17 0 0 0
48 6 31 5 44 0 34 0 4 0 5
49 33 21 11 23 22 25 3 0 0 0
50 35 19 11 23 39 26 10 0 0 0
51 60 20 11 17 31 23 10 0 0 3
52 33 7 11 23 45 26 14 0 3 0
53 48 20 11 13 35 25 6 1 5 3
54 32 23 11 16 42 26 3 0 0 0
55 14 20 11 28 42 30 3 0 0 3
56 22 41 0 43 38 0 23 2 0 4
57 31 17 11 30 30 23 8 0 3 0
58 64 41 5 38 1 32 26 0 0 4
59 55 27 11 16 10 6 22 0 0 0
60 34 26 11 15 33 9 16 0 0 1
61 48 25 12 23 22 21 4 0 0 3
62 48 26 12 23 10 23 4 0 6 3
63 49 26 11 15 23 10 18 0 0 0
64 9 23 11 20 31 25 15 0 0 0
65 63 35 11 2 7 9 24 0 0 0
66 64 27 11 2 6 6 28 0 5 0
67 8 28 0 43 42 10 23 3 2 4
68 49 20 7 23 21 26 20 0 5 0
69 62 35 11 13 5 13 20 0 7 1
70 68 0 0 41 0 26 31 4 6 3
71 58 19 11 12 20 13 27 0 5 0
72 42 41 5 30 22 0 21 0 0 2
73 48 26 11 23 22 25 3 0 5 3
74 48 41 2 30 22 0 27 0 0 2
75 42 22 12 26 20 24 4 0 5 3
76 64 23 11 9 10 10 28 0 2 0
77 22 26 11 27 22 29 3 0 0 3
78 33 7 11 27 42 25 14 0 0 0
79 2 17 11 16 41 27 4 0 6 0
80 22 18 10 23 41 21 15 0 0 1
81 29 16 11 23 41 25 6 0 0 0
82 33 11 11 23 41 26 15 0 0 0
83 31 23 11 21 42 24 3 0 0 0
84 57 39 12 11 5 0 24 0 0 1
85 34 21 9 23 31 26 15 0 0 0
86 58 1 5 29 39 17 30 0 0 5
87 38 27 12 28 8 23 3 0 5 3
88 68 8 0 6 10 2 31 0 0 3
89 33 11 11 27 31 23 10 0 5 0
90 33 17 11 20 41 30 13 0 0 0
91 60 27 3 23 17 15 29 0 0 3
92 22 41 0 36 42 0 17 3 0 4
93 35 9 11 19 40 27 18 0 6 0
94 58 20 11 12 18 13 27 0 5 0
95 49 27 5 19 31 0 27 0 0 2
96 6 41 0 43 46 0 5 2 0 4
97 59 26 11 12 10 11 24 0 3 0
98 30 41 0 33 41 0 17 3 0 4
99 51 29 3 30 6 16 29 0 5 3
100 15 16 11 27 42 16 3 0 0 3
101 48 20 12 19 22 26 6 0 0 3
102 33 29 11 23 30 18 3 0 0 0
103 23 23 11 28 22 30 3 0 6 3
104 64 41 5 23 1 14 17 3 0 4
105 24 41 0 38 42 0 5 3 0 4
106 22 41 0 38 42 0 4 3 0 4
107 4 17 0 44 2 35 2 0 0 5
108 27 17 11 33 28 30 3 0 0 3
109 30 41 0 43 43 0 20 2 0 4
110 49 26 8 20 13 26 20 0 0 0
111 49 8 0 30 47 15 30 0 0 5
112 40 8 6 11 47 15 23 0 5 3
113 18 41 0 43 43 0 18 2 0 4
114 49 29 11 19 13 2 20 0 0 0
115 22 41 0 38 46 0 9 3 0 4
116 26 15 11 23 22 26 19 0 0 1
117 64 26 8 20 22 26 20 0 0 4
118 54 26 5 30 15 22 24 1 5 3
119 47 39 7 43 4 34 0 3 0 4
120 40 27 0 3 47 0 29 0 0 3
121 34 5 6 23 46 25 20 0 6 0
122 23 17 7 20 40 26 19 0 6 3
123 13 16 11 24 43 30 3 0 6 0
124 66 27 7 18 3 5 29 0 0 3
125 68 27 7 6 2 5 30 0 0 3
126 56 17 1 27 45 10 30 0 6 5
127 33 15 11 23 42 23 4 0 5 0
128 22 2 0 19 48 34 20 0 0 4
129 21 34 0 43 22 5 20 3 0 4
130 55 26 13 14 7 2 18 0 0 0
131 33 7 11 23 43 26 10 0 0 0
132 14 17 11 28 42 30 3 0 0 3
133 23 11 11 28 44 30 3 0 0 3
134 53 41 0 38 45 0 8 3 0 4
135 33 9 11 23 42 26 18 0 5 0
136 65 26 0 29 18 15 30 0 0 5
137 33 20 11 13 42 25 3 0 0 0
138 33 26 11 18 41 26 3 0 0 0
139 28 16 11 29 40 26 3 0 0 3
140 61 27 11 7 6 11 25 0 0 0
141 14 20 11 28 40 30 3 0 4 3
142 35 9 11 17 42 26 18 0 0 0
143 49 29 11 23 8 21 18 1 0 0
144 49 27 11 23 6 10 17 2 0 0
145 23 15 0 35 47 33 28 0 0 5
146 22 24 11 29 40 26 3 0 0 3
147 48 16 11 29 22 26 14 0 5 3
148 27 27 11 34 12 29 3 0 0 3
149 54 27 5 27 10 19 27 0 5 3
150 12 41 11 30 8 11 3 0 5 3
151 40 25 12 19 22 26 3 0 0 3
152 1 27 7 34 35 34 0 3 0 4
153 63 35 11 9 5 0 25 0 0 0
154 66 27 0 23 3 13 31 0 5 3
155 40 24 11 22 34 21 3 0 0 3
156 22 25 9 23 23 21 17 0 0 1
157 33 11 11 23 41 27 15 0 0 0
158 6 41 0 43 46 0 4 2 0 4
159 49 9 5 35 26 26 28 0 0 5
160 49 41 11 0 10 1 20 0 0 0
161 48 22 11 23 22 26 6 0 0 3
162 66 0 0 8 42 0 31 0 0 3
163 59 26 11 12 7 13 24 0 5 0
164 15 41 0 43 43 0 18 2 4 4
165 4 17 0 44 2 35 2 0 0 5
166 68 0 0 8 42 0 31 0 0 3
167 63 35 11 2 7 9 24 0 0 0
168 33 11 11 16 42 25 14 0 0 0
169 48 12 11 21 22 27 18 0 6 3
170 22 25 11 22 35 30 3 0 0 3
171 15 41 0 43 42 1 20 2 0 4
172 23 16 11 23 31 25 16 0 0 3
173 12 20 11 29 42 3 5 0 5 3
174 67 29 11 7 6 1 27 0 5 0
175 44 41 0 34 39 34 0 4 0 4
176 49 33 11 23 13 25 4 0 0 0
177 18 28 5 33 42 0 17 3 0 4
178 61 41 11 12 5 11 20 0 0 0
179 41 16 11 23 40 26 6 0 0 0
180 58 0 4 29 45 26 30 0 0 5
181 49 41 0 3 46 0 29 0 0 2
182 19 17 11 27 40 26 3 0 0 3
183 22 25 11 28 24 30 3 0 5 3
184 36 26 8 30 9 25 19 0 4 1
185 63 41 0 13 25 7 30 0 4 3
186 35 9 11 23 40 25 18 0 0 0
187 28 36 0 39 44 0 17 3 0 4
188 31 9 11 23 34 26 18 0 0 0
189 39 25 6 19 36 24 20 0 0 0
190 23 22 11 23 27 25 8 0 5 1
191 52 25 0 24 19 15 30 0 0 5
192 49 26 11 23 10 24 20 0 0 0
193 34 21 6 23 41 21 20 0 4 0
194 49 30 5 29 22 0 24 0 0 2
195 8 41 0 43 42 1 20 2 0 4
196 49 34 0 40 31 0 29 0 0 2
197 48 17 11 13 20 29 20 0 5 3
198 14 17 11 27 42 30 3 0 0 3
199 14 22 11 27 42 26 3 0 0 3
200 22 6 11 36 42 28 3 0 4 3
201 23 16 11 30 40 29 3 0 5 3
202 25 24 11 30 22 30 3 0 0 3
203 48 20 9 19 28 25 20 0 5 0
204 34 26 11 28 11 26 19 0 0 1
205 34 12 11 20 40 26 15 1 5 0
206 48 24 11 27 20 21 16 0 0 3
207 29 25 11 17 38 20 6 0 0 0
208 21 35 0 43 46 1 20 2 4 4
209 19 26 11 28 41 16 3 0 0 0
210 33 11 11 20 42 26 5 0 0 0
211 16 25 10 20 26 26 4 0 0 1
212 14 15 11 41 24 30 3 0 0 3
213 18 29 11 22 40 15 3 0 5 3
214 25 9 7 30 42 30 14 0 0 3
215 48 34 5 30 25 0 21 0 0 2

View File

@@ -7,11 +7,12 @@ import pandas as pd
from sklearn.base import ClassifierMixin, BaseEstimator
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.utils.multiclass import unique_labels
from sklearn.exceptions import NotFittedError
import networkx as nx
from pgmpy.estimators import (
TreeSearch,
BayesianEstimator,
MaximumLikelihoodEstimator,
# MaximumLikelihoodEstimator,
)
from pgmpy.models import BayesianNetwork
import matplotlib.pyplot as plt
@@ -21,10 +22,12 @@ class TAN(ClassifierMixin, BaseEstimator):
"""An example classifier which implements a 1-NN algorithm.
For more information regarding how to build your own classifier, read more
in the :ref:`User Guide <user_guide>`.
Parameters
----------
demo_param : str, default='demo'
A parameter used for demonstation of how to pass and store paramters.
Attributes
----------
X_ : ndarray, shape (n_samples, n_features)
@@ -44,6 +47,7 @@ class TAN(ClassifierMixin, BaseEstimator):
def fit(self, X, y, **kwargs):
"""A reference implementation of a fitting function for a classifier.
Parameters
----------
X : array-like, shape (n_samples, n_features)
@@ -55,6 +59,7 @@ class TAN(ClassifierMixin, BaseEstimator):
features: list (default=None) List of features
head: int (default=None) Index of the head node. Default value
gets the node with the highest sum of weights (mutual_info)
Returns
-------
self : object
@@ -86,8 +91,17 @@ class TAN(ClassifierMixin, BaseEstimator):
raise ValueError("Head index out of range")
self.X_ = X
self.y_ = y
self.y_ = y.astype(int)
self.dataset_ = pd.DataFrame(
self.X_, columns=self.features_, dtype="int16"
)
self.dataset_[self.class_name_] = self.y_
try:
check_is_fitted(self, ["X_", "y_", "fitted_"])
except NotFittedError:
self.__build()
self.__train()
self.fitted_ = True
# Return the classifier
return self
@@ -101,6 +115,7 @@ class TAN(ClassifierMixin, BaseEstimator):
Marco Zaffalon,
Learning extended tree augmented naive structures,
International Journal of Approximate Reasoning,
Returns
-------
List
@@ -121,14 +136,12 @@ class TAN(ClassifierMixin, BaseEstimator):
]
return list(combinations(reordered, 2))
def __train(self):
def __build(self):
# Initialize a Naive Bayes model
net = [(self.class_name_, feature) for feature in self.features_]
self.model_ = BayesianNetwork(net)
# initialize a complete network with all edges
self.model_.add_edges_from(self.__initial_edges())
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
self.dataset_[self.class_name_] = self.y_
# learn graph structure
root_node = None if self.head_ is None else self.features_[self.head_]
est = TreeSearch(self.dataset_, root_node=root_node)
@@ -139,12 +152,17 @@ class TAN(ClassifierMixin, BaseEstimator):
)
if self.head_ is None:
self.head_ = est.root_node
self.model_ = BayesianNetwork(dag.edges())
self.model_ = BayesianNetwork(
dag.edges(), show_progress=self.show_progress
)
def __train(self):
self.model_.fit(
self.dataset_,
# estimator=MaximumLikelihoodEstimator,
estimator=BayesianEstimator,
prior_type="K2",
n_jobs=1,
)
def plot(self, title=""):
@@ -161,20 +179,54 @@ class TAN(ClassifierMixin, BaseEstimator):
def predict(self, X):
"""A reference implementation of a prediction for a classifier.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The input samples.
Returns
-------
y : ndarray, shape (n_samples,)
The label for each sample is the label of the closest sample
seen during fit.
Examples
--------
>>> import numpy as np
>>> import pandas as pd
>>> from bayesclass import TAN
>>> features = ['A', 'B', 'C', 'D', 'E']
>>> np.random.seed(17)
>>> values = pd.DataFrame(np.random.randint(low=0, high=2,
... size=(1000, 5)), columns=features)
>>> train_data = values[:800]
>>> train_y = train_data['E']
>>> predict_data = values[800:]
>>> train_data.drop('E', axis=1, inplace=True)
>>> model = TAN(random_state=17)
>>> features.remove('E')
>>> model.fit(train_data, train_y, features=features, class_name='E')
TAN(random_state=17)
>>> predict_data = predict_data.copy()
>>> predict_data.drop('E', axis=1, inplace=True)
>>> y_pred = model.predict(predict_data)
>>> y_pred[:10]
array([[0],
[0],
[1],
[1],
[0],
[1],
[1],
[1],
[0],
[1]])
"""
# Check is fit had been called
check_is_fitted(self, ["X_", "y_"])
check_is_fitted(self, ["X_", "y_", "fitted_"])
# Input validation
X = check_array(X)
dataset = pd.DataFrame(X, columns=self.features_)
return self.model_.predict(dataset).to_numpy()
dataset = pd.DataFrame(X, columns=self.features_, dtype="int16")
return self.model_.predict(dataset, n_jobs=1).to_numpy()

1
bayesclass/test.r Normal file
View File

@@ -0,0 +1 @@
m0 <- ulam(alist(height ~ dnorm(mu, sigma), mu <- a, a ~ dnorm(186, 10), sigma ~ dexp(1)), data = d, chains = 4, iter = 2000, cores = 4, log_lik=TRUE)

215
glass.csv Normal file
View File

@@ -0,0 +1,215 @@
RI,Na,Mg,Al,Si,'K',Ca,Ba,Fe,Type
35,11,11,16,41,29,18,0,0,0
22,3,11,23,40,25,13,0,0,1
35,21,11,26,23,26,6,0,0,0
3,41,5,28,48,0,3,0,0,2
68,4,0,13,0,10,31,0,6,3
22,9,6,27,42,25,19,1,5,3
34,26,11,5,41,2,20,0,0,1
46,20,6,23,38,24,20,0,0,0
8,34,0,43,45,5,20,2,2,4
35,20,12,23,20,24,8,0,6,3
22,24,11,27,28,16,3,0,0,3
32,4,7,17,46,27,20,0,6,3
63,21,11,9,10,10,28,0,0,0
57,32,11,5,7,9,24,0,0,1
23,20,11,32,22,26,4,0,5,1
27,26,11,30,22,27,3,0,0,3
28,41,0,38,41,0,13,3,4,4
22,9,7,27,43,30,3,0,0,3
50,23,0,32,41,17,30,0,0,5
40,16,7,27,40,26,19,1,0,3
58,19,11,12,18,13,27,0,5,0
66,0,0,37,17,33,31,0,6,3
48,16,11,15,20,29,20,0,5,3
33,25,11,19,34,25,3,0,5,0
52,12,4,42,15,32,25,1,8,5
10,22,11,27,42,15,3,0,0,3
14,11,11,37,39,30,3,0,0,3
24,41,0,38,43,0,10,3,3,4
25,22,11,27,36,25,3,0,0,3
16,24,11,23,25,25,4,0,0,1
23,16,11,30,42,30,3,0,0,3
45,24,8,28,14,25,20,0,0,1
0,41,0,1,48,0,1,0,0,2
22,26,11,28,22,30,3,0,0,3
33,17,11,23,41,25,5,0,5,0
11,9,11,29,42,30,3,0,6,0
14,11,11,30,40,29,3,0,6,0
30,4,6,30,40,30,20,0,0,3
23,12,11,27,41,30,3,1,6,3
5,37,7,39,20,33,3,0,0,3
37,9,11,23,40,29,18,0,0,0
38,17,11,13,42,25,4,0,5,3
22,17,11,28,35,26,3,0,0,3
14,22,8,27,42,16,3,0,0,3
49,26,11,8,24,11,20,1,6,1
33,9,11,20,42,26,17,0,0,0
6,31,5,44,0,34,0,4,0,5
33,21,11,23,22,25,3,0,0,0
35,19,11,23,39,26,10,0,0,0
60,20,11,17,31,23,10,0,0,3
33,7,11,23,45,26,14,0,3,0
48,20,11,13,35,25,6,1,5,3
32,23,11,16,42,26,3,0,0,0
14,20,11,28,42,30,3,0,0,3
22,41,0,43,38,0,23,2,0,4
31,17,11,30,30,23,8,0,3,0
64,41,5,38,1,32,26,0,0,4
55,27,11,16,10,6,22,0,0,0
34,26,11,15,33,9,16,0,0,1
48,25,12,23,22,21,4,0,0,3
48,26,12,23,10,23,4,0,6,3
49,26,11,15,23,10,18,0,0,0
9,23,11,20,31,25,15,0,0,0
63,35,11,2,7,9,24,0,0,0
64,27,11,2,6,6,28,0,5,0
8,28,0,43,42,10,23,3,2,4
49,20,7,23,21,26,20,0,5,0
62,35,11,13,5,13,20,0,7,1
68,0,0,41,0,26,31,4,6,3
58,19,11,12,20,13,27,0,5,0
42,41,5,30,22,0,21,0,0,2
48,26,11,23,22,25,3,0,5,3
48,41,2,30,22,0,27,0,0,2
42,22,12,26,20,24,4,0,5,3
64,23,11,9,10,10,28,0,2,0
22,26,11,27,22,29,3,0,0,3
33,7,11,27,42,25,14,0,0,0
2,17,11,16,41,27,4,0,6,0
22,18,10,23,41,21,15,0,0,1
29,16,11,23,41,25,6,0,0,0
33,11,11,23,41,26,15,0,0,0
31,23,11,21,42,24,3,0,0,0
57,39,12,11,5,0,24,0,0,1
34,21,9,23,31,26,15,0,0,0
58,1,5,29,39,17,30,0,0,5
38,27,12,28,8,23,3,0,5,3
68,8,0,6,10,2,31,0,0,3
33,11,11,27,31,23,10,0,5,0
33,17,11,20,41,30,13,0,0,0
60,27,3,23,17,15,29,0,0,3
22,41,0,36,42,0,17,3,0,4
35,9,11,19,40,27,18,0,6,0
58,20,11,12,18,13,27,0,5,0
49,27,5,19,31,0,27,0,0,2
6,41,0,43,46,0,5,2,0,4
59,26,11,12,10,11,24,0,3,0
30,41,0,33,41,0,17,3,0,4
51,29,3,30,6,16,29,0,5,3
15,16,11,27,42,16,3,0,0,3
48,20,12,19,22,26,6,0,0,3
33,29,11,23,30,18,3,0,0,0
23,23,11,28,22,30,3,0,6,3
64,41,5,23,1,14,17,3,0,4
24,41,0,38,42,0,5,3,0,4
22,41,0,38,42,0,4,3,0,4
4,17,0,44,2,35,2,0,0,5
27,17,11,33,28,30,3,0,0,3
30,41,0,43,43,0,20,2,0,4
49,26,8,20,13,26,20,0,0,0
49,8,0,30,47,15,30,0,0,5
40,8,6,11,47,15,23,0,5,3
18,41,0,43,43,0,18,2,0,4
49,29,11,19,13,2,20,0,0,0
22,41,0,38,46,0,9,3,0,4
26,15,11,23,22,26,19,0,0,1
64,26,8,20,22,26,20,0,0,4
54,26,5,30,15,22,24,1,5,3
47,39,7,43,4,34,0,3,0,4
40,27,0,3,47,0,29,0,0,3
34,5,6,23,46,25,20,0,6,0
23,17,7,20,40,26,19,0,6,3
13,16,11,24,43,30,3,0,6,0
66,27,7,18,3,5,29,0,0,3
68,27,7,6,2,5,30,0,0,3
56,17,1,27,45,10,30,0,6,5
33,15,11,23,42,23,4,0,5,0
22,2,0,19,48,34,20,0,0,4
21,34,0,43,22,5,20,3,0,4
55,26,13,14,7,2,18,0,0,0
33,7,11,23,43,26,10,0,0,0
14,17,11,28,42,30,3,0,0,3
23,11,11,28,44,30,3,0,0,3
53,41,0,38,45,0,8,3,0,4
33,9,11,23,42,26,18,0,5,0
65,26,0,29,18,15,30,0,0,5
33,20,11,13,42,25,3,0,0,0
33,26,11,18,41,26,3,0,0,0
28,16,11,29,40,26,3,0,0,3
61,27,11,7,6,11,25,0,0,0
14,20,11,28,40,30,3,0,4,3
35,9,11,17,42,26,18,0,0,0
49,29,11,23,8,21,18,1,0,0
49,27,11,23,6,10,17,2,0,0
23,15,0,35,47,33,28,0,0,5
22,24,11,29,40,26,3,0,0,3
48,16,11,29,22,26,14,0,5,3
27,27,11,34,12,29,3,0,0,3
54,27,5,27,10,19,27,0,5,3
12,41,11,30,8,11,3,0,5,3
40,25,12,19,22,26,3,0,0,3
1,27,7,34,35,34,0,3,0,4
63,35,11,9,5,0,25,0,0,0
66,27,0,23,3,13,31,0,5,3
40,24,11,22,34,21,3,0,0,3
22,25,9,23,23,21,17,0,0,1
33,11,11,23,41,27,15,0,0,0
6,41,0,43,46,0,4,2,0,4
49,9,5,35,26,26,28,0,0,5
49,41,11,0,10,1,20,0,0,0
48,22,11,23,22,26,6,0,0,3
66,0,0,8,42,0,31,0,0,3
59,26,11,12,7,13,24,0,5,0
15,41,0,43,43,0,18,2,4,4
4,17,0,44,2,35,2,0,0,5
68,0,0,8,42,0,31,0,0,3
63,35,11,2,7,9,24,0,0,0
33,11,11,16,42,25,14,0,0,0
48,12,11,21,22,27,18,0,6,3
22,25,11,22,35,30,3,0,0,3
15,41,0,43,42,1,20,2,0,4
23,16,11,23,31,25,16,0,0,3
12,20,11,29,42,3,5,0,5,3
67,29,11,7,6,1,27,0,5,0
44,41,0,34,39,34,0,4,0,4
49,33,11,23,13,25,4,0,0,0
18,28,5,33,42,0,17,3,0,4
61,41,11,12,5,11,20,0,0,0
41,16,11,23,40,26,6,0,0,0
58,0,4,29,45,26,30,0,0,5
49,41,0,3,46,0,29,0,0,2
19,17,11,27,40,26,3,0,0,3
22,25,11,28,24,30,3,0,5,3
36,26,8,30,9,25,19,0,4,1
63,41,0,13,25,7,30,0,4,3
35,9,11,23,40,25,18,0,0,0
28,36,0,39,44,0,17,3,0,4
31,9,11,23,34,26,18,0,0,0
39,25,6,19,36,24,20,0,0,0
23,22,11,23,27,25,8,0,5,1
52,25,0,24,19,15,30,0,0,5
49,26,11,23,10,24,20,0,0,0
34,21,6,23,41,21,20,0,4,0
49,30,5,29,22,0,24,0,0,2
8,41,0,43,42,1,20,2,0,4
49,34,0,40,31,0,29,0,0,2
48,17,11,13,20,29,20,0,5,3
14,17,11,27,42,30,3,0,0,3
14,22,11,27,42,26,3,0,0,3
22,6,11,36,42,28,3,0,4,3
23,16,11,30,40,29,3,0,5,3
25,24,11,30,22,30,3,0,0,3
48,20,9,19,28,25,20,0,5,0
34,26,11,28,11,26,19,0,0,1
34,12,11,20,40,26,15,1,5,0
48,24,11,27,20,21,16,0,0,3
29,25,11,17,38,20,6,0,0,0
21,35,0,43,46,1,20,2,4,4
19,26,11,28,41,16,3,0,0,0
33,11,11,20,42,26,5,0,0,0
16,25,10,20,26,26,4,0,0,1
14,15,11,41,24,30,3,0,0,3
18,29,11,22,40,15,3,0,5,3
25,9,7,30,42,30,14,0,0,3
48,34,5,30,25,0,21,0,0,2
1 RI Na Mg Al Si 'K' Ca Ba Fe Type
2 35 11 11 16 41 29 18 0 0 0
3 22 3 11 23 40 25 13 0 0 1
4 35 21 11 26 23 26 6 0 0 0
5 3 41 5 28 48 0 3 0 0 2
6 68 4 0 13 0 10 31 0 6 3
7 22 9 6 27 42 25 19 1 5 3
8 34 26 11 5 41 2 20 0 0 1
9 46 20 6 23 38 24 20 0 0 0
10 8 34 0 43 45 5 20 2 2 4
11 35 20 12 23 20 24 8 0 6 3
12 22 24 11 27 28 16 3 0 0 3
13 32 4 7 17 46 27 20 0 6 3
14 63 21 11 9 10 10 28 0 0 0
15 57 32 11 5 7 9 24 0 0 1
16 23 20 11 32 22 26 4 0 5 1
17 27 26 11 30 22 27 3 0 0 3
18 28 41 0 38 41 0 13 3 4 4
19 22 9 7 27 43 30 3 0 0 3
20 50 23 0 32 41 17 30 0 0 5
21 40 16 7 27 40 26 19 1 0 3
22 58 19 11 12 18 13 27 0 5 0
23 66 0 0 37 17 33 31 0 6 3
24 48 16 11 15 20 29 20 0 5 3
25 33 25 11 19 34 25 3 0 5 0
26 52 12 4 42 15 32 25 1 8 5
27 10 22 11 27 42 15 3 0 0 3
28 14 11 11 37 39 30 3 0 0 3
29 24 41 0 38 43 0 10 3 3 4
30 25 22 11 27 36 25 3 0 0 3
31 16 24 11 23 25 25 4 0 0 1
32 23 16 11 30 42 30 3 0 0 3
33 45 24 8 28 14 25 20 0 0 1
34 0 41 0 1 48 0 1 0 0 2
35 22 26 11 28 22 30 3 0 0 3
36 33 17 11 23 41 25 5 0 5 0
37 11 9 11 29 42 30 3 0 6 0
38 14 11 11 30 40 29 3 0 6 0
39 30 4 6 30 40 30 20 0 0 3
40 23 12 11 27 41 30 3 1 6 3
41 5 37 7 39 20 33 3 0 0 3
42 37 9 11 23 40 29 18 0 0 0
43 38 17 11 13 42 25 4 0 5 3
44 22 17 11 28 35 26 3 0 0 3
45 14 22 8 27 42 16 3 0 0 3
46 49 26 11 8 24 11 20 1 6 1
47 33 9 11 20 42 26 17 0 0 0
48 6 31 5 44 0 34 0 4 0 5
49 33 21 11 23 22 25 3 0 0 0
50 35 19 11 23 39 26 10 0 0 0
51 60 20 11 17 31 23 10 0 0 3
52 33 7 11 23 45 26 14 0 3 0
53 48 20 11 13 35 25 6 1 5 3
54 32 23 11 16 42 26 3 0 0 0
55 14 20 11 28 42 30 3 0 0 3
56 22 41 0 43 38 0 23 2 0 4
57 31 17 11 30 30 23 8 0 3 0
58 64 41 5 38 1 32 26 0 0 4
59 55 27 11 16 10 6 22 0 0 0
60 34 26 11 15 33 9 16 0 0 1
61 48 25 12 23 22 21 4 0 0 3
62 48 26 12 23 10 23 4 0 6 3
63 49 26 11 15 23 10 18 0 0 0
64 9 23 11 20 31 25 15 0 0 0
65 63 35 11 2 7 9 24 0 0 0
66 64 27 11 2 6 6 28 0 5 0
67 8 28 0 43 42 10 23 3 2 4
68 49 20 7 23 21 26 20 0 5 0
69 62 35 11 13 5 13 20 0 7 1
70 68 0 0 41 0 26 31 4 6 3
71 58 19 11 12 20 13 27 0 5 0
72 42 41 5 30 22 0 21 0 0 2
73 48 26 11 23 22 25 3 0 5 3
74 48 41 2 30 22 0 27 0 0 2
75 42 22 12 26 20 24 4 0 5 3
76 64 23 11 9 10 10 28 0 2 0
77 22 26 11 27 22 29 3 0 0 3
78 33 7 11 27 42 25 14 0 0 0
79 2 17 11 16 41 27 4 0 6 0
80 22 18 10 23 41 21 15 0 0 1
81 29 16 11 23 41 25 6 0 0 0
82 33 11 11 23 41 26 15 0 0 0
83 31 23 11 21 42 24 3 0 0 0
84 57 39 12 11 5 0 24 0 0 1
85 34 21 9 23 31 26 15 0 0 0
86 58 1 5 29 39 17 30 0 0 5
87 38 27 12 28 8 23 3 0 5 3
88 68 8 0 6 10 2 31 0 0 3
89 33 11 11 27 31 23 10 0 5 0
90 33 17 11 20 41 30 13 0 0 0
91 60 27 3 23 17 15 29 0 0 3
92 22 41 0 36 42 0 17 3 0 4
93 35 9 11 19 40 27 18 0 6 0
94 58 20 11 12 18 13 27 0 5 0
95 49 27 5 19 31 0 27 0 0 2
96 6 41 0 43 46 0 5 2 0 4
97 59 26 11 12 10 11 24 0 3 0
98 30 41 0 33 41 0 17 3 0 4
99 51 29 3 30 6 16 29 0 5 3
100 15 16 11 27 42 16 3 0 0 3
101 48 20 12 19 22 26 6 0 0 3
102 33 29 11 23 30 18 3 0 0 0
103 23 23 11 28 22 30 3 0 6 3
104 64 41 5 23 1 14 17 3 0 4
105 24 41 0 38 42 0 5 3 0 4
106 22 41 0 38 42 0 4 3 0 4
107 4 17 0 44 2 35 2 0 0 5
108 27 17 11 33 28 30 3 0 0 3
109 30 41 0 43 43 0 20 2 0 4
110 49 26 8 20 13 26 20 0 0 0
111 49 8 0 30 47 15 30 0 0 5
112 40 8 6 11 47 15 23 0 5 3
113 18 41 0 43 43 0 18 2 0 4
114 49 29 11 19 13 2 20 0 0 0
115 22 41 0 38 46 0 9 3 0 4
116 26 15 11 23 22 26 19 0 0 1
117 64 26 8 20 22 26 20 0 0 4
118 54 26 5 30 15 22 24 1 5 3
119 47 39 7 43 4 34 0 3 0 4
120 40 27 0 3 47 0 29 0 0 3
121 34 5 6 23 46 25 20 0 6 0
122 23 17 7 20 40 26 19 0 6 3
123 13 16 11 24 43 30 3 0 6 0
124 66 27 7 18 3 5 29 0 0 3
125 68 27 7 6 2 5 30 0 0 3
126 56 17 1 27 45 10 30 0 6 5
127 33 15 11 23 42 23 4 0 5 0
128 22 2 0 19 48 34 20 0 0 4
129 21 34 0 43 22 5 20 3 0 4
130 55 26 13 14 7 2 18 0 0 0
131 33 7 11 23 43 26 10 0 0 0
132 14 17 11 28 42 30 3 0 0 3
133 23 11 11 28 44 30 3 0 0 3
134 53 41 0 38 45 0 8 3 0 4
135 33 9 11 23 42 26 18 0 5 0
136 65 26 0 29 18 15 30 0 0 5
137 33 20 11 13 42 25 3 0 0 0
138 33 26 11 18 41 26 3 0 0 0
139 28 16 11 29 40 26 3 0 0 3
140 61 27 11 7 6 11 25 0 0 0
141 14 20 11 28 40 30 3 0 4 3
142 35 9 11 17 42 26 18 0 0 0
143 49 29 11 23 8 21 18 1 0 0
144 49 27 11 23 6 10 17 2 0 0
145 23 15 0 35 47 33 28 0 0 5
146 22 24 11 29 40 26 3 0 0 3
147 48 16 11 29 22 26 14 0 5 3
148 27 27 11 34 12 29 3 0 0 3
149 54 27 5 27 10 19 27 0 5 3
150 12 41 11 30 8 11 3 0 5 3
151 40 25 12 19 22 26 3 0 0 3
152 1 27 7 34 35 34 0 3 0 4
153 63 35 11 9 5 0 25 0 0 0
154 66 27 0 23 3 13 31 0 5 3
155 40 24 11 22 34 21 3 0 0 3
156 22 25 9 23 23 21 17 0 0 1
157 33 11 11 23 41 27 15 0 0 0
158 6 41 0 43 46 0 4 2 0 4
159 49 9 5 35 26 26 28 0 0 5
160 49 41 11 0 10 1 20 0 0 0
161 48 22 11 23 22 26 6 0 0 3
162 66 0 0 8 42 0 31 0 0 3
163 59 26 11 12 7 13 24 0 5 0
164 15 41 0 43 43 0 18 2 4 4
165 4 17 0 44 2 35 2 0 0 5
166 68 0 0 8 42 0 31 0 0 3
167 63 35 11 2 7 9 24 0 0 0
168 33 11 11 16 42 25 14 0 0 0
169 48 12 11 21 22 27 18 0 6 3
170 22 25 11 22 35 30 3 0 0 3
171 15 41 0 43 42 1 20 2 0 4
172 23 16 11 23 31 25 16 0 0 3
173 12 20 11 29 42 3 5 0 5 3
174 67 29 11 7 6 1 27 0 5 0
175 44 41 0 34 39 34 0 4 0 4
176 49 33 11 23 13 25 4 0 0 0
177 18 28 5 33 42 0 17 3 0 4
178 61 41 11 12 5 11 20 0 0 0
179 41 16 11 23 40 26 6 0 0 0
180 58 0 4 29 45 26 30 0 0 5
181 49 41 0 3 46 0 29 0 0 2
182 19 17 11 27 40 26 3 0 0 3
183 22 25 11 28 24 30 3 0 5 3
184 36 26 8 30 9 25 19 0 4 1
185 63 41 0 13 25 7 30 0 4 3
186 35 9 11 23 40 25 18 0 0 0
187 28 36 0 39 44 0 17 3 0 4
188 31 9 11 23 34 26 18 0 0 0
189 39 25 6 19 36 24 20 0 0 0
190 23 22 11 23 27 25 8 0 5 1
191 52 25 0 24 19 15 30 0 0 5
192 49 26 11 23 10 24 20 0 0 0
193 34 21 6 23 41 21 20 0 4 0
194 49 30 5 29 22 0 24 0 0 2
195 8 41 0 43 42 1 20 2 0 4
196 49 34 0 40 31 0 29 0 0 2
197 48 17 11 13 20 29 20 0 5 3
198 14 17 11 27 42 30 3 0 0 3
199 14 22 11 27 42 26 3 0 0 3
200 22 6 11 36 42 28 3 0 4 3
201 23 16 11 30 40 29 3 0 5 3
202 25 24 11 30 22 30 3 0 0 3
203 48 20 9 19 28 25 20 0 5 0
204 34 26 11 28 11 26 19 0 0 1
205 34 12 11 20 40 26 15 1 5 0
206 48 24 11 27 20 21 16 0 0 3
207 29 25 11 17 38 20 6 0 0 0
208 21 35 0 43 46 1 20 2 4 4
209 19 26 11 28 41 16 3 0 0 0
210 33 11 11 20 42 26 5 0 0 0
211 16 25 10 20 26 26 4 0 0 1
212 14 15 11 41 24 30 3 0 0 3
213 18 29 11 22 40 15 3 0 5 3
214 25 9 7 30 42 30 14 0 0 3
215 48 34 5 30 25 0 21 0 0 2

1209
test.ipynb

File diff suppressed because one or more lines are too long

111
test.py Normal file
View File

@@ -0,0 +1,111 @@
#!/usr/bin/env python
# coding: utf-8
# In[1]:
from mdlp import MDLP
import pandas as pd
from benchmark import Datasets
from bayesclass import TAN
from sklearn.model_selection import (
cross_validate,
StratifiedKFold,
KFold,
cross_val_score,
train_test_split,
)
import numpy as np
import warnings
from stree import Stree
# In[2]:
# Get data as a dataset
dt = Datasets()
data = dt.load("glass", dataframe=True)
features = dt.dataset.features
class_name = dt.dataset.class_name
factorization, class_factors = pd.factorize(data[class_name])
data[class_name] = factorization
data.head()
# In[3]:
# Fayyad Irani
discretiz = MDLP()
Xdisc = discretiz.fit_transform(
data[features].to_numpy(), data[class_name].to_numpy()
)
features_discretized = pd.DataFrame(Xdisc, columns=features)
dataset_discretized = features_discretized.copy()
dataset_discretized[class_name] = data[class_name]
X = dataset_discretized[features]
y = dataset_discretized[class_name]
dataset_discretized
# In[4]:
n_folds = 5
score_name = "accuracy"
random_state = 17
test_size = 0.3
def validate_classifier(model, X, y, stratified, fit_params):
stratified_class = StratifiedKFold if stratified else KFold
kfold = stratified_class(
shuffle=True, random_state=random_state, n_splits=n_folds
)
# return cross_validate(model, X, y, cv=kfold, return_estimator=True,
# scoring=score_name)
return cross_val_score(model, X, y, fit_params=fit_params)
def split_data(X, y, stratified):
if stratified:
return train_test_split(
X,
y,
test_size=test_size,
random_state=random_state,
stratify=y,
shuffle=True,
)
else:
return train_test_split(
X, y, test_size=test_size, random_state=random_state, shuffle=True
)
# In[5]:
warnings.filterwarnings("ignore")
for simple_init in [False, True]:
model = TAN(simple_init=simple_init)
for head in range(4):
X_train, X_test, y_train, y_test = split_data(X, y, stratified=False)
model.fit(
X_train,
y_train,
head=head,
features=features,
class_name=class_name,
)
y = model.predict(X_test)
model.plot()
# In[ ]:
model = TAN(simple_init=simple_init)
model.fit(X, y, features=features, class_name=class_name)
model.plot(
f"**simple_init={simple_init} head={head} score={model.score(X, y)}"
)