diff --git a/src/common/Dataset.cpp b/src/common/Dataset.cpp index bbeba35..0af0bfa 100644 --- a/src/common/Dataset.cpp +++ b/src/common/Dataset.cpp @@ -187,12 +187,16 @@ namespace platform { } n_samples = Xv[0].size(); n_features = Xv.size(); - if (numericFeaturesIdx.at(0) == -1) { - numericFeatures = std::vector(n_features, true); - } else { + if (numericFeaturesIdx.size() == 0) { numericFeatures = std::vector(n_features, false); - for (auto i : numericFeaturesIdx) { - numericFeatures[i] = true; + } else { + if (numericFeaturesIdx.at(0) == -1) { + numericFeatures = std::vector(n_features, true); + } else { + numericFeatures = std::vector(n_features, false); + for (auto i : numericFeaturesIdx) { + numericFeatures[i] = true; + } } } if (discretize) { diff --git a/src/common/Datasets.cpp b/src/common/Datasets.cpp index 535268e..e21d741 100644 --- a/src/common/Datasets.cpp +++ b/src/common/Datasets.cpp @@ -12,41 +12,51 @@ namespace platform { path = sd.getPath(); ifstream catalog(path + "all.txt"); std::vector numericFeaturesIdx; - if (catalog.is_open()) { - std::string line; - while (getline(catalog, line)) { - if (line.empty() || line[0] == '#') { - continue; - } - std::vector tokens = split(line, ';'); - std::string name = tokens[0]; - std::string className; - numericFeaturesIdx.clear(); - if (tokens.size() == 1) { + if (!catalog.is_open()) { + throw std::invalid_argument("Unable to open catalog file. [" + path + "all.txt" + "]"); + } + std::string line; + while (getline(catalog, line)) { + if (line.empty() || line[0] == '#') { + continue; + } + std::vector tokens = split(line, ';'); + std::string name = tokens[0]; + std::string className; + numericFeaturesIdx.clear(); + int size = tokens.size(); + switch (size) { + case 1: className = "-1"; numericFeaturesIdx.push_back(-1); - } else { + break; + case 2: className = tokens[1]; - if (tokens.size() > 2) { + numericFeaturesIdx.push_back(-1); + break; + case 3: + { + className = tokens[1]; auto numericFeatures = tokens[2]; if (numericFeatures == "all") { numericFeaturesIdx.push_back(-1); } else { - auto features = json::parse(numericFeatures); - for (auto& f : features) { - numericFeaturesIdx.push_back(f); + if (numericFeatures != "none") { + auto features = json::parse(numericFeatures); + for (auto& f : features) { + numericFeaturesIdx.push_back(f); + } } } - } else { - numericFeaturesIdx.push_back(-1); } - } - datasets[name] = make_unique(path, name, className, discretize, fileType, numericFeaturesIdx); + break; + default: + throw std::invalid_argument("Invalid catalog file format."); + } - catalog.close(); - } else { - throw std::invalid_argument("Unable to open catalog file. [" + path + "all.txt" + "]"); + datasets[name] = make_unique(path, name, className, discretize, fileType, numericFeaturesIdx); } + catalog.close(); } std::vector Datasets::getNames() { diff --git a/src/common/DotEnv.h b/src/common/DotEnv.h index 07cc921..c2c70b7 100644 --- a/src/common/DotEnv.h +++ b/src/common/DotEnv.h @@ -13,9 +13,30 @@ namespace platform { class DotEnv { private: std::map env; + std::map> valid; public: DotEnv(bool create = false) { + valid = + { + {"source_data", {"Arff", "Tanveer", "Surcov"}}, + {"experiment", {"discretiz", "odte", "covid"}}, + {"fit_features", {"0", "1"}}, + {"discretize", {"0", "1"}}, + {"ignore_nan", {"0", "1"}}, + {"stratified", {"0", "1"}}, + {"score", {"accuracy"}}, + {"framework", {"bulma", "bootstrap"}}, + {"margin", {"0.1", "0.2", "0.3"}}, + {"n_folds", {"5", "10"}}, + {"discretiz_algo", {"mdlp", "bin3u", "bin3q"}}, + {"platform", {"any"}}, + {"model", {"any"}}, + {"seeds", {"any"}}, + {"nodes", {"any"}}, + {"leaves", {"any"}}, + {"depth", {"any"}}, + }; if (create) { // For testing purposes std::ofstream file(".env"); @@ -37,7 +58,39 @@ namespace platform { std::istringstream iss(line); std::string key, value; if (std::getline(iss, key, '=') && std::getline(iss, value)) { - env[trim(key)] = trim(value); + key = trim(key); + value = trim(value); + parse(key, value); + env[key] = value; + } + } + parseEnv(); + } + void parse(const std::string& key, const std::string& value) + { + if (valid.find(key) == valid.end()) { + std::cerr << "Invalid key in .env: " << key << std::endl; + exit(1); + } + if (valid[key].front() == "any") { + return; + } + if (std::find(valid[key].begin(), valid[key].end(), value) == valid[key].end()) { + std::cerr << "Invalid value in .env: " << key << " = " << value << std::endl; + exit(1); + } + } + void parseEnv() + { + for (auto& [key, values] : valid) { + if (env.find(key) == env.end()) { + std::string valid_values = "", sep = ""; + for (const auto& value : values) { + valid_values += sep + value; + sep = ", "; + } + std::cerr << "Key not found in .env: " << key << ", valid values: " << valid_values << std::endl; + exit(1); } } }