mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-17 16:45:54 +00:00
Refactor cpp library methods
This commit is contained in:
@@ -89,7 +89,7 @@ namespace bayesnet {
|
|||||||
totalWeight += 1;
|
totalWeight += 1;
|
||||||
}
|
}
|
||||||
if (totalWeight == 0)
|
if (totalWeight == 0)
|
||||||
throw invalid_argument("Total weight should not be zero");
|
return 0;
|
||||||
double entropyValue = 0;
|
double entropyValue = 0;
|
||||||
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
|
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
|
||||||
double p_f = featureCounts[value].item<double>() / totalWeight;
|
double p_f = featureCounts[value].item<double>() / totalWeight;
|
||||||
|
@@ -21,6 +21,10 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
return maxThreads;
|
return maxThreads;
|
||||||
}
|
}
|
||||||
|
torch::Tensor& Network::getSamples()
|
||||||
|
{
|
||||||
|
return samples;
|
||||||
|
}
|
||||||
void Network::addNode(string name, int numStates)
|
void Network::addNode(string name, int numStates)
|
||||||
{
|
{
|
||||||
if (nodes.find(name) != nodes.end()) {
|
if (nodes.find(name) != nodes.end()) {
|
||||||
@@ -241,83 +245,5 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
double Network::mutual_info(torch::Tensor& first, torch::Tensor& second)
|
|
||||||
{
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
torch::Tensor Network::conditionalEdgeWeight()
|
|
||||||
{
|
|
||||||
auto result = vector<double>();
|
|
||||||
auto source = vector<string>(features);
|
|
||||||
source.push_back(className);
|
|
||||||
auto combinations = nodes[className]->combinations(source);
|
|
||||||
auto margin = nodes[className]->getCPT();
|
|
||||||
for (auto [first, second] : combinations) {
|
|
||||||
int64_t index_first = find(features.begin(), features.end(), first) - features.begin();
|
|
||||||
int64_t index_second = find(features.begin(), features.end(), second) - features.begin();
|
|
||||||
double accumulated = 0;
|
|
||||||
for (int value = 0; value < classNumStates; ++value) {
|
|
||||||
auto mask = samples.index({ "...", -1 }) == value;
|
|
||||||
auto first_dataset = samples.index({ mask, index_first });
|
|
||||||
auto second_dataset = samples.index({ mask, index_second });
|
|
||||||
auto mi = mutualInformation(first_dataset, second_dataset);
|
|
||||||
auto pb = margin[value].item<float>();
|
|
||||||
accumulated += pb * mi;
|
|
||||||
}
|
|
||||||
result.push_back(accumulated);
|
|
||||||
}
|
|
||||||
long n_vars = source.size();
|
|
||||||
auto matrix = torch::zeros({ n_vars, n_vars });
|
|
||||||
auto indices = torch::triu_indices(n_vars, n_vars, 1);
|
|
||||||
for (auto i = 0; i < result.size(); ++i) {
|
|
||||||
auto x = indices[0][i];
|
|
||||||
auto y = indices[1][i];
|
|
||||||
matrix[x][y] = result[i];
|
|
||||||
matrix[y][x] = result[i];
|
|
||||||
}
|
|
||||||
return matrix;
|
|
||||||
}
|
|
||||||
double Network::entropy(torch::Tensor& feature)
|
|
||||||
{
|
|
||||||
torch::Tensor counts = feature.bincount();
|
|
||||||
int totalWeight = counts.sum().item<int>();
|
|
||||||
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
|
|
||||||
torch::Tensor logProbs = torch::log(probs);
|
|
||||||
torch::Tensor entropy = -probs * logProbs;
|
|
||||||
return entropy.nansum().item<double>();
|
|
||||||
}
|
|
||||||
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
|
|
||||||
double Network::conditionalEntropy(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
|
|
||||||
{
|
|
||||||
int numSamples = firstFeature.sizes()[0];
|
|
||||||
torch::Tensor featureCounts = secondFeature.bincount();
|
|
||||||
unordered_map<int, unordered_map<int, double>> jointCounts;
|
|
||||||
double totalWeight = 0;
|
|
||||||
for (auto i = 0; i < numSamples; i++) {
|
|
||||||
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += 1;
|
|
||||||
totalWeight += 1;
|
|
||||||
}
|
|
||||||
if (totalWeight == 0)
|
|
||||||
throw invalid_argument("Total weight should not be zero");
|
|
||||||
double entropyValue = 0;
|
|
||||||
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
|
|
||||||
double p_f = featureCounts[value].item<double>() / totalWeight;
|
|
||||||
double entropy_f = 0;
|
|
||||||
for (auto& [label, jointCount] : jointCounts[value]) {
|
|
||||||
double p_l_f = jointCount / featureCounts[value].item<double>();
|
|
||||||
if (p_l_f > 0) {
|
|
||||||
entropy_f -= p_l_f * log(p_l_f);
|
|
||||||
} else {
|
|
||||||
entropy_f = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
entropyValue += p_f * entropy_f;
|
|
||||||
}
|
|
||||||
return entropyValue;
|
|
||||||
}
|
|
||||||
// I(X;Y) = H(Y) - H(Y|X)
|
|
||||||
double Network::mutualInformation(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
|
|
||||||
{
|
|
||||||
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@@ -15,6 +15,7 @@ namespace bayesnet {
|
|||||||
vector<string> features;
|
vector<string> features;
|
||||||
string className;
|
string className;
|
||||||
int laplaceSmoothing;
|
int laplaceSmoothing;
|
||||||
|
torch::Tensor samples;
|
||||||
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
||||||
vector<double> predict_sample(const vector<int>&);
|
vector<double> predict_sample(const vector<int>&);
|
||||||
vector<double> exactInference(map<string, int>&);
|
vector<double> exactInference(map<string, int>&);
|
||||||
@@ -24,12 +25,12 @@ namespace bayesnet {
|
|||||||
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
|
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
|
||||||
double mutualInformation(torch::Tensor&, torch::Tensor&);
|
double mutualInformation(torch::Tensor&, torch::Tensor&);
|
||||||
public:
|
public:
|
||||||
torch::Tensor samples;
|
|
||||||
Network();
|
Network();
|
||||||
Network(float, int);
|
Network(float, int);
|
||||||
Network(float);
|
Network(float);
|
||||||
Network(Network&);
|
Network(Network&);
|
||||||
~Network();
|
~Network();
|
||||||
|
torch::Tensor& getSamples();
|
||||||
float getmaxThreads();
|
float getmaxThreads();
|
||||||
void addNode(string, int);
|
void addNode(string, int);
|
||||||
void addEdge(const string, const string);
|
void addEdge(const string, const string);
|
||||||
|
Reference in New Issue
Block a user