Add env to enable test data

This commit is contained in:
2024-04-19 10:02:59 +02:00
parent 018c94bfe6
commit 1caa39c071
10 changed files with 91 additions and 71 deletions

6
.vscode/launch.json vendored
View File

@@ -108,12 +108,12 @@
"name": "test",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/build_debug/tests/unit_tests",
"program": "${workspaceFolder}/build_debug/tests/unit_tests_platform",
"args": [
"-c=\"Metrics Test\"",
// "-c=\"Metrics Test\"",
// "-s",
],
"cwd": "${workspaceFolder}/build/tests",
"cwd": "${workspaceFolder}/build_debug/tests",
},
{
"name": "Build & debug active file",

View File

@@ -85,6 +85,9 @@ message(STATUS "Bayesnet_INCLUDE_DIRS=${Bayesnet_INCLUDE_DIRS}")
# Subdirectories
# --------------
## Configure test data path
cmake_path(SET TEST_DATA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/tests/data")
configure_file(src/common/SourceData.h.in "${CMAKE_BINARY_DIR}/configured_files/include/SourceData.h")
add_subdirectory(lib/Files)
add_subdirectory(config)
add_subdirectory(src)

View File

@@ -6,37 +6,9 @@
#include <string>
#include <CPPFImdlp.h>
#include "Utils.h"
#include "SourceData.h"
namespace platform {
enum fileType_t { CSV, ARFF, RDATA };
class SourceData {
public:
SourceData(std::string source)
{
if (source == "Surcov") {
path = "datasets/";
fileType = CSV;
} else if (source == "Arff") {
path = "datasets/";
fileType = ARFF;
} else if (source == "Tanveer") {
path = "data/";
fileType = RDATA;
} else {
throw std::invalid_argument("Unknown source.");
}
}
std::string getPath()
{
return path;
}
fileType_t getFileType()
{
return fileType;
}
private:
std::string path;
fileType_t fileType;
};
class Dataset {
private:
std::string path;

View File

@@ -14,8 +14,15 @@ namespace platform {
private:
std::map<std::string, std::string> env;
public:
DotEnv()
DotEnv(bool create = false)
{
if (create) {
// For testing purposes
std::ofstream file(".env");
file << "source_data = Test" << std::endl;
file << "margin = 0.1" << std::endl;
file.close();
}
std::ifstream file(".env");
if (!file.is_open()) {
std::cerr << "File .env not found" << std::endl;
@@ -30,7 +37,7 @@ namespace platform {
std::istringstream iss(line);
std::string key, value;
if (std::getline(iss, key, '=') && std::getline(iss, value)) {
env[key] = value;
env[trim(key)] = trim(value);
}
}
}

View File

@@ -0,0 +1,38 @@
#ifndef SOURCEDATA_H
#define SOURCEDATA_H
namespace platform {
enum fileType_t { CSV, ARFF, RDATA };
class SourceData {
public:
SourceData(std::string source)
{
if (source == "Surcov") {
path = "datasets/";
fileType = CSV;
} else if (source == "Arff") {
path = "datasets/";
fileType = ARFF;
} else if (source == "Tanveer") {
path = "data/";
fileType = RDATA;
} else if (source == "Test") {
path = "@TEST_DATA_PATH@/";
fileType = ARFF;
} else {
throw std::invalid_argument("Unknown source.");
}
}
std::string getPath()
{
return path;
}
fileType_t getFileType()
{
return fileType;
}
private:
std::string path;
fileType_t fileType;
};
}
#endif

View File

@@ -69,6 +69,10 @@ namespace platform {
mark = 0.9995;
}
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
if (status == Symbols::cross) {
std::cout << "ZeroR mark: " << mark << " result=" << result << " dataset = " << dataset << std::endl;
exit(1);
}
}
}
}

View File

@@ -11,7 +11,7 @@ if(ENABLE_TESTING)
${PyClassifiers_INCLUDE_DIRS}
${Bayesnet_INCLUDE_DIRS}
)
set(TEST_SOURCES_PLATFORM TestUtils.cpp TestPlatform.cpp)
set(TEST_SOURCES_PLATFORM TestUtils.cpp TestPlatform.cpp TestResult.cpp ${Platform_SOURCE_DIR}/src/common/Datasets.cpp ${Platform_SOURCE_DIR}/src/common/Dataset.cpp)
add_executable(${TEST_PLATFORM} ${TEST_SOURCES_PLATFORM})
target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain BayesNet)
add_test(NAME ${TEST_PLATFORM} COMMAND ${TEST_PLATFORM})

View File

@@ -24,7 +24,7 @@ TEST_CASE("Test Folding library version", "[Folding]")
TEST_CASE("Test BayesNet version", "[BayesNet]")
{
std::string version = bayesnet::TAN().getVersion();
REQUIRE(version == "1.0.4");
REQUIRE(version == "1.0.4.1");
}
TEST_CASE("Test mdlp version", "[mdlp]")
{

View File

@@ -1,37 +1,25 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"
#include "Result.h"
#include <filesystem>
#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_approx.hpp>
#include <vector>
#include <string>
#include "TestUtils.h"
#include "results/Result.h"
#include "common/DotEnv.h"
#include "common/Datasets.h"
#include "common/Paths.h"
#include "config.h"
TEST_CASE("Result class tests", "[Result]")
{
std::string testPath = "test_data";
std::string testFile = "test.json";
SECTION("Constructor and load method")
TEST_CASE("ZeroR comparison in reports", "[Report]")
{
platform::Result result;
result.load(testPath, testFile);
REQUIRE(result.date != "");
REQUIRE(result.score >= 0);
REQUIRE(result.scoreName != "");
REQUIRE(result.title != "");
REQUIRE(result.duration >= 0);
REQUIRE(result.model != "");
}
SECTION("to_string method")
{
platform::Result result(testPath, testFile);
result.load();
std::string resultStr = result.to_string(1);
REQUIRE(resultStr != "");
}
SECTION("Exception handling in load method")
{
std::string invalidFile = "invalid.json";
auto result = platform::Result();
REQUIRE_THROWS_AS(platform::result.load(testPath, invalidFile), std::invalid_argument);
}
auto dotEnv = platform::DotEnv(true);
auto margin = 1e-2;
std::string dataset = "liver-disorders";
auto dt = platform::Datasets(false, platform::Paths::datasets());
dt.loadDataset(dataset);
std::vector<int> distribution = dt.getClassesCounts(dataset);
double nSamples = dt.getNSamples(dataset);
std::vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
double mark = *maxValue / nSamples * (1 + margin);
REQUIRE(mark == Catch::Approx(0.585507f).epsilon(1e-5));
}

8
tests/data/all.txt Normal file
View File

@@ -0,0 +1,8 @@
diabetes,class, all
ecoli,class, all
glass,Type, all
iris,class, all
kdd_JapaneseVowels,speaker, [2,3,4,5,6,7,8,9,10,11,12,13]
letter,class, all
liver-disorders,selector, all
mfeat-factors,class, all