2023-10-04 21:19:23 +00:00
|
|
|
#ifndef TEST_UTILS_H
|
|
|
|
#define TEST_UTILS_H
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <string>
|
|
|
|
#include <vector>
|
|
|
|
#include <map>
|
2024-01-07 18:58:22 +00:00
|
|
|
#include <tuple>
|
2024-03-08 00:13:30 +00:00
|
|
|
#include <ArffFiles.h>
|
|
|
|
#include <CPPFImdlp.h>
|
2023-10-04 21:19:23 +00:00
|
|
|
|
2024-01-07 18:58:22 +00:00
|
|
|
bool file_exists(const std::string& name);
|
2023-11-08 17:45:35 +00:00
|
|
|
std::pair<vector<mdlp::labels_t>, map<std::string, int>> discretize(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y, std::vector<string> features);
|
|
|
|
std::vector<mdlp::labels_t> discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y);
|
|
|
|
std::tuple<vector<vector<int>>, std::vector<int>, std::vector<string>, std::string, map<std::string, std::vector<int>>> loadFile(const std::string& name);
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor, std::vector<string>, std::string, map<std::string, std::vector<int>>> loadDataset(const std::string& name, bool class_last, bool discretize_dataset);
|
2023-10-04 21:19:23 +00:00
|
|
|
|
2023-10-06 15:08:54 +00:00
|
|
|
class RawDatasets {
|
|
|
|
public:
|
2023-11-08 17:45:35 +00:00
|
|
|
RawDatasets(const std::string& file_name, bool discretize)
|
2023-10-06 15:08:54 +00:00
|
|
|
{
|
|
|
|
// Xt can be either discretized or not
|
|
|
|
tie(Xt, yt, featurest, classNamet, statest) = loadDataset(file_name, true, discretize);
|
|
|
|
// Xv is always discretized
|
|
|
|
tie(Xv, yv, featuresv, classNamev, statesv) = loadFile(file_name);
|
|
|
|
auto yresized = torch::transpose(yt.view({ yt.size(0), 1 }), 0, 1);
|
|
|
|
dataset = torch::cat({ Xt, yresized }, 0);
|
|
|
|
nSamples = dataset.size(1);
|
|
|
|
weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
|
2023-11-08 17:45:35 +00:00
|
|
|
weightsv = std::vector<double>(nSamples, 1.0 / nSamples);
|
2023-10-06 15:08:54 +00:00
|
|
|
classNumStates = discretize ? statest.at(classNamet).size() : 0;
|
|
|
|
}
|
|
|
|
torch::Tensor Xt, yt, dataset, weights;
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<vector<int>> Xv;
|
|
|
|
std::vector<double> weightsv;
|
|
|
|
std::vector<int> yv;
|
|
|
|
std::vector<string> featurest, featuresv;
|
|
|
|
map<std::string, std::vector<int>> statest, statesv;
|
|
|
|
std::string classNamet, classNamev;
|
2023-10-06 15:08:54 +00:00
|
|
|
int nSamples, classNumStates;
|
|
|
|
double epsilon = 1e-5;
|
|
|
|
};
|
|
|
|
|
|
|
|
#endif //TEST_UTILS_H
|