Fix some lint warnings

This commit is contained in:
Ricardo Montañana Gómez 2023-07-29 20:37:51 +02:00
parent 5efa3beaee
commit 8b2ed26ab7
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 34 additions and 51 deletions

View File

@ -18,7 +18,7 @@ namespace bayesnet {
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset); auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset);
mi.push_back({ i, mi_value }); mi.push_back({ i, mi_value });
} }
sort(mi.begin(), mi.end(), [](auto& left, auto& right) {return left.second < right.second;}); sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
auto root = mi[mi.size() - 1].first; auto root = mi[mi.size() - 1].first;
// 2. Compute mutual information between each feature and the class // 2. Compute mutual information between each feature and the class
auto weights = metrics.conditionalEdge(); auto weights = metrics.conditionalEdge();

View File

@ -4,9 +4,9 @@
namespace platform { namespace platform {
void Datasets::load() void Datasets::load()
{ {
string line;
ifstream catalog(path + "/all.txt"); ifstream catalog(path + "/all.txt");
if (catalog.is_open()) { if (catalog.is_open()) {
string line;
while (getline(catalog, line)) { while (getline(catalog, line)) {
vector<string> tokens = split(line, ','); vector<string> tokens = split(line, ',');
string name = tokens[0]; string name = tokens[0];
@ -83,23 +83,8 @@ namespace platform {
{ {
return datasets.find(name) != datasets.end(); return datasets.find(name) != datasets.end();
} }
Dataset::Dataset(Dataset& dataset) Dataset::Dataset(const Dataset& dataset) : path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples), n_features(dataset.n_features), features(dataset.features), states(dataset.states), loaded(dataset.loaded), discretize(dataset.discretize), X(dataset.X), y(dataset.y), Xv(dataset.Xv), Xd(dataset.Xd), yv(dataset.yv), fileType(dataset.fileType)
{ {
path = dataset.path;
name = dataset.name;
className = dataset.className;
n_samples = dataset.n_samples;
n_features = dataset.n_features;
features = dataset.features;
states = dataset.states;
loaded = dataset.loaded;
discretize = dataset.discretize;
X = dataset.X;
y = dataset.y;
Xv = dataset.Xv;
Xd = dataset.Xd;
yv = dataset.yv;
fileType = dataset.fileType;
} }
string Dataset::getName() string Dataset::getName()
{ {
@ -168,9 +153,9 @@ namespace platform {
} }
void Dataset::load_csv() void Dataset::load_csv()
{ {
string line;
ifstream file(path + "/" + name + ".csv"); ifstream file(path + "/" + name + ".csv");
if (file.is_open()) { if (file.is_open()) {
string line;
getline(file, line); getline(file, line);
vector<string> tokens = split(line, ','); vector<string> tokens = split(line, ',');
features = vector<string>(tokens.begin(), tokens.end() - 1); features = vector<string>(tokens.begin(), tokens.end() - 1);

View File

@ -13,7 +13,7 @@ namespace platform {
string name; string name;
fileType_t fileType; fileType_t fileType;
string className; string className;
int n_samples, n_features; int n_samples{ 0 }, n_features{ 0 };
vector<string> features; vector<string> features;
map<string, vector<int>> states; map<string, vector<int>> states;
bool loaded; bool loaded;
@ -27,8 +27,8 @@ namespace platform {
void load_arff(); void load_arff();
void computeStates(); void computeStates();
public: public:
Dataset(string path, string name, string className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {}; Dataset(const string& path, const string& name, const string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
Dataset(Dataset&); explicit Dataset(const Dataset&);
string getName(); string getName();
string getClassName(); string getClassName();
vector<string> getFeatures(); vector<string> getFeatures();
@ -49,7 +49,7 @@ namespace platform {
bool discretize; bool discretize;
void load(); // Loads the list of datasets void load(); // Loads the list of datasets
public: public:
Datasets(string path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); }; Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
vector<string> getNames(); vector<string> getNames();
vector<string> getFeatures(string name); vector<string> getFeatures(string name);
int getNSamples(string name); int getNSamples(string name);

View File

@ -48,7 +48,7 @@ namespace platform {
result["seeds"] = randomSeeds; result["seeds"] = randomSeeds;
result["duration"] = duration; result["duration"] = duration;
result["results"] = json::array(); result["results"] = json::array();
for (auto& r : results) { for (const auto& r : results) {
json j; json j;
j["dataset"] = r.getDataset(); j["dataset"] = r.getDataset();
j["hyperparameters"] = r.getHyperparameters(); j["hyperparameters"] = r.getHyperparameters();
@ -78,7 +78,7 @@ namespace platform {
} }
return result; return result;
} }
void Experiment::save(string path) void Experiment::save(const string& path)
{ {
json data = build_json(); json data = build_json();
ofstream file(path + "/" + get_file_name()); ofstream file(path + "/" + get_file_name());
@ -97,14 +97,12 @@ namespace platform {
cout << "*** Starting experiment: " << title << " ***" << endl; cout << "*** Starting experiment: " << title << " ***" << endl;
for (auto fileName : filesToProcess) { for (auto fileName : filesToProcess) {
cout << "- " << setw(20) << left << fileName << " " << right << flush; cout << "- " << setw(20) << left << fileName << " " << right << flush;
auto result = cross_validation(path, fileName); cross_validation(path, fileName);
result.setDataset(fileName);
addResult(result);
cout << endl; cout << endl;
} }
} }
Result Experiment::cross_validation(const string& path, const string& fileName) void Experiment::cross_validation(const string& path, const string& fileName)
{ {
auto datasets = platform::Datasets(path, true, platform::ARFF); auto datasets = platform::Datasets(path, true, platform::ARFF);
// Get dataset // Get dataset
@ -172,6 +170,7 @@ namespace platform {
result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>()); result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>()); result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>()); result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
return result; result.setDataset(fileName);
addResult(result);
} }
} }

View File

@ -30,14 +30,14 @@ namespace platform {
class Result { class Result {
private: private:
string dataset, hyperparameters, model_version; string dataset, hyperparameters, model_version;
int samples, features, classes; int samples{ 0 }, features{ 0 }, classes{ 0 };
double score_train, score_test, score_train_std, score_test_std, train_time, train_time_std, test_time, test_time_std; double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 };
float nodes, leaves, depth; float nodes{ 0 }, leaves{ 0 }, depth{ 0 };
vector<double> scores_train, scores_test, times_train, times_test; vector<double> scores_train, scores_test, times_train, times_test;
public: public:
Result() = default; Result() = default;
Result& setDataset(string dataset) { this->dataset = dataset; return *this; } Result& setDataset(const string& dataset) { this->dataset = dataset; return *this; }
Result& setHyperparameters(string hyperparameters) { this->hyperparameters = hyperparameters; return *this; } Result& setHyperparameters(const string& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
Result& setSamples(int samples) { this->samples = samples; return *this; } Result& setSamples(int samples) { this->samples = samples; return *this; }
Result& setFeatures(int features) { this->features = features; return *this; } Result& setFeatures(int features) { this->features = features; return *this; }
Result& setClasses(int classes) { this->classes = classes; return *this; } Result& setClasses(int classes) { this->classes = classes; return *this; }
@ -82,21 +82,21 @@ namespace platform {
class Experiment { class Experiment {
private: private:
string title, model, platform, score_name, model_version, language_version, language; string title, model, platform, score_name, model_version, language_version, language;
bool discretized, stratified; bool discretized{ false }, stratified{ false };
vector<Result> results; vector<Result> results;
vector<int> randomSeeds; vector<int> randomSeeds;
int nfolds; int nfolds{ 0 };
float duration; float duration{ 0 };
json build_json(); json build_json();
public: public:
Experiment() = default; Experiment() = default;
Experiment& setTitle(string title) { this->title = title; return *this; } Experiment& setTitle(const string& title) { this->title = title; return *this; }
Experiment& setModel(string model) { this->model = model; return *this; } Experiment& setModel(const string& model) { this->model = model; return *this; }
Experiment& setPlatform(string platform) { this->platform = platform; return *this; } Experiment& setPlatform(const string& platform) { this->platform = platform; return *this; }
Experiment& setScoreName(string score_name) { this->score_name = score_name; return *this; } Experiment& setScoreName(const string& score_name) { this->score_name = score_name; return *this; }
Experiment& setModelVersion(string model_version) { this->model_version = model_version; return *this; } Experiment& setModelVersion(const string& model_version) { this->model_version = model_version; return *this; }
Experiment& setLanguage(string language) { this->language = language; return *this; } Experiment& setLanguage(const string& language) { this->language = language; return *this; }
Experiment& setLanguageVersion(string language_version) { this->language_version = language_version; return *this; } Experiment& setLanguageVersion(const string& language_version) { this->language_version = language_version; return *this; }
Experiment& setDiscretized(bool discretized) { this->discretized = discretized; return *this; } Experiment& setDiscretized(bool discretized) { this->discretized = discretized; return *this; }
Experiment& setStratified(bool stratified) { this->stratified = stratified; return *this; } Experiment& setStratified(bool stratified) { this->stratified = stratified; return *this; }
Experiment& setNFolds(int nfolds) { this->nfolds = nfolds; return *this; } Experiment& setNFolds(int nfolds) { this->nfolds = nfolds; return *this; }
@ -104,8 +104,8 @@ namespace platform {
Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); return *this; } Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); return *this; }
Experiment& setDuration(float duration) { this->duration = duration; return *this; } Experiment& setDuration(float duration) { this->duration = duration; return *this; }
string get_file_name(); string get_file_name();
void save(string path); void save(const string& path);
Result cross_validation(const string& path, const string& fileName); void cross_validation(const string& path, const string& fileName);
void go(vector<string> filesToProcess, const string& path); void go(vector<string> filesToProcess, const string& path);
void show(); void show();
}; };

View File

@ -7,9 +7,8 @@ Fold::Fold(int k, int n, int seed) : k(k), n(n), seed(seed)
random_seed = default_random_engine(seed == -1 ? rd() : seed); random_seed = default_random_engine(seed == -1 ? rd() : seed);
srand(seed == -1 ? time(0) : seed); srand(seed == -1 ? time(0) : seed);
} }
KFold::KFold(int k, int n, int seed) : Fold(k, n, seed) KFold::KFold(int k, int n, int seed) : Fold(k, n, seed), indices(vector<int>(n))
{ {
indices = vector<int>(n);
iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1 iota(begin(indices), end(indices), 0); // fill with 0, 1, ..., n - 1
shuffle(indices.begin(), indices.end(), random_seed); shuffle(indices.begin(), indices.end(), random_seed);
} }

View File

@ -22,7 +22,7 @@ private:
vector<int> indices; vector<int> indices;
public: public:
KFold(int k, int n, int seed = -1); KFold(int k, int n, int seed = -1);
pair<vector<int>, vector<int>> getFold(int nFold); pair<vector<int>, vector<int>> getFold(int nFold) override;
}; };
class StratifiedKFold : public Fold { class StratifiedKFold : public Fold {
private: private:
@ -32,6 +32,6 @@ private:
public: public:
StratifiedKFold(int k, const vector<int>& y, int seed = -1); StratifiedKFold(int k, const vector<int>& y, int seed = -1);
StratifiedKFold(int k, torch::Tensor& y, int seed = -1); StratifiedKFold(int k, torch::Tensor& y, int seed = -1);
pair<vector<int>, vector<int>> getFold(int nFold); pair<vector<int>, vector<int>> getFold(int nFold) override;
}; };
#endif #endif