First cfs working version

This commit is contained in:
Ricardo Montañana Gómez 2023-10-10 23:00:38 +02:00
parent ca833a34f5
commit e7ded68267
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 24 additions and 14 deletions

View File

@ -13,10 +13,10 @@ namespace bayesnet {
void BoostAODE::buildModel(const torch::Tensor& weights) void BoostAODE::buildModel(const torch::Tensor& weights)
{ {
// Models shall be built in trainModel // Models shall be built in trainModel
models.clear();
n_models = 0;
// Prepare the validation dataset // Prepare the validation dataset
auto y_ = dataset.index({ -1, "..." }); auto y_ = dataset.index({ -1, "..." });
int nSamples = dataset.size(1);
int nFeatures = dataset.size(0) - 1;
if (convergence) { if (convergence) {
// Prepare train & validation sets from train data // Prepare train & validation sets from train data
auto fold = platform::StratifiedKFold(5, y_, 271); auto fold = platform::StratifiedKFold(5, y_, 271);
@ -41,8 +41,8 @@ namespace bayesnet {
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
y_train = y_; y_train = y_;
} }
if (cfs != "") { if (cfs) {
initializeModels(nSamples, nFeatures); initializeModels();
} }
} }
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
@ -82,18 +82,18 @@ namespace bayesnet {
EVP_MD_CTX_free(mdctx); EVP_MD_CTX_free(mdctx);
stringstream oss; stringstream oss;
for (unsigned int i = 0; i < hash_len; i++) { for (unsigned int i = 0; i < hash_len; i++) {
oss << hex << (int)hash[i]; oss << hex << setfill('0') << setw(2) << (int)hash[i];
} }
return oss.str(); return oss.str();
} }
void BoostAODE::initializeModels(int nSamples, int nFeatures) void BoostAODE::initializeModels()
{ {
// Read the CFS features // Read the CFS features
string output = "[", prefix = ""; string output = "[", prefix = "";
bool first = true; bool first = true;
for (const auto& feature : features) { for (const auto& feature : features) {
output += prefix + feature; output += prefix + "'" + feature + "'";
if (first) { if (first) {
prefix = ", "; prefix = ", ";
first = false; first = false;
@ -103,21 +103,30 @@ namespace bayesnet {
// std::size_t str_hash = std::hash<std::string>{}(output); // std::size_t str_hash = std::hash<std::string>{}(output);
string str_hash = sha256(output); string str_hash = sha256(output);
stringstream oss; stringstream oss;
oss << "cfs/" << str_hash << ".json"; oss << platform::Paths::cfs() << str_hash << ".json";
string name = oss.str(); string name = oss.str();
ifstream file(name); ifstream file(name);
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
if (file.is_open()) { if (file.is_open()) {
nlohmann::json features = nlohmann::json::parse(file); nlohmann::json cfsFeatures = nlohmann::json::parse(file);
file.close(); file.close();
cout << "features: " << features.dump() << endl; for (const string& feature : cfsFeatures) {
// cout << "Feature: [" << feature << "]" << endl;
auto pos = find(features.begin(), features.end(), feature);
if (pos == features.end())
throw runtime_error("Feature " + feature + " not found in dataset");
int numFeature = pos - features.begin();
cout << "Feature: [" << feature << "] " << numFeature << endl;
models.push_back(std::make_unique<SPODE>(numFeature));
models.back()->fit(dataset, features, className, states, weights_);
n_models++;
}
} else { } else {
throw runtime_error("File " + name + " not found"); throw runtime_error("File " + name + " not found");
} }
} }
void BoostAODE::trainModel(const torch::Tensor& weights) void BoostAODE::trainModel(const torch::Tensor& weights)
{ {
models.clear();
n_models = 0;
if (maxModels == 0) if (maxModels == 0)
maxModels = .1 * n > 10 ? .1 * n : n; maxModels = .1 * n > 10 ? .1 * n : n;
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);

View File

@ -15,13 +15,13 @@ namespace bayesnet {
private: private:
torch::Tensor dataset_; torch::Tensor dataset_;
torch::Tensor X_train, y_train, X_test, y_test; torch::Tensor X_train, y_train, X_test, y_test;
void initializeModels(int nSamples, int nFeatures); void initializeModels();
// Hyperparameters // Hyperparameters
bool repeatSparent = false; // if true, a feature can be selected more than once bool repeatSparent = false; // if true, a feature can be selected more than once
int maxModels = 0; int maxModels = 0;
bool ascending = false; //Process KBest features ascending or descending order bool ascending = false; //Process KBest features ascending or descending order
bool convergence = false; //if true, stop when the model does not improve bool convergence = false; //if true, stop when the model does not improve
string cfs = ""; // if not empty, use CFS to select features bool cfs = false; // if true use CFS to select features stored in cfs folder with sha256(features) file_name
}; };
} }
#endif #endif

View File

@ -7,6 +7,7 @@ namespace platform {
public: public:
static std::string results() { return "results/"; } static std::string results() { return "results/"; }
static std::string excel() { return "excel/"; } static std::string excel() { return "excel/"; }
static std::string cfs() { return "cfs/"; }
static std::string datasets() static std::string datasets()
{ {
auto env = platform::DotEnv(); auto env = platform::DotEnv();