Claude enhancement proposal

This commit is contained in:
2025-06-28 13:17:31 +02:00
parent 059fd33b4e
commit 99b751a4d4
8 changed files with 384 additions and 37 deletions

View File

@@ -22,13 +22,15 @@ namespace mdlp {
BinDisc::~BinDisc() = default;
void BinDisc::fit(samples_t& X)
{
// y is included for compatibility with the Discretizer interface
cutPoints.clear();
// Input validation
if (X.empty()) {
cutPoints.push_back(0.0);
cutPoints.push_back(0.0);
return;
throw std::invalid_argument("Input data X cannot be empty");
}
if (X.size() < static_cast<size_t>(n_bins)) {
throw std::invalid_argument("Input data size must be at least equal to n_bins");
}
cutPoints.clear();
if (strategy == strategy_t::QUANTILE) {
direction = bound_dir_t::RIGHT;
fit_quantile(X);
@@ -39,10 +41,31 @@ namespace mdlp {
}
void BinDisc::fit(samples_t& X, labels_t& y)
{
// Input validation for supervised interface
if (X.size() != y.size()) {
throw std::invalid_argument("X and y must have the same size");
}
if (X.empty() || y.empty()) {
throw std::invalid_argument("X and y cannot be empty");
}
// BinDisc is inherently unsupervised, but we validate inputs for consistency
// Note: y parameter is validated but not used in binning strategy
fit(X);
}
std::vector<precision_t> linspace(precision_t start, precision_t end, int num)
{
// Input validation
if (num < 2) {
throw std::invalid_argument("Number of points must be at least 2 for linspace");
}
if (std::isnan(start) || std::isnan(end)) {
throw std::invalid_argument("Start and end values cannot be NaN");
}
if (std::isinf(start) || std::isinf(end)) {
throw std::invalid_argument("Start and end values cannot be infinite");
}
if (start == end) {
return { start, end };
}
@@ -60,6 +83,14 @@ namespace mdlp {
}
std::vector<precision_t> percentile(samples_t& data, const std::vector<precision_t>& percentiles)
{
// Input validation
if (data.empty()) {
throw std::invalid_argument("Data cannot be empty for percentile calculation");
}
if (percentiles.empty()) {
throw std::invalid_argument("Percentiles cannot be empty");
}
// Implementation taken from https://dpilger26.github.io/NumCpp/doxygen/html/percentile_8hpp_source.html
std::vector<precision_t> results;
bool first = true;

View File

@@ -8,6 +8,7 @@
#include <algorithm>
#include <set>
#include <cmath>
#include <stdexcept>
#include "CPPFImdlp.h"
namespace mdlp {
@@ -18,6 +19,17 @@ namespace mdlp {
max_depth(max_depth_),
proposed_cuts(proposed)
{
// Input validation for constructor parameters
if (min_length_ < 3) {
throw std::invalid_argument("min_length must be greater than 2");
}
if (max_depth_ < 1) {
throw std::invalid_argument("max_depth must be greater than 0");
}
if (proposed < 0.0f) {
throw std::invalid_argument("proposed_cuts must be non-negative");
}
direction = bound_dir_t::RIGHT;
}
@@ -49,12 +61,6 @@ namespace mdlp {
if (X.empty() || y.empty()) {
throw invalid_argument("X and y must have at least one element");
}
if (min_length < 3) {
throw invalid_argument("min_length must be greater than 2");
}
if (max_depth < 1) {
throw invalid_argument("max_depth must be greater than 0");
}
indices = sortIndices(X_, y_);
metrics.setData(y, indices);
computeCutPoints(0, X.size(), 1);
@@ -81,26 +87,32 @@ namespace mdlp {
precision_t previous;
precision_t actual;
precision_t next;
previous = X[indices[idxPrev]];
actual = X[indices[cut]];
next = X[indices[idxNext]];
previous = safe_X_access(idxPrev);
actual = safe_X_access(cut);
next = safe_X_access(idxNext);
// definition 2 of the paper => X[t-1] < X[t]
// get the first equal value of X in the interval
while (idxPrev > start && actual == previous) {
previous = X[indices[--idxPrev]];
--idxPrev;
previous = safe_X_access(idxPrev);
}
backWall = idxPrev == start && actual == previous;
// get the last equal value of X in the interval
while (idxNext < end - 1 && actual == next) {
next = X[indices[++idxNext]];
++idxNext;
next = safe_X_access(idxNext);
}
// # of duplicates before cutpoint
n = cut - 1 - idxPrev;
n = safe_subtract(safe_subtract(cut, 1), idxPrev);
// # of duplicates after cutpoint
m = idxNext - cut - 1;
m = safe_subtract(safe_subtract(idxNext, cut), 1);
// Decide which values to use
cut = cut + (backWall ? m + 1 : -n);
actual = X[indices[cut]];
if (backWall) {
cut = cut + m + 1;
} else {
cut = safe_subtract(cut, n);
}
actual = safe_X_access(cut);
return { (actual + previous) / 2, cut };
}
@@ -109,7 +121,7 @@ namespace mdlp {
size_t cut;
pair<precision_t, size_t> result;
// Check if the interval length and the depth are Ok
if (end - start < min_length || depth_ > max_depth)
if (end < start || safe_subtract(end, start) < min_length || depth_ > max_depth)
return;
depth = depth_ > depth ? depth_ : depth;
cut = getCandidate(start, end);
@@ -129,14 +141,14 @@ namespace mdlp {
/* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which
E(A, TA; S) is minimal amongst all the candidate cut points. */
size_t candidate = numeric_limits<size_t>::max();
size_t elements = end - start;
size_t elements = safe_subtract(end, start);
bool sameValues = true;
precision_t entropy_left;
precision_t entropy_right;
precision_t minEntropy;
// Check if all the values of the variable in the interval are the same
for (size_t idx = start + 1; idx < end; idx++) {
if (X[indices[idx]] != X[indices[start]]) {
if (safe_X_access(idx) != safe_X_access(start)) {
sameValues = false;
break;
}
@@ -146,7 +158,7 @@ namespace mdlp {
minEntropy = metrics.entropy(start, end);
for (size_t idx = start + 1; idx < end; idx++) {
// Cutpoints are always on boundaries (definition 2)
if (y[indices[idx]] == y[indices[idx - 1]])
if (safe_y_access(idx) == safe_y_access(idx - 1))
continue;
entropy_left = precision_t(idx - start) / static_cast<precision_t>(elements) * metrics.entropy(start, idx);
entropy_right = precision_t(end - idx) / static_cast<precision_t>(elements) * metrics.entropy(idx, end);
@@ -168,7 +180,7 @@ namespace mdlp {
precision_t ent;
precision_t ent1;
precision_t ent2;
auto N = precision_t(end - start);
auto N = precision_t(safe_subtract(end, start));
k = metrics.computeNumClasses(start, end);
k1 = metrics.computeNumClasses(start, cut);
k2 = metrics.computeNumClasses(cut, end);
@@ -188,6 +200,9 @@ namespace mdlp {
indices_t idx(X_.size());
std::iota(idx.begin(), idx.end(), 0);
stable_sort(idx.begin(), idx.end(), [&X_, &y_](size_t i1, size_t i2) {
if (i1 >= X_.size() || i2 >= X_.size() || i1 >= y_.size() || i2 >= y_.size()) {
throw std::out_of_range("Index out of bounds in sort comparison");
}
if (X_[i1] == X_[i2])
return y_[i1] < y_[i2];
else
@@ -206,7 +221,7 @@ namespace mdlp {
size_t end;
for (size_t idx = 0; idx < cutPoints.size(); idx++) {
end = begin;
while (X[indices[end]] < cutPoints[idx] && end < X.size())
while (end < indices.size() && safe_X_access(end) < cutPoints[idx] && end < X.size())
end++;
entropy = metrics.entropy(begin, end);
if (entropy > maxEntropy) {

View File

@@ -39,6 +39,33 @@ namespace mdlp {
size_t getCandidate(size_t, size_t);
size_t compute_max_num_cut_points() const;
pair<precision_t, size_t> valueCutPoint(size_t, size_t, size_t);
private:
inline precision_t safe_X_access(size_t idx) const {
if (idx >= indices.size()) {
throw std::out_of_range("Index out of bounds for indices array");
}
size_t real_idx = indices[idx];
if (real_idx >= X.size()) {
throw std::out_of_range("Index out of bounds for X array");
}
return X[real_idx];
}
inline label_t safe_y_access(size_t idx) const {
if (idx >= indices.size()) {
throw std::out_of_range("Index out of bounds for indices array");
}
size_t real_idx = indices[idx];
if (real_idx >= y.size()) {
throw std::out_of_range("Index out of bounds for y array");
}
return y[real_idx];
}
inline size_t safe_subtract(size_t a, size_t b) const {
if (b > a) {
throw std::underflow_error("Subtraction would cause underflow");
}
return a - b;
}
};
}
#endif

View File

@@ -10,6 +10,14 @@ namespace mdlp {
labels_t& Discretizer::transform(const samples_t& data)
{
// Input validation
if (data.empty()) {
throw std::invalid_argument("Data for transformation cannot be empty");
}
if (cutPoints.size() < 2) {
throw std::runtime_error("Discretizer not fitted yet or no valid cut points found");
}
discretizedData.clear();
discretizedData.reserve(data.size());
// CutPoints always have at least two items
@@ -31,6 +39,26 @@ namespace mdlp {
}
void Discretizer::fit_t(const torch::Tensor& X_, const torch::Tensor& y_)
{
// Validate tensor properties for security
if (!X_.is_contiguous() || !y_.is_contiguous()) {
throw std::invalid_argument("Tensors must be contiguous");
}
if (X_.sizes().size() != 1 || y_.sizes().size() != 1) {
throw std::invalid_argument("Only 1D tensors supported");
}
if (X_.dtype() != torch::kFloat32) {
throw std::invalid_argument("X tensor must be Float32 type");
}
if (y_.dtype() != torch::kInt32) {
throw std::invalid_argument("y tensor must be Int32 type");
}
if (X_.numel() != y_.numel()) {
throw std::invalid_argument("X and y tensors must have same number of elements");
}
if (X_.numel() == 0) {
throw std::invalid_argument("Tensors cannot be empty");
}
auto num_elements = X_.numel();
samples_t X(X_.data_ptr<precision_t>(), X_.data_ptr<precision_t>() + num_elements);
labels_t y(y_.data_ptr<int>(), y_.data_ptr<int>() + num_elements);
@@ -38,6 +66,20 @@ namespace mdlp {
}
torch::Tensor Discretizer::transform_t(const torch::Tensor& X_)
{
// Validate tensor properties for security
if (!X_.is_contiguous()) {
throw std::invalid_argument("Tensor must be contiguous");
}
if (X_.sizes().size() != 1) {
throw std::invalid_argument("Only 1D tensors supported");
}
if (X_.dtype() != torch::kFloat32) {
throw std::invalid_argument("X tensor must be Float32 type");
}
if (X_.numel() == 0) {
throw std::invalid_argument("Tensor cannot be empty");
}
auto num_elements = X_.numel();
samples_t X(X_.data_ptr<precision_t>(), X_.data_ptr<precision_t>() + num_elements);
auto result = transform(X);
@@ -45,6 +87,26 @@ namespace mdlp {
}
torch::Tensor Discretizer::fit_transform_t(const torch::Tensor& X_, const torch::Tensor& y_)
{
// Validate tensor properties for security
if (!X_.is_contiguous() || !y_.is_contiguous()) {
throw std::invalid_argument("Tensors must be contiguous");
}
if (X_.sizes().size() != 1 || y_.sizes().size() != 1) {
throw std::invalid_argument("Only 1D tensors supported");
}
if (X_.dtype() != torch::kFloat32) {
throw std::invalid_argument("X tensor must be Float32 type");
}
if (y_.dtype() != torch::kInt32) {
throw std::invalid_argument("y tensor must be Int32 type");
}
if (X_.numel() != y_.numel()) {
throw std::invalid_argument("X and y tensors must have same number of elements");
}
if (X_.numel() == 0) {
throw std::invalid_argument("Tensors cannot be empty");
}
auto num_elements = X_.numel();
samples_t X(X_.data_ptr<precision_t>(), X_.data_ptr<precision_t>() + num_elements);
labels_t y(y_.data_ptr<int>(), y_.data_ptr<int>() + num_elements);

View File

@@ -26,6 +26,7 @@ namespace mdlp {
void Metrics::setData(const labels_t& y_, const indices_t& indices_)
{
std::lock_guard<std::mutex> lock(cache_mutex);
indices = indices_;
y = y_;
numClasses = computeNumClasses(0, indices.size());
@@ -35,15 +36,23 @@ namespace mdlp {
precision_t Metrics::entropy(size_t start, size_t end)
{
if (end - start < 2)
return 0;
// Check cache first with read lock
{
std::lock_guard<std::mutex> lock(cache_mutex);
if (entropyCache.find({ start, end }) != entropyCache.end()) {
return entropyCache[{start, end}];
}
}
// Compute entropy outside of lock
precision_t p;
precision_t ventropy = 0;
int nElements = 0;
labels_t counts(numClasses + 1, 0);
if (end - start < 2)
return 0;
if (entropyCache.find({ start, end }) != entropyCache.end()) {
return entropyCache[{start, end}];
}
for (auto i = &indices[start]; i != &indices[end]; ++i) {
counts[y[*i]]++;
nElements++;
@@ -54,12 +63,27 @@ namespace mdlp {
ventropy -= p * log2(p);
}
}
entropyCache[{start, end}] = ventropy;
// Update cache with write lock
{
std::lock_guard<std::mutex> lock(cache_mutex);
entropyCache[{start, end}] = ventropy;
}
return ventropy;
}
precision_t Metrics::informationGain(size_t start, size_t cut, size_t end)
{
// Check cache first with read lock
{
std::lock_guard<std::mutex> lock(cache_mutex);
if (igCache.find(make_tuple(start, cut, end)) != igCache.end()) {
return igCache[make_tuple(start, cut, end)];
}
}
// Compute information gain outside of lock
precision_t iGain;
precision_t entropyInterval;
precision_t entropyLeft;
@@ -67,9 +91,7 @@ namespace mdlp {
size_t nElementsLeft = cut - start;
size_t nElementsRight = end - cut;
size_t nElements = end - start;
if (igCache.find(make_tuple(start, cut, end)) != igCache.end()) {
return igCache[make_tuple(start, cut, end)];
}
entropyInterval = entropy(start, end);
entropyLeft = entropy(start, cut);
entropyRight = entropy(cut, end);
@@ -77,7 +99,13 @@ namespace mdlp {
(static_cast<precision_t>(nElementsLeft) * entropyLeft +
static_cast<precision_t>(nElementsRight) * entropyRight) /
static_cast<precision_t>(nElements);
igCache[make_tuple(start, cut, end)] = iGain;
// Update cache with write lock
{
std::lock_guard<std::mutex> lock(cache_mutex);
igCache[make_tuple(start, cut, end)] = iGain;
}
return iGain;
}

View File

@@ -8,6 +8,7 @@
#define CCMETRICS_H
#include "typesFImdlp.h"
#include <mutex>
namespace mdlp {
class Metrics {
@@ -15,6 +16,7 @@ namespace mdlp {
labels_t& y;
indices_t& indices;
int numClasses;
mutable std::mutex cache_mutex;
cacheEnt_t entropyCache = cacheEnt_t();
cacheIg_t igCache = cacheIg_t();
public: