First cfs working version
This commit is contained in:
parent
ca833a34f5
commit
e7ded68267
@ -13,10 +13,10 @@ namespace bayesnet {
|
||||
void BoostAODE::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// Models shall be built in trainModel
|
||||
models.clear();
|
||||
n_models = 0;
|
||||
// Prepare the validation dataset
|
||||
auto y_ = dataset.index({ -1, "..." });
|
||||
int nSamples = dataset.size(1);
|
||||
int nFeatures = dataset.size(0) - 1;
|
||||
if (convergence) {
|
||||
// Prepare train & validation sets from train data
|
||||
auto fold = platform::StratifiedKFold(5, y_, 271);
|
||||
@ -41,8 +41,8 @@ namespace bayesnet {
|
||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
||||
y_train = y_;
|
||||
}
|
||||
if (cfs != "") {
|
||||
initializeModels(nSamples, nFeatures);
|
||||
if (cfs) {
|
||||
initializeModels();
|
||||
}
|
||||
}
|
||||
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
|
||||
@ -82,18 +82,18 @@ namespace bayesnet {
|
||||
EVP_MD_CTX_free(mdctx);
|
||||
stringstream oss;
|
||||
for (unsigned int i = 0; i < hash_len; i++) {
|
||||
oss << hex << (int)hash[i];
|
||||
oss << hex << setfill('0') << setw(2) << (int)hash[i];
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void BoostAODE::initializeModels(int nSamples, int nFeatures)
|
||||
void BoostAODE::initializeModels()
|
||||
{
|
||||
// Read the CFS features
|
||||
string output = "[", prefix = "";
|
||||
bool first = true;
|
||||
for (const auto& feature : features) {
|
||||
output += prefix + feature;
|
||||
output += prefix + "'" + feature + "'";
|
||||
if (first) {
|
||||
prefix = ", ";
|
||||
first = false;
|
||||
@ -103,21 +103,30 @@ namespace bayesnet {
|
||||
// std::size_t str_hash = std::hash<std::string>{}(output);
|
||||
string str_hash = sha256(output);
|
||||
stringstream oss;
|
||||
oss << "cfs/" << str_hash << ".json";
|
||||
oss << platform::Paths::cfs() << str_hash << ".json";
|
||||
string name = oss.str();
|
||||
ifstream file(name);
|
||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||
if (file.is_open()) {
|
||||
nlohmann::json features = nlohmann::json::parse(file);
|
||||
nlohmann::json cfsFeatures = nlohmann::json::parse(file);
|
||||
file.close();
|
||||
cout << "features: " << features.dump() << endl;
|
||||
for (const string& feature : cfsFeatures) {
|
||||
// cout << "Feature: [" << feature << "]" << endl;
|
||||
auto pos = find(features.begin(), features.end(), feature);
|
||||
if (pos == features.end())
|
||||
throw runtime_error("Feature " + feature + " not found in dataset");
|
||||
int numFeature = pos - features.begin();
|
||||
cout << "Feature: [" << feature << "] " << numFeature << endl;
|
||||
models.push_back(std::make_unique<SPODE>(numFeature));
|
||||
models.back()->fit(dataset, features, className, states, weights_);
|
||||
n_models++;
|
||||
}
|
||||
} else {
|
||||
throw runtime_error("File " + name + " not found");
|
||||
}
|
||||
}
|
||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
models.clear();
|
||||
n_models = 0;
|
||||
if (maxModels == 0)
|
||||
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||
|
@ -15,13 +15,13 @@ namespace bayesnet {
|
||||
private:
|
||||
torch::Tensor dataset_;
|
||||
torch::Tensor X_train, y_train, X_test, y_test;
|
||||
void initializeModels(int nSamples, int nFeatures);
|
||||
void initializeModels();
|
||||
// Hyperparameters
|
||||
bool repeatSparent = false; // if true, a feature can be selected more than once
|
||||
int maxModels = 0;
|
||||
bool ascending = false; //Process KBest features ascending or descending order
|
||||
bool convergence = false; //if true, stop when the model does not improve
|
||||
string cfs = ""; // if not empty, use CFS to select features
|
||||
bool cfs = false; // if true use CFS to select features stored in cfs folder with sha256(features) file_name
|
||||
};
|
||||
}
|
||||
#endif
|
@ -7,6 +7,7 @@ namespace platform {
|
||||
public:
|
||||
static std::string results() { return "results/"; }
|
||||
static std::string excel() { return "excel/"; }
|
||||
static std::string cfs() { return "cfs/"; }
|
||||
static std::string datasets()
|
||||
{
|
||||
auto env = platform::DotEnv();
|
||||
|
Loading…
Reference in New Issue
Block a user