Compare commits

...

2 Commits

Author SHA1 Message Date
f26ea1f0ac Add weights to BayesMetrics 2023-08-13 12:56:06 +02:00
af0419c9da First approx with const 1 weights 2023-08-13 00:59:02 +02:00
12 changed files with 60 additions and 51 deletions

5
.vscode/launch.json vendored
View File

@@ -25,12 +25,13 @@
"program": "${workspaceFolder}/build/src/Platform/main",
"args": [
"-m",
"SPODELd",
"SPODE",
"-p",
"/Users/rmontanana/Code/discretizbench/datasets",
"--stratified",
"--discretize",
"-d",
"iris"
"letter"
],
"cwd": "/Users/rmontanana/Code/discretizbench",
},

View File

@@ -32,7 +32,7 @@ namespace bayesnet {
}
return result;
}
torch::Tensor Metrics::conditionalEdge()
torch::Tensor Metrics::conditionalEdge(const torch::Tensor& weights)
{
auto result = vector<double>();
auto source = vector<string>(features);
@@ -52,7 +52,7 @@ namespace bayesnet {
auto mask = samples.index({ -1, "..." }) == value;
auto first_dataset = samples.index({ index_first, mask });
auto second_dataset = samples.index({ index_second, mask });
auto mi = mutualInformation(first_dataset, second_dataset);
auto mi = mutualInformation(first_dataset, second_dataset, weights);
auto pb = margin[value].item<float>();
accumulated += pb * mi;
}
@@ -70,15 +70,16 @@ namespace bayesnet {
return matrix;
}
// To use in Python
vector<float> Metrics::conditionalEdgeWeights()
vector<float> Metrics::conditionalEdgeWeights(vector<float>& weights_)
{
auto matrix = conditionalEdge();
const torch::Tensor weights = torch::tensor(weights_);
auto matrix = conditionalEdge(weights);
std::vector<float> v(matrix.data_ptr<float>(), matrix.data_ptr<float>() + matrix.numel());
return v;
}
double Metrics::entropy(const torch::Tensor& feature)
double Metrics::entropy(const torch::Tensor& feature, const torch::Tensor& weights)
{
torch::Tensor counts = feature.bincount();
torch::Tensor counts = feature.bincount(weights);
int totalWeight = counts.sum().item<int>();
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
torch::Tensor logProbs = torch::log(probs);
@@ -86,15 +87,15 @@ namespace bayesnet {
return entropy.nansum().item<double>();
}
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
double Metrics::conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature)
double Metrics::conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights)
{
int numSamples = firstFeature.sizes()[0];
torch::Tensor featureCounts = secondFeature.bincount();
torch::Tensor featureCounts = secondFeature.bincount(weights);
unordered_map<int, unordered_map<int, double>> jointCounts;
double totalWeight = 0;
for (auto i = 0; i < numSamples; i++) {
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += 1;
totalWeight += 1;
totalWeight += weights[i].item<float>();
}
if (totalWeight == 0)
return 0;
@@ -115,9 +116,9 @@ namespace bayesnet {
return entropyValue;
}
// I(X;Y) = H(Y) - H(Y|X)
double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature)
double Metrics::mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights)
{
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
return entropy(firstFeature, weights) - conditionalEntropy(firstFeature, secondFeature, weights);
}
/*
Compute the maximum spanning tree considering the weights as distances

View File

@@ -12,16 +12,16 @@ namespace bayesnet {
vector<string> features;
string className;
int classNumStates = 0;
double entropy(const Tensor& feature, const Tensor& weights);
double conditionalEntropy(const Tensor& firstFeature, const Tensor& secondFeature, const Tensor& weights);
vector<pair<string, string>> doCombinations(const vector<string>&);
public:
Metrics() = default;
Metrics(const Tensor&, const vector<string>&, const string&, const int);
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
double entropy(const Tensor&);
double conditionalEntropy(const Tensor&, const Tensor&);
double mutualInformation(const Tensor&, const Tensor&);
vector<float> conditionalEdgeWeights(); // To use in Python
Tensor conditionalEdge();
vector<pair<string, string>> doCombinations(const vector<string>&);
Metrics(const torch::Tensor& samples, const vector<string>& features, const string& className, const int classNumStates);
Metrics(const vector<vector<int>>& vsamples, const vector<int>& labels, const vector<string>& features, const string& className, const int classNumStates);
double mutualInformation(const Tensor& firstFeature, const Tensor& secondFeature, const Tensor& weights);
vector<float> conditionalEdgeWeights(vector<float>& weights); // To use in Python
Tensor conditionalEdge(const torch::Tensor& weights);
vector<pair<int, int>> maximumSpanningTree(const vector<string>& features, const Tensor& weights, const int root);
};
}

View File

@@ -37,7 +37,8 @@ namespace bayesnet {
}
void Classifier::trainModel()
{
model.fit(dataset, features, className, states);
const torch::Tensor weights = torch::ones({ m });
model.fit(dataset, weights, features, className, states);
}
// X is nxm where n is the number of features and m the number of samples
Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)

View File

@@ -14,13 +14,14 @@ namespace bayesnet {
Classifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
protected:
bool fitted;
Network model;
int m, n; // m: number of samples, n: number of features
Tensor dataset; // (n+1)xm tensor
Network model;
Metrics metrics;
vector<string> features;
string className;
map<string, vector<int>> states;
Tensor dataset; // (n+1)xm tensor
Tensor weights;
void checkFitParameters();
virtual void buildModel() = 0;
void trainModel() override;

View File

@@ -32,10 +32,10 @@ namespace bayesnet {
vector <float> mi;
for (auto i = 0; i < features.size(); i++) {
Tensor firstFeature = dataset.index({ i, "..." });
mi.push_back(metrics.mutualInformation(firstFeature, y));
mi.push_back(metrics.mutualInformation(firstFeature, y, weights));
}
// 2. Compute class conditional mutual information I(Xi;XjIC), f or each
auto conditionalEdgeWeights = metrics.conditionalEdge();
auto conditionalEdgeWeights = metrics.conditionalEdge(weights);
// 3. Let the used variable list, S, be empty.
vector<int> S;
// 4. Let the DAG network being constructed, BN, begin with a single

View File

@@ -104,8 +104,11 @@ namespace bayesnet {
{
return nodes;
}
void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states, const torch::Tensor& weights)
{
if (weights.size(0) != n_samples) {
throw invalid_argument("Weights must have the same number of elements as samples in Network::fit");
}
if (n_samples != n_samples_y) {
throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")");
}
@@ -136,28 +139,29 @@ namespace bayesnet {
classNumStates = nodes[className]->getNumStates();
}
// X comes in nxm, where n is the number of features and m the number of samples
void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
{
checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states);
checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
this->className = className;
Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
samples = torch::cat({ X , ytmp }, 0);
for (int i = 0; i < featureNames.size(); ++i) {
auto row_feature = X.index({ i, "..." });
}
completeFit(states);
completeFit(states, weights);
}
void Network::fit(const torch::Tensor& samples, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
{
checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states);
checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
this->className = className;
this->samples = samples;
completeFit(states);
completeFit(states, weights);
}
// input_data comes in nxm, where n is the number of features and m the number of samples
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<float>& weights_, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
{
checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states);
const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
this->className = className;
// Build tensor of samples (nxm) (n+1 because of the class)
samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
@@ -165,9 +169,9 @@ namespace bayesnet {
samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
}
samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
completeFit(states);
completeFit(states, weights);
}
void Network::completeFit(const map<string, vector<int>>& states)
void Network::completeFit(const map<string, vector<int>>& states, const torch::Tensor& weights)
{
setStates(states);
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
@@ -182,7 +186,7 @@ namespace bayesnet {
while (nextNodeIndex < nodes.size()) {
unique_lock<mutex> lock(mtx);
cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; });
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() {
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads, &weights]() {
while (true) {
unique_lock<mutex> lock(mtx);
if (nextNodeIndex >= nodes.size()) {
@@ -191,7 +195,7 @@ namespace bayesnet {
auto& pair = *std::next(nodes.begin(), nextNodeIndex);
++nextNodeIndex;
lock.unlock();
pair.second->computeCPT(samples, features, laplaceSmoothing);
pair.second->computeCPT(samples, features, laplaceSmoothing, weights);
lock.lock();
nodes[pair.first] = std::move(pair.second);
lock.unlock();

View File

@@ -20,8 +20,8 @@ namespace bayesnet {
vector<double> predict_sample(const torch::Tensor&);
vector<double> exactInference(map<string, int>&);
double computeFactor(map<string, int>&);
void completeFit(const map<string, vector<int>>&);
void checkFitData(int n_features, int n_samples, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>&);
void completeFit(const map<string, vector<int>>& states, const torch::Tensor& weights);
void checkFitData(int n_features, int n_samples, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states, const torch::Tensor& weights);
void setStates(const map<string, vector<int>>&);
public:
Network();
@@ -39,9 +39,9 @@ namespace bayesnet {
int getNumEdges() const;
int getClassNumStates() const;
string getClassName() const;
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const map<string, vector<int>>&);
void fit(const torch::Tensor&, const torch::Tensor&, const vector<string>&, const string&, const map<string, vector<int>>&);
void fit(const torch::Tensor&, const vector<string>&, const string&, const map<string, vector<int>>&);
void fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<float>& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
void fit(const torch::Tensor& samples, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
vector<int> predict(const vector<vector<int>>&); // Return mx1 vector of predictions
torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions
torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba);

View File

@@ -84,7 +84,7 @@ namespace bayesnet {
}
return result;
}
void Node::computeCPT(const torch::Tensor& dataset, const vector<string>& features, const int laplaceSmoothing)
void Node::computeCPT(const torch::Tensor& dataset, const vector<string>& features, const int laplaceSmoothing, const torch::Tensor& weights)
{
dimensions.clear();
// Get dimensions of the CPT
@@ -111,7 +111,7 @@ namespace bayesnet {
coordinates.push_back(dataset.index({ parent_index, n_sample }));
}
// Increment the count of the corresponding coordinate
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1);
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<float>());
}
// Normalize the counts
cpTable = cpTable / cpTable.sum(0);

View File

@@ -26,7 +26,7 @@ namespace bayesnet {
vector<Node*>& getParents();
vector<Node*>& getChildren();
torch::Tensor& getCPT();
void computeCPT(const torch::Tensor&, const vector<string>&, const int);
void computeCPT(const torch::Tensor& dataset, const vector<string>& features, const int laplaceSmoothing, const torch::Tensor& weights);
int getNumStates() const;
void setNumStates(int);
unsigned minFill();

View File

@@ -65,7 +65,8 @@ namespace bayesnet {
//Update new states of the feature/node
states[pFeatures[index]] = xStates;
}
model.fit(pDataset, pFeatures, pClassName, states);
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
model.fit(pDataset, weights, pFeatures, pClassName, states);
}
return states;
}

View File

@@ -15,15 +15,15 @@ namespace bayesnet {
Tensor class_dataset = dataset.index({ -1, "..." });
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
Tensor feature_dataset = dataset.index({ i, "..." });
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset);
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset, weights);
mi.push_back({ i, mi_value });
}
sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
auto root = mi[mi.size() - 1].first;
// 2. Compute mutual information between each feature and the class
auto weights = metrics.conditionalEdge();
auto weights_matrix = metrics.conditionalEdge(weights);
// 3. Compute the maximum spanning tree
auto mst = metrics.maximumSpanningTree(features, weights, root);
auto mst = metrics.maximumSpanningTree(features, weights_matrix, root);
// 4. Add edges from the maximum spanning tree to the model
for (auto i = 0; i < mst.size(); ++i) {
auto [from, to] = mst[i];