Add parsing to DotEnv

This commit is contained in:
2024-06-06 17:55:39 +02:00
parent a7ec930fa0
commit c4f4e332f6
3 changed files with 96 additions and 29 deletions

View File

@@ -187,12 +187,16 @@ namespace platform {
} }
n_samples = Xv[0].size(); n_samples = Xv[0].size();
n_features = Xv.size(); n_features = Xv.size();
if (numericFeaturesIdx.at(0) == -1) { if (numericFeaturesIdx.size() == 0) {
numericFeatures = std::vector<bool>(n_features, true);
} else {
numericFeatures = std::vector<bool>(n_features, false); numericFeatures = std::vector<bool>(n_features, false);
for (auto i : numericFeaturesIdx) { } else {
numericFeatures[i] = true; if (numericFeaturesIdx.at(0) == -1) {
numericFeatures = std::vector<bool>(n_features, true);
} else {
numericFeatures = std::vector<bool>(n_features, false);
for (auto i : numericFeaturesIdx) {
numericFeatures[i] = true;
}
} }
} }
if (discretize) { if (discretize) {

View File

@@ -12,41 +12,51 @@ namespace platform {
path = sd.getPath(); path = sd.getPath();
ifstream catalog(path + "all.txt"); ifstream catalog(path + "all.txt");
std::vector<int> numericFeaturesIdx; std::vector<int> numericFeaturesIdx;
if (catalog.is_open()) { if (!catalog.is_open()) {
std::string line; throw std::invalid_argument("Unable to open catalog file. [" + path + "all.txt" + "]");
while (getline(catalog, line)) { }
if (line.empty() || line[0] == '#') { std::string line;
continue; while (getline(catalog, line)) {
} if (line.empty() || line[0] == '#') {
std::vector<std::string> tokens = split(line, ';'); continue;
std::string name = tokens[0]; }
std::string className; std::vector<std::string> tokens = split(line, ';');
numericFeaturesIdx.clear(); std::string name = tokens[0];
if (tokens.size() == 1) { std::string className;
numericFeaturesIdx.clear();
int size = tokens.size();
switch (size) {
case 1:
className = "-1"; className = "-1";
numericFeaturesIdx.push_back(-1); numericFeaturesIdx.push_back(-1);
} else { break;
case 2:
className = tokens[1]; className = tokens[1];
if (tokens.size() > 2) { numericFeaturesIdx.push_back(-1);
break;
case 3:
{
className = tokens[1];
auto numericFeatures = tokens[2]; auto numericFeatures = tokens[2];
if (numericFeatures == "all") { if (numericFeatures == "all") {
numericFeaturesIdx.push_back(-1); numericFeaturesIdx.push_back(-1);
} else { } else {
auto features = json::parse(numericFeatures); if (numericFeatures != "none") {
for (auto& f : features) { auto features = json::parse(numericFeatures);
numericFeaturesIdx.push_back(f); for (auto& f : features) {
numericFeaturesIdx.push_back(f);
}
} }
} }
} else {
numericFeaturesIdx.push_back(-1);
} }
} break;
datasets[name] = make_unique<Dataset>(path, name, className, discretize, fileType, numericFeaturesIdx); default:
throw std::invalid_argument("Invalid catalog file format.");
} }
catalog.close(); datasets[name] = make_unique<Dataset>(path, name, className, discretize, fileType, numericFeaturesIdx);
} else {
throw std::invalid_argument("Unable to open catalog file. [" + path + "all.txt" + "]");
} }
catalog.close();
} }
std::vector<std::string> Datasets::getNames() std::vector<std::string> Datasets::getNames()
{ {

View File

@@ -13,9 +13,30 @@ namespace platform {
class DotEnv { class DotEnv {
private: private:
std::map<std::string, std::string> env; std::map<std::string, std::string> env;
std::map<std::string, std::vector<std::string>> valid;
public: public:
DotEnv(bool create = false) DotEnv(bool create = false)
{ {
valid =
{
{"source_data", {"Arff", "Tanveer", "Surcov"}},
{"experiment", {"discretiz", "odte", "covid"}},
{"fit_features", {"0", "1"}},
{"discretize", {"0", "1"}},
{"ignore_nan", {"0", "1"}},
{"stratified", {"0", "1"}},
{"score", {"accuracy"}},
{"framework", {"bulma", "bootstrap"}},
{"margin", {"0.1", "0.2", "0.3"}},
{"n_folds", {"5", "10"}},
{"discretiz_algo", {"mdlp", "bin3u", "bin3q"}},
{"platform", {"any"}},
{"model", {"any"}},
{"seeds", {"any"}},
{"nodes", {"any"}},
{"leaves", {"any"}},
{"depth", {"any"}},
};
if (create) { if (create) {
// For testing purposes // For testing purposes
std::ofstream file(".env"); std::ofstream file(".env");
@@ -37,7 +58,39 @@ namespace platform {
std::istringstream iss(line); std::istringstream iss(line);
std::string key, value; std::string key, value;
if (std::getline(iss, key, '=') && std::getline(iss, value)) { if (std::getline(iss, key, '=') && std::getline(iss, value)) {
env[trim(key)] = trim(value); key = trim(key);
value = trim(value);
parse(key, value);
env[key] = value;
}
}
parseEnv();
}
void parse(const std::string& key, const std::string& value)
{
if (valid.find(key) == valid.end()) {
std::cerr << "Invalid key in .env: " << key << std::endl;
exit(1);
}
if (valid[key].front() == "any") {
return;
}
if (std::find(valid[key].begin(), valid[key].end(), value) == valid[key].end()) {
std::cerr << "Invalid value in .env: " << key << " = " << value << std::endl;
exit(1);
}
}
void parseEnv()
{
for (auto& [key, values] : valid) {
if (env.find(key) == env.end()) {
std::string valid_values = "", sep = "";
for (const auto& value : values) {
valid_values += sep + value;
sep = ", ";
}
std::cerr << "Key not found in .env: " << key << ", valid values: " << valid_values << std::endl;
exit(1);
} }
} }
} }