diff --git a/Makefile b/Makefile index bad7fa9..5563222 100644 --- a/Makefile +++ b/Makefile @@ -237,7 +237,7 @@ sample: ## Build sample with Conan @if [ -d ./sample/build ]; then rm -rf ./sample/build; fi @cd sample && conan install . --output-folder=build --build=missing -s build_type=$(build_type) -o "&:enable_coverage=False" -o "&:enable_testing=False" @cd sample && cmake -B build -S . -DCMAKE_BUILD_TYPE=$(build_type) -DCMAKE_TOOLCHAIN_FILE=build/conan_toolchain.cmake && \ - cmake --build build -t bayesnet_sample + cmake --build build -t bayesnet_sample --parallel $(JOBS) sample/build/bayesnet_sample $(fname) $(model) @echo ">>> Done"; diff --git a/bayesnet/classifiers/Classifier.h b/bayesnet/classifiers/Classifier.h index 0a10a1f..95b6da2 100644 --- a/bayesnet/classifiers/Classifier.h +++ b/bayesnet/classifiers/Classifier.h @@ -37,6 +37,7 @@ namespace bayesnet { std::vector getNotes() const override { return notes; } std::string dump_cpt() const override; void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters + Network& getModel() { return model; } protected: bool fitted; unsigned int m, n; // m: number of samples, n: number of features diff --git a/bayesnet/classifiers/KDBLd.cc b/bayesnet/classifiers/KDBLd.cc index 541e005..32aa690 100644 --- a/bayesnet/classifiers/KDBLd.cc +++ b/bayesnet/classifiers/KDBLd.cc @@ -5,6 +5,7 @@ // *************************************************************** #include "KDBLd.h" +#include namespace bayesnet { KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className) @@ -35,7 +36,7 @@ namespace bayesnet { y = y_; // Use iterative local discretization instead of the two-phase approach - states = iterativeLocalDiscretization(y, this, dataset, features, className, states_, smoothing); + states = iterativeLocalDiscretization(y, static_cast(this), dataset, features, className, states_, smoothing); // Final fit with converged discretization KDB::fit(dataset, features, className, states, smoothing); @@ -56,4 +57,4 @@ namespace bayesnet { { return KDB::graph(name); } -} \ No newline at end of file +} diff --git a/bayesnet/classifiers/Proposal.cc b/bayesnet/classifiers/Proposal.cc index aa0698d..4dde35a 100644 --- a/bayesnet/classifiers/Proposal.cc +++ b/bayesnet/classifiers/Proposal.cc @@ -8,6 +8,11 @@ #include #include #include +#include "Classifier.h" +#include "KDB.h" +#include "TAN.h" +#include "KDBLd.h" +#include "TANLd.h" namespace bayesnet { Proposal::Proposal(torch::Tensor& dataset_, std::vector& features_, std::string& className_) : pDataset(dataset_), pFeatures(features_), pClassName(className_) @@ -180,7 +185,7 @@ namespace bayesnet { map> Proposal::iterativeLocalDiscretization( const torch::Tensor& y, Classifier* classifier, - const torch::Tensor& dataset, + torch::Tensor& dataset, const std::vector& features, const std::string& className, const map>& initialStates, @@ -196,19 +201,20 @@ namespace bayesnet { << convergence_params.maxIterations << " max iterations" << std::endl; } + const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble); for (int iteration = 0; iteration < convergence_params.maxIterations; ++iteration) { if (convergence_params.verbose) { std::cout << "Iteration " << (iteration + 1) << "/" << convergence_params.maxIterations << std::endl; } // Phase 2: Build model with current discretization - classifier->fit(dataset, features, className, currentStates, smoothing); - + classifier->fit(dataset, features, className, currentStates, weights, smoothing); + // Phase 3: Network-aware discretization refinement - currentStates = localDiscretizationProposal(currentStates, classifier->model); + currentStates = localDiscretizationProposal(currentStates, classifier->getModel()); // Check convergence - if (iteration > 0 && previousModel == classifier->model) { + if (iteration > 0 && previousModel == classifier->getModel()) { if (convergence_params.verbose) { std::cout << "Converged after " << (iteration + 1) << " iterations" << std::endl; } @@ -216,7 +222,7 @@ namespace bayesnet { } // Update for next iteration - previousModel = classifier->model; + previousModel = classifier->getModel(); } return currentStates; @@ -262,7 +268,11 @@ namespace bayesnet { } // Explicit template instantiation for common classifier types - // template map> Proposal::iterativeLocalDiscretization( - // const torch::Tensor&, Classifier*, const torch::Tensor&, const std::vector&, - // const std::string&, const map>&, Smoothing_t); + template map> Proposal::iterativeLocalDiscretization( + const torch::Tensor&, KDB*, torch::Tensor&, const std::vector&, + const std::string&, const map>&, Smoothing_t); + + template map> Proposal::iterativeLocalDiscretization( + const torch::Tensor&, TAN*, torch::Tensor&, const std::vector&, + const std::string&, const map>&, Smoothing_t); } diff --git a/bayesnet/classifiers/Proposal.h b/bayesnet/classifiers/Proposal.h index 150508a..b5685d9 100644 --- a/bayesnet/classifiers/Proposal.h +++ b/bayesnet/classifiers/Proposal.h @@ -31,7 +31,7 @@ namespace bayesnet { map> iterativeLocalDiscretization( const torch::Tensor& y, Classifier* classifier, - const torch::Tensor& dataset, + torch::Tensor& dataset, const std::vector& features, const std::string& className, const map>& initialStates, diff --git a/bayesnet/classifiers/TANLd.cc b/bayesnet/classifiers/TANLd.cc index d5a8dda..783681c 100644 --- a/bayesnet/classifiers/TANLd.cc +++ b/bayesnet/classifiers/TANLd.cc @@ -5,6 +5,7 @@ // *************************************************************** #include "TANLd.h" +#include namespace bayesnet { TANLd::TANLd() : TAN(), Proposal(dataset, features, className) {} @@ -17,7 +18,7 @@ namespace bayesnet { y = y_; // Use iterative local discretization instead of the two-phase approach - states = iterativeLocalDiscretization(y, this, dataset, features, className, states_, smoothing); + states = iterativeLocalDiscretization(y, static_cast(this), dataset, features, className, states_, smoothing); // Final fit with converged discretization TAN::fit(dataset, features, className, states, smoothing); @@ -38,4 +39,4 @@ namespace bayesnet { { return TAN::graph(name); } -} \ No newline at end of file +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 788a6eb..6f9959f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,7 +8,7 @@ if(ENABLE_TESTING) add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc TestXSPnDE.cc TestXBA2DE.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc TestXBAODE.cc TestA2DE.cc TestUtils.cc TestBayesEnsemble.cc TestModulesVersions.cc TestBoostA2DE.cc TestMST.cc TestXSPODE.cc ${BayesNet_SOURCES}) - target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" fimdlp::fimdlp PRIVATE Catch2::Catch2WithMain folding::folding) + target_link_libraries(TestBayesNet PRIVATE torch::torch fimdlp::fimdlp Catch2::Catch2WithMain folding::folding) add_test(NAME BayesNetworkTest COMMAND TestBayesNet) add_test(NAME A2DE COMMAND TestBayesNet "[A2DE]") add_test(NAME BoostA2DE COMMAND TestBayesNet "[BoostA2DE]")