mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-15 23:55:57 +00:00
Fix AODENew tests
This commit is contained in:
@@ -607,25 +607,23 @@ class AODENew(AODE):
|
||||
def states_(self):
|
||||
if hasattr(self, "fitted_"):
|
||||
return sum(
|
||||
[
|
||||
len(item)
|
||||
for model in self.models_
|
||||
for _, item in model.states.items()
|
||||
]
|
||||
[model.estimator.states_ for model in self.models_]
|
||||
) / len(self.models_)
|
||||
return 0
|
||||
|
||||
@property
|
||||
def depth_(self):
|
||||
return self.states_
|
||||
|
||||
def nodes_edges(self):
|
||||
nodes = 0
|
||||
edges = 0
|
||||
nodes = [0]
|
||||
edges = [0]
|
||||
if hasattr(self, "fitted_"):
|
||||
nodes = sum([len(x.estimator.dag_) for x in self.models_])
|
||||
edges = sum([len(x.estimator.dag_.edges()) for x in self.models_])
|
||||
return nodes, edges
|
||||
nodes, edges = zip(
|
||||
*[model.estimator.nodes_edges() for model in self.models_]
|
||||
)
|
||||
return sum(nodes), sum(edges)
|
||||
|
||||
def plot(self, title=""):
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
for idx, model in enumerate(self.models_):
|
||||
model.estimator.plot(title=f"{idx} {title}")
|
||||
|
||||
|
||||
class Proposal:
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 55 KiB |
Binary file not shown.
After Width: | Height: | Size: 55 KiB |
@@ -40,7 +40,9 @@ def test_AODENew_default_hyperparameters(data, clf):
|
||||
|
||||
|
||||
@image_comparison(
|
||||
baseline_images=["line_dashes_AODE"], remove_text=True, extensions=["png"]
|
||||
baseline_images=["line_dashes_AODENew"],
|
||||
remove_text=True,
|
||||
extensions=["png"],
|
||||
)
|
||||
def test_AODENew_plot(data, clf):
|
||||
# mpl_test_settings will automatically clean these internal side effects
|
||||
@@ -64,7 +66,7 @@ def test_AODENew_nodes_edges(clf, data):
|
||||
def test_AODENew_states(clf, data):
|
||||
assert clf.states_ == 0
|
||||
clf.fit(*data)
|
||||
assert clf.states_ == 23
|
||||
assert clf.states_ == 22.75
|
||||
assert clf.depth_ == clf.states_
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user