Remove unneeded loop in sortIndices

Add some static casts
This commit is contained in:
2023-03-19 19:13:37 +01:00
parent f0845c5bd1
commit cfade7a556
9 changed files with 210 additions and 146 deletions

View File

@@ -1,6 +1,10 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
project(mdlp) project(mdlp)
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
endif ()
set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 11)
add_library(mdlp CPPFImdlp.cpp Metrics.cpp sample/sample.cpp) add_library(mdlp CPPFImdlp.cpp Metrics.cpp sample/sample.cpp)

View File

@@ -5,20 +5,22 @@
#include <limits> #include <limits>
#include "CPPFImdlp.h" #include "CPPFImdlp.h"
#include "Metrics.h" #include "Metrics.h"
namespace mdlp { namespace mdlp {
CPPFImdlp::CPPFImdlp(size_t min_length_, int max_depth_, float proposed): min_length(min_length_), CPPFImdlp::CPPFImdlp(size_t min_length_, int max_depth_, float proposed) : min_length(min_length_),
max_depth(max_depth_), proposed_cuts(proposed) max_depth(max_depth_),
{ proposed_cuts(proposed) {
} }
CPPFImdlp::CPPFImdlp() = default; CPPFImdlp::CPPFImdlp() = default;
CPPFImdlp::~CPPFImdlp() = default; CPPFImdlp::~CPPFImdlp() = default;
size_t CPPFImdlp::compute_max_num_cut_points() const size_t CPPFImdlp::compute_max_num_cut_points() const {
{
// Set the actual maximum number of cut points as a number or as a percentage of the number of samples // Set the actual maximum number of cut points as a number or as a percentage of the number of samples
if (proposed_cuts == 0) { if (proposed_cuts == 0) {
return numeric_limits<size_t>::max(); return numeric_limits<size_t>::max();
} }
if (proposed_cuts < 0 || proposed_cuts > static_cast<float>(X.size())) { if (proposed_cuts < 0 || proposed_cuts > static_cast<float>(X.size())) {
throw invalid_argument("wrong proposed num_cuts value"); throw invalid_argument("wrong proposed num_cuts value");
@@ -28,8 +30,7 @@ namespace mdlp {
return static_cast<size_t>(proposed_cuts); return static_cast<size_t>(proposed_cuts);
} }
void CPPFImdlp::fit(samples_t& X_, labels_t& y_) void CPPFImdlp::fit(samples_t &X_, labels_t &y_) {
{
X = X_; X = X_;
y = y_; y = y_;
num_cut_points = compute_max_num_cut_points(); num_cut_points = compute_max_num_cut_points();
@@ -52,12 +53,15 @@ namespace mdlp {
computeCutPoints(0, X.size(), 1); computeCutPoints(0, X.size(), 1);
} }
pair<precision_t, size_t> CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end) pair<precision_t, size_t> CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end) {
{ size_t n;
size_t n, m, idxPrev = cut - 1 >= start ? cut - 1 : cut; size_t m;
size_t idxPrev = cut - 1 >= start ? cut - 1 : cut;
size_t idxNext = cut + 1 < end ? cut + 1 : cut; size_t idxNext = cut + 1 < end ? cut + 1 : cut;
bool backWall; // true if duplicates reach begining of the interval bool backWall; // true if duplicates reach begining of the interval
precision_t previous, actual, next; precision_t previous;
precision_t actual;
precision_t next;
previous = X[indices[idxPrev]]; previous = X[indices[idxPrev]];
actual = X[indices[cut]]; actual = X[indices[cut]];
next = X[indices[idxNext]]; next = X[indices[idxNext]];
@@ -78,11 +82,10 @@ namespace mdlp {
// Decide which values to use // Decide which values to use
cut = cut + (backWall ? m + 1 : -n); cut = cut + (backWall ? m + 1 : -n);
actual = X[indices[cut]]; actual = X[indices[cut]];
return { (actual + previous) / 2, cut }; return {(actual + previous) / 2, cut};
} }
void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_) void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_) {
{
size_t cut; size_t cut;
pair<precision_t, size_t> result; pair<precision_t, size_t> result;
if (cutPoints.size() == num_cut_points) if (cutPoints.size() == num_cut_points)
@@ -103,13 +106,15 @@ namespace mdlp {
} }
} }
size_t CPPFImdlp::getCandidate(size_t start, size_t end) size_t CPPFImdlp::getCandidate(size_t start, size_t end) {
{
/* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which /* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which
E(A, TA; S) is minimal amongst all the candidate cut points. */ E(A, TA; S) is minimal amongst all the candidate cut points. */
size_t candidate = numeric_limits<size_t>::max(), elements = end - start; size_t candidate = numeric_limits<size_t>::max();
size_t elements = end - start;
bool sameValues = true; bool sameValues = true;
precision_t entropy_left, entropy_right, minEntropy; precision_t entropy_left;
precision_t entropy_right;
precision_t minEntropy;
// Check if all the values of the variable in the interval are the same // Check if all the values of the variable in the interval are the same
for (size_t idx = start + 1; idx < end; idx++) { for (size_t idx = start + 1; idx < end; idx++) {
if (X[indices[idx]] != X[indices[start]]) { if (X[indices[idx]] != X[indices[start]]) {
@@ -134,11 +139,15 @@ namespace mdlp {
return candidate; return candidate;
} }
bool CPPFImdlp::mdlp(size_t start, size_t cut, size_t end) bool CPPFImdlp::mdlp(size_t start, size_t cut, size_t end) {
{ int k;
int k, k1, k2; int k1;
precision_t ig, delta; int k2;
precision_t ent, ent1, ent2; precision_t ig;
precision_t delta;
precision_t ent;
precision_t ent1;
precision_t ent2;
auto N = precision_t(end - start); auto N = precision_t(end - start);
k = metrics.computeNumClasses(start, end); k = metrics.computeNumClasses(start, end);
k1 = metrics.computeNumClasses(start, cut); k1 = metrics.computeNumClasses(start, cut);
@@ -148,33 +157,30 @@ namespace mdlp {
ent2 = metrics.entropy(cut, end); ent2 = metrics.entropy(cut, end);
ig = metrics.informationGain(start, cut, end); ig = metrics.informationGain(start, cut, end);
delta = static_cast<float>(log2(pow(3, precision_t(k)) - 2) - delta = static_cast<float>(log2(pow(3, precision_t(k)) - 2) -
(precision_t(k) * ent - precision_t(k1) * ent1 - precision_t(k2) * ent2)); (precision_t(k) * ent - precision_t(k1) * ent1 - precision_t(k2) * ent2));
precision_t term = 1 / N * (log2(N - 1) + delta); precision_t term = 1 / N * (log2(N - 1) + delta);
return ig > term; return ig > term;
} }
// Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes // Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes
indices_t CPPFImdlp::sortIndices(samples_t& X_, labels_t& y_) indices_t CPPFImdlp::sortIndices(samples_t &X_, labels_t &y_) {
{
indices_t idx(X_.size()); indices_t idx(X_.size());
iota(idx.begin(), idx.end(), 0); iota(idx.begin(), idx.end(), 0);
for (size_t i = 0; i < X_.size(); i++) stable_sort(idx.begin(), idx.end(), [&X_, &y_](size_t i1, size_t i2) {
stable_sort(idx.begin(), idx.end(), [&X_, &y_](size_t i1, size_t i2) {
if (X_[i1] == X_[i2]) if (X_[i1] == X_[i2])
return y_[i1] < y_[i2]; return y_[i1] < y_[i2];
else else
return X_[i1] < X_[i2]; return X_[i1] < X_[i2];
}); });
return idx; return idx;
} }
cutPoints_t CPPFImdlp::getCutPoints() cutPoints_t CPPFImdlp::getCutPoints() {
{
sort(cutPoints.begin(), cutPoints.end()); sort(cutPoints.begin(), cutPoints.end());
return cutPoints; return cutPoints;
} }
int CPPFImdlp::get_depth() const
{ int CPPFImdlp::get_depth() const {
return depth; return depth;
} }
} }

View File

@@ -1,9 +1,11 @@
#ifndef CPPFIMDLP_H #ifndef CPPFIMDLP_H
#define CPPFIMDLP_H #define CPPFIMDLP_H
#include "typesFImdlp.h" #include "typesFImdlp.h"
#include "Metrics.h" #include "Metrics.h"
#include <utility> #include <utility>
#include <string> #include <string>
namespace mdlp { namespace mdlp {
class CPPFImdlp { class CPPFImdlp {
protected: protected:
@@ -18,20 +20,32 @@ namespace mdlp {
cutPoints_t cutPoints; cutPoints_t cutPoints;
size_t num_cut_points = numeric_limits<size_t>::max(); size_t num_cut_points = numeric_limits<size_t>::max();
static indices_t sortIndices(samples_t&, labels_t&); static indices_t sortIndices(samples_t &, labels_t &);
void computeCutPoints(size_t, size_t, int); void computeCutPoints(size_t, size_t, int);
bool mdlp(size_t, size_t, size_t); bool mdlp(size_t, size_t, size_t);
size_t getCandidate(size_t, size_t); size_t getCandidate(size_t, size_t);
size_t compute_max_num_cut_points() const; size_t compute_max_num_cut_points() const;
pair<precision_t, size_t> valueCutPoint(size_t, size_t, size_t); pair<precision_t, size_t> valueCutPoint(size_t, size_t, size_t);
public: public:
CPPFImdlp(); CPPFImdlp();
CPPFImdlp(size_t, int, float); CPPFImdlp(size_t, int, float);
~CPPFImdlp(); ~CPPFImdlp();
void fit(samples_t&, labels_t&);
void fit(samples_t &, labels_t &);
cutPoints_t getCutPoints(); cutPoints_t getCutPoints();
int get_depth() const; int get_depth() const;
inline string version() { return "1.1.1"; };
inline string version() const { return "1.1.1"; };
}; };
} }
#endif #endif

View File

@@ -1,63 +1,71 @@
#include "Metrics.h" #include "Metrics.h"
#include <set> #include <set>
#include <cmath> #include <cmath>
using namespace std; using namespace std;
namespace mdlp { namespace mdlp {
Metrics::Metrics(labels_t& y_, indices_t& indices_): y(y_), indices(indices_), numClasses(computeNumClasses(0, indices.size())), entropyCache(cacheEnt_t()), igCache(cacheIg_t()) Metrics::Metrics(labels_t &y_, indices_t &indices_) : y(y_), indices(indices_),
{ numClasses(computeNumClasses(0, indices.size())) {
} }
int Metrics::computeNumClasses(size_t start, size_t end)
{ int Metrics::computeNumClasses(size_t start, size_t end) {
set<int> nClasses; set<int> nClasses;
for (auto i = start; i < end; ++i) { for (auto i = start; i < end; ++i) {
nClasses.insert(y[indices[i]]); nClasses.insert(y[indices[i]]);
} }
return nClasses.size(); return static_cast<int>(nClasses.size());
} }
void Metrics::setData(labels_t& y_, indices_t& indices_)
{ void Metrics::setData(const labels_t &y_, const indices_t &indices_) {
indices = indices_; indices = indices_;
y = y_; y = y_;
numClasses = computeNumClasses(0, indices.size()); numClasses = computeNumClasses(0, indices.size());
entropyCache.clear(); entropyCache.clear();
igCache.clear(); igCache.clear();
} }
precision_t Metrics::entropy(size_t start, size_t end)
{ precision_t Metrics::entropy(size_t start, size_t end) {
precision_t p, ventropy = 0; precision_t p;
precision_t ventropy = 0;
int nElements = 0; int nElements = 0;
labels_t counts(numClasses + 1, 0); labels_t counts(numClasses + 1, 0);
if (end - start < 2) if (end - start < 2)
return 0; return 0;
if (entropyCache.find({ start, end }) != entropyCache.end()) { if (entropyCache.find({start, end}) != entropyCache.end()) {
return entropyCache[{start, end}]; return entropyCache[{start, end}];
} }
for (auto i = &indices[start]; i != &indices[end]; ++i) { for (auto i = &indices[start]; i != &indices[end]; ++i) {
counts[y[*i]]++; counts[y[*i]]++;
nElements++; nElements++;
} }
for (auto count : counts) { for (auto count: counts) {
if (count > 0) { if (count > 0) {
p = (precision_t)count / nElements; p = static_cast<precision_t>(count) / static_cast<precision_t>(nElements);
ventropy -= p * log2(p); ventropy -= p * log2(p);
} }
} }
entropyCache[{start, end}] = ventropy; entropyCache[{start, end}] = ventropy;
return ventropy; return ventropy;
} }
precision_t Metrics::informationGain(size_t start, size_t cut, size_t end)
{ precision_t Metrics::informationGain(size_t start, size_t cut, size_t end) {
precision_t iGain; precision_t iGain;
precision_t entropyInterval, entropyLeft, entropyRight; precision_t entropyInterval;
int nElementsLeft = cut - start, nElementsRight = end - cut; precision_t entropyLeft;
int nElements = end - start; precision_t entropyRight;
size_t nElementsLeft = cut - start;
size_t nElementsRight = end - cut;
size_t nElements = end - start;
if (igCache.find(make_tuple(start, cut, end)) != igCache.end()) { if (igCache.find(make_tuple(start, cut, end)) != igCache.end()) {
return igCache[make_tuple(start, cut, end)]; return igCache[make_tuple(start, cut, end)];
} }
entropyInterval = entropy(start, end); entropyInterval = entropy(start, end);
entropyLeft = entropy(start, cut); entropyLeft = entropy(start, cut);
entropyRight = entropy(cut, end); entropyRight = entropy(cut, end);
iGain = entropyInterval - ((precision_t)nElementsLeft * entropyLeft + (precision_t)nElementsRight * entropyRight) / nElements; iGain = entropyInterval -
(static_cast<precision_t>(nElementsLeft) * entropyLeft +
static_cast<precision_t>(nElementsRight) * entropyRight) /
static_cast<precision_t>(nElements);
igCache[make_tuple(start, cut, end)] = iGain; igCache[make_tuple(start, cut, end)] = iGain;
return iGain; return iGain;
} }

View File

@@ -1,19 +1,25 @@
#ifndef CCMETRICS_H #ifndef CCMETRICS_H
#define CCMETRICS_H #define CCMETRICS_H
#include "typesFImdlp.h" #include "typesFImdlp.h"
namespace mdlp { namespace mdlp {
class Metrics { class Metrics {
protected: protected:
labels_t& y; labels_t &y;
indices_t& indices; indices_t &indices;
int numClasses; int numClasses;
cacheEnt_t entropyCache; cacheEnt_t entropyCache = cacheEnt_t();
cacheIg_t igCache; cacheIg_t igCache = cacheIg_t();
public: public:
Metrics(labels_t&, indices_t&); Metrics(labels_t &, indices_t &);
void setData(labels_t&, indices_t&);
void setData(const labels_t &, const indices_t &);
int computeNumClasses(size_t, size_t); int computeNumClasses(size_t, size_t);
precision_t entropy(size_t, size_t); precision_t entropy(size_t, size_t);
precision_t informationGain(size_t, size_t, size_t); precision_t informationGain(size_t, size_t, size_t);
}; };
} }

View File

@@ -38,17 +38,17 @@ tuple<string, string, int, int, float> parse_arguments(int argc, char **argv) {
int max_depth = numeric_limits<int>::max(); int max_depth = numeric_limits<int>::max();
int min_length = 3; int min_length = 3;
float max_cutpoints = 0; float max_cutpoints = 0;
static struct option long_options[] = { const option long_options[] = {
{"help", no_argument, nullptr, 'h'}, {"help", no_argument, nullptr, 'h'},
{"file", required_argument, nullptr, 'f'}, {"file", required_argument, nullptr, 'f'},
{"path", required_argument, nullptr, 'p'}, {"path", required_argument, nullptr, 'p'},
{"max_depth", required_argument, nullptr, 'm'}, {"max_depth", required_argument, nullptr, 'm'},
{"max_cutpoints", required_argument, nullptr, 'c'}, {"max_cutpoints", required_argument, nullptr, 'c'},
{"min_length", required_argument, nullptr, 'n'}, {"min_length", required_argument, nullptr, 'n'},
{nullptr, 0, nullptr, 0} {nullptr, no_argument, nullptr, 0}
}; };
while (true) { while (true) {
auto c = getopt_long(argc, argv, "hf:p:m:c:n:", long_options, nullptr); const auto c = getopt_long(argc, argv, "hf:p:m:c:n:", long_options, nullptr);
if (c == -1) if (c == -1)
break; break;
switch (c) { switch (c) {
@@ -56,16 +56,16 @@ tuple<string, string, int, int, float> parse_arguments(int argc, char **argv) {
usage(argv[0]); usage(argv[0]);
exit(0); exit(0);
case 'f': case 'f':
file_name = optarg; file_name = string(optarg);
break; break;
case 'm': case 'm':
max_depth = (int) strtol(optarg, nullptr, 10); max_depth = stoi(optarg);
break; break;
case 'n': case 'n':
min_length = (int) strtol(optarg, nullptr, 10); min_length = stoi(optarg);
break; break;
case 'c': case 'c':
max_cutpoints = strtof(optarg, nullptr); max_cutpoints = stof(optarg);
break; break;
case 'p': case 'p':
path = optarg; path = optarg;
@@ -92,7 +92,7 @@ void process_file(const string &path, const string &file_name, bool class_last,
file.load(path + file_name + ".arff", class_last); file.load(path + file_name + ".arff", class_last);
auto attributes = file.getAttributes(); auto attributes = file.getAttributes();
int items = file.getSize(); auto items = file.getSize();
cout << "Number of lines: " << items << endl; cout << "Number of lines: " << items << endl;
cout << "Attributes: " << endl; cout << "Attributes: " << endl;
for (auto attribute: attributes) { for (auto attribute: attributes) {
@@ -109,7 +109,7 @@ void process_file(const string &path, const string &file_name, bool class_last,
} }
cout << y[i] << endl; cout << y[i] << endl;
} }
mdlp::CPPFImdlp test = mdlp::CPPFImdlp(min_length, max_depth, max_cutpoints); auto test = mdlp::CPPFImdlp(min_length, max_depth, max_cutpoints);
auto total = 0; auto total = 0;
for (auto i = 0; i < attributes.size(); i++) { for (auto i = 0; i < attributes.size(); i++) {
auto min_max = minmax_element(X[i].begin(), X[i].end()); auto min_max = minmax_element(X[i].begin(), X[i].end());
@@ -141,7 +141,7 @@ void process_all_files(const map<string, bool> &datasets, const string &path, in
size_t timing = 0; size_t timing = 0;
int cut_points = 0; int cut_points = 0;
for (auto i = 0; i < attributes.size(); i++) { for (auto i = 0; i < attributes.size(); i++) {
mdlp::CPPFImdlp test = mdlp::CPPFImdlp(min_length, max_depth, max_cutpoints); auto test = mdlp::CPPFImdlp(min_length, max_depth, max_cutpoints);
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
test.fit(X[i], y); test.fit(X[i], y);
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
@@ -163,8 +163,10 @@ int main(int argc, char **argv) {
{"mfeat-factors", true}, {"mfeat-factors", true},
{"test", true} {"test", true}
}; };
string file_name, path; string file_name;
int max_depth, min_length; string path;
int max_depth;
int min_length;
float max_cutpoints; float max_cutpoints;
tie(file_name, path, max_depth, min_length, max_cutpoints) = parse_arguments(argc, argv); tie(file_name, path, max_depth, min_length, max_cutpoints) = parse_arguments(argc, argv);
if (datasets.find(file_name) == datasets.end() && file_name != "all") { if (datasets.find(file_name) == datasets.end() && file_name != "all") {

View File

@@ -6,82 +6,88 @@
using namespace std; using namespace std;
ArffFiles::ArffFiles() = default; ArffFiles::ArffFiles() = default;
vector<string> ArffFiles::getLines()
{ vector<string> ArffFiles::getLines() const {
return lines; return lines;
} }
unsigned long int ArffFiles::getSize()
{ unsigned long int ArffFiles::getSize() const {
return lines.size(); return lines.size();
} }
vector<pair<string, string>> ArffFiles::getAttributes()
{ vector<pair<string, string>> ArffFiles::getAttributes() const {
return attributes; return attributes;
} }
string ArffFiles::getClassName()
{ string ArffFiles::getClassName() const {
return className; return className;
} }
string ArffFiles::getClassType()
{ string ArffFiles::getClassType() const {
return classType; return classType;
} }
vector<vector<float>>& ArffFiles::getX()
{ vector<vector<float>> &ArffFiles::getX() {
return X; return X;
} }
vector<int>& ArffFiles::getY()
{ vector<int> &ArffFiles::getY() {
return y; return y;
} }
void ArffFiles::load(const string fileName, bool classLast)
{ void ArffFiles::load(const string &fileName, bool classLast) {
ifstream file(fileName); ifstream file(fileName);
if (file.is_open()) { if (!file.is_open()) {
string line, keyword, attribute, type, type_w;
while (getline(file, line)) {
if (line.empty() || line[0] == '%' || line == "\r" || line == " ") {
continue;
}
if (line.find("@attribute") != string::npos || line.find("@ATTRIBUTE") != string::npos) {
stringstream ss(line);
ss >> keyword >> attribute;
type = "";
while(ss >> type_w)
type += type_w + " ";
attributes.emplace_back(attribute, type );
continue;
}
if (line[0] == '@') {
continue;
}
lines.push_back(line);
}
file.close();
if (attributes.empty())
throw invalid_argument("No attributes found");
if (classLast) {
className = get<0>(attributes.back());
classType = get<1>(attributes.back());
attributes.pop_back();
} else {
className = get<0>(attributes.front());
classType = get<1>(attributes.front());
attributes.erase(attributes.begin());
}
generateDataset(classLast);
} else
throw invalid_argument("Unable to open file"); throw invalid_argument("Unable to open file");
}
string line;
string keyword;
string attribute;
string type;
string type_w;
while (getline(file, line)) {
if (line.empty() || line[0] == '%' || line == "\r" || line == " ") {
continue;
}
if (line.find("@attribute") != string::npos || line.find("@ATTRIBUTE") != string::npos) {
stringstream ss(line);
ss >> keyword >> attribute;
type = "";
while (ss >> type_w)
type += type_w + " ";
attributes.emplace_back(attribute, type);
continue;
}
if (line[0] == '@') {
continue;
}
lines.push_back(line);
}
file.close();
if (attributes.empty())
throw invalid_argument("No attributes found");
if (classLast) {
className = get<0>(attributes.back());
classType = get<1>(attributes.back());
attributes.pop_back();
} else {
className = get<0>(attributes.front());
classType = get<1>(attributes.front());
attributes.erase(attributes.begin());
}
generateDataset(classLast);
} }
void ArffFiles::generateDataset(bool classLast)
{ void ArffFiles::generateDataset(bool classLast) {
X = vector<vector<float>>(attributes.size(), vector<float>(lines.size())); X = vector<vector<float>>(attributes.size(), vector<float>(lines.size()));
vector<string> yy = vector<string>(lines.size(), ""); auto yy = vector<string>(lines.size(), "");
int labelIndex = classLast ? static_cast<int>(attributes.size()) : 0; int labelIndex = classLast ? static_cast<int>(attributes.size()) : 0;
for (size_t i = 0; i < lines.size(); i++) { for (size_t i = 0; i < lines.size(); i++) {
stringstream ss(lines[i]); stringstream ss(lines[i]);
string value; string value;
int pos = 0, xIndex = 0; int pos = 0;
int xIndex = 0;
while (getline(ss, value, ',')) { while (getline(ss, value, ',')) {
if (pos++ == labelIndex) { if (pos++ == labelIndex) {
yy[i] = value; yy[i] = value;
@@ -92,20 +98,20 @@ void ArffFiles::generateDataset(bool classLast)
} }
y = factorize(yy); y = factorize(yy);
} }
string ArffFiles::trim(const string& source)
{ string ArffFiles::trim(const string &source) {
string s(source); string s(source);
s.erase(0, s.find_first_not_of(" \n\r\t")); s.erase(0, s.find_first_not_of(" \n\r\t"));
s.erase(s.find_last_not_of(" \n\r\t") + 1); s.erase(s.find_last_not_of(" \n\r\t") + 1);
return s; return s;
} }
vector<int> ArffFiles::factorize(const vector<string>& labels_t)
{ vector<int> ArffFiles::factorize(const vector<string> &labels_t) {
vector<int> yy; vector<int> yy;
yy.reserve(labels_t.size()); yy.reserve(labels_t.size());
map<string, int> labelMap; map<string, int> labelMap;
int i = 0; int i = 0;
for (const string &label : labels_t) { for (const string &label: labels_t) {
if (labelMap.find(label) == labelMap.end()) { if (labelMap.find(label) == labelMap.end()) {
labelMap[label] = i++; labelMap[label] = i++;
} }

View File

@@ -1,27 +1,44 @@
#ifndef ARFFFILES_H #ifndef ARFFFILES_H
#define ARFFFILES_H #define ARFFFILES_H
#include <string> #include <string>
#include <vector> #include <vector>
using namespace std; using namespace std;
class ArffFiles { class ArffFiles {
private: private:
vector<string> lines; vector<string> lines;
vector<pair<string, string>> attributes; vector<pair<string, string>> attributes;
string className, classType; string className;
string classType;
vector<vector<float>> X; vector<vector<float>> X;
vector<int> y; vector<int> y;
void generateDataset(bool); void generateDataset(bool);
public: public:
ArffFiles(); ArffFiles();
void load(string, bool = true);
vector<string> getLines(); void load(const string &, bool = true);
unsigned long int getSize();
string getClassName(); vector<string> getLines() const;
string getClassType();
static string trim(const string&); unsigned long int getSize() const;
vector<vector<float>>& getX();
vector<int>& getY(); string getClassName() const;
vector<pair<string, string>> getAttributes();
static vector<int> factorize(const vector<string>& labels_t); string getClassType() const;
static string trim(const string &);
vector<vector<float>> &getX();
vector<int> &getY();
vector<pair<string, string>> getAttributes() const;
static vector<int> factorize(const vector<string> &labels_t);
}; };
#endif #endif

View File

@@ -4,9 +4,10 @@ include(FetchContent)
include_directories(${GTEST_INCLUDE_DIRS}) include_directories(${GTEST_INCLUDE_DIRS})
FetchContent_Declare( FetchContent_Declare(
googletest googletest
URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
) )
# For Windows: Prevent overriding the parent project's compiler/linker settings # For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)