From af0419c9dab9b3dabf4909833037d674a578e993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sun, 13 Aug 2023 00:59:02 +0200 Subject: [PATCH] First approx with const 1 weights --- .vscode/launch.json | 5 +++-- src/BayesNet/Classifier.cc | 3 ++- src/BayesNet/Network.cc | 30 +++++++++++++++++------------- src/BayesNet/Network.h | 10 +++++----- src/BayesNet/Node.cc | 4 ++-- src/BayesNet/Node.h | 2 +- src/BayesNet/Proposal.cc | 3 ++- 7 files changed, 32 insertions(+), 25 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index ba01ca6..a42c076 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -25,12 +25,13 @@ "program": "${workspaceFolder}/build/src/Platform/main", "args": [ "-m", - "SPODELd", + "SPODE", "-p", "/Users/rmontanana/Code/discretizbench/datasets", "--stratified", + "--discretize", "-d", - "iris" + "letter" ], "cwd": "/Users/rmontanana/Code/discretizbench", }, diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index b3317f4..87bae91 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -37,7 +37,8 @@ namespace bayesnet { } void Classifier::trainModel() { - model.fit(dataset, features, className, states); + const torch::Tensor weights = torch::ones({ m }); + model.fit(dataset, weights, features, className, states); } // X is nxm where n is the number of features and m the number of samples Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index 8a4106c..fbb62cc 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -104,8 +104,11 @@ namespace bayesnet { { return nodes; } - void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector& featureNames, const string& className, const map>& states) + void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector& featureNames, const string& className, const map>& states, const torch::Tensor& weights) { + if (weights.size(0) != n_samples) { + throw invalid_argument("Weights must have the same number of elements as samples in Network::fit"); + } if (n_samples != n_samples_y) { throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")"); } @@ -136,28 +139,29 @@ namespace bayesnet { classNumStates = nodes[className]->getNumStates(); } // X comes in nxm, where n is the number of features and m the number of samples - void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const vector& featureNames, const string& className, const map>& states) + void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector& featureNames, const string& className, const map>& states) { - checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states); + checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights); this->className = className; Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1); samples = torch::cat({ X , ytmp }, 0); for (int i = 0; i < featureNames.size(); ++i) { auto row_feature = X.index({ i, "..." }); } - completeFit(states); + completeFit(states, weights); } - void Network::fit(const torch::Tensor& samples, const vector& featureNames, const string& className, const map>& states) + void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const vector& featureNames, const string& className, const map>& states) { - checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states); + checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights); this->className = className; this->samples = samples; - completeFit(states); + completeFit(states, weights); } // input_data comes in nxm, where n is the number of features and m the number of samples - void Network::fit(const vector>& input_data, const vector& labels, const vector& featureNames, const string& className, const map>& states) + void Network::fit(const vector>& input_data, const vector& labels, const vector& weights_, const vector& featureNames, const string& className, const map>& states) { - checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states); + const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64); + checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights); this->className = className; // Build tensor of samples (nxm) (n+1 because of the class) samples = torch::zeros({ static_cast(input_data.size() + 1), static_cast(input_data[0].size()) }, torch::kInt32); @@ -165,9 +169,9 @@ namespace bayesnet { samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32)); } samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32)); - completeFit(states); + completeFit(states, weights); } - void Network::completeFit(const map>& states) + void Network::completeFit(const map>& states, const torch::Tensor& weights) { setStates(states); int maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); @@ -182,7 +186,7 @@ namespace bayesnet { while (nextNodeIndex < nodes.size()) { unique_lock lock(mtx); cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; }); - threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() { + threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads, &weights]() { while (true) { unique_lock lock(mtx); if (nextNodeIndex >= nodes.size()) { @@ -191,7 +195,7 @@ namespace bayesnet { auto& pair = *std::next(nodes.begin(), nextNodeIndex); ++nextNodeIndex; lock.unlock(); - pair.second->computeCPT(samples, features, laplaceSmoothing); + pair.second->computeCPT(samples, features, laplaceSmoothing, weights); lock.lock(); nodes[pair.first] = std::move(pair.second); lock.unlock(); diff --git a/src/BayesNet/Network.h b/src/BayesNet/Network.h index d8db620..5ea94ec 100644 --- a/src/BayesNet/Network.h +++ b/src/BayesNet/Network.h @@ -20,8 +20,8 @@ namespace bayesnet { vector predict_sample(const torch::Tensor&); vector exactInference(map&); double computeFactor(map&); - void completeFit(const map>&); - void checkFitData(int n_features, int n_samples, int n_samples_y, const vector& featureNames, const string& className, const map>&); + void completeFit(const map>& states, const torch::Tensor& weights); + void checkFitData(int n_features, int n_samples, int n_samples_y, const vector& featureNames, const string& className, const map>& states, const torch::Tensor& weights); void setStates(const map>&); public: Network(); @@ -39,9 +39,9 @@ namespace bayesnet { int getNumEdges() const; int getClassNumStates() const; string getClassName() const; - void fit(const vector>&, const vector&, const vector&, const string&, const map>&); - void fit(const torch::Tensor&, const torch::Tensor&, const vector&, const string&, const map>&); - void fit(const torch::Tensor&, const vector&, const string&, const map>&); + void fit(const vector>& input_data, const vector& labels, const vector& weights, const vector& featureNames, const string& className, const map>& states); + void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector& featureNames, const string& className, const map>& states); + void fit(const torch::Tensor& samples, const torch::Tensor& weights, const vector& featureNames, const string& className, const map>& states); vector predict(const vector>&); // Return mx1 vector of predictions torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba); diff --git a/src/BayesNet/Node.cc b/src/BayesNet/Node.cc index 6669819..10f26b8 100644 --- a/src/BayesNet/Node.cc +++ b/src/BayesNet/Node.cc @@ -84,7 +84,7 @@ namespace bayesnet { } return result; } - void Node::computeCPT(const torch::Tensor& dataset, const vector& features, const int laplaceSmoothing) + void Node::computeCPT(const torch::Tensor& dataset, const vector& features, const int laplaceSmoothing, const torch::Tensor& weights) { dimensions.clear(); // Get dimensions of the CPT @@ -111,7 +111,7 @@ namespace bayesnet { coordinates.push_back(dataset.index({ parent_index, n_sample })); } // Increment the count of the corresponding coordinate - cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1); + cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item()); } // Normalize the counts cpTable = cpTable / cpTable.sum(0); diff --git a/src/BayesNet/Node.h b/src/BayesNet/Node.h index f4eb320..83c4b1a 100644 --- a/src/BayesNet/Node.h +++ b/src/BayesNet/Node.h @@ -26,7 +26,7 @@ namespace bayesnet { vector& getParents(); vector& getChildren(); torch::Tensor& getCPT(); - void computeCPT(const torch::Tensor&, const vector&, const int); + void computeCPT(const torch::Tensor& dataset, const vector& features, const int laplaceSmoothing, const torch::Tensor& weights); int getNumStates() const; void setNumStates(int); unsigned minFill(); diff --git a/src/BayesNet/Proposal.cc b/src/BayesNet/Proposal.cc index eef0088..d95e701 100644 --- a/src/BayesNet/Proposal.cc +++ b/src/BayesNet/Proposal.cc @@ -65,7 +65,8 @@ namespace bayesnet { //Update new states of the feature/node states[pFeatures[index]] = xStates; } - model.fit(pDataset, pFeatures, pClassName, states); + const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat); + model.fit(pDataset, weights, pFeatures, pClassName, states); } return states; }