Continue refactoring
This commit is contained in:
@@ -28,6 +28,7 @@ add_executable(
|
||||
results/Result.cpp
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/XBAODE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
)
|
||||
target_link_libraries(b_best Boost::boost "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
||||
|
||||
@@ -41,6 +42,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources}
|
||||
results/Result.cpp
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/XBAODE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
)
|
||||
target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy)
|
||||
|
||||
@@ -52,6 +54,7 @@ add_executable(b_list commands/b_list.cpp
|
||||
results/Result.cpp results/ResultsDatasetExcel.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/XBAODE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
)
|
||||
target_link_libraries(b_list "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
||||
|
||||
@@ -64,6 +67,7 @@ add_executable(b_main commands/b_main.cpp ${main_sources}
|
||||
results/Result.cpp
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/XBAODE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
)
|
||||
target_link_libraries(b_main "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy)
|
||||
|
||||
|
@@ -8,7 +8,7 @@
|
||||
#include "TensorUtils.hpp"
|
||||
|
||||
namespace platform {
|
||||
ExpClf::ExpClf() : semaphore_{ CountingSemaphore::getInstance() }
|
||||
ExpClf::ExpClf() : semaphore_{ CountingSemaphore::getInstance() }, Boost(false)
|
||||
{
|
||||
}
|
||||
void ExpClf::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
|
@@ -12,14 +12,14 @@
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include "bayesnet/BaseClassifier.h"
|
||||
#include "bayesnet/ensembles/Boost.h"
|
||||
#include "common/Timer.hpp"
|
||||
#include "CountingSemaphore.hpp"
|
||||
#include "Xaode.hpp"
|
||||
|
||||
namespace platform {
|
||||
|
||||
class ExpClf : public bayesnet::BaseClassifier {
|
||||
class ExpClf : public bayesnet::Boost {
|
||||
public:
|
||||
ExpClf();
|
||||
virtual ~ExpClf() = default;
|
||||
|
@@ -13,14 +13,14 @@
|
||||
#include <loguru.hpp>
|
||||
|
||||
namespace platform {
|
||||
XBAODE::XBAODE() : Boost(false)
|
||||
XBAODE::XBAODE()
|
||||
{
|
||||
Boost::validHyperparameters = { "alpha_block", "order", "convergence", "convergence_best", "bisection", "threshold", "maxTolerance",
|
||||
validHyperparameters = { "alpha_block", "order", "convergence", "convergence_best", "bisection", "threshold", "maxTolerance",
|
||||
"predict_voting", "select_features" };
|
||||
}
|
||||
void XBAODE::trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing)
|
||||
{
|
||||
Boost::fitted = true;
|
||||
fitted = true;
|
||||
X_train_ = TensorUtils::to_matrix(X_train);
|
||||
y_train_ = TensorUtils::to_vector<int>(y_train);
|
||||
X_test_ = TensorUtils::to_matrix(X_test);
|
||||
@@ -44,7 +44,7 @@ namespace platform {
|
||||
if (selectFeatures) {
|
||||
featuresUsed = featureSelection(weights_);
|
||||
set_active_parents(featuresUsed);
|
||||
Boost::notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
|
||||
notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
|
||||
auto ypred = ExpClf::predict(X_train);
|
||||
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
|
||||
// Update significance of the models
|
||||
@@ -146,7 +146,7 @@ namespace platform {
|
||||
}
|
||||
if (tolerance > maxTolerance) {
|
||||
if (numItemsPack < n_models) {
|
||||
Boost::notes.push_back("Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated");
|
||||
notes.push_back("Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated");
|
||||
VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models);
|
||||
for (int i = featuresUsed.size() - 1; i >= featuresUsed.size() - numItemsPack; --i) {
|
||||
aode_.remove_last_parent();
|
||||
@@ -155,15 +155,15 @@ namespace platform {
|
||||
}
|
||||
VLOG_SCOPE_F(4, "*Convergence threshold %d models left & %d features used.", n_models, featuresUsed.size());
|
||||
} else {
|
||||
Boost::notes.push_back("Convergence threshold reached & 0 models eliminated");
|
||||
notes.push_back("Convergence threshold reached & 0 models eliminated");
|
||||
VLOG_SCOPE_F(4, "Convergence threshold reached & 0 models eliminated n_models=%d numItemsPack=%d", n_models, numItemsPack);
|
||||
}
|
||||
}
|
||||
if (featuresUsed.size() != features.size()) {
|
||||
Boost::notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));
|
||||
Boost::status = bayesnet::WARNING;
|
||||
notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));
|
||||
status = bayesnet::WARNING;
|
||||
}
|
||||
Boost::notes.push_back("Number of models: " + std::to_string(n_models));
|
||||
notes.push_back("Number of models: " + std::to_string(n_models));
|
||||
return;
|
||||
}
|
||||
}
|
@@ -12,11 +12,10 @@
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include "common/Timer.hpp"
|
||||
#include "bayesnet/ensembles/Boost.h"
|
||||
#include "ExpClf.h"
|
||||
|
||||
namespace platform {
|
||||
class XBAODE : public bayesnet::Boost, public ExpClf {
|
||||
class XBAODE : public ExpClf {
|
||||
public:
|
||||
XBAODE();
|
||||
virtual ~XBAODE() = default;
|
||||
|
Reference in New Issue
Block a user