Add traintest split in gridsearch
This commit is contained in:
@@ -15,10 +15,6 @@ namespace platform {
|
||||
{
|
||||
return name;
|
||||
}
|
||||
std::string Dataset::getClassName() const
|
||||
{
|
||||
return className;
|
||||
}
|
||||
std::vector<std::string> Dataset::getFeatures() const
|
||||
{
|
||||
if (loaded) {
|
||||
@@ -43,6 +39,42 @@ namespace platform {
|
||||
throw std::invalid_argument(message_dataset_not_loaded);
|
||||
}
|
||||
}
|
||||
std::string Dataset::getClassName() const
|
||||
{
|
||||
return className;
|
||||
}
|
||||
int Dataset::getNClasses() const
|
||||
{
|
||||
if (loaded) {
|
||||
if (discretize) {
|
||||
return states.at(className).size();
|
||||
}
|
||||
return *std::max_element(yv.begin(), yv.end()) + 1;
|
||||
} else {
|
||||
throw std::invalid_argument(message_dataset_not_loaded);
|
||||
}
|
||||
}
|
||||
std::vector<std::string> Dataset::getLabels() const
|
||||
{
|
||||
// Return the labels factorization result
|
||||
if (loaded) {
|
||||
return labels;
|
||||
} else {
|
||||
throw std::invalid_argument(message_dataset_not_loaded);
|
||||
}
|
||||
}
|
||||
std::vector<int> Dataset::getClassesCounts() const
|
||||
{
|
||||
if (loaded) {
|
||||
std::vector<int> counts(*std::max_element(yv.begin(), yv.end()) + 1);
|
||||
for (auto y : yv) {
|
||||
counts[y]++;
|
||||
}
|
||||
return counts;
|
||||
} else {
|
||||
throw std::invalid_argument(message_dataset_not_loaded);
|
||||
}
|
||||
}
|
||||
std::map<std::string, std::vector<int>> Dataset::getStates() const
|
||||
{
|
||||
if (loaded) {
|
||||
@@ -70,7 +102,6 @@ namespace platform {
|
||||
pair<torch::Tensor&, torch::Tensor&> Dataset::getTensors()
|
||||
{
|
||||
if (loaded) {
|
||||
buildTensors();
|
||||
return { X, y };
|
||||
} else {
|
||||
throw std::invalid_argument(message_dataset_not_loaded);
|
||||
@@ -79,29 +110,32 @@ namespace platform {
|
||||
void Dataset::load_csv()
|
||||
{
|
||||
ifstream file(path + "/" + name + ".csv");
|
||||
if (file.is_open()) {
|
||||
std::string line;
|
||||
getline(file, line);
|
||||
std::vector<std::string> tokens = split(line, ',');
|
||||
features = std::vector<std::string>(tokens.begin(), tokens.end() - 1);
|
||||
if (className == "-1") {
|
||||
className = tokens.back();
|
||||
}
|
||||
for (auto i = 0; i < features.size(); ++i) {
|
||||
Xv.push_back(std::vector<float>());
|
||||
}
|
||||
while (getline(file, line)) {
|
||||
tokens = split(line, ',');
|
||||
for (auto i = 0; i < features.size(); ++i) {
|
||||
Xv[i].push_back(stof(tokens[i]));
|
||||
}
|
||||
yv.push_back(stoi(tokens.back()));
|
||||
}
|
||||
labels.clear();
|
||||
file.close();
|
||||
} else {
|
||||
if (!file.is_open()) {
|
||||
throw std::invalid_argument("Unable to open dataset file.");
|
||||
}
|
||||
labels.clear();
|
||||
std::string line;
|
||||
getline(file, line);
|
||||
std::vector<std::string> tokens = split(line, ',');
|
||||
features = std::vector<std::string>(tokens.begin(), tokens.end() - 1);
|
||||
if (className == "-1") {
|
||||
className = tokens.back();
|
||||
}
|
||||
for (auto i = 0; i < features.size(); ++i) {
|
||||
Xv.push_back(std::vector<float>());
|
||||
}
|
||||
while (getline(file, line)) {
|
||||
tokens = split(line, ',');
|
||||
for (auto i = 0; i < features.size(); ++i) {
|
||||
Xv[i].push_back(stof(tokens[i]));
|
||||
}
|
||||
auto label = trim(tokens.back());
|
||||
if (find(labels.begin(), labels.end(), label) == labels.end()) {
|
||||
labels.push_back(label);
|
||||
}
|
||||
yv.push_back(stoi(label));
|
||||
}
|
||||
file.close();
|
||||
}
|
||||
void Dataset::computeStates()
|
||||
{
|
||||
@@ -147,32 +181,35 @@ namespace platform {
|
||||
void Dataset::load_rdata()
|
||||
{
|
||||
ifstream file(path + "/" + name + "_R.dat");
|
||||
if (file.is_open()) {
|
||||
std::string line;
|
||||
getline(file, line);
|
||||
line = ArffFiles::trim(line);
|
||||
std::vector<std::string> tokens = tokenize(line);
|
||||
transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); });
|
||||
if (className == "-1") {
|
||||
className = ArffFiles::trim(tokens.back());
|
||||
}
|
||||
for (auto i = 0; i < features.size(); ++i) {
|
||||
Xv.push_back(std::vector<float>());
|
||||
}
|
||||
while (getline(file, line)) {
|
||||
tokens = tokenize(line);
|
||||
// We have to skip the first token, which is the instance number.
|
||||
for (auto i = 1; i < features.size() + 1; ++i) {
|
||||
const float value = stof(tokens[i]);
|
||||
Xv[i - 1].push_back(value);
|
||||
}
|
||||
yv.push_back(stoi(tokens.back()));
|
||||
}
|
||||
labels.clear();
|
||||
file.close();
|
||||
} else {
|
||||
if (!file.is_open()) {
|
||||
throw std::invalid_argument("Unable to open dataset file.");
|
||||
}
|
||||
std::string line;
|
||||
labels.clear();
|
||||
getline(file, line);
|
||||
line = ArffFiles::trim(line);
|
||||
std::vector<std::string> tokens = tokenize(line);
|
||||
transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); });
|
||||
if (className == "-1") {
|
||||
className = ArffFiles::trim(tokens.back());
|
||||
}
|
||||
for (auto i = 0; i < features.size(); ++i) {
|
||||
Xv.push_back(std::vector<float>());
|
||||
}
|
||||
while (getline(file, line)) {
|
||||
tokens = tokenize(line);
|
||||
// We have to skip the first token, which is the instance number.
|
||||
for (auto i = 1; i < features.size() + 1; ++i) {
|
||||
const float value = stof(tokens[i]);
|
||||
Xv[i - 1].push_back(value);
|
||||
}
|
||||
auto label = trim(tokens.back());
|
||||
if (find(labels.begin(), labels.end(), label) == labels.end()) {
|
||||
labels.push_back(label);
|
||||
}
|
||||
yv.push_back(stoi(label));
|
||||
}
|
||||
file.close();
|
||||
}
|
||||
void Dataset::load()
|
||||
{
|
||||
@@ -200,27 +237,13 @@ namespace platform {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (discretize) {
|
||||
Xd = discretizeDataset(Xv, yv);
|
||||
computeStates();
|
||||
}
|
||||
loaded = true;
|
||||
}
|
||||
void Dataset::buildTensors()
|
||||
{
|
||||
if (discretize) {
|
||||
X = torch::zeros({ static_cast<int>(n_features), static_cast<int>(n_samples) }, torch::kInt32);
|
||||
} else {
|
||||
X = torch::zeros({ static_cast<int>(n_features), static_cast<int>(n_samples) }, torch::kFloat32);
|
||||
}
|
||||
// Build Tensors
|
||||
X = torch::zeros({ n_features, n_samples }, torch::kFloat32);
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
if (discretize) {
|
||||
X.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32));
|
||||
} else {
|
||||
X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32));
|
||||
}
|
||||
X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32));
|
||||
}
|
||||
y = torch::tensor(yv, torch::kInt32);
|
||||
loaded = true;
|
||||
}
|
||||
std::vector<mdlp::labels_t> Dataset::discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y)
|
||||
{
|
||||
@@ -233,9 +256,40 @@ namespace platform {
|
||||
}
|
||||
return Xd;
|
||||
}
|
||||
std::pair <torch::Tensor&, torch::Tensor&> Dataset::getDiscretizedTrainTestTensors()
|
||||
std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> Dataset::getTrainTestTensors(std::vector<int>& train, std::vector<int>& test)
|
||||
{
|
||||
auto discretizer = Discretization::instance()->create("mdlp");
|
||||
return { X_train, X_test };
|
||||
if (!loaded) {
|
||||
throw std::invalid_argument(message_dataset_not_loaded);
|
||||
}
|
||||
auto train_t = torch::tensor(train);
|
||||
int samples_train = train.size();
|
||||
int samples_test = test.size();
|
||||
auto test_t = torch::tensor(test);
|
||||
X_train = X.index({ "...", train_t });
|
||||
y_train = y.index({ train_t });
|
||||
X_test = X.index({ "...", test_t });
|
||||
y_test = y.index({ test_t });
|
||||
if (discretize) {
|
||||
auto discretizer = Discretization::instance()->create(discretizer_algorithm);
|
||||
auto X_train_d = torch::zeros({ n_features, samples_train }, torch::kInt32);
|
||||
auto X_test_d = torch::zeros({ n_features, samples_test }, torch::kInt32);
|
||||
for (int feature = 0; feature < n_features; ++feature) {
|
||||
if (numericFeatures[feature]) {
|
||||
auto X_train_feature = X_train.index({ feature, "..." }).to(torch::kFloat32);
|
||||
auto X_test_feature = X_test.index({ feature, "..." }).to(torch::kFloat32);
|
||||
discretizer->fit(X_train_feature, y_train);
|
||||
auto X_train_feature_d = discretizer->transform(X_train_feature);
|
||||
auto X_test_feature_d = discretizer->transform(X_test_feature);
|
||||
X_train_d.index_put_({ feature, "..." }, X_train_feature_d.to(torch::kInt32));
|
||||
X_test_d.index_put_({ feature, "..." }, X_test_feature_d.to(torch::kInt32));
|
||||
} else {
|
||||
X_train_d.index_put_({ feature, "..." }, X_train.index({ feature, "..." }).to(torch::kInt32));
|
||||
X_test_d.index_put_({ feature, "..." }, X_test.index({ feature, "..." }).to(torch::kInt32));
|
||||
}
|
||||
}
|
||||
X_train = X_train_d;
|
||||
X_test = X_test_d;
|
||||
}
|
||||
return { X_train, X_test, y_train, y_test };
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user