2024-04-11 16:02:49 +00:00
|
|
|
// ***************************************************************
|
|
|
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
|
|
|
// SPDX-FileType: SOURCE
|
|
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
// ***************************************************************
|
|
|
|
|
2023-06-29 20:00:41 +00:00
|
|
|
#ifndef NETWORK_H
|
|
|
|
#define NETWORK_H
|
|
|
|
#include <map>
|
|
|
|
#include <vector>
|
2024-03-08 21:20:54 +00:00
|
|
|
#include "bayesnet/config.h"
|
|
|
|
#include "Node.h"
|
2023-06-30 19:24:12 +00:00
|
|
|
|
2023-06-29 20:00:41 +00:00
|
|
|
namespace bayesnet {
|
|
|
|
class Network {
|
|
|
|
public:
|
2023-06-30 00:46:06 +00:00
|
|
|
Network();
|
2023-07-29 17:38:42 +00:00
|
|
|
explicit Network(float);
|
2024-04-07 22:13:59 +00:00
|
|
|
explicit Network(const Network&);
|
2023-08-31 18:30:28 +00:00
|
|
|
~Network() = default;
|
2023-07-11 20:23:49 +00:00
|
|
|
torch::Tensor& getSamples();
|
2024-04-07 22:13:59 +00:00
|
|
|
float getMaxThreads() const;
|
2023-11-08 17:45:35 +00:00
|
|
|
void addNode(const std::string&);
|
|
|
|
void addEdge(const std::string&, const std::string&);
|
|
|
|
std::map<std::string, std::unique_ptr<Node>>& getNodes();
|
|
|
|
std::vector<std::string> getFeatures() const;
|
2023-08-07 23:53:41 +00:00
|
|
|
int getStates() const;
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::pair<std::string, std::string>> getEdges() const;
|
2023-08-07 23:53:41 +00:00
|
|
|
int getNumEdges() const;
|
|
|
|
int getClassNumStates() const;
|
2023-11-08 17:45:35 +00:00
|
|
|
std::string getClassName() const;
|
2023-10-09 09:25:30 +00:00
|
|
|
/*
|
|
|
|
Notice: Nodes have to be inserted in the same order as they are in the dataset, i.e., first node is first column and so on.
|
|
|
|
*/
|
2023-11-08 17:45:35 +00:00
|
|
|
void fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
|
|
|
|
void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
|
|
|
|
void fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
|
|
|
|
std::vector<int> predict(const std::vector<std::vector<int>>&); // Return mx1 std::vector of predictions
|
2023-08-03 18:22:33 +00:00
|
|
|
torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions
|
|
|
|
torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba);
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>&); // Return mxn std::vector of probabilities
|
2023-08-03 18:22:33 +00:00
|
|
|
torch::Tensor predict_proba(const torch::Tensor&); // Return mxn tensor of probabilities
|
2023-11-08 17:45:35 +00:00
|
|
|
double score(const std::vector<std::vector<int>>&, const std::vector<int>&);
|
|
|
|
std::vector<std::string> topological_sort();
|
|
|
|
std::vector<std::string> show() const;
|
|
|
|
std::vector<std::string> graph(const std::string& title) const; // Returns a std::vector of std::strings representing the graph in graphviz format
|
2023-08-03 18:22:33 +00:00
|
|
|
void initialize();
|
2024-04-07 22:13:59 +00:00
|
|
|
std::string dump_cpt() const;
|
2024-01-07 18:58:22 +00:00
|
|
|
inline std::string version() { return { project_version.begin(), project_version.end() }; }
|
2024-02-22 10:45:40 +00:00
|
|
|
private:
|
|
|
|
std::map<std::string, std::unique_ptr<Node>> nodes;
|
|
|
|
bool fitted;
|
|
|
|
float maxThreads = 0.95;
|
|
|
|
int classNumStates;
|
|
|
|
std::vector<std::string> features; // Including classname
|
|
|
|
std::string className;
|
|
|
|
double laplaceSmoothing;
|
2024-04-07 22:13:59 +00:00
|
|
|
torch::Tensor samples; // n+1xm tensor used to fit the model
|
2024-02-22 10:45:40 +00:00
|
|
|
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
|
|
|
std::vector<double> predict_sample(const std::vector<int>&);
|
|
|
|
std::vector<double> predict_sample(const torch::Tensor&);
|
|
|
|
std::vector<double> exactInference(std::map<std::string, int>&);
|
|
|
|
double computeFactor(std::map<std::string, int>&);
|
|
|
|
void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
|
|
|
|
void checkFitData(int n_features, int n_samples, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
|
|
|
|
void setStates(const std::map<std::string, std::vector<int>>&);
|
2023-06-29 20:00:41 +00:00
|
|
|
};
|
|
|
|
}
|
|
|
|
#endif
|