From ea251aca0578a573894335855df1887cf1499b47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Thu, 23 Mar 2023 22:15:38 +0100 Subject: [PATCH] Begin AODE implementation --- bayesclass/clfs.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index b6a83a1..ca0baf0 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -510,7 +510,23 @@ class KDBNew(KDB): class AODENew(AODE): - pass + def fit(self, X, y, **kwargs): + self.estimator = Proposal(self) + return self.estimator.fit(X, y, **kwargs) + + def predict(self, X: np.ndarray) -> np.ndarray: + check_is_fitted(self, ["X_", "y_", "fitted_"]) + # Input validation + X = check_array(X) + n_samples = X.shape[0] + n_estimators = len(self.models_) + result = np.empty((n_samples, n_estimators)) + dataset = pd.DataFrame( + X, columns=self.feature_names_in_, dtype=np.int32 + ) + for index, model in enumerate(self.models_): + result[:, index] = model.predict(dataset).values.ravel() + return mode(result, axis=1, keepdims=False).mode.ravel() class Proposal: