First cfs working version
This commit is contained in:
parent
ca833a34f5
commit
e7ded68267
@ -13,10 +13,10 @@ namespace bayesnet {
|
|||||||
void BoostAODE::buildModel(const torch::Tensor& weights)
|
void BoostAODE::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
// Models shall be built in trainModel
|
// Models shall be built in trainModel
|
||||||
|
models.clear();
|
||||||
|
n_models = 0;
|
||||||
// Prepare the validation dataset
|
// Prepare the validation dataset
|
||||||
auto y_ = dataset.index({ -1, "..." });
|
auto y_ = dataset.index({ -1, "..." });
|
||||||
int nSamples = dataset.size(1);
|
|
||||||
int nFeatures = dataset.size(0) - 1;
|
|
||||||
if (convergence) {
|
if (convergence) {
|
||||||
// Prepare train & validation sets from train data
|
// Prepare train & validation sets from train data
|
||||||
auto fold = platform::StratifiedKFold(5, y_, 271);
|
auto fold = platform::StratifiedKFold(5, y_, 271);
|
||||||
@ -41,8 +41,8 @@ namespace bayesnet {
|
|||||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
||||||
y_train = y_;
|
y_train = y_;
|
||||||
}
|
}
|
||||||
if (cfs != "") {
|
if (cfs) {
|
||||||
initializeModels(nSamples, nFeatures);
|
initializeModels();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
|
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
|
||||||
@ -82,18 +82,18 @@ namespace bayesnet {
|
|||||||
EVP_MD_CTX_free(mdctx);
|
EVP_MD_CTX_free(mdctx);
|
||||||
stringstream oss;
|
stringstream oss;
|
||||||
for (unsigned int i = 0; i < hash_len; i++) {
|
for (unsigned int i = 0; i < hash_len; i++) {
|
||||||
oss << hex << (int)hash[i];
|
oss << hex << setfill('0') << setw(2) << (int)hash[i];
|
||||||
}
|
}
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void BoostAODE::initializeModels(int nSamples, int nFeatures)
|
void BoostAODE::initializeModels()
|
||||||
{
|
{
|
||||||
// Read the CFS features
|
// Read the CFS features
|
||||||
string output = "[", prefix = "";
|
string output = "[", prefix = "";
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for (const auto& feature : features) {
|
for (const auto& feature : features) {
|
||||||
output += prefix + feature;
|
output += prefix + "'" + feature + "'";
|
||||||
if (first) {
|
if (first) {
|
||||||
prefix = ", ";
|
prefix = ", ";
|
||||||
first = false;
|
first = false;
|
||||||
@ -103,21 +103,30 @@ namespace bayesnet {
|
|||||||
// std::size_t str_hash = std::hash<std::string>{}(output);
|
// std::size_t str_hash = std::hash<std::string>{}(output);
|
||||||
string str_hash = sha256(output);
|
string str_hash = sha256(output);
|
||||||
stringstream oss;
|
stringstream oss;
|
||||||
oss << "cfs/" << str_hash << ".json";
|
oss << platform::Paths::cfs() << str_hash << ".json";
|
||||||
string name = oss.str();
|
string name = oss.str();
|
||||||
ifstream file(name);
|
ifstream file(name);
|
||||||
|
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||||
if (file.is_open()) {
|
if (file.is_open()) {
|
||||||
nlohmann::json features = nlohmann::json::parse(file);
|
nlohmann::json cfsFeatures = nlohmann::json::parse(file);
|
||||||
file.close();
|
file.close();
|
||||||
cout << "features: " << features.dump() << endl;
|
for (const string& feature : cfsFeatures) {
|
||||||
|
// cout << "Feature: [" << feature << "]" << endl;
|
||||||
|
auto pos = find(features.begin(), features.end(), feature);
|
||||||
|
if (pos == features.end())
|
||||||
|
throw runtime_error("Feature " + feature + " not found in dataset");
|
||||||
|
int numFeature = pos - features.begin();
|
||||||
|
cout << "Feature: [" << feature << "] " << numFeature << endl;
|
||||||
|
models.push_back(std::make_unique<SPODE>(numFeature));
|
||||||
|
models.back()->fit(dataset, features, className, states, weights_);
|
||||||
|
n_models++;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
throw runtime_error("File " + name + " not found");
|
throw runtime_error("File " + name + " not found");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
models.clear();
|
|
||||||
n_models = 0;
|
|
||||||
if (maxModels == 0)
|
if (maxModels == 0)
|
||||||
maxModels = .1 * n > 10 ? .1 * n : n;
|
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||||
|
@ -15,13 +15,13 @@ namespace bayesnet {
|
|||||||
private:
|
private:
|
||||||
torch::Tensor dataset_;
|
torch::Tensor dataset_;
|
||||||
torch::Tensor X_train, y_train, X_test, y_test;
|
torch::Tensor X_train, y_train, X_test, y_test;
|
||||||
void initializeModels(int nSamples, int nFeatures);
|
void initializeModels();
|
||||||
// Hyperparameters
|
// Hyperparameters
|
||||||
bool repeatSparent = false; // if true, a feature can be selected more than once
|
bool repeatSparent = false; // if true, a feature can be selected more than once
|
||||||
int maxModels = 0;
|
int maxModels = 0;
|
||||||
bool ascending = false; //Process KBest features ascending or descending order
|
bool ascending = false; //Process KBest features ascending or descending order
|
||||||
bool convergence = false; //if true, stop when the model does not improve
|
bool convergence = false; //if true, stop when the model does not improve
|
||||||
string cfs = ""; // if not empty, use CFS to select features
|
bool cfs = false; // if true use CFS to select features stored in cfs folder with sha256(features) file_name
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -7,6 +7,7 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
static std::string results() { return "results/"; }
|
static std::string results() { return "results/"; }
|
||||||
static std::string excel() { return "excel/"; }
|
static std::string excel() { return "excel/"; }
|
||||||
|
static std::string cfs() { return "cfs/"; }
|
||||||
static std::string datasets()
|
static std::string datasets()
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
|
Loading…
Reference in New Issue
Block a user