Replace len(self.features_) by self.n_features_in_

This commit is contained in:
2023-01-27 12:34:34 +01:00
parent 4d416959ad
commit a4edc74e8d

View File

@@ -248,8 +248,8 @@ class TAN(BayesBase):
expected_args = ["class_name", "features", "head", "state_names"]
X, y = self._check_params_fit(X, y, expected_args, kwargs)
if self.head_ == "random":
self.head_ = random.randint(0, len(self.features_) - 1)
if self.head_ is not None and self.head_ >= len(self.features_):
self.head_ = random.randint(0, self.n_features_in_ - 1)
if self.head_ is not None and self.head_ >= self.n_features_in_:
raise ValueError("Head index out of range")
return X, y
@@ -398,7 +398,7 @@ class AODE(BayesBase, BaseEnsemble):
self.models_ = []
class_edges = [(self.class_name_, f) for f in self.features_]
states = dict(state_names=kwargs.pop("state_names", []))
for idx in range(len(self.features_)):
for idx in range(self.n_features_in_):
feature_edges = [
(self.features_[idx], f)
for f in self.features_