From c9ab88e47519a29f90f33a4a014e7a21a9239b2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 17 Mar 2025 13:28:35 +0100 Subject: [PATCH] Update models and remove normalize weights in XA1DE --- src/experimental_clfs/XA1DE.cpp | 4 ++-- src/main/Models.h | 3 ++- src/main/modelRegister.h | 6 ++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/experimental_clfs/XA1DE.cpp b/src/experimental_clfs/XA1DE.cpp index 24a48e8..ba7dd70 100644 --- a/src/experimental_clfs/XA1DE.cpp +++ b/src/experimental_clfs/XA1DE.cpp @@ -14,7 +14,7 @@ namespace platform { auto y = TensorUtils::to_vector(dataset.index({ -1, "..." })); int num_instances = X[0].size(); weights_ = torch::full({ num_instances }, 1.0); - normalize_weights(num_instances); + //normalize_weights(num_instances); aode_.fit(X, y, features, className, states, weights_, true, smoothing); } -} \ No newline at end of file +} diff --git a/src/main/Models.h b/src/main/Models.h index e0b44c1..565a96d 100644 --- a/src/main/Models.h +++ b/src/main/Models.h @@ -6,13 +6,14 @@ #include #include #include +#include #include #include #include #include #include #include -#include +#include #include #include #include diff --git a/src/main/modelRegister.h b/src/main/modelRegister.h index b8e1f95..0dbd269 100644 --- a/src/main/modelRegister.h +++ b/src/main/modelRegister.h @@ -37,10 +37,12 @@ namespace platform { [](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();}); static Registrar registrarXSPODE("XSPODE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);}); - static Registrar registrarXSPnDE("XSPnDE", - [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpnde(0, 1);}); + static Registrar registrarXSP2DE("XSP2DE", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSp2de(0, 1);}); static Registrar registrarXBAODE("XBAODE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XBAODE();}); + static Registrar registrarXBA2DE("XBA2DE", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XBA2DE();}); static Registrar registrarXA1DE("XA1DE", [](void) -> bayesnet::BaseClassifier* { return new XA1DE();}); }