Fix tests

This commit is contained in:
2024-06-21 13:58:42 +02:00
parent 02bcab01be
commit 8e9090d283
5 changed files with 76 additions and 75 deletions

View File

@@ -21,11 +21,9 @@ namespace bayesnet {
class Network {
public:
Network();
explicit Network(float);
explicit Network(const Network&);
~Network() = default;
torch::Tensor& getSamples();
float getMaxThreads() const;
void addNode(const std::string&);
void addEdge(const std::string&, const std::string&);
std::map<std::string, std::unique_ptr<Node>>& getNodes();
@@ -64,7 +62,6 @@ namespace bayesnet {
std::vector<double> predict_sample(const std::vector<int>&);
std::vector<double> predict_sample(const torch::Tensor&);
std::vector<double> exactInference(std::map<std::string, int>&);
double computeFactor(std::map<std::string, int>&);
void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing);
void checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
void setStates(const std::map<std::string, std::vector<int>>&);