Implement 3 types of smoothing

This commit is contained in:
2024-06-10 15:49:01 +02:00
parent 684443a788
commit 27a3e5a5e0
11 changed files with 37 additions and 9 deletions

View File

@@ -165,14 +165,14 @@ namespace bayesnet {
for (int i = 0; i < featureNames.size(); ++i) {
auto row_feature = X.index({ i, "..." });
}
completeFit(states, X.size(0), weights);
completeFit(states, weights);
}
void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
{
checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
this->className = className;
this->samples = samples;
completeFit(states, samples.size(1), weights);
completeFit(states, weights);
}
// input_data comes in nxm, where n is the number of features and m the number of samples
void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
@@ -186,16 +186,30 @@ namespace bayesnet {
samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
}
samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
completeFit(states, input_data[0].size(), weights);
completeFit(states, weights);
}
void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const int n_samples, const torch::Tensor& weights)
void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
{
setStates(states);
std::vector<std::thread> threads;
const double n_samples = static_cast<double>(samples.size(1));
for (auto& node : nodes) {
threads.emplace_back([this, &node, &weights, n_samples]() {
auto numStates = node.second->getNumStates();
double smoothing_factor = smoothing == Smoothing_t::CESTNIK ? static_cast<double>(n_samples) / numStates : 1.0 / static_cast<double>(n_samples);
double numStates = static_cast<double>(node.second->getNumStates());
double smoothing_factor = 0.0;
switch (smoothing) {
case Smoothing_t::OLD_LAPLACE:
smoothing_factor = 1.0 / n_samples;
break;
case Smoothing_t::LAPLACE:
smoothing_factor = 1.0;
break;
case Smoothing_t::CESTNIK:
smoothing_factor = n_samples / numStates;
break;
default:
throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast<int>(smoothing)));
}
node.second->computeCPT(samples, features, smoothing_factor, weights);
});
}