Merge pull request 'Add quiet parameter to Stratified KFold' (#1) from quiet into main

Reviewed-on: #1
This parameters enables/disables the output in std::err the Warning messages that produces trying to create a fold that lacks any values of the class
This commit is contained in:
2024-12-13 17:29:28 +00:00
10 changed files with 54 additions and 33 deletions

View File

@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [1.1.1] 2024-12-13
- Added a new parameter `quiet` to enable/disable the warning messages in the Stratified K-Fold partitioning. Default value `true`.
## [1.1.0] 2024-05-11 ## [1.1.0] 2024-05-11
### Fixed ### Fixed

View File

@@ -1,7 +1,6 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
project(Folding project(Folding
VERSION 1.1.0
DESCRIPTION "Folding utility for BayesNet library" DESCRIPTION "Folding utility for BayesNet library"
HOMEPAGE_URL "https://github.com/rmontanana/folding" HOMEPAGE_URL "https://github.com/rmontanana/folding"
LANGUAGES CXX LANGUAGES CXX
@@ -33,7 +32,6 @@ include(AddGitSubmodule)
# Subdirectories # Subdirectories
# -------------- # --------------
add_subdirectory(config)
# Testing # Testing
# ------- # -------

View File

@@ -1,4 +0,0 @@
configure_file(
"config.h.in"
"${CMAKE_BINARY_DIR}/configured_files/include/folding_config.h" ESCAPE_QUOTES
)

View File

@@ -1,14 +0,0 @@
#pragma once
#include <string>
#include <string_view>
#define PROJECT_VERSION_MAJOR @PROJECT_VERSION_MAJOR @
#define PROJECT_VERSION_MINOR @PROJECT_VERSION_MINOR @
#define PROJECT_VERSION_PATCH @PROJECT_VERSION_PATCH @
static constexpr std::string_view folding_project_name = "@PROJECT_NAME@";
static constexpr std::string_view folding_project_version = "@PROJECT_VERSION@";
static constexpr std::string_view folding_project_description = "@PROJECT_DESCRIPTION@";
static constexpr std::string_view folding_data_path = "@Folding_SOURCE_DIR@/tests/data/";
static constexpr std::string_view folding_csv_path = "@Folding_SOURCE_DIR@/tests/csv/";

View File

@@ -11,7 +11,7 @@
#include <random> #include <random>
#include <vector> #include <vector>
namespace folding { namespace folding {
const std::string FOLDING_VERSION = "1.1.0"; const std::string FOLDING_VERSION = "1.1.1";
class Fold { class Fold {
public: public:
inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed) inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
@@ -59,16 +59,18 @@ namespace folding {
}; };
class StratifiedKFold : public Fold { class StratifiedKFold : public Fold {
public: public:
inline StratifiedKFold(int k, const std::vector<int>& y, int seed = -1) : Fold(k, y.size(), seed) inline StratifiedKFold(int k, const std::vector<int>& y, int seed = -1, bool quiet = true) : Fold(k, y.size(), seed)
{ {
this->y = y; this->y = y;
n = y.size(); n = y.size();
this->quiet = quiet;
build(); build();
} }
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed) inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1, bool quiet = true) : Fold(k, y.numel(), seed)
{ {
n = y.numel(); n = y.numel();
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n); this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
this->quiet = quiet;
build(); build();
} }
@@ -90,6 +92,7 @@ namespace folding {
std::vector<int> y; std::vector<int> y;
std::vector<std::vector<int>> stratified_indices; std::vector<std::vector<int>> stratified_indices;
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds. bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
bool quiet = true; // Enable or disable warning messages
void build() void build()
{ {
stratified_indices = std::vector<std::vector<int>>(k); stratified_indices = std::vector<std::vector<int>>(k);
@@ -105,6 +108,7 @@ namespace folding {
int num_samples_to_take = num_samples / k; int num_samples_to_take = num_samples / k;
int remainder_samples_to_take = num_samples % k; int remainder_samples_to_take = num_samples % k;
if (num_samples_to_take == 0) { if (num_samples_to_take == 0) {
if (!quiet)
std::cerr << "Warning! The number of samples in class " << label << " (" << num_samples std::cerr << "Warning! The number of samples in class " << label << " (" << num_samples
<< ") is less than the number of folds (" << k << ")." << std::endl; << ") is less than the number of folds (" << k << ")." << std::endl;
faulty = true; faulty = true;

View File

@@ -1,12 +1,12 @@
if(ENABLE_TESTING) if(ENABLE_TESTING)
include_directories( include_directories(
${Folding_SOURCE_DIR} ${Folding_SOURCE_DIR}
lib/Files
lib/mdlp
${CMAKE_BINARY_DIR}/configured_files/include ${CMAKE_BINARY_DIR}/configured_files/include
lib/Files
lib/mdlp/src
) )
set(TEST_FOLDING "unit_tests_folding") set(TEST_FOLDING "unit_tests_folding")
add_executable(${TEST_FOLDING} TestFolding.cc TestUtils.cc) add_executable(${TEST_FOLDING} TestFolding.cc TestUtils.cc)
target_link_libraries(${TEST_FOLDING} PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain) target_link_libraries(${TEST_FOLDING} PUBLIC "${TORCH_LIBRARIES}" ArffFiles fimdlp Catch2::Catch2WithMain)
add_test(NAME ${TEST_FOLDING} COMMAND ${TEST_FOLDING}) add_test(NAME ${TEST_FOLDING} COMMAND ${TEST_FOLDING})
endif(ENABLE_TESTING) endif(ENABLE_TESTING)

View File

@@ -12,7 +12,7 @@
TEST_CASE("Version Test", "[Folding]") TEST_CASE("Version Test", "[Folding]")
{ {
std::string actual_version = { folding_project_version.begin(), folding_project_version.end() }; std::string actual_version = "1.1.1";
auto data = std::vector<int>(100); auto data = std::vector<int>(100);
folding::StratifiedKFold stratified_kfold(5, data, 17); folding::StratifiedKFold stratified_kfold(5, data, 17);
REQUIRE(stratified_kfold.version() == actual_version); REQUIRE(stratified_kfold.version() == actual_version);
@@ -187,3 +187,37 @@ TEST_CASE("StratifiedKFold Test", "[Folding]")
} }
} }
} }
TEST_CASE("Stratified KFold quiet parameter", "[Folding]")
{
auto raw = RawDatasets("glass", true);
std::string expected = "Warning! The number of samples in class 2 (9) is less than the number of folds (10).\n";
SECTION("With vectors")
{
// Redirect cerr to a stringstream
std::streambuf* originalCerrBuffer = std::cerr.rdbuf();
std::stringstream capturedOutput;
std::cerr.rdbuf(capturedOutput.rdbuf());
// StratifiedKFold with quiet parameter set to false
folding::StratifiedKFold stratified_kfold(10, raw.yv, 17, false);
// Restore the original cerr buffer
std::cerr.rdbuf(originalCerrBuffer);
// Check the captured output
REQUIRE(capturedOutput.str() == expected);
REQUIRE(stratified_kfold.isFaulty());
}
SECTION("With tensors")
{
// Redirect cerr to a stringstream
std::streambuf* originalCerrBuffer = std::cerr.rdbuf();
std::stringstream capturedOutput;
std::cerr.rdbuf(capturedOutput.rdbuf());
// StratifiedKFold with quiet parameter set to false
folding::StratifiedKFold stratified_kfold(10, raw.yt, 17, false);
// Restore the original cerr buffer
std::cerr.rdbuf(originalCerrBuffer);
// Check the captured output
REQUIRE(capturedOutput.str() == expected);
REQUIRE(stratified_kfold.isFaulty());
}
}

View File

@@ -8,7 +8,6 @@
#include <tuple> #include <tuple>
#include "ArffFiles.h" #include "ArffFiles.h"
#include "CPPFImdlp.h" #include "CPPFImdlp.h"
#include "folding_config.h"
bool file_exists(const std::string& name); bool file_exists(const std::string& name);
std::pair<vector<mdlp::labels_t>, map<std::string, int>> discretize(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y, std::vector<string> features); std::pair<vector<mdlp::labels_t>, map<std::string, int>> discretize(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y, std::vector<string> features);
@@ -45,11 +44,11 @@ class Paths {
public: public:
static std::string datasets() static std::string datasets()
{ {
return { folding_data_path.begin(), folding_data_path.end() }; return "../../tests/data/";
} }
static std::string csv() static std::string csv()
{ {
return { folding_csv_path.begin(), folding_csv_path.end() }; return "../../tests/csv/";
} }
}; };
class CSVFiles { class CSVFiles {