Add mdlp as library in lib/
Fix tests to reach 99.1% of coverage

Reviewed-on: #31
This commit is contained in:
2024-11-23 17:22:41 +00:00
parent f0f3d9ad6e
commit 86f2bc44fc
26 changed files with 5183 additions and 261 deletions

View File

@@ -4,10 +4,9 @@ include_directories(
${BayesNet_SOURCE_DIR}/lib/json/include
${BayesNet_SOURCE_DIR}
${CMAKE_BINARY_DIR}/configured_files/include
${FImdlp_INCLUDE_DIRS}
)
file(GLOB_RECURSE Sources "*.cc")
add_library(BayesNet ${Sources})
target_link_libraries(BayesNet ${FImdlp} "${TORCH_LIBRARIES}")
target_link_libraries(BayesNet fimdlp "${TORCH_LIBRARIES}")

View File

@@ -9,7 +9,7 @@
#include <string>
#include <map>
#include <torch/torch.h>
#include <fimdlp/CPPFImdlp.h>
#include <CPPFImdlp.h>
#include "bayesnet/network/Network.h"
#include "Classifier.h"

View File

@@ -59,6 +59,9 @@ namespace bayesnet {
std::vector<int> featuresUsed;
if (selectFeatures) {
featuresUsed = initializeModels(smoothing);
if (featuresUsed.size() == 0) {
return;
}
auto ypred = predict(X_train);
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
// Update significance of the models

View File

@@ -209,7 +209,7 @@ namespace bayesnet {
pthread_setname_np(threadName.c_str());
#endif
double numStates = static_cast<double>(node.second->getNumStates());
double smoothing_factor = 0.0;
double smoothing_factor;
switch (smoothing) {
case Smoothing_t::ORIGINAL:
smoothing_factor = 1.0 / n_samples;
@@ -221,7 +221,7 @@ namespace bayesnet {
smoothing_factor = 1 / numStates;
break;
default:
throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast<int>(smoothing)));
smoothing_factor = 0.0; // No smoothing
}
node.second->computeCPT(samples, features, smoothing_factor, weights);
semaphore.release();
@@ -234,16 +234,6 @@ namespace bayesnet {
for (auto& thread : threads) {
thread.join();
}
// std::fstream file;
// file.open("cpt.txt", std::fstream::out | std::fstream::app);
// file << std::string(80, '*') << std::endl;
// for (const auto& item : graph("Test")) {
// file << item << std::endl;
// }
// file << std::string(80, '-') << std::endl;
// file << dump_cpt() << std::endl;
// file << std::string(80, '=') << std::endl;
// file.close();
fitted = true;
}
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)

View File

@@ -53,14 +53,14 @@ namespace bayesnet {
}
}
void insertElement(std::list<int>& variables, int variable)
void MST::insertElement(std::list<int>& variables, int variable)
{
if (std::find(variables.begin(), variables.end(), variable) == variables.end()) {
variables.push_front(variable);
}
}
std::vector<std::pair<int, int>> reorder(std::vector<std::pair<float, std::pair<int, int>>> T, int root_original)
std::vector<std::pair<int, int>> MST::reorder(std::vector<std::pair<float, std::pair<int, int>>> T, int root_original)
{
// Create the edges of a DAG from the MST
// replacing unordered_set with list because unordered_set cannot guarantee the order of the elements inserted

View File

@@ -14,6 +14,8 @@ namespace bayesnet {
public:
MST() = default;
MST(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
void insertElement(std::list<int>& variables, int variable);
std::vector<std::pair<int, int>> reorder(std::vector<std::pair<float, std::pair<int, int>>> T, int root_original);
std::vector<std::pair<int, int>> maximumSpanningTree();
private:
torch::Tensor weights;