mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-17 16:45:54 +00:00
Fix AODENew tests
This commit is contained in:
@@ -607,25 +607,23 @@ class AODENew(AODE):
|
|||||||
def states_(self):
|
def states_(self):
|
||||||
if hasattr(self, "fitted_"):
|
if hasattr(self, "fitted_"):
|
||||||
return sum(
|
return sum(
|
||||||
[
|
[model.estimator.states_ for model in self.models_]
|
||||||
len(item)
|
|
||||||
for model in self.models_
|
|
||||||
for _, item in model.states.items()
|
|
||||||
]
|
|
||||||
) / len(self.models_)
|
) / len(self.models_)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@property
|
|
||||||
def depth_(self):
|
|
||||||
return self.states_
|
|
||||||
|
|
||||||
def nodes_edges(self):
|
def nodes_edges(self):
|
||||||
nodes = 0
|
nodes = [0]
|
||||||
edges = 0
|
edges = [0]
|
||||||
if hasattr(self, "fitted_"):
|
if hasattr(self, "fitted_"):
|
||||||
nodes = sum([len(x.estimator.dag_) for x in self.models_])
|
nodes, edges = zip(
|
||||||
edges = sum([len(x.estimator.dag_.edges()) for x in self.models_])
|
*[model.estimator.nodes_edges() for model in self.models_]
|
||||||
return nodes, edges
|
)
|
||||||
|
return sum(nodes), sum(edges)
|
||||||
|
|
||||||
|
def plot(self, title=""):
|
||||||
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
|
for idx, model in enumerate(self.models_):
|
||||||
|
model.estimator.plot(title=f"{idx} {title}")
|
||||||
|
|
||||||
|
|
||||||
class Proposal:
|
class Proposal:
|
||||||
|
Binary file not shown.
After Width: | Height: | Size: 55 KiB |
Binary file not shown.
After Width: | Height: | Size: 55 KiB |
@@ -40,7 +40,9 @@ def test_AODENew_default_hyperparameters(data, clf):
|
|||||||
|
|
||||||
|
|
||||||
@image_comparison(
|
@image_comparison(
|
||||||
baseline_images=["line_dashes_AODE"], remove_text=True, extensions=["png"]
|
baseline_images=["line_dashes_AODENew"],
|
||||||
|
remove_text=True,
|
||||||
|
extensions=["png"],
|
||||||
)
|
)
|
||||||
def test_AODENew_plot(data, clf):
|
def test_AODENew_plot(data, clf):
|
||||||
# mpl_test_settings will automatically clean these internal side effects
|
# mpl_test_settings will automatically clean these internal side effects
|
||||||
@@ -64,7 +66,7 @@ def test_AODENew_nodes_edges(clf, data):
|
|||||||
def test_AODENew_states(clf, data):
|
def test_AODENew_states(clf, data):
|
||||||
assert clf.states_ == 0
|
assert clf.states_ == 0
|
||||||
clf.fit(*data)
|
clf.fit(*data)
|
||||||
assert clf.states_ == 23
|
assert clf.states_ == 22.75
|
||||||
assert clf.depth_ == clf.states_
|
assert clf.depth_ == clf.states_
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user