diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index b9cf8de..0126c36 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -475,7 +475,6 @@ class KDBNew(KDB): return self.estimator.fit(X, y, **kwargs) def predict(self, X): - self.plot() return self.estimator.predict(X) @@ -492,14 +491,14 @@ class Proposal: # Build the model super(self.class_type, self.estimator).fit(self.Xd, y, **kwargs) self.check_integrity("f", self.Xd) - # # Local discretization based on the model - # features = kwargs["features"] - # # assign indices to feature names - # self.idx_features_ = dict(list(zip(features, range(len(features))))) - # upgraded, self.Xd = self._local_discretization() - # if upgraded: - # kwargs = self.update_kwargs(y, kwargs) - # super(self.class_type, self.estimator).fit(self.Xd, y, **kwargs) + # Local discretization based on the model + features = kwargs["features"] + # assign indices to feature names + self.idx_features_ = dict(list(zip(features, range(len(features))))) + upgraded, self.Xd = self._local_discretization() + if upgraded: + kwargs = self.update_kwargs(y, kwargs) + super(self.class_type, self.estimator).fit(self.Xd, y, **kwargs) def predict(self, X): self.check_integrity("p", self.discretizer.transform(X)) @@ -534,14 +533,14 @@ class Proposal: """Discretize each feature with its fathers and the class""" res = self.Xd.copy() upgraded = False - print("-" * 80) + # print("-" * 80) for idx, feature in enumerate(self.estimator.feature_names_in_): fathers = self.estimator.dag_.get_parents(feature) if len(fathers) > 1: - print( - "Discretizing " + feature + " with " + str(fathers), - end=" ", - ) + # print( + # "Discretizing " + feature + " with " + str(fathers), + # end=" ", + # ) # First remove the class name as it will be added later fathers.remove(self.estimator.class_name_) # Get the fathers indices @@ -550,12 +549,12 @@ class Proposal: res[:, idx] = self.discretizer.join_fit( target=idx, features=features, data=self.Xd ) - print(self.discretizer.y_join[:5]) + # print(self.discretizer.y_join[:5]) upgraded = True return upgraded, res def check_integrity(self, source, X): - print(f"Checking integrity of {source} data") + # print(f"Checking integrity of {source} data") for i in range(X.shape[1]): if not set(np.unique(X[:, i]).tolist()).issubset( set(self.state_names_[self.features_[i]])