Add traintest split in gridsearch

This commit is contained in:
2024-06-07 11:05:59 +02:00
parent 5dd3deca1a
commit 361c51d864
8 changed files with 213 additions and 247 deletions

View File

@@ -4,27 +4,30 @@
#include <map>
#include <vector>
#include <string>
#include <tuple>
#include <common/DiscretizationRegister.h>
#include "Utils.h"
#include "SourceData.h"
namespace platform {
class Dataset {
public:
Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType, std::vector<int> numericFeaturesIdx) :
Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType, std::vector<int> numericFeaturesIdx, std::string discretizer_algo = "none") :
path(path), name(name), className(className), discretize(discretize),
loaded(false), fileType(fileType), numericFeaturesIdx(numericFeaturesIdx)
loaded(false), fileType(fileType), numericFeaturesIdx(numericFeaturesIdx), discretizer_algorithm(discretizer_algo)
{
};
explicit Dataset(const Dataset&);
std::string getName() const;
std::string getClassName() const;
std::vector<std::string> getLabels() const { return labels; }
int getNClasses() const;
std::vector<std::string> getLabels() const; // return the labels factorization result
std::vector<int> getClassesCounts() const;
std::vector<string> getFeatures() const;
std::map<std::string, std::vector<int>> getStates() const;
std::pair<vector<std::vector<float>>&, std::vector<int>&> getVectors();
std::pair<vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized();
std::pair<torch::Tensor&, torch::Tensor&> getDiscretizedTrainTestTensors();
std::pair<torch::Tensor&, torch::Tensor&> getTensors();
std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> getTrainTestTensors(std::vector<int>& train, std::vector<int>& test);
int getNFeatures() const;
int getNSamples() const;
std::vector<bool>& getNumericFeatures() { return numericFeatures; }
@@ -37,6 +40,7 @@ namespace platform {
std::string className;
int n_samples{ 0 }, n_features{ 0 };
std::vector<int> numericFeaturesIdx;
std::string discretizer_algorithm;
std::vector<bool> numericFeatures; // true if feature is numeric
std::vector<std::string> features;
std::vector<std::string> labels;
@@ -44,11 +48,10 @@ namespace platform {
bool loaded;
bool discretize;
torch::Tensor X, y;
torch::Tensor X_train, X_test;
torch::Tensor X_train, X_test, y_train, y_test;
std::vector<std::vector<float>> Xv;
std::vector<std::vector<int>> Xd;
std::vector<int> yv;
void buildTensors();
void load_csv();
void load_arff();
void load_rdata();