Refator kdb with new BayesNetwork

This commit is contained in:
2023-07-08 10:40:33 +02:00
parent 260997c872
commit 36cc875615

View File

@@ -16,6 +16,7 @@ from pgmpy.base import DAG
import matplotlib.pyplot as plt
from fimdlp.mdlp import FImdlp
from .cppSelectFeatures import CSelectKBestWeighted
from .BayesNet import BayesNetwork
from ._version import __version__
@@ -166,17 +167,27 @@ class BayesBase(BaseEstimator, ClassifierMixin):
kwargs : dict
fit parameters
"""
self.model_ = BayesianNetwork(
self.dag_.edges(), show_progress=self.show_progress
)
states = dict(state_names=kwargs.pop("state_names", []))
self.model_.fit(
self.dataset_,
estimator=BayesianEstimator,
prior_type="K2",
weighted=self.weighted_,
**states,
)
# self.model_ = BayesianNetwork(
# self.dag_.edges(), show_progress=self.show_progress
# )
# states = dict(state_names=kwargs.pop("state_names", []))
# self.model_.fit(
# self.dataset_,
# estimator=BayesianEstimator,
# prior_type="K2",
# weighted=self.weighted_,
# **states,
# )
self.model_ = BayesNetwork()
features = kwargs["features"]
for i, feature in enumerate(features):
maxf = max(self.X_[:, i] + 1)
self.model_.addNode(feature, maxf)
class_name = kwargs["class_name"]
self.model_.addNode(class_name, max(self.y_) + 1)
for source, destination in self.dag_.edges():
self.model_.addEdge(source, destination)
self.model_.fit(self.X_, self.y_, features, class_name)
def predict(self, X):
"""A reference implementation of a prediction for a classifier.
@@ -228,10 +239,11 @@ class BayesBase(BaseEstimator, ClassifierMixin):
check_is_fitted(self, ["X_", "y_", "fitted_"])
# Input validation
X = check_array(X)
dataset = pd.DataFrame(
X, columns=self.feature_names_in_, dtype=np.int32
)
return self.model_.predict(dataset).values.ravel()
# dataset = pd.DataFrame(
# X, columns=self.feature_names_in_, dtype=np.int32
# )
# return self.model_.predict(dataset).values.ravel()
return self.model_.predict(X)
def plot(self, title="", node_size=800):
warnings.simplefilter("ignore", UserWarning)
@@ -398,7 +410,7 @@ class KDB(BayesBase):
# 3. Let the used variable list, S, be empty.
S_nodes = []
# 4. Let the DAG being constructed, BN, begin with a single class node
dag = BayesianNetwork(show_progress=self.show_progress)
dag = BayesianNetwork()
dag.add_node(self.class_name_) # , state_names=self.classes_)
# 5. Repeat until S includes all domain features
# 5.1 Select feature Xmax which is not in S and has the largest value