Add tests for the quiet parameter and fix initialization mistake
This commit is contained in:
10
folding.hpp
10
folding.hpp
@@ -11,7 +11,7 @@
|
|||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
namespace folding {
|
namespace folding {
|
||||||
const std::string FOLDING_VERSION = "1.1.0";
|
const std::string FOLDING_VERSION = "1.1.1";
|
||||||
class Fold {
|
class Fold {
|
||||||
public:
|
public:
|
||||||
inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
|
inline Fold(int k, int n, int seed = -1) : k(k), n(n), seed(seed)
|
||||||
@@ -63,12 +63,14 @@ namespace folding {
|
|||||||
{
|
{
|
||||||
this->y = y;
|
this->y = y;
|
||||||
n = y.size();
|
n = y.size();
|
||||||
|
this->quiet = quiet;
|
||||||
build();
|
build();
|
||||||
}
|
}
|
||||||
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1) : Fold(k, y.numel(), seed)
|
inline StratifiedKFold(int k, torch::Tensor& y, int seed = -1, bool quiet = true) : Fold(k, y.numel(), seed)
|
||||||
{
|
{
|
||||||
n = y.numel();
|
n = y.numel();
|
||||||
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
|
this->y = std::vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + n);
|
||||||
|
this->quiet = quiet;
|
||||||
build();
|
build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,6 +92,7 @@ namespace folding {
|
|||||||
std::vector<int> y;
|
std::vector<int> y;
|
||||||
std::vector<std::vector<int>> stratified_indices;
|
std::vector<std::vector<int>> stratified_indices;
|
||||||
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
|
bool faulty = false; // Only true if the number of samples of any class is less than the number of folds.
|
||||||
|
bool quiet = true; // Enable or disable warning messages
|
||||||
void build()
|
void build()
|
||||||
{
|
{
|
||||||
stratified_indices = std::vector<std::vector<int>>(k);
|
stratified_indices = std::vector<std::vector<int>>(k);
|
||||||
@@ -105,7 +108,8 @@ namespace folding {
|
|||||||
int num_samples_to_take = num_samples / k;
|
int num_samples_to_take = num_samples / k;
|
||||||
int remainder_samples_to_take = num_samples % k;
|
int remainder_samples_to_take = num_samples % k;
|
||||||
if (num_samples_to_take == 0) {
|
if (num_samples_to_take == 0) {
|
||||||
std::cerr << "Warning! The number of samples in class " << label << " (" << num_samples
|
if (!quiet)
|
||||||
|
std::cerr << "Warning! The number of samples in class " << label << " (" << num_samples
|
||||||
<< ") is less than the number of folds (" << k << ")." << std::endl;
|
<< ") is less than the number of folds (" << k << ")." << std::endl;
|
||||||
faulty = true;
|
faulty = true;
|
||||||
}
|
}
|
||||||
|
@@ -1,11 +1,12 @@
|
|||||||
if(ENABLE_TESTING)
|
if(ENABLE_TESTING)
|
||||||
include_directories(
|
include_directories(
|
||||||
${Folding_SOURCE_DIR}
|
${Folding_SOURCE_DIR}
|
||||||
|
${CMAKE_BINARY_DIR}/configured_files/include
|
||||||
lib/Files
|
lib/Files
|
||||||
lib/mdlp/src
|
lib/mdlp/src
|
||||||
)
|
)
|
||||||
set(TEST_FOLDING "unit_tests_folding")
|
set(TEST_FOLDING "unit_tests_folding")
|
||||||
add_executable(${TEST_FOLDING} TestFolding.cc TestUtils.cc)
|
add_executable(${TEST_FOLDING} TestFolding.cc TestUtils.cc)
|
||||||
target_link_libraries(${TEST_FOLDING} PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain)
|
target_link_libraries(${TEST_FOLDING} PUBLIC "${TORCH_LIBRARIES}" ArffFiles fimdlp Catch2::Catch2WithMain)
|
||||||
add_test(NAME ${TEST_FOLDING} COMMAND ${TEST_FOLDING})
|
add_test(NAME ${TEST_FOLDING} COMMAND ${TEST_FOLDING})
|
||||||
endif(ENABLE_TESTING)
|
endif(ENABLE_TESTING)
|
||||||
|
@@ -12,7 +12,7 @@
|
|||||||
|
|
||||||
TEST_CASE("Version Test", "[Folding]")
|
TEST_CASE("Version Test", "[Folding]")
|
||||||
{
|
{
|
||||||
std::string actual_version = { folding_project_version.begin(), folding_project_version.end() };
|
std::string actual_version = "1.1.1";
|
||||||
auto data = std::vector<int>(100);
|
auto data = std::vector<int>(100);
|
||||||
folding::StratifiedKFold stratified_kfold(5, data, 17);
|
folding::StratifiedKFold stratified_kfold(5, data, 17);
|
||||||
REQUIRE(stratified_kfold.version() == actual_version);
|
REQUIRE(stratified_kfold.version() == actual_version);
|
||||||
@@ -187,3 +187,37 @@ TEST_CASE("StratifiedKFold Test", "[Folding]")
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
TEST_CASE("Stratified KFold quiet parameter", "[Folding]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("glass", true);
|
||||||
|
std::string expected = "Warning! The number of samples in class 2 (9) is less than the number of folds (10).\n";
|
||||||
|
|
||||||
|
SECTION("With vectors")
|
||||||
|
{
|
||||||
|
// Redirect cerr to a stringstream
|
||||||
|
std::streambuf* originalCerrBuffer = std::cerr.rdbuf();
|
||||||
|
std::stringstream capturedOutput;
|
||||||
|
std::cerr.rdbuf(capturedOutput.rdbuf());
|
||||||
|
// StratifiedKFold with quiet parameter set to false
|
||||||
|
folding::StratifiedKFold stratified_kfold(10, raw.yv, 17, false);
|
||||||
|
// Restore the original cerr buffer
|
||||||
|
std::cerr.rdbuf(originalCerrBuffer);
|
||||||
|
// Check the captured output
|
||||||
|
REQUIRE(capturedOutput.str() == expected);
|
||||||
|
REQUIRE(stratified_kfold.isFaulty());
|
||||||
|
}
|
||||||
|
SECTION("With tensors")
|
||||||
|
{
|
||||||
|
// Redirect cerr to a stringstream
|
||||||
|
std::streambuf* originalCerrBuffer = std::cerr.rdbuf();
|
||||||
|
std::stringstream capturedOutput;
|
||||||
|
std::cerr.rdbuf(capturedOutput.rdbuf());
|
||||||
|
// StratifiedKFold with quiet parameter set to false
|
||||||
|
folding::StratifiedKFold stratified_kfold(10, raw.yt, 17, false);
|
||||||
|
// Restore the original cerr buffer
|
||||||
|
std::cerr.rdbuf(originalCerrBuffer);
|
||||||
|
// Check the captured output
|
||||||
|
REQUIRE(capturedOutput.str() == expected);
|
||||||
|
REQUIRE(stratified_kfold.isFaulty());
|
||||||
|
}
|
||||||
|
}
|
@@ -8,7 +8,6 @@
|
|||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include "ArffFiles.h"
|
#include "ArffFiles.h"
|
||||||
#include "CPPFImdlp.h"
|
#include "CPPFImdlp.h"
|
||||||
#include "folding_config.h"
|
|
||||||
|
|
||||||
bool file_exists(const std::string& name);
|
bool file_exists(const std::string& name);
|
||||||
std::pair<vector<mdlp::labels_t>, map<std::string, int>> discretize(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y, std::vector<string> features);
|
std::pair<vector<mdlp::labels_t>, map<std::string, int>> discretize(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y, std::vector<string> features);
|
||||||
@@ -45,11 +44,11 @@ class Paths {
|
|||||||
public:
|
public:
|
||||||
static std::string datasets()
|
static std::string datasets()
|
||||||
{
|
{
|
||||||
return { folding_data_path.begin(), folding_data_path.end() };
|
return "../../tests/data/";
|
||||||
}
|
}
|
||||||
static std::string csv()
|
static std::string csv()
|
||||||
{
|
{
|
||||||
return { folding_csv_path.begin(), folding_csv_path.end() };
|
return "../../tests/csv/";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
class CSVFiles {
|
class CSVFiles {
|
||||||
|
Reference in New Issue
Block a user