diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index 8fb028a..e1983eb 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -607,25 +607,23 @@ class AODENew(AODE): def states_(self): if hasattr(self, "fitted_"): return sum( - [ - len(item) - for model in self.models_ - for _, item in model.states.items() - ] + [model.estimator.states_ for model in self.models_] ) / len(self.models_) return 0 - @property - def depth_(self): - return self.states_ - def nodes_edges(self): - nodes = 0 - edges = 0 + nodes = [0] + edges = [0] if hasattr(self, "fitted_"): - nodes = sum([len(x.estimator.dag_) for x in self.models_]) - edges = sum([len(x.estimator.dag_.edges()) for x in self.models_]) - return nodes, edges + nodes, edges = zip( + *[model.estimator.nodes_edges() for model in self.models_] + ) + return sum(nodes), sum(edges) + + def plot(self, title=""): + warnings.simplefilter("ignore", UserWarning) + for idx, model in enumerate(self.models_): + model.estimator.plot(title=f"{idx} {title}") class Proposal: diff --git a/bayesclass/tests/baseline_images/test_AODENew/line_dashes_AODENew-expected.png b/bayesclass/tests/baseline_images/test_AODENew/line_dashes_AODENew-expected.png new file mode 100644 index 0000000..054b92f Binary files /dev/null and b/bayesclass/tests/baseline_images/test_AODENew/line_dashes_AODENew-expected.png differ diff --git a/bayesclass/tests/baseline_images/test_AODENew/line_dashes_AODENew.png b/bayesclass/tests/baseline_images/test_AODENew/line_dashes_AODENew.png new file mode 100644 index 0000000..054b92f Binary files /dev/null and b/bayesclass/tests/baseline_images/test_AODENew/line_dashes_AODENew.png differ diff --git a/bayesclass/tests/test_AODENew.py b/bayesclass/tests/test_AODENew.py index 9a7047b..a843326 100644 --- a/bayesclass/tests/test_AODENew.py +++ b/bayesclass/tests/test_AODENew.py @@ -40,7 +40,9 @@ def test_AODENew_default_hyperparameters(data, clf): @image_comparison( - baseline_images=["line_dashes_AODE"], remove_text=True, extensions=["png"] + baseline_images=["line_dashes_AODENew"], + remove_text=True, + extensions=["png"], ) def test_AODENew_plot(data, clf): # mpl_test_settings will automatically clean these internal side effects @@ -64,7 +66,7 @@ def test_AODENew_nodes_edges(clf, data): def test_AODENew_states(clf, data): assert clf.states_ == 0 clf.fit(*data) - assert clf.states_ == 23 + assert clf.states_ == 22.75 assert clf.depth_ == clf.states_