Implement Cestnik & Laplace smoothing

This commit is contained in:
2024-06-09 17:19:38 +02:00
parent 6d9badc33b
commit 684443a788
9 changed files with 28 additions and 22 deletions

View File

@@ -12,6 +12,10 @@
#include "Node.h"
namespace bayesnet {
enum class Smoothing_t {
LAPLACE,
CESTNIK
};
class Network {
public:
Network();
@@ -54,15 +58,15 @@ namespace bayesnet {
int classNumStates;
std::vector<std::string> features; // Including classname
std::string className;
double laplaceSmoothing;
Smoothing_t smoothing;
torch::Tensor samples; // n+1xm tensor used to fit the model
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
std::vector<double> predict_sample(const std::vector<int>&);
std::vector<double> predict_sample(const torch::Tensor&);
std::vector<double> exactInference(std::map<std::string, int>&);
double computeFactor(std::map<std::string, int>&);
void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
void checkFitData(int n_features, int n_samples, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
void completeFit(const std::map<std::string, std::vector<int>>& states, const int n_samples, const torch::Tensor& weights);
void checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
void setStates(const std::map<std::string, std::vector<int>>&);
};
}