Complete first working cfs

This commit is contained in:
Ricardo Montañana Gómez 2023-10-11 11:33:29 +02:00
parent e7ded68267
commit 47e2b138c5
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 23 additions and 21 deletions

View File

@ -65,7 +65,8 @@ endif (ENABLE_CLANG_TIDY)
add_git_submodule("lib/mdlp")
add_git_submodule("lib/argparse")
add_git_submodule("lib/json")
find_library(XLSXWRITER_LIB libxlsxwriter.dylib PATHS /usr/local/lib)
find_library(XLSXWRITER_LIB libxlsxwriter.dylib PATHS /usr/local/lib ${HOME}/lib/usr/local/lib)
# Subdirectories
# --------------

View File

@ -41,9 +41,6 @@ namespace bayesnet {
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
y_train = y_;
}
if (cfs) {
initializeModels();
}
}
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
{
@ -87,8 +84,9 @@ namespace bayesnet {
return oss.str();
}
void BoostAODE::initializeModels()
unordered_set<int> BoostAODE::initializeModels()
{
unordered_set<int> featuresUsed;
// Read the CFS features
string output = "[", prefix = "";
bool first = true;
@ -110,28 +108,30 @@ namespace bayesnet {
if (file.is_open()) {
nlohmann::json cfsFeatures = nlohmann::json::parse(file);
file.close();
for (const string& feature : cfsFeatures) {
// cout << "Feature: [" << feature << "]" << endl;
auto pos = find(features.begin(), features.end(), feature);
if (pos == features.end())
throw runtime_error("Feature " + feature + " not found in dataset");
int numFeature = pos - features.begin();
cout << "Feature: [" << feature << "] " << numFeature << endl;
models.push_back(std::make_unique<SPODE>(numFeature));
models.back()->fit(dataset, features, className, states, weights_);
for (const int& feature : cfsFeatures) {
// cout << "Feature: [" << feature << "] " << feature << " " << features.at(feature) << endl;
featuresUsed.insert(feature);
unique_ptr<Classifier> model = std::make_unique<SPODE>(feature);
model->fit(dataset, features, className, states, weights_);
models.push_back(std::move(model));
significanceModels.push_back(1.0);
n_models++;
}
} else {
throw runtime_error("File " + name + " not found");
}
return featuresUsed;
}
void BoostAODE::trainModel(const torch::Tensor& weights)
{
unordered_set<int> featuresUsed;
if (cfs) {
featuresUsed = initializeModels();
}
if (maxModels == 0)
maxModels = .1 * n > 10 ? .1 * n : n;
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
bool exitCondition = false;
unordered_set<int> featuresUsed;
// Variables to control the accuracy finish condition
double priorAccuracy = 0.0;
double delta = 1.0;
@ -150,16 +150,16 @@ namespace bayesnet {
unique_ptr<Classifier> model;
auto feature = featureSelection[0];
if (!repeatSparent || featuresUsed.size() < featureSelection.size()) {
bool found = false;
for (auto feat : featureSelection) {
bool used = true;
for (const auto& feat : featureSelection) {
if (find(featuresUsed.begin(), featuresUsed.end(), feat) != featuresUsed.end()) {
continue;
}
found = true;
used = false;
feature = feat;
break;
}
if (!found) {
if (used) {
exitCondition = true;
continue;
}
@ -199,7 +199,7 @@ namespace bayesnet {
count++;
}
}
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
}
if (featuresUsed.size() != features.size()) {
status = WARNING;

View File

@ -1,6 +1,7 @@
#ifndef BOOSTAODE_H
#define BOOSTAODE_H
#include "Ensemble.h"
#include <map>
#include "SPODE.h"
namespace bayesnet {
class BoostAODE : public Ensemble {
@ -15,7 +16,7 @@ namespace bayesnet {
private:
torch::Tensor dataset_;
torch::Tensor X_train, y_train, X_test, y_test;
void initializeModels();
unordered_set<int> initializeModels();
// Hyperparameters
bool repeatSparent = false; // if true, a feature can be selected more than once
int maxModels = 0;