diff --git a/.vscode/launch.json b/.vscode/launch.json index 7933b41..a587132 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -108,12 +108,12 @@ "name": "test", "type": "lldb", "request": "launch", - "program": "${workspaceFolder}/build_debug/tests/unit_tests", + "program": "${workspaceFolder}/build_debug/tests/unit_tests_platform", "args": [ - "-c=\"Metrics Test\"", + // "-c=\"Metrics Test\"", // "-s", ], - "cwd": "${workspaceFolder}/build/tests", + "cwd": "${workspaceFolder}/build_debug/tests", }, { "name": "Build & debug active file", diff --git a/CMakeLists.txt b/CMakeLists.txt index a90d9a6..c23943f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,9 @@ message(STATUS "Bayesnet_INCLUDE_DIRS=${Bayesnet_INCLUDE_DIRS}") # Subdirectories # -------------- +## Configure test data path +cmake_path(SET TEST_DATA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/tests/data") +configure_file(src/common/SourceData.h.in "${CMAKE_BINARY_DIR}/configured_files/include/SourceData.h") add_subdirectory(lib/Files) add_subdirectory(config) add_subdirectory(src) diff --git a/src/common/Dataset.h b/src/common/Dataset.h index 6c89769..49ffc48 100644 --- a/src/common/Dataset.h +++ b/src/common/Dataset.h @@ -6,37 +6,9 @@ #include #include #include "Utils.h" +#include "SourceData.h" namespace platform { - enum fileType_t { CSV, ARFF, RDATA }; - class SourceData { - public: - SourceData(std::string source) - { - if (source == "Surcov") { - path = "datasets/"; - fileType = CSV; - } else if (source == "Arff") { - path = "datasets/"; - fileType = ARFF; - } else if (source == "Tanveer") { - path = "data/"; - fileType = RDATA; - } else { - throw std::invalid_argument("Unknown source."); - } - } - std::string getPath() - { - return path; - } - fileType_t getFileType() - { - return fileType; - } - private: - std::string path; - fileType_t fileType; - }; + class Dataset { private: std::string path; diff --git a/src/common/DotEnv.h b/src/common/DotEnv.h index 905e909..6882f09 100644 --- a/src/common/DotEnv.h +++ b/src/common/DotEnv.h @@ -14,8 +14,15 @@ namespace platform { private: std::map env; public: - DotEnv() + DotEnv(bool create = false) { + if (create) { + // For testing purposes + std::ofstream file(".env"); + file << "source_data = Test" << std::endl; + file << "margin = 0.1" << std::endl; + file.close(); + } std::ifstream file(".env"); if (!file.is_open()) { std::cerr << "File .env not found" << std::endl; @@ -30,7 +37,7 @@ namespace platform { std::istringstream iss(line); std::string key, value; if (std::getline(iss, key, '=') && std::getline(iss, value)) { - env[key] = value; + env[trim(key)] = trim(value); } } } diff --git a/src/common/SourceData.h.in b/src/common/SourceData.h.in new file mode 100644 index 0000000..386c7e0 --- /dev/null +++ b/src/common/SourceData.h.in @@ -0,0 +1,38 @@ +#ifndef SOURCEDATA_H +#define SOURCEDATA_H +namespace platform { + enum fileType_t { CSV, ARFF, RDATA }; + class SourceData { + public: + SourceData(std::string source) + { + if (source == "Surcov") { + path = "datasets/"; + fileType = CSV; + } else if (source == "Arff") { + path = "datasets/"; + fileType = ARFF; + } else if (source == "Tanveer") { + path = "data/"; + fileType = RDATA; + } else if (source == "Test") { + path = "@TEST_DATA_PATH@/"; + fileType = ARFF; + } else { + throw std::invalid_argument("Unknown source."); + } + } + std::string getPath() + { + return path; + } + fileType_t getFileType() + { + return fileType; + } + private: + std::string path; + fileType_t fileType; + }; +} +#endif \ No newline at end of file diff --git a/src/reports/ReportBase.cpp b/src/reports/ReportBase.cpp index 1abcc27..844f6bf 100644 --- a/src/reports/ReportBase.cpp +++ b/src/reports/ReportBase.cpp @@ -69,6 +69,10 @@ namespace platform { mark = 0.9995; } status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "="; + if (status == Symbols::cross) { + std::cout << "ZeroR mark: " << mark << " result=" << result << " dataset = " << dataset << std::endl; + exit(1); + } } } } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1c94b52..e76bc54 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,7 +11,7 @@ if(ENABLE_TESTING) ${PyClassifiers_INCLUDE_DIRS} ${Bayesnet_INCLUDE_DIRS} ) - set(TEST_SOURCES_PLATFORM TestUtils.cpp TestPlatform.cpp) + set(TEST_SOURCES_PLATFORM TestUtils.cpp TestPlatform.cpp TestResult.cpp ${Platform_SOURCE_DIR}/src/common/Datasets.cpp ${Platform_SOURCE_DIR}/src/common/Dataset.cpp) add_executable(${TEST_PLATFORM} ${TEST_SOURCES_PLATFORM}) target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain BayesNet) add_test(NAME ${TEST_PLATFORM} COMMAND ${TEST_PLATFORM}) diff --git a/tests/TestPlatform.cpp b/tests/TestPlatform.cpp index ace41b2..9a44072 100644 --- a/tests/TestPlatform.cpp +++ b/tests/TestPlatform.cpp @@ -24,7 +24,7 @@ TEST_CASE("Test Folding library version", "[Folding]") TEST_CASE("Test BayesNet version", "[BayesNet]") { std::string version = bayesnet::TAN().getVersion(); - REQUIRE(version == "1.0.4"); + REQUIRE(version == "1.0.4.1"); } TEST_CASE("Test mdlp version", "[mdlp]") { diff --git a/tests/TestResult.cpp b/tests/TestResult.cpp index 05e2622..1c96bd6 100644 --- a/tests/TestResult.cpp +++ b/tests/TestResult.cpp @@ -1,37 +1,25 @@ -#define CATCH_CONFIG_MAIN -#include "catch.hpp" -#include "Result.h" -#include +#include +#include +#include +#include +#include "TestUtils.h" +#include "results/Result.h" +#include "common/DotEnv.h" +#include "common/Datasets.h" +#include "common/Paths.h" +#include "config.h" -TEST_CASE("Result class tests", "[Result]") + +TEST_CASE("ZeroR comparison in reports", "[Report]") { - std::string testPath = "test_data"; - std::string testFile = "test.json"; - - SECTION("Constructor and load method") - { - platform::Result result; - result.load(testPath, testFile); - REQUIRE(result.date != ""); - REQUIRE(result.score >= 0); - REQUIRE(result.scoreName != ""); - REQUIRE(result.title != ""); - REQUIRE(result.duration >= 0); - REQUIRE(result.model != ""); - } - - SECTION("to_string method") - { - platform::Result result(testPath, testFile); - result.load(); - std::string resultStr = result.to_string(1); - REQUIRE(resultStr != ""); - } - - SECTION("Exception handling in load method") - { - std::string invalidFile = "invalid.json"; - auto result = platform::Result(); - REQUIRE_THROWS_AS(platform::result.load(testPath, invalidFile), std::invalid_argument); - } + auto dotEnv = platform::DotEnv(true); + auto margin = 1e-2; + std::string dataset = "liver-disorders"; + auto dt = platform::Datasets(false, platform::Paths::datasets()); + dt.loadDataset(dataset); + std::vector distribution = dt.getClassesCounts(dataset); + double nSamples = dt.getNSamples(dataset); + std::vector::iterator maxValue = max_element(distribution.begin(), distribution.end()); + double mark = *maxValue / nSamples * (1 + margin); + REQUIRE(mark == Catch::Approx(0.585507f).epsilon(1e-5)); } \ No newline at end of file diff --git a/tests/data/all.txt b/tests/data/all.txt new file mode 100644 index 0000000..6963a58 --- /dev/null +++ b/tests/data/all.txt @@ -0,0 +1,8 @@ +diabetes,class, all +ecoli,class, all +glass,Type, all +iris,class, all +kdd_JapaneseVowels,speaker, [2,3,4,5,6,7,8,9,10,11,12,13] +letter,class, all +liver-disorders,selector, all +mfeat-factors,class, all \ No newline at end of file