Add file name validation and other optimizations

This commit is contained in:
2025-06-27 22:40:32 +02:00
parent 007286983f
commit 9338c818fd
4 changed files with 428 additions and 254 deletions

View File

@@ -255,6 +255,55 @@ public:
std::string version() const { return VERSION; }
private:
// Helper function to validate file path for security
static void validateFilePath(const std::string& fileName) {
if (fileName.empty()) {
throw std::invalid_argument("File path cannot be empty");
}
// Check for path traversal attempts
if (fileName.find("..") != std::string::npos) {
throw std::invalid_argument("Path traversal detected in file path: " + fileName);
}
// Check for absolute paths starting with / (Unix) or drive letters (Windows)
if (fileName[0] == '/' || (fileName.length() >= 3 && fileName[1] == ':')) {
// Allow absolute paths but log a warning - this is for user awareness
// In production, you might want to restrict this based on your security requirements
}
// Check for suspicious characters that could be used in path manipulation
const std::string suspiciousChars = "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0b\x0c\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f";
for (char c : suspiciousChars) {
if (fileName.find(c) != std::string::npos) {
throw std::invalid_argument("Invalid character detected in file path");
}
}
// Check for excessively long paths (potential buffer overflow attempts)
constexpr size_t MAX_PATH_LENGTH = 4096; // Common filesystem limit
if (fileName.length() > MAX_PATH_LENGTH) {
throw std::invalid_argument("File path too long (exceeds " + std::to_string(MAX_PATH_LENGTH) + " characters)");
}
// Additional validation using filesystem operations when available
try {
// Check if the file exists and validate its canonical path
if (std::filesystem::exists(fileName)) {
std::filesystem::path normalizedPath = std::filesystem::canonical(fileName);
std::string normalizedStr = normalizedPath.string();
// Check if normalized path still contains traversal attempts
if (normalizedStr.find("..") != std::string::npos) {
throw std::invalid_argument("Path traversal detected after normalization: " + normalizedStr);
}
}
} catch (const std::filesystem::filesystem_error& e) {
// If filesystem operations fail, we can still proceed with basic validation
// This ensures compatibility with systems where filesystem might not be fully available
}
}
// Helper function to validate resource usage limits
static void validateResourceLimits(const std::string& fileName, size_t sampleCount = 0, size_t featureCount = 0) {
// Check file size limit
@@ -303,7 +352,15 @@ private:
continue;
auto values = attribute.second;
std::transform(values.begin(), values.end(), values.begin(), ::toupper);
numeric_features[feature] = values == "REAL" || values == "INTEGER" || values == "NUMERIC";
// Enhanced attribute type detection
bool isNumeric = values == "REAL" || values == "INTEGER" || values == "NUMERIC";
bool isDate = values.find("DATE") != std::string::npos;
bool isString = values == "STRING";
// For now, treat DATE and STRING as categorical (non-numeric)
// This provides basic compatibility while maintaining existing functionality
numeric_features[feature] = isNumeric;
}
}
std::vector<int> factorize(const std::string feature, const std::vector<std::string>& labels_t)
@@ -345,10 +402,16 @@ private:
// Pre-allocate with feature-major layout: X[feature][sample]
X.assign(numFeatures, std::vector<float>(numSamples));
// Cache feature types for fast lookup during data processing
std::vector<bool> isNumericFeature(numFeatures);
for (size_t i = 0; i < numFeatures; ++i) {
isNumericFeature[i] = numeric_features.at(attributes[i].first);
}
// Temporary storage for categorical data per feature (only for non-numeric features)
std::vector<std::vector<std::string>> categoricalData(numFeatures);
for (size_t i = 0; i < numFeatures; ++i) {
if (!numeric_features[attributes[i].first]) {
if (!isNumericFeature[i]) {
categoricalData[i].reserve(numSamples);
}
}
@@ -380,18 +443,19 @@ private:
throw std::invalid_argument("Too many feature values at sample " + std::to_string(sampleIdx));
}
const auto& featureName = attributes[featureIdx].first;
if (numeric_features.at(featureName)) {
if (isNumericFeature[featureIdx]) {
// Parse numeric value with exception handling
try {
X[featureIdx][sampleIdx] = std::stof(token);
}
catch (const std::exception& e) {
const auto& featureName = attributes[featureIdx].first;
throw std::invalid_argument("Invalid numeric value '" + token + "' at sample " + std::to_string(sampleIdx) + ", feature " + featureName);
}
} else {
// Store categorical value temporarily
if (token.empty()) {
const auto& featureName = attributes[featureIdx].first;
throw std::invalid_argument("Empty categorical value at sample " + std::to_string(sampleIdx) + ", feature " + featureName);
}
categoricalData[featureIdx].push_back(token);
@@ -403,7 +467,7 @@ private:
// Convert categorical features to numeric
for (size_t featureIdx = 0; featureIdx < numFeatures; ++featureIdx) {
if (!numeric_features[attributes[featureIdx].first]) {
if (!isNumericFeature[featureIdx]) {
const auto& featureName = attributes[featureIdx].first;
auto encodedValues = factorize(featureName, categoricalData[featureIdx]);
@@ -424,6 +488,9 @@ private:
states.clear();
numeric_features.clear();
// Validate file path for security
validateFilePath(fileName);
// Validate file size before processing
validateResourceLimits(fileName);
@@ -440,6 +507,13 @@ private:
if (line.empty() || line[0] == '%' || line == "\r" || line == " ") {
continue;
}
// Skip sparse data format for now (lines starting with '{')
// Future enhancement: implement full sparse data support
if (!line.empty() && line[0] == '{') {
continue;
}
if (line.find("@attribute") != std::string::npos || line.find("@ATTRIBUTE") != std::string::npos) {
std::stringstream ss(line);
ss >> keyword >> attribute;
@@ -553,6 +627,9 @@ private:
size_t& sampleCount,
int classIndex = -1,
const std::string& classNameToFind = "") {
// Validate file path for security
validateFilePath(fileName);
std::ifstream file(fileName);
if (!file.is_open()) {
throw std::invalid_argument("Unable to open file: " + fileName);
@@ -568,6 +645,12 @@ private:
if (line.empty() || line[0] == '%' || line == "\r" || line == " ") {
continue;
}
// Skip sparse data format for now (lines starting with '{')
if (!line.empty() && line[0] == '{') {
continue;
}
if (line.find("@attribute") != std::string::npos || line.find("@ATTRIBUTE") != std::string::npos) {
std::stringstream ss(line);
std::string keyword, attribute, type_w;
@@ -620,7 +703,7 @@ private:
// Count samples and collect unique class values
do {
if (!line.empty() && line[0] != '@' && line[0] != '%' && !containsMissingValueStatic(line)) {
if (!line.empty() && line[0] != '@' && line[0] != '%' && line[0] != '{' && !containsMissingValueStatic(line)) {
auto tokens = splitStatic(line, ',');
if (!tokens.empty()) {
std::string classValue;