Complete first working cfs

This commit is contained in:
Ricardo Montañana Gómez 2023-10-11 11:33:29 +02:00
parent e7ded68267
commit 47e2b138c5
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 23 additions and 21 deletions

View File

@ -65,7 +65,8 @@ endif (ENABLE_CLANG_TIDY)
add_git_submodule("lib/mdlp") add_git_submodule("lib/mdlp")
add_git_submodule("lib/argparse") add_git_submodule("lib/argparse")
add_git_submodule("lib/json") add_git_submodule("lib/json")
find_library(XLSXWRITER_LIB libxlsxwriter.dylib PATHS /usr/local/lib)
find_library(XLSXWRITER_LIB libxlsxwriter.dylib PATHS /usr/local/lib ${HOME}/lib/usr/local/lib)
# Subdirectories # Subdirectories
# -------------- # --------------

View File

@ -41,9 +41,6 @@ namespace bayesnet {
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
y_train = y_; y_train = y_;
} }
if (cfs) {
initializeModels();
}
} }
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
{ {
@ -87,8 +84,9 @@ namespace bayesnet {
return oss.str(); return oss.str();
} }
void BoostAODE::initializeModels() unordered_set<int> BoostAODE::initializeModels()
{ {
unordered_set<int> featuresUsed;
// Read the CFS features // Read the CFS features
string output = "[", prefix = ""; string output = "[", prefix = "";
bool first = true; bool first = true;
@ -110,28 +108,30 @@ namespace bayesnet {
if (file.is_open()) { if (file.is_open()) {
nlohmann::json cfsFeatures = nlohmann::json::parse(file); nlohmann::json cfsFeatures = nlohmann::json::parse(file);
file.close(); file.close();
for (const string& feature : cfsFeatures) { for (const int& feature : cfsFeatures) {
// cout << "Feature: [" << feature << "]" << endl; // cout << "Feature: [" << feature << "] " << feature << " " << features.at(feature) << endl;
auto pos = find(features.begin(), features.end(), feature); featuresUsed.insert(feature);
if (pos == features.end()) unique_ptr<Classifier> model = std::make_unique<SPODE>(feature);
throw runtime_error("Feature " + feature + " not found in dataset"); model->fit(dataset, features, className, states, weights_);
int numFeature = pos - features.begin(); models.push_back(std::move(model));
cout << "Feature: [" << feature << "] " << numFeature << endl; significanceModels.push_back(1.0);
models.push_back(std::make_unique<SPODE>(numFeature));
models.back()->fit(dataset, features, className, states, weights_);
n_models++; n_models++;
} }
} else { } else {
throw runtime_error("File " + name + " not found"); throw runtime_error("File " + name + " not found");
} }
return featuresUsed;
} }
void BoostAODE::trainModel(const torch::Tensor& weights) void BoostAODE::trainModel(const torch::Tensor& weights)
{ {
unordered_set<int> featuresUsed;
if (cfs) {
featuresUsed = initializeModels();
}
if (maxModels == 0) if (maxModels == 0)
maxModels = .1 * n > 10 ? .1 * n : n; maxModels = .1 * n > 10 ? .1 * n : n;
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
bool exitCondition = false; bool exitCondition = false;
unordered_set<int> featuresUsed;
// Variables to control the accuracy finish condition // Variables to control the accuracy finish condition
double priorAccuracy = 0.0; double priorAccuracy = 0.0;
double delta = 1.0; double delta = 1.0;
@ -150,16 +150,16 @@ namespace bayesnet {
unique_ptr<Classifier> model; unique_ptr<Classifier> model;
auto feature = featureSelection[0]; auto feature = featureSelection[0];
if (!repeatSparent || featuresUsed.size() < featureSelection.size()) { if (!repeatSparent || featuresUsed.size() < featureSelection.size()) {
bool found = false; bool used = true;
for (auto feat : featureSelection) { for (const auto& feat : featureSelection) {
if (find(featuresUsed.begin(), featuresUsed.end(), feat) != featuresUsed.end()) { if (find(featuresUsed.begin(), featuresUsed.end(), feat) != featuresUsed.end()) {
continue; continue;
} }
found = true; used = false;
feature = feat; feature = feat;
break; break;
} }
if (!found) { if (used) {
exitCondition = true; exitCondition = true;
continue; continue;
} }
@ -199,7 +199,7 @@ namespace bayesnet {
count++; count++;
} }
} }
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance; exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
} }
if (featuresUsed.size() != features.size()) { if (featuresUsed.size() != features.size()) {
status = WARNING; status = WARNING;

View File

@ -1,6 +1,7 @@
#ifndef BOOSTAODE_H #ifndef BOOSTAODE_H
#define BOOSTAODE_H #define BOOSTAODE_H
#include "Ensemble.h" #include "Ensemble.h"
#include <map>
#include "SPODE.h" #include "SPODE.h"
namespace bayesnet { namespace bayesnet {
class BoostAODE : public Ensemble { class BoostAODE : public Ensemble {
@ -15,7 +16,7 @@ namespace bayesnet {
private: private:
torch::Tensor dataset_; torch::Tensor dataset_;
torch::Tensor X_train, y_train, X_test, y_test; torch::Tensor X_train, y_train, X_test, y_test;
void initializeModels(); unordered_set<int> initializeModels();
// Hyperparameters // Hyperparameters
bool repeatSparent = false; // if true, a feature can be selected more than once bool repeatSparent = false; // if true, a feature can be selected more than once
int maxModels = 0; int maxModels = 0;