Fix weights mistakes in computation

This commit is contained in:
Ricardo Montañana Gómez 2023-08-16 12:32:51 +02:00
parent 4d4780c1d5
commit 80b20f35b4
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
16 changed files with 262 additions and 75 deletions

3
.vscode/launch.json vendored
View File

@ -25,9 +25,10 @@
"program": "${workspaceFolder}/build/src/Platform/main",
"args": [
"-m",
"TANLd",
"BoostAODE",
"-p",
"/Users/rmontanana/Code/discretizbench/datasets",
"--discretize",
"--stratified",
"-d",
"iris"

View File

@ -60,6 +60,7 @@ add_git_submodule("lib/json")
# --------------
add_subdirectory(config)
add_subdirectory(lib/Files)
add_subdirectory(lib/FeatureSelect)
add_subdirectory(src/BayesNet)
add_subdirectory(src/Platform)
add_subdirectory(sample)

View File

@ -1,2 +1 @@
add_library(ArffFiles ArffFiles.cc)
#target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
add_library(ArffFiles ArffFiles.cc)

View File

@ -0,0 +1 @@
add_library(FeatureSelect FeatureSelect.cpp)

View File

@ -0,0 +1,119 @@
#include "FeatureSelect.h"
namespace features {
SelectKBestWeighted::SelectKBestWeighted(samples_t& samples, labels_t& labels, weights_t& weights, int k, bool nat)
: samples(samples), labels(labels), weights(weights), k(k), nat(nat)
{
if (samples.size() == 0 || samples[0].size() == 0)
throw invalid_argument("features must be a non-empty matrix");
if (samples.size() != labels.size())
throw invalid_argument("number of samples (" + to_string(samples.size()) + ") and labels (" + to_string(labels.size()) + ") must be equal");
if (samples.size() != weights.size())
throw invalid_argument("number of samples and weights must be equal");
if (k < 1 || k > static_cast<int>(samples[0].size()))
throw invalid_argument("k must be between 1 and number of features");
numFeatures = 0;
numClasses = 0;
numSamples = 0;
fitted = false;
}
SelectKBestWeighted& SelectKBestWeighted::fit()
{
auto labelsCopy = labels;
numFeatures = samples[0].size();
numSamples = samples.size();
// compute number of classes
sort(labelsCopy.begin(), labelsCopy.end());
auto last = unique(labelsCopy.begin(), labelsCopy.end());
labelsCopy.erase(last, labelsCopy.end());
numClasses = labelsCopy.size();
// compute scores
scores.reserve(numFeatures);
for (int i = 0; i < numFeatures; ++i) {
scores.push_back(MutualInformation(i));
features.push_back(i);
}
// sort & reduce scores and features
sort(features.begin(), features.end(), [&](int i, int j)
{ return scores[i] > scores[j]; });
sort(scores.begin(), scores.end(), greater<precision_t>());
features.resize(k);
scores.resize(k);
fitted = true;
return *this;
}
precision_t SelectKBestWeighted::entropyLabel()
{
return entropy(labels);
}
precision_t SelectKBestWeighted::entropy(const sample_t& data)
{
precision_t ventropy = 0, totalWeight = 0;
score_t counts(numClasses + 1, 0);
for (auto i = 0; i < static_cast<int>(data.size()); ++i) {
counts[data[i]] += weights[i];
totalWeight += weights[i];
}
for (auto count : counts) {
precision_t p = count / totalWeight;
if (p > 0) {
if (nat) {
ventropy -= p * log(p);
} else {
ventropy -= p * log2(p);
}
}
}
return ventropy;
}
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
precision_t SelectKBestWeighted::conditionalEntropy(const int feature)
{
unordered_map<value_t, precision_t> featureCounts;
unordered_map<value_t, unordered_map<value_t, precision_t>> jointCounts;
featureCounts.clear();
jointCounts.clear();
precision_t totalWeight = 0;
for (auto i = 0; i < numSamples; i++) {
featureCounts[samples[i][feature]] += weights[i];
jointCounts[samples[i][feature]][labels[i]] += weights[i];
totalWeight += weights[i];
}
if (totalWeight == 0)
throw invalid_argument("Total weight should not be zero");
precision_t entropy = 0;
for (auto& [feat, count] : featureCounts) {
auto p_f = count / totalWeight;
precision_t entropy_f = 0;
for (auto& [label, jointCount] : jointCounts[feat]) {
auto p_l_f = jointCount / count;
if (p_l_f > 0) {
if (nat) {
entropy_f -= p_l_f * log(p_l_f);
} else {
entropy_f -= p_l_f * log2(p_l_f);
}
}
}
entropy += p_f * entropy_f;
}
return entropy;
}
// I(X;Y) = H(Y) - H(Y|X)
precision_t SelectKBestWeighted::MutualInformation(const int i)
{
return entropyLabel() - conditionalEntropy(i);
}
score_t SelectKBestWeighted::getScores() const
{
if (!fitted)
throw logic_error("score not fitted");
return scores;
}
//Return the indices of the selected features
labels_t SelectKBestWeighted::getFeatures() const
{
if (!fitted)
throw logic_error("score not fitted");
return features;
}
}

View File

@ -0,0 +1,38 @@
#ifndef SELECT_K_BEST_WEIGHTED_H
#define SELECT_K_BEST_WEIGHTED_H
#include <map>
#include <vector>
#include <string>
using namespace std;
namespace features {
typedef float precision_t;
typedef int value_t;
typedef vector<value_t> sample_t;
typedef vector<sample_t> samples_t;
typedef vector<value_t> labels_t;
typedef vector<precision_t> score_t, weights_t;
class SelectKBestWeighted {
private:
const samples_t samples;
const labels_t labels;
const weights_t weights;
const int k;
bool nat; // use natural log or log2
int numFeatures, numClasses, numSamples;
bool fitted;
score_t scores; // scores of the features
labels_t features; // indices of the selected features
precision_t entropyLabel();
precision_t entropy(const sample_t&);
precision_t conditionalEntropy(const int);
precision_t MutualInformation(const int);
public:
SelectKBestWeighted(samples_t&, labels_t&, weights_t&, int, bool);
SelectKBestWeighted& fit();
score_t getScores() const;
labels_t getFeatures() const; //Return the indices of the selected features
static inline string version() { return "0.1.0"; };
};
}
#endif

View File

@ -178,59 +178,59 @@ int main(int argc, char** argv)
cout << "end." << endl;
auto score = clf->score(Xd, y);
cout << "Score: " << score << endl;
// auto graph = clf->graph();
// auto dot_file = model_name + "_" + file_name;
// ofstream file(dot_file + ".dot");
// file << graph;
// file.close();
// cout << "Graph saved in " << model_name << "_" << file_name << ".dot" << endl;
// cout << "dot -Tpng -o " + dot_file + ".png " + dot_file + ".dot " << endl;
// string stratified_string = stratified ? " Stratified" : "";
// cout << nFolds << " Folds" << stratified_string << " Cross validation" << endl;
// cout << "==========================================" << endl;
// torch::Tensor Xt = torch::zeros({ static_cast<int>(Xd.size()), static_cast<int>(Xd[0].size()) }, torch::kInt32);
// torch::Tensor yt = torch::tensor(y, torch::kInt32);
// for (int i = 0; i < features.size(); ++i) {
// Xt.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32));
// }
// float total_score = 0, total_score_train = 0, score_train, score_test;
// Fold* fold;
// if (stratified)
// fold = new StratifiedKFold(nFolds, y, seed);
// else
// fold = new KFold(nFolds, y.size(), seed);
// for (auto i = 0; i < nFolds; ++i) {
// auto [train, test] = fold->getFold(i);
// cout << "Fold: " << i + 1 << endl;
// if (tensors) {
// auto ttrain = torch::tensor(train, torch::kInt64);
// auto ttest = torch::tensor(test, torch::kInt64);
// torch::Tensor Xtraint = torch::index_select(Xt, 1, ttrain);
// torch::Tensor ytraint = yt.index({ ttrain });
// torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
// torch::Tensor ytestt = yt.index({ ttest });
// clf->fit(Xtraint, ytraint, features, className, states);
// auto temp = clf->predict(Xtraint);
// score_train = clf->score(Xtraint, ytraint);
// score_test = clf->score(Xtestt, ytestt);
// } else {
// auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
// auto [Xtest, ytest] = extract_indices(test, Xd, y);
// clf->fit(Xtrain, ytrain, features, className, states);
// score_train = clf->score(Xtrain, ytrain);
// score_test = clf->score(Xtest, ytest);
// }
// if (dump_cpt) {
// cout << "--- CPT Tables ---" << endl;
// clf->dump_cpt();
// }
// total_score_train += score_train;
// total_score += score_test;
// cout << "Score Train: " << score_train << endl;
// cout << "Score Test : " << score_test << endl;
// cout << "-------------------------------------------------------------------------------" << endl;
// }
// cout << "**********************************************************************************" << endl;
// cout << "Average Score Train: " << total_score_train / nFolds << endl;
// cout << "Average Score Test : " << total_score / nFolds << endl;return 0;
auto graph = clf->graph();
auto dot_file = model_name + "_" + file_name;
ofstream file(dot_file + ".dot");
file << graph;
file.close();
cout << "Graph saved in " << model_name << "_" << file_name << ".dot" << endl;
cout << "dot -Tpng -o " + dot_file + ".png " + dot_file + ".dot " << endl;
string stratified_string = stratified ? " Stratified" : "";
cout << nFolds << " Folds" << stratified_string << " Cross validation" << endl;
cout << "==========================================" << endl;
torch::Tensor Xt = torch::zeros({ static_cast<int>(Xd.size()), static_cast<int>(Xd[0].size()) }, torch::kInt32);
torch::Tensor yt = torch::tensor(y, torch::kInt32);
for (int i = 0; i < features.size(); ++i) {
Xt.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32));
}
float total_score = 0, total_score_train = 0, score_train, score_test;
Fold* fold;
if (stratified)
fold = new StratifiedKFold(nFolds, y, seed);
else
fold = new KFold(nFolds, y.size(), seed);
for (auto i = 0; i < nFolds; ++i) {
auto [train, test] = fold->getFold(i);
cout << "Fold: " << i + 1 << endl;
if (tensors) {
auto ttrain = torch::tensor(train, torch::kInt64);
auto ttest = torch::tensor(test, torch::kInt64);
torch::Tensor Xtraint = torch::index_select(Xt, 1, ttrain);
torch::Tensor ytraint = yt.index({ ttrain });
torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
torch::Tensor ytestt = yt.index({ ttest });
clf->fit(Xtraint, ytraint, features, className, states);
auto temp = clf->predict(Xtraint);
score_train = clf->score(Xtraint, ytraint);
score_test = clf->score(Xtestt, ytestt);
} else {
auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
auto [Xtest, ytest] = extract_indices(test, Xd, y);
clf->fit(Xtrain, ytrain, features, className, states);
score_train = clf->score(Xtrain, ytrain);
score_test = clf->score(Xtest, ytest);
}
if (dump_cpt) {
cout << "--- CPT Tables ---" << endl;
clf->dump_cpt();
}
total_score_train += score_train;
total_score += score_test;
cout << "Score Train: " << score_train << endl;
cout << "Score Test : " << score_test << endl;
cout << "-------------------------------------------------------------------------------" << endl;
}
cout << "**********************************************************************************" << endl;
cout << "Average Score Train: " << total_score_train / nFolds << endl;
cout << "Average Score Test : " << total_score / nFolds << endl;return 0;
}

View File

@ -38,12 +38,14 @@ namespace bayesnet {
auto source = vector<string>(features);
source.push_back(className);
auto combinations = doCombinations(source);
double totalWeight = weights.sum().item<double>();
// Compute class prior
auto margin = torch::zeros({ classNumStates });
auto margin = torch::zeros({ classNumStates }, torch::kFloat);
for (int value = 0; value < classNumStates; ++value) {
auto mask = samples.index({ -1, "..." }) == value;
margin[value] = mask.sum().item<float>() / samples.size(1);
margin[value] = mask.sum().item<double>() / samples.size(1);
}
cout << "Margin: " << margin;
for (auto [first, second] : combinations) {
int index_first = find(features.begin(), features.end(), first) - features.begin();
int index_second = find(features.begin(), features.end(), second) - features.begin();
@ -54,7 +56,7 @@ namespace bayesnet {
auto second_dataset = samples.index({ index_second, mask });
auto weights_dataset = weights.index({ mask });
auto mi = mutualInformation(first_dataset, second_dataset, weights_dataset);
auto pb = margin[value].item<float>();
auto pb = margin[value].item<double>();
accumulated += pb * mi;
}
result.push_back(accumulated);
@ -81,7 +83,7 @@ namespace bayesnet {
double Metrics::entropy(const torch::Tensor& feature, const torch::Tensor& weights)
{
torch::Tensor counts = feature.bincount(weights);
int totalWeight = counts.sum().item<int>();
double totalWeight = counts.sum().item<double>();
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
torch::Tensor logProbs = torch::log(probs);
torch::Tensor entropy = -probs * logProbs;
@ -95,7 +97,7 @@ namespace bayesnet {
unordered_map<int, unordered_map<int, double>> jointCounts;
double totalWeight = 0;
for (auto i = 0; i < numSamples; i++) {
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += 1;
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += weights[i].item<double>();
totalWeight += weights[i].item<float>();
}
if (totalWeight == 0)

View File

@ -1,10 +1,32 @@
#include "BoostAODE.h"
#include "FeatureSelect.h"
namespace bayesnet {
BoostAODE::BoostAODE() : Ensemble() {}
void BoostAODE::buildModel(const torch::Tensor& weights)
{
models.clear();
int n_samples = dataset.size(1);
int n_features = dataset.size(0);
features::samples_t vsamples;
for (auto i = 0; i < n_samples; ++i) {
auto row = dataset.index({ "...", i });
// convert row to std::vector<int>
auto vrow = vector<int>(row.data_ptr<int>(), row.data_ptr<int>() + row.numel());
vsamples.push_back(vrow);
}
auto vweights = features::weights_t(n_samples, 1.0 / n_samples);
auto row = dataset.index({ -1, "..." });
auto yv = features::labels_t(row.data_ptr<int>(), row.data_ptr<int>() + row.numel());
auto featureSelection = features::SelectKBestWeighted(vsamples, yv, vweights, n_features, true);
auto features = featureSelection.fit().getFeatures();
// features = (
// CSelectKBestWeighted(
// self.X_, self.y_, weights, k = self.n_features_in_
// )
// .fit()
// .get_features()
auto scores = features::score_t(n_features, 0.0);
for (int i = 0; i < features.size(); ++i) {
models.push_back(std::make_unique<SPODE>(i));
}

View File

@ -1,7 +1,9 @@
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${BayesNet_SOURCE_DIR}/lib/featureselect)
include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
include_directories(${BayesNet_SOURCE_DIR}/src/Platform)
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc
KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc BoostAODE.cc Mst.cc Proposal.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
target_link_libraries(BayesNet mdlp ArffFiles "${TORCH_LIBRARIES}")
KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc BoostAODE.cc
Mst.cc Proposal.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
target_link_libraries(BayesNet mdlp FeatureSelect "${TORCH_LIBRARIES}")

View File

@ -43,7 +43,7 @@ namespace bayesnet {
{
dataset = X;
buildDataset(y);
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
return build(features, className, states, weights);
}
// X is nxm where n is the number of features and m the number of samples
@ -55,13 +55,13 @@ namespace bayesnet {
}
auto ytmp = torch::tensor(y, kInt32);
buildDataset(ytmp);
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
return build(features, className, states, weights);
}
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
{
this->dataset = dataset;
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
return build(features, className, states, weights);
}
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)

View File

@ -5,7 +5,6 @@
namespace bayesnet {
Network::Network() : features(vector<string>()), className(""), classNumStates(0), fitted(false) {}
Network::Network(float maxT) : features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.
getmaxThreads()), fitted(other.fitted)
{
@ -174,6 +173,7 @@ namespace bayesnet {
void Network::completeFit(const map<string, vector<int>>& states, const torch::Tensor& weights)
{
setStates(states);
laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
if (maxThreadsRunning < 1) {
maxThreadsRunning = 1;
@ -347,7 +347,7 @@ namespace bayesnet {
}
// Normalize result
double sum = accumulate(result.begin(), result.end(), 0.0);
transform(result.begin(), result.end(), result.begin(), [sum](double& value) { return value / sum; });
transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
return result;
}
vector<string> Network::show() const
@ -435,6 +435,7 @@ namespace bayesnet {
{
for (auto& node : nodes) {
cout << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << endl;
cout << node.second->getCPT() << endl;
}
}
}

View File

@ -13,7 +13,7 @@ namespace bayesnet {
int classNumStates;
vector<string> features; // Including classname
string className;
int laplaceSmoothing = 1;
double laplaceSmoothing;
torch::Tensor samples; // nxm tensor used to fit the model
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
vector<double> predict_sample(const vector<int>&);
@ -25,7 +25,6 @@ namespace bayesnet {
void setStates(const map<string, vector<int>>&);
public:
Network();
explicit Network(float, int);
explicit Network(float);
explicit Network(Network&);
torch::Tensor& getSamples();

View File

@ -84,7 +84,7 @@ namespace bayesnet {
}
return result;
}
void Node::computeCPT(const torch::Tensor& dataset, const vector<string>& features, const int laplaceSmoothing, const torch::Tensor& weights)
void Node::computeCPT(const torch::Tensor& dataset, const vector<string>& features, const double laplaceSmoothing, const torch::Tensor& weights)
{
dimensions.clear();
// Get dimensions of the CPT
@ -111,7 +111,7 @@ namespace bayesnet {
coordinates.push_back(dataset.index({ parent_index, n_sample }));
}
// Increment the count of the corresponding coordinate
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<float>());
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
}
// Normalize the counts
cpTable = cpTable / cpTable.sum(0);

View File

@ -26,7 +26,7 @@ namespace bayesnet {
vector<Node*>& getParents();
vector<Node*>& getChildren();
torch::Tensor& getCPT();
void computeCPT(const torch::Tensor& dataset, const vector<string>& features, const int laplaceSmoothing, const torch::Tensor& weights);
void computeCPT(const torch::Tensor& dataset, const vector<string>& features, const double laplaceSmoothing, const torch::Tensor& weights);
int getNumStates() const;
void setNumStates(int);
unsigned minFill();

View File

@ -22,6 +22,8 @@ namespace bayesnet {
auto root = mi[mi.size() - 1].first;
// 2. Compute mutual information between each feature and the class
auto weights_matrix = metrics.conditionalEdge(weights);
cout << "*** Weights matrix ***\n";
cout << weights_matrix << "\n";
// 3. Compute the maximum spanning tree
auto mst = metrics.maximumSpanningTree(features, weights_matrix, root);
// 4. Add edges from the maximum spanning tree to the model