diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index b368f0b..50c49b8 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -16,6 +16,7 @@ from pgmpy.base import DAG import matplotlib.pyplot as plt from fimdlp.mdlp import FImdlp from .cppSelectFeatures import CSelectKBestWeighted +from .BayesNet import BayesNetwork from ._version import __version__ @@ -166,17 +167,27 @@ class BayesBase(BaseEstimator, ClassifierMixin): kwargs : dict fit parameters """ - self.model_ = BayesianNetwork( - self.dag_.edges(), show_progress=self.show_progress - ) - states = dict(state_names=kwargs.pop("state_names", [])) - self.model_.fit( - self.dataset_, - estimator=BayesianEstimator, - prior_type="K2", - weighted=self.weighted_, - **states, - ) + # self.model_ = BayesianNetwork( + # self.dag_.edges(), show_progress=self.show_progress + # ) + # states = dict(state_names=kwargs.pop("state_names", [])) + # self.model_.fit( + # self.dataset_, + # estimator=BayesianEstimator, + # prior_type="K2", + # weighted=self.weighted_, + # **states, + # ) + self.model_ = BayesNetwork() + features = kwargs["features"] + for i, feature in enumerate(features): + maxf = max(self.X_[:, i] + 1) + self.model_.addNode(feature, maxf) + class_name = kwargs["class_name"] + self.model_.addNode(class_name, max(self.y_) + 1) + for source, destination in self.dag_.edges(): + self.model_.addEdge(source, destination) + self.model_.fit(self.X_, self.y_, features, class_name) def predict(self, X): """A reference implementation of a prediction for a classifier. @@ -228,10 +239,11 @@ class BayesBase(BaseEstimator, ClassifierMixin): check_is_fitted(self, ["X_", "y_", "fitted_"]) # Input validation X = check_array(X) - dataset = pd.DataFrame( - X, columns=self.feature_names_in_, dtype=np.int32 - ) - return self.model_.predict(dataset).values.ravel() + # dataset = pd.DataFrame( + # X, columns=self.feature_names_in_, dtype=np.int32 + # ) + # return self.model_.predict(dataset).values.ravel() + return self.model_.predict(X) def plot(self, title="", node_size=800): warnings.simplefilter("ignore", UserWarning) @@ -398,7 +410,7 @@ class KDB(BayesBase): # 3. Let the used variable list, S, be empty. S_nodes = [] # 4. Let the DAG being constructed, BN, begin with a single class node - dag = BayesianNetwork(show_progress=self.show_progress) + dag = BayesianNetwork() dag.add_node(self.class_name_) # , state_names=self.classes_) # 5. Repeat until S includes all domain features # 5.1 Select feature Xmax which is not in S and has the largest value