set min_length as protected

This commit is contained in:
2023-02-26 12:07:52 +01:00
parent 552b03afc9
commit a7d13f602d
3 changed files with 11 additions and 28 deletions

View File

@@ -12,7 +12,7 @@ namespace mdlp {
metrics(Metrics(y, indices)) metrics(Metrics(y, indices))
{ {
} }
CPPFImdlp::CPPFImdlp(int min_length_, int max_depth_): depth(0), max_depth(max_depth_), min_length(min_length_), indices(indices_t()), X(samples_t()), y(labels_t()), CPPFImdlp::CPPFImdlp(size_t min_length_, int max_depth_): depth(0), max_depth(max_depth_), min_length(min_length_), indices(indices_t()), X(samples_t()), y(labels_t()),
metrics(Metrics(y, indices)) metrics(Metrics(y, indices))
{ {
} }

View File

@@ -7,10 +7,11 @@
namespace mdlp { namespace mdlp {
class CPPFImdlp { class CPPFImdlp {
protected: protected:
indices_t indices; size_t min_length;
int depth, max_depth;
samples_t X; samples_t X;
labels_t y; labels_t y;
int depth, max_depth; indices_t indices;
Metrics metrics; Metrics metrics;
cutPoints_t cutPoints; cutPoints_t cutPoints;
@@ -20,9 +21,8 @@ namespace mdlp {
size_t getCandidate(size_t, size_t); size_t getCandidate(size_t, size_t);
pair<precision_t, size_t> valueCutPoint(size_t, size_t, size_t); pair<precision_t, size_t> valueCutPoint(size_t, size_t, size_t);
public: public:
int min_length;
CPPFImdlp(); CPPFImdlp();
CPPFImdlp(int, int); CPPFImdlp(size_t, int);
~CPPFImdlp(); ~CPPFImdlp();
CPPFImdlp& fit(samples_t&, labels_t&); CPPFImdlp& fit(samples_t&, labels_t&);
cutPoints_t getCutPoints(); cutPoints_t getCutPoints();

View File

@@ -57,7 +57,7 @@ namespace mdlp {
void test_dataset(CPPFImdlp& test, string filename, vector<cutPoints_t>& expected, int depths[]) void test_dataset(CPPFImdlp& test, string filename, vector<cutPoints_t>& expected, int depths[])
{ {
ArffFiles file; ArffFiles file;
file.load("../datasets/" + filename, true); file.load("../datasets/" + filename + ".arff", true);
vector<samples_t>& X = file.getX(); vector<samples_t>& X = file.getX();
labels_t& y = file.getY(); labels_t& y = file.getY();
auto attributes = file.getAttributes(); auto attributes = file.getAttributes();
@@ -203,10 +203,11 @@ namespace mdlp {
{0.8 } {0.8 }
}; };
int depths[] = { 1, 1, 1, 1 }; int depths[] = { 1, 1, 1, 1 };
test_dataset(test, "iris.arff", expected, depths); test_dataset(test, "iris", expected, depths);
} }
TEST_F(TestFImdlp, MinLength) TEST_F(TestFImdlp, MinLength)
{ {
auto test = CPPFImdlp(75, 100);
// Set min_length to 75 // Set min_length to 75
vector<cutPoints_t> expected = { vector<cutPoints_t> expected = {
{ 5.45, 5.75 }, { 5.45, 5.75 },
@@ -214,26 +215,8 @@ namespace mdlp {
{ 2.45, 4.75 }, { 2.45, 4.75 },
{ 0.8, 1.75 } { 0.8, 1.75 }
}; };
int depths[] = { 2, 2, 2, 2 }; int depths[] = { 3, 2, 2, 2 };
//test_dataset(test, "iris", expected, depths); test_dataset(test, "iris", expected, depths);
ArffFiles file;
file.load("../datasets/iris.arff", true);
vector<samples_t>& X = file.getX();
labels_t& y = file.getY();
auto attributes = file.getAttributes();
for (auto feature = 0; feature < attributes.size(); feature++) {
auto test = CPPFImdlp(75, 100);
test.fit(X[feature], y);
cout << "Feature: " << feature << " Depth: " << test.get_depth() << endl;
//EXPECT_EQ(test.get_depth(), depths[feature]);
auto computed = test.getCutPoints();
for (auto item : test.getCutPoints()) {
cout << item << " ";
}
cout << endl;
//checkCutPoints(computed, expected[feature]);
}
FAIL();
} }
TEST_F(TestFImdlp, MinLengthMaxDepth) TEST_F(TestFImdlp, MinLengthMaxDepth)
{ {
@@ -246,6 +229,6 @@ namespace mdlp {
{ 0.8, 1.75 } { 0.8, 1.75 }
}; };
int depths[] = { 2, 2, 2, 2 }; int depths[] = { 2, 2, 2, 2 };
test_dataset(test, "iris.arff", expected, depths); test_dataset(test, "iris", expected, depths);
} }
} }