Fix classifier build in proposal

This commit is contained in:
2025-07-07 02:10:08 +02:00
parent 0ce7f664b4
commit 2c7352ac38
7 changed files with 29 additions and 16 deletions

View File

@@ -237,7 +237,7 @@ sample: ## Build sample with Conan
@if [ -d ./sample/build ]; then rm -rf ./sample/build; fi
@cd sample && conan install . --output-folder=build --build=missing -s build_type=$(build_type) -o "&:enable_coverage=False" -o "&:enable_testing=False"
@cd sample && cmake -B build -S . -DCMAKE_BUILD_TYPE=$(build_type) -DCMAKE_TOOLCHAIN_FILE=build/conan_toolchain.cmake && \
cmake --build build -t bayesnet_sample
cmake --build build -t bayesnet_sample --parallel $(JOBS)
sample/build/bayesnet_sample $(fname) $(model)
@echo ">>> Done";

View File

@@ -37,6 +37,7 @@ namespace bayesnet {
std::vector<std::string> getNotes() const override { return notes; }
std::string dump_cpt() const override;
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
Network& getModel() { return model; }
protected:
bool fitted;
unsigned int m, n; // m: number of samples, n: number of features

View File

@@ -5,6 +5,7 @@
// ***************************************************************
#include "KDBLd.h"
#include <memory>
namespace bayesnet {
KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className)
@@ -35,7 +36,7 @@ namespace bayesnet {
y = y_;
// Use iterative local discretization instead of the two-phase approach
states = iterativeLocalDiscretization(y, this, dataset, features, className, states_, smoothing);
states = iterativeLocalDiscretization(y, static_cast<KDB*>(this), dataset, features, className, states_, smoothing);
// Final fit with converged discretization
KDB::fit(dataset, features, className, states, smoothing);
@@ -56,4 +57,4 @@ namespace bayesnet {
{
return KDB::graph(name);
}
}
}

View File

@@ -8,6 +8,11 @@
#include <iostream>
#include <cmath>
#include <limits>
#include "Classifier.h"
#include "KDB.h"
#include "TAN.h"
#include "KDBLd.h"
#include "TANLd.h"
namespace bayesnet {
Proposal::Proposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_) : pDataset(dataset_), pFeatures(features_), pClassName(className_)
@@ -180,7 +185,7 @@ namespace bayesnet {
map<std::string, std::vector<int>> Proposal::iterativeLocalDiscretization(
const torch::Tensor& y,
Classifier* classifier,
const torch::Tensor& dataset,
torch::Tensor& dataset,
const std::vector<std::string>& features,
const std::string& className,
const map<std::string, std::vector<int>>& initialStates,
@@ -196,19 +201,20 @@ namespace bayesnet {
<< convergence_params.maxIterations << " max iterations" << std::endl;
}
const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble);
for (int iteration = 0; iteration < convergence_params.maxIterations; ++iteration) {
if (convergence_params.verbose) {
std::cout << "Iteration " << (iteration + 1) << "/" << convergence_params.maxIterations << std::endl;
}
// Phase 2: Build model with current discretization
classifier->fit(dataset, features, className, currentStates, smoothing);
classifier->fit(dataset, features, className, currentStates, weights, smoothing);
// Phase 3: Network-aware discretization refinement
currentStates = localDiscretizationProposal(currentStates, classifier->model);
currentStates = localDiscretizationProposal(currentStates, classifier->getModel());
// Check convergence
if (iteration > 0 && previousModel == classifier->model) {
if (iteration > 0 && previousModel == classifier->getModel()) {
if (convergence_params.verbose) {
std::cout << "Converged after " << (iteration + 1) << " iterations" << std::endl;
}
@@ -216,7 +222,7 @@ namespace bayesnet {
}
// Update for next iteration
previousModel = classifier->model;
previousModel = classifier->getModel();
}
return currentStates;
@@ -262,7 +268,11 @@ namespace bayesnet {
}
// Explicit template instantiation for common classifier types
// template map<std::string, std::vector<int>> Proposal::iterativeLocalDiscretization<Classifier>(
// const torch::Tensor&, Classifier*, const torch::Tensor&, const std::vector<std::string>&,
// const std::string&, const map<std::string, std::vector<int>>&, Smoothing_t);
template map<std::string, std::vector<int>> Proposal::iterativeLocalDiscretization<KDB>(
const torch::Tensor&, KDB*, torch::Tensor&, const std::vector<std::string>&,
const std::string&, const map<std::string, std::vector<int>>&, Smoothing_t);
template map<std::string, std::vector<int>> Proposal::iterativeLocalDiscretization<TAN>(
const torch::Tensor&, TAN*, torch::Tensor&, const std::vector<std::string>&,
const std::string&, const map<std::string, std::vector<int>>&, Smoothing_t);
}

View File

@@ -31,7 +31,7 @@ namespace bayesnet {
map<std::string, std::vector<int>> iterativeLocalDiscretization(
const torch::Tensor& y,
Classifier* classifier,
const torch::Tensor& dataset,
torch::Tensor& dataset,
const std::vector<std::string>& features,
const std::string& className,
const map<std::string, std::vector<int>>& initialStates,

View File

@@ -5,6 +5,7 @@
// ***************************************************************
#include "TANLd.h"
#include <memory>
namespace bayesnet {
TANLd::TANLd() : TAN(), Proposal(dataset, features, className) {}
@@ -17,7 +18,7 @@ namespace bayesnet {
y = y_;
// Use iterative local discretization instead of the two-phase approach
states = iterativeLocalDiscretization(y, this, dataset, features, className, states_, smoothing);
states = iterativeLocalDiscretization(y, static_cast<TAN*>(this), dataset, features, className, states_, smoothing);
// Final fit with converged discretization
TAN::fit(dataset, features, className, states, smoothing);
@@ -38,4 +39,4 @@ namespace bayesnet {
{
return TAN::graph(name);
}
}
}

View File

@@ -8,7 +8,7 @@ if(ENABLE_TESTING)
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc TestXSPnDE.cc TestXBA2DE.cc
TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc TestXBAODE.cc TestA2DE.cc
TestUtils.cc TestBayesEnsemble.cc TestModulesVersions.cc TestBoostA2DE.cc TestMST.cc TestXSPODE.cc ${BayesNet_SOURCES})
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" fimdlp::fimdlp PRIVATE Catch2::Catch2WithMain folding::folding)
target_link_libraries(TestBayesNet PRIVATE torch::torch fimdlp::fimdlp Catch2::Catch2WithMain folding::folding)
add_test(NAME BayesNetworkTest COMMAND TestBayesNet)
add_test(NAME A2DE COMMAND TestBayesNet "[A2DE]")
add_test(NAME BoostA2DE COMMAND TestBayesNet "[BoostA2DE]")