BayesNet/bayesnet/classifiers/KDB.cc

111 lines
4.5 KiB
C++
Raw Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2023-07-13 01:15:42 +00:00
#include "KDB.h"
namespace bayesnet {
2023-11-19 21:36:27 +00:00
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta)
{
validHyperparameters = { "k", "theta" };
}
void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
2023-08-22 22:44:10 +00:00
{
auto hyperparameters = hyperparameters_;
2023-08-22 22:44:10 +00:00
if (hyperparameters.contains("k")) {
k = hyperparameters["k"];
hyperparameters.erase("k");
2023-08-22 22:44:10 +00:00
}
if (hyperparameters.contains("theta")) {
theta = hyperparameters["theta"];
hyperparameters.erase("theta");
2023-08-22 22:44:10 +00:00
}
Classifier::setHyperparameters(hyperparameters);
2023-08-22 22:44:10 +00:00
}
2023-08-15 13:04:56 +00:00
void KDB::buildModel(const torch::Tensor& weights)
2023-07-13 01:15:42 +00:00
{
/*
1. For each feature Xi, compute mutual information, I(X;C),
where C is the class.
2. Compute class conditional mutual information I(Xi;XjIC), f or each
pair of features Xi and Xj, where i#j.
3. Let the used variable list, S, be empty.
4. Let the DAG network being constructed, BN, begin with a single
class node, C.
5. Repeat until S includes all domain features
5.1. Select feature Xmax which is not in S and has the largest value
I(Xmax;C).
5.2. Add a node to BN representing Xmax.
5.3. Add an arc from C to Xmax in BN.
5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
the highest value for I(Xmax;X,jC).
5.5. Add Xmax to S.
Compute the conditional probabilility infered by the structure of BN by
using counts from DB, and output BN.
*/
// 1. For each feature Xi, compute mutual information, I(X;C),
// where C is the class.
2023-08-03 18:22:33 +00:00
addNodes();
2023-11-08 17:45:35 +00:00
const torch::Tensor& y = dataset.index({ -1, "..." });
std::vector<double> mi;
2023-07-13 01:15:42 +00:00
for (auto i = 0; i < features.size(); i++) {
2023-11-08 17:45:35 +00:00
torch::Tensor firstFeature = dataset.index({ i, "..." });
2023-08-13 10:56:06 +00:00
mi.push_back(metrics.mutualInformation(firstFeature, y, weights));
2023-07-13 01:44:33 +00:00
}
// 2. Compute class conditional mutual information I(Xi;XjIC), f or each
2023-08-13 10:56:06 +00:00
auto conditionalEdgeWeights = metrics.conditionalEdge(weights);
2023-07-13 01:44:33 +00:00
// 3. Let the used variable list, S, be empty.
2023-11-08 17:45:35 +00:00
std::vector<int> S;
2023-07-13 01:44:33 +00:00
// 4. Let the DAG network being constructed, BN, begin with a single
// class node, C.
// 5. Repeat until S includes all domain features
// 5.1. Select feature Xmax which is not in S and has the largest value
// I(Xmax;C).
auto order = argsort(mi);
for (auto idx : order) {
// 5.2. Add a node to BN representing Xmax.
// 5.3. Add an arc from C to Xmax in BN.
model.addEdge(className, features[idx]);
// 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
// the highest value for I(Xmax;X,jC).
2023-07-13 08:58:27 +00:00
add_m_edges(idx, S, conditionalEdgeWeights);
2023-07-13 01:44:33 +00:00
// 5.5. Add Xmax to S.
S.push_back(idx);
2023-07-13 01:15:42 +00:00
}
}
2023-11-08 17:45:35 +00:00
void KDB::add_m_edges(int idx, std::vector<int>& S, torch::Tensor& weights)
2023-07-13 08:58:27 +00:00
{
2023-11-08 17:45:35 +00:00
auto n_edges = std::min(k, static_cast<int>(S.size()));
2023-07-13 08:58:27 +00:00
auto cond_w = clone(weights);
bool exit_cond = k == 0;
int num = 0;
while (!exit_cond) {
2023-07-13 14:59:06 +00:00
auto max_minfo = argmax(cond_w.index({ idx, "..." })).item<int>();
2023-07-13 08:58:27 +00:00
auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
try {
2023-07-13 14:59:06 +00:00
model.addEdge(features[max_minfo], features[idx]);
2023-07-13 08:58:27 +00:00
num++;
}
2023-11-08 17:45:35 +00:00
catch (const std::invalid_argument& e) {
2023-07-13 08:58:27 +00:00
// Loops are not allowed
}
}
2023-07-13 14:59:06 +00:00
cond_w.index_put_({ idx, max_minfo }, -1);
auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta);
auto candidates = candidates_mask.nonzero();
2023-07-13 08:58:27 +00:00
exit_cond = num == n_edges || candidates.size(0) == 0;
}
}
2023-11-08 17:45:35 +00:00
std::vector<std::string> KDB::graph(const std::string& title) const
2023-07-15 23:20:47 +00:00
{
2023-11-08 17:45:35 +00:00
std::string header{ title };
2023-07-15 23:20:47 +00:00
if (title == "KDB") {
2023-11-08 17:45:35 +00:00
header += " (k=" + std::to_string(k) + ", theta=" + std::to_string(theta) + ")";
2023-07-15 23:20:47 +00:00
}
2023-07-31 17:53:55 +00:00
return model.graph(header);
2023-07-15 23:20:47 +00:00
}
2023-07-13 01:15:42 +00:00
}