mirror of
https://github.com/rmontanana/mdlp.git
synced 2025-08-16 07:55:58 +00:00
Add max_depth and min_length as hyperparams
This commit is contained in:
@@ -8,7 +8,11 @@
|
||||
|
||||
namespace mdlp {
|
||||
|
||||
CPPFImdlp::CPPFImdlp(): indices(indices_t()), X(samples_t()), y(labels_t()),
|
||||
CPPFImdlp::CPPFImdlp():depth(0), max_depth(numeric_limits<int>::max()), min_length(3), indices(indices_t()), X(samples_t()), y(labels_t()),
|
||||
metrics(Metrics(y, indices))
|
||||
{
|
||||
}
|
||||
CPPFImdlp::CPPFImdlp(int min_length_, int max_depth_): depth(0), max_depth(max_depth_), min_length(min_length_), indices(indices_t()), X(samples_t()), y(labels_t()),
|
||||
metrics(Metrics(y, indices))
|
||||
{
|
||||
}
|
||||
@@ -25,9 +29,15 @@ namespace mdlp {
|
||||
if (X.empty() || y.empty()) {
|
||||
throw invalid_argument("X and y must have at least one element");
|
||||
}
|
||||
if (min_length < 3) {
|
||||
throw invalid_argument("min_length must be greater than 2");
|
||||
}
|
||||
if (max_depth < 1) {
|
||||
throw invalid_argument("max_depth must be greater than 0");
|
||||
}
|
||||
indices = sortIndices(X_, y_);
|
||||
metrics.setData(y, indices);
|
||||
computeCutPoints(0, X.size());
|
||||
computeCutPoints(0, X.size(), 1);
|
||||
return *this;
|
||||
}
|
||||
|
||||
@@ -60,12 +70,14 @@ namespace mdlp {
|
||||
return { (actual + previous) / 2, cut };
|
||||
}
|
||||
|
||||
void CPPFImdlp::computeCutPoints(size_t start, size_t end)
|
||||
void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_)
|
||||
{
|
||||
size_t cut;
|
||||
pair<precision_t, size_t> result;
|
||||
if (end - start < 3)
|
||||
// Check if the interval length and the depth are Ok
|
||||
if (end - start < min_length || depth_ > max_depth)
|
||||
return;
|
||||
depth = depth_ > depth ? depth_ : depth;
|
||||
cut = getCandidate(start, end);
|
||||
if (cut == numeric_limits<size_t>::max())
|
||||
return;
|
||||
@@ -73,8 +85,8 @@ namespace mdlp {
|
||||
result = valueCutPoint(start, cut, end);
|
||||
cut = result.second;
|
||||
cutPoints.push_back(result.first);
|
||||
computeCutPoints(start, cut);
|
||||
computeCutPoints(cut, end);
|
||||
computeCutPoints(start, cut, depth_ + 1);
|
||||
computeCutPoints(cut, end, depth_ + 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,4 +170,8 @@ namespace mdlp {
|
||||
sort(output.begin(), output.end());
|
||||
return output;
|
||||
}
|
||||
int CPPFImdlp::get_depth()
|
||||
{
|
||||
return depth;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user