Add transform method to discretize values using CutPoints

This commit is contained in:
2023-07-06 16:06:52 +02:00
parent e8559faf1f
commit 5679d607e5
2 changed files with 13 additions and 1 deletions

View File

@@ -7,7 +7,7 @@
namespace mdlp { namespace mdlp {
CPPFImdlp::CPPFImdlp(size_t min_length_, int max_depth_, float proposed): min_length(min_length_), CPPFImdlp::CPPFImdlp(size_t min_length_, int max_depth_, float proposed) : min_length(min_length_),
max_depth(max_depth_), max_depth(max_depth_),
proposed_cuts(proposed) proposed_cuts(proposed)
{ {
@@ -37,6 +37,7 @@ namespace mdlp {
y = y_; y = y_;
num_cut_points = compute_max_num_cut_points(); num_cut_points = compute_max_num_cut_points();
depth = 0; depth = 0;
discretizedData.clear();
cutPoints.clear(); cutPoints.clear();
if (X.size() != y.size()) { if (X.size() != y.size()) {
throw invalid_argument("X and y must have the same size"); throw invalid_argument("X and y must have the same size");
@@ -208,4 +209,13 @@ namespace mdlp {
} }
cutPoints.erase(cutPoints.begin() + static_cast<long>(maxEntropyIdx)); cutPoints.erase(cutPoints.begin() + static_cast<long>(maxEntropyIdx));
} }
labels_t& CPPFImdlp::transform(const samples_t& data)
{
discretizedData.reserve(data.size());
for (const precision_t& item : data) {
auto upper = upper_bound(cutPoints.begin(), cutPoints.end(), item);
discretizedData.push_back(upper - cutPoints.begin());
}
return discretizedData;
}
} }

View File

@@ -20,6 +20,7 @@ namespace mdlp {
Metrics metrics = Metrics(y, indices); Metrics metrics = Metrics(y, indices);
cutPoints_t cutPoints; cutPoints_t cutPoints;
size_t num_cut_points = numeric_limits<size_t>::max(); size_t num_cut_points = numeric_limits<size_t>::max();
labels_t discretizedData = labels_t();
static indices_t sortIndices(samples_t&, labels_t&); static indices_t sortIndices(samples_t&, labels_t&);
@@ -36,6 +37,7 @@ namespace mdlp {
~CPPFImdlp(); ~CPPFImdlp();
void fit(samples_t&, labels_t&); void fit(samples_t&, labels_t&);
inline cutPoints_t getCutPoints() const { return cutPoints; }; inline cutPoints_t getCutPoints() const { return cutPoints; };
labels_t& transform(const samples_t&);
inline int get_depth() const { return depth; }; inline int get_depth() const { return depth; };
static inline string version() { return "1.1.2"; }; static inline string version() { return "1.1.2"; };
}; };