Implement Cestnik & Laplace smoothing
This commit is contained in:
@@ -7,17 +7,18 @@
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include "Network.h"
|
||||
#include "bayesnet/utils/bayesnetUtils.h"
|
||||
namespace bayesnet {
|
||||
Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
|
||||
Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, smoothing{ Smoothing_t::LAPLACE }
|
||||
{
|
||||
}
|
||||
Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
|
||||
Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, smoothing{ Smoothing_t::LAPLACE }
|
||||
{
|
||||
|
||||
}
|
||||
Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
|
||||
Network::Network(const Network& other) : smoothing(other.smoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
|
||||
maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
|
||||
{
|
||||
if (samples.defined())
|
||||
@@ -164,14 +165,14 @@ namespace bayesnet {
|
||||
for (int i = 0; i < featureNames.size(); ++i) {
|
||||
auto row_feature = X.index({ i, "..." });
|
||||
}
|
||||
completeFit(states, weights);
|
||||
completeFit(states, X.size(0), weights);
|
||||
}
|
||||
void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
|
||||
{
|
||||
checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
|
||||
this->className = className;
|
||||
this->samples = samples;
|
||||
completeFit(states, weights);
|
||||
completeFit(states, samples.size(1), weights);
|
||||
}
|
||||
// input_data comes in nxm, where n is the number of features and m the number of samples
|
||||
void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
|
||||
@@ -185,16 +186,17 @@ namespace bayesnet {
|
||||
samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
|
||||
}
|
||||
samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
|
||||
completeFit(states, weights);
|
||||
completeFit(states, input_data[0].size(), weights);
|
||||
}
|
||||
void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
|
||||
void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const int n_samples, const torch::Tensor& weights)
|
||||
{
|
||||
setStates(states);
|
||||
laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
|
||||
std::vector<std::thread> threads;
|
||||
for (auto& node : nodes) {
|
||||
threads.emplace_back([this, &node, &weights]() {
|
||||
node.second->computeCPT(samples, features, laplaceSmoothing, weights);
|
||||
threads.emplace_back([this, &node, &weights, n_samples]() {
|
||||
auto numStates = node.second->getNumStates();
|
||||
double smoothing_factor = smoothing == Smoothing_t::CESTNIK ? static_cast<double>(n_samples) / numStates : 1.0 / static_cast<double>(n_samples);
|
||||
node.second->computeCPT(samples, features, smoothing_factor, weights);
|
||||
});
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
@@ -337,7 +339,7 @@ namespace bayesnet {
|
||||
thread.join();
|
||||
}
|
||||
// Normalize result
|
||||
double sum = accumulate(result.begin(), result.end(), 0.0);
|
||||
double sum = std::accumulate(result.begin(), result.end(), 0.0);
|
||||
transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
|
||||
return result;
|
||||
}
|
||||
|
Reference in New Issue
Block a user