// *************************************************************** // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #ifndef NETWORK_H #define NETWORK_H #include #include #include "bayesnet/config.h" #include "Node.h" namespace bayesnet { class Network { public: Network(); explicit Network(float); explicit Network(const Network&); ~Network() = default; torch::Tensor& getSamples(); float getMaxThreads() const; void addNode(const std::string&); void addEdge(const std::string&, const std::string&); std::map>& getNodes(); std::vector getFeatures() const; int getStates() const; std::vector> getEdges() const; int getNumEdges() const; int getClassNumStates() const; std::string getClassName() const; /* Notice: Nodes have to be inserted in the same order as they are in the dataset, i.e., first node is first column and so on. */ void fit(const std::vector>& input_data, const std::vector& labels, const std::vector& weights, const std::vector& featureNames, const std::string& className, const std::map>& states); void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector& featureNames, const std::string& className, const std::map>& states); void fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector& featureNames, const std::string& className, const std::map>& states); std::vector predict(const std::vector>&); // Return mx1 std::vector of predictions torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba); std::vector> predict_proba(const std::vector>&); // Return mxn std::vector of probabilities torch::Tensor predict_proba(const torch::Tensor&); // Return mxn tensor of probabilities double score(const std::vector>&, const std::vector&); std::vector topological_sort(); std::vector show() const; std::vector graph(const std::string& title) const; // Returns a std::vector of std::strings representing the graph in graphviz format void initialize(); std::string dump_cpt() const; inline std::string version() { return { project_version.begin(), project_version.end() }; } private: std::map> nodes; bool fitted; float maxThreads = 0.95; int classNumStates; std::vector features; // Including classname std::string className; double laplaceSmoothing; torch::Tensor samples; // n+1xm tensor used to fit the model bool isCyclic(const std::string&, std::unordered_set&, std::unordered_set&); std::vector predict_sample(const std::vector&); std::vector predict_sample(const torch::Tensor&); std::vector exactInference(std::map&); double computeFactor(std::map&); void completeFit(const std::map>& states, const torch::Tensor& weights); void checkFitData(int n_features, int n_samples, int n_samples_y, const std::vector& featureNames, const std::string& className, const std::map>& states, const torch::Tensor& weights); void setStates(const std::map>&); }; } #endif