Reformat code and update version

This commit is contained in:
2023-04-25 10:48:59 +02:00
parent 22997f5d69
commit a1f26a257c
8 changed files with 199 additions and 171 deletions

View File

@@ -9,14 +9,16 @@ namespace mdlp {
CPPFImdlp::CPPFImdlp(size_t min_length_, int max_depth_, float proposed): min_length(min_length_), CPPFImdlp::CPPFImdlp(size_t min_length_, int max_depth_, float proposed): min_length(min_length_),
max_depth(max_depth_), max_depth(max_depth_),
proposed_cuts(proposed) { proposed_cuts(proposed)
{
} }
CPPFImdlp::CPPFImdlp() = default; CPPFImdlp::CPPFImdlp() = default;
CPPFImdlp::~CPPFImdlp() = default; CPPFImdlp::~CPPFImdlp() = default;
size_t CPPFImdlp::compute_max_num_cut_points() const { size_t CPPFImdlp::compute_max_num_cut_points() const
{
// Set the actual maximum number of cut points as a number or as a percentage of the number of samples // Set the actual maximum number of cut points as a number or as a percentage of the number of samples
if (proposed_cuts == 0) { if (proposed_cuts == 0) {
return numeric_limits<size_t>::max(); return numeric_limits<size_t>::max();
@@ -29,7 +31,8 @@ namespace mdlp {
return static_cast<size_t>(proposed_cuts); return static_cast<size_t>(proposed_cuts);
} }
void CPPFImdlp::fit(samples_t &X_, labels_t &y_) { void CPPFImdlp::fit(samples_t& X_, labels_t& y_)
{
X = X_; X = X_;
y = y_; y = y_;
num_cut_points = compute_max_num_cut_points(); num_cut_points = compute_max_num_cut_points();
@@ -59,7 +62,8 @@ namespace mdlp {
} }
} }
pair<precision_t, size_t> CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end) { pair<precision_t, size_t> CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end)
{
size_t n; size_t n;
size_t m; size_t m;
size_t idxPrev = cut - 1 >= start ? cut - 1 : cut; size_t idxPrev = cut - 1 >= start ? cut - 1 : cut;
@@ -91,7 +95,8 @@ namespace mdlp {
return { (actual + previous) / 2, cut }; return { (actual + previous) / 2, cut };
} }
void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_) { void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_)
{
size_t cut; size_t cut;
pair<precision_t, size_t> result; pair<precision_t, size_t> result;
// Check if the interval length and the depth are Ok // Check if the interval length and the depth are Ok
@@ -110,7 +115,8 @@ namespace mdlp {
} }
} }
size_t CPPFImdlp::getCandidate(size_t start, size_t end) { size_t CPPFImdlp::getCandidate(size_t start, size_t end)
{
/* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which /* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which
E(A, TA; S) is minimal amongst all the candidate cut points. */ E(A, TA; S) is minimal amongst all the candidate cut points. */
size_t candidate = numeric_limits<size_t>::max(); size_t candidate = numeric_limits<size_t>::max();
@@ -143,7 +149,8 @@ namespace mdlp {
return candidate; return candidate;
} }
bool CPPFImdlp::mdlp(size_t start, size_t cut, size_t end) { bool CPPFImdlp::mdlp(size_t start, size_t cut, size_t end)
{
int k; int k;
int k1; int k1;
int k2; int k2;
@@ -167,7 +174,8 @@ namespace mdlp {
} }
// Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes // Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes
indices_t CPPFImdlp::sortIndices(samples_t &X_, labels_t &y_) { indices_t CPPFImdlp::sortIndices(samples_t& X_, labels_t& y_)
{
indices_t idx(X_.size()); indices_t idx(X_.size());
iota(idx.begin(), idx.end(), 0); iota(idx.begin(), idx.end(), 0);
stable_sort(idx.begin(), idx.end(), [&X_, &y_](size_t i1, size_t i2) { stable_sort(idx.begin(), idx.end(), [&X_, &y_](size_t i1, size_t i2) {
@@ -179,7 +187,8 @@ namespace mdlp {
return idx; return idx;
} }
void CPPFImdlp::resizeCutPoints() { void CPPFImdlp::resizeCutPoints()
{
//Compute entropy of each of the whole cutpoint set and discards the biggest value //Compute entropy of each of the whole cutpoint set and discards the biggest value
precision_t maxEntropy = 0; precision_t maxEntropy = 0;
precision_t entropy; precision_t entropy;

View File

@@ -24,31 +24,20 @@ namespace mdlp {
static indices_t sortIndices(samples_t&, labels_t&); static indices_t sortIndices(samples_t&, labels_t&);
void computeCutPoints(size_t, size_t, int); void computeCutPoints(size_t, size_t, int);
void resizeCutPoints(); void resizeCutPoints();
bool mdlp(size_t, size_t, size_t); bool mdlp(size_t, size_t, size_t);
size_t getCandidate(size_t, size_t); size_t getCandidate(size_t, size_t);
size_t compute_max_num_cut_points() const; size_t compute_max_num_cut_points() const;
pair<precision_t, size_t> valueCutPoint(size_t, size_t, size_t); pair<precision_t, size_t> valueCutPoint(size_t, size_t, size_t);
public: public:
CPPFImdlp(); CPPFImdlp();
CPPFImdlp(size_t, int, float); CPPFImdlp(size_t, int, float);
~CPPFImdlp(); ~CPPFImdlp();
void fit(samples_t&, labels_t&); void fit(samples_t&, labels_t&);
inline cutPoints_t getCutPoints() const { return cutPoints; }; inline cutPoints_t getCutPoints() const { return cutPoints; };
inline int get_depth() const { return depth; }; inline int get_depth() const { return depth; };
static inline string version() { return "1.1.2"; };
static inline string version() { return "1.1.1"; };
}; };
} }
#endif #endif

View File

@@ -5,10 +5,12 @@
using namespace std; using namespace std;
namespace mdlp { namespace mdlp {
Metrics::Metrics(labels_t& y_, indices_t& indices_): y(y_), indices(indices_), Metrics::Metrics(labels_t& y_, indices_t& indices_): y(y_), indices(indices_),
numClasses(computeNumClasses(0, indices.size())) { numClasses(computeNumClasses(0, indices.size()))
{
} }
int Metrics::computeNumClasses(size_t start, size_t end) { int Metrics::computeNumClasses(size_t start, size_t end)
{
set<int> nClasses; set<int> nClasses;
for (auto i = start; i < end; ++i) { for (auto i = start; i < end; ++i) {
nClasses.insert(y[indices[i]]); nClasses.insert(y[indices[i]]);
@@ -16,7 +18,8 @@ namespace mdlp {
return static_cast<int>(nClasses.size()); return static_cast<int>(nClasses.size());
} }
void Metrics::setData(const labels_t &y_, const indices_t &indices_) { void Metrics::setData(const labels_t& y_, const indices_t& indices_)
{
indices = indices_; indices = indices_;
y = y_; y = y_;
numClasses = computeNumClasses(0, indices.size()); numClasses = computeNumClasses(0, indices.size());
@@ -24,7 +27,8 @@ namespace mdlp {
igCache.clear(); igCache.clear();
} }
precision_t Metrics::entropy(size_t start, size_t end) { precision_t Metrics::entropy(size_t start, size_t end)
{
precision_t p; precision_t p;
precision_t ventropy = 0; precision_t ventropy = 0;
int nElements = 0; int nElements = 0;
@@ -48,7 +52,8 @@ namespace mdlp {
return ventropy; return ventropy;
} }
precision_t Metrics::informationGain(size_t start, size_t cut, size_t end) { precision_t Metrics::informationGain(size_t start, size_t cut, size_t end)
{
precision_t iGain; precision_t iGain;
precision_t entropyInterval; precision_t entropyInterval;
precision_t entropyLeft; precision_t entropyLeft;

View File

@@ -13,13 +13,9 @@ namespace mdlp {
cacheIg_t igCache = cacheIg_t(); cacheIg_t igCache = cacheIg_t();
public: public:
Metrics(labels_t&, indices_t&); Metrics(labels_t&, indices_t&);
void setData(const labels_t&, const indices_t&); void setData(const labels_t&, const indices_t&);
int computeNumClasses(size_t, size_t); int computeNumClasses(size_t, size_t);
precision_t entropy(size_t, size_t); precision_t entropy(size_t, size_t);
precision_t informationGain(size_t, size_t, size_t); precision_t informationGain(size_t, size_t, size_t);
}; };
} }

View File

@@ -7,35 +7,43 @@ using namespace std;
ArffFiles::ArffFiles() = default; ArffFiles::ArffFiles() = default;
vector<string> ArffFiles::getLines() const { vector<string> ArffFiles::getLines() const
{
return lines; return lines;
} }
unsigned long int ArffFiles::getSize() const { unsigned long int ArffFiles::getSize() const
{
return lines.size(); return lines.size();
} }
vector<pair<string, string>> ArffFiles::getAttributes() const { vector<pair<string, string>> ArffFiles::getAttributes() const
{
return attributes; return attributes;
} }
string ArffFiles::getClassName() const { string ArffFiles::getClassName() const
{
return className; return className;
} }
string ArffFiles::getClassType() const { string ArffFiles::getClassType() const
{
return classType; return classType;
} }
vector<mdlp::samples_t> &ArffFiles::getX() { vector<mdlp::samples_t>& ArffFiles::getX()
{
return X; return X;
} }
vector<int> &ArffFiles::getY() { vector<int>& ArffFiles::getY()
{
return y; return y;
} }
void ArffFiles::load(const string &fileName, bool classLast) { void ArffFiles::load(const string& fileName, bool classLast)
{
ifstream file(fileName); ifstream file(fileName);
if (!file.is_open()) { if (!file.is_open()) {
throw invalid_argument("Unable to open file"); throw invalid_argument("Unable to open file");
@@ -79,7 +87,8 @@ void ArffFiles::load(const string &fileName, bool classLast) {
} }
void ArffFiles::generateDataset(bool classLast) { void ArffFiles::generateDataset(bool classLast)
{
X = vector<mdlp::samples_t>(attributes.size(), mdlp::samples_t(lines.size())); X = vector<mdlp::samples_t>(attributes.size(), mdlp::samples_t(lines.size()));
auto yy = vector<string>(lines.size(), ""); auto yy = vector<string>(lines.size(), "");
int labelIndex = classLast ? static_cast<int>(attributes.size()) : 0; int labelIndex = classLast ? static_cast<int>(attributes.size()) : 0;
@@ -99,14 +108,16 @@ void ArffFiles::generateDataset(bool classLast) {
y = factorize(yy); y = factorize(yy);
} }
string ArffFiles::trim(const string &source) { string ArffFiles::trim(const string& source)
{
string s(source); string s(source);
s.erase(0, s.find_first_not_of(" \n\r\t")); s.erase(0, s.find_first_not_of(" \n\r\t"));
s.erase(s.find_last_not_of(" \n\r\t") + 1); s.erase(s.find_last_not_of(" \n\r\t") + 1);
return s; return s;
} }
vector<int> ArffFiles::factorize(const vector<string> &labels_t) { vector<int> ArffFiles::factorize(const vector<string>& labels_t)
{
vector<int> yy; vector<int> yy;
yy.reserve(labels_t.size()); yy.reserve(labels_t.size());
map<string, int> labelMap; map<string, int> labelMap;

View File

@@ -20,25 +20,15 @@ private:
public: public:
ArffFiles(); ArffFiles();
void load(const string&, bool = true); void load(const string&, bool = true);
vector<string> getLines() const; vector<string> getLines() const;
unsigned long int getSize() const; unsigned long int getSize() const;
string getClassName() const; string getClassName() const;
string getClassType() const; string getClassType() const;
static string trim(const string&); static string trim(const string&);
vector<mdlp::samples_t>& getX(); vector<mdlp::samples_t>& getX();
vector<int>& getY(); vector<int>& getY();
vector<pair<string, string>> getAttributes() const; vector<pair<string, string>> getAttributes() const;
static vector<int> factorize(const vector<string>& labels_t); static vector<int> factorize(const vector<string>& labels_t);
}; };

View File

@@ -23,7 +23,8 @@ namespace mdlp {
string data_path; string data_path;
void SetUp() override { void SetUp() override
{
X = { 4.7f, 4.7f, 4.7f, 4.7f, 4.8f, 4.8f, 4.8f, 4.8f, 4.9f, 4.95f, 5.7f, 5.3f, 5.2f, 5.1f, 5.0f, 5.6f, 5.1f, X = { 4.7f, 4.7f, 4.7f, 4.7f, 4.8f, 4.8f, 4.8f, 4.8f, 4.9f, 4.95f, 5.7f, 5.3f, 5.2f, 5.1f, 5.0f, 5.6f, 5.1f,
6.0f, 5.1f, 5.9f }; 6.0f, 5.1f, 5.9f };
y = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2 }; y = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2 };
@@ -31,7 +32,8 @@ namespace mdlp {
data_path = set_data_path(); data_path = set_data_path();
} }
static string set_data_path() { static string set_data_path()
{
string path = "../datasets/"; string path = "../datasets/";
ifstream file(path + "iris.arff"); ifstream file(path + "iris.arff");
if (file.is_open()) { if (file.is_open()) {
@@ -41,7 +43,8 @@ namespace mdlp {
return "../../tests/datasets/"; return "../../tests/datasets/";
} }
void checkSortedVector() { void checkSortedVector()
{
indices_t testSortedIndices = sortIndices(X, y); indices_t testSortedIndices = sortIndices(X, y);
precision_t prev = X[testSortedIndices[0]]; precision_t prev = X[testSortedIndices[0]];
for (unsigned long i = 0; i < X.size(); ++i) { for (unsigned long i = 0; i < X.size(); ++i) {
@@ -51,7 +54,8 @@ namespace mdlp {
} }
} }
void checkCutPoints(cutPoints_t &computed, cutPoints_t &expected) const { void checkCutPoints(cutPoints_t& computed, cutPoints_t& expected) const
{
EXPECT_EQ(computed.size(), expected.size()); EXPECT_EQ(computed.size(), expected.size());
for (unsigned long i = 0; i < computed.size(); i++) { for (unsigned long i = 0; i < computed.size(); i++) {
cout << "(" << computed[i] << ", " << expected[i] << ") "; cout << "(" << computed[i] << ", " << expected[i] << ") ";
@@ -59,7 +63,8 @@ namespace mdlp {
} }
} }
bool test_result(const samples_t &X_, size_t cut, float midPoint, size_t limit, const string &title) { bool test_result(const samples_t& X_, size_t cut, float midPoint, size_t limit, const string& title)
{
pair<precision_t, size_t> result; pair<precision_t, size_t> result;
labels_t y_ = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; labels_t y_ = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
X = X_; X = X_;
@@ -73,7 +78,8 @@ namespace mdlp {
} }
void test_dataset(CPPFImdlp& test, const string& filename, vector<cutPoints_t>& expected, void test_dataset(CPPFImdlp& test, const string& filename, vector<cutPoints_t>& expected,
vector<int> &depths) const { vector<int>& depths) const
{
ArffFiles file; ArffFiles file;
file.load(data_path + filename + ".arff", true); file.load(data_path + filename + ".arff", true);
vector<samples_t>& X = file.getX(); vector<samples_t>& X = file.getX();
@@ -90,19 +96,22 @@ namespace mdlp {
} }
}; };
TEST_F(TestFImdlp, FitErrorEmptyDataset) { TEST_F(TestFImdlp, FitErrorEmptyDataset)
{
X = samples_t(); X = samples_t();
y = labels_t(); y = labels_t();
EXPECT_THROW_WITH_MESSAGE(fit(X, y), invalid_argument, "X and y must have at least one element"); EXPECT_THROW_WITH_MESSAGE(fit(X, y), invalid_argument, "X and y must have at least one element");
} }
TEST_F(TestFImdlp, FitErrorDifferentSize) { TEST_F(TestFImdlp, FitErrorDifferentSize)
{
X = { 1, 2, 3 }; X = { 1, 2, 3 };
y = { 1, 2 }; y = { 1, 2 };
EXPECT_THROW_WITH_MESSAGE(fit(X, y), invalid_argument, "X and y must have the same size"); EXPECT_THROW_WITH_MESSAGE(fit(X, y), invalid_argument, "X and y must have the same size");
} }
TEST_F(TestFImdlp, FitErrorMinLengtMaxDepth) { TEST_F(TestFImdlp, FitErrorMinLengtMaxDepth)
{
auto testLength = CPPFImdlp(2, 10, 0); auto testLength = CPPFImdlp(2, 10, 0);
auto testDepth = CPPFImdlp(3, 0, 0); auto testDepth = CPPFImdlp(3, 0, 0);
X = { 1, 2, 3 }; X = { 1, 2, 3 };
@@ -111,7 +120,8 @@ namespace mdlp {
EXPECT_THROW_WITH_MESSAGE(testDepth.fit(X, y), invalid_argument, "max_depth must be greater than 0"); EXPECT_THROW_WITH_MESSAGE(testDepth.fit(X, y), invalid_argument, "max_depth must be greater than 0");
} }
TEST_F(TestFImdlp, JoinFit) { TEST_F(TestFImdlp, JoinFit)
{
samples_t X_ = { 1, 2, 2, 3, 4, 2, 3 }; samples_t X_ = { 1, 2, 2, 3, 4, 2, 3 };
labels_t y_ = { 0, 0, 1, 2, 3, 4, 5 }; labels_t y_ = { 0, 0, 1, 2, 3, 4, 5 };
cutPoints_t expected = { 1.5f, 2.5f }; cutPoints_t expected = { 1.5f, 2.5f };
@@ -121,7 +131,8 @@ namespace mdlp {
checkCutPoints(computed, expected); checkCutPoints(computed, expected);
} }
TEST_F(TestFImdlp, FitErrorMaxCutPoints) { TEST_F(TestFImdlp, FitErrorMaxCutPoints)
{
auto testmin = CPPFImdlp(2, 10, -1); auto testmin = CPPFImdlp(2, 10, -1);
auto testmax = CPPFImdlp(3, 0, 200); auto testmax = CPPFImdlp(3, 0, 200);
X = { 1, 2, 3 }; X = { 1, 2, 3 };
@@ -130,7 +141,8 @@ namespace mdlp {
EXPECT_THROW_WITH_MESSAGE(testmax.fit(X, y), invalid_argument, "wrong proposed num_cuts value"); EXPECT_THROW_WITH_MESSAGE(testmax.fit(X, y), invalid_argument, "wrong proposed num_cuts value");
} }
TEST_F(TestFImdlp, SortIndices) { TEST_F(TestFImdlp, SortIndices)
{
X = { 5.7f, 5.3f, 5.2f, 5.1f, 5.0f, 5.6f, 5.1f, 6.0f, 5.1f, 5.9f }; X = { 5.7f, 5.3f, 5.2f, 5.1f, 5.0f, 5.6f, 5.1f, 6.0f, 5.1f, 5.9f };
y = { 1, 1, 1, 1, 1, 2, 2, 2, 2, 2 }; y = { 1, 1, 1, 1, 1, 2, 2, 2, 2, 2 };
indices = { 4, 3, 6, 8, 2, 1, 5, 0, 9, 7 }; indices = { 4, 3, 6, 8, 2, 1, 5, 0, 9, 7 };
@@ -148,7 +160,8 @@ namespace mdlp {
indices = { 1, 2, 0 }; indices = { 1, 2, 0 };
} }
TEST_F(TestFImdlp, TestShortDatasets) { TEST_F(TestFImdlp, TestShortDatasets)
{
vector<precision_t> computed; vector<precision_t> computed;
X = { 1 }; X = { 1 };
y = { 1 }; y = { 1 };
@@ -173,7 +186,8 @@ namespace mdlp {
EXPECT_NEAR(computed[0], 1.5, precision); EXPECT_NEAR(computed[0], 1.5, precision);
} }
TEST_F(TestFImdlp, TestArtificialDataset) { TEST_F(TestFImdlp, TestArtificialDataset)
{
fit(X, y); fit(X, y);
cutPoints_t expected = { 5.05f }; cutPoints_t expected = { 5.05f };
vector<precision_t> computed = getCutPoints(); vector<precision_t> computed = getCutPoints();
@@ -183,7 +197,8 @@ namespace mdlp {
} }
} }
TEST_F(TestFImdlp, TestIris) { TEST_F(TestFImdlp, TestIris)
{
vector<cutPoints_t> expected = { vector<cutPoints_t> expected = {
{5.45f, 5.75f}, {5.45f, 5.75f},
{2.75f, 2.85f, 2.95f, 3.05f, 3.35f}, {2.75f, 2.85f, 2.95f, 3.05f, 3.35f},
@@ -195,7 +210,8 @@ namespace mdlp {
test_dataset(test, "iris", expected, depths); test_dataset(test, "iris", expected, depths);
} }
TEST_F(TestFImdlp, ComputeCutPointsGCase) { TEST_F(TestFImdlp, ComputeCutPointsGCase)
{
cutPoints_t expected; cutPoints_t expected;
expected = { 1.5 }; expected = { 1.5 };
samples_t X_ = { 0, 1, 2, 2, 2 }; samples_t X_ = { 0, 1, 2, 2, 2 };
@@ -205,7 +221,8 @@ namespace mdlp {
checkCutPoints(computed, expected); checkCutPoints(computed, expected);
} }
TEST_F(TestFImdlp, ValueCutPoint) { TEST_F(TestFImdlp, ValueCutPoint)
{
// Case titles as stated in the doc // Case titles as stated in the doc
samples_t X1a{ 3.1f, 3.2f, 3.3f, 3.4f, 3.5f, 3.6f, 3.7f, 3.8f, 3.9f, 4.0f }; samples_t X1a{ 3.1f, 3.2f, 3.3f, 3.4f, 3.5f, 3.6f, 3.7f, 3.8f, 3.9f, 4.0f };
test_result(X1a, 6, 7.3f / 2, 6, "1a"); test_result(X1a, 6, 7.3f / 2, 6, "1a");
@@ -225,7 +242,8 @@ namespace mdlp {
test_result(X4c, 4, 6.9f / 2, 2, "4c"); test_result(X4c, 4, 6.9f / 2, 2, "4c");
} }
TEST_F(TestFImdlp, MaxDepth) { TEST_F(TestFImdlp, MaxDepth)
{
// Set max_depth to 1 // Set max_depth to 1
auto test = CPPFImdlp(3, 1, 0); auto test = CPPFImdlp(3, 1, 0);
vector<cutPoints_t> expected = { vector<cutPoints_t> expected = {
@@ -238,7 +256,8 @@ namespace mdlp {
test_dataset(test, "iris", expected, depths); test_dataset(test, "iris", expected, depths);
} }
TEST_F(TestFImdlp, MinLength) { TEST_F(TestFImdlp, MinLength)
{
auto test = CPPFImdlp(75, 100, 0); auto test = CPPFImdlp(75, 100, 0);
// Set min_length to 75 // Set min_length to 75
vector<cutPoints_t> expected = { vector<cutPoints_t> expected = {
@@ -251,7 +270,8 @@ namespace mdlp {
test_dataset(test, "iris", expected, depths); test_dataset(test, "iris", expected, depths);
} }
TEST_F(TestFImdlp, MinLengthMaxDepth) { TEST_F(TestFImdlp, MinLengthMaxDepth)
{
// Set min_length to 75 // Set min_length to 75
auto test = CPPFImdlp(75, 2, 0); auto test = CPPFImdlp(75, 2, 0);
vector<cutPoints_t> expected = { vector<cutPoints_t> expected = {
@@ -264,7 +284,8 @@ namespace mdlp {
test_dataset(test, "iris", expected, depths); test_dataset(test, "iris", expected, depths);
} }
TEST_F(TestFImdlp, MaxCutPointsInteger) { TEST_F(TestFImdlp, MaxCutPointsInteger)
{
// Set min_length to 75 // Set min_length to 75
auto test = CPPFImdlp(75, 2, 1); auto test = CPPFImdlp(75, 2, 1);
vector<cutPoints_t> expected = { vector<cutPoints_t> expected = {
@@ -278,7 +299,8 @@ namespace mdlp {
} }
TEST_F(TestFImdlp, MaxCutPointsFloat) { TEST_F(TestFImdlp, MaxCutPointsFloat)
{
// Set min_length to 75 // Set min_length to 75
auto test = CPPFImdlp(75, 2, 0.2f); auto test = CPPFImdlp(75, 2, 0.2f);
vector<cutPoints_t> expected = { vector<cutPoints_t> expected = {
@@ -291,7 +313,8 @@ namespace mdlp {
test_dataset(test, "iris", expected, depths); test_dataset(test, "iris", expected, depths);
} }
TEST_F(TestFImdlp, ProposedCuts) { TEST_F(TestFImdlp, ProposedCuts)
{
vector<pair<float, size_t>> proposed_list = { {0.1f, 2}, vector<pair<float, size_t>> proposed_list = { {0.1f, 2},
{0.5f, 10}, {0.5f, 10},
{0.07f, 1}, {0.07f, 1},

View File

@@ -10,19 +10,22 @@ namespace mdlp {
TestMetrics(): Metrics(y_, indices_) {}; TestMetrics(): Metrics(y_, indices_) {};
void SetUp() override { void SetUp() override
{
setData(y_, indices_); setData(y_, indices_);
} }
}; };
TEST_F(TestMetrics, NumClasses) { TEST_F(TestMetrics, NumClasses)
{
y = { 1, 1, 1, 1, 1, 1, 1, 1, 2, 1 }; y = { 1, 1, 1, 1, 1, 1, 1, 1, 2, 1 };
EXPECT_EQ(1, computeNumClasses(4, 8)); EXPECT_EQ(1, computeNumClasses(4, 8));
EXPECT_EQ(2, computeNumClasses(0, 10)); EXPECT_EQ(2, computeNumClasses(0, 10));
EXPECT_EQ(2, computeNumClasses(8, 10)); EXPECT_EQ(2, computeNumClasses(8, 10));
} }
TEST_F(TestMetrics, Entropy) { TEST_F(TestMetrics, Entropy)
{
EXPECT_EQ(1, entropy(0, 10)); EXPECT_EQ(1, entropy(0, 10));
EXPECT_EQ(0, entropy(0, 5)); EXPECT_EQ(0, entropy(0, 5));
y = { 1, 1, 1, 1, 1, 1, 1, 1, 2, 1 }; y = { 1, 1, 1, 1, 1, 1, 1, 1, 2, 1 };
@@ -30,7 +33,8 @@ namespace mdlp {
ASSERT_NEAR(0.468996f, entropy(0, 10), precision); ASSERT_NEAR(0.468996f, entropy(0, 10), precision);
} }
TEST_F(TestMetrics, EntropyDouble) { TEST_F(TestMetrics, EntropyDouble)
{
y = { 0, 0, 1, 2, 3 }; y = { 0, 0, 1, 2, 3 };
samples_t expected_entropies = { 0.0, 0.0, 0.91829583, 1.5, 1.4575424759098898 }; samples_t expected_entropies = { 0.0, 0.0, 0.91829583, 1.5, 1.4575424759098898 };
for (auto idx = 0; idx < y.size(); ++idx) { for (auto idx = 0; idx < y.size(); ++idx) {
@@ -38,7 +42,8 @@ namespace mdlp {
} }
} }
TEST_F(TestMetrics, InformationGain) { TEST_F(TestMetrics, InformationGain)
{
ASSERT_NEAR(1, informationGain(0, 5, 10), precision); ASSERT_NEAR(1, informationGain(0, 5, 10), precision);
ASSERT_NEAR(1, informationGain(0, 5, 10), precision); // For cache ASSERT_NEAR(1, informationGain(0, 5, 10), precision); // For cache
y = { 1, 1, 1, 1, 1, 1, 1, 1, 2, 1 }; y = { 1, 1, 1, 1, 1, 1, 1, 1, 2, 1 };