diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index 166a2d3..5868027 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -47,6 +47,12 @@ class BayesBase(BaseEstimator, ClassifierMixin): def default_class_name(): return "class" + def build_dataset(self): + self.dataset_ = pd.DataFrame( + self.X_, columns=self.feature_names_in_, dtype=np.int32 + ) + self.dataset_[self.class_name_] = self.y_ + def _check_params_fit(self, X, y, expected_args, kwargs): """Check the common parameters passed to fit""" # Check that X and y have correct shape @@ -64,6 +70,10 @@ class BayesBase(BaseEstimator, ClassifierMixin): else: raise ValueError(f"Unexpected argument: {key}") self.feature_names_in_ = self.features_ + # used for local discretization + self.indexed_features_ = { + feature: i for i, feature in enumerate(self.features_) + } if self.random_state is not None: random.seed(self.random_state) if len(self.feature_names_in_) != X.shape[1]: @@ -125,10 +135,7 @@ class BayesBase(BaseEstimator, ClassifierMixin): # Store the information needed to build the model self.X_ = X_ self.y_ = y_ - self.dataset_ = pd.DataFrame( - self.X_, columns=self.feature_names_in_, dtype=np.int32 - ) - self.dataset_[self.class_name_] = self.y_ + self.build_dataset() # Build the DAG self._build() # Train the model @@ -660,14 +667,8 @@ class Proposal(BaseEstimator): # Build the model super(self.class_type, self.estimator).fit(self.Xd, y, **kwargs) # Local discretization based on the model - features = kwargs["features"] - # assign indices to feature names - self.idx_features_ = dict(list(zip(features, range(len(features))))) - upgraded, self.Xd = self._local_discretization() + self._local_discretization() # self.check_integrity("fit", self.Xd) - if upgraded: - kwargs = self.update_kwargs(y, kwargs) - super(self.class_type, self.estimator).fit(self.Xd, y, **kwargs) self.fitted_ = True return self @@ -705,27 +706,45 @@ class Proposal(BaseEstimator): def _local_discretization(self): """Discretize each feature with its fathers and the class""" - res = self.Xd.copy() - upgraded = False - # print("-" * 80) - for idx, feature in enumerate(self.estimator.feature_names_in_): + upgrade = False + # order of local discretization is important. no good 0, 1, 2... + ancestral_order = list(nx.topological_sort(self.estimator.dag_)) + for feature in ancestral_order: + if feature == self.estimator.class_name_: + continue + idx = self.estimator.indexed_features_[feature] fathers = self.estimator.dag_.get_parents(feature) if len(fathers) > 1: - # print( - # "Discretizing " + feature + " with " + str(fathers), - # end=" ", - # ) # First remove the class name as it will be added later fathers.remove(self.estimator.class_name_) # Get the fathers indices - features = [self.idx_features_[f] for f in fathers] + features = [ + self.estimator.indexed_features_[f] for f in fathers + ] # Update the discretization of the feature - res[:, idx] = self.discretizer_.join_fit( - target=idx, features=features, data=self.Xd + self.Xd[:, idx] = self.discretizer_.join_fit( + # each feature has to use previous discretization data=res + target=idx, + features=features, + data=self.Xd, ) - # print(self.discretizer.y_join[:5]) - upgraded = True - return upgraded, res + upgrade = True + if upgrade: + # Update the dataset + self.estimator.X_ = self.Xd + self.estimator.build_dataset() + self.state_names_ = { + key: self.discretizer_.get_states_feature(value) + for key, value in self.estimator.indexed_features_.items() + } + states = {"state_names": self.state_names_} + # Update the model + self.estimator.model_.fit( + self.estimator.dataset_, + estimator=BayesianEstimator, + prior_type="K2", + **states, + ) # def check_integrity(self, source, X): # # print(f"Checking integrity of {source} data") diff --git a/bayesclass/test.py b/bayesclass/test.py new file mode 100644 index 0000000..fd983d6 --- /dev/null +++ b/bayesclass/test.py @@ -0,0 +1,19 @@ +from bayesclass.clfs import AODENew, TANNew, KDBNew, AODE +from benchmark.datasets import Datasets +import os + +os.chdir("../discretizbench") +dt = Datasets() +clfan = AODENew() +clftn = TANNew() +clfkn = KDBNew() +# clfa = AODE() +X, y = dt.load("iris") +# clfa.fit(X, y) +clfan.fit(X, y) +clftn.fit(X, y) +clfkn.fit(X, y) + + +self.discretizer_.target_ +self.estimator.indexed_features_ diff --git a/bayesclass/tests/baseline_images/test_KDB/line_dashes_KDB.png b/bayesclass/tests/baseline_images/test_KDB/line_dashes_KDB.png index 45f293b..376c3de 100644 Binary files a/bayesclass/tests/baseline_images/test_KDB/line_dashes_KDB.png and b/bayesclass/tests/baseline_images/test_KDB/line_dashes_KDB.png differ diff --git a/bayesclass/tests/baseline_images/test_KDBNew/line_dashes_KDBNew.png b/bayesclass/tests/baseline_images/test_KDBNew/line_dashes_KDBNew.png index 0fd3312..376c3de 100644 Binary files a/bayesclass/tests/baseline_images/test_KDBNew/line_dashes_KDBNew.png and b/bayesclass/tests/baseline_images/test_KDBNew/line_dashes_KDBNew.png differ diff --git a/bayesclass/tests/baseline_images/test_TANNew/line_dashes_TANNew.png b/bayesclass/tests/baseline_images/test_TANNew/line_dashes_TANNew.png index 3f55599..b9fe3b0 100644 Binary files a/bayesclass/tests/baseline_images/test_TANNew/line_dashes_TANNew.png and b/bayesclass/tests/baseline_images/test_TANNew/line_dashes_TANNew.png differ diff --git a/bayesclass/tests/conftest.py b/bayesclass/tests/conftest.py new file mode 100644 index 0000000..6447b12 --- /dev/null +++ b/bayesclass/tests/conftest.py @@ -0,0 +1,38 @@ +import pytest +from sklearn.datasets import load_iris +from fimdlp.mdlp import FImdlp + + +@pytest.fixture +def iris(): + dataset = load_iris() + X = dataset["data"] + y = dataset["target"] + features = dataset["feature_names"] + # To make iris dataset has the same values as our iris.arff dataset + patch = {(34, 3): (0.2, 0.1), (37, 1): (3.6, 3.1), (37, 2): (1.4, 1.5)} + for key, value in patch.items(): + X[key] = value[1] + return X, y, features + + +@pytest.fixture +def data(iris): + return iris[0], iris[1] + + +@pytest.fixture +def features(iris): + return iris[2] + + +@pytest.fixture +def class_name(): + return "class" + + +@pytest.fixture +def data_disc(data): + clf = FImdlp() + X, y = data + return clf.fit_transform(X, y), y diff --git a/bayesclass/tests/test_AODE.py b/bayesclass/tests/test_AODE.py index ff79665..bc1e954 100644 --- a/bayesclass/tests/test_AODE.py +++ b/bayesclass/tests/test_AODE.py @@ -1,6 +1,5 @@ import pytest import numpy as np -from sklearn.datasets import load_iris from sklearn.preprocessing import KBinsDiscretizer from matplotlib.testing.decorators import image_comparison from matplotlib.testing.conftest import mpl_test_settings @@ -10,26 +9,19 @@ from bayesclass.clfs import AODE from .._version import __version__ -@pytest.fixture -def data(): - X, y = load_iris(return_X_y=True) - enc = KBinsDiscretizer(encode="ordinal") - return enc.fit_transform(X), y - - @pytest.fixture def clf(): return AODE(random_state=17) -def test_AODE_default_hyperparameters(data, clf): +def test_AODE_default_hyperparameters(data_disc, clf): # Test default values of hyperparameters assert not clf.show_progress assert clf.random_state == 17 clf = AODE(show_progress=True) assert clf.show_progress assert clf.random_state is None - clf.fit(*data) + clf.fit(*data_disc) assert clf.class_name_ == "class" assert clf.feature_names_in_ == [ "feature_0", @@ -42,37 +34,35 @@ def test_AODE_default_hyperparameters(data, clf): @image_comparison( baseline_images=["line_dashes_AODE"], remove_text=True, extensions=["png"] ) -def test_AODE_plot(data, clf): +def test_AODE_plot(data_disc, features, clf): # mpl_test_settings will automatically clean these internal side effects mpl_test_settings - dataset = load_iris(as_frame=True) - clf.fit(*data, features=dataset["feature_names"]) + clf.fit(*data_disc, features=features) clf.plot("AODE Iris") -def test_AODE_version(clf, data): +def test_AODE_version(clf, features, data_disc): """Check AODE version.""" assert __version__ == clf.version() - dataset = load_iris(as_frame=True) - clf.fit(*data, features=dataset["feature_names"]) + clf.fit(*data_disc, features=features) assert __version__ == clf.version() -def test_AODE_nodes_edges(clf, data): +def test_AODE_nodes_edges(clf, data_disc): assert clf.nodes_edges() == (0, 0) - clf.fit(*data) + clf.fit(*data_disc) assert clf.nodes_leaves() == (20, 28) -def test_AODE_states(clf, data): +def test_AODE_states(clf, data_disc): assert clf.states_ == 0 - clf.fit(*data) - assert clf.states_ == 23 + clf.fit(*data_disc) + assert clf.states_ == 19 assert clf.depth_ == clf.states_ -def test_AODE_classifier(data, clf): - clf.fit(*data) +def test_AODE_classifier(data_disc, clf): + clf.fit(*data_disc) attribs = [ "feature_names_in_", "class_name_", @@ -82,28 +72,28 @@ def test_AODE_classifier(data, clf): ] for attr in attribs: assert hasattr(clf, attr) - X = data[0] - y = data[1] + X = data_disc[0] + y = data_disc[1] y_pred = clf.predict(X) assert y_pred.shape == (X.shape[0],) - assert sum(y == y_pred) == 147 + assert sum(y == y_pred) == 146 -def test_AODE_wrong_num_features(data, clf): +def test_AODE_wrong_num_features(data_disc, clf): with pytest.raises( ValueError, match="Number of features does not match the number of columns in X", ): - clf.fit(*data, features=["feature_1", "feature_2"]) + clf.fit(*data_disc, features=["feature_1", "feature_2"]) -def test_AODE_wrong_hyperparam(data, clf): +def test_AODE_wrong_hyperparam(data_disc, clf): with pytest.raises(ValueError, match="Unexpected argument: wrong_param"): - clf.fit(*data, wrong_param="wrong_param") + clf.fit(*data_disc, wrong_param="wrong_param") -def test_AODE_error_size_predict(data, clf): - X, y = data +def test_AODE_error_size_predict(data_disc, clf): + X, y = data_disc clf.fit(X, y) with pytest.raises(ValueError): X_diff_size = np.ones((10, X.shape[1] + 1)) diff --git a/bayesclass/tests/test_AODENew.py b/bayesclass/tests/test_AODENew.py index 72a73cb..11ee782 100644 --- a/bayesclass/tests/test_AODENew.py +++ b/bayesclass/tests/test_AODENew.py @@ -1,7 +1,5 @@ import pytest import numpy as np -from sklearn.datasets import load_iris -from sklearn.preprocessing import KBinsDiscretizer from matplotlib.testing.decorators import image_comparison from matplotlib.testing.conftest import mpl_test_settings @@ -10,13 +8,6 @@ from bayesclass.clfs import AODENew from .._version import __version__ -@pytest.fixture -def data(): - X, y = load_iris(return_X_y=True) - enc = KBinsDiscretizer(encode="ordinal") - return enc.fit_transform(X), y - - @pytest.fixture def clf(): return AODENew(random_state=17) @@ -44,19 +35,17 @@ def test_AODENew_default_hyperparameters(data, clf): remove_text=True, extensions=["png"], ) -def test_AODENew_plot(data, clf): +def test_AODENew_plot(data, features, clf): # mpl_test_settings will automatically clean these internal side effects mpl_test_settings - dataset = load_iris(as_frame=True) - clf.fit(*data, features=dataset["feature_names"]) + clf.fit(*data, features=features) clf.plot("AODE Iris") def test_AODENew_version(clf, data): """Check AODENew version.""" assert __version__ == clf.version() - dataset = load_iris(as_frame=True) - clf.fit(*data, features=dataset["feature_names"]) + clf.fit(*data) assert __version__ == clf.version() @@ -69,7 +58,7 @@ def test_AODENew_nodes_edges(clf, data): def test_AODENew_states(clf, data): assert clf.states_ == 0 clf.fit(*data) - assert clf.states_ == 22.75 + assert clf.states_ == 17.75 assert clf.depth_ == clf.states_ @@ -88,17 +77,17 @@ def test_AODENew_classifier(data, clf): y = data[1] y_pred = clf.predict(X) assert y_pred.shape == (X.shape[0],) - assert sum(y == y_pred) == 147 + assert sum(y == y_pred) == 146 -def test_AODENew_local_discretization(clf, data): +def test_AODENew_local_discretization(clf, data_disc): expected_data = [ [-1, [0, -1], [0, -1], [0, -1]], [[1, -1], -1, [1, -1], [1, -1]], [[2, -1], [2, -1], -1, [2, -1]], [[3, -1], [3, -1], [3, -1], -1], ] - clf.fit(*data) + clf.fit(*data_disc) for idx, estimator in enumerate(clf.estimators_): expected = expected_data[idx] for feature in range(4): diff --git a/bayesclass/tests/test_KDB.py b/bayesclass/tests/test_KDB.py index acda865..aa35751 100644 --- a/bayesclass/tests/test_KDB.py +++ b/bayesclass/tests/test_KDB.py @@ -1,6 +1,5 @@ import pytest import numpy as np -from sklearn.datasets import load_iris from sklearn.preprocessing import KBinsDiscretizer from matplotlib.testing.decorators import image_comparison from matplotlib.testing.conftest import mpl_test_settings @@ -11,19 +10,12 @@ from bayesclass.clfs import KDB from .._version import __version__ -@pytest.fixture -def data(): - X, y = load_iris(return_X_y=True) - enc = KBinsDiscretizer(encode="ordinal") - return enc.fit_transform(X), y - - @pytest.fixture def clf(): return KDB(k=3) -def test_KDB_default_hyperparameters(data, clf): +def test_KDB_default_hyperparameters(data_disc, clf): # Test default values of hyperparameters assert not clf.show_progress assert clf.random_state is None @@ -32,7 +24,7 @@ def test_KDB_default_hyperparameters(data, clf): assert clf.show_progress assert clf.random_state == 17 assert clf.k == 3 - clf.fit(*data) + clf.fit(*data_disc) assert clf.class_name_ == "class" assert clf.feature_names_in_ == [ "feature_0", @@ -47,57 +39,56 @@ def test_KDB_version(clf): assert __version__ == clf.version() -def test_KDB_nodes_edges(clf, data): +def test_KDB_nodes_edges(clf, data_disc): assert clf.nodes_edges() == (0, 0) - clf.fit(*data) - assert clf.nodes_leaves() == (5, 10) + clf.fit(*data_disc) + assert clf.nodes_leaves() == (5, 9) -def test_KDB_states(clf, data): +def test_KDB_states(clf, data_disc): assert clf.states_ == 0 - clf.fit(*data) - assert clf.states_ == 23 + clf.fit(*data_disc) + assert clf.states_ == 19 assert clf.depth_ == clf.states_ -def test_KDB_classifier(data, clf): - clf.fit(*data) +def test_KDB_classifier(data_disc, clf): + clf.fit(*data_disc) attribs = ["classes_", "X_", "y_", "feature_names_in_", "class_name_"] for attr in attribs: assert hasattr(clf, attr) - X = data[0] - y = data[1] + X = data_disc[0] + y = data_disc[1] y_pred = clf.predict(X) assert y_pred.shape == (X.shape[0],) - assert sum(y == y_pred) == 148 + assert sum(y == y_pred) == 146 @image_comparison( baseline_images=["line_dashes_KDB"], remove_text=True, extensions=["png"] ) -def test_KDB_plot(data, clf): +def test_KDB_plot(data_disc, features, clf): # mpl_test_settings will automatically clean these internal side effects mpl_test_settings - dataset = load_iris(as_frame=True) - clf.fit(*data, features=dataset["feature_names"]) + clf.fit(*data_disc, features=features) clf.plot("KDB Iris") -def test_KDB_wrong_num_features(data, clf): +def test_KDB_wrong_num_features(data_disc, clf): with pytest.raises( ValueError, match="Number of features does not match the number of columns in X", ): - clf.fit(*data, features=["feature_1", "feature_2"]) + clf.fit(*data_disc, features=["feature_1", "feature_2"]) -def test_KDB_wrong_hyperparam(data, clf): +def test_KDB_wrong_hyperparam(data_disc, clf): with pytest.raises(ValueError, match="Unexpected argument: wrong_param"): - clf.fit(*data, wrong_param="wrong_param") + clf.fit(*data_disc, wrong_param="wrong_param") -def test_KDB_error_size_predict(data, clf): - X, y = data +def test_KDB_error_size_predict(data_disc, clf): + X, y = data_disc clf.fit(X, y) with pytest.raises(ValueError): X_diff_size = np.ones((10, X.shape[1] + 1)) diff --git a/bayesclass/tests/test_KDBNew.py b/bayesclass/tests/test_KDBNew.py index b14d731..36de4da 100644 --- a/bayesclass/tests/test_KDBNew.py +++ b/bayesclass/tests/test_KDBNew.py @@ -1,7 +1,5 @@ import pytest import numpy as np -from sklearn.datasets import load_iris -from sklearn.preprocessing import KBinsDiscretizer from matplotlib.testing.decorators import image_comparison from matplotlib.testing.conftest import mpl_test_settings from pgmpy.models import BayesianNetwork @@ -11,13 +9,6 @@ from bayesclass.clfs import KDBNew from .._version import __version__ -@pytest.fixture -def data(): - X, y = load_iris(return_X_y=True) - enc = KBinsDiscretizer(encode="ordinal") - return enc.fit_transform(X), y - - @pytest.fixture def clf(): return KDBNew(k=3) @@ -50,13 +41,13 @@ def test_KDBNew_version(clf): def test_KDBNew_nodes_edges(clf, data): assert clf.nodes_edges() == (0, 0) clf.fit(*data) - assert clf.nodes_leaves() == (5, 10) + assert clf.nodes_leaves() == (5, 9) def test_KDBNew_states(clf, data): assert clf.states_ == 0 clf.fit(*data) - assert clf.states_ == 23 + assert clf.states_ == 22 assert clf.depth_ == clf.states_ @@ -69,14 +60,15 @@ def test_KDBNew_classifier(data, clf): y = data[1] y_pred = clf.predict(X) assert y_pred.shape == (X.shape[0],) - assert sum(y == y_pred) == 148 + assert sum(y == y_pred) == 145 def test_KDBNew_local_discretization(clf, data): - expected = [[1, -1], -1, [0, 1, 3, -1], [1, 0, -1]] + expected = [[1, -1], -1, [0, 1, 3, -1], [1, -1]] clf.fit(*data) for feature in range(4): computed = clf.estimator_.discretizer_.target_[feature] + print("computed:", computed) if type(computed) == list: for j, k in zip(expected[feature], computed): assert j == k @@ -92,11 +84,10 @@ def test_KDBNew_local_discretization(clf, data): remove_text=True, extensions=["png"], ) -def test_KDBNew_plot(data, clf): +def test_KDBNew_plot(data, features, class_name, clf): # mpl_test_settings will automatically clean these internal side effects mpl_test_settings - dataset = load_iris(as_frame=True) - clf.fit(*data, features=dataset["feature_names"]) + clf.fit(*data, features=features, class_name=class_name) clf.plot("KDBNew Iris") diff --git a/bayesclass/tests/test_TAN.py b/bayesclass/tests/test_TAN.py index 95ec46a..8f10461 100644 --- a/bayesclass/tests/test_TAN.py +++ b/bayesclass/tests/test_TAN.py @@ -1,7 +1,5 @@ import pytest import numpy as np -from sklearn.datasets import load_iris -from sklearn.preprocessing import KBinsDiscretizer from matplotlib.testing.decorators import image_comparison from matplotlib.testing.conftest import mpl_test_settings @@ -10,26 +8,19 @@ from bayesclass.clfs import TAN from .._version import __version__ -@pytest.fixture -def data(): - X, y = load_iris(return_X_y=True) - enc = KBinsDiscretizer(encode="ordinal") - return enc.fit_transform(X), y - - @pytest.fixture def clf(): return TAN(random_state=17) -def test_TAN_default_hyperparameters(data, clf): +def test_TAN_default_hyperparameters(data_disc, clf): # Test default values of hyperparameters assert not clf.show_progress assert clf.random_state == 17 clf = TAN(show_progress=True) assert clf.show_progress assert clf.random_state is None - clf.fit(*data) + clf.fit(*data_disc) assert clf.head_ == 0 assert clf.class_name_ == "class" assert clf.feature_names_in_ == [ @@ -45,26 +36,26 @@ def test_TAN_version(clf): assert __version__ == clf.version() -def test_TAN_nodes_edges(clf, data): +def test_TAN_nodes_edges(clf, data_disc): assert clf.nodes_edges() == (0, 0) - clf.fit(*data, head="random") + clf.fit(*data_disc, head="random") assert clf.nodes_leaves() == (5, 7) -def test_TAN_states(clf, data): +def test_TAN_states(clf, data_disc): assert clf.states_ == 0 - clf.fit(*data) - assert clf.states_ == 23 + clf.fit(*data_disc) + assert clf.states_ == 19 assert clf.depth_ == clf.states_ -def test_TAN_random_head(clf, data): - clf.fit(*data, head="random") +def test_TAN_random_head(clf, data_disc): + clf.fit(*data_disc, head="random") assert clf.head_ == 3 -def test_TAN_classifier(data, clf): - clf.fit(*data) +def test_TAN_classifier(data_disc, clf): + clf.fit(*data_disc) attribs = [ "classes_", "X_", @@ -75,44 +66,43 @@ def test_TAN_classifier(data, clf): ] for attr in attribs: assert hasattr(clf, attr) - X = data[0] - y = data[1] + X = data_disc[0] + y = data_disc[1] y_pred = clf.predict(X) assert y_pred.shape == (X.shape[0],) - assert sum(y == y_pred) == 147 + assert sum(y == y_pred) == 146 @image_comparison( baseline_images=["line_dashes_TAN"], remove_text=True, extensions=["png"] ) -def test_TAN_plot(data, clf): +def test_TAN_plot(data_disc, features, clf): # mpl_test_settings will automatically clean these internal side effects mpl_test_settings - dataset = load_iris(as_frame=True) - clf.fit(*data, features=dataset["feature_names"], head=0) + clf.fit(*data_disc, features=features, head=0) clf.plot("TAN Iris head=0") -def test_TAN_wrong_num_features(data, clf): +def test_TAN_wrong_num_features(data_disc, clf): with pytest.raises( ValueError, match="Number of features does not match the number of columns in X", ): - clf.fit(*data, features=["feature_1", "feature_2"]) + clf.fit(*data_disc, features=["feature_1", "feature_2"]) -def test_TAN_wrong_hyperparam(data, clf): +def test_TAN_wrong_hyperparam(data_disc, clf): with pytest.raises(ValueError, match="Unexpected argument: wrong_param"): - clf.fit(*data, wrong_param="wrong_param") + clf.fit(*data_disc, wrong_param="wrong_param") -def test_TAN_head_out_of_range(data, clf): +def test_TAN_head_out_of_range(data_disc, clf): with pytest.raises(ValueError, match="Head index out of range"): - clf.fit(*data, head=4) + clf.fit(*data_disc, head=4) -def test_TAN_error_size_predict(data, clf): - X, y = data +def test_TAN_error_size_predict(data_disc, clf): + X, y = data_disc clf.fit(X, y) with pytest.raises(ValueError): X_diff_size = np.ones((10, X.shape[1] + 1)) diff --git a/bayesclass/tests/test_TANNew.py b/bayesclass/tests/test_TANNew.py index 406222e..506330d 100644 --- a/bayesclass/tests/test_TANNew.py +++ b/bayesclass/tests/test_TANNew.py @@ -1,7 +1,5 @@ import pytest import numpy as np -from sklearn.datasets import load_iris -from sklearn.preprocessing import KBinsDiscretizer from matplotlib.testing.decorators import image_comparison from matplotlib.testing.conftest import mpl_test_settings @@ -10,13 +8,6 @@ from bayesclass.clfs import TANNew from .._version import __version__ -@pytest.fixture -def data(): - X, y = load_iris(return_X_y=True) - enc = KBinsDiscretizer(encode="ordinal") - return enc.fit_transform(X), y - - @pytest.fixture def clf(): return TANNew(random_state=17) @@ -54,7 +45,7 @@ def test_TANNew_nodes_edges(clf, data): def test_TANNew_states(clf, data): assert clf.states_ == 0 clf.fit(*data) - assert clf.states_ == 22 + assert clf.states_ == 18 assert clf.depth_ == clf.states_ @@ -88,7 +79,7 @@ def test_TANNew_classifier(data, clf): y = data[1] y_pred = clf.predict(X) assert y_pred.shape == (X.shape[0],) - assert sum(y == y_pred) == 145 + assert sum(y == y_pred) == 146 @image_comparison( @@ -96,11 +87,10 @@ def test_TANNew_classifier(data, clf): remove_text=True, extensions=["png"], ) -def test_TANNew_plot(data, clf): +def test_TANNew_plot(data, features, clf): # mpl_test_settings will automatically clean these internal side effects mpl_test_settings - dataset = load_iris(as_frame=True) - clf.fit(*data, features=dataset["feature_names"], head=0) + clf.fit(*data, features=features, head=0) clf.plot("TANNew Iris head=0")