3 Commits

4 changed files with 40 additions and 6 deletions

View File

@@ -7,7 +7,7 @@
namespace mdlp { namespace mdlp {
CPPFImdlp::CPPFImdlp(size_t min_length_, int max_depth_, float proposed): min_length(min_length_), CPPFImdlp::CPPFImdlp(size_t min_length_, int max_depth_, float proposed) : min_length(min_length_),
max_depth(max_depth_), max_depth(max_depth_),
proposed_cuts(proposed) proposed_cuts(proposed)
{ {
@@ -37,6 +37,7 @@ namespace mdlp {
y = y_; y = y_;
num_cut_points = compute_max_num_cut_points(); num_cut_points = compute_max_num_cut_points();
depth = 0; depth = 0;
discretizedData.clear();
cutPoints.clear(); cutPoints.clear();
if (X.size() != y.size()) { if (X.size() != y.size()) {
throw invalid_argument("X and y must have the same size"); throw invalid_argument("X and y must have the same size");
@@ -208,4 +209,13 @@ namespace mdlp {
} }
cutPoints.erase(cutPoints.begin() + static_cast<long>(maxEntropyIdx)); cutPoints.erase(cutPoints.begin() + static_cast<long>(maxEntropyIdx));
} }
labels_t& CPPFImdlp::transform(const samples_t& data)
{
discretizedData.reserve(data.size());
for (const precision_t& item : data) {
auto upper = upper_bound(cutPoints.begin(), cutPoints.end(), item);
discretizedData.push_back(upper - cutPoints.begin());
}
return discretizedData;
}
} }

View File

@@ -20,6 +20,7 @@ namespace mdlp {
Metrics metrics = Metrics(y, indices); Metrics metrics = Metrics(y, indices);
cutPoints_t cutPoints; cutPoints_t cutPoints;
size_t num_cut_points = numeric_limits<size_t>::max(); size_t num_cut_points = numeric_limits<size_t>::max();
labels_t discretizedData = labels_t();
static indices_t sortIndices(samples_t&, labels_t&); static indices_t sortIndices(samples_t&, labels_t&);
@@ -36,6 +37,7 @@ namespace mdlp {
~CPPFImdlp(); ~CPPFImdlp();
void fit(samples_t&, labels_t&); void fit(samples_t&, labels_t&);
inline cutPoints_t getCutPoints() const { return cutPoints; }; inline cutPoints_t getCutPoints() const { return cutPoints; };
labels_t& transform(const samples_t&);
inline int get_depth() const { return depth; }; inline int get_depth() const { return depth; };
static inline string version() { return "1.1.2"; }; static inline string version() { return "1.1.2"; };
}; };

View File

@@ -63,7 +63,7 @@ void ArffFiles::load(const string& fileName, bool classLast)
type = ""; type = "";
while (ss >> type_w) while (ss >> type_w)
type += type_w + " "; type += type_w + " ";
attributes.emplace_back(attribute, trim(type)); attributes.emplace_back(trim(attribute), trim(type));
continue; continue;
} }
if (line[0] == '@') { if (line[0] == '@') {
@@ -111,8 +111,8 @@ void ArffFiles::generateDataset(bool classLast)
string ArffFiles::trim(const string& source) string ArffFiles::trim(const string& source)
{ {
string s(source); string s(source);
s.erase(0, s.find_first_not_of(" \n\r\t")); s.erase(0, s.find_first_not_of(" '\n\r\t"));
s.erase(s.find_last_not_of(" \n\r\t") + 1); s.erase(s.find_last_not_of(" '\n\r\t") + 1);
return s; return s;
} }

View File

@@ -15,11 +15,11 @@ throw; \
, etype) , etype)
namespace mdlp { namespace mdlp {
class TestFImdlp: public CPPFImdlp, public testing::Test { class TestFImdlp : public CPPFImdlp, public testing::Test {
public: public:
precision_t precision = 0.000001f; precision_t precision = 0.000001f;
TestFImdlp(): CPPFImdlp() {} TestFImdlp() : CPPFImdlp() {}
string data_path; string data_path;
@@ -329,4 +329,26 @@ namespace mdlp {
} }
} }
TEST_F(TestFImdlp, TransformTest)
{
labels_t expected = {
5, 3, 4, 4, 5, 5, 5, 5, 2, 4, 5, 5, 3, 3, 5, 5, 5, 5, 5, 5, 5, 5,
5, 4, 5, 3, 5, 5, 5, 4, 4, 5, 5, 5, 4, 4, 5, 4, 3, 5, 5, 0, 4, 5,
5, 3, 5, 4, 5, 4, 4, 4, 4, 0, 1, 1, 4, 0, 2, 0, 0, 3, 0, 2, 2, 4,
3, 0, 0, 0, 4, 1, 0, 1, 2, 3, 1, 3, 2, 0, 0, 0, 0, 0, 3, 5, 4, 0,
3, 0, 0, 3, 0, 0, 0, 3, 2, 2, 0, 1, 4, 0, 3, 2, 3, 3, 0, 2, 0, 5,
4, 0, 3, 0, 1, 4, 3, 5, 0, 0, 4, 1, 1, 0, 4, 4, 1, 3, 1, 3, 1, 5,
1, 1, 0, 3, 5, 4, 3, 4, 4, 4, 0, 4, 4, 3, 0, 3, 5, 3
};
ArffFiles file;
file.load(data_path + "iris.arff", true);
vector<samples_t>& X = file.getX();
labels_t& y = file.getY();
fit(X[1], y);
auto computed = transform(X[1]);
EXPECT_EQ(computed.size(), expected.size());
for (unsigned long i = 0; i < computed.size(); i++) {
EXPECT_EQ(computed[i], expected[i]);
}
}
} }