diff --git a/src/BayesNet/BayesMetrics.cc b/src/BayesNet/BayesMetrics.cc index 8952ead..2f0de11 100644 --- a/src/BayesNet/BayesMetrics.cc +++ b/src/BayesNet/BayesMetrics.cc @@ -32,7 +32,7 @@ namespace bayesnet { } return result; } - torch::Tensor Metrics::conditionalEdge() + torch::Tensor Metrics::conditionalEdge(const torch::Tensor& weights) { auto result = vector(); auto source = vector(features); @@ -52,7 +52,7 @@ namespace bayesnet { auto mask = samples.index({ -1, "..." }) == value; auto first_dataset = samples.index({ index_first, mask }); auto second_dataset = samples.index({ index_second, mask }); - auto mi = mutualInformation(first_dataset, second_dataset); + auto mi = mutualInformation(first_dataset, second_dataset, weights); auto pb = margin[value].item(); accumulated += pb * mi; } @@ -70,15 +70,16 @@ namespace bayesnet { return matrix; } // To use in Python - vector Metrics::conditionalEdgeWeights() + vector Metrics::conditionalEdgeWeights(vector& weights_) { - auto matrix = conditionalEdge(); + const torch::Tensor weights = torch::tensor(weights_); + auto matrix = conditionalEdge(weights); std::vector v(matrix.data_ptr(), matrix.data_ptr() + matrix.numel()); return v; } - double Metrics::entropy(const torch::Tensor& feature) + double Metrics::entropy(const torch::Tensor& feature, const torch::Tensor& weights) { - torch::Tensor counts = feature.bincount(); + torch::Tensor counts = feature.bincount(weights); int totalWeight = counts.sum().item(); torch::Tensor probs = counts.to(torch::kFloat) / totalWeight; torch::Tensor logProbs = torch::log(probs); @@ -86,15 +87,15 @@ namespace bayesnet { return entropy.nansum().item(); } // H(Y|X) = sum_{x in X} p(x) H(Y|X=x) - double Metrics::conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature) + double Metrics::conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights) { int numSamples = firstFeature.sizes()[0]; - torch::Tensor featureCounts = secondFeature.bincount(); + torch::Tensor featureCounts = secondFeature.bincount(weights); unordered_map> jointCounts; double totalWeight = 0; for (auto i = 0; i < numSamples; i++) { jointCounts[secondFeature[i].item()][firstFeature[i].item()] += 1; - totalWeight += 1; + totalWeight += weights[i].item(); } if (totalWeight == 0) return 0; @@ -115,9 +116,9 @@ namespace bayesnet { return entropyValue; } // I(X;Y) = H(Y) - H(Y|X) - double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature) + double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights) { - return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature); + return entropy(firstFeature, weights) - conditionalEntropy(firstFeature, secondFeature, weights); } /* Compute the maximum spanning tree considering the weights as distances diff --git a/src/BayesNet/BayesMetrics.h b/src/BayesNet/BayesMetrics.h index 2a2fff3..5bd25b6 100644 --- a/src/BayesNet/BayesMetrics.h +++ b/src/BayesNet/BayesMetrics.h @@ -12,16 +12,16 @@ namespace bayesnet { vector features; string className; int classNumStates = 0; + double entropy(const Tensor& feature, const Tensor& weights); + double conditionalEntropy(const Tensor& firstFeature, const Tensor& secondFeature, const Tensor& weights); + vector> doCombinations(const vector&); public: Metrics() = default; - Metrics(const Tensor&, const vector&, const string&, const int); - Metrics(const vector>&, const vector&, const vector&, const string&, const int); - double entropy(const Tensor&); - double conditionalEntropy(const Tensor&, const Tensor&); - double mutualInformation(const Tensor&, const Tensor&); - vector conditionalEdgeWeights(); // To use in Python - Tensor conditionalEdge(); - vector> doCombinations(const vector&); + Metrics(const torch::Tensor& samples, const vector& features, const string& className, const int classNumStates); + Metrics(const vector>& vsamples, const vector& labels, const vector& features, const string& className, const int classNumStates); + double mutualInformation(const Tensor& firstFeature, const Tensor& secondFeature, const Tensor& weights); + vector conditionalEdgeWeights(vector& weights); // To use in Python + Tensor conditionalEdge(const torch::Tensor& weights); vector> maximumSpanningTree(const vector& features, const Tensor& weights, const int root); }; } diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index 2e736a3..6d00928 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -14,13 +14,14 @@ namespace bayesnet { Classifier& build(vector& features, string className, map>& states); protected: bool fitted; - Network model; int m, n; // m: number of samples, n: number of features - Tensor dataset; // (n+1)xm tensor + Network model; Metrics metrics; vector features; string className; map> states; + Tensor dataset; // (n+1)xm tensor + Tensor weights; void checkFitParameters(); virtual void buildModel() = 0; void trainModel() override; diff --git a/src/BayesNet/KDB.cc b/src/BayesNet/KDB.cc index 74566b0..874e08a 100644 --- a/src/BayesNet/KDB.cc +++ b/src/BayesNet/KDB.cc @@ -32,10 +32,10 @@ namespace bayesnet { vector mi; for (auto i = 0; i < features.size(); i++) { Tensor firstFeature = dataset.index({ i, "..." }); - mi.push_back(metrics.mutualInformation(firstFeature, y)); + mi.push_back(metrics.mutualInformation(firstFeature, y, weights)); } // 2. Compute class conditional mutual information I(Xi;XjIC), f or each - auto conditionalEdgeWeights = metrics.conditionalEdge(); + auto conditionalEdgeWeights = metrics.conditionalEdge(weights); // 3. Let the used variable list, S, be empty. vector S; // 4. Let the DAG network being constructed, BN, begin with a single diff --git a/src/BayesNet/TAN.cc b/src/BayesNet/TAN.cc index 7b3e3a6..843a5e6 100644 --- a/src/BayesNet/TAN.cc +++ b/src/BayesNet/TAN.cc @@ -15,15 +15,15 @@ namespace bayesnet { Tensor class_dataset = dataset.index({ -1, "..." }); for (int i = 0; i < static_cast(features.size()); ++i) { Tensor feature_dataset = dataset.index({ i, "..." }); - auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset); + auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset, weights); mi.push_back({ i, mi_value }); } sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;}); auto root = mi[mi.size() - 1].first; // 2. Compute mutual information between each feature and the class - auto weights = metrics.conditionalEdge(); + auto weights_matrix = metrics.conditionalEdge(weights); // 3. Compute the maximum spanning tree - auto mst = metrics.maximumSpanningTree(features, weights, root); + auto mst = metrics.maximumSpanningTree(features, weights_matrix, root); // 4. Add edges from the maximum spanning tree to the model for (auto i = 0; i < mst.size(); ++i) { auto [from, to] = mst[i];