Fix AODENew tests

This commit is contained in:
2023-03-30 21:03:42 +02:00
parent 3af05c9511
commit c9afafbf60
4 changed files with 16 additions and 16 deletions

View File

@@ -607,25 +607,23 @@ class AODENew(AODE):
def states_(self):
if hasattr(self, "fitted_"):
return sum(
[
len(item)
for model in self.models_
for _, item in model.states.items()
]
[model.estimator.states_ for model in self.models_]
) / len(self.models_)
return 0
@property
def depth_(self):
return self.states_
def nodes_edges(self):
nodes = 0
edges = 0
nodes = [0]
edges = [0]
if hasattr(self, "fitted_"):
nodes = sum([len(x.estimator.dag_) for x in self.models_])
edges = sum([len(x.estimator.dag_.edges()) for x in self.models_])
return nodes, edges
nodes, edges = zip(
*[model.estimator.nodes_edges() for model in self.models_]
)
return sum(nodes), sum(edges)
def plot(self, title=""):
warnings.simplefilter("ignore", UserWarning)
for idx, model in enumerate(self.models_):
model.estimator.plot(title=f"{idx} {title}")
class Proposal:

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

View File

@@ -40,7 +40,9 @@ def test_AODENew_default_hyperparameters(data, clf):
@image_comparison(
baseline_images=["line_dashes_AODE"], remove_text=True, extensions=["png"]
baseline_images=["line_dashes_AODENew"],
remove_text=True,
extensions=["png"],
)
def test_AODENew_plot(data, clf):
# mpl_test_settings will automatically clean these internal side effects
@@ -64,7 +66,7 @@ def test_AODENew_nodes_edges(clf, data):
def test_AODENew_states(clf, data):
assert clf.states_ == 0
clf.fit(*data)
assert clf.states_ == 23
assert clf.states_ == 22.75
assert clf.depth_ == clf.states_