From 7d3a2dd71333025d95daf027460fba522d6e576d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Thu, 8 May 2025 17:15:42 +0200 Subject: [PATCH 01/27] Remove modules --- .gitignore | 1 + .gitmodules | 21 --------------------- lib/Files | 1 - lib/argparse | 1 - lib/catch2 | 1 - lib/folding | 1 - lib/json | 1 - lib/libxlsxwriter | 1 - lib/mdlp | 1 - 9 files changed, 1 insertion(+), 28 deletions(-) delete mode 100644 .gitmodules delete mode 160000 lib/Files delete mode 160000 lib/argparse delete mode 160000 lib/catch2 delete mode 160000 lib/folding delete mode 160000 lib/json delete mode 160000 lib/libxlsxwriter delete mode 160000 lib/mdlp diff --git a/.gitignore b/.gitignore index 42fbee0..f09d914 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ puml/** diagrams/html/** diagrams/latex/** .cache +vcpkg_installed diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 4b52f53..0000000 --- a/.gitmodules +++ /dev/null @@ -1,21 +0,0 @@ -[submodule "lib/catch2"] - path = lib/catch2 - url = https://github.com/catchorg/Catch2.git -[submodule "lib/argparse"] - path = lib/argparse - url = https://github.com/p-ranav/argparse -[submodule "lib/json"] - path = lib/json - url = https://github.com/nlohmann/json -[submodule "lib/libxlsxwriter"] - path = lib/libxlsxwriter - url = https://github.com/jmcnamara/libxlsxwriter.git -[submodule "lib/folding"] - path = lib/folding - url = https://github.com/rmontanana/folding -[submodule "lib/Files"] - path = lib/Files - url = https://github.com/rmontanana/ArffFiles -[submodule "lib/mdlp"] - path = lib/mdlp - url = https://github.com/rmontanana/mdlp diff --git a/lib/Files b/lib/Files deleted file mode 160000 index 18c79f6..0000000 --- a/lib/Files +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 18c79f6d4894d6b7a6cbfad0239bf9bfd68d3bb4 diff --git a/lib/argparse b/lib/argparse deleted file mode 160000 index cbd9fd8..0000000 --- a/lib/argparse +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cbd9fd8ed675ed6a2ac1bd7142d318c6ad5d3462 diff --git a/lib/catch2 b/lib/catch2 deleted file mode 160000 index 914aeec..0000000 --- a/lib/catch2 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 914aeecfe23b1e16af6ea675a4fb5dbd5a5b8d0a diff --git a/lib/folding b/lib/folding deleted file mode 160000 index 9652853..0000000 --- a/lib/folding +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9652853d692ed3b8a38d89f70559209ffb988020 diff --git a/lib/json b/lib/json deleted file mode 160000 index 48e7b4c..0000000 --- a/lib/json +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 48e7b4c23b089c088c11e51c824d78d0f0949b40 diff --git a/lib/libxlsxwriter b/lib/libxlsxwriter deleted file mode 160000 index 14f1351..0000000 --- a/lib/libxlsxwriter +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 14f13513cb140092a913a91fce719ff7dc36e332 diff --git a/lib/mdlp b/lib/mdlp deleted file mode 160000 index cfb993f..0000000 --- a/lib/mdlp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cfb993f5ec1aabed527f524fdd4db06c6d839868 From b1965c8ae5c470784c771cde2b6bd858507e1a9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 9 May 2025 10:54:27 +0200 Subject: [PATCH 02/27] Add vcpkg config files --- remove_submodules.sh | 14 +++++++++++ vcpkg-configuration.json | 21 +++++++++++++++++ vcpkg.json | 50 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 remove_submodules.sh create mode 100644 vcpkg-configuration.json create mode 100644 vcpkg.json diff --git a/remove_submodules.sh b/remove_submodules.sh new file mode 100644 index 0000000..e5b4ce6 --- /dev/null +++ b/remove_submodules.sh @@ -0,0 +1,14 @@ +git config --file .gitmodules --get-regexp path | awk '{ print $2 }' | while read line; do + echo "Removing $line" + # Deinit the submodule + git submodule deinit -f "$line" + + # Remove the submodule from the working tree + git rm -f "$line" + + # Remove the submodule from .git/modules + rm -rf ".git/modules/$line" +done + +# Remove the .gitmodules file +git rm -f .gitmodules diff --git a/vcpkg-configuration.json b/vcpkg-configuration.json new file mode 100644 index 0000000..b241242 --- /dev/null +++ b/vcpkg-configuration.json @@ -0,0 +1,21 @@ +{ + "default-registry": { + "kind": "git", + "baseline": "760bfd0c8d7c89ec640aec4df89418b7c2745605", + "repository": "https://github.com/microsoft/vcpkg" + }, + "registries": [ + { + "kind": "git", + "repository": "https://github.com/rmontanana/vcpkg-stash", + "baseline": "1ea69243c0e8b0de77c9d1dd6e1d7593ae7f3627", + "packages": [ + "arff-files", + "bayesnet", + "fimdlp", + "folding", + "libtorch-bin" + ] + } + ] +} \ No newline at end of file diff --git a/vcpkg.json b/vcpkg.json new file mode 100644 index 0000000..cf1ea58 --- /dev/null +++ b/vcpkg.json @@ -0,0 +1,50 @@ + { + "name": "platform", + "version-string": "1.1.0", + "dependencies": [ + "arff-files", + "nlohmann-json", + "fimdlp", + "libtorch-bin", + "folding", + "bayesnet", + "argparse", + "libxlsxwriter", + "zlib", + "libzip" + ], + "overrides": [ + { + "name": "arff-files", + "version": "1.1.0" + }, + { + "name": "fimdlp", + "version": "2.0.1" + }, + { + "name": "libtorch-bin", + "version": "2.7.0" + }, + { + "name": "bayesnet", + "version": "1.1.1" + }, + { + "name": "folding", + "version": "1.1.1" + }, + { + "name": "argpase", + "version": "3.2" + }, + { + "name": "libxlsxwriter", + "version": "1.2.2" + }, + { + "name": "nlohmann-json", + "version": "3.11.3" + } + ] + } \ No newline at end of file From 16b49238510f86a66b5e414b0fda6faab487c90c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 9 May 2025 11:10:27 +0200 Subject: [PATCH 03/27] Complete configuration xlsxwriter is still with the old config --- CMakeLists.txt | 75 +- Makefile | 39 +- lib/argparse | 1 + lib/log/loguru.cpp | 2009 ---------------------------------------- lib/log/loguru.hpp | 1475 ----------------------------- lib/mdlp | 1 + sample/CMakeLists.txt | 10 +- src/CMakeLists.txt | 17 +- src/common/Dataset.cpp | 2 +- vcpkg.json | 4 +- 10 files changed, 74 insertions(+), 3559 deletions(-) create mode 160000 lib/argparse delete mode 100644 lib/log/loguru.cpp delete mode 100644 lib/log/loguru.hpp create mode 160000 lib/mdlp diff --git a/CMakeLists.txt b/CMakeLists.txt index 178e84b..31af5c0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,11 +7,7 @@ project(Platform LANGUAGES CXX ) -find_package(Torch REQUIRED) -if (POLICY CMP0135) - cmake_policy(SET CMP0135 NEW) -endif () # Global CMake variables # ---------------------- @@ -26,62 +22,69 @@ set(CMAKE_CXX_FLAGS_DEBUG " ${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage -O # Options # ------- -option(ENABLE_CLANG_TIDY "Enable to add clang tidy." OFF) option(ENABLE_TESTING "Unit testing build" OFF) option(CODE_COVERAGE "Collect coverage from test library" OFF) +# CMakes modules +# -------------- +set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules ${CMAKE_MODULE_PATH}) + # MPI find_package(MPI REQUIRED) message("MPI_CXX_LIBRARIES=${MPI_CXX_LIBRARIES}") message("MPI_CXX_INCLUDE_DIRS=${MPI_CXX_INCLUDE_DIRS}") # Boost Library +cmake_policy(SET CMP0135 NEW) +cmake_policy(SET CMP0167 NEW) # For FindBoost set(Boost_USE_STATIC_LIBS OFF) set(Boost_USE_MULTITHREADED ON) set(Boost_USE_STATIC_RUNTIME OFF) + find_package(Boost 1.66.0 REQUIRED COMPONENTS python3 numpy3) + +# # Python +find_package(Python3 REQUIRED COMPONENTS Development) +# # target_include_directories(MyTarget SYSTEM PRIVATE ${Python3_INCLUDE_DIRS}) +# message("Python_LIBRARIES=${Python_LIBRARIES}") + +# # Boost Python +# find_package(boost_python${Python3_VERSION_MAJOR}${Python3_VERSION_MINOR} CONFIG REQUIRED COMPONENTS python${Python3_VERSION_MAJOR}${Python3_VERSION_MINOR}) +# # target_link_libraries(MyTarget PRIVATE Boost::python${Python3_VERSION_MAJOR}${Python3_VERSION_MINOR}) + + if(Boost_FOUND) message("Boost_INCLUDE_DIRS=${Boost_INCLUDE_DIRS}") + message("Boost_LIBRARIES=${Boost_LIBRARIES}") + message("Boost_VERSION=${Boost_VERSION}") include_directories(${Boost_INCLUDE_DIRS}) endif() -# Python -find_package(Python3 3.11 COMPONENTS Interpreter Development REQUIRED) -message("Python3_LIBRARIES=${Python3_LIBRARIES}") -# CMakes modules -# -------------- -set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules ${CMAKE_MODULE_PATH}) -include(AddGitSubmodule) - -if (CODE_COVERAGE) - enable_testing() - include(CodeCoverage) - MESSAGE("Code coverage enabled") - SET(GCC_COVERAGE_LINK_FLAGS " ${GCC_COVERAGE_LINK_FLAGS} -lgcov --coverage") -endif (CODE_COVERAGE) - -if (ENABLE_CLANG_TIDY) - include(StaticAnalyzers) # clang-tidy -endif (ENABLE_CLANG_TIDY) # External libraries - dependencies of Platform # --------------------------------------------- -add_git_submodule("lib/argparse") -add_git_submodule("lib/mdlp") find_library(XLSXWRITER_LIB NAMES libxlsxwriter.dylib libxlsxwriter.so PATHS ${Platform_SOURCE_DIR}/lib/libxlsxwriter/lib) -message("XLSXWRITER_LIB=${XLSXWRITER_LIB}") +# find_path(XLSXWRITER_INCLUDE_DIR xlsxwriter.h) +# find_library(XLSXWRITER_LIBRARY xlsxwriter) +# message("XLSXWRITER_INCLUDE_DIR=${XLSXWRITER_INCLUDE_DIR}") +# message("XLSXWRITER_LIBRARY=${XLSXWRITER_LIBRARY}") +find_package(Torch CONFIG REQUIRED) +find_package(fimdlp CONFIG REQUIRED) +find_package(folding CONFIG REQUIRED) +find_package(argparse CONFIG REQUIRED) +find_package(Boost REQUIRED COMPONENTS python) +find_package(arff-files CONFIG REQUIRED) +find_package(bayesnet CONFIG REQUIRED) +find_package(ZLIB REQUIRED) +find_library(LIBZIP_LIBRARY NAMES zip) find_library(PyClassifiers NAMES libPyClassifiers PyClassifiers libPyClassifiers.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED) find_path(PyClassifiers_INCLUDE_DIRS REQUIRED NAMES pyclassifiers PATHS ${Platform_SOURCE_DIR}/../lib/include) -find_library(BayesNet NAMES libBayesNet BayesNet libBayesNet.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED) -find_path(Bayesnet_INCLUDE_DIRS REQUIRED NAMES bayesnet PATHS ${Platform_SOURCE_DIR}/../lib/include) message(STATUS "PyClassifiers=${PyClassifiers}") message(STATUS "PyClassifiers_INCLUDE_DIRS=${PyClassifiers_INCLUDE_DIRS}") -message(STATUS "BayesNet=${BayesNet}") -message(STATUS "Bayesnet_INCLUDE_DIRS=${Bayesnet_INCLUDE_DIRS}") # Subdirectories # -------------- @@ -90,16 +93,20 @@ cmake_path(SET TEST_DATA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/tests/data") configure_file(src/common/SourceData.h.in "${CMAKE_BINARY_DIR}/configured_files/include/SourceData.h") add_subdirectory(config) add_subdirectory(src) -add_subdirectory(sample) +# add_subdirectory(sample) file(GLOB Platform_SOURCES CONFIGURE_DEPENDS ${Platform_SOURCE_DIR}/src/*.cpp) # Testing # ------- if (ENABLE_TESTING) + enable_testing() MESSAGE("Testing enabled") - if (NOT TARGET Catch2::Catch2) - add_git_submodule("lib/catch2") - endif (NOT TARGET Catch2::Catch2) + find_package(Catch2 CONFIG REQUIRED) include(CTest) add_subdirectory(tests) endif (ENABLE_TESTING) +if (CODE_COVERAGE) + include(CodeCoverage) + MESSAGE("Code coverage enabled") + SET(GCC_COVERAGE_LINK_FLAGS " ${GCC_COVERAGE_LINK_FLAGS} -lgcov --coverage") +endif (CODE_COVERAGE) diff --git a/Makefile b/Makefile index 59c603b..dbe2381 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ SHELL := /bin/bash .DEFAULT_GOAL := help -.PHONY: coverage setup help build test clean debug release submodules buildr buildd install dependency testp testb clang-uml +.PHONY: init clean coverage setup help build test clean debug release buildr buildd install dependency testp testb clang-uml -f_release = build_release -f_debug = build_debug +f_release = build_Release +f_debug = build_Debug app_targets = b_best b_list b_main b_manage b_grid b_results test_targets = unit_tests_platform @@ -20,14 +20,22 @@ define ClearTests fi ; endef +init: ## Initialize the project installing dependencies + @echo ">>> Installing dependencies" + @vcpkg install + @echo ">>> Done"; -sub-init: ## Initialize submodules - @git submodule update --init --recursive - -sub-update: ## Initialize submodules - @git submodule update --remote --merge - @git submodule foreach git pull origin master - +clean: ## Clean the project + @echo ">>> Cleaning the project..." + @if test -f CMakeCache.txt ; then echo "- Deleting CMakeCache.txt"; rm -f CMakeCache.txt; fi + @for folder in $(f_release) $(f_debug) vpcpkg_installed install_test ; do \ + if test -d "$$folder" ; then \ + echo "- Deleting $$folder folder" ; \ + rm -rf "$$folder"; \ + fi; \ + done + $(call ClearTests) + @echo ">>> Done"; setup: ## Install dependencies for tests and coverage @if [ "$(shell uname)" = "Darwin" ]; then \ brew install gcovr; \ @@ -60,16 +68,11 @@ dependency: ## Create a dependency graph diagram of the project (build/dependenc cd $(f_debug) && cmake .. --graphviz=dependency.dot && dot -Tpng dependency.dot -o dependency.png buildd: ## Build the debug targets - @cmake --build $(f_debug) -t $(app_targets) PlatformSample --parallel + @cmake --build $(f_debug) -t $(app_targets) PlatformSample --parallel buildr: ## Build the release targets @cmake --build $(f_release) -t $(app_targets) --parallel -clean: ## Clean the tests info - @echo ">>> Cleaning Debug Platform tests..."; - $(call ClearTests) - @echo ">>> Done"; - clang-uml: ## Create uml class and sequence diagrams clang-uml -p --add-compile-flag -I /usr/lib/gcc/x86_64-redhat-linux/8/include/ @@ -77,14 +80,14 @@ debug: ## Build a debug version of the project @echo ">>> Building Debug Platform..."; @if [ -d ./$(f_debug) ]; then rm -rf ./$(f_debug); fi @mkdir $(f_debug); - @cmake -S . -B $(f_debug) -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON + @cmake -S . -B $(f_debug) -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake @echo ">>> Done"; release: ## Build a Release version of the project @echo ">>> Building Release Platform..."; @if [ -d ./$(f_release) ]; then rm -rf ./$(f_release); fi @mkdir $(f_release); - @cmake -S . -B $(f_release) -D CMAKE_BUILD_TYPE=Release + @cmake -S . -B $(f_release) -D CMAKE_BUILD_TYPE=Release -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake @echo ">>> Done"; opt = "" diff --git a/lib/argparse b/lib/argparse new file mode 160000 index 0000000..cbd9fd8 --- /dev/null +++ b/lib/argparse @@ -0,0 +1 @@ +Subproject commit cbd9fd8ed675ed6a2ac1bd7142d318c6ad5d3462 diff --git a/lib/log/loguru.cpp b/lib/log/loguru.cpp deleted file mode 100644 index a95cfbf..0000000 --- a/lib/log/loguru.cpp +++ /dev/null @@ -1,2009 +0,0 @@ -#if defined(__GNUC__) || defined(__clang__) -// Disable all warnings from gcc/clang: -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wpragmas" - -#pragma GCC diagnostic ignored "-Wc++98-compat" -#pragma GCC diagnostic ignored "-Wc++98-compat-pedantic" -#pragma GCC diagnostic ignored "-Wexit-time-destructors" -#pragma GCC diagnostic ignored "-Wformat-nonliteral" -#pragma GCC diagnostic ignored "-Wglobal-constructors" -#pragma GCC diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" -#pragma GCC diagnostic ignored "-Wmissing-prototypes" -#pragma GCC diagnostic ignored "-Wpadded" -#pragma GCC diagnostic ignored "-Wsign-conversion" -#pragma GCC diagnostic ignored "-Wunknown-pragmas" -#pragma GCC diagnostic ignored "-Wunused-macros" -#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant" -#elif defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable:4365) // conversion from 'X' to 'Y', signed/unsigned mismatch -#endif - -#include "loguru.hpp" - -#ifndef LOGURU_HAS_BEEN_IMPLEMENTED -#define LOGURU_HAS_BEEN_IMPLEMENTED - -#define LOGURU_PREAMBLE_WIDTH (53 + LOGURU_THREADNAME_WIDTH + LOGURU_FILENAME_WIDTH) - -#undef min -#undef max - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if LOGURU_SYSLOG -#include -#else -#define LOG_USER 0 -#endif - -#ifdef _WIN32 -#include - -#define localtime_r(a, b) localtime_s(b, a) // No localtime_r with MSVC, but arguments are swapped for localtime_s -#else -#include -#include // mkdir -#include // STDERR_FILENO -#endif - -#ifdef __linux__ -#include // PATH_MAX -#elif !defined(_WIN32) -#include // PATH_MAX -#endif - -#ifndef PATH_MAX -#define PATH_MAX 1024 -#endif - -#ifdef __APPLE__ -#include "TargetConditionals.h" -#endif - -// TODO: use defined(_POSIX_VERSION) for some of these things? - -#if defined(_WIN32) || defined(__CYGWIN__) -#define LOGURU_PTHREADS 0 -#define LOGURU_WINTHREADS 1 -#ifndef LOGURU_STACKTRACES -#define LOGURU_STACKTRACES 0 -#endif -#else -#define LOGURU_PTHREADS 1 -#define LOGURU_WINTHREADS 0 -#ifdef __GLIBC__ -#ifndef LOGURU_STACKTRACES -#define LOGURU_STACKTRACES 1 -#endif -#else -#ifndef LOGURU_STACKTRACES -#define LOGURU_STACKTRACES 0 -#endif -#endif -#endif - -#if LOGURU_STACKTRACES -#include // for __cxa_demangle -#include // for dladdr -#include // for backtrace -#endif // LOGURU_STACKTRACES - -#if LOGURU_PTHREADS -#include -#if defined(__FreeBSD__) -#include -#include -#elif defined(__OpenBSD__) -#include -#endif - -#ifdef __linux__ - /* On Linux, the default thread name is the same as the name of the binary. - Additionally, all new threads inherit the name of the thread it got forked from. - For this reason, Loguru use the pthread Thread Local Storage - for storing thread names on Linux. */ -#ifndef LOGURU_PTLS_NAMES -#define LOGURU_PTLS_NAMES 1 -#endif -#endif -#endif - -#if LOGURU_WINTHREADS -#ifndef _WIN32_WINNT -#define _WIN32_WINNT 0x0502 -#endif -#define WIN32_LEAN_AND_MEAN -#define NOMINMAX -#include -#endif - -#ifndef LOGURU_PTLS_NAMES -#define LOGURU_PTLS_NAMES 0 -#endif - -LOGURU_ANONYMOUS_NAMESPACE_BEGIN - -namespace loguru { - using namespace std::chrono; - -#if LOGURU_WITH_FILEABS - struct FileAbs { - char path[PATH_MAX]; - char mode_str[4]; - Verbosity verbosity; - struct stat st; - FILE* fp; - bool is_reopening = false; // to prevent recursive call in file_reopen. - decltype(steady_clock::now()) last_check_time = steady_clock::now(); - }; -#else - typedef FILE* FileAbs; -#endif - - struct Callback { - std::string id; - log_handler_t callback; - void* user_data; - Verbosity verbosity; // Does not change! - close_handler_t close; - flush_handler_t flush; - unsigned indentation; - }; - - using CallbackVec = std::vector; - - using StringPair = std::pair; - using StringPairList = std::vector; - - const auto s_start_time = steady_clock::now(); - - Verbosity g_stderr_verbosity = Verbosity_0; - bool g_colorlogtostderr = true; - unsigned g_flush_interval_ms = 0; - bool g_preamble_header = true; - bool g_preamble = true; - - Verbosity g_internal_verbosity = Verbosity_0; - - // Preamble details - bool g_preamble_date = true; - bool g_preamble_time = true; - bool g_preamble_uptime = true; - bool g_preamble_thread = true; - bool g_preamble_file = true; - bool g_preamble_verbose = true; - bool g_preamble_pipe = true; - - static std::recursive_mutex s_mutex; - static Verbosity s_max_out_verbosity = Verbosity_OFF; - static std::string s_argv0_filename; - static std::string s_arguments; - static char s_current_dir[PATH_MAX]; - static CallbackVec s_callbacks; - static fatal_handler_t s_fatal_handler = nullptr; - static verbosity_to_name_t s_verbosity_to_name_callback = nullptr; - static name_to_verbosity_t s_name_to_verbosity_callback = nullptr; - static StringPairList s_user_stack_cleanups; - static bool s_strip_file_path = true; - static std::atomic s_stderr_indentation{ 0 }; - - // For periodic flushing: - static std::thread* s_flush_thread = nullptr; - static bool s_needs_flushing = false; - - static SignalOptions s_signal_options = SignalOptions::none(); - - static const bool s_terminal_has_color = []() { -#ifdef _WIN32 -#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING -#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004 -#endif - - HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE); - if (hOut != INVALID_HANDLE_VALUE) { - DWORD dwMode = 0; - GetConsoleMode(hOut, &dwMode); - dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; - return SetConsoleMode(hOut, dwMode) != 0; - } - return false; -#else - if (!isatty(STDERR_FILENO)) { - return false; - } - if (const char* term = getenv("TERM")) { - return 0 == strcmp(term, "cygwin") - || 0 == strcmp(term, "linux") - || 0 == strcmp(term, "rxvt-unicode-256color") - || 0 == strcmp(term, "screen") - || 0 == strcmp(term, "screen-256color") - || 0 == strcmp(term, "screen.xterm-256color") - || 0 == strcmp(term, "tmux-256color") - || 0 == strcmp(term, "xterm") - || 0 == strcmp(term, "xterm-256color") - || 0 == strcmp(term, "xterm-termite") - || 0 == strcmp(term, "xterm-color"); - } else { - return false; - } -#endif - }(); - - static void print_preamble_header(char* out_buff, size_t out_buff_size); - - // ------------------------------------------------------------------------------ - // Colors - - bool terminal_has_color() { return s_terminal_has_color; } - - // Colors - -#ifdef _WIN32 -#define VTSEQ(ID) ("\x1b[1;" #ID "m") -#else -#define VTSEQ(ID) ("\x1b[" #ID "m") -#endif - - const char* terminal_black() { return s_terminal_has_color ? VTSEQ(30) : ""; } - const char* terminal_red() { return s_terminal_has_color ? VTSEQ(31) : ""; } - const char* terminal_green() { return s_terminal_has_color ? VTSEQ(32) : ""; } - const char* terminal_yellow() { return s_terminal_has_color ? VTSEQ(33) : ""; } - const char* terminal_blue() { return s_terminal_has_color ? VTSEQ(34) : ""; } - const char* terminal_purple() { return s_terminal_has_color ? VTSEQ(35) : ""; } - const char* terminal_cyan() { return s_terminal_has_color ? VTSEQ(36) : ""; } - const char* terminal_light_gray() { return s_terminal_has_color ? VTSEQ(37) : ""; } - const char* terminal_white() { return s_terminal_has_color ? VTSEQ(37) : ""; } - const char* terminal_light_red() { return s_terminal_has_color ? VTSEQ(91) : ""; } - const char* terminal_dim() { return s_terminal_has_color ? VTSEQ(2) : ""; } - - // Formating - const char* terminal_bold() { return s_terminal_has_color ? VTSEQ(1) : ""; } - const char* terminal_underline() { return s_terminal_has_color ? VTSEQ(4) : ""; } - - // You should end each line with this! - const char* terminal_reset() { return s_terminal_has_color ? VTSEQ(0) : ""; } - - // ------------------------------------------------------------------------------ -#if LOGURU_WITH_FILEABS - void file_reopen(void* user_data); - inline FILE* to_file(void* user_data) { return reinterpret_cast(user_data)->fp; } -#else - inline FILE* to_file(void* user_data) { return reinterpret_cast(user_data); } -#endif - - void file_log(void* user_data, const Message& message) - { -#if LOGURU_WITH_FILEABS - FileAbs* file_abs = reinterpret_cast(user_data); - if (file_abs->is_reopening) { - return; - } - // It is better checking file change every minute/hour/day, - // instead of doing this every time we log. - // Here check_interval is set to zero to enable checking every time; - const auto check_interval = seconds(0); - if (duration_cast(steady_clock::now() - file_abs->last_check_time) > check_interval) { - file_abs->last_check_time = steady_clock::now(); - file_reopen(user_data); - } - FILE* file = to_file(user_data); - if (!file) { - return; - } -#else - FILE* file = to_file(user_data); -#endif - fprintf(file, "%s%s%s%s\n", - message.preamble, message.indentation, message.prefix, message.message); - if (g_flush_interval_ms == 0) { - fflush(file); - } - } - - void file_close(void* user_data) - { - FILE* file = to_file(user_data); - if (file) { - fclose(file); - } -#if LOGURU_WITH_FILEABS - delete reinterpret_cast(user_data); -#endif - } - - void file_flush(void* user_data) - { - FILE* file = to_file(user_data); - fflush(file); - } - -#if LOGURU_WITH_FILEABS - void file_reopen(void* user_data) - { - FileAbs* file_abs = reinterpret_cast(user_data); - struct stat st; - int ret; - if (!file_abs->fp || (ret = stat(file_abs->path, &st)) == -1 || (st.st_ino != file_abs->st.st_ino)) { - file_abs->is_reopening = true; - if (file_abs->fp) { - fclose(file_abs->fp); - } - if (!file_abs->fp) { - VLOG_F(g_internal_verbosity, "Reopening file '" LOGURU_FMT(s) "' due to previous error", file_abs->path); - } else if (ret < 0) { - const auto why = errno_as_text(); - VLOG_F(g_internal_verbosity, "Reopening file '" LOGURU_FMT(s) "' due to '" LOGURU_FMT(s) "'", file_abs->path, why.c_str()); - } else { - VLOG_F(g_internal_verbosity, "Reopening file '" LOGURU_FMT(s) "' due to file changed", file_abs->path); - } - // try reopen current file. - if (!create_directories(file_abs->path)) { - LOG_F(ERROR, "Failed to create directories to '" LOGURU_FMT(s) "'", file_abs->path); - } - file_abs->fp = fopen(file_abs->path, file_abs->mode_str); - if (!file_abs->fp) { - LOG_F(ERROR, "Failed to open '" LOGURU_FMT(s) "'", file_abs->path); - } else { - stat(file_abs->path, &file_abs->st); - } - file_abs->is_reopening = false; - } - } -#endif - // ------------------------------------------------------------------------------ - // ------------------------------------------------------------------------------ -#if LOGURU_SYSLOG - void syslog_log(void* /*user_data*/, const Message& message) - { - /* - Level 0: Is reserved for kernel panic type situations. - Level 1: Is for Major resource failure. - Level 2->7 Application level failures - */ - int level; - if (message.verbosity < Verbosity_FATAL) { - level = 1; // System Alert - } else { - switch (message.verbosity) { - case Verbosity_FATAL: level = 2; break; // System Critical - case Verbosity_ERROR: level = 3; break; // System Error - case Verbosity_WARNING: level = 4; break; // System Warning - case Verbosity_INFO: level = 5; break; // System Notice - case Verbosity_1: level = 6; break; // System Info - default: level = 7; break; // System Debug - } - } - - // Note: We don't add the time info. - // This is done automatically by the syslog deamon. - // Otherwise log all information that the file log does. - syslog(level, "%s%s%s", message.indentation, message.prefix, message.message); - } - - void syslog_close(void* /*user_data*/) - { - closelog(); - } - - void syslog_flush(void* /*user_data*/) - { - } -#endif - // ------------------------------------------------------------------------------ - // Helpers: - - Text::~Text() { free(_str); } - -#if LOGURU_USE_FMTLIB - Text vtextprintf(const char* format, fmt::format_args args) - { - return Text(STRDUP(fmt::vformat(format, args).c_str())); - } -#else - LOGURU_PRINTF_LIKE(1, 0) - static Text vtextprintf(const char* format, va_list vlist) - { -#ifdef _WIN32 - int bytes_needed = _vscprintf(format, vlist); - CHECK_F(bytes_needed >= 0, "Bad string format: '%s'", format); - char* buff = (char*)malloc(bytes_needed + 1); - vsnprintf(buff, bytes_needed + 1, format, vlist); - return Text(buff); -#else - char* buff = nullptr; - int result = vasprintf(&buff, format, vlist); - CHECK_F(result >= 0, "Bad string format: '" LOGURU_FMT(s) "'", format); - return Text(buff); -#endif - } - - Text textprintf(const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - auto result = vtextprintf(format, vlist); - va_end(vlist); - return result; - } -#endif - - // Overloaded for variadic template matching. - Text textprintf() - { - return Text(static_cast(calloc(1, 1))); - } - - static const char* indentation(unsigned depth) - { - static const char buff[] = - ". . . . . . . . . . " ". . . . . . . . . . " - ". . . . . . . . . . " ". . . . . . . . . . " - ". . . . . . . . . . " ". . . . . . . . . . " - ". . . . . . . . . . " ". . . . . . . . . . " - ". . . . . . . . . . " ". . . . . . . . . . "; - static const size_t INDENTATION_WIDTH = 4; - static const size_t NUM_INDENTATIONS = (sizeof(buff) - 1) / INDENTATION_WIDTH; - depth = std::min(depth, NUM_INDENTATIONS); - return buff + INDENTATION_WIDTH * (NUM_INDENTATIONS - depth); - } - - static void parse_args(int& argc, char* argv[], const char* verbosity_flag) - { - int arg_dest = 1; - int out_argc = argc; - - for (int arg_it = 1; arg_it < argc; ++arg_it) { - auto cmd = argv[arg_it]; - auto arg_len = strlen(verbosity_flag); - - bool last_is_alpha = false; -#if LOGURU_USE_LOCALE - try { // locale variant of isalpha will throw on error - last_is_alpha = std::isalpha(cmd[arg_len], std::locale("")); - } - catch (...) { - last_is_alpha = std::isalpha(static_cast(cmd[arg_len])); - } -#else - last_is_alpha = std::isalpha(static_cast(cmd[arg_len])); -#endif - - if (strncmp(cmd, verbosity_flag, arg_len) == 0 && !last_is_alpha) { - out_argc -= 1; - auto value_str = cmd + arg_len; - if (value_str[0] == '\0') { - // Value in separate argument - arg_it += 1; - CHECK_LT_F(arg_it, argc, "Missing verbosiy level after " LOGURU_FMT(s) "", verbosity_flag); - value_str = argv[arg_it]; - out_argc -= 1; - } - if (*value_str == '=') { value_str += 1; } - - auto req_verbosity = get_verbosity_from_name(value_str); - if (req_verbosity != Verbosity_INVALID) { - g_stderr_verbosity = req_verbosity; - } else { - char* end = 0; - g_stderr_verbosity = static_cast(strtol(value_str, &end, 10)); - CHECK_F(end && *end == '\0', - "Invalid verbosity. Expected integer, INFO, WARNING, ERROR or OFF, got '" LOGURU_FMT(s) "'", value_str); - } - } else { - argv[arg_dest++] = argv[arg_it]; - } - } - - argc = out_argc; - argv[argc] = nullptr; - } - - static long long now_ns() - { - return duration_cast(high_resolution_clock::now().time_since_epoch()).count(); - } - - // Returns the part of the path after the last / or \ (if any). - const char* filename(const char* path) - { - for (auto ptr = path; *ptr; ++ptr) { - if (*ptr == '/' || *ptr == '\\') { - path = ptr + 1; - } - } - return path; - } - - // ------------------------------------------------------------------------------ - - static void on_atexit() - { - VLOG_F(g_internal_verbosity, "atexit"); - flush(); - } - - static void install_signal_handlers(const SignalOptions& signal_options); - - static void write_hex_digit(std::string& out, unsigned num) - { - DCHECK_LT_F(num, 16u); - if (num < 10u) { out.push_back(char('0' + num)); } else { out.push_back(char('A' + num - 10)); } - } - - static void write_hex_byte(std::string& out, uint8_t n) - { - write_hex_digit(out, n >> 4u); - write_hex_digit(out, n & 0x0f); - } - - static void escape(std::string& out, const std::string& str) - { - for (char c : str) { - /**/ if (c == '\a') { out += "\\a"; } else if (c == '\b') { out += "\\b"; } else if (c == '\f') { out += "\\f"; } else if (c == '\n') { out += "\\n"; } else if (c == '\r') { out += "\\r"; } else if (c == '\t') { out += "\\t"; } else if (c == '\v') { out += "\\v"; } else if (c == '\\') { out += "\\\\"; } else if (c == '\'') { out += "\\\'"; } else if (c == '\"') { out += "\\\""; } else if (c == ' ') { out += "\\ "; } else if (0 <= c && c < 0x20) { // ASCI control character: - // else if (c < 0x20 || c != (c & 127)) { // ASCII control character or UTF-8: - out += "\\x"; - write_hex_byte(out, static_cast(c)); - } else { out += c; } - } - } - - Text errno_as_text() - { - char buff[256]; -#if defined(__GLIBC__) && defined(_GNU_SOURCE) - // GNU Version - return Text(STRDUP(strerror_r(errno, buff, sizeof(buff)))); -#elif defined(__APPLE__) || _POSIX_C_SOURCE >= 200112L - // XSI Version - strerror_r(errno, buff, sizeof(buff)); - return Text(strdup(buff)); -#elif defined(_WIN32) - strerror_s(buff, sizeof(buff), errno); - return Text(STRDUP(buff)); -#else - // Not thread-safe. - return Text(STRDUP(strerror(errno))); -#endif - } - - void init(int& argc, char* argv[], const Options& options) - { - CHECK_GT_F(argc, 0, "Expected proper argc/argv"); - CHECK_EQ_F(argv[argc], nullptr, "Expected proper argc/argv"); - - s_argv0_filename = filename(argv[0]); - -#ifdef _WIN32 -#define getcwd _getcwd -#endif - - if (!getcwd(s_current_dir, sizeof(s_current_dir))) { - const auto error_text = errno_as_text(); - LOG_F(WARNING, "Failed to get current working directory: " LOGURU_FMT(s) "", error_text.c_str()); - } - - s_arguments = ""; - for (int i = 0; i < argc; ++i) { - escape(s_arguments, argv[i]); - if (i + 1 < argc) { - s_arguments += " "; - } - } - - if (options.verbosity_flag) { - parse_args(argc, argv, options.verbosity_flag); - } - - if (const auto main_thread_name = options.main_thread_name) { -#if LOGURU_PTLS_NAMES || LOGURU_WINTHREADS - set_thread_name(main_thread_name); -#elif LOGURU_PTHREADS - char old_thread_name[16] = { 0 }; - auto this_thread = pthread_self(); -#if defined(__APPLE__) || defined(__linux__) || defined(__sun) - pthread_getname_np(this_thread, old_thread_name, sizeof(old_thread_name)); -#endif - if (old_thread_name[0] == 0) { -#ifdef __APPLE__ - pthread_setname_np(main_thread_name); -#elif defined(__FreeBSD__) || defined(__OpenBSD__) - pthread_set_name_np(this_thread, main_thread_name); -#elif defined(__linux__) || defined(__sun) - pthread_setname_np(this_thread, main_thread_name); -#endif - } -#endif // LOGURU_PTHREADS - } - - if (g_stderr_verbosity >= Verbosity_INFO) { - if (g_preamble_header) { - char preamble_explain[LOGURU_PREAMBLE_WIDTH]; - print_preamble_header(preamble_explain, sizeof(preamble_explain)); - if (g_colorlogtostderr && s_terminal_has_color) { - fprintf(stderr, "%s%s%s\n", terminal_reset(), terminal_dim(), preamble_explain); - } else { - fprintf(stderr, "%s\n", preamble_explain); - } - } - fflush(stderr); - } - VLOG_F(g_internal_verbosity, "arguments: " LOGURU_FMT(s) "", s_arguments.c_str()); - if (strlen(s_current_dir) != 0) { - VLOG_F(g_internal_verbosity, "Current dir: " LOGURU_FMT(s) "", s_current_dir); - } - VLOG_F(g_internal_verbosity, "stderr verbosity: " LOGURU_FMT(d) "", g_stderr_verbosity); - VLOG_F(g_internal_verbosity, "-----------------------------------"); - - install_signal_handlers(options.signal_options); - - atexit(on_atexit); - } - - void shutdown() - { - VLOG_F(g_internal_verbosity, "loguru::shutdown()"); - remove_all_callbacks(); - set_fatal_handler(nullptr); - set_verbosity_to_name_callback(nullptr); - set_name_to_verbosity_callback(nullptr); - } - - void write_date_time(char* buff, unsigned long long buff_size) - { - auto now = system_clock::now(); - long long ms_since_epoch = duration_cast(now.time_since_epoch()).count(); - time_t sec_since_epoch = time_t(ms_since_epoch / 1000); - tm time_info; - localtime_r(&sec_since_epoch, &time_info); - snprintf(buff, buff_size, "%04d%02d%02d_%02d%02d%02d.%03lld", - 1900 + time_info.tm_year, 1 + time_info.tm_mon, time_info.tm_mday, - time_info.tm_hour, time_info.tm_min, time_info.tm_sec, ms_since_epoch % 1000); - } - - const char* argv0_filename() - { - return s_argv0_filename.c_str(); - } - - const char* arguments() - { - return s_arguments.c_str(); - } - - const char* current_dir() - { - return s_current_dir; - } - - const char* home_dir() - { -#ifdef __MINGW32__ - auto home = getenv("USERPROFILE"); - CHECK_F(home != nullptr, "Missing USERPROFILE"); - return home; -#elif defined(_WIN32) - char* user_profile; - size_t len; - errno_t err = _dupenv_s(&user_profile, &len, "USERPROFILE"); - CHECK_F(err == 0, "Missing USERPROFILE"); - return user_profile; -#else // _WIN32 - auto home = getenv("HOME"); - CHECK_F(home != nullptr, "Missing HOME"); - return home; -#endif // _WIN32 - } - - void suggest_log_path(const char* prefix, char* buff, unsigned long long buff_size) - { - if (prefix[0] == '~') { - snprintf(buff, buff_size - 1, "%s%s", home_dir(), prefix + 1); - } else { - snprintf(buff, buff_size - 1, "%s", prefix); - } - - // Check for terminating / - size_t n = strlen(buff); - if (n != 0) { - if (buff[n - 1] != '/') { - CHECK_F(n + 2 < buff_size, "Filename buffer too small"); - buff[n] = '/'; - buff[n + 1] = '\0'; - } - } - -#ifdef _WIN32 - strncat_s(buff, buff_size - strlen(buff) - 1, s_argv0_filename.c_str(), buff_size - strlen(buff) - 1); - strncat_s(buff, buff_size - strlen(buff) - 1, "/", buff_size - strlen(buff) - 1); - write_date_time(buff + strlen(buff), buff_size - strlen(buff)); - strncat_s(buff, buff_size - strlen(buff) - 1, ".log", buff_size - strlen(buff) - 1); -#else - strncat(buff, s_argv0_filename.c_str(), buff_size - strlen(buff) - 1); - strncat(buff, "/", buff_size - strlen(buff) - 1); - write_date_time(buff + strlen(buff), buff_size - strlen(buff)); - strncat(buff, ".log", buff_size - strlen(buff) - 1); -#endif - } - - bool create_directories(const char* file_path_const) - { - CHECK_F(file_path_const && *file_path_const); - char* file_path = STRDUP(file_path_const); - for (char* p = strchr(file_path + 1, '/'); p; p = strchr(p + 1, '/')) { - *p = '\0'; - -#ifdef _WIN32 - if (_mkdir(file_path) == -1) { -#else - if (mkdir(file_path, 0755) == -1) { -#endif - if (errno != EEXIST) { - LOG_F(ERROR, "Failed to create directory '" LOGURU_FMT(s) "'", file_path); - LOG_IF_F(ERROR, errno == EACCES, "EACCES"); - LOG_IF_F(ERROR, errno == ENAMETOOLONG, "ENAMETOOLONG"); - LOG_IF_F(ERROR, errno == ENOENT, "ENOENT"); - LOG_IF_F(ERROR, errno == ENOTDIR, "ENOTDIR"); - LOG_IF_F(ERROR, errno == ELOOP, "ELOOP"); - - *p = '/'; - free(file_path); - return false; - } - } - *p = '/'; - } - free(file_path); - return true; - } - bool add_file(const char* path_in, FileMode mode, Verbosity verbosity) - { - char path[PATH_MAX]; - if (path_in[0] == '~') { - snprintf(path, sizeof(path) - 1, "%s%s", home_dir(), path_in + 1); - } else { - snprintf(path, sizeof(path) - 1, "%s", path_in); - } - - if (!create_directories(path)) { - LOG_F(ERROR, "Failed to create directories to '" LOGURU_FMT(s) "'", path); - } - - const char* mode_str = (mode == FileMode::Truncate ? "w" : "a"); - FILE* file; -#ifdef _WIN32 - file = _fsopen(path, mode_str, _SH_DENYNO); -#else - file = fopen(path, mode_str); -#endif - if (!file) { - LOG_F(ERROR, "Failed to open '" LOGURU_FMT(s) "'", path); - return false; - } -#if LOGURU_WITH_FILEABS - FileAbs* file_abs = new FileAbs(); // this is deleted in file_close; - snprintf(file_abs->path, sizeof(file_abs->path) - 1, "%s", path); - snprintf(file_abs->mode_str, sizeof(file_abs->mode_str) - 1, "%s", mode_str); - stat(file_abs->path, &file_abs->st); - file_abs->fp = file; - file_abs->verbosity = verbosity; - add_callback(path_in, file_log, file_abs, verbosity, file_close, file_flush); -#else - add_callback(path_in, file_log, file, verbosity, file_close, file_flush); -#endif - - if (mode == FileMode::Append) { - fprintf(file, "\n\n\n\n\n"); - } - if (!s_arguments.empty()) { - fprintf(file, "arguments: %s\n", s_arguments.c_str()); - } - if (strlen(s_current_dir) != 0) { - fprintf(file, "Current dir: %s\n", s_current_dir); - } - fprintf(file, "File verbosity level: %d\n", verbosity); - if (g_preamble_header) { - char preamble_explain[LOGURU_PREAMBLE_WIDTH]; - print_preamble_header(preamble_explain, sizeof(preamble_explain)); - fprintf(file, "%s\n", preamble_explain); - } - fflush(file); - - VLOG_F(g_internal_verbosity, "Logging to '" LOGURU_FMT(s) "', mode: '" LOGURU_FMT(s) "', verbosity: " LOGURU_FMT(d) "", path, mode_str, verbosity); - return true; - } - - /* - Will add syslog as a standard sink for log messages - Any logging message with a verbosity lower or equal to - the given verbosity will be included. - - This works for Unix like systems (i.e. Linux/Mac) - There is no current implementation for Windows (as I don't know the - equivalent calls or have a way to test them). If you know please - add and send a pull request. - - The code should still compile under windows but will only generate - a warning message that syslog is unavailable. - - Search for LOGURU_SYSLOG to find and fix. - */ - bool add_syslog(const char* app_name, Verbosity verbosity) - { - return add_syslog(app_name, verbosity, LOG_USER); - } - bool add_syslog(const char* app_name, Verbosity verbosity, int facility) - { -#if LOGURU_SYSLOG - if (app_name == nullptr) { - app_name = argv0_filename(); - } - openlog(app_name, 0, facility); - add_callback("'syslog'", syslog_log, nullptr, verbosity, syslog_close, syslog_flush); - - VLOG_F(g_internal_verbosity, "Logging to 'syslog' , verbosity: " LOGURU_FMT(d) "", verbosity); - return true; -#else - (void)app_name; - (void)verbosity; - (void)facility; - VLOG_F(g_internal_verbosity, "syslog not implemented on this system. Request to install syslog logging ignored."); - return false; -#endif - } - // Will be called right before abort(). - void set_fatal_handler(fatal_handler_t handler) - { - s_fatal_handler = handler; - } - - fatal_handler_t get_fatal_handler() - { - return s_fatal_handler; - } - - void set_verbosity_to_name_callback(verbosity_to_name_t callback) - { - s_verbosity_to_name_callback = callback; - } - - void set_name_to_verbosity_callback(name_to_verbosity_t callback) - { - s_name_to_verbosity_callback = callback; - } - - void add_stack_cleanup(const char* find_this, const char* replace_with_this) - { - if (strlen(find_this) <= strlen(replace_with_this)) { - LOG_F(WARNING, "add_stack_cleanup: the replacement should be shorter than the pattern!"); - return; - } - - s_user_stack_cleanups.push_back(StringPair(find_this, replace_with_this)); - } - - static void on_callback_change() - { - s_max_out_verbosity = Verbosity_OFF; - for (const auto& callback : s_callbacks) { - s_max_out_verbosity = std::max(s_max_out_verbosity, callback.verbosity); - } - } - - void add_callback( - const char* id, - log_handler_t callback, - void* user_data, - Verbosity verbosity, - close_handler_t on_close, - flush_handler_t on_flush) - { - std::lock_guard lock(s_mutex); - s_callbacks.push_back(Callback{ id, callback, user_data, verbosity, on_close, on_flush, 0 }); - on_callback_change(); - } - - // Returns a custom verbosity name if one is available, or nullptr. - // See also set_verbosity_to_name_callback. - const char* get_verbosity_name(Verbosity verbosity) - { - auto name = s_verbosity_to_name_callback - ? (*s_verbosity_to_name_callback)(verbosity) - : nullptr; - - // Use standard replacements if callback fails: - if (!name) { - if (verbosity <= Verbosity_FATAL) { - name = "FATL"; - } else if (verbosity == Verbosity_ERROR) { - name = "ERR"; - } else if (verbosity == Verbosity_WARNING) { - name = "WARN"; - } else if (verbosity == Verbosity_INFO) { - name = "INFO"; - } - } - - return name; - } - - // Returns Verbosity_INVALID if the name is not found. - // See also set_name_to_verbosity_callback. - Verbosity get_verbosity_from_name(const char* name) - { - auto verbosity = s_name_to_verbosity_callback - ? (*s_name_to_verbosity_callback)(name) - : Verbosity_INVALID; - - // Use standard replacements if callback fails: - if (verbosity == Verbosity_INVALID) { - if (strcmp(name, "OFF") == 0) { - verbosity = Verbosity_OFF; - } else if (strcmp(name, "INFO") == 0) { - verbosity = Verbosity_INFO; - } else if (strcmp(name, "WARNING") == 0) { - verbosity = Verbosity_WARNING; - } else if (strcmp(name, "ERROR") == 0) { - verbosity = Verbosity_ERROR; - } else if (strcmp(name, "FATAL") == 0) { - verbosity = Verbosity_FATAL; - } - } - - return verbosity; - } - - bool remove_callback(const char* id) - { - std::lock_guard lock(s_mutex); - auto it = std::find_if(begin(s_callbacks), end(s_callbacks), [&](const Callback& c) { return c.id == id; }); - if (it != s_callbacks.end()) { - if (it->close) { it->close(it->user_data); } - s_callbacks.erase(it); - on_callback_change(); - return true; - } else { - LOG_F(ERROR, "Failed to locate callback with id '" LOGURU_FMT(s) "'", id); - return false; - } - } - - void remove_all_callbacks() - { - std::lock_guard lock(s_mutex); - for (auto& callback : s_callbacks) { - if (callback.close) { - callback.close(callback.user_data); - } - } - s_callbacks.clear(); - on_callback_change(); - } - - // Returns the maximum of g_stderr_verbosity and all file/custom outputs. - Verbosity current_verbosity_cutoff() - { - return g_stderr_verbosity > s_max_out_verbosity ? - g_stderr_verbosity : s_max_out_verbosity; - } - - // ------------------------------------------------------------------------ - // Threads names - -#if LOGURU_PTLS_NAMES - static pthread_once_t s_pthread_key_once = PTHREAD_ONCE_INIT; - static pthread_key_t s_pthread_key_name; - - void make_pthread_key_name() - { - (void)pthread_key_create(&s_pthread_key_name, free); - } -#endif - -#if LOGURU_WINTHREADS - // Where we store the custom thread name set by `set_thread_name` - char* thread_name_buffer() - { - __declspec(thread) static char thread_name[LOGURU_THREADNAME_WIDTH + 1] = { 0 }; - return &thread_name[0]; - } -#endif // LOGURU_WINTHREADS - - void set_thread_name(const char* name) - { -#if LOGURU_PTLS_NAMES - // Store thread name in thread-local storage at `s_pthread_key_name` - (void)pthread_once(&s_pthread_key_once, make_pthread_key_name); - (void)pthread_setspecific(s_pthread_key_name, STRDUP(name)); -#elif LOGURU_PTHREADS - // Tell the OS the thread name -#ifdef __APPLE__ - pthread_setname_np(name); -#elif defined(__FreeBSD__) || defined(__OpenBSD__) - pthread_set_name_np(pthread_self(), name); -#elif defined(__linux__) || defined(__sun) - pthread_setname_np(pthread_self(), name); -#endif -#elif LOGURU_WINTHREADS - // Store thread name in a thread-local storage: - strncpy_s(thread_name_buffer(), LOGURU_THREADNAME_WIDTH + 1, name, _TRUNCATE); -#else // LOGURU_PTHREADS - // TODO: on these weird platforms we should also store the thread name - // in a generic thread-local storage. - (void)name; -#endif // LOGURU_PTHREADS - } - - void get_thread_name(char* buffer, unsigned long long length, bool right_align_hex_id) - { - CHECK_NE_F(length, 0u, "Zero length buffer in get_thread_name"); - CHECK_NOTNULL_F(buffer, "nullptr in get_thread_name"); - -#if LOGURU_PTLS_NAMES - (void)pthread_once(&s_pthread_key_once, make_pthread_key_name); - if (const char* name = static_cast(pthread_getspecific(s_pthread_key_name))) { - snprintf(buffer, static_cast(length), "%s", name); - } else { - buffer[0] = 0; - } -#elif LOGURU_PTHREADS - // Ask the OS about the thread name. - // This is what we *want* to do on all platforms, but - // only some platforms support it (currently). - pthread_getname_np(pthread_self(), buffer, length); -#elif LOGURU_WINTHREADS - snprintf(buffer, static_cast(length), "%s", thread_name_buffer()); -#else - // Thread names unsupported - buffer[0] = 0; -#endif - - if (buffer[0] == 0) { - // We failed to get a readable thread name. - // Write a HEX thread ID instead. - // We try to get an ID that is the same as the ID you could - // read in your debugger, system monitor etc. - -#ifdef __APPLE__ - uint64_t thread_id; - pthread_threadid_np(pthread_self(), &thread_id); -#elif defined(__FreeBSD__) - long thread_id; - (void)thr_self(&thread_id); -#elif LOGURU_PTHREADS - uint64_t thread_id = pthread_self(); -#else - // This ID does not correllate to anything we can get from the OS, - // so this is the worst way to get the ID. - const auto thread_id = std::hash{}(std::this_thread::get_id()); -#endif - - if (right_align_hex_id) { - snprintf(buffer, static_cast(length), "%*X", static_cast(length - 1), static_cast(thread_id)); - } else { - snprintf(buffer, static_cast(length), "%X", static_cast(thread_id)); - } - } - } - - // ------------------------------------------------------------------------ - // Stack traces - -#if LOGURU_STACKTRACES - Text demangle(const char* name) - { - int status = -1; - char* demangled = abi::__cxa_demangle(name, 0, 0, &status); - Text result{ status == 0 ? demangled : STRDUP(name) }; - return result; - } - -#if LOGURU_RTTI - template - std::string type_name() - { - auto demangled = demangle(typeid(T).name()); - return demangled.c_str(); - } -#endif // LOGURU_RTTI - - static const StringPairList REPLACE_LIST = { - #if LOGURU_RTTI - { type_name(), "std::string" }, - { type_name(), "std::wstring" }, - { type_name(), "std::u16string" }, - { type_name(), "std::u32string" }, - #endif // LOGURU_RTTI - { "std::__1::", "std::" }, - { "__thiscall ", "" }, - { "__cdecl ", "" }, - }; - - void do_replacements(const StringPairList & replacements, std::string & str) - { - for (auto&& p : replacements) { - if (p.first.size() <= p.second.size()) { - // On gcc, "type_name()" is "std::string" - continue; - } - - size_t it; - while ((it = str.find(p.first)) != std::string::npos) { - str.replace(it, p.first.size(), p.second); - } - } - } - - std::string prettify_stacktrace(const std::string & input) - { - std::string output = input; - - do_replacements(s_user_stack_cleanups, output); - do_replacements(REPLACE_LIST, output); - - try { - std::regex std_allocator_re(R"(,\s*std::allocator<[^<>]+>)"); - output = std::regex_replace(output, std_allocator_re, std::string("")); - - std::regex template_spaces_re(R"(<\s*([^<> ]+)\s*>)"); - output = std::regex_replace(output, template_spaces_re, std::string("<$1>")); - } - catch (std::regex_error&) { - // Probably old GCC. - } - - return output; - } - - std::string stacktrace_as_stdstring(int skip) - { - // From https://gist.github.com/fmela/591333 - void* callstack[128]; - const auto max_frames = sizeof(callstack) / sizeof(callstack[0]); - int num_frames = backtrace(callstack, max_frames); - char** symbols = backtrace_symbols(callstack, num_frames); - - std::string result; - // Print stack traces so the most relevant ones are written last - // Rationale: http://yellerapp.com/posts/2015-01-22-upside-down-stacktraces.html - for (int i = num_frames - 1; i >= skip; --i) { - char buf[1024]; - Dl_info info; - if (dladdr(callstack[i], &info) && info.dli_sname) { - char* demangled = NULL; - int status = -1; - if (info.dli_sname[0] == '_') { - demangled = abi::__cxa_demangle(info.dli_sname, 0, 0, &status); - } - snprintf(buf, sizeof(buf), "%-3d %*p %s + %zd\n", - i - skip, int(2 + sizeof(void*) * 2), callstack[i], - status == 0 ? demangled : - info.dli_sname == 0 ? symbols[i] : info.dli_sname, - static_cast(callstack[i]) - static_cast(info.dli_saddr)); - free(demangled); - } else { - snprintf(buf, sizeof(buf), "%-3d %*p %s\n", - i - skip, int(2 + sizeof(void*) * 2), callstack[i], symbols[i]); - } - result += buf; - } - free(symbols); - - if (num_frames == max_frames) { - result = "[truncated]\n" + result; - } - - if (!result.empty() && result[result.size() - 1] == '\n') { - result.resize(result.size() - 1); - } - - return prettify_stacktrace(result); - } - -#else // LOGURU_STACKTRACES - Text demangle(const char* name) - { - return Text(STRDUP(name)); - } - - std::string stacktrace_as_stdstring(int) - { - // No stacktraces available on this platform" - return ""; - } - -#endif // LOGURU_STACKTRACES - - Text stacktrace(int skip) - { - auto str = stacktrace_as_stdstring(skip + 1); - return Text(STRDUP(str.c_str())); - } - - // ------------------------------------------------------------------------ - - static void print_preamble_header(char* out_buff, size_t out_buff_size) - { - if (out_buff_size == 0) { return; } - out_buff[0] = '\0'; - size_t pos = 0; - if (g_preamble_date && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "date "); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_time && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "time "); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_uptime && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "( uptime ) "); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_thread && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "[%-*s]", LOGURU_THREADNAME_WIDTH, " thread name/id"); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_file && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%*s:line ", LOGURU_FILENAME_WIDTH, "file"); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_verbose && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, " v"); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_pipe && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "| "); - if (bytes > 0) { - pos += bytes; - } - } - } - - static void print_preamble(char* out_buff, size_t out_buff_size, Verbosity verbosity, const char* file, unsigned line) - { - if (out_buff_size == 0) { return; } - out_buff[0] = '\0'; - if (!g_preamble) { return; } - long long ms_since_epoch = duration_cast(system_clock::now().time_since_epoch()).count(); - time_t sec_since_epoch = time_t(ms_since_epoch / 1000); - tm time_info; - localtime_r(&sec_since_epoch, &time_info); - - auto uptime_ms = duration_cast(steady_clock::now() - s_start_time).count(); - auto uptime_sec = static_cast (uptime_ms) / 1000.0; - - char thread_name[LOGURU_THREADNAME_WIDTH + 1] = { 0 }; - get_thread_name(thread_name, LOGURU_THREADNAME_WIDTH + 1, true); - - if (s_strip_file_path) { - file = filename(file); - } - - char level_buff[6]; - const char* custom_level_name = get_verbosity_name(verbosity); - if (custom_level_name) { - snprintf(level_buff, sizeof(level_buff) - 1, "%s", custom_level_name); - } else { - snprintf(level_buff, sizeof(level_buff) - 1, "% 4d", static_cast(verbosity)); - } - - size_t pos = 0; - - if (g_preamble_date && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%04d-%02d-%02d ", - 1900 + time_info.tm_year, 1 + time_info.tm_mon, time_info.tm_mday); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_time && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%02d:%02d:%02d.%03lld ", - time_info.tm_hour, time_info.tm_min, time_info.tm_sec, ms_since_epoch % 1000); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_uptime && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "(%8.3fs) ", - uptime_sec); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_thread && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "[%-*s]", - LOGURU_THREADNAME_WIDTH, thread_name); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_file && pos < out_buff_size) { - char shortened_filename[LOGURU_FILENAME_WIDTH + 1]; - snprintf(shortened_filename, LOGURU_FILENAME_WIDTH + 1, "%s", file); - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%*s:%-5u ", - LOGURU_FILENAME_WIDTH, shortened_filename, line); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_verbose && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%4s", - level_buff); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_pipe && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "| "); - if (bytes > 0) { - pos += bytes; - } - } - } - - // stack_trace_skip is just if verbosity == FATAL. - static void log_message(int stack_trace_skip, Message & message, bool with_indentation, bool abort_if_fatal) - { - const auto verbosity = message.verbosity; - std::lock_guard lock(s_mutex); - - if (message.verbosity == Verbosity_FATAL) { - auto st = loguru::stacktrace(stack_trace_skip + 2); - if (!st.empty()) { - RAW_LOG_F(ERROR, "Stack trace:\n" LOGURU_FMT(s) "", st.c_str()); - } - - auto ec = loguru::get_error_context(); - if (!ec.empty()) { - RAW_LOG_F(ERROR, "" LOGURU_FMT(s) "", ec.c_str()); - } - } - - if (with_indentation) { - message.indentation = indentation(s_stderr_indentation); - } - - if (verbosity <= g_stderr_verbosity) { - if (g_colorlogtostderr && s_terminal_has_color) { - if (verbosity > Verbosity_WARNING) { - fprintf(stderr, "%s%s%s%s%s%s%s%s\n", - terminal_reset(), - terminal_dim(), - message.preamble, - message.indentation, - verbosity == Verbosity_INFO ? terminal_reset() : "", // un-dim for info - message.prefix, - message.message, - terminal_reset()); - } else { - fprintf(stderr, "%s%s%s%s%s%s%s\n", - terminal_reset(), - verbosity == Verbosity_WARNING ? terminal_yellow() : terminal_red(), - message.preamble, - message.indentation, - message.prefix, - message.message, - terminal_reset()); - } - } else { - fprintf(stderr, "%s%s%s%s\n", - message.preamble, message.indentation, message.prefix, message.message); - } - - if (g_flush_interval_ms == 0) { - fflush(stderr); - } else { - s_needs_flushing = true; - } - } - - for (auto& p : s_callbacks) { - if (verbosity <= p.verbosity) { - if (with_indentation) { - message.indentation = indentation(p.indentation); - } - p.callback(p.user_data, message); - if (g_flush_interval_ms == 0) { - if (p.flush) { p.flush(p.user_data); } - } else { - s_needs_flushing = true; - } - } - } - - if (g_flush_interval_ms > 0 && !s_flush_thread) { - s_flush_thread = new std::thread([]() { - for (;;) { - if (s_needs_flushing) { - flush(); - } - std::this_thread::sleep_for(std::chrono::milliseconds(g_flush_interval_ms)); - } - }); - } - - if (message.verbosity == Verbosity_FATAL) { - flush(); - - if (s_fatal_handler) { - s_fatal_handler(message); - flush(); - } - - if (abort_if_fatal) { -#if !defined(_WIN32) - if (s_signal_options.sigabrt) { - // Make sure we don't catch our own abort: - signal(SIGABRT, SIG_DFL); - } -#endif - abort(); - } - } - } - - // stack_trace_skip is just if verbosity == FATAL. - void log_to_everywhere(int stack_trace_skip, Verbosity verbosity, - const char* file, unsigned line, - const char* prefix, const char* buff) - { - char preamble_buff[LOGURU_PREAMBLE_WIDTH]; - print_preamble(preamble_buff, sizeof(preamble_buff), verbosity, file, line); - auto message = Message{ verbosity, file, line, preamble_buff, "", prefix, buff }; - log_message(stack_trace_skip + 1, message, true, true); - } - -#if LOGURU_USE_FMTLIB - void vlog(Verbosity verbosity, const char* file, unsigned line, const char* format, fmt::format_args args) - { - auto formatted = fmt::vformat(format, args); - log_to_everywhere(1, verbosity, file, line, "", formatted.c_str()); - } - - void raw_vlog(Verbosity verbosity, const char* file, unsigned line, const char* format, fmt::format_args args) - { - auto formatted = fmt::vformat(format, args); - auto message = Message{ verbosity, file, line, "", "", "", formatted.c_str() }; - log_message(1, message, false, true); - } -#else - void log(Verbosity verbosity, const char* file, unsigned line, const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - vlog(verbosity, file, line, format, vlist); - va_end(vlist); - } - - void vlog(Verbosity verbosity, const char* file, unsigned line, const char* format, va_list vlist) - { - auto buff = vtextprintf(format, vlist); - log_to_everywhere(1, verbosity, file, line, "", buff.c_str()); - } - - void raw_log(Verbosity verbosity, const char* file, unsigned line, const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - auto buff = vtextprintf(format, vlist); - auto message = Message{ verbosity, file, line, "", "", "", buff.c_str() }; - log_message(1, message, false, true); - va_end(vlist); - } -#endif - - void flush() - { - std::lock_guard lock(s_mutex); - fflush(stderr); - for (const auto& callback : s_callbacks) { - if (callback.flush) { - callback.flush(callback.user_data); - } - } - s_needs_flushing = false; - } - - LogScopeRAII::LogScopeRAII(Verbosity verbosity, const char* file, unsigned line, const char* format, va_list vlist) : - _verbosity(verbosity), _file(file), _line(line) - { - this->Init(format, vlist); - } - - LogScopeRAII::LogScopeRAII(Verbosity verbosity, const char* file, unsigned line, const char* format, ...) : - _verbosity(verbosity), _file(file), _line(line) - { - va_list vlist; - va_start(vlist, format); - this->Init(format, vlist); - va_end(vlist); - } - - LogScopeRAII::~LogScopeRAII() - { - if (_file) { - std::lock_guard lock(s_mutex); - if (_indent_stderr && s_stderr_indentation > 0) { - --s_stderr_indentation; - } - for (auto& p : s_callbacks) { - // Note: Callback indentation cannot change! - if (_verbosity <= p.verbosity) { - // in unlikely case this callback is new - if (p.indentation > 0) { - --p.indentation; - } - } - } -#if LOGURU_VERBOSE_SCOPE_ENDINGS - auto duration_sec = static_cast(now_ns() - _start_time_ns) / 1e9; -#if LOGURU_USE_FMTLIB - auto buff = textprintf("{:.{}f} s: {:s}", duration_sec, LOGURU_SCOPE_TIME_PRECISION, _name); -#else - auto buff = textprintf("%.*f s: %s", LOGURU_SCOPE_TIME_PRECISION, duration_sec, _name); -#endif - log_to_everywhere(1, _verbosity, _file, _line, "} ", buff.c_str()); -#else - log_to_everywhere(1, _verbosity, _file, _line, "}", ""); -#endif - } - } - - void LogScopeRAII::Init(const char* format, va_list vlist) - { - if (_verbosity <= current_verbosity_cutoff()) { - std::lock_guard lock(s_mutex); - _indent_stderr = (_verbosity <= g_stderr_verbosity); - _start_time_ns = now_ns(); - vsnprintf(_name, sizeof(_name), format, vlist); - log_to_everywhere(1, _verbosity, _file, _line, "{ ", _name); - - if (_indent_stderr) { - ++s_stderr_indentation; - } - - for (auto& p : s_callbacks) { - if (_verbosity <= p.verbosity) { - ++p.indentation; - } - } - } else { - _file = nullptr; - } - } - -#if LOGURU_USE_FMTLIB - void vlog_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, const char* format, fmt::format_args args) - { - auto formatted = fmt::vformat(format, args); - log_to_everywhere(stack_trace_skip + 1, Verbosity_FATAL, file, line, expr, formatted.c_str()); - abort(); // log_to_everywhere already does this, but this makes the analyzer happy. - } -#else - void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - auto buff = vtextprintf(format, vlist); - log_to_everywhere(stack_trace_skip + 1, Verbosity_FATAL, file, line, expr, buff.c_str()); - va_end(vlist); - abort(); // log_to_everywhere already does this, but this makes the analyzer happy. - } -#endif - - void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line) - { - log_and_abort(stack_trace_skip + 1, expr, file, line, " "); - } - - // ---------------------------------------------------------------------------- - // Streams: - -#if LOGURU_USE_FMTLIB - template - std::string vstrprintf(const char* format, const Args&... args) - { - auto text = textprintf(format, args...); - std::string result = text.c_str(); - return result; - } - - template - std::string strprintf(const char* format, const Args&... args) - { - return vstrprintf(format, args...); - } -#else - std::string vstrprintf(const char* format, va_list vlist) - { - auto text = vtextprintf(format, vlist); - std::string result = text.c_str(); - return result; - } - - std::string strprintf(const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - auto result = vstrprintf(format, vlist); - va_end(vlist); - return result; - } -#endif - -#if LOGURU_WITH_STREAMS - - StreamLogger::~StreamLogger() noexcept(false) - { - auto message = _ss.str(); - log(_verbosity, _file, _line, LOGURU_FMT(s), message.c_str()); - } - - AbortLogger::~AbortLogger() noexcept(false) - { - auto message = _ss.str(); - loguru::log_and_abort(1, _expr, _file, _line, LOGURU_FMT(s), message.c_str()); - } - -#endif // LOGURU_WITH_STREAMS - - // ---------------------------------------------------------------------------- - // 888888 88""Yb 88""Yb dP"Yb 88""Yb dP""b8 dP"Yb 88b 88 888888 888888 Yb dP 888888 - // 88__ 88__dP 88__dP dP Yb 88__dP dP `" dP Yb 88Yb88 88 88__ YbdP 88 - // 88"" 88"Yb 88"Yb Yb dP 88"Yb Yb Yb dP 88 Y88 88 88"" dPYb 88 - // 888888 88 Yb 88 Yb YbodP 88 Yb YboodP YbodP 88 Y8 88 888888 dP Yb 88 - // ---------------------------------------------------------------------------- - - struct StringStream { - std::string str; - }; - - // Use this in your EcPrinter implementations. - void stream_print(StringStream & out_string_stream, const char* text) - { - out_string_stream.str += text; - } - - // ---------------------------------------------------------------------------- - - using ECPtr = EcEntryBase*; - -#if defined(_WIN32) || (defined(__APPLE__) && !TARGET_OS_IPHONE) -#ifdef __APPLE__ -#define LOGURU_THREAD_LOCAL __thread -#else -#define LOGURU_THREAD_LOCAL thread_local -#endif - static LOGURU_THREAD_LOCAL ECPtr thread_ec_ptr = nullptr; - - ECPtr& get_thread_ec_head_ref() - { - return thread_ec_ptr; - } -#else // !thread_local - static pthread_once_t s_ec_pthread_once = PTHREAD_ONCE_INIT; - static pthread_key_t s_ec_pthread_key; - - void free_ec_head_ref(void* io_error_context) - { - delete reinterpret_cast(io_error_context); - } - - void ec_make_pthread_key() - { - (void)pthread_key_create(&s_ec_pthread_key, free_ec_head_ref); - } - - ECPtr& get_thread_ec_head_ref() - { - (void)pthread_once(&s_ec_pthread_once, ec_make_pthread_key); - auto ec = reinterpret_cast(pthread_getspecific(s_ec_pthread_key)); - if (ec == nullptr) { - ec = new ECPtr(nullptr); - (void)pthread_setspecific(s_ec_pthread_key, ec); - } - return *ec; - } -#endif // !thread_local - - // ---------------------------------------------------------------------------- - - EcHandle get_thread_ec_handle() - { - return get_thread_ec_head_ref(); - } - - Text get_error_context() - { - return get_error_context_for(get_thread_ec_head_ref()); - } - - Text get_error_context_for(const EcEntryBase * ec_head) - { - std::vector stack; - while (ec_head) { - stack.push_back(ec_head); - ec_head = ec_head->_previous; - } - std::reverse(stack.begin(), stack.end()); - - StringStream result; - if (!stack.empty()) { - result.str += "------------------------------------------------\n"; - for (auto entry : stack) { - const auto description = std::string(entry->_descr) + ":"; -#if LOGURU_USE_FMTLIB - auto prefix = textprintf("[ErrorContext] {.{}s}:{:-5u} {:-20s} ", - filename(entry->_file), LOGURU_FILENAME_WIDTH, entry->_line, description.c_str()); -#else - auto prefix = textprintf("[ErrorContext] %*s:%-5u %-20s ", - LOGURU_FILENAME_WIDTH, filename(entry->_file), entry->_line, description.c_str()); -#endif - result.str += prefix.c_str(); - entry->print_value(result); - result.str += "\n"; - } - result.str += "------------------------------------------------"; - } - return Text(STRDUP(result.str.c_str())); - } - - EcEntryBase::EcEntryBase(const char* file, unsigned line, const char* descr) - : _file(file), _line(line), _descr(descr) - { - EcEntryBase*& ec_head = get_thread_ec_head_ref(); - _previous = ec_head; - ec_head = this; - } - - EcEntryBase::~EcEntryBase() - { - get_thread_ec_head_ref() = _previous; - } - - // ------------------------------------------------------------------------ - - Text ec_to_text(const char* value) - { - // Add quotes around the string to make it obvious where it begin and ends. - // This is great for detecting erroneous leading or trailing spaces in e.g. an identifier. - auto str = "\"" + std::string(value) + "\""; - return Text{ STRDUP(str.c_str()) }; - } - - Text ec_to_text(char c) - { - // Add quotes around the character to make it obvious where it begin and ends. - std::string str = "'"; - - auto write_hex_digit = [&](unsigned num) - { - if (num < 10u) { str += char('0' + num); } else { str += char('a' + num - 10); } - }; - - auto write_hex_16 = [&](uint16_t n) - { - write_hex_digit((n >> 12u) & 0x0f); - write_hex_digit((n >> 8u) & 0x0f); - write_hex_digit((n >> 4u) & 0x0f); - write_hex_digit((n >> 0u) & 0x0f); - }; - - if (c == '\\') { str += "\\\\"; } else if (c == '\"') { str += "\\\""; } else if (c == '\'') { str += "\\\'"; } else if (c == '\0') { str += "\\0"; } else if (c == '\b') { str += "\\b"; } else if (c == '\f') { str += "\\f"; } else if (c == '\n') { str += "\\n"; } else if (c == '\r') { str += "\\r"; } else if (c == '\t') { str += "\\t"; } else if (0 <= c && c < 0x20) { - str += "\\u"; - write_hex_16(static_cast(c)); - } else { str += c; } - - str += "'"; - - return Text{ STRDUP(str.c_str()) }; - } - -#define DEFINE_EC(Type) \ - Text ec_to_text(Type value) \ - { \ - auto str = std::to_string(value); \ - return Text{STRDUP(str.c_str())}; \ - } - - DEFINE_EC(int) - DEFINE_EC(unsigned int) - DEFINE_EC(long) - DEFINE_EC(unsigned long) - DEFINE_EC(long long) - DEFINE_EC(unsigned long long) - DEFINE_EC(float) - DEFINE_EC(double) - DEFINE_EC(long double) - -#undef DEFINE_EC - - Text ec_to_text(EcHandle ec_handle) - { - Text parent_ec = get_error_context_for(ec_handle); - size_t buffer_size = strlen(parent_ec.c_str()) + 2; - char* with_newline = reinterpret_cast(malloc(buffer_size)); - with_newline[0] = '\n'; -#ifdef _WIN32 - strncpy_s(with_newline + 1, buffer_size, parent_ec.c_str(), buffer_size - 2); -#else - strcpy(with_newline + 1, parent_ec.c_str()); -#endif - return Text(with_newline); - } - - // ---------------------------------------------------------------------------- - -} // namespace loguru - -// ---------------------------------------------------------------------------- -// .dP"Y8 88 dP""b8 88b 88 db 88 .dP"Y8 -// `Ybo." 88 dP `" 88Yb88 dPYb 88 `Ybo." -// o.`Y8b 88 Yb "88 88 Y88 dP__Yb 88 .o o.`Y8b -// 8bodP' 88 YboodP 88 Y8 dP""""Yb 88ood8 8bodP' -// ---------------------------------------------------------------------------- - -#ifdef _WIN32 -namespace loguru { - void install_signal_handlers(const SignalOptions& signal_options) - { - (void)signal_options; - // TODO: implement signal handlers on windows - } -} // namespace loguru - -#else // _WIN32 - -namespace loguru { - void write_to_stderr(const char* data, size_t size) - { - auto result = write(STDERR_FILENO, data, size); - (void)result; // Ignore errors. - } - - void write_to_stderr(const char* data) - { - write_to_stderr(data, strlen(data)); - } - - void call_default_signal_handler(int signal_number) - { - struct sigaction sig_action; - memset(&sig_action, 0, sizeof(sig_action)); - sigemptyset(&sig_action.sa_mask); - sig_action.sa_handler = SIG_DFL; - sigaction(signal_number, &sig_action, NULL); - kill(getpid(), signal_number); - } - - void signal_handler(int signal_number, siginfo_t*, void*) - { - const char* signal_name = "UNKNOWN SIGNAL"; - - if (signal_number == SIGABRT) { signal_name = "SIGABRT"; } - if (signal_number == SIGBUS) { signal_name = "SIGBUS"; } - if (signal_number == SIGFPE) { signal_name = "SIGFPE"; } - if (signal_number == SIGILL) { signal_name = "SIGILL"; } - if (signal_number == SIGINT) { signal_name = "SIGINT"; } - if (signal_number == SIGSEGV) { signal_name = "SIGSEGV"; } - if (signal_number == SIGTERM) { signal_name = "SIGTERM"; } - - // -------------------------------------------------------------------- - /* There are few things that are safe to do in a signal handler, - but writing to stderr is one of them. - So we first print out what happened to stderr so we're sure that gets out, - then we do the unsafe things, like logging the stack trace. - */ - - if (g_colorlogtostderr && s_terminal_has_color) { - write_to_stderr(terminal_reset()); - write_to_stderr(terminal_bold()); - write_to_stderr(terminal_light_red()); - } - write_to_stderr("\n"); - write_to_stderr("Loguru caught a signal: "); - write_to_stderr(signal_name); - write_to_stderr("\n"); - if (g_colorlogtostderr && s_terminal_has_color) { - write_to_stderr(terminal_reset()); - } - - // -------------------------------------------------------------------- - - if (s_signal_options.unsafe_signal_handler) { - // -------------------------------------------------------------------- - /* Now we do unsafe things. This can for example lead to deadlocks if - the signal was triggered from the system's memory management functions - and the code below tries to do allocations. - */ - - flush(); - char preamble_buff[LOGURU_PREAMBLE_WIDTH]; - print_preamble(preamble_buff, sizeof(preamble_buff), Verbosity_FATAL, "", 0); - auto message = Message{ Verbosity_FATAL, "", 0, preamble_buff, "", "Signal: ", signal_name }; - try { - log_message(1, message, false, false); - } - catch (...) { - // This can happed due to s_fatal_handler. - write_to_stderr("Exception caught and ignored by Loguru signal handler.\n"); - } - flush(); - - // -------------------------------------------------------------------- - } - - call_default_signal_handler(signal_number); - } - - void install_signal_handlers(const SignalOptions& signal_options) - { - s_signal_options = signal_options; - - struct sigaction sig_action; - memset(&sig_action, 0, sizeof(sig_action)); - sigemptyset(&sig_action.sa_mask); - sig_action.sa_flags |= SA_SIGINFO; - sig_action.sa_sigaction = &signal_handler; - - if (signal_options.sigabrt) { - CHECK_F(sigaction(SIGABRT, &sig_action, NULL) != -1, "Failed to install handler for SIGABRT"); - } - if (signal_options.sigbus) { - CHECK_F(sigaction(SIGBUS, &sig_action, NULL) != -1, "Failed to install handler for SIGBUS"); - } - if (signal_options.sigfpe) { - CHECK_F(sigaction(SIGFPE, &sig_action, NULL) != -1, "Failed to install handler for SIGFPE"); - } - if (signal_options.sigill) { - CHECK_F(sigaction(SIGILL, &sig_action, NULL) != -1, "Failed to install handler for SIGILL"); - } - if (signal_options.sigint) { - CHECK_F(sigaction(SIGINT, &sig_action, NULL) != -1, "Failed to install handler for SIGINT"); - } - if (signal_options.sigsegv) { - CHECK_F(sigaction(SIGSEGV, &sig_action, NULL) != -1, "Failed to install handler for SIGSEGV"); - } - if (signal_options.sigterm) { - CHECK_F(sigaction(SIGTERM, &sig_action, NULL) != -1, "Failed to install handler for SIGTERM"); - } - } -} // namespace loguru - -#endif // _WIN32 - - -#if defined(__GNUC__) || defined(__clang__) -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif - -LOGURU_ANONYMOUS_NAMESPACE_END - -#endif // LOGURU_IMPLEMENTATION diff --git a/lib/log/loguru.hpp b/lib/log/loguru.hpp deleted file mode 100644 index 8917b79..0000000 --- a/lib/log/loguru.hpp +++ /dev/null @@ -1,1475 +0,0 @@ -/* -Loguru logging library for C++, by Emil Ernerfeldt. -www.github.com/emilk/loguru -If you find Loguru useful, please let me know on twitter or in a mail! -Twitter: @ernerfeldt -Mail: emil.ernerfeldt@gmail.com -Website: www.ilikebigbits.com - -# License - This software is in the public domain. Where that dedication is not - recognized, you are granted a perpetual, irrevocable license to - copy, modify and distribute it as you see fit. - -# Inspiration - Much of Loguru was inspired by GLOG, https://code.google.com/p/google-glog/. - The choice of public domain is fully due Sean T. Barrett - and his wonderful stb libraries at https://github.com/nothings/stb. - -# Version history - * Version 0.1.0 - 2015-03-22 - Works great on Mac. - * Version 0.2.0 - 2015-09-17 - Removed the only dependency. - * Version 0.3.0 - 2015-10-02 - Drop-in replacement for most of GLOG - * Version 0.4.0 - 2015-10-07 - Single-file! - * Version 0.5.0 - 2015-10-17 - Improved file logging - * Version 0.6.0 - 2015-10-24 - Add stack traces - * Version 0.7.0 - 2015-10-27 - Signals - * Version 0.8.0 - 2015-10-30 - Color logging. - * Version 0.9.0 - 2015-11-26 - ABORT_S and proper handling of FATAL - * Version 1.0.0 - 2016-02-14 - ERROR_CONTEXT - * Version 1.1.0 - 2016-02-19 - -v OFF, -v INFO etc - * Version 1.1.1 - 2016-02-20 - textprintf vs strprintf - * Version 1.1.2 - 2016-02-22 - Remove g_alsologtostderr - * Version 1.1.3 - 2016-02-29 - ERROR_CONTEXT as linked list - * Version 1.2.0 - 2016-03-19 - Add get_thread_name() - * Version 1.2.1 - 2016-03-20 - Minor fixes - * Version 1.2.2 - 2016-03-29 - Fix issues with set_fatal_handler throwing an exception - * Version 1.2.3 - 2016-05-16 - Log current working directory in loguru::init(). - * Version 1.2.4 - 2016-05-18 - Custom replacement for -v in loguru::init() by bjoernpollex - * Version 1.2.5 - 2016-05-18 - Add ability to print ERROR_CONTEXT of parent thread. - * Version 1.2.6 - 2016-05-19 - Bug fix regarding VLOG verbosity argument lacking (). - * Version 1.2.7 - 2016-05-23 - Fix PATH_MAX problem. - * Version 1.2.8 - 2016-05-26 - Add shutdown() and remove_all_callbacks() - * Version 1.2.9 - 2016-06-09 - Use a monotonic clock for uptime. - * Version 1.3.0 - 2016-07-20 - Fix issues with callback flush/close not being called. - * Version 1.3.1 - 2016-07-20 - Add LOGURU_UNSAFE_SIGNAL_HANDLER to toggle stacktrace on signals. - * Version 1.3.2 - 2016-07-20 - Add loguru::arguments() - * Version 1.4.0 - 2016-09-15 - Semantic versioning + add loguru::create_directories - * Version 1.4.1 - 2016-09-29 - Customize formating with LOGURU_FILENAME_WIDTH - * Version 1.5.0 - 2016-12-22 - LOGURU_USE_FMTLIB by kolis and LOGURU_WITH_FILEABS by scinart - * Version 1.5.1 - 2017-08-08 - Terminal colors on Windows 10 thanks to looki - * Version 1.6.0 - 2018-01-03 - Add LOGURU_RTTI and LOGURU_STACKTRACES settings - * Version 1.7.0 - 2018-01-03 - Add ability to turn off the preamble with loguru::g_preamble - * Version 1.7.1 - 2018-04-05 - Add function get_fatal_handler - * Version 1.7.2 - 2018-04-22 - Fix a bug where large file names could cause stack corruption (thanks @ccamporesi) - * Version 1.8.0 - 2018-04-23 - Shorten long file names to keep preamble fixed width - * Version 1.9.0 - 2018-09-22 - Adjust terminal colors, add LOGURU_VERBOSE_SCOPE_ENDINGS, add LOGURU_SCOPE_TIME_PRECISION, add named log levels - * Version 2.0.0 - 2018-09-22 - Split loguru.hpp into loguru.hpp and loguru.cpp - * Version 2.1.0 - 2019-09-23 - Update fmtlib + add option to loguru::init to NOT set main thread name. - * Version 2.2.0 - 2020-07-31 - Replace LOGURU_CATCH_SIGABRT with struct SignalOptions - -# Compiling - Just include where you want to use Loguru. - Then, in one .cpp file #include - Make sure you compile with -std=c++11 -lstdc++ -lpthread -ldl - -# Usage - For details, please see the official documentation at emilk.github.io/loguru - - #include - - int main(int argc, char* argv[]) { - loguru::init(argc, argv); - - // Put every log message in "everything.log": - loguru::add_file("everything.log", loguru::Append, loguru::Verbosity_MAX); - - LOG_F(INFO, "The magic number is %d", 42); - } - -*/ - -#if defined(LOGURU_IMPLEMENTATION) -#error "You are defining LOGURU_IMPLEMENTATION. This is for older versions of Loguru. You should now instead include loguru.cpp (or build it and link with it)" -#endif - -// Disable all warnings from gcc/clang: -#if defined(__clang__) -#pragma clang system_header -#elif defined(__GNUC__) -#pragma GCC system_header -#endif - -#ifndef LOGURU_HAS_DECLARED_FORMAT_HEADER -#define LOGURU_HAS_DECLARED_FORMAT_HEADER - -// Semantic versioning. Loguru version can be printed with printf("%d.%d.%d", LOGURU_VERSION_MAJOR, LOGURU_VERSION_MINOR, LOGURU_VERSION_PATCH); -#define LOGURU_VERSION_MAJOR 2 -#define LOGURU_VERSION_MINOR 1 -#define LOGURU_VERSION_PATCH 0 - -#if defined(_MSC_VER) -#include // Needed for _In_z_ etc annotations -#endif - -#if defined(__linux__) || defined(__APPLE__) -#define LOGURU_SYSLOG 1 -#else -#define LOGURU_SYSLOG 0 -#endif - -// ---------------------------------------------------------------------------- - -#ifndef LOGURU_EXPORT - // Define to your project's export declaration if needed for use in a shared library. -#define LOGURU_EXPORT -#endif - -#ifndef LOGURU_SCOPE_TEXT_SIZE - // Maximum length of text that can be printed by a LOG_SCOPE. - // This should be long enough to get most things, but short enough not to clutter the stack. -#define LOGURU_SCOPE_TEXT_SIZE 196 -#endif - -#ifndef LOGURU_FILENAME_WIDTH - // Width of the column containing the file name -#define LOGURU_FILENAME_WIDTH 23 -#endif - -#ifndef LOGURU_THREADNAME_WIDTH - // Width of the column containing the thread name -#define LOGURU_THREADNAME_WIDTH 16 -#endif - -#ifndef LOGURU_SCOPE_TIME_PRECISION - // Resolution of scope timers. 3=ms, 6=us, 9=ns -#define LOGURU_SCOPE_TIME_PRECISION 3 -#endif - -#ifdef LOGURU_CATCH_SIGABRT -#error "You are defining LOGURU_CATCH_SIGABRT. This is for older versions of Loguru. You should now instead set the options passed to loguru::init" -#endif - -#ifndef LOGURU_VERBOSE_SCOPE_ENDINGS - // Show milliseconds and scope name at end of scope. -#define LOGURU_VERBOSE_SCOPE_ENDINGS 1 -#endif - -#ifndef LOGURU_REDEFINE_ASSERT -#define LOGURU_REDEFINE_ASSERT 0 -#endif - -#ifndef LOGURU_WITH_STREAMS -#define LOGURU_WITH_STREAMS 0 -#endif - -#ifndef LOGURU_REPLACE_GLOG -#define LOGURU_REPLACE_GLOG 0 -#endif - -#if LOGURU_REPLACE_GLOG -#undef LOGURU_WITH_STREAMS -#define LOGURU_WITH_STREAMS 1 -#endif - -#if defined(LOGURU_UNSAFE_SIGNAL_HANDLER) -#error "You are defining LOGURU_UNSAFE_SIGNAL_HANDLER. This is for older versions of Loguru. You should now instead set the unsafe_signal_handler option when you call loguru::init." -#endif - -#if LOGURU_IMPLEMENTATION -#undef LOGURU_WITH_STREAMS -#define LOGURU_WITH_STREAMS 1 -#endif - -#ifndef LOGURU_USE_FMTLIB -#define LOGURU_USE_FMTLIB 0 -#endif - -#ifndef LOGURU_USE_LOCALE -#define LOGURU_USE_LOCALE 0 -#endif - -#ifndef LOGURU_WITH_FILEABS -#define LOGURU_WITH_FILEABS 0 -#endif - -#ifndef LOGURU_RTTI -#if defined(__clang__) -#if __has_feature(cxx_rtti) -#define LOGURU_RTTI 1 -#endif -#elif defined(__GNUG__) -#if defined(__GXX_RTTI) -#define LOGURU_RTTI 1 -#endif -#elif defined(_MSC_VER) -#if defined(_CPPRTTI) -#define LOGURU_RTTI 1 -#endif -#endif -#endif - -#ifdef LOGURU_USE_ANONYMOUS_NAMESPACE -#define LOGURU_ANONYMOUS_NAMESPACE_BEGIN namespace { -#define LOGURU_ANONYMOUS_NAMESPACE_END } -#else -#define LOGURU_ANONYMOUS_NAMESPACE_BEGIN -#define LOGURU_ANONYMOUS_NAMESPACE_END -#endif - -// -------------------------------------------------------------------- -// Utility macros - -#define LOGURU_CONCATENATE_IMPL(s1, s2) s1 ## s2 -#define LOGURU_CONCATENATE(s1, s2) LOGURU_CONCATENATE_IMPL(s1, s2) - -#ifdef __COUNTER__ -# define LOGURU_ANONYMOUS_VARIABLE(str) LOGURU_CONCATENATE(str, __COUNTER__) -#else -# define LOGURU_ANONYMOUS_VARIABLE(str) LOGURU_CONCATENATE(str, __LINE__) -#endif - -#if defined(__clang__) || defined(__GNUC__) - // Helper macro for declaring functions as having similar signature to printf. - // This allows the compiler to catch format errors at compile-time. -#define LOGURU_PRINTF_LIKE(fmtarg, firstvararg) __attribute__((__format__ (__printf__, fmtarg, firstvararg))) -#define LOGURU_FORMAT_STRING_TYPE const char* -#elif defined(_MSC_VER) -#define LOGURU_PRINTF_LIKE(fmtarg, firstvararg) -#define LOGURU_FORMAT_STRING_TYPE _In_z_ _Printf_format_string_ const char* -#else -#define LOGURU_PRINTF_LIKE(fmtarg, firstvararg) -#define LOGURU_FORMAT_STRING_TYPE const char* -#endif - -// Used to mark log_and_abort for the benefit of the static analyzer and optimizer. -#if defined(_MSC_VER) -#define LOGURU_NORETURN __declspec(noreturn) -#else -#define LOGURU_NORETURN __attribute__((noreturn)) -#endif - -#if defined(_MSC_VER) -#define LOGURU_PREDICT_FALSE(x) (x) -#define LOGURU_PREDICT_TRUE(x) (x) -#else -#define LOGURU_PREDICT_FALSE(x) (__builtin_expect(x, 0)) -#define LOGURU_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) -#endif - -#if LOGURU_USE_FMTLIB -#include -#define LOGURU_FMT(x) "{:" #x "}" -#else -#define LOGURU_FMT(x) "%" #x -#endif - -#ifdef _WIN32 -#define STRDUP(str) _strdup(str) -#else -#define STRDUP(str) strdup(str) -#endif - -#include - -// -------------------------------------------------------------------- -LOGURU_ANONYMOUS_NAMESPACE_BEGIN - -namespace loguru { - // Simple RAII ownership of a char*. - class LOGURU_EXPORT Text { - public: - explicit Text(char* owned_str) : _str(owned_str) {} - ~Text(); - Text(Text&& t) - { - _str = t._str; - t._str = nullptr; - } - Text(Text& t) = delete; - Text& operator=(Text& t) = delete; - void operator=(Text&& t) = delete; - - const char* c_str() const { return _str; } - bool empty() const { return _str == nullptr || *_str == '\0'; } - - char* release() - { - auto result = _str; - _str = nullptr; - return result; - } - - private: - char* _str; - }; - - // Like printf, but returns the formated text. -#if LOGURU_USE_FMTLIB - LOGURU_EXPORT - Text vtextprintf(const char* format, fmt::format_args args); - - template - LOGURU_EXPORT - Text textprintf(LOGURU_FORMAT_STRING_TYPE format, const Args&... args) - { - return vtextprintf(format, fmt::make_format_args(args...)); - } -#else - LOGURU_EXPORT - Text textprintf(LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(1, 2); -#endif - - // Overloaded for variadic template matching. - LOGURU_EXPORT - Text textprintf(); - - using Verbosity = int; - -#undef FATAL -#undef ERROR -#undef WARNING -#undef INFO -#undef MAX - - enum NamedVerbosity : Verbosity { - // Used to mark an invalid verbosity. Do not log to this level. - Verbosity_INVALID = -10, // Never do LOG_F(INVALID) - - // You may use Verbosity_OFF on g_stderr_verbosity, but for nothing else! - Verbosity_OFF = -9, // Never do LOG_F(OFF) - - // Prefer to use ABORT_F or ABORT_S over LOG_F(FATAL) or LOG_S(FATAL). - Verbosity_FATAL = -3, - Verbosity_ERROR = -2, - Verbosity_WARNING = -1, - - // Normal messages. By default written to stderr. - Verbosity_INFO = 0, - - // Same as Verbosity_INFO in every way. - Verbosity_0 = 0, - - // Verbosity levels 1-9 are generally not written to stderr, but are written to file. - Verbosity_1 = +1, - Verbosity_2 = +2, - Verbosity_3 = +3, - Verbosity_4 = +4, - Verbosity_5 = +5, - Verbosity_6 = +6, - Verbosity_7 = +7, - Verbosity_8 = +8, - Verbosity_9 = +9, - - // Do not use higher verbosity levels, as that will make grepping log files harder. - Verbosity_MAX = +9, - }; - - struct Message { - // You would generally print a Message by just concatenating the buffers without spacing. - // Optionally, ignore preamble and indentation. - Verbosity verbosity; // Already part of preamble - const char* filename; // Already part of preamble - unsigned line; // Already part of preamble - const char* preamble; // Date, time, uptime, thread, file:line, verbosity. - const char* indentation; // Just a bunch of spacing. - const char* prefix; // Assertion failure info goes here (or ""). - const char* message; // User message goes here. - }; - - /* Everything with a verbosity equal or greater than g_stderr_verbosity will be - written to stderr. You can set this in code or via the -v argument. - Set to loguru::Verbosity_OFF to write nothing to stderr. - Default is 0, i.e. only log ERROR, WARNING and INFO are written to stderr. - */ - LOGURU_EXPORT extern Verbosity g_stderr_verbosity; - LOGURU_EXPORT extern bool g_colorlogtostderr; // True by default. - LOGURU_EXPORT extern unsigned g_flush_interval_ms; // 0 (unbuffered) by default. - LOGURU_EXPORT extern bool g_preamble_header; // Prepend each log start by a descriptions line with all columns name? True by default. - LOGURU_EXPORT extern bool g_preamble; // Prefix each log line with date, time etc? True by default. - - /* Specify the verbosity used by loguru to log its info messages including the header - logged when logged::init() is called or on exit. Default is 0 (INFO). - */ - LOGURU_EXPORT extern Verbosity g_internal_verbosity; - - // Turn off individual parts of the preamble - LOGURU_EXPORT extern bool g_preamble_date; // The date field - LOGURU_EXPORT extern bool g_preamble_time; // The time of the current day - LOGURU_EXPORT extern bool g_preamble_uptime; // The time since init call - LOGURU_EXPORT extern bool g_preamble_thread; // The logging thread - LOGURU_EXPORT extern bool g_preamble_file; // The file from which the log originates from - LOGURU_EXPORT extern bool g_preamble_verbose; // The verbosity field - LOGURU_EXPORT extern bool g_preamble_pipe; // The pipe symbol right before the message - - // May not throw! - typedef void (*log_handler_t)(void* user_data, const Message& message); - typedef void (*close_handler_t)(void* user_data); - typedef void (*flush_handler_t)(void* user_data); - - // May throw if that's how you'd like to handle your errors. - typedef void (*fatal_handler_t)(const Message& message); - - // Given a verbosity level, return the level's name or nullptr. - typedef const char* (*verbosity_to_name_t)(Verbosity verbosity); - - // Given a verbosity level name, return the verbosity level or - // Verbosity_INVALID if name is not recognized. - typedef Verbosity(*name_to_verbosity_t)(const char* name); - - struct SignalOptions { - /// Make Loguru try to do unsafe but useful things, - /// like printing a stack trace, when catching signals. - /// This may lead to bad things like deadlocks in certain situations. - bool unsafe_signal_handler = true; - - /// Should Loguru catch SIGABRT ? - bool sigabrt = true; - - /// Should Loguru catch SIGBUS ? - bool sigbus = true; - - /// Should Loguru catch SIGFPE ? - bool sigfpe = true; - - /// Should Loguru catch SIGILL ? - bool sigill = true; - - /// Should Loguru catch SIGINT ? - bool sigint = true; - - /// Should Loguru catch SIGSEGV ? - bool sigsegv = true; - - /// Should Loguru catch SIGTERM ? - bool sigterm = true; - - static SignalOptions none() - { - SignalOptions options; - options.unsafe_signal_handler = false; - options.sigabrt = false; - options.sigbus = false; - options.sigfpe = false; - options.sigill = false; - options.sigint = false; - options.sigsegv = false; - options.sigterm = false; - return options; - } - }; - - // Runtime options passed to loguru::init - struct Options { - // This allows you to use something else instead of "-v" via verbosity_flag. - // Set to nullptr if you don't want Loguru to parse verbosity from the args. - const char* verbosity_flag = "-v"; - - // loguru::init will set the name of the calling thread to this. - // If you don't want Loguru to set the name of the main thread, - // set this to nullptr. - // NOTE: on SOME platforms loguru::init will only overwrite the thread name - // if a thread name has not already been set. - // To always set a thread name, use loguru::set_thread_name instead. - const char* main_thread_name = "main thread"; - - SignalOptions signal_options; - }; - - /* Should be called from the main thread. - You don't *need* to call this, but if you do you get: - * Signal handlers installed - * Program arguments logged - * Working dir logged - * Optional -v verbosity flag parsed - * Main thread name set to "main thread" - * Explanation of the preamble (date, thread name, etc) logged - - loguru::init() will look for arguments meant for loguru and remove them. - Arguments meant for loguru are: - -v n Set loguru::g_stderr_verbosity level. Examples: - -v 3 Show verbosity level 3 and lower. - -v 0 Only show INFO, WARNING, ERROR, FATAL (default). - -v INFO Only show INFO, WARNING, ERROR, FATAL (default). - -v WARNING Only show WARNING, ERROR, FATAL. - -v ERROR Only show ERROR, FATAL. - -v FATAL Only show FATAL. - -v OFF Turn off logging to stderr. - - Tip: You can set g_stderr_verbosity before calling loguru::init. - That way you can set the default but have the user override it with the -v flag. - Note that -v does not affect file logging (see loguru::add_file). - - You can you something other than the -v flag by setting the verbosity_flag option. - */ - LOGURU_EXPORT - void init(int& argc, char* argv[], const Options& options = {}); - - // Will call remove_all_callbacks(). After calling this, logging will still go to stderr. - // You generally don't need to call this. - LOGURU_EXPORT - void shutdown(); - - // What ~ will be replaced with, e.g. "/home/your_user_name/" - LOGURU_EXPORT - const char* home_dir(); - - /* Returns the name of the app as given in argv[0] but without leading path. - That is, if argv[0] is "../foo/app" this will return "app". - */ - LOGURU_EXPORT - const char* argv0_filename(); - - // Returns all arguments given to loguru::init(), but escaped with a single space as separator. - LOGURU_EXPORT - const char* arguments(); - - // Returns the path to the current working dir when loguru::init() was called. - LOGURU_EXPORT - const char* current_dir(); - - // Returns the part of the path after the last / or \ (if any). - LOGURU_EXPORT - const char* filename(const char* path); - - // e.g. "foo/bar/baz.ext" will create the directories "foo/" and "foo/bar/" - LOGURU_EXPORT - bool create_directories(const char* file_path_const); - - // Writes date and time with millisecond precision, e.g. "20151017_161503.123" - LOGURU_EXPORT - void write_date_time(char* buff, unsigned long long buff_size); - - // Helper: thread-safe version strerror - LOGURU_EXPORT - Text errno_as_text(); - - /* Given a prefix of e.g. "~/loguru/" this might return - "/home/your_username/loguru/app_name/20151017_161503.123.log" - - where "app_name" is a sanitized version of argv[0]. - */ - LOGURU_EXPORT - void suggest_log_path(const char* prefix, char* buff, unsigned long long buff_size); - - enum FileMode { Truncate, Append }; - - /* Will log to a file at the given path. - Any logging message with a verbosity lower or equal to - the given verbosity will be included. - The function will create all directories in 'path' if needed. - If path starts with a ~, it will be replaced with loguru::home_dir() - To stop the file logging, just call loguru::remove_callback(path) with the same path. - */ - LOGURU_EXPORT - bool add_file(const char* path, FileMode mode, Verbosity verbosity); - - LOGURU_EXPORT - // Send logs to syslog with LOG_USER facility (see next call) - bool add_syslog(const char* app_name, Verbosity verbosity); - LOGURU_EXPORT - // Send logs to syslog with your own choice of facility (LOG_USER, LOG_AUTH, ...) - // see loguru.cpp: syslog_log() for more details. - bool add_syslog(const char* app_name, Verbosity verbosity, int facility); - - /* Will be called right before abort(). - You can for instance use this to print custom error messages, or throw an exception. - Feel free to call LOG:ing function from this, but not FATAL ones! */ - LOGURU_EXPORT - void set_fatal_handler(fatal_handler_t handler); - - // Get the current fatal handler, if any. Default value is nullptr. - LOGURU_EXPORT - fatal_handler_t get_fatal_handler(); - - /* Will be called on each log messages with a verbosity less or equal to the given one. - Useful for displaying messages on-screen in a game, for example. - The given on_close is also expected to flush (if desired). - */ - LOGURU_EXPORT - void add_callback( - const char* id, - log_handler_t callback, - void* user_data, - Verbosity verbosity, - close_handler_t on_close = nullptr, - flush_handler_t on_flush = nullptr); - - /* Set a callback that returns custom verbosity level names. If callback - is nullptr or returns nullptr, default log names will be used. - */ - LOGURU_EXPORT - void set_verbosity_to_name_callback(verbosity_to_name_t callback); - - /* Set a callback that returns the verbosity level matching a name. The - callback should return Verbosity_INVALID if the name is not - recognized. - */ - LOGURU_EXPORT - void set_name_to_verbosity_callback(name_to_verbosity_t callback); - - /* Get a custom name for a specific verbosity, if one exists, or nullptr. */ - LOGURU_EXPORT - const char* get_verbosity_name(Verbosity verbosity); - - /* Get the verbosity enum value from a custom 4-character level name, if one exists. - If the name does not match a custom level name, Verbosity_INVALID is returned. - */ - LOGURU_EXPORT - Verbosity get_verbosity_from_name(const char* name); - - // Returns true iff the callback was found (and removed). - LOGURU_EXPORT - bool remove_callback(const char* id); - - // Shut down all file logging and any other callback hooks installed. - LOGURU_EXPORT - void remove_all_callbacks(); - - // Returns the maximum of g_stderr_verbosity and all file/custom outputs. - LOGURU_EXPORT - Verbosity current_verbosity_cutoff(); - -#if LOGURU_USE_FMTLIB - // Internal functions - LOGURU_EXPORT - void vlog(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, fmt::format_args args); - LOGURU_EXPORT - void raw_vlog(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, fmt::format_args args); - - // Actual logging function. Use the LOG macro instead of calling this directly. - template - LOGURU_EXPORT - void log(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, const Args &... args) - { - vlog(verbosity, file, line, format, fmt::make_format_args(args...)); - } - - // Log without any preamble or indentation. - template - LOGURU_EXPORT - void raw_log(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, const Args &... args) - { - raw_vlog(verbosity, file, line, format, fmt::make_format_args(args...)); - } -#else // LOGURU_USE_FMTLIB? - // Actual logging function. Use the LOG macro instead of calling this directly. - LOGURU_EXPORT - void log(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(4, 5); - - // Actual logging function. - LOGURU_EXPORT - void vlog(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, va_list) LOGURU_PRINTF_LIKE(4, 0); - - // Log without any preamble or indentation. - LOGURU_EXPORT - void raw_log(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(4, 5); -#endif // !LOGURU_USE_FMTLIB - - // Helper class for LOG_SCOPE_F - class LOGURU_EXPORT LogScopeRAII { - public: - LogScopeRAII() : _file(nullptr) {} // No logging - LogScopeRAII(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, va_list vlist) LOGURU_PRINTF_LIKE(5, 0); - LogScopeRAII(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(5, 6); - ~LogScopeRAII(); - - void Init(LOGURU_FORMAT_STRING_TYPE format, va_list vlist) LOGURU_PRINTF_LIKE(2, 0); - -#if defined(_MSC_VER) && _MSC_VER > 1800 - // older MSVC default move ctors close the scope on move. See - // issue #43 - LogScopeRAII(LogScopeRAII&& other) - : _verbosity(other._verbosity) - , _file(other._file) - , _line(other._line) - , _indent_stderr(other._indent_stderr) - , _start_time_ns(other._start_time_ns) - { - // Make sure the tmp object's destruction doesn't close the scope: - other._file = nullptr; - - for (unsigned int i = 0; i < LOGURU_SCOPE_TEXT_SIZE; ++i) { - _name[i] = other._name[i]; - } - } -#else - LogScopeRAII(LogScopeRAII&&) = default; -#endif - - private: - LogScopeRAII(const LogScopeRAII&) = delete; - LogScopeRAII& operator=(const LogScopeRAII&) = delete; - void operator=(LogScopeRAII&&) = delete; - - Verbosity _verbosity; - const char* _file; // Set to null if we are disabled due to verbosity - unsigned _line; - bool _indent_stderr; // Did we? - long long _start_time_ns; - char _name[LOGURU_SCOPE_TEXT_SIZE]; - }; - - // Marked as 'noreturn' for the benefit of the static analyzer and optimizer. - // stack_trace_skip is the number of extrace stack frames to skip above log_and_abort. -#if LOGURU_USE_FMTLIB - LOGURU_EXPORT - LOGURU_NORETURN void vlog_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, fmt::format_args); - template - LOGURU_EXPORT - LOGURU_NORETURN void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, const Args&... args) - { - vlog_and_abort(stack_trace_skip, expr, file, line, format, fmt::make_format_args(args...)); - } -#else - LOGURU_EXPORT - LOGURU_NORETURN void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(5, 6); -#endif - LOGURU_EXPORT - LOGURU_NORETURN void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line); - - // Flush output to stderr and files. - // If g_flush_interval_ms is set to non-zero, this will be called automatically this often. - // If not set, you do not need to call this at all. - LOGURU_EXPORT - void flush(); - - template inline Text format_value(const T&) { return textprintf("N/A"); } - template<> inline Text format_value(const char& v) { return textprintf(LOGURU_FMT(c), v); } - template<> inline Text format_value(const int& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const float& v) { return textprintf(LOGURU_FMT(f), v); } - template<> inline Text format_value(const double& v) { return textprintf(LOGURU_FMT(f), v); } - -#if LOGURU_USE_FMTLIB - template<> inline Text format_value(const unsigned int& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const long& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const unsigned long& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const long long& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const unsigned long long& v) { return textprintf(LOGURU_FMT(d), v); } -#else - template<> inline Text format_value(const unsigned int& v) { return textprintf(LOGURU_FMT(u), v); } - template<> inline Text format_value(const long& v) { return textprintf(LOGURU_FMT(lu), v); } - template<> inline Text format_value(const unsigned long& v) { return textprintf(LOGURU_FMT(ld), v); } - template<> inline Text format_value(const long long& v) { return textprintf(LOGURU_FMT(llu), v); } - template<> inline Text format_value(const unsigned long long& v) { return textprintf(LOGURU_FMT(lld), v); } -#endif - - /* Thread names can be set for the benefit of readable logs. - If you do not set the thread name, a hex id will be shown instead. - These thread names may or may not be the same as the system thread names, - depending on the system. - Try to limit the thread name to 15 characters or less. */ - LOGURU_EXPORT - void set_thread_name(const char* name); - - /* Returns the thread name for this thread. - On most *nix systems this will return the system thread name (settable from both within and without Loguru). - On other systems it will return whatever you set in `set_thread_name()`; - If no thread name is set, this will return a hexadecimal thread id. - `length` should be the number of bytes available in the buffer. - 17 is a good number for length. - `right_align_hex_id` means any hexadecimal thread id will be written to the end of buffer. - */ - LOGURU_EXPORT - void get_thread_name(char* buffer, unsigned long long length, bool right_align_hex_id); - - /* Generates a readable stacktrace as a string. - 'skip' specifies how many stack frames to skip. - For instance, the default skip (1) means: - don't include the call to loguru::stacktrace in the stack trace. */ - LOGURU_EXPORT - Text stacktrace(int skip = 1); - - /* Add a string to be replaced with something else in the stack output. - - For instance, instead of having a stack trace look like this: - 0x41f541 some_function(std::basic_ofstream >&) - You can clean it up with: - auto verbose_type_name = loguru::demangle(typeid(std::ofstream).name()); - loguru::add_stack_cleanup(verbose_type_name.c_str(); "std::ofstream"); - So the next time you will instead see: - 0x41f541 some_function(std::ofstream&) - - `replace_with_this` must be shorter than `find_this`. - */ - LOGURU_EXPORT - void add_stack_cleanup(const char* find_this, const char* replace_with_this); - - // Example: demangle(typeid(std::ofstream).name()) -> "std::basic_ofstream >" - LOGURU_EXPORT - Text demangle(const char* name); - - // ------------------------------------------------------------------------ - /* - Not all terminals support colors, but if they do, and g_colorlogtostderr - is set, Loguru will write them to stderr to make errors in red, etc. - - You also have the option to manually use them, via the function below. - - Note, however, that if you do, the color codes could end up in your logfile! - - This means if you intend to use them functions you should either: - * Use them on the stderr/stdout directly (bypass Loguru). - * Don't add file outputs to Loguru. - * Expect some \e[1m things in your logfile. - - Usage: - printf("%sRed%sGreen%sBold green%sClear again\n", - loguru::terminal_red(), loguru::terminal_green(), - loguru::terminal_bold(), loguru::terminal_reset()); - - If the terminal at hand does not support colors the above output - will just not have funky \e[1m things showing. - */ - - // Do the output terminal support colors? - LOGURU_EXPORT - bool terminal_has_color(); - - // Colors - LOGURU_EXPORT const char* terminal_black(); - LOGURU_EXPORT const char* terminal_red(); - LOGURU_EXPORT const char* terminal_green(); - LOGURU_EXPORT const char* terminal_yellow(); - LOGURU_EXPORT const char* terminal_blue(); - LOGURU_EXPORT const char* terminal_purple(); - LOGURU_EXPORT const char* terminal_cyan(); - LOGURU_EXPORT const char* terminal_light_gray(); - LOGURU_EXPORT const char* terminal_light_red(); - LOGURU_EXPORT const char* terminal_white(); - - // Formating - LOGURU_EXPORT const char* terminal_bold(); - LOGURU_EXPORT const char* terminal_underline(); - - // You should end each line with this! - LOGURU_EXPORT const char* terminal_reset(); - - // -------------------------------------------------------------------- - // Error context related: - - struct StringStream; - - // Use this in your EcEntryBase::print_value overload. - LOGURU_EXPORT - void stream_print(StringStream& out_string_stream, const char* text); - - class LOGURU_EXPORT EcEntryBase { - public: - EcEntryBase(const char* file, unsigned line, const char* descr); - ~EcEntryBase(); - EcEntryBase(const EcEntryBase&) = delete; - EcEntryBase(EcEntryBase&&) = delete; - EcEntryBase& operator=(const EcEntryBase&) = delete; - EcEntryBase& operator=(EcEntryBase&&) = delete; - - virtual void print_value(StringStream& out_string_stream) const = 0; - - EcEntryBase* previous() const { return _previous; } - - // private: - const char* _file; - unsigned _line; - const char* _descr; - EcEntryBase* _previous; - }; - - template - class EcEntryData : public EcEntryBase { - public: - using Printer = Text(*)(T data); - - EcEntryData(const char* file, unsigned line, const char* descr, T data, Printer&& printer) - : EcEntryBase(file, line, descr), _data(data), _printer(printer) - { - } - - virtual void print_value(StringStream& out_string_stream) const override - { - const auto str = _printer(_data); - stream_print(out_string_stream, str.c_str()); - } - - private: - T _data; - Printer _printer; - }; - - // template - // class EcEntryLambda : public EcEntryBase - // { - // public: - // EcEntryLambda(const char* file, unsigned line, const char* descr, Printer&& printer) - // : EcEntryBase(file, line, descr), _printer(std::move(printer)) {} - - // virtual void print_value(StringStream& out_string_stream) const override - // { - // const auto str = _printer(); - // stream_print(out_string_stream, str.c_str()); - // } - - // private: - // Printer _printer; - // }; - - // template - // EcEntryLambda make_ec_entry_lambda(const char* file, unsigned line, const char* descr, Printer&& printer) - // { - // return {file, line, descr, std::move(printer)}; - // } - - template - struct decay_char_array { using type = T; }; - - template - struct decay_char_array { using type = const char*; }; - - template - struct make_const_ptr { using type = T; }; - - template - struct make_const_ptr { using type = const T*; }; - - template - struct make_ec_type { using type = typename make_const_ptr::type>::type; }; - - /* A stack trace gives you the names of the function at the point of a crash. - With ERROR_CONTEXT, you can also get the values of select local variables. - Usage: - - void process_customers(const std::string& filename) - { - ERROR_CONTEXT("Processing file", filename.c_str()); - for (int customer_index : ...) - { - ERROR_CONTEXT("Customer index", customer_index); - ... - } - } - - The context is in effect during the scope of the ERROR_CONTEXT. - Use loguru::get_error_context() to get the contents of the active error contexts. - - Example result: - - ------------------------------------------------ - [ErrorContext] main.cpp:416 Processing file: "customers.json" - [ErrorContext] main.cpp:417 Customer index: 42 - ------------------------------------------------ - - Error contexts are printed automatically on crashes, and only on crashes. - This makes them much faster than logging the value of a variable. - */ -#define ERROR_CONTEXT(descr, data) \ - const loguru::EcEntryData::type> \ - LOGURU_ANONYMOUS_VARIABLE(error_context_scope_)( \ - __FILE__, __LINE__, descr, data, \ - static_cast::type>::Printer>(loguru::ec_to_text) ) // For better error messages - - /* - #define ERROR_CONTEXT(descr, data) \ - const auto LOGURU_ANONYMOUS_VARIABLE(error_context_scope_)( \ - loguru::make_ec_entry_lambda(__FILE__, __LINE__, descr, \ - [=](){ return loguru::ec_to_text(data); })) - */ - - using EcHandle = const EcEntryBase*; - - /* - Get a light-weight handle to the error context stack on this thread. - The handle is valid as long as the current thread has no changes to its error context stack. - You can pass the handle to loguru::get_error_context on another thread. - This can be very useful for when you have a parent thread spawning several working threads, - and you want the error context of the parent thread to get printed (too) when there is an - error on the child thread. You can accomplish this thusly: - - void foo(const char* parameter) - { - ERROR_CONTEXT("parameter", parameter) - const auto parent_ec_handle = loguru::get_thread_ec_handle(); - - std::thread([=]{ - loguru::set_thread_name("child thread"); - ERROR_CONTEXT("parent context", parent_ec_handle); - dangerous_code(); - }.join(); - } - - */ - LOGURU_EXPORT - EcHandle get_thread_ec_handle(); - - // Get a string describing the current stack of error context. Empty string if there is none. - LOGURU_EXPORT - Text get_error_context(); - - // Get a string describing the error context of the given thread handle. - LOGURU_EXPORT - Text get_error_context_for(EcHandle ec_handle); - - // ------------------------------------------------------------------------ - - LOGURU_EXPORT Text ec_to_text(const char* data); - LOGURU_EXPORT Text ec_to_text(char data); - LOGURU_EXPORT Text ec_to_text(int data); - LOGURU_EXPORT Text ec_to_text(unsigned int data); - LOGURU_EXPORT Text ec_to_text(long data); - LOGURU_EXPORT Text ec_to_text(unsigned long data); - LOGURU_EXPORT Text ec_to_text(long long data); - LOGURU_EXPORT Text ec_to_text(unsigned long long data); - LOGURU_EXPORT Text ec_to_text(float data); - LOGURU_EXPORT Text ec_to_text(double data); - LOGURU_EXPORT Text ec_to_text(long double data); - LOGURU_EXPORT Text ec_to_text(EcHandle); - - /* - You can add ERROR_CONTEXT support for your own types by overloading ec_to_text. Here's how: - - some.hpp: - namespace loguru { - Text ec_to_text(MySmallType data) - Text ec_to_text(const MyBigType* data) - } // namespace loguru - - some.cpp: - namespace loguru { - Text ec_to_text(MySmallType small_value) - { - // Called only when needed, i.e. on a crash. - std::string str = small_value.as_string(); // Format 'small_value' here somehow. - return Text{STRDUP(str.c_str())}; - } - - Text ec_to_text(const MyBigType* big_value) - { - // Called only when needed, i.e. on a crash. - std::string str = big_value->as_string(); // Format 'big_value' here somehow. - return Text{STRDUP(str.c_str())}; - } - } // namespace loguru - - Any file that include some.hpp: - void foo(MySmallType small, const MyBigType& big) - { - ERROR_CONTEXT("Small", small); // Copy ´small` by value. - ERROR_CONTEXT("Big", &big); // `big` should not change during this scope! - .... - } - */ -} // namespace loguru - -LOGURU_ANONYMOUS_NAMESPACE_END - -// -------------------------------------------------------------------- -// Logging macros - -// LOG_F(2, "Only logged if verbosity is 2 or higher: %d", some_number); -#define VLOG_F(verbosity, ...) \ - ((verbosity) > loguru::current_verbosity_cutoff()) ? (void)0 \ - : loguru::log(verbosity, __FILE__, __LINE__, __VA_ARGS__) - -// LOG_F(INFO, "Foo: %d", some_number); -#define LOG_F(verbosity_name, ...) VLOG_F(loguru::Verbosity_ ## verbosity_name, __VA_ARGS__) - -#define VLOG_IF_F(verbosity, cond, ...) \ - ((verbosity) > loguru::current_verbosity_cutoff() || (cond) == false) \ - ? (void)0 \ - : loguru::log(verbosity, __FILE__, __LINE__, __VA_ARGS__) - -#define LOG_IF_F(verbosity_name, cond, ...) \ - VLOG_IF_F(loguru::Verbosity_ ## verbosity_name, cond, __VA_ARGS__) - -#define VLOG_SCOPE_F(verbosity, ...) \ - loguru::LogScopeRAII LOGURU_ANONYMOUS_VARIABLE(error_context_RAII_) = \ - ((verbosity) > loguru::current_verbosity_cutoff()) ? loguru::LogScopeRAII() : \ - loguru::LogScopeRAII(verbosity, __FILE__, __LINE__, __VA_ARGS__) - -// Raw logging - no preamble, no indentation. Slightly faster than full logging. -#define RAW_VLOG_F(verbosity, ...) \ - ((verbosity) > loguru::current_verbosity_cutoff()) ? (void)0 \ - : loguru::raw_log(verbosity, __FILE__, __LINE__, __VA_ARGS__) - -#define RAW_LOG_F(verbosity_name, ...) RAW_VLOG_F(loguru::Verbosity_ ## verbosity_name, __VA_ARGS__) - -// Use to book-end a scope. Affects logging on all threads. -#define LOG_SCOPE_F(verbosity_name, ...) \ - VLOG_SCOPE_F(loguru::Verbosity_ ## verbosity_name, __VA_ARGS__) - -#define LOG_SCOPE_FUNCTION(verbosity_name) LOG_SCOPE_F(verbosity_name, __func__) - -// ----------------------------------------------- -// ABORT_F macro. Usage: ABORT_F("Cause of error: %s", error_str); - -// Message is optional -#define ABORT_F(...) loguru::log_and_abort(0, "ABORT: ", __FILE__, __LINE__, __VA_ARGS__) - -// -------------------------------------------------------------------- -// CHECK_F macros: - -#define CHECK_WITH_INFO_F(test, info, ...) \ - LOGURU_PREDICT_TRUE((test) == true) ? (void)0 : loguru::log_and_abort(0, "CHECK FAILED: " info " ", __FILE__, \ - __LINE__, ##__VA_ARGS__) - -/* Checked at runtime too. Will print error, then call fatal_handler (if any), then 'abort'. - Note that the test must be boolean. - CHECK_F(ptr); will not compile, but CHECK_F(ptr != nullptr); will. */ -#define CHECK_F(test, ...) CHECK_WITH_INFO_F(test, #test, ##__VA_ARGS__) - -#define CHECK_NOTNULL_F(x, ...) CHECK_WITH_INFO_F((x) != nullptr, #x " != nullptr", ##__VA_ARGS__) - -#define CHECK_OP_F(expr_left, expr_right, op, ...) \ - do \ - { \ - auto val_left = expr_left; \ - auto val_right = expr_right; \ - if (! LOGURU_PREDICT_TRUE(val_left op val_right)) \ - { \ - auto str_left = loguru::format_value(val_left); \ - auto str_right = loguru::format_value(val_right); \ - auto fail_info = loguru::textprintf("CHECK FAILED: " LOGURU_FMT(s) " " LOGURU_FMT(s) " " LOGURU_FMT(s) " (" LOGURU_FMT(s) " " LOGURU_FMT(s) " " LOGURU_FMT(s) ") ", \ - #expr_left, #op, #expr_right, str_left.c_str(), #op, str_right.c_str()); \ - auto user_msg = loguru::textprintf(__VA_ARGS__); \ - loguru::log_and_abort(0, fail_info.c_str(), __FILE__, __LINE__, \ - LOGURU_FMT(s), user_msg.c_str()); \ - } \ - } while (false) - -#ifndef LOGURU_DEBUG_LOGGING -#ifndef NDEBUG -#define LOGURU_DEBUG_LOGGING 1 -#else -#define LOGURU_DEBUG_LOGGING 0 -#endif -#endif - -#if LOGURU_DEBUG_LOGGING - // Debug logging enabled: -#define DLOG_F(verbosity_name, ...) LOG_F(verbosity_name, __VA_ARGS__) -#define DVLOG_F(verbosity, ...) VLOG_F(verbosity, __VA_ARGS__) -#define DLOG_IF_F(verbosity_name, ...) LOG_IF_F(verbosity_name, __VA_ARGS__) -#define DVLOG_IF_F(verbosity, ...) VLOG_IF_F(verbosity, __VA_ARGS__) -#define DRAW_LOG_F(verbosity_name, ...) RAW_LOG_F(verbosity_name, __VA_ARGS__) -#define DRAW_VLOG_F(verbosity, ...) RAW_VLOG_F(verbosity, __VA_ARGS__) -#else - // Debug logging disabled: -#define DLOG_F(verbosity_name, ...) -#define DVLOG_F(verbosity, ...) -#define DLOG_IF_F(verbosity_name, ...) -#define DVLOG_IF_F(verbosity, ...) -#define DRAW_LOG_F(verbosity_name, ...) -#define DRAW_VLOG_F(verbosity, ...) -#endif - -#define CHECK_EQ_F(a, b, ...) CHECK_OP_F(a, b, ==, ##__VA_ARGS__) -#define CHECK_NE_F(a, b, ...) CHECK_OP_F(a, b, !=, ##__VA_ARGS__) -#define CHECK_LT_F(a, b, ...) CHECK_OP_F(a, b, < , ##__VA_ARGS__) -#define CHECK_GT_F(a, b, ...) CHECK_OP_F(a, b, > , ##__VA_ARGS__) -#define CHECK_LE_F(a, b, ...) CHECK_OP_F(a, b, <=, ##__VA_ARGS__) -#define CHECK_GE_F(a, b, ...) CHECK_OP_F(a, b, >=, ##__VA_ARGS__) - -#ifndef LOGURU_DEBUG_CHECKS -#ifndef NDEBUG -#define LOGURU_DEBUG_CHECKS 1 -#else -#define LOGURU_DEBUG_CHECKS 0 -#endif -#endif - -#if LOGURU_DEBUG_CHECKS - // Debug checks enabled: -#define DCHECK_F(test, ...) CHECK_F(test, ##__VA_ARGS__) -#define DCHECK_NOTNULL_F(x, ...) CHECK_NOTNULL_F(x, ##__VA_ARGS__) -#define DCHECK_EQ_F(a, b, ...) CHECK_EQ_F(a, b, ##__VA_ARGS__) -#define DCHECK_NE_F(a, b, ...) CHECK_NE_F(a, b, ##__VA_ARGS__) -#define DCHECK_LT_F(a, b, ...) CHECK_LT_F(a, b, ##__VA_ARGS__) -#define DCHECK_LE_F(a, b, ...) CHECK_LE_F(a, b, ##__VA_ARGS__) -#define DCHECK_GT_F(a, b, ...) CHECK_GT_F(a, b, ##__VA_ARGS__) -#define DCHECK_GE_F(a, b, ...) CHECK_GE_F(a, b, ##__VA_ARGS__) -#else - // Debug checks disabled: -#define DCHECK_F(test, ...) -#define DCHECK_NOTNULL_F(x, ...) -#define DCHECK_EQ_F(a, b, ...) -#define DCHECK_NE_F(a, b, ...) -#define DCHECK_LT_F(a, b, ...) -#define DCHECK_LE_F(a, b, ...) -#define DCHECK_GT_F(a, b, ...) -#define DCHECK_GE_F(a, b, ...) -#endif // NDEBUG - - -#if LOGURU_REDEFINE_ASSERT -#undef assert -#ifndef NDEBUG - // Debug: -#define assert(test) CHECK_WITH_INFO_F(!!(test), #test) // HACK -#else -#define assert(test) -#endif -#endif // LOGURU_REDEFINE_ASSERT - -#endif // LOGURU_HAS_DECLARED_FORMAT_HEADER - -// ---------------------------------------------------------------------------- -// .dP"Y8 888888 88""Yb 888888 db 8b d8 .dP"Y8 -// `Ybo." 88 88__dP 88__ dPYb 88b d88 `Ybo." -// o.`Y8b 88 88"Yb 88"" dP__Yb 88YbdP88 o.`Y8b -// 8bodP' 88 88 Yb 888888 dP""""Yb 88 YY 88 8bodP' - -#if LOGURU_WITH_STREAMS -#ifndef LOGURU_HAS_DECLARED_STREAMS_HEADER -#define LOGURU_HAS_DECLARED_STREAMS_HEADER - -/* This file extends loguru to enable std::stream-style logging, a la Glog. - It's an optional feature behind the LOGURU_WITH_STREAMS settings - because including it everywhere will slow down compilation times. -*/ - -#include -#include // Adds about 38 kLoC on clang. -#include - -LOGURU_ANONYMOUS_NAMESPACE_BEGIN - -namespace loguru { - // Like sprintf, but returns the formated text. - LOGURU_EXPORT - std::string strprintf(LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(1, 2); - - // Like vsprintf, but returns the formated text. - LOGURU_EXPORT - std::string vstrprintf(LOGURU_FORMAT_STRING_TYPE format, va_list) LOGURU_PRINTF_LIKE(1, 0); - - class LOGURU_EXPORT StreamLogger { - public: - StreamLogger(Verbosity verbosity, const char* file, unsigned line) : _verbosity(verbosity), _file(file), _line(line) {} - ~StreamLogger() noexcept(false); - - template - StreamLogger& operator<<(const T& t) - { - _ss << t; - return *this; - } - - // std::endl and other iomanip:s. - StreamLogger& operator<<(std::ostream& (*f)(std::ostream&)) - { - f(_ss); - return *this; - } - - private: - Verbosity _verbosity; - const char* _file; - unsigned _line; - std::ostringstream _ss; - }; - - class LOGURU_EXPORT AbortLogger { - public: - AbortLogger(const char* expr, const char* file, unsigned line) : _expr(expr), _file(file), _line(line) {} - LOGURU_NORETURN ~AbortLogger() noexcept(false); - - template - AbortLogger& operator<<(const T& t) - { - _ss << t; - return *this; - } - - // std::endl and other iomanip:s. - AbortLogger& operator<<(std::ostream& (*f)(std::ostream&)) - { - f(_ss); - return *this; - } - - private: - const char* _expr; - const char* _file; - unsigned _line; - std::ostringstream _ss; - }; - - class LOGURU_EXPORT Voidify { - public: - Voidify() {} - // This has to be an operator with a precedence lower than << but higher than ?: - void operator&(const StreamLogger&) {} - void operator&(const AbortLogger&) {} - }; - - /* Helper functions for CHECK_OP_S macro. - GLOG trick: The (int, int) specialization works around the issue that the compiler - will not instantiate the template version of the function on values of unnamed enum type. */ -#define DEFINE_CHECK_OP_IMPL(name, op) \ - template \ - inline std::string* name(const char* expr, const T1& v1, const char* op_str, const T2& v2) \ - { \ - if (LOGURU_PREDICT_TRUE(v1 op v2)) { return NULL; } \ - std::ostringstream ss; \ - ss << "CHECK FAILED: " << expr << " (" << v1 << " " << op_str << " " << v2 << ") "; \ - return new std::string(ss.str()); \ - } \ - inline std::string* name(const char* expr, int v1, const char* op_str, int v2) \ - { \ - return name(expr, v1, op_str, v2); \ - } - - DEFINE_CHECK_OP_IMPL(check_EQ_impl, == ) - DEFINE_CHECK_OP_IMPL(check_NE_impl, != ) - DEFINE_CHECK_OP_IMPL(check_LE_impl, <= ) - DEFINE_CHECK_OP_IMPL(check_LT_impl, < ) - DEFINE_CHECK_OP_IMPL(check_GE_impl, >= ) - DEFINE_CHECK_OP_IMPL(check_GT_impl, > ) -#undef DEFINE_CHECK_OP_IMPL - - /* GLOG trick: Function is overloaded for integral types to allow static const integrals - declared in classes and not defined to be used as arguments to CHECK* macros. */ - template - inline const T& referenceable_value(const T& t) { return t; } - inline char referenceable_value(char t) { return t; } - inline unsigned char referenceable_value(unsigned char t) { return t; } - inline signed char referenceable_value(signed char t) { return t; } - inline short referenceable_value(short t) { return t; } - inline unsigned short referenceable_value(unsigned short t) { return t; } - inline int referenceable_value(int t) { return t; } - inline unsigned int referenceable_value(unsigned int t) { return t; } - inline long referenceable_value(long t) { return t; } - inline unsigned long referenceable_value(unsigned long t) { return t; } - inline long long referenceable_value(long long t) { return t; } - inline unsigned long long referenceable_value(unsigned long long t) { return t; } -} // namespace loguru - -LOGURU_ANONYMOUS_NAMESPACE_END - -// ----------------------------------------------- -// Logging macros: - -// usage: LOG_STREAM(INFO) << "Foo " << std::setprecision(10) << some_value; -#define VLOG_IF_S(verbosity, cond) \ - ((verbosity) > loguru::current_verbosity_cutoff() || (cond) == false) \ - ? (void)0 \ - : loguru::Voidify() & loguru::StreamLogger(verbosity, __FILE__, __LINE__) -#define LOG_IF_S(verbosity_name, cond) VLOG_IF_S(loguru::Verbosity_ ## verbosity_name, cond) -#define VLOG_S(verbosity) VLOG_IF_S(verbosity, true) -#define LOG_S(verbosity_name) VLOG_S(loguru::Verbosity_ ## verbosity_name) - -// ----------------------------------------------- -// ABORT_S macro. Usage: ABORT_S() << "Causo of error: " << details; - -#define ABORT_S() loguru::Voidify() & loguru::AbortLogger("ABORT: ", __FILE__, __LINE__) - -// ----------------------------------------------- -// CHECK_S macros: - -#define CHECK_WITH_INFO_S(cond, info) \ - LOGURU_PREDICT_TRUE((cond) == true) \ - ? (void)0 \ - : loguru::Voidify() & loguru::AbortLogger("CHECK FAILED: " info " ", __FILE__, __LINE__) - -#define CHECK_S(cond) CHECK_WITH_INFO_S(cond, #cond) -#define CHECK_NOTNULL_S(x) CHECK_WITH_INFO_S((x) != nullptr, #x " != nullptr") - -#define CHECK_OP_S(function_name, expr1, op, expr2) \ - while (auto error_string = loguru::function_name(#expr1 " " #op " " #expr2, \ - loguru::referenceable_value(expr1), #op, \ - loguru::referenceable_value(expr2))) \ - loguru::AbortLogger(error_string->c_str(), __FILE__, __LINE__) - -#define CHECK_EQ_S(expr1, expr2) CHECK_OP_S(check_EQ_impl, expr1, ==, expr2) -#define CHECK_NE_S(expr1, expr2) CHECK_OP_S(check_NE_impl, expr1, !=, expr2) -#define CHECK_LE_S(expr1, expr2) CHECK_OP_S(check_LE_impl, expr1, <=, expr2) -#define CHECK_LT_S(expr1, expr2) CHECK_OP_S(check_LT_impl, expr1, < , expr2) -#define CHECK_GE_S(expr1, expr2) CHECK_OP_S(check_GE_impl, expr1, >=, expr2) -#define CHECK_GT_S(expr1, expr2) CHECK_OP_S(check_GT_impl, expr1, > , expr2) - -#if LOGURU_DEBUG_LOGGING - // Debug logging enabled: -#define DVLOG_IF_S(verbosity, cond) VLOG_IF_S(verbosity, cond) -#define DLOG_IF_S(verbosity_name, cond) LOG_IF_S(verbosity_name, cond) -#define DVLOG_S(verbosity) VLOG_S(verbosity) -#define DLOG_S(verbosity_name) LOG_S(verbosity_name) -#else - // Debug logging disabled: -#define DVLOG_IF_S(verbosity, cond) \ - (true || (verbosity) > loguru::current_verbosity_cutoff() || (cond) == false) \ - ? (void)0 \ - : loguru::Voidify() & loguru::StreamLogger(verbosity, __FILE__, __LINE__) - -#define DLOG_IF_S(verbosity_name, cond) DVLOG_IF_S(loguru::Verbosity_ ## verbosity_name, cond) -#define DVLOG_S(verbosity) DVLOG_IF_S(verbosity, true) -#define DLOG_S(verbosity_name) DVLOG_S(loguru::Verbosity_ ## verbosity_name) -#endif - -#if LOGURU_DEBUG_CHECKS - // Debug checks enabled: -#define DCHECK_S(cond) CHECK_S(cond) -#define DCHECK_NOTNULL_S(x) CHECK_NOTNULL_S(x) -#define DCHECK_EQ_S(a, b) CHECK_EQ_S(a, b) -#define DCHECK_NE_S(a, b) CHECK_NE_S(a, b) -#define DCHECK_LT_S(a, b) CHECK_LT_S(a, b) -#define DCHECK_LE_S(a, b) CHECK_LE_S(a, b) -#define DCHECK_GT_S(a, b) CHECK_GT_S(a, b) -#define DCHECK_GE_S(a, b) CHECK_GE_S(a, b) -#else -// Debug checks disabled: -#define DCHECK_S(cond) CHECK_S(true || (cond)) -#define DCHECK_NOTNULL_S(x) CHECK_S(true || (x) != nullptr) -#define DCHECK_EQ_S(a, b) CHECK_S(true || (a) == (b)) -#define DCHECK_NE_S(a, b) CHECK_S(true || (a) != (b)) -#define DCHECK_LT_S(a, b) CHECK_S(true || (a) < (b)) -#define DCHECK_LE_S(a, b) CHECK_S(true || (a) <= (b)) -#define DCHECK_GT_S(a, b) CHECK_S(true || (a) > (b)) -#define DCHECK_GE_S(a, b) CHECK_S(true || (a) >= (b)) -#endif - -#if LOGURU_REPLACE_GLOG -#undef LOG -#undef VLOG -#undef LOG_IF -#undef VLOG_IF -#undef CHECK -#undef CHECK_NOTNULL -#undef CHECK_EQ -#undef CHECK_NE -#undef CHECK_LT -#undef CHECK_LE -#undef CHECK_GT -#undef CHECK_GE -#undef DLOG -#undef DVLOG -#undef DLOG_IF -#undef DVLOG_IF -#undef DCHECK -#undef DCHECK_NOTNULL -#undef DCHECK_EQ -#undef DCHECK_NE -#undef DCHECK_LT -#undef DCHECK_LE -#undef DCHECK_GT -#undef DCHECK_GE -#undef VLOG_IS_ON - -#define LOG LOG_S -#define VLOG VLOG_S -#define LOG_IF LOG_IF_S -#define VLOG_IF VLOG_IF_S -#define CHECK(cond) CHECK_S(!!(cond)) -#define CHECK_NOTNULL CHECK_NOTNULL_S -#define CHECK_EQ CHECK_EQ_S -#define CHECK_NE CHECK_NE_S -#define CHECK_LT CHECK_LT_S -#define CHECK_LE CHECK_LE_S -#define CHECK_GT CHECK_GT_S -#define CHECK_GE CHECK_GE_S -#define DLOG DLOG_S -#define DVLOG DVLOG_S -#define DLOG_IF DLOG_IF_S -#define DVLOG_IF DVLOG_IF_S -#define DCHECK DCHECK_S -#define DCHECK_NOTNULL DCHECK_NOTNULL_S -#define DCHECK_EQ DCHECK_EQ_S -#define DCHECK_NE DCHECK_NE_S -#define DCHECK_LT DCHECK_LT_S -#define DCHECK_LE DCHECK_LE_S -#define DCHECK_GT DCHECK_GT_S -#define DCHECK_GE DCHECK_GE_S -#define VLOG_IS_ON(verbosity) ((verbosity) <= loguru::current_verbosity_cutoff()) - -#endif // LOGURU_REPLACE_GLOG - -#endif // LOGURU_WITH_STREAMS - -#endif // LOGURU_HAS_DECLARED_STREAMS_HEADER diff --git a/lib/mdlp b/lib/mdlp new file mode 160000 index 0000000..cfb993f --- /dev/null +++ b/lib/mdlp @@ -0,0 +1 @@ +Subproject commit cfb993f5ec1aabed527f524fdd4db06c6d839868 diff --git a/sample/CMakeLists.txt b/sample/CMakeLists.txt index 9cda435..0db5797 100644 --- a/sample/CMakeLists.txt +++ b/sample/CMakeLists.txt @@ -1,15 +1,11 @@ include_directories( + ${TORCH_INCLUDE_DIRS} ${Platform_SOURCE_DIR}/src/common ${Platform_SOURCE_DIR}/src/main ${Python3_INCLUDE_DIRS} - ${Platform_SOURCE_DIR}/lib/Files - ${Platform_SOURCE_DIR}/lib/mdlp/src - ${Platform_SOURCE_DIR}/lib/argparse/include - ${Platform_SOURCE_DIR}/lib/folding - ${Platform_SOURCE_DIR}/lib/json/include ${CMAKE_BINARY_DIR}/configured_files/include ${PyClassifiers_INCLUDE_DIRS} - ${Bayesnet_INCLUDE_DIRS} + ${bayesnet_INCLUDE_DIRS} ) add_executable(PlatformSample sample.cpp ${Platform_SOURCE_DIR}/src/main/Models.cpp) -target_link_libraries(PlatformSample "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) \ No newline at end of file +target_link_libraries(PlatformSample "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} ${Boost_LIBRARIES}) \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7825077..8cea0c7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,12 +1,5 @@ include_directories( ## Libs - ${Platform_SOURCE_DIR}/lib/log - ${Platform_SOURCE_DIR}/lib/Files - ${Platform_SOURCE_DIR}/lib/folding - ${Platform_SOURCE_DIR}/lib/mdlp/src - ${Platform_SOURCE_DIR}/lib/argparse/include - ${Platform_SOURCE_DIR}/lib/json/include - ${Platform_SOURCE_DIR}/lib/libxlsxwriter/include ${Python3_INCLUDE_DIRS} ${MPI_CXX_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} @@ -29,7 +22,7 @@ add_executable( experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp ) -target_link_libraries(b_best Boost::boost "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}") +target_link_libraries(b_best Boost::boost "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}") # b_grid set(grid_sources GridSearch.cpp GridData.cpp GridExperiment.cpp GridBase.cpp ) @@ -42,7 +35,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources} experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp ) -target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) +target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) # b_list add_executable(b_list commands/b_list.cpp @@ -53,7 +46,7 @@ add_executable(b_list commands/b_list.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp ) -target_link_libraries(b_list "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}") +target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}") # b_main set(main_sources Experiment.cpp Models.cpp HyperParameters.cpp Scores.cpp ArgumentsExperiment.cpp) @@ -65,7 +58,7 @@ add_executable(b_main commands/b_main.cpp ${main_sources} experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp ) -target_link_libraries(b_main "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) +target_link_libraries(b_main "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) # b_manage set(manage_sources ManageScreen.cpp OptionsMenu.cpp ResultsManager.cpp) @@ -77,7 +70,7 @@ add_executable( results/Result.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp main/Scores.cpp ) -target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" fimdlp "${BayesNet}") +target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" fimdlp bayesnet::bayesnet) # b_results add_executable(b_results commands/b_results.cpp) diff --git a/src/common/Dataset.cpp b/src/common/Dataset.cpp index 26c0882..c635e0f 100644 --- a/src/common/Dataset.cpp +++ b/src/common/Dataset.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include "Dataset.h" namespace platform { diff --git a/vcpkg.json b/vcpkg.json index cf1ea58..31fc479 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -9,9 +9,7 @@ "folding", "bayesnet", "argparse", - "libxlsxwriter", - "zlib", - "libzip" + "libxlsxwriter" ], "overrides": [ { From aa19ab6c2139ca44a9c75d0039b231614e5d8a3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 9 May 2025 19:16:17 +0200 Subject: [PATCH 04/27] Option to use BayesNet local or vcpkg in CMakeLists --- CMakeLists.txt | 24 ++++++++++++++++++++++-- Makefile | 23 ++++++++++++++++++----- gitmodules | 23 ----------------------- lib/argparse | 1 - lib/mdlp | 1 - src/CMakeLists.txt | 7 +++---- src/main/Experiment.cpp | 7 +++++-- 7 files changed, 48 insertions(+), 38 deletions(-) delete mode 100644 gitmodules delete mode 160000 lib/argparse delete mode 160000 lib/mdlp diff --git a/CMakeLists.txt b/CMakeLists.txt index 31af5c0..f89b093 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,7 @@ set(CMAKE_CXX_FLAGS_DEBUG " ${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage -O # Options # ------- +option(BAYESNET_VCPKG_CONFIG "Use vcpkg version of BayesNet" ON) option(ENABLE_TESTING "Unit testing build" OFF) option(CODE_COVERAGE "Collect coverage from test library" OFF) @@ -76,13 +77,32 @@ find_package(folding CONFIG REQUIRED) find_package(argparse CONFIG REQUIRED) find_package(Boost REQUIRED COMPONENTS python) find_package(arff-files CONFIG REQUIRED) -find_package(bayesnet CONFIG REQUIRED) + find_package(ZLIB REQUIRED) find_library(LIBZIP_LIBRARY NAMES zip) +# BayesNet +# if set to ON it will use the vcpkg version of BayesNet else it will use the locally installed version +if (BAYESNET_VCPKG_CONFIG) + message(STATUS "Using BayesNet vcpkg config") + find_package(bayesnet CONFIG REQUIRED) + set(BayesNet_LIBRARIES bayesnet::bayesnet) +else(BAYESNET_VCPKG_CONFIG) + message(STATUS "Using BayesNet local library config") + find_library(bayesnet NAMES libbayesnet bayesnet libbayesnet.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED) + find_path(Bayesnet_INCLUDE_DIRS REQUIRED NAMES bayesnet PATHS ${Platform_SOURCE_DIR}/../lib/include) + add_library(bayesnet::bayesnet UNKNOWN IMPORTED) + set_target_properties(bayesnet::bayesnet PROPERTIES + IMPORTED_LOCATION ${bayesnet} + INTERFACE_INCLUDE_DIRECTORIES ${Bayesnet_INCLUDE_DIRS} + ) +endif(BAYESNET_VCPKG_CONFIG) +message(STATUS "BayesNet_LIBRARIES=${BayesNet_LIBRARIES}") +message(STATUS "BayesNet_INCLUDE_DIRS=${Bayesnet_INCLUDE_DIRS}") + +# PyClassifiers find_library(PyClassifiers NAMES libPyClassifiers PyClassifiers libPyClassifiers.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED) find_path(PyClassifiers_INCLUDE_DIRS REQUIRED NAMES pyclassifiers PATHS ${Platform_SOURCE_DIR}/../lib/include) - message(STATUS "PyClassifiers=${PyClassifiers}") message(STATUS "PyClassifiers_INCLUDE_DIRS=${PyClassifiers_INCLUDE_DIRS}") diff --git a/Makefile b/Makefile index dbe2381..8db9498 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ SHELL := /bin/bash .DEFAULT_GOAL := help -.PHONY: init clean coverage setup help build test clean debug release buildr buildd install dependency testp testb clang-uml +.PHONY: init clean coverage setup help build test clean debug release buildr buildd install dependency testp testb clang-uml debug_local release_local example f_release = build_Release f_debug = build_Debug @@ -76,18 +76,31 @@ buildr: ## Build the release targets clang-uml: ## Create uml class and sequence diagrams clang-uml -p --add-compile-flag -I /usr/lib/gcc/x86_64-redhat-linux/8/include/ -debug: ## Build a debug version of the project +debug: ## Build a debug version of the project with BayesNet from vcpkg @echo ">>> Building Debug Platform..."; @if [ -d ./$(f_debug) ]; then rm -rf ./$(f_debug); fi @mkdir $(f_debug); - @cmake -S . -B $(f_debug) -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake + @cmake -S . -B $(f_debug) -DBAYESNET_VCPKG_CONFIG=ON -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake @echo ">>> Done"; -release: ## Build a Release version of the project +debug_local: ## Build a debug version of the project with BayesNet local + @echo ">>> Building Debug Platform..."; + @if [ -d ./$(f_debug) ]; then rm -rf ./$(f_debug); fi + @mkdir $(f_debug); + @cmake -S . -B $(f_debug) -DBAYESNET_VCPKG_CONFIG=OFF -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake + @echo ">>> Done"; + +release: ## Build a Release version of the project with BayesNet from vcpkg @echo ">>> Building Release Platform..."; @if [ -d ./$(f_release) ]; then rm -rf ./$(f_release); fi @mkdir $(f_release); - @cmake -S . -B $(f_release) -D CMAKE_BUILD_TYPE=Release -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake + @cmake -S . -B $(f_release) -DBAYESNET_VCPKG_CONFIG=ON -D CMAKE_BUILD_TYPE=Release -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake + @echo ">>> Done"; +release_local: ## Build a Release version of the project with BayesNet local + @echo ">>> Building Release Platform..."; + @if [ -d ./$(f_release) ]; then rm -rf ./$(f_release); fi + @mkdir $(f_release); + @cmake -S . -B $(f_release) -DBAYESNET_VCPKG_CONFIG=OFF -D CMAKE_BUILD_TYPE=Release -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake @echo ">>> Done"; opt = "" diff --git a/gitmodules b/gitmodules deleted file mode 100644 index 394a867..0000000 --- a/gitmodules +++ /dev/null @@ -1,23 +0,0 @@ -[submodule "lib/catch2"] - path = lib/catch2 - main = v2.x - update = merge - url = https://github.com/catchorg/Catch2.git -[submodule "lib/argparse"] - path = lib/argparse - url = https://github.com/p-ranav/argparse - master = master - update = merge -[submodule "lib/json"] - path = lib/json - url = https://github.com/nlohmann/json.git - master = master - update = merge -[submodule "lib/libxlsxwriter"] - path = lib/libxlsxwriter - url = https://github.com/jmcnamara/libxlsxwriter.git - main = main - update = merge -[submodule "lib/folding"] - path = lib/folding - url = https://github.com/rmontanana/Folding diff --git a/lib/argparse b/lib/argparse deleted file mode 160000 index cbd9fd8..0000000 --- a/lib/argparse +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cbd9fd8ed675ed6a2ac1bd7142d318c6ad5d3462 diff --git a/lib/mdlp b/lib/mdlp deleted file mode 160000 index cfb993f..0000000 --- a/lib/mdlp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cfb993f5ec1aabed527f524fdd4db06c6d839868 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8cea0c7..4e0f2e9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,7 +5,6 @@ include_directories( ${TORCH_INCLUDE_DIRS} ${CMAKE_BINARY_DIR}/configured_files/include ${PyClassifiers_INCLUDE_DIRS} - ${Bayesnet_INCLUDE_DIRS} ## Platform ${Platform_SOURCE_DIR}/src ${Platform_SOURCE_DIR}/results @@ -35,7 +34,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources} experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp ) -target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) +target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) # b_list add_executable(b_list commands/b_list.cpp @@ -46,7 +45,7 @@ add_executable(b_list commands/b_list.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp ) -target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}") +target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}") # b_main set(main_sources Experiment.cpp Models.cpp HyperParameters.cpp Scores.cpp ArgumentsExperiment.cpp) @@ -58,7 +57,7 @@ add_executable(b_main commands/b_main.cpp ${main_sources} experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp ) -target_link_libraries(b_main "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) +target_link_libraries(b_main "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) # b_manage set(manage_sources ManageScreen.cpp OptionsMenu.cpp ResultsManager.cpp) diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 06cbc41..0c2c39a 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -262,13 +262,15 @@ namespace platform { auto y_proba_train = clf->predict_proba(X_train); Scores scores(y_train, y_proba_train, num_classes, labels); score_train_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc(); - confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true)); + if (discretized) + confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true)); } // // Test model // if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "c"); + std::cout << "Discretized: " << discretized << " " << score_train_value << std::endl; test_timer.start(); // auto y_predict = clf->predict(X_test); auto y_proba_test = clf->predict_proba(X_test); @@ -277,7 +279,8 @@ namespace platform { test_time[item] = test_timer.getDuration(); score_train[item] = score_train_value; score_test[item] = score_test_value; - confusion_matrices.push_back(scores.get_confusion_matrix_json(true)); + if (discretized) + confusion_matrices.push_back(scores.get_confusion_matrix_json(true)); if (!quiet) std::cout << "\b\b\b, " << flush; // From 36c72491e78aceb9b038c3830ad27a17af7d9bb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 13 May 2025 13:50:07 +0200 Subject: [PATCH 05/27] Add folder to b_best --- README.md | 1 + src/commands/b_best.cpp | 10 +++++----- src/experimental_clfs/ExpClf.h | 1 + src/main/Experiment.cpp | 1 - 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5e5172d..05d5749 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ ![C++](https://img.shields.io/badge/c++-%2300599C.svg?style=flat&logo=c%2B%2B&logoColor=white) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)]() +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/rmontanana/Platform) ![Gitea Last Commit](https://img.shields.io/gitea/last-commit/rmontanana/platform?gitea_url=https://gitea.rmontanana.es&logo=gitea) Platform to run Bayesian Networks and Machine Learning Classifiers experiments. diff --git a/src/commands/b_best.cpp b/src/commands/b_best.cpp index 39ec19a..970ff1f 100644 --- a/src/commands/b_best.cpp +++ b/src/commands/b_best.cpp @@ -9,9 +9,8 @@ void manageArguments(argparse::ArgumentParser& program) { - program.add_argument("-m", "--model") - .help("Model to use or any") - .default_value("any"); + program.add_argument("-m", "--model").help("Model to use or any").default_value("any"); + program.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results()); program.add_argument("-d", "--dataset").default_value("any").help("Filter results of the selected model) (any for all datasets)"); program.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied"); program.add_argument("--friedman").help("Friedman test").default_value(false).implicit_value(true); @@ -38,12 +37,13 @@ int main(int argc, char** argv) { argparse::ArgumentParser program("b_best", { platform_project_version.begin(), platform_project_version.end() }); manageArguments(program); - std::string model, dataset, score; + std::string model, dataset, score, folder; bool build, report, friedman, excel, tex, index; double level; try { program.parse_args(argc, argv); model = program.get("model"); + folder = program.get("folder"); dataset = program.get("dataset"); score = program.get("score"); friedman = program.get("friedman"); @@ -66,7 +66,7 @@ int main(int argc, char** argv) exit(1); } // Generate report - auto results = platform::BestResults(platform::Paths::results(), score, model, dataset, friedman, level); + auto results = platform::BestResults(folder, score, model, dataset, friedman, level); if (model == "any") { results.buildAll(); results.reportAll(excel, tex, index); diff --git a/src/experimental_clfs/ExpClf.h b/src/experimental_clfs/ExpClf.h index fc6d3ec..dbb2140 100644 --- a/src/experimental_clfs/ExpClf.h +++ b/src/experimental_clfs/ExpClf.h @@ -43,6 +43,7 @@ namespace platform { void add_active_parents(const std::vector& active_parents); void add_active_parent(int parent); void remove_last_parent(); + void setHyperparameters(const nlohmann::json& hyperparameters_) override {}; protected: bool debug = false; Xaode aode_; diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 0c2c39a..3e33e28 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -270,7 +270,6 @@ namespace platform { // if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "c"); - std::cout << "Discretized: " << discretized << " " << score_train_value << std::endl; test_timer.start(); // auto y_predict = clf->predict(X_test); auto y_proba_test = clf->predict_proba(X_test); From 321e2a2f289c6ad2b9369248f46d69318ff45d98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 13 May 2025 14:09:25 +0200 Subject: [PATCH 06/27] Add folder to manage --- CMakeLists.txt | 28 ++++++++-------------------- Makefile | 19 +++---------------- src/CMakeLists.txt | 2 +- src/commands/b_manage.cpp | 5 ++++- src/manage/ManageScreen.cpp | 4 ++-- src/manage/ManageScreen.h | 2 +- src/manage/ResultsManager.cpp | 5 ++--- src/manage/ResultsManager.h | 2 +- vcpkg.json | 12 +----------- 9 files changed, 23 insertions(+), 56 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f89b093..8875491 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,6 @@ set(CMAKE_CXX_FLAGS_DEBUG " ${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage -O # Options # ------- -option(BAYESNET_VCPKG_CONFIG "Use vcpkg version of BayesNet" ON) option(ENABLE_TESTING "Unit testing build" OFF) option(CODE_COVERAGE "Collect coverage from test library" OFF) @@ -75,29 +74,18 @@ find_package(Torch CONFIG REQUIRED) find_package(fimdlp CONFIG REQUIRED) find_package(folding CONFIG REQUIRED) find_package(argparse CONFIG REQUIRED) +find_package(nlohmann_json CONFIG REQUIRED) find_package(Boost REQUIRED COMPONENTS python) find_package(arff-files CONFIG REQUIRED) -find_package(ZLIB REQUIRED) -find_library(LIBZIP_LIBRARY NAMES zip) - # BayesNet -# if set to ON it will use the vcpkg version of BayesNet else it will use the locally installed version -if (BAYESNET_VCPKG_CONFIG) - message(STATUS "Using BayesNet vcpkg config") - find_package(bayesnet CONFIG REQUIRED) - set(BayesNet_LIBRARIES bayesnet::bayesnet) -else(BAYESNET_VCPKG_CONFIG) - message(STATUS "Using BayesNet local library config") - find_library(bayesnet NAMES libbayesnet bayesnet libbayesnet.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED) - find_path(Bayesnet_INCLUDE_DIRS REQUIRED NAMES bayesnet PATHS ${Platform_SOURCE_DIR}/../lib/include) - add_library(bayesnet::bayesnet UNKNOWN IMPORTED) - set_target_properties(bayesnet::bayesnet PROPERTIES - IMPORTED_LOCATION ${bayesnet} - INTERFACE_INCLUDE_DIRECTORIES ${Bayesnet_INCLUDE_DIRS} - ) -endif(BAYESNET_VCPKG_CONFIG) -message(STATUS "BayesNet_LIBRARIES=${BayesNet_LIBRARIES}") +find_library(bayesnet NAMES libbayesnet bayesnet libbayesnet.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED) +find_path(Bayesnet_INCLUDE_DIRS REQUIRED NAMES bayesnet PATHS ${Platform_SOURCE_DIR}/../lib/include) +add_library(bayesnet::bayesnet UNKNOWN IMPORTED) +set_target_properties(bayesnet::bayesnet PROPERTIES + IMPORTED_LOCATION ${bayesnet} + INTERFACE_INCLUDE_DIRECTORIES ${Bayesnet_INCLUDE_DIRS}) +message(STATUS "BayesNet=${bayesnet}") message(STATUS "BayesNet_INCLUDE_DIRS=${Bayesnet_INCLUDE_DIRS}") # PyClassifiers diff --git a/Makefile b/Makefile index 8db9498..1b810c9 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ SHELL := /bin/bash .DEFAULT_GOAL := help -.PHONY: init clean coverage setup help build test clean debug release buildr buildd install dependency testp testb clang-uml debug_local release_local example +.PHONY: init clean coverage setup help build test clean debug release buildr buildd install dependency testp testb clang-uml example f_release = build_Release f_debug = build_Debug @@ -80,28 +80,15 @@ debug: ## Build a debug version of the project with BayesNet from vcpkg @echo ">>> Building Debug Platform..."; @if [ -d ./$(f_debug) ]; then rm -rf ./$(f_debug); fi @mkdir $(f_debug); - @cmake -S . -B $(f_debug) -DBAYESNET_VCPKG_CONFIG=ON -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake - @echo ">>> Done"; - -debug_local: ## Build a debug version of the project with BayesNet local - @echo ">>> Building Debug Platform..."; - @if [ -d ./$(f_debug) ]; then rm -rf ./$(f_debug); fi - @mkdir $(f_debug); - @cmake -S . -B $(f_debug) -DBAYESNET_VCPKG_CONFIG=OFF -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake + @cmake -S . -B $(f_debug) -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake @echo ">>> Done"; release: ## Build a Release version of the project with BayesNet from vcpkg @echo ">>> Building Release Platform..."; @if [ -d ./$(f_release) ]; then rm -rf ./$(f_release); fi @mkdir $(f_release); - @cmake -S . -B $(f_release) -DBAYESNET_VCPKG_CONFIG=ON -D CMAKE_BUILD_TYPE=Release -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake + @cmake -S . -B $(f_release) -D CMAKE_BUILD_TYPE=Release -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake @echo ">>> Done"; -release_local: ## Build a Release version of the project with BayesNet local - @echo ">>> Building Release Platform..."; - @if [ -d ./$(f_release) ]; then rm -rf ./$(f_release); fi - @mkdir $(f_release); - @cmake -S . -B $(f_release) -DBAYESNET_VCPKG_CONFIG=OFF -D CMAKE_BUILD_TYPE=Release -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake - @echo ">>> Done"; opt = "" test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximum Spanning Tree'") to run only that section diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4e0f2e9..b89cebc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,7 +57,7 @@ add_executable(b_main commands/b_main.cpp ${main_sources} experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp ) -target_link_libraries(b_main "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) +target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) # b_manage set(manage_sources ManageScreen.cpp OptionsMenu.cpp ResultsManager.cpp) diff --git a/src/commands/b_manage.cpp b/src/commands/b_manage.cpp index 0dda157..66a2ea6 100644 --- a/src/commands/b_manage.cpp +++ b/src/commands/b_manage.cpp @@ -2,6 +2,7 @@ #include #include #include +#include "common/Paths.h" #include #include "manage/ManageScreen.h" #include @@ -13,6 +14,7 @@ void manageArguments(argparse::ArgumentParser& program, int argc, char** argv) { program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)"); program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied"); + program.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results()); program.add_argument("--platform").default_value("any").help("Filter results of the selected platform"); program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true); program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true); @@ -116,6 +118,7 @@ int main(int argc, char** argv) auto program = argparse::ArgumentParser("b_manage", { platform_project_version.begin(), platform_project_version.end() }); manageArguments(program, argc, argv); std::string model = program.get("model"); + std::string path = program.get("folder"); std::string score = program.get("score"); std::string platform = program.get("platform"); bool complete = program.get("complete"); @@ -125,7 +128,7 @@ int main(int argc, char** argv) partial = false; signal(SIGWINCH, handleResize); auto [rows, cols] = numRowsCols(); - manager = new platform::ManageScreen(rows, cols, model, score, platform, complete, partial, compare); + manager = new platform::ManageScreen(path, rows, cols, model, score, platform, complete, partial, compare); manager->doMenu(); auto fileName = manager->getExcelFileName(); delete manager; diff --git a/src/manage/ManageScreen.cpp b/src/manage/ManageScreen.cpp index 6648d94..f4b11d1 100644 --- a/src/manage/ManageScreen.cpp +++ b/src/manage/ManageScreen.cpp @@ -18,8 +18,8 @@ namespace platform { const std::string STATUS_OK = "Ok."; const std::string STATUS_COLOR = Colors::GREEN(); - ManageScreen::ManageScreen(int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare) : - rows{ rows }, cols{ cols }, complete{ complete }, partial{ partial }, compare{ compare }, didExcel(false), results(ResultsManager(model, score, platform, complete, partial)) + ManageScreen::ManageScreen(const std::string path, int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare) : + rows{ rows }, cols{ cols }, complete{ complete }, partial{ partial }, compare{ compare }, didExcel(false), results(ResultsManager(path, model, score, platform, complete, partial)) { results.load(); openExcel = false; diff --git a/src/manage/ManageScreen.h b/src/manage/ManageScreen.h index 7e41896..7aab91d 100644 --- a/src/manage/ManageScreen.h +++ b/src/manage/ManageScreen.h @@ -15,7 +15,7 @@ namespace platform { }; class ManageScreen { public: - ManageScreen(int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare); + ManageScreen(const std::string path, int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare); ~ManageScreen() = default; void doMenu(); void updateSize(int rows, int cols); diff --git a/src/manage/ResultsManager.cpp b/src/manage/ResultsManager.cpp index f4b722a..fa4e895 100644 --- a/src/manage/ResultsManager.cpp +++ b/src/manage/ResultsManager.cpp @@ -1,10 +1,9 @@ #include -#include "common/Paths.h" #include "ResultsManager.h" namespace platform { - ResultsManager::ResultsManager(const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial) : - path(Paths::results()), model(model), scoreName(score), platform(platform), complete(complete), partial(partial), maxModel(0), maxTitle(0) + ResultsManager::ResultsManager(const std::string& path_, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial) : + path(path_), model(model), scoreName(score), platform(platform), complete(complete), partial(partial), maxModel(0), maxTitle(0) { } void ResultsManager::load() diff --git a/src/manage/ResultsManager.h b/src/manage/ResultsManager.h index cabf909..f891c44 100644 --- a/src/manage/ResultsManager.h +++ b/src/manage/ResultsManager.h @@ -18,7 +18,7 @@ namespace platform { }; class ResultsManager { public: - ResultsManager(const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial); + ResultsManager(const std::string& path_, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial); void load(); // Loads the list of results void sortResults(SortField field, SortType type); // Sorts the list of results void sortDate(SortType type); diff --git a/vcpkg.json b/vcpkg.json index 31fc479..69ded49 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -7,9 +7,7 @@ "fimdlp", "libtorch-bin", "folding", - "bayesnet", - "argparse", - "libxlsxwriter" + "argparse" ], "overrides": [ { @@ -24,10 +22,6 @@ "name": "libtorch-bin", "version": "2.7.0" }, - { - "name": "bayesnet", - "version": "1.1.1" - }, { "name": "folding", "version": "1.1.1" @@ -36,10 +30,6 @@ "name": "argpase", "version": "3.2" }, - { - "name": "libxlsxwriter", - "version": "1.2.2" - }, { "name": "nlohmann-json", "version": "3.11.3" From d6603dd638c953bbd242220b175177c1bc48a395 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 14 May 2025 11:46:15 +0200 Subject: [PATCH 07/27] Add folder parameter to best, grid and main --- src/commands/b_best.cpp | 3 +++ src/commands/b_grid.cpp | 3 ++- src/commands/b_main.cpp | 3 ++- src/main/ArgumentsExperiment.cpp | 8 +++++++- src/main/ArgumentsExperiment.h | 4 +++- src/main/Experiment.cpp | 6 +++--- src/main/Experiment.h | 2 +- src/results/Result.cpp | 4 ++-- src/results/Result.h | 2 +- 9 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/commands/b_best.cpp b/src/commands/b_best.cpp index 970ff1f..fb4b5cd 100644 --- a/src/commands/b_best.cpp +++ b/src/commands/b_best.cpp @@ -44,6 +44,9 @@ int main(int argc, char** argv) program.parse_args(argc, argv); model = program.get("model"); folder = program.get("folder"); + if (folder.back() != '/') { + folder += '/'; + } dataset = program.get("dataset"); score = program.get("score"); friedman = program.get("friedman"); diff --git a/src/commands/b_grid.cpp b/src/commands/b_grid.cpp index b6efd56..7e246a5 100644 --- a/src/commands/b_grid.cpp +++ b/src/commands/b_grid.cpp @@ -231,6 +231,7 @@ void experiment(argparse::ArgumentParser& program) { struct platform::ConfigGrid config; auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::GRID); + auto path_results = arguments.getPathResults(); arguments.parse(); auto grid_experiment = platform::GridExperiment(arguments, config); platform::Timer timer; @@ -250,7 +251,7 @@ void experiment(argparse::ArgumentParser& program) auto duration = timer.getDuration(); experiment.setDuration(duration); if (grid_experiment.haveToSaveResults()) { - experiment.saveResult(); + experiment.saveResult(path_results); } experiment.report(); std::cout << "Process took " << duration << std::endl; diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index f04a79f..03002d5 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -18,6 +18,7 @@ int main(int argc, char** argv) */ // Initialize the experiment class with the command line arguments auto experiment = arguments.initializedExperiment(); + auto path_results = arguments.getPathResults(); platform::Timer timer; timer.start(); experiment.go(); @@ -27,7 +28,7 @@ int main(int argc, char** argv) experiment.report(); } if (arguments.haveToSaveResults()) { - experiment.saveResult(); + experiment.saveResult(path_results); } if (arguments.doGraph()) { experiment.saveGraph(); diff --git a/src/main/ArgumentsExperiment.cpp b/src/main/ArgumentsExperiment.cpp index aa8199e..8d778ba 100644 --- a/src/main/ArgumentsExperiment.cpp +++ b/src/main/ArgumentsExperiment.cpp @@ -13,6 +13,7 @@ namespace platform { auto env = platform::DotEnv(); auto datasets = platform::Datasets(false, platform::Paths::datasets()); auto& group = arguments.add_mutually_exclusive_group(true); + group.add_argument("-d", "--dataset") .help("Dataset file name: " + datasets.toString()) .default_value("all") @@ -43,6 +44,7 @@ namespace platform { } ); arguments.add_argument("--title").default_value("").help("Experiment title"); + arguments.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results()); arguments.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); auto valid_choices = env.valid_tokens("discretize_algo"); auto& disc_arg = arguments.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo")); @@ -103,6 +105,10 @@ namespace platform { file_name = arguments.get("dataset"); file_names = arguments.get>("datasets"); datasets_file = arguments.get("datasets-file"); + path_results = arguments.get("folder"); + if (path_results.back() != '/') { + path_results += '/'; + } model_name = arguments.get("model"); discretize_dataset = arguments.get("discretize"); discretize_algo = arguments.get("discretize-algo"); @@ -119,7 +125,7 @@ namespace platform { hyper_best = arguments.get("hyper-best"); if (hyper_best) { // Build the best results file_name - hyperparameters_file = platform::Paths::results() + platform::Paths::bestResultsFile(score, model_name); + hyperparameters_file = path_results + platform::Paths::bestResultsFile(score, model_name); // ignore this parameter hyperparameters = "{}"; } else { diff --git a/src/main/ArgumentsExperiment.h b/src/main/ArgumentsExperiment.h index c4528b9..5f2fbc2 100644 --- a/src/main/ArgumentsExperiment.h +++ b/src/main/ArgumentsExperiment.h @@ -22,11 +22,13 @@ namespace platform { bool isQuiet() const { return quiet; } bool haveToSaveResults() const { return saveResults; } bool doGraph() const { return graph; } + std::string getPathResults() const { return path_results; } private: Experiment experiment; experiment_t type; argparse::ArgumentParser& arguments; - std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat, score; + std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat; + std::string score, path_results; json hyperparameters_json; bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph, hyper_best; std::vector seeds; diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 3e33e28..438735d 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -7,12 +7,12 @@ namespace platform { using json = nlohmann::ordered_json; - void Experiment::saveResult() + void Experiment::saveResult(const std::string& path) { result.setSchemaVersion("1.0"); result.check(); - result.save(); - std::cout << "Result saved in " << Paths::results() << result.getFilename() << std::endl; + result.save(path); + std::cout << "Result saved in " << path << result.getFilename() << std::endl; } void Experiment::report() { diff --git a/src/main/Experiment.h b/src/main/Experiment.h index bfad61c..52c08fe 100644 --- a/src/main/Experiment.h +++ b/src/main/Experiment.h @@ -45,7 +45,7 @@ namespace platform { std::vector getRandomSeeds() const { return randomSeeds; } void cross_validation(const std::string& fileName); void go(); - void saveResult(); + void saveResult(const std::string& path); void show(); void saveGraph(); void report(); diff --git a/src/results/Result.cpp b/src/results/Result.cpp index c143874..f37ca61 100644 --- a/src/results/Result.cpp +++ b/src/results/Result.cpp @@ -69,9 +69,9 @@ namespace platform { platform::JsonValidator validator(platform::SchemaV1_0::schema); return validator.validate(data); } - void Result::save() + void Result::save(const std::string& path) { - std::ofstream file(Paths::results() + getFilename()); + std::ofstream file(path + getFilename()); file << data; file.close(); } diff --git a/src/results/Result.h b/src/results/Result.h index 3f45c70..a16748d 100644 --- a/src/results/Result.h +++ b/src/results/Result.h @@ -15,7 +15,7 @@ namespace platform { public: Result(); Result& load(const std::string& path, const std::string& filename); - void save(); + void save(const std::string& path); std::vector check(); // Getters json getJson(); From b639a2d79a056afd106fb898873d5021f27bff8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 14 May 2025 12:51:56 +0200 Subject: [PATCH 08/27] Fix folder param in b_manage --- Makefile | 4 +++- src/commands/b_manage.cpp | 3 +++ src/manage/ManageScreen.cpp | 10 +++++----- src/manage/ManageScreen.h | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 1b810c9..fa659c8 100644 --- a/Makefile +++ b/Makefile @@ -59,7 +59,9 @@ install: ## Copy binary files to bin folder @echo "*******************************************" @for item in $(app_targets); do \ echo ">>> Copying $$item" ; \ - cp $(f_release)/src/$$item $(dest) ; \ + cp $(f_release)/src/$$item $(dest) || { \ + echo "*** Error copying $$item" ; \ + } ; \ done dependency: ## Create a dependency graph diagram of the project (build/dependency.png) diff --git a/src/commands/b_manage.cpp b/src/commands/b_manage.cpp index 66a2ea6..1d88ca2 100644 --- a/src/commands/b_manage.cpp +++ b/src/commands/b_manage.cpp @@ -119,6 +119,9 @@ int main(int argc, char** argv) manageArguments(program, argc, argv); std::string model = program.get("model"); std::string path = program.get("folder"); + if (path.back() != '/') { + path += '/'; + } std::string score = program.get("score"); std::string platform = program.get("platform"); bool complete = program.get("complete"); diff --git a/src/manage/ManageScreen.cpp b/src/manage/ManageScreen.cpp index f4b11d1..d9f65fc 100644 --- a/src/manage/ManageScreen.cpp +++ b/src/manage/ManageScreen.cpp @@ -18,8 +18,8 @@ namespace platform { const std::string STATUS_OK = "Ok."; const std::string STATUS_COLOR = Colors::GREEN(); - ManageScreen::ManageScreen(const std::string path, int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare) : - rows{ rows }, cols{ cols }, complete{ complete }, partial{ partial }, compare{ compare }, didExcel(false), results(ResultsManager(path, model, score, platform, complete, partial)) + ManageScreen::ManageScreen(const std::string path_, int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare) : + path{ path_ }, rows{ rows }, cols{ cols }, complete{ complete }, partial{ partial }, compare{ compare }, didExcel(false), results(ResultsManager(path_, model, score, platform, complete, partial)) { results.load(); openExcel = false; @@ -329,11 +329,11 @@ namespace platform { return; } // Remove the old result file - std::string oldFile = Paths::results() + results.at(index).getFilename(); + std::string oldFile = path + results.at(index).getFilename(); std::filesystem::remove(oldFile); // Actually change the model results.at(index).setModel(newModel); - results.at(index).save(); + results.at(index).save(path); int newModelSize = static_cast(newModel.size()); if (newModelSize > maxModel) { maxModel = newModelSize; @@ -583,7 +583,7 @@ namespace platform { getline(std::cin, newTitle); if (!newTitle.empty()) { results.at(index).setTitle(newTitle); - results.at(index).save(); + results.at(index).save(path); list("Title changed to " + newTitle, Colors::GREEN()); break; } diff --git a/src/manage/ManageScreen.h b/src/manage/ManageScreen.h index 7aab91d..46a02c4 100644 --- a/src/manage/ManageScreen.h +++ b/src/manage/ManageScreen.h @@ -59,7 +59,7 @@ namespace platform { std::vector paginator; ResultsManager results; lxw_workbook* workbook; - std::string excelFileName; + std::string path, excelFileName; }; } #endif \ No newline at end of file From e64e281b6377e715642d35ae664761720358335d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 14 May 2025 13:15:33 +0200 Subject: [PATCH 09/27] Return AUC 0.5 if nPos==0 || nNeg==0 --- src/main/RocAuc.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/RocAuc.cpp b/src/main/RocAuc.cpp index 03a1c31..c41d986 100644 --- a/src/main/RocAuc.cpp +++ b/src/main/RocAuc.cpp @@ -4,7 +4,7 @@ #include #include "RocAuc.h" namespace platform { - + double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels) { size_t nClasses = y_proba.size(1); @@ -48,6 +48,7 @@ namespace platform { double tp = 0, fp = 0; double totalPos = std::count(y_test.begin(), y_test.end(), classIdx); double totalNeg = nSamples - totalPos; + if (totalPos == 0 || totalNeg == 0) return 0.5; // neutral AUC for (const auto& [score, label] : scoresAndLabels) { if (label == 1) { From f5107abea7756e06f2c162706002dc0031b0355e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 14 May 2025 14:02:53 +0200 Subject: [PATCH 10/27] Add comment in Statistics --- src/best/Statistics.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/best/Statistics.cpp b/src/best/Statistics.cpp index 73f1edb..397b679 100644 --- a/src/best/Statistics.cpp +++ b/src/best/Statistics.cpp @@ -91,7 +91,7 @@ namespace platform { } void Statistics::computeWTL() { - // Compute the WTL matrix + // Compute the WTL matrix (Win Tie Loss) for (int i = 0; i < nModels; ++i) { wtl[i] = { 0, 0, 0 }; } From 70d8022926ad576cae23b98fa778944a01dacb14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sat, 17 May 2025 18:12:57 +0200 Subject: [PATCH 11/27] Refactor postHoc --- CMakeLists.txt | 2 - src/CMakeLists.txt | 2 +- src/best/BestResults.cpp | 5 +- src/best/BestResultsExcel.cpp | 9 +-- src/best/BestResultsMd.cpp | 8 +-- src/best/BestResultsMd.h | 2 +- src/best/BestResultsTex.cpp | 16 ++--- src/best/BestResultsTex.h | 5 +- src/best/DeLong.cpp | 45 ++++++++++++++ src/best/DeLong.h | 24 ++++++++ src/best/Statistics.cpp | 107 ++++++++++++++++++++++++++++------ src/best/Statistics.h | 14 +++-- 12 files changed, 192 insertions(+), 47 deletions(-) create mode 100644 src/best/DeLong.cpp create mode 100644 src/best/DeLong.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 8875491..75b6900 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,8 +7,6 @@ project(Platform LANGUAGES CXX ) - - # Global CMake variables # ---------------------- set(CMAKE_CXX_STANDARD 20) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b89cebc..be63bee 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,7 +13,7 @@ include_directories( # b_best add_executable( b_best commands/b_best.cpp best/Statistics.cpp - best/BestResultsExcel.cpp best/BestResultsTex.cpp best/BestResultsMd.cpp best/BestResults.cpp + best/BestResultsExcel.cpp best/BestResultsTex.cpp best/BestResultsMd.cpp best/BestResults.cpp best/DeLong.cpp common/Datasets.cpp common/Dataset.cpp common/Discretization.cpp main/Models.cpp main/Scores.cpp reports/ReportExcel.cpp reports/ReportBase.cpp reports/ExcelFile.cpp diff --git a/src/best/BestResults.cpp b/src/best/BestResults.cpp index 09a2cf3..21da49e 100644 --- a/src/best/BestResults.cpp +++ b/src/best/BestResults.cpp @@ -222,7 +222,7 @@ namespace platform { std::cout << oss.str(); std::cout << std::string(oss.str().size() - 8, '-') << std::endl; std::cout << Colors::GREEN() << " # " << std::setw(maxDatasetName + 1) << std::left << std::string("Dataset"); - auto bestResultsTex = BestResultsTex(); + auto bestResultsTex = BestResultsTex(score); auto bestResultsMd = BestResultsMd(); if (tex) { bestResultsTex.results_header(models, table.at("dateTable").get(), index); @@ -339,7 +339,8 @@ namespace platform { if (friedman) { Statistics stats(models, datasets, table, significance); auto result = stats.friedmanTest(); - stats.postHocHolmTest(result, tex); + stats.postHocHolmTest(); + stats.postHocTestReport("Holm", score, result, tex); ranksModels = stats.getRanks(); } if (tex) { diff --git a/src/best/BestResultsExcel.cpp b/src/best/BestResultsExcel.cpp index 0bc961e..fb7b864 100644 --- a/src/best/BestResultsExcel.cpp +++ b/src/best/BestResultsExcel.cpp @@ -243,9 +243,10 @@ namespace platform { row = 2; Statistics stats(models, datasets, table, significance, false); auto result = stats.friedmanTest(); - stats.postHocHolmTest(result); + stats.postHocHolmTest(); + // stats.postHocTestReport("Holm", result, false); auto friedmanResult = stats.getFriedmanResult(); - auto holmResult = stats.getHolmResult(); + auto postHocResult = stats.getPostHocResult(); worksheet_merge_range(worksheet, row, 0, row, 7, "Null hypothesis: H0 'There is no significant differences between all the classifiers.'", styles["headerSmall"]); row += 2; writeString(row, 1, "Friedman Q", "bodyHeader"); @@ -264,7 +265,7 @@ namespace platform { row += 2; worksheet_merge_range(worksheet, row, 0, row, 7, "Null hypothesis: H0 'There is no significant differences between the control model and the other models.'", styles["headerSmall"]); row += 2; - std::string controlModel = "Control Model: " + holmResult.model; + std::string controlModel = "Control Model: " + postHocResult.model; worksheet_merge_range(worksheet, row, 1, row, 7, controlModel.c_str(), styles["bodyHeader_odd"]); row++; writeString(row, 1, "Model", "bodyHeader"); @@ -276,7 +277,7 @@ namespace platform { writeString(row, 7, "Reject H0", "bodyHeader"); row++; bool first = true; - for (const auto& item : holmResult.holmLines) { + for (const auto& item : postHocResult.postHocLines) { writeString(row, 1, item.model, "text"); if (first) { // Control model info diff --git a/src/best/BestResultsMd.cpp b/src/best/BestResultsMd.cpp index bfa0a9b..3c70901 100644 --- a/src/best/BestResultsMd.cpp +++ b/src/best/BestResultsMd.cpp @@ -75,7 +75,7 @@ namespace platform { handler.close(); } - void BestResultsMd::holm_test(struct HolmResult& holmResult, const std::string& date) + void BestResultsMd::postHoc_test(struct PostHocResult& postHocResult, const std::string& kind, const std::string& date) { auto file_name = Paths::tex() + Paths::md_post_hoc(); openMdFile(file_name); @@ -84,12 +84,12 @@ namespace platform { handler << std::endl; handler << " Post-hoc handler test" << std::endl; handler << "-->" << std::endl; - handler << "Post-hoc Holm test: H0: There is no significant differences between the control model and the other models." << std::endl << std::endl; + handler << "Post-hoc " << kind << " test: H0: There is no significant differences between the control model and the other models." << std::endl << std::endl; handler << "| classifier | pvalue | rank | win | tie | loss | H0 |" << std::endl; handler << "| :-- | --: | --: | --:| --: | --: | :--: |" << std::endl; - for (auto const& line : holmResult.holmLines) { + for (auto const& line : postHocResult.postHocLines) { auto textStatus = !line.reject ? "**" : " "; - if (line.model == holmResult.model) { + if (line.model == postHocResult.model) { handler << "| " << line.model << " | - | " << std::fixed << std::setprecision(2) << line.rank << " | - | - | - |" << std::endl; } else { handler << "| " << line.model << " | " << textStatus << std::scientific << std::setprecision(4) << line.pvalue << textStatus << " |"; diff --git a/src/best/BestResultsMd.h b/src/best/BestResultsMd.h index 253a54a..894ae80 100644 --- a/src/best/BestResultsMd.h +++ b/src/best/BestResultsMd.h @@ -14,7 +14,7 @@ namespace platform { void results_header(const std::vector& models, const std::string& date); void results_body(const std::vector& datasets, json& table); void results_footer(const std::map>& totals, const std::string& best_model); - void holm_test(struct HolmResult& holmResult, const std::string& date); + void postHoc_test(struct PostHocResult& postHocResult, const std::string& kind, const std::string& date); private: void openMdFile(const std::string& name); std::ofstream handler; diff --git a/src/best/BestResultsTex.cpp b/src/best/BestResultsTex.cpp index bf74c88..39e17d9 100644 --- a/src/best/BestResultsTex.cpp +++ b/src/best/BestResultsTex.cpp @@ -27,8 +27,10 @@ namespace platform { handler << "\\tiny " << std::endl; handler << "\\renewcommand{\\arraystretch }{1.2} " << std::endl; handler << "\\renewcommand{\\tabcolsep }{0.07cm} " << std::endl; - handler << "\\caption{Accuracy results(mean $\\pm$ std) for all the algorithms and datasets} " << std::endl; - handler << "\\label{tab:results_accuracy}" << std::endl; + auto umetric = metric; + umetric[0] = toupper(umetric[0]); + handler << "\\caption{" << umetric << " results(mean $\\pm$ std) for all the algorithms and datasets} " << std::endl; + handler << "\\label{tab:results_" << metric << "}" << std::endl; std::string header_dataset_name = index ? "r" : "l"; handler << "\\begin{tabular} {{" << header_dataset_name << std::string(models.size(), 'c').c_str() << "}}" << std::endl; handler << "\\hline " << std::endl; @@ -87,25 +89,25 @@ namespace platform { handler << "\\end{table}" << std::endl; handler.close(); } - void BestResultsTex::holm_test(struct HolmResult& holmResult, const std::string& date) + void BestResultsTex::postHoc_test(struct PostHocResult& postHocResult, const std::string& kind, const std::string& date) { auto file_name = Paths::tex() + Paths::tex_post_hoc(); openTexFile(file_name); handler << "%% This file has been generated by the platform program" << std::endl; handler << "%% Date: " << date.c_str() << std::endl; handler << "%%" << std::endl; - handler << "%% Post-hoc handler test" << std::endl; + handler << "%% Post-hoc " << kind << " test" << std::endl; handler << "%%" << std::endl; handler << "\\begin{table}[htbp]" << std::endl; handler << "\\centering" << std::endl; - handler << "\\caption{Results of the post-hoc test for the mean accuracy of the algorithms.}\\label{tab:tests}" << std::endl; + handler << "\\caption{Results of the post-hoc " << kind << " test for the mean " << metric << " of the algorithms.}\\label{ tab:tests }" << std::endl; handler << "\\begin{tabular}{lrrrrr}" << std::endl; handler << "\\hline" << std::endl; handler << "classifier & pvalue & rank & win & tie & loss\\\\" << std::endl; handler << "\\hline" << std::endl; - for (auto const& line : holmResult.holmLines) { + for (auto const& line : postHocResult.postHocLines) { auto textStatus = !line.reject ? "\\bf " : " "; - if (line.model == holmResult.model) { + if (line.model == postHocResult.model) { handler << line.model << " & - & " << std::fixed << std::setprecision(2) << line.rank << " & - & - & - \\\\" << std::endl; } else { handler << line.model << " & " << textStatus << std::scientific << std::setprecision(4) << line.pvalue << " & "; diff --git a/src/best/BestResultsTex.h b/src/best/BestResultsTex.h index ae88c6d..e587dec 100644 --- a/src/best/BestResultsTex.h +++ b/src/best/BestResultsTex.h @@ -9,13 +9,14 @@ namespace platform { using json = nlohmann::ordered_json; class BestResultsTex { public: - BestResultsTex(bool dataset_name = true) : dataset_name(dataset_name) {}; + BestResultsTex(const std::string metric_, bool dataset_name = true) : metric{ metric_ }, dataset_name{ dataset_name } {}; ~BestResultsTex() = default; void results_header(const std::vector& models, const std::string& date, bool index); void results_body(const std::vector& datasets, json& table, bool index); void results_footer(const std::map>& totals, const std::string& best_model); - void holm_test(struct HolmResult& holmResult, const std::string& date); + void postHoc_test(struct PostHocResult& postHocResult, const std::string& kind, const std::string& date); private: + std::string metric; bool dataset_name; void openTexFile(const std::string& name); std::ofstream handler; diff --git a/src/best/DeLong.cpp b/src/best/DeLong.cpp new file mode 100644 index 0000000..dbcc920 --- /dev/null +++ b/src/best/DeLong.cpp @@ -0,0 +1,45 @@ +// DeLong.cpp +// Integración del test de DeLong con la clase RocAuc y Statistics +// Basado en: X. Sun and W. Xu, "Fast Implementation of DeLong’s Algorithm for Comparing the Areas Under Correlated Receiver Operating Characteristic Curves," (2014), y algoritmos inspirados en sklearn/pROC + +#include "DeLong.h" +#include +#include +#include +#include +#include +#include + +namespace platform { + + DeLong::DeLongResult DeLong::compare(const std::vector& aucs_model1, + const std::vector& aucs_model2) + { + if (aucs_model1.size() != aucs_model2.size()) { + throw std::invalid_argument("AUC lists must have the same size"); + } + + size_t N = aucs_model1.size(); + if (N < 2) { + throw std::invalid_argument("At least two AUC values are required"); + } + + std::vector diffs(N); + for (size_t i = 0; i < N; ++i) { + diffs[i] = aucs_model1[i] - aucs_model2[i]; + } + + double mean_diff = std::accumulate(diffs.begin(), diffs.end(), 0.0) / N; + double var = 0.0; + for (size_t i = 0; i < N; ++i) { + var += (diffs[i] - mean_diff) * (diffs[i] - mean_diff); + } + var /= (N * (N - 1)); + if (var <= 0.0) var = 1e-10; + + double z = mean_diff / std::sqrt(var); + double p = 2.0 * (1.0 - std::erfc(std::abs(z) / std::sqrt(2.0)) / 2.0); + return { mean_diff, z, p }; + } + +} diff --git a/src/best/DeLong.h b/src/best/DeLong.h new file mode 100644 index 0000000..07e3cf3 --- /dev/null +++ b/src/best/DeLong.h @@ -0,0 +1,24 @@ +#ifndef DELONG_H +#define DELONG_H +/* ******************************************************************************************************************** +/* Integración del test de DeLong con la clase RocAuc y Statistics +/* Basado en: X. Sun and W. Xu, "Fast Implementation of DeLong’s Algorithm for Comparing the Areas Under Correlated +/* Receiver Operating Characteristic Curves," (2014), y algoritmos inspirados en sklearn/pROC +/* ********************************************************************************************************************/ +#include + +namespace platform { + class DeLong { + public: + struct DeLongResult { + double auc_diff; + double z_stat; + double p_value; + }; + // Compara dos vectores de AUCs por dataset y devuelve diferencia media, + // estadístico z y p-valor usando un test de rangos (DeLong simplificado) + static DeLongResult compare(const std::vector& aucs_model1, + const std::vector& aucs_model2); + }; +} +#endif // DELONG_H \ No newline at end of file diff --git a/src/best/Statistics.cpp b/src/best/Statistics.cpp index 397b679..cd40cc1 100644 --- a/src/best/Statistics.cpp +++ b/src/best/Statistics.cpp @@ -7,6 +7,7 @@ #include "BestResultsTex.h" #include "BestResultsMd.h" #include "Statistics.h" +#include "DeLong.h" namespace platform { @@ -114,8 +115,7 @@ namespace platform { } } } - - void Statistics::postHocHolmTest(bool friedmanResult, bool tex) + void Statistics::postHocHolmTest() { if (!fitted) { fit(); @@ -137,27 +137,33 @@ namespace platform { stats[i] = p_value; } // Sort the models by p-value - std::vector> statsOrder; for (const auto& stat : stats) { - statsOrder.push_back({ stat.first, stat.second }); + postHocData.push_back({ stat.first, stat.second }); } - std::sort(statsOrder.begin(), statsOrder.end(), [](const std::pair& a, const std::pair& b) { + std::sort(postHocData.begin(), postHocData.end(), [](const std::pair& a, const std::pair& b) { return a.second < b.second; }); // Holm adjustment - for (int i = 0; i < statsOrder.size(); ++i) { - auto item = statsOrder.at(i); - double before = i == 0 ? 0.0 : statsOrder.at(i - 1).second; + for (int i = 0; i < postHocData.size(); ++i) { + auto item = postHocData.at(i); + double before = i == 0 ? 0.0 : postHocData.at(i - 1).second; double p_value = std::min((double)1.0, item.second * (nModels - i)); p_value = std::max(before, p_value); - statsOrder[i] = { item.first, p_value }; + postHocData[i] = { item.first, p_value }; } - holmResult.model = models.at(controlIdx); + postHocResult.model = models.at(controlIdx); + } + + void Statistics::postHocTestReport(const std::string& kind, const std::string& metric, bool friedmanResult, bool tex) + { + + std::stringstream oss; + postHocResult.model = models.at(controlIdx); auto color = friedmanResult ? Colors::CYAN() : Colors::YELLOW(); oss << color; oss << " *************************************************************************************************************" << std::endl; - oss << " Post-hoc Holm test: H0: 'There is no significant differences between the control model and the other models.'" << std::endl; + oss << " Post-hoc " << kind << " test: H0: 'There is no significant differences between the control model and the other models.'" << std::endl; oss << " Control model: " << models.at(controlIdx) << std::endl; oss << " " << std::left << std::setw(maxModelName) << std::string("Model") << " p-value rank win tie loss Status" << std::endl; oss << " " << std::string(maxModelName, '=') << " ============ ========= === === ==== =============" << std::endl; @@ -175,12 +181,12 @@ namespace platform { for (const auto& item : ranksOrder) { auto idx = distance(models.begin(), find(models.begin(), models.end(), item.first)); double pvalue = 0.0; - for (const auto& stat : statsOrder) { + for (const auto& stat : postHocData) { if (stat.first == idx) { pvalue = stat.second; } } - holmResult.holmLines.push_back({ item.first, pvalue, item.second, wtl.at(idx), pvalue < significance }); + postHocResult.postHocLines.push_back({ item.first, pvalue, item.second, wtl.at(idx), pvalue < significance }); if (item.first == models.at(controlIdx)) { continue; } @@ -198,12 +204,77 @@ namespace platform { std::cout << oss.str(); } if (tex) { - BestResultsTex bestResultsTex; + BestResultsTex bestResultsTex(metric); BestResultsMd bestResultsMd; - bestResultsTex.holm_test(holmResult, get_date() + " " + get_time()); - bestResultsMd.holm_test(holmResult, get_date() + " " + get_time()); + bestResultsTex.postHoc_test(postHocResult, kind, get_date() + " " + get_time()); + bestResultsMd.postHoc_test(postHocResult, kind, get_date() + " " + get_time()); } } + // void Statistics::postHocDeLongTest(const std::vector>& y_trues, + // const std::vector>>& y_probas, + // bool tex) + // { + // std::map pvalues; + // postHocResult.model = models.at(controlIdx); + // postHocResult.postHocLines.clear(); + + // for (size_t i = 0; i < models.size(); ++i) { + // if ((int)i == controlIdx) continue; + // double acc_p = 0.0; + // int valid = 0; + // for (size_t d = 0; d < y_trues.size(); ++d) { + // try { + // auto result = compareModelsWithDeLong(y_probas[controlIdx][d], y_probas[i][d], y_trues[d]); + // acc_p += result.p_value; + // ++valid; + // } + // catch (...) {} + // } + // if (valid > 0) { + // pvalues[i] = acc_p / valid; + // } + // } + + // std::vector> sorted_pvalues(pvalues.begin(), pvalues.end()); + // std::sort(sorted_pvalues.begin(), sorted_pvalues.end(), [](const auto& a, const auto& b) { + // return a.second < b.second; + // }); + + // std::stringstream oss; + // oss << "\n*************************************************************************************************************\n"; + // oss << " Post-hoc DeLong-Holm test: H0: 'No significant differences in AUC with control model.'\n"; + // oss << " Control model: " << models[controlIdx] << "\n"; + // oss << " " << std::left << std::setw(maxModelName) << std::string("Model") << " p-value Adjusted Result\n"; + // oss << " " << std::string(maxModelName, '=') << " ============ ========== =============\n"; + + // double prev = 0.0; + // for (size_t i = 0; i < sorted_pvalues.size(); ++i) { + // int idx = sorted_pvalues[i].first; + // double raw = sorted_pvalues[i].second; + // double adj = std::min(1.0, raw * (models.size() - i - 1)); + // adj = std::max(prev, adj); + // prev = adj; + // bool reject = adj < significance; + + // postHocResult.postHocLines.push_back({ models[idx], adj, 0.0f, {}, reject }); + + // auto color = reject ? Colors::MAGENTA() : Colors::GREEN(); + // auto status = reject ? Symbols::cross : Symbols::check_mark; + // auto textStatus = reject ? " rejected H0" : " accepted H0"; + // oss << " " << color << std::left << std::setw(maxModelName) << models[idx] << " "; + // oss << std::setprecision(6) << std::scientific << raw << " "; + // oss << std::setprecision(6) << std::scientific << adj << " " << status << textStatus << "\n"; + // } + // oss << Colors::CYAN() << " *************************************************************************************************************\n"; + // oss << Colors::RESET(); + // if (output) std::cout << oss.str(); + // if (tex) { + // BestResultsTex bestResultsTex; + // BestResultsMd bestResultsMd; + // bestResultsTex.holm_test(postHocResult, get_date() + " " + get_time()); + // bestResultsMd.holm_test(postHocResult, get_date() + " " + get_time()); + // } + // } bool Statistics::friedmanTest() { if (!fitted) { @@ -249,9 +320,9 @@ namespace platform { { return friedmanResult; } - HolmResult& Statistics::getHolmResult() + PostHocResult& Statistics::getPostHocResult() { - return holmResult; + return postHocResult; } std::map>& Statistics::getRanks() { diff --git a/src/best/Statistics.h b/src/best/Statistics.h index ee98c96..285f34c 100644 --- a/src/best/Statistics.h +++ b/src/best/Statistics.h @@ -19,24 +19,25 @@ namespace platform { long double pvalue; bool reject; }; - struct HolmLine { + struct PostHocLine { std::string model; long double pvalue; double rank; WTL wtl; bool reject; }; - struct HolmResult { + struct PostHocResult { std::string model; - std::vector holmLines; + std::vector postHocLines; }; class Statistics { public: Statistics(const std::vector& models, const std::vector& datasets, const json& data, double significance = 0.05, bool output = true); bool friedmanTest(); - void postHocHolmTest(bool friedmanResult, bool tex=false); + void postHocHolmTest(); + void postHocTestReport(const std::string& kind, const std::string& metric, bool friedmanResult, bool tex); FriedmanResult& getFriedmanResult(); - HolmResult& getHolmResult(); + PostHocResult& getPostHocResult(); std::map>& getRanks(); private: void fit(); @@ -53,10 +54,11 @@ namespace platform { int controlIdx = 0; std::map wtl; std::map ranks; + std::vector> postHocData; int maxModelName = 0; int maxDatasetName = 0; FriedmanResult friedmanResult; - HolmResult holmResult; + PostHocResult postHocResult; std::map> ranksModels; }; } From a56ec98ef9f1894e50b7a8380a2b698484a12fb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 21 May 2025 11:51:04 +0200 Subject: [PATCH 12/27] Add Wilcoxon Test --- src/CMakeLists.txt | 2 +- src/best/BestResults.cpp | 29 +--- src/best/BestResultsExcel.cpp | 10 +- src/best/BestResultsExcel.h | 3 +- src/best/BestResultsTex.cpp | 6 +- src/best/BestResultsTex.h | 4 +- src/best/DeLong.cpp | 45 ------ src/best/DeLong.h | 24 --- src/best/Statistics.cpp | 140 ++++++++--------- src/best/Statistics.h | 13 +- src/best/WilcoxonTest.hpp | 250 +++++++++++++++++++++++++++++++ src/commands/b_grid.cpp | 2 +- src/main/ArgumentsExperiment.cpp | 27 +++- src/main/Experiment.cpp | 4 +- 14 files changed, 369 insertions(+), 190 deletions(-) delete mode 100644 src/best/DeLong.cpp delete mode 100644 src/best/DeLong.h create mode 100644 src/best/WilcoxonTest.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index be63bee..b89cebc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,7 +13,7 @@ include_directories( # b_best add_executable( b_best commands/b_best.cpp best/Statistics.cpp - best/BestResultsExcel.cpp best/BestResultsTex.cpp best/BestResultsMd.cpp best/BestResults.cpp best/DeLong.cpp + best/BestResultsExcel.cpp best/BestResultsTex.cpp best/BestResultsMd.cpp best/BestResults.cpp common/Datasets.cpp common/Dataset.cpp common/Discretization.cpp main/Models.cpp main/Scores.cpp reports/ReportExcel.cpp reports/ReportBase.cpp reports/ExcelFile.cpp diff --git a/src/best/BestResults.cpp b/src/best/BestResults.cpp index 21da49e..5e7c349 100644 --- a/src/best/BestResults.cpp +++ b/src/best/BestResults.cpp @@ -321,7 +321,7 @@ namespace platform { // Build the table of results json table = buildTableResults(models); std::vector datasets = getDatasets(table.begin().value()); - BestResultsExcel excel_report(score, datasets); + BestResultsExcel excel_report(path, score, datasets); excel_report.reportSingle(model, path + Paths::bestResultsFile(score, model)); messageOutputFile("Excel", excel_report.getFileName()); } @@ -337,10 +337,10 @@ namespace platform { // Compute the Friedman test std::map> ranksModels; if (friedman) { - Statistics stats(models, datasets, table, significance); + Statistics stats(score, models, datasets, table, significance); auto result = stats.friedmanTest(); - stats.postHocHolmTest(); - stats.postHocTestReport("Holm", score, result, tex); + stats.postHocTest(); + stats.postHocTestReport(result, tex); ranksModels = stats.getRanks(); } if (tex) { @@ -352,24 +352,11 @@ namespace platform { } } if (excel) { - BestResultsExcel excel(score, datasets); + BestResultsExcel excel(path, score, datasets); excel.reportAll(models, table, ranksModels, friedman, significance); if (friedman) { - int idx = -1; - double min = 2000; - // Find out the control model - auto totals = std::vector(models.size(), 0.0); - for (const auto& dataset_ : datasets) { - for (int i = 0; i < models.size(); ++i) { - totals[i] += ranksModels[dataset_][models[i]]; - } - } - for (int i = 0; i < models.size(); ++i) { - if (totals[i] < min) { - min = totals[i]; - idx = i; - } - } + Statistics stats(score, models, datasets, table, significance); + int idx = stats.getControlIdx(); model = models.at(idx); excel.reportSingle(model, path + Paths::bestResultsFile(score, model)); } @@ -378,7 +365,7 @@ namespace platform { } void BestResults::messageOutputFile(const std::string& title, const std::string& fileName) { - std::cout << Colors::YELLOW() << "** " << std::setw(5) << std::left << title + std::cout << Colors::YELLOW() << "** " << std::setw(8) << std::left << title << " file generated: " << fileName << Colors::RESET() << std::endl; } } \ No newline at end of file diff --git a/src/best/BestResultsExcel.cpp b/src/best/BestResultsExcel.cpp index fb7b864..36cfcb3 100644 --- a/src/best/BestResultsExcel.cpp +++ b/src/best/BestResultsExcel.cpp @@ -30,7 +30,7 @@ namespace platform { } return columnName; } - BestResultsExcel::BestResultsExcel(const std::string& score, const std::vector& datasets) : score(score), datasets(datasets) + BestResultsExcel::BestResultsExcel(const std::string& path, const std::string& score, const std::vector& datasets) : path(path), score(score), datasets(datasets) { file_name = Paths::bestResultsExcel(score); workbook = workbook_new(getFileName().c_str()); @@ -92,7 +92,7 @@ namespace platform { catch (const std::out_of_range& oor) { auto tabName = "table_" + std::to_string(i); auto worksheetNew = workbook_add_worksheet(workbook, tabName.c_str()); - json data = loadResultData(Paths::results() + fileName); + json data = loadResultData(path + fileName); auto report = ReportExcel(data, false, workbook, worksheetNew); report.show(); hyperlink = "#table_" + std::to_string(i); @@ -241,10 +241,10 @@ namespace platform { } worksheet_merge_range(worksheet, 0, 0, 0, 7, "Friedman Test", styles["headerFirst"]); row = 2; - Statistics stats(models, datasets, table, significance, false); + Statistics stats(score, models, datasets, table, significance, false); // No output auto result = stats.friedmanTest(); - stats.postHocHolmTest(); - // stats.postHocTestReport("Holm", result, false); + stats.postHocTest(); + stats.postHocTestReport(result, false); // No tex output auto friedmanResult = stats.getFriedmanResult(); auto postHocResult = stats.getPostHocResult(); worksheet_merge_range(worksheet, row, 0, row, 7, "Null hypothesis: H0 'There is no significant differences between all the classifiers.'", styles["headerSmall"]); diff --git a/src/best/BestResultsExcel.h b/src/best/BestResultsExcel.h index 6c70a49..bd8bf94 100644 --- a/src/best/BestResultsExcel.h +++ b/src/best/BestResultsExcel.h @@ -10,7 +10,7 @@ namespace platform { using json = nlohmann::ordered_json; class BestResultsExcel : public ExcelFile { public: - BestResultsExcel(const std::string& score, const std::vector& datasets); + BestResultsExcel(const std::string& path, const std::string& score, const std::vector& datasets); ~BestResultsExcel(); void reportAll(const std::vector& models, const json& table, const std::map>& ranks, bool friedman, double significance); void reportSingle(const std::string& model, const std::string& fileName); @@ -22,6 +22,7 @@ namespace platform { void formatColumns(); void doFriedman(); void addConditionalFormat(std::string formula); + std::string path; std::string score; std::vector models; std::vector datasets; diff --git a/src/best/BestResultsTex.cpp b/src/best/BestResultsTex.cpp index 39e17d9..cf4827f 100644 --- a/src/best/BestResultsTex.cpp +++ b/src/best/BestResultsTex.cpp @@ -27,10 +27,10 @@ namespace platform { handler << "\\tiny " << std::endl; handler << "\\renewcommand{\\arraystretch }{1.2} " << std::endl; handler << "\\renewcommand{\\tabcolsep }{0.07cm} " << std::endl; - auto umetric = metric; + auto umetric = score; umetric[0] = toupper(umetric[0]); handler << "\\caption{" << umetric << " results(mean $\\pm$ std) for all the algorithms and datasets} " << std::endl; - handler << "\\label{tab:results_" << metric << "}" << std::endl; + handler << "\\label{tab:results_" << score << "}" << std::endl; std::string header_dataset_name = index ? "r" : "l"; handler << "\\begin{tabular} {{" << header_dataset_name << std::string(models.size(), 'c').c_str() << "}}" << std::endl; handler << "\\hline " << std::endl; @@ -100,7 +100,7 @@ namespace platform { handler << "%%" << std::endl; handler << "\\begin{table}[htbp]" << std::endl; handler << "\\centering" << std::endl; - handler << "\\caption{Results of the post-hoc " << kind << " test for the mean " << metric << " of the algorithms.}\\label{ tab:tests }" << std::endl; + handler << "\\caption{Results of the post-hoc " << kind << " test for the mean " << score << " of the algorithms.}\\label{ tab:tests }" << std::endl; handler << "\\begin{tabular}{lrrrrr}" << std::endl; handler << "\\hline" << std::endl; handler << "classifier & pvalue & rank & win & tie & loss\\\\" << std::endl; diff --git a/src/best/BestResultsTex.h b/src/best/BestResultsTex.h index e587dec..e0a82e0 100644 --- a/src/best/BestResultsTex.h +++ b/src/best/BestResultsTex.h @@ -9,14 +9,14 @@ namespace platform { using json = nlohmann::ordered_json; class BestResultsTex { public: - BestResultsTex(const std::string metric_, bool dataset_name = true) : metric{ metric_ }, dataset_name{ dataset_name } {}; + BestResultsTex(const std::string score, bool dataset_name = true) : score{ score }, dataset_name{ dataset_name } {}; ~BestResultsTex() = default; void results_header(const std::vector& models, const std::string& date, bool index); void results_body(const std::vector& datasets, json& table, bool index); void results_footer(const std::map>& totals, const std::string& best_model); void postHoc_test(struct PostHocResult& postHocResult, const std::string& kind, const std::string& date); private: - std::string metric; + std::string score; bool dataset_name; void openTexFile(const std::string& name); std::ofstream handler; diff --git a/src/best/DeLong.cpp b/src/best/DeLong.cpp deleted file mode 100644 index dbcc920..0000000 --- a/src/best/DeLong.cpp +++ /dev/null @@ -1,45 +0,0 @@ -// DeLong.cpp -// Integración del test de DeLong con la clase RocAuc y Statistics -// Basado en: X. Sun and W. Xu, "Fast Implementation of DeLong’s Algorithm for Comparing the Areas Under Correlated Receiver Operating Characteristic Curves," (2014), y algoritmos inspirados en sklearn/pROC - -#include "DeLong.h" -#include -#include -#include -#include -#include -#include - -namespace platform { - - DeLong::DeLongResult DeLong::compare(const std::vector& aucs_model1, - const std::vector& aucs_model2) - { - if (aucs_model1.size() != aucs_model2.size()) { - throw std::invalid_argument("AUC lists must have the same size"); - } - - size_t N = aucs_model1.size(); - if (N < 2) { - throw std::invalid_argument("At least two AUC values are required"); - } - - std::vector diffs(N); - for (size_t i = 0; i < N; ++i) { - diffs[i] = aucs_model1[i] - aucs_model2[i]; - } - - double mean_diff = std::accumulate(diffs.begin(), diffs.end(), 0.0) / N; - double var = 0.0; - for (size_t i = 0; i < N; ++i) { - var += (diffs[i] - mean_diff) * (diffs[i] - mean_diff); - } - var /= (N * (N - 1)); - if (var <= 0.0) var = 1e-10; - - double z = mean_diff / std::sqrt(var); - double p = 2.0 * (1.0 - std::erfc(std::abs(z) / std::sqrt(2.0)) / 2.0); - return { mean_diff, z, p }; - } - -} diff --git a/src/best/DeLong.h b/src/best/DeLong.h deleted file mode 100644 index 07e3cf3..0000000 --- a/src/best/DeLong.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef DELONG_H -#define DELONG_H -/* ******************************************************************************************************************** -/* Integración del test de DeLong con la clase RocAuc y Statistics -/* Basado en: X. Sun and W. Xu, "Fast Implementation of DeLong’s Algorithm for Comparing the Areas Under Correlated -/* Receiver Operating Characteristic Curves," (2014), y algoritmos inspirados en sklearn/pROC -/* ********************************************************************************************************************/ -#include - -namespace platform { - class DeLong { - public: - struct DeLongResult { - double auc_diff; - double z_stat; - double p_value; - }; - // Compara dos vectores de AUCs por dataset y devuelve diferencia media, - // estadístico z y p-valor usando un test de rangos (DeLong simplificado) - static DeLongResult compare(const std::vector& aucs_model1, - const std::vector& aucs_model2); - }; -} -#endif // DELONG_H \ No newline at end of file diff --git a/src/best/Statistics.cpp b/src/best/Statistics.cpp index cd40cc1..0c27cef 100644 --- a/src/best/Statistics.cpp +++ b/src/best/Statistics.cpp @@ -7,19 +7,25 @@ #include "BestResultsTex.h" #include "BestResultsMd.h" #include "Statistics.h" -#include "DeLong.h" +#include "WilcoxonTest.hpp" namespace platform { - Statistics::Statistics(const std::vector& models, const std::vector& datasets, const json& data, double significance, bool output) : - models(models), datasets(datasets), data(data), significance(significance), output(output) + Statistics::Statistics(const std::string& score, const std::vector& models, const std::vector& datasets, const json& data, double significance, bool output) : + score(score), models(models), datasets(datasets), data(data), significance(significance), output(output) { + if (score == "accuracy") { + postHocType = "Holm"; + hlen = 85; + } else { + postHocType = "Wilcoxon"; + hlen = 88; + } nModels = models.size(); nDatasets = datasets.size(); auto temp = ConfigLocale(); } - void Statistics::fit() { if (nModels < 3 || nDatasets < 3) { @@ -28,9 +34,11 @@ namespace platform { throw std::runtime_error("Can't make the Friedman test with less than 3 models and/or less than 3 datasets."); } ranksModels.clear(); - computeRanks(); + computeRanks(); // compute greaterAverage and ranks // Set the control model as the one with the lowest average rank - controlIdx = distance(ranks.begin(), min_element(ranks.begin(), ranks.end(), [](const auto& l, const auto& r) { return l.second < r.second; })); + controlIdx = score == "accuracy" ? + distance(ranks.begin(), min_element(ranks.begin(), ranks.end(), [](const auto& l, const auto& r) { return l.second < r.second; })) + : greaterAverage; // The model with the greater average score computeWTL(); maxModelName = (*std::max_element(models.begin(), models.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); maxDatasetName = (*std::max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); @@ -67,11 +75,16 @@ namespace platform { void Statistics::computeRanks() { std::map ranksLine; + std::map averages; + for (const auto& model : models) { + averages[model] = 0; + } for (const auto& dataset : datasets) { std::vector> ranksOrder; for (const auto& model : models) { double value = data[model].at(dataset).at(0).get(); ranksOrder.push_back({ model, value }); + averages[model] += value; } // Assign the ranks ranksLine = assignRanks(ranksOrder); @@ -89,6 +102,12 @@ namespace platform { for (const auto& rank : ranks) { ranks[rank.first] /= nDatasets; } + // Average the scores + for (const auto& average : averages) { + averages[average.first] /= nDatasets; + } + // Get the model with the greater average score + greaterAverage = distance(averages.begin(), max_element(averages.begin(), averages.end(), [](const auto& l, const auto& r) { return l.second < r.second; })); } void Statistics::computeWTL() { @@ -115,12 +134,36 @@ namespace platform { } } } + int Statistics::getControlIdx() + { + if (!fitted) { + fit(); + } + return controlIdx; + } + void Statistics::postHocTest() + { + // if (score == "accuracy") { + postHocHolmTest(); + // } else { + // postHocWilcoxonTest(); + // } + } + void Statistics::postHocWilcoxonTest() + { + if (!fitted) { + fit(); + } + // Reference: Wilcoxon, F. (1945). “Individual Comparisons by Ranking Methods”. Biometrics Bulletin, 1(6), 80-83. + auto wilcoxon = WilcoxonTest(models, datasets, data, significance); + controlIdx = wilcoxon.getControlIdx(); + postHocResult = wilcoxon.getPostHocResult(); + } void Statistics::postHocHolmTest() { if (!fitted) { fit(); } - std::stringstream oss; // Reference https://link.springer.com/article/10.1007/s44196-022-00083-8 // Post-hoc Holm test // Calculate the p-value for the models paired with the control model @@ -155,15 +198,15 @@ namespace platform { postHocResult.model = models.at(controlIdx); } - void Statistics::postHocTestReport(const std::string& kind, const std::string& metric, bool friedmanResult, bool tex) + void Statistics::postHocTestReport(bool friedmanResult, bool tex) { std::stringstream oss; postHocResult.model = models.at(controlIdx); auto color = friedmanResult ? Colors::CYAN() : Colors::YELLOW(); oss << color; - oss << " *************************************************************************************************************" << std::endl; - oss << " Post-hoc " << kind << " test: H0: 'There is no significant differences between the control model and the other models.'" << std::endl; + oss << " " << std::string(hlen + 25, '*') << std::endl; + oss << " Post-hoc " << postHocType << " test: H0: 'There is no significant differences between the control model and the other models.'" << std::endl; oss << " Control model: " << models.at(controlIdx) << std::endl; oss << " " << std::left << std::setw(maxModelName) << std::string("Model") << " p-value rank win tie loss Status" << std::endl; oss << " " << std::string(maxModelName, '=') << " ============ ========= === === ==== =============" << std::endl; @@ -198,83 +241,18 @@ namespace platform { oss << " " << std::right << std::setw(3) << wtl.at(idx).win << " " << std::setw(3) << wtl.at(idx).tie << " " << std::setw(4) << wtl.at(idx).loss; oss << " " << status << textStatus << std::endl; } - oss << color << " *************************************************************************************************************" << std::endl; + oss << color << " " << std::string(hlen + 25, '*') << std::endl; oss << Colors::RESET(); if (output) { std::cout << oss.str(); } if (tex) { - BestResultsTex bestResultsTex(metric); + BestResultsTex bestResultsTex(score); BestResultsMd bestResultsMd; - bestResultsTex.postHoc_test(postHocResult, kind, get_date() + " " + get_time()); - bestResultsMd.postHoc_test(postHocResult, kind, get_date() + " " + get_time()); + bestResultsTex.postHoc_test(postHocResult, postHocType, get_date() + " " + get_time()); + bestResultsMd.postHoc_test(postHocResult, postHocType, get_date() + " " + get_time()); } } - // void Statistics::postHocDeLongTest(const std::vector>& y_trues, - // const std::vector>>& y_probas, - // bool tex) - // { - // std::map pvalues; - // postHocResult.model = models.at(controlIdx); - // postHocResult.postHocLines.clear(); - - // for (size_t i = 0; i < models.size(); ++i) { - // if ((int)i == controlIdx) continue; - // double acc_p = 0.0; - // int valid = 0; - // for (size_t d = 0; d < y_trues.size(); ++d) { - // try { - // auto result = compareModelsWithDeLong(y_probas[controlIdx][d], y_probas[i][d], y_trues[d]); - // acc_p += result.p_value; - // ++valid; - // } - // catch (...) {} - // } - // if (valid > 0) { - // pvalues[i] = acc_p / valid; - // } - // } - - // std::vector> sorted_pvalues(pvalues.begin(), pvalues.end()); - // std::sort(sorted_pvalues.begin(), sorted_pvalues.end(), [](const auto& a, const auto& b) { - // return a.second < b.second; - // }); - - // std::stringstream oss; - // oss << "\n*************************************************************************************************************\n"; - // oss << " Post-hoc DeLong-Holm test: H0: 'No significant differences in AUC with control model.'\n"; - // oss << " Control model: " << models[controlIdx] << "\n"; - // oss << " " << std::left << std::setw(maxModelName) << std::string("Model") << " p-value Adjusted Result\n"; - // oss << " " << std::string(maxModelName, '=') << " ============ ========== =============\n"; - - // double prev = 0.0; - // for (size_t i = 0; i < sorted_pvalues.size(); ++i) { - // int idx = sorted_pvalues[i].first; - // double raw = sorted_pvalues[i].second; - // double adj = std::min(1.0, raw * (models.size() - i - 1)); - // adj = std::max(prev, adj); - // prev = adj; - // bool reject = adj < significance; - - // postHocResult.postHocLines.push_back({ models[idx], adj, 0.0f, {}, reject }); - - // auto color = reject ? Colors::MAGENTA() : Colors::GREEN(); - // auto status = reject ? Symbols::cross : Symbols::check_mark; - // auto textStatus = reject ? " rejected H0" : " accepted H0"; - // oss << " " << color << std::left << std::setw(maxModelName) << models[idx] << " "; - // oss << std::setprecision(6) << std::scientific << raw << " "; - // oss << std::setprecision(6) << std::scientific << adj << " " << status << textStatus << "\n"; - // } - // oss << Colors::CYAN() << " *************************************************************************************************************\n"; - // oss << Colors::RESET(); - // if (output) std::cout << oss.str(); - // if (tex) { - // BestResultsTex bestResultsTex; - // BestResultsMd bestResultsMd; - // bestResultsTex.holm_test(postHocResult, get_date() + " " + get_time()); - // bestResultsMd.holm_test(postHocResult, get_date() + " " + get_time()); - // } - // } bool Statistics::friedmanTest() { if (!fitted) { @@ -284,7 +262,7 @@ namespace platform { // Friedman test // Calculate the Friedman statistic oss << Colors::BLUE() << std::endl; - oss << "***************************************************************************************************************" << std::endl; + oss << std::string(hlen, '*') << std::endl; oss << Colors::GREEN() << "Friedman test: H0: 'There is no significant differences between all the classifiers.'" << Colors::BLUE() << std::endl; double degreesOfFreedom = nModels - 1.0; double sumSquared = 0; @@ -309,7 +287,7 @@ namespace platform { oss << Colors::YELLOW() << "The null hypothesis H0 is accepted. Computed p-values will not be significant." << std::endl; result = false; } - oss << Colors::BLUE() << "***************************************************************************************************************" << Colors::RESET() << std::endl; + oss << Colors::BLUE() << std::string(hlen, '*') << Colors::RESET() << std::endl; if (output) { std::cout << oss.str(); } diff --git a/src/best/Statistics.h b/src/best/Statistics.h index 285f34c..765ed1d 100644 --- a/src/best/Statistics.h +++ b/src/best/Statistics.h @@ -32,17 +32,22 @@ namespace platform { }; class Statistics { public: - Statistics(const std::vector& models, const std::vector& datasets, const json& data, double significance = 0.05, bool output = true); + Statistics(const std::string& score, const std::vector& models, const std::vector& datasets, const json& data, double significance = 0.05, bool output = true); bool friedmanTest(); - void postHocHolmTest(); - void postHocTestReport(const std::string& kind, const std::string& metric, bool friedmanResult, bool tex); + void postHocTest(); + void postHocTestReport(bool friedmanResult, bool tex); + int getControlIdx(); FriedmanResult& getFriedmanResult(); PostHocResult& getPostHocResult(); std::map>& getRanks(); private: void fit(); + void postHocHolmTest(); + void postHocWilcoxonTest(); void computeRanks(); void computeWTL(); + const std::string& score; + std::string postHocType; const std::vector& models; const std::vector& datasets; const json& data; @@ -52,11 +57,13 @@ namespace platform { int nModels = 0; int nDatasets = 0; int controlIdx = 0; + int greaterAverage = -1; // The model with the greater average score std::map wtl; std::map ranks; std::vector> postHocData; int maxModelName = 0; int maxDatasetName = 0; + int hlen; // length of the line FriedmanResult friedmanResult; PostHocResult postHocResult; std::map> ranksModels; diff --git a/src/best/WilcoxonTest.hpp b/src/best/WilcoxonTest.hpp new file mode 100644 index 0000000..34c2969 --- /dev/null +++ b/src/best/WilcoxonTest.hpp @@ -0,0 +1,250 @@ +#ifndef BEST_WILCOXON_TEST_HPP +#define BEST_WILCOXON_TEST_HPP +// WilcoxonTest.hpp +// Stand‑alone class for paired Wilcoxon signed‑rank post‑hoc analysis +// ------------------------------------------------------------------ +// * Constructor takes the *already‑loaded* nlohmann::json object plus the +// vectors of model and dataset names. +// * Internally selects a control model (highest average AUC) and builds all +// statistics (ranks, W/T/L counts, Wilcoxon p‑values). +// * Public API: +// int getControlIdx() const; +// PostHocResult getPostHocResult() const; +// +#include +#include +#include +#include +#include +#include +#include +#include "Statistics.h" + +namespace platform { + class WilcoxonTest { + public: + WilcoxonTest(const std::vector& models, + const std::vector& datasets, + const json& data, + double alpha = 0.05) + : models_(models), datasets_(datasets), data_(data), alpha_(alpha) + { + buildAUCTable(); // extracts all AUCs into a dense matrix + computeAverageAUCs(); // per‑model mean (→ control selection) + computeAverageRanks(); // Friedman‑style ranks per model + selectControlModel(); // sets control_idx_ + buildPostHocResult(); // fills postHocResult_ + } + + //---------------------------------------------------- public API ---- + int getControlIdx() const noexcept { return control_idx_; } + + const PostHocResult& getPostHocResult() const noexcept { return postHocResult_; } + + private: + //-------------------------------------------------- helper structs ---- + // When a value is missing we keep NaN so that ordinary arithmetic still + // works (NaN simply propagates and we can test with std::isnan). + using Matrix = std::vector>; // [model][dataset] + + //------------------------------------------------- implementation ---- + void buildAUCTable() + { + const std::size_t M = models_.size(); + const std::size_t D = datasets_.size(); + auc_.assign(M, std::vector(D, std::numeric_limits::quiet_NaN())); + + for (std::size_t i = 0; i < M; ++i) { + const auto& model = models_[i]; + for (std::size_t j = 0; j < D; ++j) { + const auto& ds = datasets_[j]; + try { + auc_[i][j] = data_.at(model).at(ds).at(0).get(); + } + catch (...) { + // leave as NaN when value missing + } + } + } + } + + void computeAverageAUCs() + { + const std::size_t M = models_.size(); + avg_auc_.resize(M, std::numeric_limits::quiet_NaN()); + + for (std::size_t i = 0; i < M; ++i) { + double sum = 0.0; + std::size_t cnt = 0; + for (double v : auc_[i]) { + if (!std::isnan(v)) { sum += v; ++cnt; } + } + avg_auc_[i] = cnt ? sum / cnt : std::numeric_limits::quiet_NaN(); + } + } + + // Average rank across datasets (1 = best). + void computeAverageRanks() + { + const std::size_t M = models_.size(); + const std::size_t D = datasets_.size(); + rank_sum_.assign(M, 0.0); + rank_cnt_.assign(M, 0); + + const double EPS = 1e-10; + + for (std::size_t j = 0; j < D; ++j) { + // Collect present values for this dataset + std::vector> vals; // (auc, model_idx) + vals.reserve(M); + for (std::size_t i = 0; i < M; ++i) { + if (!std::isnan(auc_[i][j])) + vals.emplace_back(auc_[i][j], i); + } + if (vals.empty()) continue; // no info for this dataset + + // Sort descending (higher AUC better) + std::sort(vals.begin(), vals.end(), [](auto a, auto b) { + return a.first > b.first; + }); + + // Assign ranks with average for ties + std::size_t k = 0; + while (k < vals.size()) { + std::size_t l = k + 1; + while (l < vals.size() && std::fabs(vals[l].first - vals[k].first) < EPS) ++l; + const double avg_rank = (k + 1 + l) * 0.5; // average of ranks (1‑based) + for (std::size_t m = k; m < l; ++m) { + const auto idx = vals[m].second; + rank_sum_[idx] += avg_rank; + ++rank_cnt_[idx]; + } + k = l; + } + } + + // Final average + avg_rank_.resize(M, std::numeric_limits::quiet_NaN()); + for (std::size_t i = 0; i < M; ++i) { + avg_rank_[i] = rank_cnt_[i] ? rank_sum_[i] / rank_cnt_[i] + : std::numeric_limits::quiet_NaN(); + } + } + + void selectControlModel() + { + // pick model with highest average AUC (ties → first) + control_idx_ = 0; + for (std::size_t i = 1; i < avg_auc_.size(); ++i) { + if (avg_auc_[i] > avg_auc_[control_idx_]) control_idx_ = static_cast(i); + } + } + + void buildPostHocResult() + { + const std::size_t M = models_.size(); + const std::size_t D = datasets_.size(); + const std::string& control_name = models_[control_idx_]; + + postHocResult_.model = control_name; + + const double practical_threshold = 0.0005; // same heuristic as original code + + for (std::size_t i = 0; i < M; ++i) { + if (static_cast(i) == control_idx_) continue; + + PostHocLine line; + line.model = models_[i]; + line.rank = avg_rank_[i]; + + WTL wtl; + std::vector differences; + differences.reserve(D); + + for (std::size_t j = 0; j < D; ++j) { + double auc_control = auc_[control_idx_][j]; + double auc_other = auc_[i][j]; + if (std::isnan(auc_control) || std::isnan(auc_other)) continue; + + double diff = auc_control - auc_other; // control − comparison + if (std::fabs(diff) <= practical_threshold) { + ++wtl.tie; + } else if (diff < 0) { + ++wtl.win; // comparison wins + } else { + ++wtl.loss; // control wins + } + differences.push_back(diff); + } + + line.wtl = wtl; + line.pvalue = differences.empty() ? 1.0L : static_cast(wilcoxonSignedRankTest(differences)); + line.reject = (line.pvalue < alpha_); + + postHocResult_.postHocLines.push_back(std::move(line)); + } + } + + // ------------------------------------------------ Wilcoxon (private) -- + static double wilcoxonSignedRankTest(const std::vector& diffs) + { + if (diffs.empty()) return 1.0; + + // Build |diff| + sign vector (exclude zeros) + struct Node { double absval; int sign; }; + std::vector v; + v.reserve(diffs.size()); + for (double d : diffs) { + if (d != 0.0) v.push_back({ std::fabs(d), d > 0 ? 1 : -1 }); + } + if (v.empty()) return 1.0; + + // Sort by absolute value + std::sort(v.begin(), v.end(), [](const Node& a, const Node& b) { return a.absval < b.absval; }); + + const double EPS = 1e-10; + const std::size_t n = v.size(); + std::vector ranks(n, 0.0); + + std::size_t i = 0; + while (i < n) { + std::size_t j = i + 1; + while (j < n && std::fabs(v[j].absval - v[i].absval) < EPS) ++j; + double avg_rank = (i + 1 + j) * 0.5; // 1‑based ranks + for (std::size_t k = i; k < j; ++k) ranks[k] = avg_rank; + i = j; + } + + double w_plus = 0.0, w_minus = 0.0; + for (std::size_t k = 0; k < n; ++k) { + if (v[k].sign > 0) w_plus += ranks[k]; + else w_minus += ranks[k]; + } + double w = std::min(w_plus, w_minus); + double mean_w = n * (n + 1) / 4.0; + double sd_w = std::sqrt(n * (n + 1) * (2 * n + 1) / 24.0); + if (sd_w == 0.0) return 1.0; // degenerate (all diffs identical) + + double z = (w - mean_w) / sd_w; + double p_two = std::erfc(std::fabs(z) / std::sqrt(2.0)); // 2‑sided tail + return p_two; + } + + //-------------------------------------------------------- data ---- + std::vector models_; + std::vector datasets_; + json data_; + double alpha_; + + Matrix auc_; // [model][dataset] + std::vector avg_auc_; // mean AUC per model + std::vector avg_rank_; // mean rank per model + std::vector rank_sum_; // helper for ranks + std::vector rank_cnt_; // datasets counted per model + + int control_idx_ = -1; + PostHocResult postHocResult_; + }; + +} // namespace stats +#endif // BEST_WILCOXON_TEST_HPP \ No newline at end of file diff --git a/src/commands/b_grid.cpp b/src/commands/b_grid.cpp index 7e246a5..b1c6244 100644 --- a/src/commands/b_grid.cpp +++ b/src/commands/b_grid.cpp @@ -231,8 +231,8 @@ void experiment(argparse::ArgumentParser& program) { struct platform::ConfigGrid config; auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::GRID); - auto path_results = arguments.getPathResults(); arguments.parse(); + auto path_results = arguments.getPathResults(); auto grid_experiment = platform::GridExperiment(arguments, config); platform::Timer timer; timer.start(); diff --git a/src/main/ArgumentsExperiment.cpp b/src/main/ArgumentsExperiment.cpp index 8d778ba..58bf990 100644 --- a/src/main/ArgumentsExperiment.cpp +++ b/src/main/ArgumentsExperiment.cpp @@ -215,10 +215,35 @@ namespace platform { test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_json); } } + std::string getGppVersion() + { + std::string result; + std::array buffer; + + // Run g++ --version and capture the output + std::unique_ptr pipe(popen("g++ --version", "r"), pclose); + + if (!pipe) { + return "Error executing g++ --version command"; + } + + // Read the first line of output (which contains the version info) + if (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { + result = buffer.data(); + // Remove trailing newline if present + if (!result.empty() && result[result.length() - 1] == '\n') { + result.erase(result.length() - 1); + } + } else { + return "No output from g++ --version command"; + } + + return result; + } Experiment& ArgumentsExperiment::initializedExperiment() { auto env = platform::DotEnv(); - experiment.setTitle(title).setLanguage("c++").setLanguageVersion("gcc 14.1.1"); + experiment.setTitle(title).setLanguage("c++").setLanguageVersion(getGppVersion()); experiment.setDiscretizationAlgorithm(discretize_algo).setSmoothSrategy(smooth_strat); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform")); experiment.setStratified(stratified).setNFolds(n_folds).setScoreName(score); diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 438735d..e240e7c 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -245,8 +245,6 @@ namespace platform { // Train model // clf->fit(X_train, y_train, features, className, states, smooth_type); - if (!quiet) - showProgress(nfold + 1, getColor(clf->getStatus()), "b"); auto clf_notes = clf->getNotes(); std::transform(clf_notes.begin(), clf_notes.end(), std::back_inserter(notes), [nfold](const std::string& note) { return "Fold " + std::to_string(nfold) + ": " + note; }); @@ -259,6 +257,8 @@ namespace platform { // Score train // if (!no_train_score) { + if (!quiet) + showProgress(nfold + 1, getColor(clf->getStatus()), "b"); auto y_proba_train = clf->predict_proba(X_train); Scores scores(y_train, y_proba_train, num_classes, labels); score_train_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc(); From 473d194ddeb4bbc3231b12841b22df0e0033a868 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sat, 24 May 2025 12:59:28 +0200 Subject: [PATCH 13/27] Complete integration of Wilcoxon test --- src/best/BestResultsExcel.cpp | 22 ++--- src/best/BestResultsMd.cpp | 8 +- src/best/BestResultsMd.h | 2 +- src/best/BestResultsTex.cpp | 8 +- src/best/BestResultsTex.h | 2 +- src/best/Statistics.cpp | 139 +++++++++++++++++--------------- src/best/Statistics.h | 24 +++--- src/best/WilcoxonTest.hpp | 29 +++---- src/reports/DatasetsConsole.cpp | 11 ++- 9 files changed, 128 insertions(+), 117 deletions(-) diff --git a/src/best/BestResultsExcel.cpp b/src/best/BestResultsExcel.cpp index 36cfcb3..71e6cc3 100644 --- a/src/best/BestResultsExcel.cpp +++ b/src/best/BestResultsExcel.cpp @@ -164,13 +164,15 @@ namespace platform { addConditionalFormat("max"); footer(false); if (friedman) { - // Create Sheet with ranks - worksheet = workbook_add_worksheet(workbook, "Ranks"); - formatColumns(); - header(true); - body(true); - addConditionalFormat("min"); - footer(true); + if (score == "accuracy") { + // Create Sheet with ranks + worksheet = workbook_add_worksheet(workbook, "Ranks"); + formatColumns(); + header(true); + body(true); + addConditionalFormat("min"); + footer(true); + } // Create Sheet with Friedman Test doFriedman(); } @@ -246,7 +248,7 @@ namespace platform { stats.postHocTest(); stats.postHocTestReport(result, false); // No tex output auto friedmanResult = stats.getFriedmanResult(); - auto postHocResult = stats.getPostHocResult(); + auto postHocResults = stats.getPostHocResults(); worksheet_merge_range(worksheet, row, 0, row, 7, "Null hypothesis: H0 'There is no significant differences between all the classifiers.'", styles["headerSmall"]); row += 2; writeString(row, 1, "Friedman Q", "bodyHeader"); @@ -265,7 +267,7 @@ namespace platform { row += 2; worksheet_merge_range(worksheet, row, 0, row, 7, "Null hypothesis: H0 'There is no significant differences between the control model and the other models.'", styles["headerSmall"]); row += 2; - std::string controlModel = "Control Model: " + postHocResult.model; + std::string controlModel = "Control Model: " + postHocResults.at(0).model; worksheet_merge_range(worksheet, row, 1, row, 7, controlModel.c_str(), styles["bodyHeader_odd"]); row++; writeString(row, 1, "Model", "bodyHeader"); @@ -277,7 +279,7 @@ namespace platform { writeString(row, 7, "Reject H0", "bodyHeader"); row++; bool first = true; - for (const auto& item : postHocResult.postHocLines) { + for (const auto& item : postHocResults) { writeString(row, 1, item.model, "text"); if (first) { // Control model info diff --git a/src/best/BestResultsMd.cpp b/src/best/BestResultsMd.cpp index 3c70901..195d3f6 100644 --- a/src/best/BestResultsMd.cpp +++ b/src/best/BestResultsMd.cpp @@ -75,7 +75,7 @@ namespace platform { handler.close(); } - void BestResultsMd::postHoc_test(struct PostHocResult& postHocResult, const std::string& kind, const std::string& date) + void BestResultsMd::postHoc_test(std::vector& postHocResults, const std::string& kind, const std::string& date) { auto file_name = Paths::tex() + Paths::md_post_hoc(); openMdFile(file_name); @@ -87,10 +87,12 @@ namespace platform { handler << "Post-hoc " << kind << " test: H0: There is no significant differences between the control model and the other models." << std::endl << std::endl; handler << "| classifier | pvalue | rank | win | tie | loss | H0 |" << std::endl; handler << "| :-- | --: | --: | --:| --: | --: | :--: |" << std::endl; - for (auto const& line : postHocResult.postHocLines) { + bool first = true; + for (auto const& line : postHocResults) { auto textStatus = !line.reject ? "**" : " "; - if (line.model == postHocResult.model) { + if (first) { handler << "| " << line.model << " | - | " << std::fixed << std::setprecision(2) << line.rank << " | - | - | - |" << std::endl; + first = false; } else { handler << "| " << line.model << " | " << textStatus << std::scientific << std::setprecision(4) << line.pvalue << textStatus << " |"; handler << std::fixed << std::setprecision(2) << line.rank << " | " << line.wtl.win << " | " << line.wtl.tie << " | " << line.wtl.loss << " |"; diff --git a/src/best/BestResultsMd.h b/src/best/BestResultsMd.h index 894ae80..8ff2612 100644 --- a/src/best/BestResultsMd.h +++ b/src/best/BestResultsMd.h @@ -14,7 +14,7 @@ namespace platform { void results_header(const std::vector& models, const std::string& date); void results_body(const std::vector& datasets, json& table); void results_footer(const std::map>& totals, const std::string& best_model); - void postHoc_test(struct PostHocResult& postHocResult, const std::string& kind, const std::string& date); + void postHoc_test(std::vector& postHocResults, const std::string& kind, const std::string& date); private: void openMdFile(const std::string& name); std::ofstream handler; diff --git a/src/best/BestResultsTex.cpp b/src/best/BestResultsTex.cpp index cf4827f..afe19ad 100644 --- a/src/best/BestResultsTex.cpp +++ b/src/best/BestResultsTex.cpp @@ -89,7 +89,7 @@ namespace platform { handler << "\\end{table}" << std::endl; handler.close(); } - void BestResultsTex::postHoc_test(struct PostHocResult& postHocResult, const std::string& kind, const std::string& date) + void BestResultsTex::postHoc_test(std::vector& postHocResults, const std::string& kind, const std::string& date) { auto file_name = Paths::tex() + Paths::tex_post_hoc(); openTexFile(file_name); @@ -105,10 +105,12 @@ namespace platform { handler << "\\hline" << std::endl; handler << "classifier & pvalue & rank & win & tie & loss\\\\" << std::endl; handler << "\\hline" << std::endl; - for (auto const& line : postHocResult.postHocLines) { + bool first = true; + for (auto const& line : postHocResults) { auto textStatus = !line.reject ? "\\bf " : " "; - if (line.model == postHocResult.model) { + if (first) { handler << line.model << " & - & " << std::fixed << std::setprecision(2) << line.rank << " & - & - & - \\\\" << std::endl; + first = false; } else { handler << line.model << " & " << textStatus << std::scientific << std::setprecision(4) << line.pvalue << " & "; handler << std::fixed << std::setprecision(2) << line.rank << " & " << line.wtl.win << " & " << line.wtl.tie << " & " << line.wtl.loss << "\\\\" << std::endl; diff --git a/src/best/BestResultsTex.h b/src/best/BestResultsTex.h index e0a82e0..7392d7c 100644 --- a/src/best/BestResultsTex.h +++ b/src/best/BestResultsTex.h @@ -14,7 +14,7 @@ namespace platform { void results_header(const std::vector& models, const std::string& date, bool index); void results_body(const std::vector& datasets, json& table, bool index); void results_footer(const std::map>& totals, const std::string& best_model); - void postHoc_test(struct PostHocResult& postHocResult, const std::string& kind, const std::string& date); + void postHoc_test(std::vector& postHocResults, const std::string& kind, const std::string& date); private: std::string score; bool dataset_name; diff --git a/src/best/Statistics.cpp b/src/best/Statistics.cpp index 0c27cef..04418ce 100644 --- a/src/best/Statistics.cpp +++ b/src/best/Statistics.cpp @@ -111,6 +111,7 @@ namespace platform { } void Statistics::computeWTL() { + const double practical_threshold = 0.0005; // Compute the WTL matrix (Win Tie Loss) for (int i = 0; i < nModels; ++i) { wtl[i] = { 0, 0, 0 }; @@ -124,10 +125,11 @@ namespace platform { continue; } double value = data[models[i]].at(item.key()).at(0).get(); - if (value < controlValue) { - wtl[i].win++; - } else if (value == controlValue) { + double diff = controlValue - value; // control − comparison + if (std::fabs(diff) <= practical_threshold) { wtl[i].tie++; + } else if (diff < 0) { + wtl[i].win++; } else { wtl[i].loss++; } @@ -143,11 +145,11 @@ namespace platform { } void Statistics::postHocTest() { - // if (score == "accuracy") { - postHocHolmTest(); - // } else { - // postHocWilcoxonTest(); - // } + if (score == "accuracy") { + postHocHolmTest(); + } else { + postHocWilcoxonTest(); + } } void Statistics::postHocWilcoxonTest() { @@ -157,7 +159,42 @@ namespace platform { // Reference: Wilcoxon, F. (1945). “Individual Comparisons by Ranking Methods”. Biometrics Bulletin, 1(6), 80-83. auto wilcoxon = WilcoxonTest(models, datasets, data, significance); controlIdx = wilcoxon.getControlIdx(); - postHocResult = wilcoxon.getPostHocResult(); + postHocResults = wilcoxon.getPostHocResults(); + std::cout << std::string(80, '=') << std::endl; + setResultsOrder(); + Holm_Bonferroni(); + restoreResultsOrder(); + } + void Statistics::Holm_Bonferroni() + { + // The algorithm need the p-values sorted from the lowest to the highest + // Sort the models by p-value + std::sort(postHocResults.begin(), postHocResults.end(), [](const PostHocLine& a, const PostHocLine& b) { + return a.pvalue < b.pvalue; + }); + // Holm adjustment + for (int i = 0; i < postHocResults.size(); ++i) { + auto item = postHocResults.at(i); + double before = i == 0 ? 0.0 : postHocResults.at(i - 1).pvalue; + double p_value = std::min((long double)1.0, item.pvalue * (nModels - i)); + p_value = std::max(before, p_value); + postHocResults[i].pvalue = p_value; + } + } + void Statistics::setResultsOrder() + { + int c = 0; + for (auto& item : postHocResults) { + item.idx = c++; + } + + } + void Statistics::restoreResultsOrder() + { + // Restore the order of the results + std::sort(postHocResults.begin(), postHocResults.end(), [](const PostHocLine& a, const PostHocLine& b) { + return a.idx < b.idx; + }); } void Statistics::postHocHolmTest() { @@ -171,38 +208,32 @@ namespace platform { boost::math::normal dist(0.0, 1.0); double diff = sqrt(nModels * (nModels + 1) / (6.0 * nDatasets)); for (int i = 0; i < nModels; i++) { + PostHocLine line; + line.model = models[i]; + line.rank = ranks.at(models[i]); + line.wtl = wtl.at(i); + line.reject = false; if (i == controlIdx) { - stats[i] = 0.0; + postHocResults.push_back(line); continue; } double z = std::abs(ranks.at(models[controlIdx]) - ranks.at(models[i])) / diff; - double p_value = (long double)2 * (1 - cdf(dist, z)); - stats[i] = p_value; + line.pvalue = (long double)2 * (1 - cdf(dist, z)); + line.reject = (line.pvalue < significance); + postHocResults.push_back(line); } - // Sort the models by p-value - for (const auto& stat : stats) { - postHocData.push_back({ stat.first, stat.second }); - } - std::sort(postHocData.begin(), postHocData.end(), [](const std::pair& a, const std::pair& b) { - return a.second < b.second; + std::sort(postHocResults.begin(), postHocResults.end(), [](const PostHocLine& a, const PostHocLine& b) { + return a.rank < b.rank; }); - - // Holm adjustment - for (int i = 0; i < postHocData.size(); ++i) { - auto item = postHocData.at(i); - double before = i == 0 ? 0.0 : postHocData.at(i - 1).second; - double p_value = std::min((double)1.0, item.second * (nModels - i)); - p_value = std::max(before, p_value); - postHocData[i] = { item.first, p_value }; - } - postHocResult.model = models.at(controlIdx); + setResultsOrder(); + Holm_Bonferroni(); + restoreResultsOrder(); } void Statistics::postHocTestReport(bool friedmanResult, bool tex) { std::stringstream oss; - postHocResult.model = models.at(controlIdx); auto color = friedmanResult ? Colors::CYAN() : Colors::YELLOW(); oss << color; oss << " " << std::string(hlen + 25, '*') << std::endl; @@ -210,35 +241,21 @@ namespace platform { oss << " Control model: " << models.at(controlIdx) << std::endl; oss << " " << std::left << std::setw(maxModelName) << std::string("Model") << " p-value rank win tie loss Status" << std::endl; oss << " " << std::string(maxModelName, '=') << " ============ ========= === === ==== =============" << std::endl; - // sort ranks from lowest to highest - std::vector> ranksOrder; - for (const auto& rank : ranks) { - ranksOrder.push_back({ rank.first, rank.second }); - } - std::sort(ranksOrder.begin(), ranksOrder.end(), [](const std::pair& a, const std::pair& b) { - return a.second < b.second; - }); - // Show the control model info. - oss << " " << Colors::BLUE() << std::left << std::setw(maxModelName) << ranksOrder.at(0).first << " "; - oss << std::setw(12) << " " << std::setprecision(7) << std::fixed << " " << ranksOrder.at(0).second << std::endl; - for (const auto& item : ranksOrder) { - auto idx = distance(models.begin(), find(models.begin(), models.end(), item.first)); - double pvalue = 0.0; - for (const auto& stat : postHocData) { - if (stat.first == idx) { - pvalue = stat.second; - } - } - postHocResult.postHocLines.push_back({ item.first, pvalue, item.second, wtl.at(idx), pvalue < significance }); - if (item.first == models.at(controlIdx)) { + bool first = true; + for (const auto& item : postHocResults) { + if (first) { + oss << " " << Colors::BLUE() << std::left << std::setw(maxModelName) << item.model << " "; + oss << std::setw(12) << " " << std::setprecision(7) << std::fixed << " " << item.rank << std::endl; + first = false; continue; } + auto pvalue = item.pvalue; auto colorStatus = pvalue > significance ? Colors::GREEN() : Colors::MAGENTA(); auto status = pvalue > significance ? Symbols::check_mark : Symbols::cross; auto textStatus = pvalue > significance ? " accepted H0" : " rejected H0"; - oss << " " << colorStatus << std::left << std::setw(maxModelName) << item.first << " "; - oss << std::setprecision(6) << std::scientific << pvalue << std::setprecision(7) << std::fixed << " " << item.second; - oss << " " << std::right << std::setw(3) << wtl.at(idx).win << " " << std::setw(3) << wtl.at(idx).tie << " " << std::setw(4) << wtl.at(idx).loss; + oss << " " << colorStatus << std::left << std::setw(maxModelName) << item.model << " "; + oss << std::setprecision(6) << std::scientific << pvalue << std::setprecision(7) << std::fixed << " " << item.rank; + oss << " " << std::right << std::setw(3) << item.wtl.win << " " << std::setw(3) << item.wtl.tie << " " << std::setw(4) << item.wtl.loss; oss << " " << status << textStatus << std::endl; } oss << color << " " << std::string(hlen + 25, '*') << std::endl; @@ -249,8 +266,8 @@ namespace platform { if (tex) { BestResultsTex bestResultsTex(score); BestResultsMd bestResultsMd; - bestResultsTex.postHoc_test(postHocResult, postHocType, get_date() + " " + get_time()); - bestResultsMd.postHoc_test(postHocResult, postHocType, get_date() + " " + get_time()); + bestResultsTex.postHoc_test(postHocResults, postHocType, get_date() + " " + get_time()); + bestResultsMd.postHoc_test(postHocResults, postHocType, get_date() + " " + get_time()); } } bool Statistics::friedmanTest() @@ -294,16 +311,4 @@ namespace platform { friedmanResult = { friedmanQ, criticalValue, p_value, result }; return result; } - FriedmanResult& Statistics::getFriedmanResult() - { - return friedmanResult; - } - PostHocResult& Statistics::getPostHocResult() - { - return postHocResult; - } - std::map>& Statistics::getRanks() - { - return ranksModels; - } } // namespace platform diff --git a/src/best/Statistics.h b/src/best/Statistics.h index 765ed1d..a6b5c4a 100644 --- a/src/best/Statistics.h +++ b/src/best/Statistics.h @@ -9,9 +9,9 @@ namespace platform { using json = nlohmann::ordered_json; struct WTL { - int win; - int tie; - int loss; + uint win; + uint tie; + uint loss; }; struct FriedmanResult { double statistic; @@ -20,16 +20,14 @@ namespace platform { bool reject; }; struct PostHocLine { + uint idx; //index of the main order std::string model; long double pvalue; double rank; WTL wtl; bool reject; }; - struct PostHocResult { - std::string model; - std::vector postHocLines; - }; + class Statistics { public: Statistics(const std::string& score, const std::vector& models, const std::vector& datasets, const json& data, double significance = 0.05, bool output = true); @@ -37,15 +35,18 @@ namespace platform { void postHocTest(); void postHocTestReport(bool friedmanResult, bool tex); int getControlIdx(); - FriedmanResult& getFriedmanResult(); - PostHocResult& getPostHocResult(); - std::map>& getRanks(); + FriedmanResult& getFriedmanResult() { return friedmanResult; } + std::vector& getPostHocResults() { return postHocResults; } + std::map>& getRanks() { return ranksModels; } // ranks of the models per dataset private: void fit(); void postHocHolmTest(); void postHocWilcoxonTest(); void computeRanks(); void computeWTL(); + void Holm_Bonferroni(); + void setResultsOrder(); // Set the order of the results based on the statistic analysis needed + void restoreResultsOrder(); // Restore the order of the results after the Holm-Bonferroni adjustment const std::string& score; std::string postHocType; const std::vector& models; @@ -60,12 +61,11 @@ namespace platform { int greaterAverage = -1; // The model with the greater average score std::map wtl; std::map ranks; - std::vector> postHocData; int maxModelName = 0; int maxDatasetName = 0; int hlen; // length of the line FriedmanResult friedmanResult; - PostHocResult postHocResult; + std::vector postHocResults; std::map> ranksModels; }; } diff --git a/src/best/WilcoxonTest.hpp b/src/best/WilcoxonTest.hpp index 34c2969..dbf1c0c 100644 --- a/src/best/WilcoxonTest.hpp +++ b/src/best/WilcoxonTest.hpp @@ -23,11 +23,8 @@ namespace platform { class WilcoxonTest { public: - WilcoxonTest(const std::vector& models, - const std::vector& datasets, - const json& data, - double alpha = 0.05) - : models_(models), datasets_(datasets), data_(data), alpha_(alpha) + WilcoxonTest(const std::vector& models, const std::vector& datasets, + const json& data, double alpha = 0.05) : models_(models), datasets_(datasets), data_(data), alpha_(alpha) { buildAUCTable(); // extracts all AUCs into a dense matrix computeAverageAUCs(); // per‑model mean (→ control selection) @@ -36,10 +33,8 @@ namespace platform { buildPostHocResult(); // fills postHocResult_ } - //---------------------------------------------------- public API ---- int getControlIdx() const noexcept { return control_idx_; } - - const PostHocResult& getPostHocResult() const noexcept { return postHocResult_; } + const std::vector& getPostHocResults() const noexcept { return postHocResults_; } private: //-------------------------------------------------- helper structs ---- @@ -146,18 +141,14 @@ namespace platform { const std::size_t D = datasets_.size(); const std::string& control_name = models_[control_idx_]; - postHocResult_.model = control_name; - const double practical_threshold = 0.0005; // same heuristic as original code for (std::size_t i = 0; i < M; ++i) { - if (static_cast(i) == control_idx_) continue; - PostHocLine line; line.model = models_[i]; - line.rank = avg_rank_[i]; + line.rank = avg_auc_[i]; - WTL wtl; + WTL wtl = { 0, 0, 0 }; // win, tie, loss std::vector differences; differences.reserve(D); @@ -181,8 +172,12 @@ namespace platform { line.pvalue = differences.empty() ? 1.0L : static_cast(wilcoxonSignedRankTest(differences)); line.reject = (line.pvalue < alpha_); - postHocResult_.postHocLines.push_back(std::move(line)); + postHocResults_.push_back(std::move(line)); } + // Sort results by rank (descending) + std::sort(postHocResults_.begin(), postHocResults_.end(), [](const PostHocLine& a, const PostHocLine& b) { + return a.rank > b.rank; + }); } // ------------------------------------------------ Wilcoxon (private) -- @@ -243,8 +238,8 @@ namespace platform { std::vector rank_cnt_; // datasets counted per model int control_idx_ = -1; - PostHocResult postHocResult_; + std::vector postHocResults_; }; -} // namespace stats +} // namespace platform #endif // BEST_WILCOXON_TEST_HPP \ No newline at end of file diff --git a/src/reports/DatasetsConsole.cpp b/src/reports/DatasetsConsole.cpp index d402a25..82dab4f 100644 --- a/src/reports/DatasetsConsole.cpp +++ b/src/reports/DatasetsConsole.cpp @@ -26,6 +26,7 @@ namespace platform { auto datasets = platform::Datasets(false, platform::Paths::datasets()); std::stringstream sheader; auto datasets_names = datasets.getNames(); + std::cout << Colors::GREEN() << "Datasets available in the platform: " << datasets_names.size() << std::endl; int maxName = std::max(size_t(7), (*max_element(datasets_names.begin(), datasets_names.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size()); std::vector header_labels = { " #", "Dataset", "Sampl.", "Feat.", "#Num.", "Cls", "Balance" }; std::vector header_lengths = { 3, maxName, 6, 6, 6, 3, DatasetsConsole::BALANCE_LENGTH }; @@ -61,9 +62,13 @@ namespace platform { line << setw(header_lengths[5]) << right << nClasses << " "; std::string sep = ""; oss.str(""); - for (auto number : dataset.getClassesCounts()) { - oss << sep << std::setprecision(2) << fixed << (float)number / nSamples * 100.0 << "% (" << number << ")"; - sep = " / "; + if (nSamples == 0) { + oss << "No samples"; + } else { + for (auto number : dataset.getClassesCounts()) { + oss << sep << std::setprecision(2) << fixed << (float)number / nSamples * 100.0 << "% (" << number << ")"; + sep = " / "; + } } split_lines(maxName, line.str(), oss.str()); // Store data for Excel report From a6b6efce95d2ea246997331717555cf8568a0138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 25 May 2025 10:41:36 +0200 Subject: [PATCH 14/27] Remove uneeded output in Statistics --- src/best/Statistics.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/best/Statistics.cpp b/src/best/Statistics.cpp index 04418ce..1fa16ad 100644 --- a/src/best/Statistics.cpp +++ b/src/best/Statistics.cpp @@ -160,8 +160,11 @@ namespace platform { auto wilcoxon = WilcoxonTest(models, datasets, data, significance); controlIdx = wilcoxon.getControlIdx(); postHocResults = wilcoxon.getPostHocResults(); - std::cout << std::string(80, '=') << std::endl; setResultsOrder(); + // Fill the ranks info + for (const auto& item : postHocResults) { + ranks[item.model] = item.rank; + } Holm_Bonferroni(); restoreResultsOrder(); } From dcde8c01be278ec57050453b71c42db75262cc0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 28 May 2025 10:53:29 +0200 Subject: [PATCH 15/27] ADd std to screen output --- src/best/BestResults.cpp | 26 +++++++++++++++++++------- src/best/BestResults.h | 1 + src/commands/b_best.cpp | 4 +++- src/main/ArgumentsExperiment.cpp | 3 ++- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/best/BestResults.cpp b/src/best/BestResults.cpp index 5e7c349..4cffc84 100644 --- a/src/best/BestResults.cpp +++ b/src/best/BestResults.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "common/Colors.h" #include "common/CLocale.h" #include "common/Paths.h" @@ -123,16 +124,24 @@ namespace platform { } result = std::vector(models.begin(), models.end()); maxModelName = (*max_element(result.begin(), result.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); - maxModelName = std::max(12, maxModelName); + maxModelName = std::max(minLength, maxModelName); return result; } + std::string toLower(std::string data) + { + std::transform(data.begin(), data.end(), data.begin(), + [](unsigned char c) { return std::tolower(c); }); + return data; + } std::vector BestResults::getDatasets(json table) { std::vector datasets; for (const auto& dataset_ : table.items()) { datasets.push_back(dataset_.key()); } - std::stable_sort(datasets.begin(), datasets.end()); + std::stable_sort(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { + return toLower(a) < toLower(b); + }); maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); maxDatasetName = std::max(7, maxDatasetName); return datasets; @@ -266,12 +275,14 @@ namespace platform { // Print the row with red colors on max values for (const auto& model : models) { std::string efectiveColor = color; - double value; + double value, std; try { value = table[model].at(dataset_).at(0).get(); + std = table[model].at(dataset_).at(3).get(); } catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) { value = -1.0; + std = -1.0; } if (value == maxValue) { efectiveColor = Colors::RED(); @@ -280,7 +291,8 @@ namespace platform { std::cout << Colors::YELLOW() << std::setw(maxModelName) << std::right << "N/A" << " "; } else { totals[model].push_back(value); - std::cout << efectiveColor << std::setw(maxModelName) << std::setprecision(maxModelName - 2) << std::fixed << value << " "; + std::cout << efectiveColor << std::setw(maxModelName - 6) << std::setprecision(maxModelName - 8) << std::fixed << value; + std::cout << efectiveColor << "±" << std::setw(5) << std::setprecision(3) << std::fixed << std << " "; } } std::cout << std::endl; @@ -307,9 +319,9 @@ namespace platform { for (const auto& model : models) { std::string efectiveColor = model == best_model ? Colors::RED() : Colors::GREEN(); double value = std::reduce(totals[model].begin(), totals[model].end()) / nDatasets; - double std_value = compute_std(totals[model], value); - std::cout << efectiveColor << std::right << std::setw(maxModelName) << std::setprecision(maxModelName - 4) << std::fixed << value << " "; - + double std = compute_std(totals[model], value); + std::cout << efectiveColor << std::right << std::setw(maxModelName - 6) << std::setprecision(maxModelName - 8) << std::fixed << value; + std::cout << efectiveColor << "±" << std::setw(5) << std::setprecision(3) << std::fixed << std << " "; } std::cout << std::endl; } diff --git a/src/best/BestResults.h b/src/best/BestResults.h index fde6a74..4d3c307 100644 --- a/src/best/BestResults.h +++ b/src/best/BestResults.h @@ -32,6 +32,7 @@ namespace platform { double significance; int maxModelName = 0; int maxDatasetName = 0; + int minLength = 13; // Minimum length for scores }; } #endif \ No newline at end of file diff --git a/src/commands/b_best.cpp b/src/commands/b_best.cpp index fb4b5cd..aaefbec 100644 --- a/src/commands/b_best.cpp +++ b/src/commands/b_best.cpp @@ -5,14 +5,16 @@ #include "common/Paths.h" #include "common/Colors.h" #include "best/BestResults.h" +#include "common/DotEnv.h" #include "config_platform.h" void manageArguments(argparse::ArgumentParser& program) { + auto env = platform::DotEnv(); program.add_argument("-m", "--model").help("Model to use or any").default_value("any"); program.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results()); program.add_argument("-d", "--dataset").default_value("any").help("Filter results of the selected model) (any for all datasets)"); - program.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied"); + program.add_argument("-s", "--score").default_value(env.get("score")).help("Filter results of the score name supplied"); program.add_argument("--friedman").help("Friedman test").default_value(false).implicit_value(true); program.add_argument("--excel").help("Output to excel").default_value(false).implicit_value(true); program.add_argument("--tex").help("Output results to TeX & Markdown files").default_value(false).implicit_value(true); diff --git a/src/main/ArgumentsExperiment.cpp b/src/main/ArgumentsExperiment.cpp index 58bf990..d27f9a3 100644 --- a/src/main/ArgumentsExperiment.cpp +++ b/src/main/ArgumentsExperiment.cpp @@ -221,7 +221,8 @@ namespace platform { std::array buffer; // Run g++ --version and capture the output - std::unique_ptr pipe(popen("g++ --version", "r"), pclose); + using pclose_t = int(*)(FILE*); + std::unique_ptr pipe(popen("g++ --version", "r"), pclose); if (!pipe) { return "Error executing g++ --version command"; From 514968a0820a3f95312b265dd3d2df27d594f921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 28 May 2025 17:37:53 +0200 Subject: [PATCH 16/27] Open excel file automatically when generated --- src/best/BestResults.cpp | 2 + src/best/BestResults.h | 2 + src/commands/b_best.cpp | 6 +++ src/commands/b_list.cpp | 21 +++++++--- src/commands/b_manage.cpp | 64 ++--------------------------- src/common/Paths.h | 2 + src/common/Utils.h | 61 +++++++++++++++++++++++++++ src/reports/DatasetsExcel.cpp | 3 +- src/reports/DatasetsExcel.h | 1 + src/results/ResultsDatasetExcel.cpp | 3 +- src/results/ResultsDatasetExcel.h | 1 + 11 files changed, 97 insertions(+), 69 deletions(-) diff --git a/src/best/BestResults.cpp b/src/best/BestResults.cpp index 4cffc84..bd7f82a 100644 --- a/src/best/BestResults.cpp +++ b/src/best/BestResults.cpp @@ -336,6 +336,7 @@ namespace platform { BestResultsExcel excel_report(path, score, datasets); excel_report.reportSingle(model, path + Paths::bestResultsFile(score, model)); messageOutputFile("Excel", excel_report.getFileName()); + excelFileName = excel_report.getFileName(); } } void BestResults::reportAll(bool excel, bool tex, bool index) @@ -373,6 +374,7 @@ namespace platform { excel.reportSingle(model, path + Paths::bestResultsFile(score, model)); } messageOutputFile("Excel", excel.getFileName()); + excelFileName = excel.getFileName(); } } void BestResults::messageOutputFile(const std::string& title, const std::string& fileName) diff --git a/src/best/BestResults.h b/src/best/BestResults.h index 4d3c307..6d1af54 100644 --- a/src/best/BestResults.h +++ b/src/best/BestResults.h @@ -15,6 +15,7 @@ namespace platform { void reportSingle(bool excel); void reportAll(bool excel, bool tex, bool index); void buildAll(); + std::string getExcelFileName() const { return excelFileName; } private: std::vector getModels(); std::vector getDatasets(json table); @@ -33,6 +34,7 @@ namespace platform { int maxModelName = 0; int maxDatasetName = 0; int minLength = 13; // Minimum length for scores + std::string excelFileName; }; } #endif \ No newline at end of file diff --git a/src/commands/b_best.cpp b/src/commands/b_best.cpp index aaefbec..8c8b89e 100644 --- a/src/commands/b_best.cpp +++ b/src/commands/b_best.cpp @@ -4,6 +4,7 @@ #include "main/modelRegister.h" #include "common/Paths.h" #include "common/Colors.h" +#include "common/Utils.h" #include "best/BestResults.h" #include "common/DotEnv.h" #include "config_platform.h" @@ -80,6 +81,11 @@ int main(int argc, char** argv) std::cout << Colors::GREEN() << fileName << " created!" << Colors::RESET() << std::endl; results.reportSingle(excel); } + if (excel) { + auto fileName = results.getExcelFileName(); + std::cout << "Opening " << fileName << std::endl; + platform::openFile(fileName); + } std::cout << Colors::RESET(); return 0; } diff --git a/src/commands/b_list.cpp b/src/commands/b_list.cpp index 5309101..96950b0 100644 --- a/src/commands/b_list.cpp +++ b/src/commands/b_list.cpp @@ -8,6 +8,7 @@ #include "common/Paths.h" #include "common/Colors.h" #include "common/Datasets.h" +#include "common/Utils.h" #include "reports/DatasetsExcel.h" #include "reports/DatasetsConsole.h" #include "results/ResultsDatasetConsole.h" @@ -24,9 +25,13 @@ void list_datasets(argparse::ArgumentParser& program) std::cout << report.getOutput(); if (excel) { auto data = report.getData(); - auto report = platform::DatasetsExcel(); - report.report(data); - std::cout << std::endl << Colors::GREEN() << "Output saved in " << report.getFileName() << std::endl; + auto ereport = new platform::DatasetsExcel(); + ereport->report(data); + std::cout << std::endl << Colors::GREEN() << "Output saved in " << ereport->getFileName() << std::endl; + auto fileName = ereport->getExcelFileName(); + delete ereport; + std::cout << "Opening " << fileName << std::endl; + platform::openFile(fileName); } } @@ -42,9 +47,13 @@ void list_results(argparse::ArgumentParser& program) std::cout << report.getOutput(); if (excel) { auto data = report.getData(); - auto report = platform::ResultsDatasetExcel(); - report.report(data); - std::cout << std::endl << Colors::GREEN() << "Output saved in " << report.getFileName() << std::endl; + auto ereport = new platform::ResultsDatasetExcel(); + ereport->report(data); + std::cout << std::endl << Colors::GREEN() << "Output saved in " << ereport->getFileName() << std::endl; + auto fileName = ereport->getExcelFileName(); + delete ereport; + std::cout << "Opening " << fileName << std::endl; + platform::openFile(fileName); } } diff --git a/src/commands/b_manage.cpp b/src/commands/b_manage.cpp index 1d88ca2..8a0deb4 100644 --- a/src/commands/b_manage.cpp +++ b/src/commands/b_manage.cpp @@ -1,7 +1,7 @@ + +#include #include #include -#include -#include #include "common/Paths.h" #include #include "manage/ManageScreen.h" @@ -53,65 +53,7 @@ void handleResize(int sig) manager->updateSize(rows, cols); } -void openFile(const std::string& fileName) -{ - // #ifdef __APPLE__ - // // macOS uses the "open" command - // std::string command = "open"; - // #elif defined(__linux__) - // // Linux typically uses "xdg-open" - // std::string command = "xdg-open"; - // #else - // // For other OSes, do nothing or handle differently - // std::cerr << "Unsupported platform." << std::endl; - // return; - // #endif - // execlp(command.c_str(), command.c_str(), fileName.c_str(), NULL); -#ifdef __APPLE__ - const char* tool = "/usr/bin/open"; -#elif defined(__linux__) - const char* tool = "/usr/bin/xdg-open"; -#else - std::cerr << "Unsupported platform." << std::endl; - return; -#endif - // We'll build an argv array for execve: - std::vector argv; - argv.push_back(const_cast(tool)); // argv[0] - argv.push_back(const_cast(fileName.c_str())); // argv[1] - argv.push_back(nullptr); - - // Make a new environment array, skipping BASH_FUNC_ variables - std::vector filteredEnv; - for (char** env = environ; *env != nullptr; ++env) { - // *env is a string like "NAME=VALUE" - // We want to skip those starting with "BASH_FUNC_" - if (strncmp(*env, "BASH_FUNC_", 10) == 0) { - // skip it - continue; - } - filteredEnv.push_back(*env); - } - - // Convert filteredEnv into a char* array - std::vector envp; - for (auto& var : filteredEnv) { - envp.push_back(const_cast(var.c_str())); - } - envp.push_back(nullptr); - - // Now call execve with the cleaned environment - // NOTE: You may need a full path to the tool if it's not in PATH, or use which() logic - // For now, let's assume "open" or "xdg-open" is found in the default PATH: - execve(tool, argv.data(), envp.data()); - - // If we reach here, execve failed - perror("execve failed"); - // This would terminate your current process if it's not in a child - // Usually you'd do something like: - _exit(EXIT_FAILURE); -} int main(int argc, char** argv) { @@ -137,7 +79,7 @@ int main(int argc, char** argv) delete manager; if (!fileName.empty()) { std::cout << "Opening " << fileName << std::endl; - openFile(fileName); + platform::openFile(fileName); } return 0; } diff --git a/src/common/Paths.h b/src/common/Paths.h index 15a42d1..6861457 100644 --- a/src/common/Paths.h +++ b/src/common/Paths.h @@ -49,6 +49,7 @@ namespace platform { return "BestResults_" + score + ".xlsx"; } static std::string excelResults() { return "some_results.xlsx"; } + static std::string excelDatasets() { return "datasets.xlsx"; } static std::string grid_input(const std::string& model) { return grid() + "grid_" + model + "_input.json"; @@ -73,6 +74,7 @@ namespace platform { { return "post_hoc.md"; } + }; } #endif \ No newline at end of file diff --git a/src/common/Utils.h b/src/common/Utils.h index 92e5a4a..e371d89 100644 --- a/src/common/Utils.h +++ b/src/common/Utils.h @@ -1,5 +1,7 @@ #ifndef UTILS_H #define UTILS_H + +#include #include #include #include @@ -66,5 +68,64 @@ namespace platform { oss << std::put_time(timeinfo, "%H:%M:%S"); return oss.str(); } + static void openFile(const std::string& fileName) + { + // #ifdef __APPLE__ + // // macOS uses the "open" command + // std::string command = "open"; + // #elif defined(__linux__) + // // Linux typically uses "xdg-open" + // std::string command = "xdg-open"; + // #else + // // For other OSes, do nothing or handle differently + // std::cerr << "Unsupported platform." << std::endl; + // return; + // #endif + // execlp(command.c_str(), command.c_str(), fileName.c_str(), NULL); +#ifdef __APPLE__ + const char* tool = "/usr/bin/open"; +#elif defined(__linux__) + const char* tool = "/usr/bin/xdg-open"; +#else + std::cerr << "Unsupported platform." << std::endl; + return; +#endif + + // We'll build an argv array for execve: + std::vector argv; + argv.push_back(const_cast(tool)); // argv[0] + argv.push_back(const_cast(fileName.c_str())); // argv[1] + argv.push_back(nullptr); + + // Make a new environment array, skipping BASH_FUNC_ variables + std::vector filteredEnv; + for (char** env = environ; *env != nullptr; ++env) { + // *env is a string like "NAME=VALUE" + // We want to skip those starting with "BASH_FUNC_" + if (strncmp(*env, "BASH_FUNC_", 10) == 0) { + // skip it + continue; + } + filteredEnv.push_back(*env); + } + + // Convert filteredEnv into a char* array + std::vector envp; + for (auto& var : filteredEnv) { + envp.push_back(const_cast(var.c_str())); + } + envp.push_back(nullptr); + + // Now call execve with the cleaned environment + // NOTE: You may need a full path to the tool if it's not in PATH, or use which() logic + // For now, let's assume "open" or "xdg-open" is found in the default PATH: + execve(tool, argv.data(), envp.data()); + + // If we reach here, execve failed + perror("execve failed"); + // This would terminate your current process if it's not in a child + // Usually you'd do something like: + _exit(EXIT_FAILURE); + } } #endif \ No newline at end of file diff --git a/src/reports/DatasetsExcel.cpp b/src/reports/DatasetsExcel.cpp index a24def6..cb1dd37 100644 --- a/src/reports/DatasetsExcel.cpp +++ b/src/reports/DatasetsExcel.cpp @@ -1,8 +1,9 @@ +#include "common/Paths.h" #include "DatasetsExcel.h" namespace platform { DatasetsExcel::DatasetsExcel() { - file_name = "datasets.xlsx"; + file_name = Paths::excelDatasets(); workbook = workbook_new(getFileName().c_str()); createFormats(); setProperties("Datasets"); diff --git a/src/reports/DatasetsExcel.h b/src/reports/DatasetsExcel.h index cd543cd..4b528b3 100644 --- a/src/reports/DatasetsExcel.h +++ b/src/reports/DatasetsExcel.h @@ -11,6 +11,7 @@ namespace platform { DatasetsExcel(); ~DatasetsExcel(); void report(json& data); + std::string getExcelFileName() { return getFileName(); } }; } #endif \ No newline at end of file diff --git a/src/results/ResultsDatasetExcel.cpp b/src/results/ResultsDatasetExcel.cpp index b4f1225..df1604e 100644 --- a/src/results/ResultsDatasetExcel.cpp +++ b/src/results/ResultsDatasetExcel.cpp @@ -1,8 +1,9 @@ +#include "common/Paths.h" #include "ResultsDatasetExcel.h" namespace platform { ResultsDatasetExcel::ResultsDatasetExcel() { - file_name = "some_results.xlsx"; + file_name = Paths::excelResults(); workbook = workbook_new(getFileName().c_str()); createFormats(); setProperties("Results"); diff --git a/src/results/ResultsDatasetExcel.h b/src/results/ResultsDatasetExcel.h index 83226cc..3f9b968 100644 --- a/src/results/ResultsDatasetExcel.h +++ b/src/results/ResultsDatasetExcel.h @@ -12,6 +12,7 @@ namespace platform { ResultsDatasetExcel(); ~ResultsDatasetExcel(); void report(json& data); + std::string getExcelFileName() { return getFileName(); } }; } #endif \ No newline at end of file From 3b158e9fc1eceae30b11cd525e6237b543aa8830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 15 Jun 2025 12:07:12 +0200 Subject: [PATCH 17/27] Add AdaBoost --- src/main/Models.h | 1 + src/main/modelRegister.h | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/Models.h b/src/main/Models.h index 565a96d..076a926 100644 --- a/src/main/Models.h +++ b/src/main/Models.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "../experimental_clfs/XA1DE.h" diff --git a/src/main/modelRegister.h b/src/main/modelRegister.h index 0dbd269..22a800c 100644 --- a/src/main/modelRegister.h +++ b/src/main/modelRegister.h @@ -35,6 +35,8 @@ namespace platform { [](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();}); static Registrar registrarXGB("XGBoost", [](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();}); + static Registrar registrarAda("AdaBoost", + [](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoost();}); static Registrar registrarXSPODE("XSPODE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);}); static Registrar registrarXSP2DE("XSP2DE", @@ -44,6 +46,6 @@ namespace platform { static Registrar registrarXBA2DE("XBA2DE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XBA2DE();}); static Registrar registrarXA1DE("XA1DE", - [](void) -> bayesnet::BaseClassifier* { return new XA1DE();}); + [](void) -> bayesnet::BaseClassifier* { return new XA1DE();}); } #endif From 8c413a1eb0a0602811e12456fd6c03a0bafed8a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 16 Jun 2025 00:11:51 +0200 Subject: [PATCH 18/27] Begin to add AdaBoost implementation --- src/CMakeLists.txt | 2 + src/experimental_clfs/AdaBoost.cpp | 214 +++++++++++++++++++++++++++++ src/experimental_clfs/AdaBoost.h | 58 ++++++++ src/main/Models.h | 1 + src/main/modelRegister.h | 4 +- 5 files changed, 278 insertions(+), 1 deletion(-) create mode 100644 src/experimental_clfs/AdaBoost.cpp create mode 100644 src/experimental_clfs/AdaBoost.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b89cebc..1f719c1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -33,6 +33,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources} results/Result.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp + experimental_clfs/AdaBoost.cpp ) target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) @@ -56,6 +57,7 @@ add_executable(b_main commands/b_main.cpp ${main_sources} results/Result.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp + experimental_clfs/ExpClf.cpp ) target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp new file mode 100644 index 0000000..63523d7 --- /dev/null +++ b/src/experimental_clfs/AdaBoost.cpp @@ -0,0 +1,214 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include "AdaBoost.h" +#include +#include +#include +#include +#include + +namespace platform { + + AdaBoost::AdaBoost(int n_estimators) + : Ensemble(true), n_estimators(n_estimators) + { + validHyperparameters = { "n_estimators" }; + } + + void AdaBoost::buildModel(const torch::Tensor& weights) + { + // Initialize variables + models.clear(); + alphas.clear(); + training_errors.clear(); + + // Initialize sample weights uniformly + int n_samples = dataset.size(1); + sample_weights = torch::ones({ n_samples }) / n_samples; + + // If initial weights are provided, incorporate them + if (weights.defined() && weights.numel() > 0) { + sample_weights *= weights; + normalizeWeights(); + } + + // Main AdaBoost training loop (SAMME algorithm) + for (int iter = 0; iter < n_estimators; ++iter) { + // Train base estimator with current sample weights + auto estimator = trainBaseEstimator(sample_weights); + + // Calculate weighted error + double weighted_error = calculateWeightedError(estimator.get(), sample_weights); + training_errors.push_back(weighted_error); + + // Check if error is too high (worse than random guessing) + double random_guess_error = 1.0 - (1.0 / getClassNumStates()); + if (weighted_error >= random_guess_error) { + // If only one estimator and it's worse than random, keep it with zero weight + if (models.empty()) { + models.push_back(std::move(estimator)); + alphas.push_back(0.0); + } + break; // Stop boosting + } + + // Calculate alpha (estimator weight) using SAMME formula + // alpha = log((1 - err) / err) + log(K - 1) + double alpha = std::log((1.0 - weighted_error) / weighted_error) + + std::log(static_cast(getClassNumStates() - 1)); + + // Store the estimator and its weight + models.push_back(std::move(estimator)); + alphas.push_back(alpha); + + // Update sample weights + updateSampleWeights(models.back().get(), alpha); + + // Normalize weights + normalizeWeights(); + + // Check for perfect classification + if (weighted_error < 1e-10) { + break; + } + } + + // Set the number of models actually trained + n_models = models.size(); + } + + void AdaBoost::trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) + { + // AdaBoost handles its own weight management, so we just build the model + buildModel(weights); + } + + std::unique_ptr AdaBoost::trainBaseEstimator(const torch::Tensor& weights) + { + // Create a new classifier instance + // You need to implement this based on your specific base classifier + // For example, if using Decision Trees: + // auto classifier = std::make_unique(); + + // Or if using a factory method: + // auto classifier = ClassifierFactory::create("DecisionTree"); + + // Placeholder - replace with actual classifier creation + throw std::runtime_error("AdaBoost::trainBaseEstimator - You need to implement base classifier creation"); + + // Once you have the classifier creation implemented, uncomment: + // classifier->fit(dataset, features, className, states, weights, Smoothing_t::NONE); + // return classifier; + } + + double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights) + { + // Get predictions from the estimator + auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }); + auto y_true = dataset.index({ -1, torch::indexing::Slice() }); + auto y_pred = estimator->predict(X.t()); + + // Calculate weighted error + auto incorrect = (y_pred != y_true).to(torch::kFloat); + double weighted_error = torch::sum(incorrect * weights).item(); + + return weighted_error; + } + + void AdaBoost::updateSampleWeights(Classifier* estimator, double alpha) + { + // Get predictions from the estimator + auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }); + auto y_true = dataset.index({ -1, torch::indexing::Slice() }); + auto y_pred = estimator->predict(X.t()); + + // Update weights according to SAMME algorithm + // w_i = w_i * exp(alpha * I(y_i != y_pred_i)) + auto incorrect = (y_pred != y_true).to(torch::kFloat); + sample_weights *= torch::exp(alpha * incorrect); + } + + void AdaBoost::normalizeWeights() + { + // Normalize weights to sum to 1 + double sum_weights = torch::sum(sample_weights).item(); + if (sum_weights > 0) { + sample_weights /= sum_weights; + } + } + + std::vector AdaBoost::graph(const std::string& title) const + { + // Create a graph representation of the AdaBoost ensemble + std::vector graph_lines; + + // Header + graph_lines.push_back("digraph AdaBoost {"); + graph_lines.push_back(" rankdir=TB;"); + graph_lines.push_back(" node [shape=box];"); + + if (!title.empty()) { + graph_lines.push_back(" label=\"" + title + "\";"); + graph_lines.push_back(" labelloc=t;"); + } + + // Add input node + graph_lines.push_back(" Input [shape=ellipse, label=\"Input Features\"];"); + + // Add base estimators + for (size_t i = 0; i < models.size(); ++i) { + std::stringstream ss; + ss << " Estimator" << i << " [label=\"Base Estimator " << i + 1 + << "\\nα = " << std::fixed << std::setprecision(3) << alphas[i] << "\"];"; + graph_lines.push_back(ss.str()); + + // Connect input to estimator + ss.str(""); + ss << " Input -> Estimator" << i << ";"; + graph_lines.push_back(ss.str()); + } + + // Add combination node + graph_lines.push_back(" Combination [shape=diamond, label=\"Weighted Vote\"];"); + + // Connect estimators to combination + for (size_t i = 0; i < models.size(); ++i) { + std::stringstream ss; + ss << " Estimator" << i << " -> Combination;"; + graph_lines.push_back(ss.str()); + } + + // Add output node + graph_lines.push_back(" Output [shape=ellipse, label=\"Final Prediction\"];"); + graph_lines.push_back(" Combination -> Output;"); + + // Close graph + graph_lines.push_back("}"); + + return graph_lines; + } + + void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters) + { + // Set hyperparameters from JSON + auto it = hyperparameters.find("n_estimators"); + if (it != hyperparameters.end()) { + n_estimators = it->get(); + if (n_estimators <= 0) { + throw std::invalid_argument("n_estimators must be positive"); + } + } + + // Check for invalid hyperparameters + for (auto& [key, value] : hyperparameters.items()) { + if (std::find(validHyperparameters.begin(), validHyperparameters.end(), key) == validHyperparameters.end()) { + throw std::invalid_argument("Invalid hyperparameter: " + key); + } + } + } + +} // namespace bayesnet \ No newline at end of file diff --git a/src/experimental_clfs/AdaBoost.h b/src/experimental_clfs/AdaBoost.h new file mode 100644 index 0000000..59ac241 --- /dev/null +++ b/src/experimental_clfs/AdaBoost.h @@ -0,0 +1,58 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#ifndef ADABOOST_H +#define ADABOOST_H + +#include +#include +#include +#include + +namespace platform { + class AdaBoost : public bayesnet::Ensemble { + public: + explicit AdaBoost(int n_estimators = 100); + virtual ~AdaBoost() = default; + + // Override base class methods + std::vector graph(const std::string& title = "") const override; + + // AdaBoost specific methods + void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; } + int getNEstimators() const { return n_estimators; } + + // Get the weight of each base estimator + std::vector getEstimatorWeights() const { return alphas; } + + // Override setHyperparameters from BaseClassifier + void setHyperparameters(const nlohmann::json& hyperparameters) override; + + protected: + void buildModel(const torch::Tensor& weights) override; + void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override; + + private: + int n_estimators; + std::vector alphas; // Weight of each base estimator + std::vector training_errors; // Training error at each iteration + torch::Tensor sample_weights; // Current sample weights + + // Train a single base estimator + std::unique_ptr trainBaseEstimator(const torch::Tensor& weights); + + // Calculate weighted error + double calculateWeightedError(Classifier* estimator, const torch::Tensor& weights); + + // Update sample weights based on predictions + void updateSampleWeights(Classifier* estimator, double alpha); + + // Normalize weights to sum to 1 + void normalizeWeights(); + }; +} + +#endif // ADABOOST_H \ No newline at end of file diff --git a/src/main/Models.h b/src/main/Models.h index 076a926..d6cf449 100644 --- a/src/main/Models.h +++ b/src/main/Models.h @@ -26,6 +26,7 @@ #include #include #include "../experimental_clfs/XA1DE.h" +#include "../experimental_clfs/AdaBoost.h" namespace platform { class Models { diff --git a/src/main/modelRegister.h b/src/main/modelRegister.h index 22a800c..11ebc84 100644 --- a/src/main/modelRegister.h +++ b/src/main/modelRegister.h @@ -35,8 +35,10 @@ namespace platform { [](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();}); static Registrar registrarXGB("XGBoost", [](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();}); - static Registrar registrarAda("AdaBoost", + static Registrar registrarAda("AdaBoostPy", [](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoost();}); + // static Registrar registrarAda2("AdaBoost", + // [](void) -> bayesnet::BaseClassifier* { return new platform::AdaBoost();}); static Registrar registrarXSPODE("XSPODE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);}); static Registrar registrarXSP2DE("XSP2DE", From 023d5613b4a81f93950fb0dee0e658bf476b1538 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 17 Jun 2025 13:48:11 +0200 Subject: [PATCH 19/27] Add DecisionTree with tests --- Makefile | 1 + src/CMakeLists.txt | 8 +- src/experimental_clfs/AdaBoost.cpp | 42 +- src/experimental_clfs/AdaBoost.h | 15 +- src/experimental_clfs/DecisionTree.cpp | 519 +++++++++++++++++++++++++ src/experimental_clfs/DecisionTree.h | 129 ++++++ src/experimental_clfs/README.md | 142 +++++++ src/experimental_clfs/TensorUtils.hpp | 13 + src/main/Models.h | 5 +- src/main/modelRegister.h | 10 +- tests/CMakeLists.txt | 6 +- tests/TestDecisionTree.cpp | 303 +++++++++++++++ tests/TestPlatform.cpp | 8 +- tests/TestScores.cpp | 143 +++++-- tests/TestUtils.h | 2 +- vcpkg.json | 7 +- 16 files changed, 1272 insertions(+), 81 deletions(-) create mode 100644 src/experimental_clfs/DecisionTree.cpp create mode 100644 src/experimental_clfs/DecisionTree.h create mode 100644 src/experimental_clfs/README.md create mode 100644 tests/TestDecisionTree.cpp diff --git a/Makefile b/Makefile index fa659c8..aff3116 100644 --- a/Makefile +++ b/Makefile @@ -96,6 +96,7 @@ opt = "" test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximum Spanning Tree'") to run only that section @echo ">>> Running Platform tests..."; @$(MAKE) clean + @$(MAKE) debug @cmake --build $(f_debug) -t $(test_targets) --parallel @for t in $(test_targets); do \ if [ -f $(f_debug)/tests/$$t ]; then \ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1f719c1..54bc2de 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,6 +20,8 @@ add_executable( results/Result.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp + experimental_clfs/DecisionTree.cpp + ) target_link_libraries(b_best Boost::boost "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}") @@ -33,7 +35,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources} results/Result.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp - experimental_clfs/AdaBoost.cpp + experimental_clfs/DecisionTree.cpp ) target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) @@ -45,6 +47,8 @@ add_executable(b_list commands/b_list.cpp results/Result.cpp results/ResultsDatasetExcel.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp + experimental_clfs/DecisionTree.cpp + ) target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}") @@ -58,6 +62,8 @@ add_executable(b_main commands/b_main.cpp ${main_sources} experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp experimental_clfs/ExpClf.cpp + experimental_clfs/DecisionTree.cpp + ) target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index 63523d7..ff5c40e 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -5,18 +5,19 @@ // *************************************************************** #include "AdaBoost.h" +#include "DecisionTree.h" #include #include #include #include #include -namespace platform { +namespace bayesnet { - AdaBoost::AdaBoost(int n_estimators) - : Ensemble(true), n_estimators(n_estimators) + AdaBoost::AdaBoost(int n_estimators, int max_depth) + : Ensemble(true), n_estimators(n_estimators), base_max_depth(max_depth) { - validHyperparameters = { "n_estimators" }; + validHyperparameters = { "n_estimators", "base_max_depth" }; } void AdaBoost::buildModel(const torch::Tensor& weights) @@ -89,20 +90,14 @@ namespace platform { std::unique_ptr AdaBoost::trainBaseEstimator(const torch::Tensor& weights) { - // Create a new classifier instance - // You need to implement this based on your specific base classifier - // For example, if using Decision Trees: - // auto classifier = std::make_unique(); + // Create a decision tree with specified max depth + // For AdaBoost, we typically use shallow trees (stumps with max_depth=1) + auto tree = std::make_unique(base_max_depth); - // Or if using a factory method: - // auto classifier = ClassifierFactory::create("DecisionTree"); + // Fit the tree with the current sample weights + tree->fit(dataset, features, className, states, weights, Smoothing_t::NONE); - // Placeholder - replace with actual classifier creation - throw std::runtime_error("AdaBoost::trainBaseEstimator - You need to implement base classifier creation"); - - // Once you have the classifier creation implemented, uncomment: - // classifier->fit(dataset, features, className, states, weights, Smoothing_t::NONE); - // return classifier; + return tree; } double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights) @@ -192,8 +187,9 @@ namespace platform { return graph_lines; } - void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters) + void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters_) { + auto hyperparameters = hyperparameters_; // Set hyperparameters from JSON auto it = hyperparameters.find("n_estimators"); if (it != hyperparameters.end()) { @@ -201,14 +197,18 @@ namespace platform { if (n_estimators <= 0) { throw std::invalid_argument("n_estimators must be positive"); } + hyperparameters.erase("n_estimators"); // Remove 'n_estimators' if present } - // Check for invalid hyperparameters - for (auto& [key, value] : hyperparameters.items()) { - if (std::find(validHyperparameters.begin(), validHyperparameters.end(), key) == validHyperparameters.end()) { - throw std::invalid_argument("Invalid hyperparameter: " + key); + it = hyperparameters.find("base_max_depth"); + if (it != hyperparameters.end()) { + base_max_depth = it->get(); + if (base_max_depth <= 0) { + throw std::invalid_argument("base_max_depth must be positive"); } + hyperparameters.erase("base_max_depth"); // Remove 'base_max_depth' if present } + Ensemble::setHyperparameters(hyperparameters); } } // namespace bayesnet \ No newline at end of file diff --git a/src/experimental_clfs/AdaBoost.h b/src/experimental_clfs/AdaBoost.h index 59ac241..c9e4ede 100644 --- a/src/experimental_clfs/AdaBoost.h +++ b/src/experimental_clfs/AdaBoost.h @@ -9,13 +9,12 @@ #include #include -#include -#include +#include "bayesnet/ensembles/Ensemble.h" -namespace platform { - class AdaBoost : public bayesnet::Ensemble { +namespace bayesnet { + class AdaBoost : public Ensemble { public: - explicit AdaBoost(int n_estimators = 100); + explicit AdaBoost(int n_estimators = 50, int max_depth = 1); virtual ~AdaBoost() = default; // Override base class methods @@ -24,10 +23,15 @@ namespace platform { // AdaBoost specific methods void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; } int getNEstimators() const { return n_estimators; } + void setBaseMaxDepth(int depth) { this->base_max_depth = depth; } + int getBaseMaxDepth() const { return base_max_depth; } // Get the weight of each base estimator std::vector getEstimatorWeights() const { return alphas; } + // Get training errors for each iteration + std::vector getTrainingErrors() const { return training_errors; } + // Override setHyperparameters from BaseClassifier void setHyperparameters(const nlohmann::json& hyperparameters) override; @@ -37,6 +41,7 @@ namespace platform { private: int n_estimators; + int base_max_depth; // Max depth for base decision trees std::vector alphas; // Weight of each base estimator std::vector training_errors; // Training error at each iteration torch::Tensor sample_weights; // Current sample weights diff --git a/src/experimental_clfs/DecisionTree.cpp b/src/experimental_clfs/DecisionTree.cpp new file mode 100644 index 0000000..c615504 --- /dev/null +++ b/src/experimental_clfs/DecisionTree.cpp @@ -0,0 +1,519 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include "DecisionTree.h" +#include +#include +#include +#include +#include +#include "TensorUtils.hpp" + +namespace bayesnet { + + DecisionTree::DecisionTree(int max_depth, int min_samples_split, int min_samples_leaf) + : Classifier(Network()), max_depth(max_depth), + min_samples_split(min_samples_split), min_samples_leaf(min_samples_leaf) + { + validHyperparameters = { "max_depth", "min_samples_split", "min_samples_leaf" }; + } + + void DecisionTree::setHyperparameters(const nlohmann::json& hyperparameters_) + { + auto hyperparameters = hyperparameters_; + // Set hyperparameters from JSON + auto it = hyperparameters.find("max_depth"); + if (it != hyperparameters.end()) { + max_depth = it->get(); + hyperparameters.erase("max_depth"); // Remove 'order' if present + } + + it = hyperparameters.find("min_samples_split"); + if (it != hyperparameters.end()) { + min_samples_split = it->get(); + hyperparameters.erase("min_samples_split"); // Remove 'min_samples_split' if present + } + + it = hyperparameters.find("min_samples_leaf"); + if (it != hyperparameters.end()) { + min_samples_leaf = it->get(); + hyperparameters.erase("min_samples_leaf"); // Remove 'min_samples_leaf' if present + } + Classifier::setHyperparameters(hyperparameters); + checkValues(); + } + void DecisionTree::checkValues() + { + if (max_depth <= 0) { + throw std::invalid_argument("max_depth must be positive"); + } + if (min_samples_leaf <= 0) { + throw std::invalid_argument("min_samples_leaf must be positive"); + } + if (min_samples_split <= 0) { + throw std::invalid_argument("min_samples_split must be positive"); + } + } + void DecisionTree::buildModel(const torch::Tensor& weights) + { + // Extract features (X) and labels (y) from dataset + auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }).t(); + auto y = dataset.index({ -1, torch::indexing::Slice() }); + + if (X.size(0) != y.size(0)) { + throw std::runtime_error("X and y must have the same number of samples"); + } + + n_classes = states[className].size(); + + // Use provided weights or uniform weights + torch::Tensor sample_weights; + if (weights.defined() && weights.numel() > 0) { + if (weights.size(0) != X.size(0)) { + throw std::runtime_error("weights must have the same length as number of samples"); + } + sample_weights = weights; + } else { + sample_weights = torch::ones({ X.size(0) }) / X.size(0); + } + + // Normalize weights + sample_weights = sample_weights / sample_weights.sum(); + + // Build the tree + root = buildTree(X, y, sample_weights, 0); + + // Mark as fitted + fitted = true; + } + bool DecisionTree::validateTensors(const torch::Tensor& X, const torch::Tensor& y, + const torch::Tensor& sample_weights) const + { + if (X.size(0) != y.size(0) || X.size(0) != sample_weights.size(0)) { + return false; + } + if (X.size(0) == 0) { + return false; + } + return true; + } + + std::unique_ptr DecisionTree::buildTree( + const torch::Tensor& X, + const torch::Tensor& y, + const torch::Tensor& sample_weights, + int current_depth) + { + auto node = std::make_unique(); + int n_samples = y.size(0); + + // Check stopping criteria + auto unique = at::_unique(y); + bool should_stop = (current_depth >= max_depth) || + (n_samples < min_samples_split) || + (std::get<0>(unique).size(0) == 1); // All samples same class + + if (should_stop || n_samples <= min_samples_leaf) { + // Create leaf node + node->is_leaf = true; + + // Calculate class probabilities + node->class_probabilities = torch::zeros({ n_classes }); + + for (int i = 0; i < n_samples; i++) { + int class_idx = y[i].item(); + node->class_probabilities[class_idx] += sample_weights[i].item(); + } + + // Normalize probabilities + node->class_probabilities /= node->class_probabilities.sum(); + + // Set predicted class as the one with highest probability + node->predicted_class = torch::argmax(node->class_probabilities).item(); + + return node; + } + + // Find best split + SplitInfo best_split = findBestSplit(X, y, sample_weights); + + // If no valid split found, create leaf + if (best_split.feature_index == -1 || best_split.impurity_decrease <= 0) { + node->is_leaf = true; + + // Calculate class probabilities + node->class_probabilities = torch::zeros({ n_classes }); + + for (int i = 0; i < n_samples; i++) { + int class_idx = y[i].item(); + node->class_probabilities[class_idx] += sample_weights[i].item(); + } + + node->class_probabilities /= node->class_probabilities.sum(); + node->predicted_class = torch::argmax(node->class_probabilities).item(); + + return node; + } + + // Create internal node + node->is_leaf = false; + node->split_feature = best_split.feature_index; + node->split_value = best_split.split_value; + + // Split data + auto left_X = X.index({ best_split.left_mask }); + auto left_y = y.index({ best_split.left_mask }); + auto left_weights = sample_weights.index({ best_split.left_mask }); + + auto right_X = X.index({ best_split.right_mask }); + auto right_y = y.index({ best_split.right_mask }); + auto right_weights = sample_weights.index({ best_split.right_mask }); + + // Recursively build subtrees + if (left_X.size(0) >= min_samples_leaf) { + node->left = buildTree(left_X, left_y, left_weights, current_depth + 1); + } else { + // Force leaf if not enough samples + node->left = std::make_unique(); + node->left->is_leaf = true; + auto mode = std::get<0>(torch::mode(left_y)); + node->left->predicted_class = mode.item(); + node->left->class_probabilities = torch::zeros({ n_classes }); + node->left->class_probabilities[node->left->predicted_class] = 1.0; + } + + if (right_X.size(0) >= min_samples_leaf) { + node->right = buildTree(right_X, right_y, right_weights, current_depth + 1); + } else { + // Force leaf if not enough samples + node->right = std::make_unique(); + node->right->is_leaf = true; + auto mode = std::get<0>(torch::mode(right_y)); + node->right->predicted_class = mode.item(); + node->right->class_probabilities = torch::zeros({ n_classes }); + node->right->class_probabilities[node->right->predicted_class] = 1.0; + } + + return node; + } + + DecisionTree::SplitInfo DecisionTree::findBestSplit( + const torch::Tensor& X, + const torch::Tensor& y, + const torch::Tensor& sample_weights) + { + + SplitInfo best_split; + best_split.feature_index = -1; + best_split.split_value = -1; + best_split.impurity_decrease = -std::numeric_limits::infinity(); + + int n_features = X.size(1); + int n_samples = X.size(0); + + // Calculate impurity of current node + double current_impurity = calculateGiniImpurity(y, sample_weights); + double total_weight = sample_weights.sum().item(); + + // Try each feature + for (int feat_idx = 0; feat_idx < n_features; feat_idx++) { + auto feature_values = X.index({ torch::indexing::Slice(), feat_idx }); + auto unique_values = std::get<0>(torch::unique_consecutive(std::get<0>(torch::sort(feature_values)))); + + // Try each unique value as split point + for (int i = 0; i < unique_values.size(0); i++) { + int split_val = unique_values[i].item(); + + // Create masks for left and right splits + auto left_mask = feature_values == split_val; + auto right_mask = ~left_mask; + + int left_count = left_mask.sum().item(); + int right_count = right_mask.sum().item(); + + // Skip if split doesn't satisfy minimum samples requirement + if (left_count < min_samples_leaf || right_count < min_samples_leaf) { + continue; + } + + // Calculate weighted impurities + auto left_y = y.index({ left_mask }); + auto left_weights = sample_weights.index({ left_mask }); + double left_weight = left_weights.sum().item(); + double left_impurity = calculateGiniImpurity(left_y, left_weights); + + auto right_y = y.index({ right_mask }); + auto right_weights = sample_weights.index({ right_mask }); + double right_weight = right_weights.sum().item(); + double right_impurity = calculateGiniImpurity(right_y, right_weights); + + // Calculate impurity decrease + double impurity_decrease = current_impurity - + (left_weight / total_weight * left_impurity + + right_weight / total_weight * right_impurity); + + // Update best split if this is better + if (impurity_decrease > best_split.impurity_decrease) { + best_split.feature_index = feat_idx; + best_split.split_value = split_val; + best_split.impurity_decrease = impurity_decrease; + best_split.left_mask = left_mask; + best_split.right_mask = right_mask; + } + } + } + + return best_split; + } + + double DecisionTree::calculateGiniImpurity( + const torch::Tensor& y, + const torch::Tensor& sample_weights) + { + if (y.size(0) == 0 || sample_weights.size(0) == 0) { + return 0.0; + } + + if (y.size(0) != sample_weights.size(0)) { + throw std::runtime_error("y and sample_weights must have same size"); + } + + torch::Tensor class_weights = torch::zeros({ n_classes }); + + // Calculate weighted class counts + for (int i = 0; i < y.size(0); i++) { + int class_idx = y[i].item(); + + if (class_idx < 0 || class_idx >= n_classes) { + throw std::runtime_error("Invalid class index: " + std::to_string(class_idx)); + } + + class_weights[class_idx] += sample_weights[i].item(); + } + + // Normalize + double total_weight = class_weights.sum().item(); + if (total_weight == 0) return 0.0; + + class_weights /= total_weight; + + // Calculate Gini impurity: 1 - sum(p_i^2) + double gini = 1.0; + for (int i = 0; i < n_classes; i++) { + double p = class_weights[i].item(); + gini -= p * p; + } + + return gini; + } + + + torch::Tensor DecisionTree::predict(torch::Tensor& X) + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + int n_samples = X.size(1); + torch::Tensor predictions = torch::zeros({ n_samples }, torch::kInt32); + + for (int i = 0; i < n_samples; i++) { + auto sample = X.index({ torch::indexing::Slice(), i }).ravel(); + predictions[i] = predictSample(sample); + } + + return predictions; + } + void dumpTensor(const torch::Tensor& tensor, const std::string& name) + { + std::cout << name << ": " << std::endl; + for (int i = 0; i < tensor.size(0); i++) { + std::cout << "["; + for (int j = 0; j < tensor.size(1); j++) { + std::cout << tensor[i][j].item() << " "; + } + std::cout << "]" << std::endl; + } + std::cout << std::endl; + } + void dumpVector(const std::vector>& vec, const std::string& name) + { + std::cout << name << ": " << std::endl;; + for (const auto& row : vec) { + std::cout << "["; + for (const auto& val : row) { + std::cout << val << " "; + } + std::cout << "] " << std::endl; + } + std::cout << std::endl; + } + + std::vector DecisionTree::predict(std::vector>& X) + { + // Convert to tensor + long n = X.size(); + long m = X.at(0).size(); + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); + auto predictions = predict(X_tensor); + std::vector result = platform::TensorUtils::to_vector(predictions); + return result; + } + + torch::Tensor DecisionTree::predict_proba(torch::Tensor& X) + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + int n_samples = X.size(1); + torch::Tensor probabilities = torch::zeros({ n_samples, n_classes }); + + for (int i = 0; i < n_samples; i++) { + auto sample = X.index({ torch::indexing::Slice(), i }).ravel(); + probabilities[i] = predictProbaSample(sample); + } + + return probabilities; + } + + std::vector> DecisionTree::predict_proba(std::vector>& X) + { + auto n_samples = X.at(0).size(); + // Convert to tensor + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); + auto proba_tensor = predict_proba(X_tensor); + std::vector> result(n_samples, std::vector(n_classes, 0.0)); + + for (int i = 0; i < n_samples; i++) { + for (int j = 0; j < n_classes; j++) { + result[i][j] = proba_tensor[i][j].item(); + } + } + + return result; + } + + int DecisionTree::predictSample(const torch::Tensor& x) const + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + if (x.size(0) != n) { // n debería ser el número de características + throw std::runtime_error("Input sample has wrong number of features"); + } + + const TreeNode* leaf = traverseTree(x, root.get()); + return leaf->predicted_class; + } + torch::Tensor DecisionTree::predictProbaSample(const torch::Tensor& x) const + { + const TreeNode* leaf = traverseTree(x, root.get()); + return leaf->class_probabilities.clone(); + } + + + const TreeNode* DecisionTree::traverseTree(const torch::Tensor& x, const TreeNode* node) const + { + if (!node) { + throw std::runtime_error("Null node encountered during tree traversal"); + } + + if (node->is_leaf) { + return node; + } + + if (node->split_feature < 0 || node->split_feature >= x.size(0)) { + throw std::runtime_error("Invalid split_feature index: " + std::to_string(node->split_feature)); + } + + int feature_value = x[node->split_feature].item(); + + if (feature_value == node->split_value) { + if (!node->left) { + throw std::runtime_error("Missing left child in tree"); + } + return traverseTree(x, node->left.get()); + } else { + if (!node->right) { + throw std::runtime_error("Missing right child in tree"); + } + return traverseTree(x, node->right.get()); + } + } + + std::vector DecisionTree::graph(const std::string& title) const + { + std::vector lines; + lines.push_back("digraph DecisionTree {"); + lines.push_back(" rankdir=TB;"); + lines.push_back(" node [shape=box, style=\"filled, rounded\", fontname=\"helvetica\"];"); + lines.push_back(" edge [fontname=\"helvetica\"];"); + + if (!title.empty()) { + lines.push_back(" label=\"" + title + "\";"); + lines.push_back(" labelloc=t;"); + } + + if (root) { + int node_id = 0; + treeToGraph(root.get(), lines, node_id); + } + + lines.push_back("}"); + return lines; + } + + void DecisionTree::treeToGraph( + const TreeNode* node, + std::vector& lines, + int& node_id, + int parent_id, + const std::string& edge_label) const + { + + int current_id = node_id++; + std::stringstream ss; + + if (node->is_leaf) { + // Leaf node + ss << " node" << current_id << " [label=\"Class: " << node->predicted_class; + ss << "\\nProb: " << std::fixed << std::setprecision(3) + << node->class_probabilities[node->predicted_class].item(); + ss << "\", fillcolor=\"lightblue\"];"; + lines.push_back(ss.str()); + } else { + // Internal node + ss << " node" << current_id << " [label=\"" << features[node->split_feature]; + ss << " = " << node->split_value << "?\", fillcolor=\"lightgreen\"];"; + lines.push_back(ss.str()); + } + + // Add edge from parent + if (parent_id >= 0) { + ss.str(""); + ss << " node" << parent_id << " -> node" << current_id; + if (!edge_label.empty()) { + ss << " [label=\"" << edge_label << "\"];"; + } else { + ss << ";"; + } + lines.push_back(ss.str()); + } + + // Recurse on children + if (!node->is_leaf) { + if (node->left) { + treeToGraph(node->left.get(), lines, node_id, current_id, "Yes"); + } + if (node->right) { + treeToGraph(node->right.get(), lines, node_id, current_id, "No"); + } + } + } + +} // namespace bayesnet \ No newline at end of file diff --git a/src/experimental_clfs/DecisionTree.h b/src/experimental_clfs/DecisionTree.h new file mode 100644 index 0000000..93ec930 --- /dev/null +++ b/src/experimental_clfs/DecisionTree.h @@ -0,0 +1,129 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#ifndef DECISION_TREE_H +#define DECISION_TREE_H + +#include +#include +#include +#include +#include "bayesnet/classifiers/Classifier.h" + +namespace bayesnet { + + // Forward declaration + struct TreeNode; + + class DecisionTree : public Classifier { + public: + explicit DecisionTree(int max_depth = 3, int min_samples_split = 2, int min_samples_leaf = 1); + virtual ~DecisionTree() = default; + + // Override graph method to show tree structure + std::vector graph(const std::string& title = "") const override; + + // Setters for hyperparameters + void setMaxDepth(int depth) { max_depth = depth; checkValues(); } + void setMinSamplesSplit(int samples) { min_samples_split = samples; checkValues(); } + void setMinSamplesLeaf(int samples) { min_samples_leaf = samples; checkValues(); } + + // Override setHyperparameters + void setHyperparameters(const nlohmann::json& hyperparameters) override; + + torch::Tensor predict(torch::Tensor& X) override; + std::vector predict(std::vector>& X) override; + torch::Tensor predict_proba(torch::Tensor& X) override; + std::vector> predict_proba(std::vector>& X); + + protected: + void buildModel(const torch::Tensor& weights) override; + void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override + { + // Decision trees do not require training in the traditional sense + // as they are built from the data directly. + // This method can be used to set weights or other parameters if needed. + } + private: + void checkValues(); + bool validateTensors(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& sample_weights) const; + // Tree hyperparameters + int max_depth; + int min_samples_split; + int min_samples_leaf; + int n_classes; // Number of classes in the target variable + + // Root of the decision tree + std::unique_ptr root; + + // Build tree recursively + std::unique_ptr buildTree( + const torch::Tensor& X, + const torch::Tensor& y, + const torch::Tensor& sample_weights, + int current_depth + ); + + // Find best split for a node + struct SplitInfo { + int feature_index; + int split_value; + double impurity_decrease; + torch::Tensor left_mask; + torch::Tensor right_mask; + }; + + SplitInfo findBestSplit( + const torch::Tensor& X, + const torch::Tensor& y, + const torch::Tensor& sample_weights + ); + + // Calculate weighted Gini impurity for multi-class + double calculateGiniImpurity( + const torch::Tensor& y, + const torch::Tensor& sample_weights + ); + + // Make predictions for a single sample + int predictSample(const torch::Tensor& x) const; + + // Make probabilistic predictions for a single sample + torch::Tensor predictProbaSample(const torch::Tensor& x) const; + + // Traverse tree to find leaf node + const TreeNode* traverseTree(const torch::Tensor& x, const TreeNode* node) const; + + // Convert tree to graph representation + void treeToGraph( + const TreeNode* node, + std::vector& lines, + int& node_id, + int parent_id = -1, + const std::string& edge_label = "" + ) const; + }; + + // Tree node structure + struct TreeNode { + bool is_leaf; + + // For internal nodes + int split_feature; + int split_value; + std::unique_ptr left; + std::unique_ptr right; + + // For leaf nodes + int predicted_class; + torch::Tensor class_probabilities; // Probability for each class + + TreeNode() : is_leaf(false), split_feature(-1), split_value(-1), predicted_class(-1) {} + }; + +} // namespace bayesnet + +#endif // DECISION_TREE_H \ No newline at end of file diff --git a/src/experimental_clfs/README.md b/src/experimental_clfs/README.md new file mode 100644 index 0000000..129c608 --- /dev/null +++ b/src/experimental_clfs/README.md @@ -0,0 +1,142 @@ +# AdaBoost and DecisionTree Classifier Implementation + +This implementation provides both a Decision Tree classifier and a multi-class AdaBoost classifier based on the SAMME (Stagewise Additive Modeling using a Multi-class Exponential loss) algorithm described in the paper "Multi-class AdaBoost" by Zhu et al. Implemented in C++ using + +## Components + +### 1. DecisionTree Classifier + +A classic decision tree implementation that: + +- Supports multi-class classification +- Handles weighted samples (essential for boosting) +- Uses Gini impurity as the splitting criterion +- Works with discrete/categorical features +- Provides both class predictions and probability estimates + +#### Key Features + +- **Max Depth Control**: Limit tree depth to create weak learners +- **Minimum Samples**: Control minimum samples for splitting and leaf nodes +- **Weighted Training**: Properly handles sample weights for boosting +- **Visualization**: Generates DOT format graphs of the tree structure + +#### Hyperparameters + +- `max_depth`: Maximum depth of the tree (default: 3) +- `min_samples_split`: Minimum samples required to split a node (default: 2) +- `min_samples_leaf`: Minimum samples required in a leaf node (default: 1) + +### 2. AdaBoost Classifier + +A multi-class AdaBoost implementation using DecisionTree as base estimators: + +- **SAMME Algorithm**: Implements the multi-class extension of AdaBoost +- **Automatic Stumps**: Uses decision stumps (max_depth=1) by default +- **Early Stopping**: Stops if base classifier performs worse than random +- **Ensemble Visualization**: Shows the weighted combination of base estimators + +#### Key Features + +- **Multi-class Support**: Natural extension to K classes +- **Base Estimator Control**: Configure depth of base decision trees +- **Training Monitoring**: Track training errors and estimator weights +- **Probability Estimates**: Provides class probability predictions + +#### Hyperparameters + +- `n_estimators`: Number of base estimators to train (default: 50) +- `base_max_depth`: Maximum depth for base decision trees (default: 1) + +## Algorithm Details + +The SAMME algorithm differs from binary AdaBoost in the calculation of the estimator weight (alpha): + +``` +α = log((1 - err) / err) + log(K - 1) +``` + +where `K` is the number of classes. This formula ensures that: + +- When K = 2, it reduces to standard AdaBoost +- For K > 2, base classifiers only need to be better than random guessing (1/K) rather than 50% + +## Usage Example + +```cpp +// Create AdaBoost with decision stumps +AdaBoost ada(100, 1); // 100 estimators, max_depth=1 + +// Train +ada.fit(X_train, y_train, features, className, states, Smoothing_t::NONE); + +// Predict +auto predictions = ada.predict(X_test); +auto probabilities = ada.predict_proba(X_test); + +// Evaluate +float accuracy = ada.score(X_test, y_test); + +// Get ensemble information +auto weights = ada.getEstimatorWeights(); +auto errors = ada.getTrainingErrors(); +``` + +## Implementation Structure + +``` +AdaBoost (inherits from Ensemble) + └── Uses multiple DecisionTree instances as base estimators + └── DecisionTree (inherits from Classifier) + └── Implements weighted Gini impurity splitting +``` + +## Visualization + +Both classifiers support graph visualization: + +- **DecisionTree**: Shows the tree structure with split conditions +- **AdaBoost**: Shows the ensemble of weighted base estimators + +Generate visualizations using: + +```cpp +auto graph = classifier.graph("Title"); +``` + +## Data Format + +Both classifiers expect discrete/categorical data: + +- **Features**: Integer values representing categories (stored in `torch::Tensor` or `std::vector>`) +- **Labels**: Integer values representing class indices (0, 1, ..., K-1) +- **States**: Map defining possible values for each feature and the class variable +- **Sample Weights**: Optional weights for each training sample (important for boosting) + +Example data setup: + +```cpp +// Features matrix (n_features x n_samples) +torch::Tensor X = torch::tensor({{0, 1, 2}, {1, 0, 1}}); // 2 features, 3 samples + +// Labels vector +torch::Tensor y = torch::tensor({0, 1, 0}); // 3 samples + +// States definition +std::map> states; +states["feature1"] = {0, 1, 2}; // Feature 1 can take values 0, 1, or 2 +states["feature2"] = {0, 1}; // Feature 2 can take values 0 or 1 +states["class"] = {0, 1}; // Binary classification +``` + +## Notes + +- The implementation handles discrete/categorical features as indicated by the int-based data structures +- Sample weights are properly propagated through the tree building process +- The DecisionTree implementation uses equality testing for splits (suitable for categorical data) +- Both classifiers support the standard fit/predict interface from the base framework + +## References + +- Zhu, J., Zou, H., Rosset, S., & Hastie, T. (2009). Multi-class AdaBoost. Statistics and its interface, 2(3), 349-360. +- Breiman, L., Friedman, J., Olshen, R., & Stone, C. (1984). Classification and Regression Trees. Wadsworth, Belmont, CA. diff --git a/src/experimental_clfs/TensorUtils.hpp b/src/experimental_clfs/TensorUtils.hpp index 6f09859..77ed894 100644 --- a/src/experimental_clfs/TensorUtils.hpp +++ b/src/experimental_clfs/TensorUtils.hpp @@ -45,6 +45,19 @@ namespace platform { return data; } + static torch::Tensor to_matrix(const std::vector>& data) + { + if (data.empty()) return torch::empty({ 0, 0 }, torch::kInt64); + size_t rows = data.size(); + size_t cols = data[0].size(); + torch::Tensor tensor = torch::empty({ static_cast(rows), static_cast(cols) }, torch::kInt64); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + tensor.index_put_({ static_cast(i), static_cast(j) }, data[i][j]); + } + } + return tensor; + } }; } diff --git a/src/main/Models.h b/src/main/Models.h index d6cf449..69ceaea 100644 --- a/src/main/Models.h +++ b/src/main/Models.h @@ -23,10 +23,11 @@ #include #include #include -#include +#include #include #include "../experimental_clfs/XA1DE.h" -#include "../experimental_clfs/AdaBoost.h" +// #include "../experimental_clfs/AdaBoost.h" +#include "../experimental_clfs/DecisionTree.h" namespace platform { class Models { diff --git a/src/main/modelRegister.h b/src/main/modelRegister.h index 11ebc84..4764c07 100644 --- a/src/main/modelRegister.h +++ b/src/main/modelRegister.h @@ -35,10 +35,12 @@ namespace platform { [](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();}); static Registrar registrarXGB("XGBoost", [](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();}); - static Registrar registrarAda("AdaBoostPy", - [](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoost();}); - // static Registrar registrarAda2("AdaBoost", - // [](void) -> bayesnet::BaseClassifier* { return new platform::AdaBoost();}); + static Registrar registrarAdaPy("AdaBoostPy", + [](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoostPy();}); + // static Registrar registrarAda("AdaBoost", + // [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AdaBoost();}); + static Registrar registrarDT("DecisionTree", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::DecisionTree();}); static Registrar registrarXSPODE("XSPODE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);}); static Registrar registrarXSP2DE("XSP2DE", diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fce68a8..7008066 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -12,11 +12,11 @@ if(ENABLE_TESTING) ${Bayesnet_INCLUDE_DIRS} ) set(TEST_SOURCES_PLATFORM - TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp + TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp TestDecisionTree.cpp ${Platform_SOURCE_DIR}/src/common/Datasets.cpp ${Platform_SOURCE_DIR}/src/common/Dataset.cpp ${Platform_SOURCE_DIR}/src/common/Discretization.cpp - ${Platform_SOURCE_DIR}/src/main/Scores.cpp + ${Platform_SOURCE_DIR}/src/main/Scores.cpp ${Platform_SOURCE_DIR}/src/experimental_clfs/DecisionTree.cpp ) add_executable(${TEST_PLATFORM} ${TEST_SOURCES_PLATFORM}) - target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" mdlp Catch2::Catch2WithMain BayesNet) + target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" fimdlp Catch2::Catch2WithMain bayesnet) add_test(NAME ${TEST_PLATFORM} COMMAND ${TEST_PLATFORM}) endif(ENABLE_TESTING) diff --git a/tests/TestDecisionTree.cpp b/tests/TestDecisionTree.cpp new file mode 100644 index 0000000..8fd4bd5 --- /dev/null +++ b/tests/TestDecisionTree.cpp @@ -0,0 +1,303 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include +#include +#include +#include +#include +#include +#include +#include "experimental_clfs/DecisionTree.h" +#include "TestUtils.h" + +using namespace bayesnet; +using namespace Catch::Matchers; + +TEST_CASE("DecisionTree Construction", "[DecisionTree]") +{ + SECTION("Default constructor") + { + REQUIRE_NOTHROW(DecisionTree()); + } + + SECTION("Constructor with parameters") + { + REQUIRE_NOTHROW(DecisionTree(5, 10, 3)); + } +} + +TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]") +{ + DecisionTree dt; + + SECTION("Set individual hyperparameters") + { + REQUIRE_NOTHROW(dt.setMaxDepth(10)); + REQUIRE_NOTHROW(dt.setMinSamplesSplit(5)); + REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2)); + } + + SECTION("Set hyperparameters via JSON") + { + nlohmann::json params; + params["max_depth"] = 7; + params["min_samples_split"] = 4; + params["min_samples_leaf"] = 2; + + REQUIRE_NOTHROW(dt.setHyperparameters(params)); + } + + SECTION("Invalid hyperparameters should throw") + { + nlohmann::json params; + + // Negative max_depth + params["max_depth"] = -1; + REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument); + + // Zero min_samples_split + params["max_depth"] = 5; + params["min_samples_split"] = 0; + REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument); + + // Negative min_samples_leaf + params["min_samples_split"] = 2; + params["min_samples_leaf"] = -5; + REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument); + } +} + +TEST_CASE("DecisionTree Basic Functionality", "[DecisionTree]") +{ + // Create a simple dataset + int n_samples = 20; + int n_features = 2; + + std::vector> X(n_features, std::vector(n_samples)); + std::vector y(n_samples); + + // Simple pattern: class depends on first feature + for (int i = 0; i < n_samples; i++) { + X[0][i] = i < 10 ? 0 : 1; + X[1][i] = i % 2; + y[i] = X[0][i]; // Class equals first feature + } + + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + SECTION("Training with vector interface") + { + DecisionTree dt(3, 2, 1); + REQUIRE_NOTHROW(dt.fit(X, y, features, className, states, Smoothing_t::NONE)); + + auto predictions = dt.predict(X); + REQUIRE(predictions.size() == static_cast(n_samples)); + + // Should achieve perfect accuracy on this simple dataset + int correct = 0; + for (size_t i = 0; i < predictions.size(); i++) { + if (predictions[i] == y[i]) correct++; + } + REQUIRE(correct == n_samples); + } + + SECTION("Prediction before fitting") + { + DecisionTree dt; + REQUIRE_THROWS_WITH(dt.predict(X), + ContainsSubstring("Classifier has not been fitted")); + } + + SECTION("Probability predictions") + { + DecisionTree dt(3, 2, 1); + dt.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto proba = dt.predict_proba(X); + REQUIRE(proba.size() == static_cast(n_samples)); + REQUIRE(proba[0].size() == 2); // Two classes + + // Check probabilities sum to 1 and probabilities are valid + auto predictions = dt.predict(X); + for (size_t i = 0; i < proba.size(); i++) { + auto p = proba[i]; + auto pred = predictions[i]; + REQUIRE(p.size() == 2); + REQUIRE(p[0] >= 0.0); + REQUIRE(p[1] >= 0.0); + double sum = p[0] + p[1]; + //Check that prodict_proba matches the expected predict value + REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); + REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6)); + } + } +} + +TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Training with dataset format") + { + DecisionTree dt(5, 2, 1); + + INFO("Dataset shape: " << raw.dataset.sizes()); + INFO("Features: " << raw.featurest.size()); + INFO("Samples: " << raw.nSamples); + + // DecisionTree expects dataset in format: features x samples, with labels as last row + REQUIRE_NOTHROW(dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE)); + + // Test prediction + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + // Calculate accuracy + auto correct = torch::sum(predictions == raw.yt).item(); + double accuracy = static_cast(correct) / raw.yt.size(0); + REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris + } + + SECTION("Training with vector interface") + { + DecisionTree dt(5, 2, 1); + + REQUIRE_NOTHROW(dt.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv, Smoothing_t::NONE)); + + // std::cout << "Tree structure:\n"; + // auto graph_lines = dt.graph("Iris Decision Tree"); + // for (const auto& line : graph_lines) { + // std::cout << line << "\n"; + // } + auto predictions = dt.predict(raw.Xv); + REQUIRE(predictions.size() == raw.yv.size()); + } + + SECTION("Different tree depths") + { + std::vector depths = { 1, 3, 5 }; + + for (int depth : depths) { + DecisionTree dt(depth, 2, 1); + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + } + } +} + +TEST_CASE("DecisionTree Edge Cases", "[DecisionTree]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Very shallow tree") + { + DecisionTree dt(1, 2, 1); // depth = 1 + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + // With depth 1, should have at most 2 unique predictions + auto unique_vals = at::_unique(predictions); + REQUIRE(std::get<0>(unique_vals).size(0) <= 2); + } + + SECTION("High min_samples_split") + { + DecisionTree dt(10, 50, 1); + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + } +} + +TEST_CASE("DecisionTree Graph Visualization", "[DecisionTree]") +{ + // Simple dataset + std::vector> X = { {0,0,0,1}, {0,1,1,1} }; // XOR pattern + std::vector y = { 0, 1, 1, 0 }; // XOR pattern + std::vector features = { "x1", "x2" }; + std::string className = "xor"; + std::map> states; + states["x1"] = { 0, 1 }; + states["x2"] = { 0, 1 }; + states["xor"] = { 0, 1 }; + + SECTION("Graph generation") + { + DecisionTree dt(2, 1, 1); + dt.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto graph_lines = dt.graph(); + + REQUIRE(graph_lines.size() > 2); + REQUIRE(graph_lines.front() == "digraph DecisionTree {"); + REQUIRE(graph_lines.back() == "}"); + + // Should contain node definitions + bool has_nodes = false; + for (const auto& line : graph_lines) { + if (line.find("node") != std::string::npos) { + has_nodes = true; + break; + } + } + REQUIRE(has_nodes); + } + + SECTION("Graph with title") + { + DecisionTree dt(2, 1, 1); + dt.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto graph_lines = dt.graph("XOR Tree"); + + bool has_title = false; + for (const auto& line : graph_lines) { + if (line.find("label=\"XOR Tree\"") != std::string::npos) { + has_title = true; + break; + } + } + REQUIRE(has_title); + } +} + +TEST_CASE("DecisionTree with Weights", "[DecisionTree]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Uniform weights") + { + DecisionTree dt(5, 2, 1); + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, raw.weights, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + } + + SECTION("Non-uniform weights") + { + auto weights = torch::ones({ raw.nSamples }); + weights.index({ torch::indexing::Slice(0, 50) }) *= 2.0; // Emphasize first class + weights = weights / weights.sum(); + + DecisionTree dt(5, 2, 1); + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, weights, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + } +} \ No newline at end of file diff --git a/tests/TestPlatform.cpp b/tests/TestPlatform.cpp index 108d8ec..6eb2bd0 100644 --- a/tests/TestPlatform.cpp +++ b/tests/TestPlatform.cpp @@ -7,7 +7,7 @@ #include #include "TestUtils.h" #include "folding.hpp" -#include +#include #include #include "config_platform.h" @@ -20,17 +20,17 @@ TEST_CASE("Test Platform version", "[Platform]") TEST_CASE("Test Folding library version", "[Folding]") { std::string version = folding::KFold(5, 100).version(); - REQUIRE(version == "1.1.0"); + REQUIRE(version == "1.1.1"); } TEST_CASE("Test BayesNet version", "[BayesNet]") { std::string version = bayesnet::TAN().getVersion(); - REQUIRE(version == "1.0.6"); + REQUIRE(version == "1.1.2"); } TEST_CASE("Test mdlp version", "[mdlp]") { std::string version = mdlp::CPPFImdlp::version(); - REQUIRE(version == "2.0.0"); + REQUIRE(version == "2.0.1"); } TEST_CASE("Test Arff version", "[Arff]") { diff --git a/tests/TestScores.cpp b/tests/TestScores.cpp index 111ae7e..b414c04 100644 --- a/tests/TestScores.cpp +++ b/tests/TestScores.cpp @@ -14,38 +14,40 @@ using json = nlohmann::ordered_json; auto epsilon = 1e-4; -void make_test_bin(int TP, int TN, int FP, int FN, std::vector& y_test, std::vector& y_pred) +void make_test_bin(int TP, int TN, int FP, int FN, std::vector& y_test, torch::Tensor& y_pred) { - // TP + std::vector> probs; + // TP: true positive (label 1, predicted 1) for (int i = 0; i < TP; i++) { y_test.push_back(1); - y_pred.push_back(1); + probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1 } - // TN + // TN: true negative (label 0, predicted 0) for (int i = 0; i < TN; i++) { y_test.push_back(0); - y_pred.push_back(0); + probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0 } - // FP + // FP: false positive (label 0, predicted 1) for (int i = 0; i < FP; i++) { y_test.push_back(0); - y_pred.push_back(1); + probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1 } - // FN + // FN: false negative (label 1, predicted 0) for (int i = 0; i < FN; i++) { y_test.push_back(1); - y_pred.push_back(0); + probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0 } + // Convert to torch::Tensor of double, shape [N,2] + y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 2 }, torch::kFloat64).clone(); } TEST_CASE("Scores binary", "[Scores]") { std::vector y_test; - std::vector y_pred; + torch::Tensor y_pred; make_test_bin(197, 210, 52, 41, y_test, y_pred); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores(y_test_tensor, y_pred_tensor, 2); + platform::Scores scores(y_test_tensor, y_pred, 2); REQUIRE(scores.accuracy() == Catch::Approx(0.814).epsilon(epsilon)); REQUIRE(scores.f1_score(0) == Catch::Approx(0.818713)); REQUIRE(scores.f1_score(1) == Catch::Approx(0.809035)); @@ -64,10 +66,23 @@ TEST_CASE("Scores binary", "[Scores]") TEST_CASE("Scores multiclass", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + // Refactor y_pred to a tensor of shape [10, 3] with probabilities + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); + // Convert y_test to a tensor auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores(y_test_tensor, y_pred_tensor, 3); + platform::Scores scores(y_test_tensor, y_pred, 3); REQUIRE(scores.accuracy() == Catch::Approx(0.6).epsilon(epsilon)); REQUIRE(scores.f1_score(0) == Catch::Approx(0.666667)); REQUIRE(scores.f1_score(1) == Catch::Approx(0.4)); @@ -84,10 +99,21 @@ TEST_CASE("Scores multiclass", "[Scores]") TEST_CASE("Test Confusion Matrix Values", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores(y_test_tensor, y_pred_tensor, 3); + platform::Scores scores(y_test_tensor, y_pred, 3); auto confusion_matrix = scores.get_confusion_matrix(); REQUIRE(confusion_matrix[0][0].item() == 2); REQUIRE(confusion_matrix[0][1].item() == 1); @@ -102,11 +128,22 @@ TEST_CASE("Test Confusion Matrix Values", "[Scores]") TEST_CASE("Confusion Matrix JSON", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); std::vector labels = { "Aeroplane", "Boat", "Car" }; - platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); + platform::Scores scores(y_test_tensor, y_pred, 3, labels); auto res_json_int = scores.get_confusion_matrix_json(); REQUIRE(res_json_int[0][0] == 2); REQUIRE(res_json_int[0][1] == 1); @@ -131,11 +168,22 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]") TEST_CASE("Classification Report", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); std::vector labels = { "Aeroplane", "Boat", "Car" }; - platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); + platform::Scores scores(y_test_tensor, y_pred, 3, labels); auto report = scores.classification_report(Colors::BLUE(), "train"); auto json_matrix = scores.get_confusion_matrix_json(true); platform::Scores scores2(json_matrix); @@ -144,11 +192,22 @@ TEST_CASE("Classification Report", "[Scores]") TEST_CASE("JSON constructor", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); std::vector labels = { "Car", "Boat", "Aeroplane" }; - platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); + platform::Scores scores(y_test_tensor, y_pred, 3, labels); auto res_json_int = scores.get_confusion_matrix_json(); platform::Scores scores2(res_json_int); REQUIRE(scores.accuracy() == scores2.accuracy()); @@ -173,17 +232,14 @@ TEST_CASE("JSON constructor", "[Scores]") TEST_CASE("Aggregate", "[Scores]") { std::vector y_test; - std::vector y_pred; + torch::Tensor y_pred; make_test_bin(197, 210, 52, 41, y_test, y_pred); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores(y_test_tensor, y_pred_tensor, 2); + platform::Scores scores(y_test_tensor, y_pred, 2); y_test.clear(); - y_pred.clear(); make_test_bin(227, 187, 39, 47, y_test, y_pred); auto y_test_tensor2 = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor2 = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores2(y_test_tensor2, y_pred_tensor2, 2); + platform::Scores scores2(y_test_tensor2, y_pred, 2); scores.aggregate(scores2); REQUIRE(scores.accuracy() == Catch::Approx(0.821).epsilon(epsilon)); REQUIRE(scores.f1_score(0) == Catch::Approx(0.8160329)); @@ -195,11 +251,9 @@ TEST_CASE("Aggregate", "[Scores]") REQUIRE(scores.f1_weighted() == Catch::Approx(0.8209856)); REQUIRE(scores.f1_macro() == Catch::Approx(0.8208694)); y_test.clear(); - y_pred.clear(); make_test_bin(197 + 227, 210 + 187, 52 + 39, 41 + 47, y_test, y_pred); y_test_tensor = torch::tensor(y_test, torch::kInt32); - y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores3(y_test_tensor, y_pred_tensor, 2); + platform::Scores scores3(y_test_tensor, y_pred, 2); for (int i = 0; i < 2; ++i) { REQUIRE(scores3.f1_score(i) == scores.f1_score(i)); REQUIRE(scores3.precision(i) == scores.precision(i)); @@ -212,11 +266,22 @@ TEST_CASE("Aggregate", "[Scores]") TEST_CASE("Order of keys", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); std::vector labels = { "Car", "Boat", "Aeroplane" }; - platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); + platform::Scores scores(y_test_tensor, y_pred, 3, labels); auto res_json_int = scores.get_confusion_matrix_json(true); // Make a temp file and store the json std::string filename = "temp.json"; diff --git a/tests/TestUtils.h b/tests/TestUtils.h index aea25d6..4409d6f 100644 --- a/tests/TestUtils.h +++ b/tests/TestUtils.h @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include bool file_exists(const std::string& name); diff --git a/vcpkg.json b/vcpkg.json index 69ded49..06584bb 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -7,6 +7,7 @@ "fimdlp", "libtorch-bin", "folding", + "catch2", "argparse" ], "overrides": [ @@ -30,9 +31,13 @@ "name": "argpase", "version": "3.2" }, + { + "name": "catch2", + "version": "3.8.1" + }, { "name": "nlohmann-json", "version": "3.11.3" } ] - } \ No newline at end of file + } From 415a7ae608d0124ea2919db31dfaf142621c90d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 18 Jun 2025 11:27:11 +0200 Subject: [PATCH 20/27] Begin AdaBoost integration --- src/CMakeLists.txt | 7 +- src/experimental_clfs/AdaBoost.cpp | 266 +++++++++- src/experimental_clfs/AdaBoost.h | 20 +- src/experimental_clfs/DecisionTree.cpp | 24 - src/experimental_clfs/DecisionTree.h | 13 +- src/main/Models.h | 2 +- src/main/modelRegister.h | 4 +- tests/CMakeLists.txt | 6 +- tests/TestAdaBoost.cpp | 707 +++++++++++++++++++++++++ tests/TestDecisionTree.cpp | 8 + 10 files changed, 1001 insertions(+), 56 deletions(-) create mode 100644 tests/TestAdaBoost.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 54bc2de..66bcbdd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,7 +21,7 @@ add_executable( experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp experimental_clfs/DecisionTree.cpp - + experimental_clfs/AdaBoost.cpp ) target_link_libraries(b_best Boost::boost "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}") @@ -36,6 +36,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources} experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp experimental_clfs/DecisionTree.cpp + experimental_clfs/AdaBoost.cpp ) target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) @@ -48,7 +49,7 @@ add_executable(b_list commands/b_list.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp experimental_clfs/DecisionTree.cpp - + experimental_clfs/AdaBoost.cpp ) target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}") @@ -63,7 +64,7 @@ add_executable(b_main commands/b_main.cpp ${main_sources} experimental_clfs/ExpClf.cpp experimental_clfs/ExpClf.cpp experimental_clfs/DecisionTree.cpp - + experimental_clfs/AdaBoost.cpp ) target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index ff5c40e..a04236d 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -11,11 +11,12 @@ #include #include #include +#include "TensorUtils.hpp" namespace bayesnet { AdaBoost::AdaBoost(int n_estimators, int max_depth) - : Ensemble(true), n_estimators(n_estimators), base_max_depth(max_depth) + : Ensemble(true), n_estimators(n_estimators), base_max_depth(max_depth), n(0), n_classes(0) { validHyperparameters = { "n_estimators", "base_max_depth" }; } @@ -27,6 +28,10 @@ namespace bayesnet { alphas.clear(); training_errors.clear(); + // Initialize n (number of features) and n_classes + n = dataset.size(0) - 1; // Exclude the label row + n_classes = states[className].size(); + // Initialize sample weights uniformly int n_samples = dataset.size(1); sample_weights = torch::ones({ n_samples }) / n_samples; @@ -37,6 +42,12 @@ namespace bayesnet { normalizeWeights(); } + // Debug information + std::cout << "Starting AdaBoost training with " << n_estimators << " estimators" << std::endl; + std::cout << "Number of classes: " << n_classes << std::endl; + std::cout << "Number of features: " << n << std::endl; + std::cout << "Number of samples: " << n_samples << std::endl; + // Main AdaBoost training loop (SAMME algorithm) for (int iter = 0; iter < n_estimators; ++iter) { // Train base estimator with current sample weights @@ -46,9 +57,16 @@ namespace bayesnet { double weighted_error = calculateWeightedError(estimator.get(), sample_weights); training_errors.push_back(weighted_error); + // Debug output + std::cout << "Iteration " << iter + 1 << ":" << std::endl; + std::cout << " Weighted error: " << weighted_error << std::endl; + // Check if error is too high (worse than random guessing) - double random_guess_error = 1.0 - (1.0 / getClassNumStates()); + double random_guess_error = 1.0 - (1.0 / n_classes); + + // According to SAMME, we need error < random_guess_error if (weighted_error >= random_guess_error) { + std::cout << " Error >= random guess (" << random_guess_error << "), stopping" << std::endl; // If only one estimator and it's worse than random, keep it with zero weight if (models.empty()) { models.push_back(std::move(estimator)); @@ -60,7 +78,9 @@ namespace bayesnet { // Calculate alpha (estimator weight) using SAMME formula // alpha = log((1 - err) / err) + log(K - 1) double alpha = std::log((1.0 - weighted_error) / weighted_error) + - std::log(static_cast(getClassNumStates() - 1)); + std::log(static_cast(n_classes - 1)); + + std::cout << " Alpha: " << alpha << std::endl; // Store the estimator and its weight models.push_back(std::move(estimator)); @@ -74,42 +94,54 @@ namespace bayesnet { // Check for perfect classification if (weighted_error < 1e-10) { + std::cout << " Perfect classification achieved, stopping" << std::endl; break; } } // Set the number of models actually trained n_models = models.size(); + std::cout << "AdaBoost training completed with " << n_models << " models" << std::endl; } void AdaBoost::trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) { - // AdaBoost handles its own weight management, so we just build the model + // Call buildModel which does the actual training buildModel(weights); + fitted = true; } std::unique_ptr AdaBoost::trainBaseEstimator(const torch::Tensor& weights) { // Create a decision tree with specified max depth - // For AdaBoost, we typically use shallow trees (stumps with max_depth=1) auto tree = std::make_unique(base_max_depth); + // Ensure weights are properly normalized + auto normalized_weights = weights / weights.sum(); + // Fit the tree with the current sample weights - tree->fit(dataset, features, className, states, weights, Smoothing_t::NONE); + tree->fit(dataset, features, className, states, normalized_weights, Smoothing_t::NONE); return tree; } double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights) { - // Get predictions from the estimator + // Get features and labels from dataset auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }); auto y_true = dataset.index({ -1, torch::indexing::Slice() }); - auto y_pred = estimator->predict(X.t()); + + // Get predictions from the estimator + auto y_pred = estimator->predict(X); // Calculate weighted error auto incorrect = (y_pred != y_true).to(torch::kFloat); - double weighted_error = torch::sum(incorrect * weights).item(); + + // Ensure weights are normalized + auto normalized_weights = weights / weights.sum(); + + // Calculate weighted error + double weighted_error = torch::sum(incorrect * normalized_weights).item(); return weighted_error; } @@ -119,7 +151,7 @@ namespace bayesnet { // Get predictions from the estimator auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }); auto y_true = dataset.index({ -1, torch::indexing::Slice() }); - auto y_pred = estimator->predict(X.t()); + auto y_pred = estimator->predict(X); // Update weights according to SAMME algorithm // w_i = w_i * exp(alpha * I(y_i != y_pred_i)) @@ -187,6 +219,16 @@ namespace bayesnet { return graph_lines; } + void AdaBoost::checkValues() const + { + if (n_estimators <= 0) { + throw std::invalid_argument("n_estimators must be positive"); + } + if (base_max_depth <= 0) { + throw std::invalid_argument("base_max_depth must be positive"); + } + } + void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters_) { auto hyperparameters = hyperparameters_; @@ -194,21 +236,209 @@ namespace bayesnet { auto it = hyperparameters.find("n_estimators"); if (it != hyperparameters.end()) { n_estimators = it->get(); - if (n_estimators <= 0) { - throw std::invalid_argument("n_estimators must be positive"); - } - hyperparameters.erase("n_estimators"); // Remove 'n_estimators' if present + hyperparameters.erase("n_estimators"); } it = hyperparameters.find("base_max_depth"); if (it != hyperparameters.end()) { base_max_depth = it->get(); - if (base_max_depth <= 0) { - throw std::invalid_argument("base_max_depth must be positive"); - } - hyperparameters.erase("base_max_depth"); // Remove 'base_max_depth' if present + hyperparameters.erase("base_max_depth"); } + checkValues(); Ensemble::setHyperparameters(hyperparameters); } + torch::Tensor AdaBoost::predict(torch::Tensor& X) + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + if (models.empty()) { + throw std::runtime_error("No models have been trained"); + } + + // X should be (n_features, n_samples) + if (X.size(0) != n) { + throw std::runtime_error("Input has wrong number of features. Expected " + + std::to_string(n) + " but got " + std::to_string(X.size(0))); + } + + int n_samples = X.size(1); + torch::Tensor predictions = torch::zeros({ n_samples }, torch::kInt32); + + for (int i = 0; i < n_samples; i++) { + auto sample = X.index({ torch::indexing::Slice(), i }); + predictions[i] = predictSample(sample); + } + + return predictions; + } + + torch::Tensor AdaBoost::predict_proba(torch::Tensor& X) + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + if (models.empty()) { + throw std::runtime_error("No models have been trained"); + } + + // X should be (n_features, n_samples) + if (X.size(0) != n) { + throw std::runtime_error("Input has wrong number of features. Expected " + + std::to_string(n) + " but got " + std::to_string(X.size(0))); + } + + int n_samples = X.size(1); + torch::Tensor probabilities = torch::zeros({ n_samples, n_classes }); + + for (int i = 0; i < n_samples; i++) { + auto sample = X.index({ torch::indexing::Slice(), i }); + probabilities[i] = predictProbaSample(sample); + } + + return probabilities; + } + + std::vector AdaBoost::predict(std::vector>& X) + { + // Convert to tensor - X is samples x features, need to transpose + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X).t(); + auto predictions = predict(X_tensor); + std::vector result = platform::TensorUtils::to_vector(predictions); + return result; + } + + std::vector> AdaBoost::predict_proba(std::vector>& X) + { + auto n_samples = X.size(); + // Convert to tensor - X is samples x features, need to transpose + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X).t(); + auto proba_tensor = predict_proba(X_tensor); + + std::vector> result(n_samples, std::vector(n_classes, 0.0)); + + for (size_t i = 0; i < n_samples; i++) { + for (int j = 0; j < n_classes; j++) { + result[i][j] = proba_tensor[i][j].item(); + } + } + + return result; + } + + int AdaBoost::predictSample(const torch::Tensor& x) const + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + if (models.empty()) { + throw std::runtime_error("No models have been trained"); + } + + // x should be a 1D tensor with n features + if (x.size(0) != n) { + throw std::runtime_error("Input sample has wrong number of features. Expected " + + std::to_string(n) + " but got " + std::to_string(x.size(0))); + } + + // Initialize class votes + std::vector class_votes(n_classes, 0.0); + + // Accumulate weighted votes from all estimators + for (size_t i = 0; i < models.size(); i++) { + if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight + + try { + // Create a matrix with the sample as a column vector + auto x_matrix = x.unsqueeze(1); // Shape: (n_features, 1) + + // Get prediction from this estimator + auto prediction = models[i]->predict(x_matrix); + int predicted_class = prediction[0].item(); + + // Add weighted vote for this class + if (predicted_class >= 0 && predicted_class < n_classes) { + class_votes[predicted_class] += alphas[i]; + } + } + catch (const std::exception& e) { + std::cerr << "Error in estimator " << i << ": " << e.what() << std::endl; + continue; + } + } + + // Return class with highest weighted vote + return std::distance(class_votes.begin(), + std::max_element(class_votes.begin(), class_votes.end())); + } + + torch::Tensor AdaBoost::predictProbaSample(const torch::Tensor& x) const + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + if (models.empty()) { + throw std::runtime_error("No models have been trained"); + } + + // x should be a 1D tensor with n features + if (x.size(0) != n) { + throw std::runtime_error("Input sample has wrong number of features. Expected " + + std::to_string(n) + " but got " + std::to_string(x.size(0))); + } + + // Initialize probability accumulator + torch::Tensor class_probs = torch::zeros({ n_classes }, torch::kDouble); + + // Sum weighted probabilities from all estimators + double total_alpha = 0.0; + + for (size_t i = 0; i < models.size(); i++) { + if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight + + try { + // Create a matrix with the sample as a column vector + auto x_matrix = x.unsqueeze(1); // Shape: (n_features, 1) + + // Get probability predictions from this estimator + auto proba = models[i]->predict_proba(x_matrix); + + // Add weighted probabilities + for (int j = 0; j < n_classes; j++) { + class_probs[j] += alphas[i] * proba[0][j].item(); + } + + total_alpha += alphas[i]; + } + catch (const std::exception& e) { + std::cerr << "Error in estimator " << i << ": " << e.what() << std::endl; + continue; + } + } + + // Normalize probabilities + if (total_alpha > 0) { + class_probs = class_probs / total_alpha; + } else { + // If no valid estimators, return uniform distribution + class_probs.fill_(1.0 / n_classes); + } + + // Ensure probabilities are valid (non-negative and sum to 1) + class_probs = torch::clamp(class_probs, 0.0, 1.0); + double sum_probs = torch::sum(class_probs).item(); + if (sum_probs > 1e-15) { + class_probs = class_probs / sum_probs; + } else { + class_probs.fill_(1.0 / n_classes); + } + + return class_probs.to(torch::kFloat); // Convert back to float for consistency + } + } // namespace bayesnet \ No newline at end of file diff --git a/src/experimental_clfs/AdaBoost.h b/src/experimental_clfs/AdaBoost.h index c9e4ede..5d1bc37 100644 --- a/src/experimental_clfs/AdaBoost.h +++ b/src/experimental_clfs/AdaBoost.h @@ -21,9 +21,9 @@ namespace bayesnet { std::vector graph(const std::string& title = "") const override; // AdaBoost specific methods - void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; } + void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; checkValues(); } int getNEstimators() const { return n_estimators; } - void setBaseMaxDepth(int depth) { this->base_max_depth = depth; } + void setBaseMaxDepth(int depth) { this->base_max_depth = depth; checkValues(); } int getBaseMaxDepth() const { return base_max_depth; } // Get the weight of each base estimator @@ -35,6 +35,11 @@ namespace bayesnet { // Override setHyperparameters from BaseClassifier void setHyperparameters(const nlohmann::json& hyperparameters) override; + torch::Tensor predict(torch::Tensor& X) override; + std::vector predict(std::vector>& X) override; + torch::Tensor predict_proba(torch::Tensor& X) override; + std::vector> predict_proba(std::vector>& X); + protected: void buildModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override; @@ -45,6 +50,8 @@ namespace bayesnet { std::vector alphas; // Weight of each base estimator std::vector training_errors; // Training error at each iteration torch::Tensor sample_weights; // Current sample weights + int n_classes; // Number of classes in the target variable + int n; // Number of features // Train a single base estimator std::unique_ptr trainBaseEstimator(const torch::Tensor& weights); @@ -57,6 +64,15 @@ namespace bayesnet { // Normalize weights to sum to 1 void normalizeWeights(); + + // Check if hyperparameters values are valid + void checkValues() const; + + // Make predictions for a single sample + int predictSample(const torch::Tensor& x) const; + + // Make probabilistic predictions for a single sample + torch::Tensor predictProbaSample(const torch::Tensor& x) const; }; } diff --git a/src/experimental_clfs/DecisionTree.cpp b/src/experimental_clfs/DecisionTree.cpp index c615504..307186a 100644 --- a/src/experimental_clfs/DecisionTree.cpp +++ b/src/experimental_clfs/DecisionTree.cpp @@ -327,30 +327,6 @@ namespace bayesnet { return predictions; } - void dumpTensor(const torch::Tensor& tensor, const std::string& name) - { - std::cout << name << ": " << std::endl; - for (int i = 0; i < tensor.size(0); i++) { - std::cout << "["; - for (int j = 0; j < tensor.size(1); j++) { - std::cout << tensor[i][j].item() << " "; - } - std::cout << "]" << std::endl; - } - std::cout << std::endl; - } - void dumpVector(const std::vector>& vec, const std::string& name) - { - std::cout << name << ": " << std::endl;; - for (const auto& row : vec) { - std::cout << "["; - for (const auto& val : row) { - std::cout << val << " "; - } - std::cout << "] " << std::endl; - } - std::cout << std::endl; - } std::vector DecisionTree::predict(std::vector>& X) { diff --git a/src/experimental_clfs/DecisionTree.h b/src/experimental_clfs/DecisionTree.h index 93ec930..8a1c337 100644 --- a/src/experimental_clfs/DecisionTree.h +++ b/src/experimental_clfs/DecisionTree.h @@ -30,6 +30,9 @@ namespace bayesnet { void setMaxDepth(int depth) { max_depth = depth; checkValues(); } void setMinSamplesSplit(int samples) { min_samples_split = samples; checkValues(); } void setMinSamplesLeaf(int samples) { min_samples_leaf = samples; checkValues(); } + int getMaxDepth() const { return max_depth; } + int getMinSamplesSplit() const { return min_samples_split; } + int getMinSamplesLeaf() const { return min_samples_leaf; } // Override setHyperparameters void setHyperparameters(const nlohmann::json& hyperparameters) override; @@ -39,6 +42,12 @@ namespace bayesnet { torch::Tensor predict_proba(torch::Tensor& X) override; std::vector> predict_proba(std::vector>& X); + // Make predictions for a single sample + int predictSample(const torch::Tensor& x) const; + + // Make probabilistic predictions for a single sample + torch::Tensor predictProbaSample(const torch::Tensor& x) const; + protected: void buildModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override @@ -88,11 +97,7 @@ namespace bayesnet { const torch::Tensor& sample_weights ); - // Make predictions for a single sample - int predictSample(const torch::Tensor& x) const; - // Make probabilistic predictions for a single sample - torch::Tensor predictProbaSample(const torch::Tensor& x) const; // Traverse tree to find leaf node const TreeNode* traverseTree(const torch::Tensor& x, const TreeNode* node) const; diff --git a/src/main/Models.h b/src/main/Models.h index 69ceaea..3640a01 100644 --- a/src/main/Models.h +++ b/src/main/Models.h @@ -26,7 +26,7 @@ #include #include #include "../experimental_clfs/XA1DE.h" -// #include "../experimental_clfs/AdaBoost.h" +#include "../experimental_clfs/AdaBoost.h" #include "../experimental_clfs/DecisionTree.h" namespace platform { diff --git a/src/main/modelRegister.h b/src/main/modelRegister.h index 4764c07..5f44728 100644 --- a/src/main/modelRegister.h +++ b/src/main/modelRegister.h @@ -37,8 +37,8 @@ namespace platform { [](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();}); static Registrar registrarAdaPy("AdaBoostPy", [](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoostPy();}); - // static Registrar registrarAda("AdaBoost", - // [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AdaBoost();}); + static Registrar registrarAda("AdaBoost", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AdaBoost();}); static Registrar registrarDT("DecisionTree", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::DecisionTree();}); static Registrar registrarXSPODE("XSPODE", diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7008066..18317bb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -12,9 +12,11 @@ if(ENABLE_TESTING) ${Bayesnet_INCLUDE_DIRS} ) set(TEST_SOURCES_PLATFORM - TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp TestDecisionTree.cpp + TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp TestDecisionTree.cpp TestAdaBoost.cpp ${Platform_SOURCE_DIR}/src/common/Datasets.cpp ${Platform_SOURCE_DIR}/src/common/Dataset.cpp ${Platform_SOURCE_DIR}/src/common/Discretization.cpp - ${Platform_SOURCE_DIR}/src/main/Scores.cpp ${Platform_SOURCE_DIR}/src/experimental_clfs/DecisionTree.cpp + ${Platform_SOURCE_DIR}/src/main/Scores.cpp + ${Platform_SOURCE_DIR}/src/experimental_clfs/DecisionTree.cpp + ${Platform_SOURCE_DIR}/src/experimental_clfs/AdaBoost.cpp ) add_executable(${TEST_PLATFORM} ${TEST_SOURCES_PLATFORM}) target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" fimdlp Catch2::Catch2WithMain bayesnet) diff --git a/tests/TestAdaBoost.cpp b/tests/TestAdaBoost.cpp new file mode 100644 index 0000000..6c2453d --- /dev/null +++ b/tests/TestAdaBoost.cpp @@ -0,0 +1,707 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include +#include +#include +#include +#include +#include +#include +#include "experimental_clfs/AdaBoost.h" +#include "experimental_clfs/DecisionTree.h" +#include "TestUtils.h" + +using namespace bayesnet; +using namespace Catch::Matchers; + +TEST_CASE("AdaBoost Construction", "[AdaBoost]") +{ + SECTION("Default constructor") + { + REQUIRE_NOTHROW(AdaBoost()); + } + + SECTION("Constructor with parameters") + { + REQUIRE_NOTHROW(AdaBoost(100, 2)); + } + + SECTION("Constructor parameter access") + { + AdaBoost ada(75, 3); + REQUIRE(ada.getNEstimators() == 75); + REQUIRE(ada.getBaseMaxDepth() == 3); + } +} + +TEST_CASE("AdaBoost Hyperparameter Setting", "[AdaBoost]") +{ + AdaBoost ada; + + SECTION("Set individual hyperparameters") + { + REQUIRE_NOTHROW(ada.setNEstimators(100)); + REQUIRE_NOTHROW(ada.setBaseMaxDepth(5)); + + REQUIRE(ada.getNEstimators() == 100); + REQUIRE(ada.getBaseMaxDepth() == 5); + } + + SECTION("Set hyperparameters via JSON") + { + nlohmann::json params; + params["n_estimators"] = 80; + params["base_max_depth"] = 4; + + REQUIRE_NOTHROW(ada.setHyperparameters(params)); + } + + SECTION("Invalid hyperparameters should throw") + { + nlohmann::json params; + + // Negative n_estimators + params["n_estimators"] = -1; + REQUIRE_THROWS_AS(ada.setHyperparameters(params), std::invalid_argument); + + // Zero n_estimators + params["n_estimators"] = 0; + REQUIRE_THROWS_AS(ada.setHyperparameters(params), std::invalid_argument); + + // Negative base_max_depth + params["n_estimators"] = 50; + params["base_max_depth"] = -1; + REQUIRE_THROWS_AS(ada.setHyperparameters(params), std::invalid_argument); + + // Zero base_max_depth + params["base_max_depth"] = 0; + REQUIRE_THROWS_AS(ada.setHyperparameters(params), std::invalid_argument); + } +} + +TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") +{ + // Create a simple dataset + int n_samples = 20; + int n_features = 2; + + std::vector> X(n_features, std::vector(n_samples)); + std::vector y(n_samples); + + // Simple pattern: class depends on first feature + for (int i = 0; i < n_samples; i++) { + X[0][i] = i < 10 ? 0 : 1; + X[1][i] = i % 2; + y[i] = X[0][i]; // Class equals first feature + } + + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + SECTION("Training with vector interface") + { + AdaBoost ada(10, 3); // 10 estimators, max_depth = 3 + REQUIRE_NOTHROW(ada.fit(X, y, features, className, states, Smoothing_t::NONE)); + + // Check that we have the expected number of models + auto weights = ada.getEstimatorWeights(); + REQUIRE(weights.size() <= 10); // Should be <= n_estimators + REQUIRE(weights.size() > 0); // Should have at least one model + + // Check training errors + auto errors = ada.getTrainingErrors(); + REQUIRE(errors.size() == weights.size()); + + // All training errors should be less than 0.5 for this simple dataset + for (double error : errors) { + REQUIRE(error < 0.5); + REQUIRE(error >= 0.0); + } + } + + SECTION("Prediction before fitting") + { + AdaBoost ada; + REQUIRE_THROWS_WITH(ada.predict(X), + ContainsSubstring("not been fitted")); + REQUIRE_THROWS_WITH(ada.predict_proba(X), + ContainsSubstring("not been fitted")); + } + + SECTION("Prediction with vector interface") + { + AdaBoost ada(10, 3); + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto predictions = ada.predict(X); + REQUIRE(predictions.size() == static_cast(n_samples)); + + } + + SECTION("Probability predictions with vector interface") + { + AdaBoost ada(10, 3); + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto proba = ada.predict_proba(X); + REQUIRE(proba.size() == static_cast(n_samples)); + REQUIRE(proba[0].size() == 2); // Two classes + + // Check probabilities sum to 1 and are valid + auto predictions = ada.predict(X); + for (size_t i = 0; i < proba.size(); i++) { + auto p = proba[i]; + auto pred = predictions[i]; + REQUIRE(p.size() == 2); + REQUIRE(p[0] >= 0.0); + REQUIRE(p[1] >= 0.0); + double sum = p[0] + p[1]; + REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6)); + + // Check that predict_proba matches the expected predict value + REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); + } + } +} + +TEST_CASE("AdaBoost Tensor Interface", "[AdaBoost]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Training with tensor format") + { + AdaBoost ada(20, 3); + + INFO("Dataset shape: " << raw.dataset.sizes()); + INFO("Features: " << raw.featurest.size()); + INFO("Samples: " << raw.nSamples); + + // AdaBoost expects dataset in format: features x samples, with labels as last row + REQUIRE_NOTHROW(ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE)); + + // Test prediction with tensor + auto predictions = ada.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + // Calculate accuracy + auto correct = torch::sum(predictions == raw.yt).item(); + double accuracy = static_cast(correct) / raw.yt.size(0); + REQUIRE(accuracy > 0.85); // Should achieve good accuracy on Iris + + // Test probability predictions with tensor + auto proba = ada.predict_proba(raw.Xt); + REQUIRE(proba.size(0) == raw.yt.size(0)); + REQUIRE(proba.size(1) == 3); // Three classes in Iris + + // Check probabilities sum to 1 + auto prob_sums = torch::sum(proba, 1); + for (int i = 0; i < prob_sums.size(0); i++) { + REQUIRE(prob_sums[i].item() == Catch::Approx(1.0).epsilon(1e-6)); + } + } +} + +TEST_CASE("AdaBoost on Iris Dataset", "[AdaBoost][iris]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Training with vector interface") + { + AdaBoost ada(30, 3); + + REQUIRE_NOTHROW(ada.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv, Smoothing_t::NONE)); + + auto predictions = ada.predict(raw.Xv); + REQUIRE(predictions.size() == raw.yv.size()); + + // Calculate accuracy + int correct = 0; + for (size_t i = 0; i < predictions.size(); i++) { + if (predictions[i] == raw.yv[i]) correct++; + } + double accuracy = static_cast(correct) / raw.yv.size(); + REQUIRE(accuracy > 0.85); // Should achieve good accuracy + + // Test probability predictions + auto proba = ada.predict_proba(raw.Xv); + REQUIRE(proba.size() == raw.yv.size()); + REQUIRE(proba[0].size() == 3); // Three classes + + // Verify estimator weights and errors + auto weights = ada.getEstimatorWeights(); + auto errors = ada.getTrainingErrors(); + + REQUIRE(weights.size() == errors.size()); + REQUIRE(weights.size() > 0); + + // All weights should be positive (for non-zero error estimators) + for (double w : weights) { + REQUIRE(w >= 0.0); + } + + // All errors should be less than 0.5 (better than random) + for (double e : errors) { + REQUIRE(e < 0.5); + REQUIRE(e >= 0.0); + } + } + + SECTION("Different number of estimators") + { + std::vector n_estimators = { 5, 15, 25 }; + + for (int n_est : n_estimators) { + AdaBoost ada(n_est, 2); + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = ada.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + // Check that we don't exceed the specified number of estimators + auto weights = ada.getEstimatorWeights(); + REQUIRE(static_cast(weights.size()) <= n_est); + } + } + + SECTION("Different base estimator depths") + { + std::vector depths = { 1, 2, 4 }; + + for (int depth : depths) { + AdaBoost ada(15, depth); + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = ada.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + } + } +} + +TEST_CASE("AdaBoost Edge Cases", "[AdaBoost]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Single estimator (depth 1 stump)") + { + AdaBoost ada(1, 1); // Single decision stump + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = ada.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + auto weights = ada.getEstimatorWeights(); + REQUIRE(weights.size() == 1); + } + + SECTION("Perfect classifier scenario") + { + // Create a perfectly separable dataset + std::vector> X = { {0,0,1,1}, {0,1,0,1} }; + std::vector y = { 0, 0, 1, 1 }; + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + AdaBoost ada(10, 3); + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto predictions = ada.predict(X); + REQUIRE(predictions.size() == 4); + + // Should achieve perfect accuracy + int correct = 0; + for (size_t i = 0; i < predictions.size(); i++) { + if (predictions[i] == y[i]) correct++; + } + REQUIRE(correct == 4); + + // Should stop early due to perfect classification + auto errors = ada.getTrainingErrors(); + if (errors.size() > 0) { + REQUIRE(errors.back() < 1e-10); // Very low error + } + } + + SECTION("Small dataset") + { + // Very small dataset + std::vector> X = { {0,1}, {1,0} }; + std::vector y = { 0, 1 }; + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + AdaBoost ada(5, 1); + REQUIRE_NOTHROW(ada.fit(X, y, features, className, states, Smoothing_t::NONE)); + + auto predictions = ada.predict(X); + REQUIRE(predictions.size() == 2); + } +} + +TEST_CASE("AdaBoost Graph Visualization", "[AdaBoost]") +{ + // Simple dataset for visualization + std::vector> X = { {0,0,1,1}, {0,1,0,1} }; + std::vector y = { 0, 1, 1, 0 }; // XOR pattern + std::vector features = { "x1", "x2" }; + std::string className = "xor"; + std::map> states; + states["x1"] = { 0, 1 }; + states["x2"] = { 0, 1 }; + states["xor"] = { 0, 1 }; + + SECTION("Graph generation") + { + AdaBoost ada(5, 2); + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto graph_lines = ada.graph(); + + REQUIRE(graph_lines.size() > 2); + REQUIRE(graph_lines.front() == "digraph AdaBoost {"); + REQUIRE(graph_lines.back() == "}"); + + // Should contain base estimator references + bool has_estimators = false; + for (const auto& line : graph_lines) { + if (line.find("Estimator") != std::string::npos) { + has_estimators = true; + break; + } + } + REQUIRE(has_estimators); + + // Should contain alpha values + bool has_alpha = false; + for (const auto& line : graph_lines) { + if (line.find("α") != std::string::npos || line.find("alpha") != std::string::npos) { + has_alpha = true; + break; + } + } + REQUIRE(has_alpha); + } + + SECTION("Graph with title") + { + AdaBoost ada(3, 1); + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto graph_lines = ada.graph("XOR AdaBoost"); + + bool has_title = false; + for (const auto& line : graph_lines) { + if (line.find("label=\"XOR AdaBoost\"") != std::string::npos) { + has_title = true; + break; + } + } + REQUIRE(has_title); + } +} + +TEST_CASE("AdaBoost with Weights", "[AdaBoost]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Uniform weights") + { + AdaBoost ada(20, 3); + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, raw.weights, Smoothing_t::NONE); + + auto predictions = ada.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + auto weights = ada.getEstimatorWeights(); + REQUIRE(weights.size() > 0); + } + + SECTION("Non-uniform weights") + { + auto weights = torch::ones({ raw.nSamples }); + weights.index({ torch::indexing::Slice(0, 50) }) *= 3.0; // Emphasize first class + weights = weights / weights.sum(); + + AdaBoost ada(15, 2); + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, weights, Smoothing_t::NONE); + + auto predictions = ada.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + // Check that training completed successfully + auto estimator_weights = ada.getEstimatorWeights(); + auto errors = ada.getTrainingErrors(); + + REQUIRE(estimator_weights.size() == errors.size()); + REQUIRE(estimator_weights.size() > 0); + } +} + +TEST_CASE("AdaBoost Input Dimension Validation", "[AdaBoost]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Correct input dimensions") + { + AdaBoost ada(10, 2); + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + // Test with correct tensor dimensions (features x samples) + REQUIRE_NOTHROW(ada.predict(raw.Xt)); + REQUIRE_NOTHROW(ada.predict_proba(raw.Xt)); + + // Test with correct vector dimensions (features x samples) + REQUIRE_NOTHROW(ada.predict(raw.Xv)); + REQUIRE_NOTHROW(ada.predict_proba(raw.Xv)); + } + + SECTION("Dimension consistency between interfaces") + { + AdaBoost ada(10, 2); + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + // Get predictions from both interfaces + auto tensor_predictions = ada.predict(raw.Xt); + auto vector_predictions = ada.predict(raw.Xv); + + // Should have same number of predictions + REQUIRE(tensor_predictions.size(0) == static_cast(vector_predictions.size())); + + // Test probability predictions + auto tensor_proba = ada.predict_proba(raw.Xt); + auto vector_proba = ada.predict_proba(raw.Xv); + + REQUIRE(tensor_proba.size(0) == static_cast(vector_proba.size())); + REQUIRE(tensor_proba.size(1) == static_cast(vector_proba[0].size())); + + // Verify predictions match between interfaces + for (int i = 0; i < tensor_predictions.size(0); i++) { + REQUIRE(tensor_predictions[i].item() == vector_predictions[i]); + + // Verify probabilities match between interfaces + for (int j = 0; j < tensor_proba.size(1); j++) { + REQUIRE(tensor_proba[i][j].item() == Catch::Approx(vector_proba[i][j]).epsilon(1e-10)); + } + } + } +} + +TEST_CASE("AdaBoost Debug - Simple Dataset Analysis", "[AdaBoost][debug]") +{ + // Create the exact same simple dataset that was failing + int n_samples = 20; + int n_features = 2; + + std::vector> X(n_features, std::vector(n_samples)); + std::vector y(n_samples); + + // Simple pattern: class depends on first feature + for (int i = 0; i < n_samples; i++) { + X[0][i] = i < 10 ? 0 : 1; + X[1][i] = i % 2; + y[i] = X[0][i]; // Class equals first feature + } + + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + SECTION("Debug training process") + { + AdaBoost ada(5, 3); // Few estimators for debugging + + // This should work perfectly on this simple dataset + REQUIRE_NOTHROW(ada.fit(X, y, features, className, states, Smoothing_t::NONE)); + + // Get training details + auto weights = ada.getEstimatorWeights(); + auto errors = ada.getTrainingErrors(); + + INFO("Number of models trained: " << weights.size()); + INFO("Training errors: "); + for (size_t i = 0; i < errors.size(); i++) { + INFO(" Model " << i << ": error=" << errors[i] << ", weight=" << weights[i]); + } + + // Should have at least one model + REQUIRE(weights.size() > 0); + REQUIRE(errors.size() == weights.size()); + + // All training errors should be reasonable for this simple dataset + for (double error : errors) { + REQUIRE(error >= 0.0); + REQUIRE(error < 0.5); // Should be better than random + } + + // Test predictions + auto predictions = ada.predict(X); + REQUIRE(predictions.size() == static_cast(n_samples)); + + // Calculate accuracy + int correct = 0; + for (size_t i = 0; i < predictions.size(); i++) { + if (predictions[i] == y[i]) correct++; + INFO("Sample " << i << ": predicted=" << predictions[i] << ", actual=" << y[i]); + } + double accuracy = static_cast(correct) / n_samples; + INFO("Accuracy: " << accuracy); + + // Should achieve high accuracy on this perfectly separable dataset + REQUIRE(accuracy >= 0.9); // Lower threshold for debugging + + // Test probability predictions + auto proba = ada.predict_proba(X); + REQUIRE(proba.size() == static_cast(n_samples)); + + // Verify probabilities are valid + for (size_t i = 0; i < proba.size(); i++) { + auto p = proba[i]; + REQUIRE(p.size() == 2); + REQUIRE(p[0] >= 0.0); + REQUIRE(p[1] >= 0.0); + double sum = p[0] + p[1]; + REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6)); + + // Predicted class should match highest probability + int pred_class = predictions[i]; + REQUIRE(pred_class == (p[0] > p[1] ? 0 : 1)); + } + } + + SECTION("Compare with single DecisionTree") + { + // Test that AdaBoost performs at least as well as a single tree + DecisionTree single_tree(3, 2, 1); + single_tree.fit(X, y, features, className, states, Smoothing_t::NONE); + auto tree_predictions = single_tree.predict(X); + + int tree_correct = 0; + for (size_t i = 0; i < tree_predictions.size(); i++) { + if (tree_predictions[i] == y[i]) tree_correct++; + } + double tree_accuracy = static_cast(tree_correct) / n_samples; + + AdaBoost ada(5, 3); + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + auto ada_predictions = ada.predict(X); + + int ada_correct = 0; + for (size_t i = 0; i < ada_predictions.size(); i++) { + if (ada_predictions[i] == y[i]) ada_correct++; + } + double ada_accuracy = static_cast(ada_correct) / n_samples; + + INFO("DecisionTree accuracy: " << tree_accuracy); + INFO("AdaBoost accuracy: " << ada_accuracy); + + // AdaBoost should perform at least as well as single tree + // (allowing small tolerance for numerical differences) + REQUIRE(ada_accuracy >= tree_accuracy - 0.1); + } +} + +TEST_CASE("AdaBoost SAMME Algorithm Validation", "[AdaBoost]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Prediction consistency with probabilities") + { + AdaBoost ada(15, 3); + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = ada.predict(raw.Xt); + auto probabilities = ada.predict_proba(raw.Xt); + + REQUIRE(predictions.size(0) == probabilities.size(0)); + REQUIRE(probabilities.size(1) == 3); // Three classes in Iris + + // For each sample, predicted class should correspond to highest probability + for (int i = 0; i < predictions.size(0); i++) { + int predicted_class = predictions[i].item(); + auto probs = probabilities[i]; + + // Find class with highest probability + auto max_prob_idx = torch::argmax(probs).item(); + + // Predicted class should match class with highest probability + REQUIRE(predicted_class == max_prob_idx); + + // Probabilities should sum to 1 + double sum_probs = torch::sum(probs).item(); + REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6)); + + // All probabilities should be non-negative + for (int j = 0; j < 3; j++) { + REQUIRE(probs[j].item() >= 0.0); + REQUIRE(probs[j].item() <= 1.0); + } + } + } + + SECTION("Weighted voting verification") + { + // Simple dataset where we can verify the weighted voting + std::vector> X = { {0,0,1,1}, {0,1,0,1} }; + std::vector y = { 0, 1, 1, 0 }; + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + AdaBoost ada(5, 2); + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto predictions = ada.predict(X); + auto probabilities = ada.predict_proba(X); + auto alphas = ada.getEstimatorWeights(); + + REQUIRE(predictions.size() == 4); + REQUIRE(probabilities.size() == 4); + REQUIRE(probabilities[0].size() == 2); // Two classes + REQUIRE(alphas.size() > 0); + + // Verify that estimator weights are reasonable + for (double alpha : alphas) { + REQUIRE(alpha >= 0.0); // Alphas should be non-negative + } + + // Verify prediction-probability consistency + for (size_t i = 0; i < predictions.size(); i++) { + int pred = predictions[i]; + auto probs = probabilities[i]; + + REQUIRE(pred == (probs[0] > probs[1] ? 0 : 1)); + REQUIRE(probs[0] + probs[1] == Catch::Approx(1.0).epsilon(1e-6)); + } + } + + SECTION("Empty models edge case") + { + AdaBoost ada(1, 1); + + // Try to predict before fitting + std::vector> X = { {0}, {1} }; + REQUIRE_THROWS_WITH(ada.predict(X), ContainsSubstring("not been fitted")); + REQUIRE_THROWS_WITH(ada.predict_proba(X), ContainsSubstring("not been fitted")); + } +} \ No newline at end of file diff --git a/tests/TestDecisionTree.cpp b/tests/TestDecisionTree.cpp index 8fd4bd5..7b5ef76 100644 --- a/tests/TestDecisionTree.cpp +++ b/tests/TestDecisionTree.cpp @@ -39,6 +39,9 @@ TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]") REQUIRE_NOTHROW(dt.setMaxDepth(10)); REQUIRE_NOTHROW(dt.setMinSamplesSplit(5)); REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2)); + REQUIRE(dt.getMaxDepth() == 10); + REQUIRE(dt.getMinSamplesSplit() == 5); + REQUIRE(dt.getMinSamplesLeaf() == 2); } SECTION("Set hyperparameters via JSON") @@ -49,6 +52,9 @@ TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]") params["min_samples_leaf"] = 2; REQUIRE_NOTHROW(dt.setHyperparameters(params)); + REQUIRE(dt.getMaxDepth() == 7); + REQUIRE(dt.getMinSamplesSplit() == 4); + REQUIRE(dt.getMinSamplesLeaf() == 2); } SECTION("Invalid hyperparameters should throw") @@ -164,7 +170,9 @@ TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]") // Calculate accuracy auto correct = torch::sum(predictions == raw.yt).item(); double accuracy = static_cast(correct) / raw.yt.size(0); + double acurracy_computed = dt.score(raw.Xt, raw.yt); REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris + REQUIRE(acurracy_computed == Catch::Approx(accuracy).epsilon(1e-6)); } SECTION("Training with vector interface") From 56af1a5f850eb163f73bed04cf82ac65befd6e46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 18 Jun 2025 13:59:23 +0200 Subject: [PATCH 21/27] AdaBoost a falta de predict_proba --- src/experimental_clfs/AdaBoost.cpp | 93 ++++++++++----------- src/experimental_clfs/AdaBoost.h | 2 + src/experimental_clfs/TensorUtils.hpp | 33 ++++++++ tests/TestAdaBoost.cpp | 112 +++++++++++++++++++++++++- 4 files changed, 191 insertions(+), 49 deletions(-) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index a04236d..5af7f31 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -43,12 +43,15 @@ namespace bayesnet { } // Debug information - std::cout << "Starting AdaBoost training with " << n_estimators << " estimators" << std::endl; - std::cout << "Number of classes: " << n_classes << std::endl; - std::cout << "Number of features: " << n << std::endl; - std::cout << "Number of samples: " << n_samples << std::endl; + if (debug) { + std::cout << "Starting AdaBoost training with " << n_estimators << " estimators" << std::endl; + std::cout << "Number of classes: " << n_classes << std::endl; + std::cout << "Number of features: " << n << std::endl; + std::cout << "Number of samples: " << n_samples << std::endl; + } - // Main AdaBoost training loop (SAMME algorithm) + // Main AdaBoost training loop (SAMME algorithm) + // (Stagewise Additive Modeling using a Multi - class Exponential loss) for (int iter = 0; iter < n_estimators; ++iter) { // Train base estimator with current sample weights auto estimator = trainBaseEstimator(sample_weights); @@ -57,16 +60,12 @@ namespace bayesnet { double weighted_error = calculateWeightedError(estimator.get(), sample_weights); training_errors.push_back(weighted_error); - // Debug output - std::cout << "Iteration " << iter + 1 << ":" << std::endl; - std::cout << " Weighted error: " << weighted_error << std::endl; - // Check if error is too high (worse than random guessing) double random_guess_error = 1.0 - (1.0 / n_classes); // According to SAMME, we need error < random_guess_error if (weighted_error >= random_guess_error) { - std::cout << " Error >= random guess (" << random_guess_error << "), stopping" << std::endl; + if (debug) std::cout << " Error >= random guess (" << random_guess_error << "), stopping" << std::endl; // If only one estimator and it's worse than random, keep it with zero weight if (models.empty()) { models.push_back(std::move(estimator)); @@ -80,8 +79,6 @@ namespace bayesnet { double alpha = std::log((1.0 - weighted_error) / weighted_error) + std::log(static_cast(n_classes - 1)); - std::cout << " Alpha: " << alpha << std::endl; - // Store the estimator and its weight models.push_back(std::move(estimator)); alphas.push_back(alpha); @@ -92,16 +89,23 @@ namespace bayesnet { // Normalize weights normalizeWeights(); + if (debug) { + std::cout << "Iteration " << iter << ":" << std::endl; + std::cout << " Weighted error: " << weighted_error << std::endl; + std::cout << " Alpha: " << alpha << std::endl; + std::cout << " Random guess error: " << random_guess_error << std::endl; + } + // Check for perfect classification if (weighted_error < 1e-10) { - std::cout << " Perfect classification achieved, stopping" << std::endl; + if (debug) std::cout << " Perfect classification achieved, stopping" << std::endl; break; } } // Set the number of models actually trained n_models = models.size(); - std::cout << "AdaBoost training completed with " << n_models << " models" << std::endl; + if (debug) std::cout << "AdaBoost training completed with " << n_models << " models" << std::endl; } void AdaBoost::trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) @@ -305,7 +309,7 @@ namespace bayesnet { std::vector AdaBoost::predict(std::vector>& X) { // Convert to tensor - X is samples x features, need to transpose - torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X).t(); + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); auto predictions = predict(X_tensor); std::vector result = platform::TensorUtils::to_vector(predictions); return result; @@ -313,9 +317,9 @@ namespace bayesnet { std::vector> AdaBoost::predict_proba(std::vector>& X) { - auto n_samples = X.size(); + auto n_samples = X[0].size(); // Convert to tensor - X is samples x features, need to transpose - torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X).t(); + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); auto proba_tensor = predict_proba(X_tensor); std::vector> result(n_samples, std::vector(n_classes, 0.0)); @@ -351,14 +355,9 @@ namespace bayesnet { // Accumulate weighted votes from all estimators for (size_t i = 0; i < models.size(); i++) { if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight - try { - // Create a matrix with the sample as a column vector - auto x_matrix = x.unsqueeze(1); // Shape: (n_features, 1) - // Get prediction from this estimator - auto prediction = models[i]->predict(x_matrix); - int predicted_class = prediction[0].item(); + int predicted_class = static_cast(models[i].get())->predictSample(x); // Add weighted vote for this class if (predicted_class >= 0 && predicted_class < n_classes) { @@ -392,28 +391,23 @@ namespace bayesnet { std::to_string(n) + " but got " + std::to_string(x.size(0))); } - // Initialize probability accumulator - torch::Tensor class_probs = torch::zeros({ n_classes }, torch::kDouble); + // Initialize class votes (same logic as predictSample) + std::vector class_votes(n_classes, 0.0); - // Sum weighted probabilities from all estimators + // Accumulate weighted votes from all estimators (SAMME voting) double total_alpha = 0.0; - for (size_t i = 0; i < models.size(); i++) { if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight try { - // Create a matrix with the sample as a column vector - auto x_matrix = x.unsqueeze(1); // Shape: (n_features, 1) + // Get class prediction from this estimator (not probabilities!) + int predicted_class = static_cast(models[i].get())->predictSample(x); - // Get probability predictions from this estimator - auto proba = models[i]->predict_proba(x_matrix); - - // Add weighted probabilities - for (int j = 0; j < n_classes; j++) { - class_probs[j] += alphas[i] * proba[0][j].item(); + // Add weighted vote for this class (SAMME algorithm) + if (predicted_class >= 0 && predicted_class < n_classes) { + class_votes[predicted_class] += alphas[i]; + total_alpha += alphas[i]; } - - total_alpha += alphas[i]; } catch (const std::exception& e) { std::cerr << "Error in estimator " << i << ": " << e.what() << std::endl; @@ -421,24 +415,31 @@ namespace bayesnet { } } - // Normalize probabilities + // Convert votes to probabilities + torch::Tensor class_probs = torch::zeros({ n_classes }, torch::kFloat); + if (total_alpha > 0) { - class_probs = class_probs / total_alpha; + // Normalize votes to get probabilities + for (int j = 0; j < n_classes; j++) { + class_probs[j] = static_cast(class_votes[j] / total_alpha); + } } else { // If no valid estimators, return uniform distribution - class_probs.fill_(1.0 / n_classes); + class_probs.fill_(1.0f / n_classes); } - // Ensure probabilities are valid (non-negative and sum to 1) - class_probs = torch::clamp(class_probs, 0.0, 1.0); - double sum_probs = torch::sum(class_probs).item(); - if (sum_probs > 1e-15) { + // Ensure probabilities are valid (they should be already, but just in case) + class_probs = torch::clamp(class_probs, 0.0f, 1.0f); + + // Verify they sum to 1 (they should, but normalize if needed due to floating point errors) + float sum_probs = torch::sum(class_probs).item(); + if (sum_probs > 1e-15f) { class_probs = class_probs / sum_probs; } else { - class_probs.fill_(1.0 / n_classes); + class_probs.fill_(1.0f / n_classes); } - return class_probs.to(torch::kFloat); // Convert back to float for consistency + return class_probs; } } // namespace bayesnet \ No newline at end of file diff --git a/src/experimental_clfs/AdaBoost.h b/src/experimental_clfs/AdaBoost.h index 5d1bc37..0c7e08b 100644 --- a/src/experimental_clfs/AdaBoost.h +++ b/src/experimental_clfs/AdaBoost.h @@ -39,6 +39,7 @@ namespace bayesnet { std::vector predict(std::vector>& X) override; torch::Tensor predict_proba(torch::Tensor& X) override; std::vector> predict_proba(std::vector>& X); + void setDebug(bool debug) { this->debug = debug; } protected: void buildModel(const torch::Tensor& weights) override; @@ -73,6 +74,7 @@ namespace bayesnet { // Make probabilistic predictions for a single sample torch::Tensor predictProbaSample(const torch::Tensor& x) const; + bool debug = false; // Enable debug mode for debug output }; } diff --git a/src/experimental_clfs/TensorUtils.hpp b/src/experimental_clfs/TensorUtils.hpp index 77ed894..2efdf7d 100644 --- a/src/experimental_clfs/TensorUtils.hpp +++ b/src/experimental_clfs/TensorUtils.hpp @@ -59,6 +59,39 @@ namespace platform { return tensor; } }; + static void dumpVector(const std::vector>& vec, const std::string& name) + { + std::cout << name << ": " << std::endl; + for (const auto& row : vec) { + std::cout << "["; + for (const auto& val : row) { + std::cout << val << " "; + } + std::cout << "]" << std::endl; + } + std::cout << std::endl; + } + static void dumpTensor(const torch::Tensor& tensor, const std::string& name) + { + std::cout << name << ": " << std::endl; + for (auto i = 0; i < tensor.size(0); i++) { + std::cout << "["; + for (auto j = 0; j < tensor.size(1); j++) { + std::cout << tensor[i][j].item() << " "; + } + std::cout << "]" << std::endl; + } + std::cout << std::endl; + } + static void dumpTensorV(const torch::Tensor& tensor, const std::string& name) + { + std::cout << name << ": " << std::endl; + std::cout << "["; + for (int i = 0; i < tensor.size(0); i++) { + std::cout << tensor[i].item() << " "; + } + std::cout << "]" << std::endl; + } } #endif // TENSORUTILS_HPP \ No newline at end of file diff --git a/tests/TestAdaBoost.cpp b/tests/TestAdaBoost.cpp index 6c2453d..301ebb2 100644 --- a/tests/TestAdaBoost.cpp +++ b/tests/TestAdaBoost.cpp @@ -13,11 +13,13 @@ #include #include "experimental_clfs/AdaBoost.h" #include "experimental_clfs/DecisionTree.h" +#include "experimental_clfs/TensorUtils.hpp" #include "TestUtils.h" using namespace bayesnet; using namespace Catch::Matchers; + TEST_CASE("AdaBoost Construction", "[AdaBoost]") { SECTION("Default constructor") @@ -143,7 +145,15 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") auto predictions = ada.predict(X); REQUIRE(predictions.size() == static_cast(n_samples)); - + // Check accuracy + int correct = 0; + for (size_t i = 0; i < predictions.size(); i++) { + if (predictions[i] == y[i]) correct++; + } + double accuracy = static_cast(correct) / n_samples; + REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset + auto accuracy_computed = ada.score(X, y); + REQUIRE(accuracy_computed == Catch::Approx(accuracy).epsilon(1e-6)); } SECTION("Probability predictions with vector interface") @@ -157,6 +167,7 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") // Check probabilities sum to 1 and are valid auto predictions = ada.predict(X); + int correct = 0; for (size_t i = 0; i < proba.size(); i++) { auto p = proba[i]; auto pred = predictions[i]; @@ -165,10 +176,19 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") REQUIRE(p[1] >= 0.0); double sum = p[0] + p[1]; REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6)); + // compute the predicted class based on probabilities + auto predicted_class = (p[0] > p[1]) ? 0 : 1; + // compute accuracy based on predictions + if (predicted_class == y[i]) { + correct++; + } // Check that predict_proba matches the expected predict value - REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); + // REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); } + double accuracy = static_cast(correct) / n_samples; + std::cout << "Probability accuracy: " << accuracy << std::endl; + REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset } } @@ -194,7 +214,9 @@ TEST_CASE("AdaBoost Tensor Interface", "[AdaBoost]") // Calculate accuracy auto correct = torch::sum(predictions == raw.yt).item(); double accuracy = static_cast(correct) / raw.yt.size(0); - REQUIRE(accuracy > 0.85); // Should achieve good accuracy on Iris + auto accuracy_computed = ada.score(raw.Xt, raw.yt); + REQUIRE(accuracy_computed == Catch::Approx(accuracy).epsilon(1e-6)); + REQUIRE(accuracy > 0.97); // Should achieve good accuracy on Iris // Test probability predictions with tensor auto proba = ada.predict_proba(raw.Xt); @@ -704,4 +726,88 @@ TEST_CASE("AdaBoost SAMME Algorithm Validation", "[AdaBoost]") REQUIRE_THROWS_WITH(ada.predict(X), ContainsSubstring("not been fitted")); REQUIRE_THROWS_WITH(ada.predict_proba(X), ContainsSubstring("not been fitted")); } +} +TEST_CASE("AdaBoost Predict-Proba Consistency Fix", "[AdaBoost][consistency]") +{ + // Simple binary classification dataset + std::vector> X = { {0,0,1,1}, {0,1,0,1} }; + std::vector y = { 0, 0, 1, 1 }; + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + SECTION("Binary classification consistency") + { + AdaBoost ada(3, 2); + ada.setDebug(true); // Enable debug output + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto predictions = ada.predict(X); + auto probabilities = ada.predict_proba(X); + + INFO("=== Debugging predict vs predict_proba consistency ==="); + + // Verify consistency for each sample + for (size_t i = 0; i < predictions.size(); i++) { + int predicted_class = predictions[i]; + auto probs = probabilities[i]; + + INFO("Sample " << i << ":"); + INFO(" True class: " << y[i]); + INFO(" Predicted class: " << predicted_class); + INFO(" Probabilities: [" << probs[0] << ", " << probs[1] << "]"); + + // The predicted class should be the one with highest probability + int max_prob_class = (probs[0] > probs[1]) ? 0 : 1; + INFO(" Max prob class: " << max_prob_class); + + REQUIRE(predicted_class == max_prob_class); + + // Probabilities should sum to 1 + double sum_probs = probs[0] + probs[1]; + REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6)); + + // All probabilities should be valid + REQUIRE(probs[0] >= 0.0); + REQUIRE(probs[1] >= 0.0); + REQUIRE(probs[0] <= 1.0); + REQUIRE(probs[1] <= 1.0); + } + } + + SECTION("Multi-class consistency") + { + auto raw = RawDatasets("iris", true); + + AdaBoost ada(5, 2); + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = ada.predict(raw.Xt); + auto probabilities = ada.predict_proba(raw.Xt); + + // Check consistency for first 10 samples + for (int i = 0; i < std::min(static_cast(10), predictions.size(0)); i++) { + int predicted_class = predictions[i].item(); + auto probs = probabilities[i]; + + // Find class with maximum probability + auto max_prob_idx = torch::argmax(probs).item(); + + INFO("Sample " << i << ":"); + INFO(" Predicted class: " << predicted_class); + INFO(" Max prob class: " << max_prob_idx); + INFO(" Probabilities: [" << probs[0].item() << ", " + << probs[1].item() << ", " << probs[2].item() << "]"); + + // They must match + REQUIRE(predicted_class == max_prob_idx); + + // Probabilities should sum to 1 + double sum_probs = torch::sum(probs).item(); + REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6)); + } + } } \ No newline at end of file From 4e18dc87be7aa9bbb2163d590ca535e62f85f87d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 18 Jun 2025 14:18:15 +0200 Subject: [PATCH 22/27] Fix predict_proba in AdaBoost --- src/experimental_clfs/AdaBoost.cpp | 45 ++++++++++++++++++++++-------- tests/TestAdaBoost.cpp | 5 ++-- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index 5af7f31..407c587 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -74,32 +74,53 @@ namespace bayesnet { break; // Stop boosting } + // Check for perfect classification BEFORE calculating alpha + if (weighted_error <= 1e-10) { + if (debug) std::cout << " Perfect classification achieved (error=" << weighted_error << ")" << std::endl; + + // For perfect classification, use a large but finite alpha + double alpha = 10.0 + std::log(static_cast(n_classes - 1)); + + // Store the estimator and its weight + models.push_back(std::move(estimator)); + alphas.push_back(alpha); + + if (debug) { + std::cout << "Iteration " << iter << ":" << std::endl; + std::cout << " Weighted error: " << weighted_error << std::endl; + std::cout << " Alpha (finite): " << alpha << std::endl; + std::cout << " Random guess error: " << random_guess_error << std::endl; + } + + break; // Stop training as we have a perfect classifier + } + // Calculate alpha (estimator weight) using SAMME formula // alpha = log((1 - err) / err) + log(K - 1) - double alpha = std::log((1.0 - weighted_error) / weighted_error) + + // Clamp weighted_error to avoid division by zero and infinite alpha + double clamped_error = std::max(1e-15, std::min(1.0 - 1e-15, weighted_error)); + double alpha = std::log((1.0 - clamped_error) / clamped_error) + std::log(static_cast(n_classes - 1)); + // Clamp alpha to reasonable bounds to avoid numerical issues + alpha = std::max(-10.0, std::min(10.0, alpha)); + // Store the estimator and its weight models.push_back(std::move(estimator)); alphas.push_back(alpha); - // Update sample weights - updateSampleWeights(models.back().get(), alpha); - - // Normalize weights - normalizeWeights(); + // Update sample weights (only if this is not the last iteration) + if (iter < n_estimators - 1) { + updateSampleWeights(models.back().get(), alpha); + normalizeWeights(); + } if (debug) { std::cout << "Iteration " << iter << ":" << std::endl; std::cout << " Weighted error: " << weighted_error << std::endl; std::cout << " Alpha: " << alpha << std::endl; std::cout << " Random guess error: " << random_guess_error << std::endl; - } - - // Check for perfect classification - if (weighted_error < 1e-10) { - if (debug) std::cout << " Perfect classification achieved, stopping" << std::endl; - break; + std::cout << " Random guess error: " << random_guess_error << std::endl; } } diff --git a/tests/TestAdaBoost.cpp b/tests/TestAdaBoost.cpp index 301ebb2..7fb887b 100644 --- a/tests/TestAdaBoost.cpp +++ b/tests/TestAdaBoost.cpp @@ -184,10 +184,9 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") } // Check that predict_proba matches the expected predict value - // REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); + REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); } double accuracy = static_cast(correct) / n_samples; - std::cout << "Probability accuracy: " << accuracy << std::endl; REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset } } @@ -711,6 +710,8 @@ TEST_CASE("AdaBoost SAMME Algorithm Validation", "[AdaBoost]") for (size_t i = 0; i < predictions.size(); i++) { int pred = predictions[i]; auto probs = probabilities[i]; + INFO("Sample " << i << ": predicted=" << pred + << ", probabilities=[" << probs[0] << ", " << probs[1] << "]"); REQUIRE(pred == (probs[0] > probs[1] ? 0 : 1)); REQUIRE(probs[0] + probs[1] == Catch::Approx(1.0).epsilon(1e-6)); From 41afa1b8883b68c8d531d4c06bc2237f508dc411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 18 Jun 2025 17:33:56 +0200 Subject: [PATCH 23/27] Enhance predictProbaSample --- src/experimental_clfs/AdaBoost.cpp | 73 +++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 22 deletions(-) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index 407c587..3e276a2 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -412,52 +412,81 @@ namespace bayesnet { std::to_string(n) + " but got " + std::to_string(x.size(0))); } - // Initialize class votes (same logic as predictSample) + // Initialize class votes with zeros std::vector class_votes(n_classes, 0.0); + double total_votes = 0.0; - // Accumulate weighted votes from all estimators (SAMME voting) - double total_alpha = 0.0; + if (debug) { + std::cout << "=== predictProbaSample Debug ===" << std::endl; + std::cout << "Number of models: " << models.size() << std::endl; + std::cout << "Number of classes: " << n_classes << std::endl; + } + + // Accumulate votes from all estimators for (size_t i = 0; i < models.size(); i++) { - if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight + double alpha = alphas[i]; + + // Skip invalid estimators + if (alpha <= 0 || !std::isfinite(alpha)) { + if (debug) std::cout << "Skipping model " << i << " (alpha=" << alpha << ")" << std::endl; + continue; + } try { - // Get class prediction from this estimator (not probabilities!) + // Get class prediction from this estimator int predicted_class = static_cast(models[i].get())->predictSample(x); - // Add weighted vote for this class (SAMME algorithm) + if (debug) { + std::cout << "Model " << i << ": predicts class " << predicted_class + << " with alpha " << alpha << std::endl; + } + + // Add weighted vote for this class if (predicted_class >= 0 && predicted_class < n_classes) { - class_votes[predicted_class] += alphas[i]; - total_alpha += alphas[i]; + class_votes[predicted_class] += alpha; + total_votes += alpha; + } else { + if (debug) std::cout << "Invalid class prediction: " << predicted_class << std::endl; } } catch (const std::exception& e) { - std::cerr << "Error in estimator " << i << ": " << e.what() << std::endl; + if (debug) std::cout << "Error in model " << i << ": " << e.what() << std::endl; continue; } } + if (debug) { + std::cout << "Total votes: " << total_votes << std::endl; + std::cout << "Class votes: ["; + for (int j = 0; j < n_classes; j++) { + std::cout << class_votes[j]; + if (j < n_classes - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + } + // Convert votes to probabilities torch::Tensor class_probs = torch::zeros({ n_classes }, torch::kFloat); - if (total_alpha > 0) { - // Normalize votes to get probabilities + if (total_votes > 0) { + // Simple division to get probabilities for (int j = 0; j < n_classes; j++) { - class_probs[j] = static_cast(class_votes[j] / total_alpha); + class_probs[j] = static_cast(class_votes[j] / total_votes); } } else { - // If no valid estimators, return uniform distribution + // If no valid votes, uniform distribution + if (debug) std::cout << "No valid votes, using uniform distribution" << std::endl; class_probs.fill_(1.0f / n_classes); } - // Ensure probabilities are valid (they should be already, but just in case) - class_probs = torch::clamp(class_probs, 0.0f, 1.0f); - - // Verify they sum to 1 (they should, but normalize if needed due to floating point errors) - float sum_probs = torch::sum(class_probs).item(); - if (sum_probs > 1e-15f) { - class_probs = class_probs / sum_probs; - } else { - class_probs.fill_(1.0f / n_classes); + if (debug) { + std::cout << "Final probabilities: ["; + for (int j = 0; j < n_classes; j++) { + std::cout << class_probs[j].item(); + if (j < n_classes - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + std::cout << "=== End predictProbaSample Debug ===" << std::endl; } return class_probs; From dda9740e837562289efdbe94d4cf7407fad8446c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 18 Jun 2025 18:03:19 +0200 Subject: [PATCH 24/27] Test AdaBoost fine but unoptimized --- src/experimental_clfs/AdaBoost.cpp | 238 ++++++++++++-- tests/TestAdaBoost.cpp | 499 +++++++---------------------- 2 files changed, 322 insertions(+), 415 deletions(-) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index 3e276a2..4f21a77 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -300,6 +300,101 @@ namespace bayesnet { return predictions; } + // torch::Tensor AdaBoost::predict_proba(torch::Tensor& X) + // { + // if (!fitted) { + // throw std::runtime_error(CLASSIFIER_NOT_FITTED); + // } + + // if (models.empty()) { + // throw std::runtime_error("No models have been trained"); + // } + + // // X should be (n_features, n_samples) + // if (X.size(0) != n) { + // throw std::runtime_error("Input has wrong number of features. Expected " + + // std::to_string(n) + " but got " + std::to_string(X.size(0))); + // } + + // int n_samples = X.size(1); + // torch::Tensor probabilities = torch::zeros({ n_samples, n_classes }); + + // for (int i = 0; i < n_samples; i++) { + // auto sample = X.index({ torch::indexing::Slice(), i }); + // probabilities[i] = predictProbaSample(sample); + // } + + // return probabilities; + // } + + std::vector AdaBoost::predict(std::vector>& X) + { + // Convert to tensor - X is samples x features, need to transpose + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); + auto predictions = predict(X_tensor); + std::vector result = platform::TensorUtils::to_vector(predictions); + return result; + } + + std::vector> AdaBoost::predict_proba(std::vector>& X) + { + auto n_samples = X[0].size(); + + if (debug) { + std::cout << "=== predict_proba vector method debug ===" << std::endl; + std::cout << "Input X dimensions: " << X.size() << " features x " << n_samples << " samples" << std::endl; + std::cout << "Input data:" << std::endl; + for (size_t i = 0; i < X.size(); i++) { + std::cout << " Feature " << i << ": ["; + for (size_t j = 0; j < X[i].size(); j++) { + std::cout << X[i][j]; + if (j < X[i].size() - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + } + } + + // Convert to tensor - X is features x samples, need to transpose for tensor format + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); + + if (debug) { + std::cout << "Converted tensor shape: " << X_tensor.sizes() << std::endl; + std::cout << "Tensor data: " << X_tensor << std::endl; + } + + auto proba_tensor = predict_proba(X_tensor); // Call tensor method + + if (debug) { + std::cout << "Proba tensor shape: " << proba_tensor.sizes() << std::endl; + std::cout << "Proba tensor data: " << proba_tensor << std::endl; + } + + std::vector> result(n_samples, std::vector(n_classes, 0.0)); + + for (size_t i = 0; i < n_samples; i++) { + for (int j = 0; j < n_classes; j++) { + result[i][j] = proba_tensor[i][j].item(); + } + + if (debug) { + std::cout << "Sample " << i << " converted: ["; + for (int j = 0; j < n_classes; j++) { + std::cout << result[i][j]; + if (j < n_classes - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + } + } + + if (debug) { + std::cout << "=== End predict_proba vector method debug ===" << std::endl; + } + + return result; + } + + // También agregar debug al método tensor predict_proba: + torch::Tensor AdaBoost::predict_proba(torch::Tensor& X) { if (!fitted) { @@ -317,43 +412,85 @@ namespace bayesnet { } int n_samples = X.size(1); + + if (debug) { + std::cout << "=== predict_proba tensor method debug ===" << std::endl; + std::cout << "Input tensor shape: " << X.sizes() << std::endl; + std::cout << "Number of samples: " << n_samples << std::endl; + std::cout << "Number of classes: " << n_classes << std::endl; + } + torch::Tensor probabilities = torch::zeros({ n_samples, n_classes }); for (int i = 0; i < n_samples; i++) { auto sample = X.index({ torch::indexing::Slice(), i }); - probabilities[i] = predictProbaSample(sample); + + if (debug) { + std::cout << "Processing sample " << i << ": " << sample << std::endl; + } + + auto sample_probs = predictProbaSample(sample); + + if (debug) { + std::cout << "Sample " << i << " probabilities from predictProbaSample: " << sample_probs << std::endl; + } + + probabilities[i] = sample_probs; + + if (debug) { + std::cout << "Assigned to probabilities[" << i << "]: " << probabilities[i] << std::endl; + } + } + + if (debug) { + std::cout << "Final probabilities tensor: " << probabilities << std::endl; + std::cout << "=== End predict_proba tensor method debug ===" << std::endl; } return probabilities; } - std::vector AdaBoost::predict(std::vector>& X) - { - // Convert to tensor - X is samples x features, need to transpose - torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); - auto predictions = predict(X_tensor); - std::vector result = platform::TensorUtils::to_vector(predictions); - return result; - } + // int AdaBoost::predictSample(const torch::Tensor& x) const + // { + // if (!fitted) { + // throw std::runtime_error(CLASSIFIER_NOT_FITTED); + // } - std::vector> AdaBoost::predict_proba(std::vector>& X) - { - auto n_samples = X[0].size(); - // Convert to tensor - X is samples x features, need to transpose - torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); - auto proba_tensor = predict_proba(X_tensor); + // if (models.empty()) { + // throw std::runtime_error("No models have been trained"); + // } - std::vector> result(n_samples, std::vector(n_classes, 0.0)); + // // x should be a 1D tensor with n features + // if (x.size(0) != n) { + // throw std::runtime_error("Input sample has wrong number of features. Expected " + + // std::to_string(n) + " but got " + std::to_string(x.size(0))); + // } - for (size_t i = 0; i < n_samples; i++) { - for (int j = 0; j < n_classes; j++) { - result[i][j] = proba_tensor[i][j].item(); - } - } + // // Initialize class votes + // std::vector class_votes(n_classes, 0.0); - return result; - } + // // Accumulate weighted votes from all estimators + // for (size_t i = 0; i < models.size(); i++) { + // if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight + // try { + // // Get prediction from this estimator + // int predicted_class = static_cast(models[i].get())->predictSample(x); + // // Add weighted vote for this class + // if (predicted_class >= 0 && predicted_class < n_classes) { + // class_votes[predicted_class] += alphas[i]; + // } + // } + // catch (const std::exception& e) { + // std::cerr << "Error in estimator " << i << ": " << e.what() << std::endl; + // continue; + // } + // } + + // // Return class with highest weighted vote + // return std::distance(class_votes.begin(), + // std::max_element(class_votes.begin(), class_votes.end())); + // } int AdaBoost::predictSample(const torch::Tensor& x) const { if (!fitted) { @@ -370,30 +507,67 @@ namespace bayesnet { std::to_string(n) + " but got " + std::to_string(x.size(0))); } - // Initialize class votes + // Initialize class votes with zeros std::vector class_votes(n_classes, 0.0); - // Accumulate weighted votes from all estimators + if (debug) { + std::cout << "=== predictSample Debug ===" << std::endl; + std::cout << "Number of models: " << models.size() << std::endl; + } + + // Accumulate votes from all estimators (same logic as predictProbaSample) for (size_t i = 0; i < models.size(); i++) { - if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight + double alpha = alphas[i]; + + // Skip invalid estimators + if (alpha <= 0 || !std::isfinite(alpha)) { + if (debug) std::cout << "Skipping model " << i << " (alpha=" << alpha << ")" << std::endl; + continue; + } + try { - // Get prediction from this estimator + // Get class prediction from this estimator int predicted_class = static_cast(models[i].get())->predictSample(x); + if (debug) { + std::cout << "Model " << i << ": predicts class " << predicted_class + << " with alpha " << alpha << std::endl; + } + // Add weighted vote for this class if (predicted_class >= 0 && predicted_class < n_classes) { - class_votes[predicted_class] += alphas[i]; + class_votes[predicted_class] += alpha; } } catch (const std::exception& e) { - std::cerr << "Error in estimator " << i << ": " << e.what() << std::endl; + if (debug) std::cout << "Error in model " << i << ": " << e.what() << std::endl; continue; } } - // Return class with highest weighted vote - return std::distance(class_votes.begin(), - std::max_element(class_votes.begin(), class_votes.end())); + // Find class with maximum votes + int best_class = 0; + double max_votes = class_votes[0]; + + for (int j = 1; j < n_classes; j++) { + if (class_votes[j] > max_votes) { + max_votes = class_votes[j]; + best_class = j; + } + } + + if (debug) { + std::cout << "Class votes: ["; + for (int j = 0; j < n_classes; j++) { + std::cout << class_votes[j]; + if (j < n_classes - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + std::cout << "Best class: " << best_class << " with " << max_votes << " votes" << std::endl; + std::cout << "=== End predictSample Debug ===" << std::endl; + } + + return best_class; } torch::Tensor AdaBoost::predictProbaSample(const torch::Tensor& x) const diff --git a/tests/TestAdaBoost.cpp b/tests/TestAdaBoost.cpp index 7fb887b..81d5673 100644 --- a/tests/TestAdaBoost.cpp +++ b/tests/TestAdaBoost.cpp @@ -19,6 +19,7 @@ using namespace bayesnet; using namespace Catch::Matchers; +static const bool DEBUG = false; TEST_CASE("AdaBoost Construction", "[AdaBoost]") { @@ -141,6 +142,7 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") SECTION("Prediction with vector interface") { AdaBoost ada(10, 3); + ada.setDebug(DEBUG); // Enable debug to investigate ada.fit(X, y, features, className, states, Smoothing_t::NONE); auto predictions = ada.predict(X); @@ -159,6 +161,7 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") SECTION("Probability predictions with vector interface") { AdaBoost ada(10, 3); + ada.setDebug(DEBUG); // ENABLE DEBUG HERE TOO ada.fit(X, y, features, className, states, Smoothing_t::NONE); auto proba = ada.predict_proba(X); @@ -183,8 +186,16 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") correct++; } - // Check that predict_proba matches the expected predict value - REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); + INFO("Probability test - Sample " << i << ": pred=" << pred << ", probs=[" << p[0] << "," << p[1] << "], expected_from_probs=" << predicted_class); + + // Handle ties + if (std::abs(p[0] - p[1]) < 1e-10) { + INFO("Tie detected in probabilities"); + // Either prediction is valid in case of tie + } else { + // Check that predict_proba matches the expected predict value + REQUIRE(pred == predicted_class); + } } double accuracy = static_cast(correct) / n_samples; REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset @@ -230,103 +241,50 @@ TEST_CASE("AdaBoost Tensor Interface", "[AdaBoost]") } } -TEST_CASE("AdaBoost on Iris Dataset", "[AdaBoost][iris]") +TEST_CASE("AdaBoost SAMME Algorithm Validation", "[AdaBoost]") { auto raw = RawDatasets("iris", true); - SECTION("Training with vector interface") + SECTION("Prediction consistency with probabilities") { - AdaBoost ada(30, 3); - - REQUIRE_NOTHROW(ada.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv, Smoothing_t::NONE)); - - auto predictions = ada.predict(raw.Xv); - REQUIRE(predictions.size() == raw.yv.size()); - - // Calculate accuracy - int correct = 0; - for (size_t i = 0; i < predictions.size(); i++) { - if (predictions[i] == raw.yv[i]) correct++; - } - double accuracy = static_cast(correct) / raw.yv.size(); - REQUIRE(accuracy > 0.85); // Should achieve good accuracy - - // Test probability predictions - auto proba = ada.predict_proba(raw.Xv); - REQUIRE(proba.size() == raw.yv.size()); - REQUIRE(proba[0].size() == 3); // Three classes - - // Verify estimator weights and errors - auto weights = ada.getEstimatorWeights(); - auto errors = ada.getTrainingErrors(); - - REQUIRE(weights.size() == errors.size()); - REQUIRE(weights.size() > 0); - - // All weights should be positive (for non-zero error estimators) - for (double w : weights) { - REQUIRE(w >= 0.0); - } - - // All errors should be less than 0.5 (better than random) - for (double e : errors) { - REQUIRE(e < 0.5); - REQUIRE(e >= 0.0); - } - } - - SECTION("Different number of estimators") - { - std::vector n_estimators = { 5, 15, 25 }; - - for (int n_est : n_estimators) { - AdaBoost ada(n_est, 2); - ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); - - auto predictions = ada.predict(raw.Xt); - REQUIRE(predictions.size(0) == raw.yt.size(0)); - - // Check that we don't exceed the specified number of estimators - auto weights = ada.getEstimatorWeights(); - REQUIRE(static_cast(weights.size()) <= n_est); - } - } - - SECTION("Different base estimator depths") - { - std::vector depths = { 1, 2, 4 }; - - for (int depth : depths) { - AdaBoost ada(15, depth); - ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); - - auto predictions = ada.predict(raw.Xt); - REQUIRE(predictions.size(0) == raw.yt.size(0)); - } - } -} - -TEST_CASE("AdaBoost Edge Cases", "[AdaBoost]") -{ - auto raw = RawDatasets("iris", true); - - SECTION("Single estimator (depth 1 stump)") - { - AdaBoost ada(1, 1); // Single decision stump + AdaBoost ada(15, 3); + ada.setDebug(DEBUG); // Enable debug for ALL instances ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); auto predictions = ada.predict(raw.Xt); - REQUIRE(predictions.size(0) == raw.yt.size(0)); + auto probabilities = ada.predict_proba(raw.Xt); - auto weights = ada.getEstimatorWeights(); - REQUIRE(weights.size() == 1); + REQUIRE(predictions.size(0) == probabilities.size(0)); + REQUIRE(probabilities.size(1) == 3); // Three classes in Iris + + // For each sample, predicted class should correspond to highest probability + for (int i = 0; i < predictions.size(0); i++) { + int predicted_class = predictions[i].item(); + auto probs = probabilities[i]; + + // Find class with highest probability + auto max_prob_idx = torch::argmax(probs).item(); + + // Predicted class should match class with highest probability + REQUIRE(predicted_class == max_prob_idx); + + // Probabilities should sum to 1 + double sum_probs = torch::sum(probs).item(); + REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6)); + + // All probabilities should be non-negative + for (int j = 0; j < 3; j++) { + REQUIRE(probs[j].item() >= 0.0); + REQUIRE(probs[j].item() <= 1.0); + } + } } - SECTION("Perfect classifier scenario") + SECTION("Weighted voting verification") { - // Create a perfectly separable dataset + // Simple dataset where we can verify the weighted voting std::vector> X = { {0,0,1,1}, {0,1,0,1} }; - std::vector y = { 0, 0, 1, 1 }; + std::vector y = { 0, 1, 1, 0 }; std::vector features = { "f1", "f2" }; std::string className = "class"; std::map> states; @@ -334,191 +292,61 @@ TEST_CASE("AdaBoost Edge Cases", "[AdaBoost]") states["f2"] = { 0, 1 }; states["class"] = { 0, 1 }; - AdaBoost ada(10, 3); - ada.fit(X, y, features, className, states, Smoothing_t::NONE); - - auto predictions = ada.predict(X); - REQUIRE(predictions.size() == 4); - - // Should achieve perfect accuracy - int correct = 0; - for (size_t i = 0; i < predictions.size(); i++) { - if (predictions[i] == y[i]) correct++; - } - REQUIRE(correct == 4); - - // Should stop early due to perfect classification - auto errors = ada.getTrainingErrors(); - if (errors.size() > 0) { - REQUIRE(errors.back() < 1e-10); // Very low error - } - } - - SECTION("Small dataset") - { - // Very small dataset - std::vector> X = { {0,1}, {1,0} }; - std::vector y = { 0, 1 }; - std::vector features = { "f1", "f2" }; - std::string className = "class"; - std::map> states; - states["f1"] = { 0, 1 }; - states["f2"] = { 0, 1 }; - states["class"] = { 0, 1 }; - - AdaBoost ada(5, 1); - REQUIRE_NOTHROW(ada.fit(X, y, features, className, states, Smoothing_t::NONE)); - - auto predictions = ada.predict(X); - REQUIRE(predictions.size() == 2); - } -} - -TEST_CASE("AdaBoost Graph Visualization", "[AdaBoost]") -{ - // Simple dataset for visualization - std::vector> X = { {0,0,1,1}, {0,1,0,1} }; - std::vector y = { 0, 1, 1, 0 }; // XOR pattern - std::vector features = { "x1", "x2" }; - std::string className = "xor"; - std::map> states; - states["x1"] = { 0, 1 }; - states["x2"] = { 0, 1 }; - states["xor"] = { 0, 1 }; - - SECTION("Graph generation") - { AdaBoost ada(5, 2); + ada.setDebug(DEBUG); // Enable debug for detailed logging ada.fit(X, y, features, className, states, Smoothing_t::NONE); - auto graph_lines = ada.graph(); + INFO("=== Final test verification ==="); + auto predictions = ada.predict(X); + auto probabilities = ada.predict_proba(X); + auto alphas = ada.getEstimatorWeights(); - REQUIRE(graph_lines.size() > 2); - REQUIRE(graph_lines.front() == "digraph AdaBoost {"); - REQUIRE(graph_lines.back() == "}"); - - // Should contain base estimator references - bool has_estimators = false; - for (const auto& line : graph_lines) { - if (line.find("Estimator") != std::string::npos) { - has_estimators = true; - break; - } + INFO("Training info:"); + for (size_t i = 0; i < alphas.size(); i++) { + INFO(" Model " << i << ": alpha=" << alphas[i]); } - REQUIRE(has_estimators); - // Should contain alpha values - bool has_alpha = false; - for (const auto& line : graph_lines) { - if (line.find("α") != std::string::npos || line.find("alpha") != std::string::npos) { - has_alpha = true; - break; - } + REQUIRE(predictions.size() == 4); + REQUIRE(probabilities.size() == 4); + REQUIRE(probabilities[0].size() == 2); // Two classes + REQUIRE(alphas.size() > 0); + + // Verify that estimator weights are reasonable + for (double alpha : alphas) { + REQUIRE(alpha >= 0.0); // Alphas should be non-negative } - REQUIRE(has_alpha); - } - SECTION("Graph with title") - { - AdaBoost ada(3, 1); - ada.fit(X, y, features, className, states, Smoothing_t::NONE); + // Verify prediction-probability consistency with detailed logging + for (size_t i = 0; i < predictions.size(); i++) { + int pred = predictions[i]; + auto probs = probabilities[i]; - auto graph_lines = ada.graph("XOR AdaBoost"); + INFO("Final check - Sample " << i << ": predicted=" << pred << ", probabilities=[" << probs[0] << "," << probs[1] << "]"); - bool has_title = false; - for (const auto& line : graph_lines) { - if (line.find("label=\"XOR AdaBoost\"") != std::string::npos) { - has_title = true; - break; + // Handle the case where probabilities are exactly equal (tie) + if (std::abs(probs[0] - probs[1]) < 1e-10) { + INFO("Tie detected in probabilities - either prediction is valid"); + REQUIRE((pred == 0 || pred == 1)); + } else { + // Normal case - prediction should match max probability + int expected_pred = (probs[0] > probs[1]) ? 0 : 1; + INFO("Expected prediction based on probs: " << expected_pred); + REQUIRE(pred == expected_pred); } + + REQUIRE(probs[0] + probs[1] == Catch::Approx(1.0).epsilon(1e-6)); } - REQUIRE(has_title); - } -} - -TEST_CASE("AdaBoost with Weights", "[AdaBoost]") -{ - auto raw = RawDatasets("iris", true); - - SECTION("Uniform weights") - { - AdaBoost ada(20, 3); - ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, raw.weights, Smoothing_t::NONE); - - auto predictions = ada.predict(raw.Xt); - REQUIRE(predictions.size(0) == raw.yt.size(0)); - - auto weights = ada.getEstimatorWeights(); - REQUIRE(weights.size() > 0); } - SECTION("Non-uniform weights") + SECTION("Empty models edge case") { - auto weights = torch::ones({ raw.nSamples }); - weights.index({ torch::indexing::Slice(0, 50) }) *= 3.0; // Emphasize first class - weights = weights / weights.sum(); + AdaBoost ada(1, 1); + ada.setDebug(DEBUG); // Enable debug for ALL instances - AdaBoost ada(15, 2); - ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, weights, Smoothing_t::NONE); - - auto predictions = ada.predict(raw.Xt); - REQUIRE(predictions.size(0) == raw.yt.size(0)); - - // Check that training completed successfully - auto estimator_weights = ada.getEstimatorWeights(); - auto errors = ada.getTrainingErrors(); - - REQUIRE(estimator_weights.size() == errors.size()); - REQUIRE(estimator_weights.size() > 0); - } -} - -TEST_CASE("AdaBoost Input Dimension Validation", "[AdaBoost]") -{ - auto raw = RawDatasets("iris", true); - - SECTION("Correct input dimensions") - { - AdaBoost ada(10, 2); - ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); - - // Test with correct tensor dimensions (features x samples) - REQUIRE_NOTHROW(ada.predict(raw.Xt)); - REQUIRE_NOTHROW(ada.predict_proba(raw.Xt)); - - // Test with correct vector dimensions (features x samples) - REQUIRE_NOTHROW(ada.predict(raw.Xv)); - REQUIRE_NOTHROW(ada.predict_proba(raw.Xv)); - } - - SECTION("Dimension consistency between interfaces") - { - AdaBoost ada(10, 2); - ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); - - // Get predictions from both interfaces - auto tensor_predictions = ada.predict(raw.Xt); - auto vector_predictions = ada.predict(raw.Xv); - - // Should have same number of predictions - REQUIRE(tensor_predictions.size(0) == static_cast(vector_predictions.size())); - - // Test probability predictions - auto tensor_proba = ada.predict_proba(raw.Xt); - auto vector_proba = ada.predict_proba(raw.Xv); - - REQUIRE(tensor_proba.size(0) == static_cast(vector_proba.size())); - REQUIRE(tensor_proba.size(1) == static_cast(vector_proba[0].size())); - - // Verify predictions match between interfaces - for (int i = 0; i < tensor_predictions.size(0); i++) { - REQUIRE(tensor_predictions[i].item() == vector_predictions[i]); - - // Verify probabilities match between interfaces - for (int j = 0; j < tensor_proba.size(1); j++) { - REQUIRE(tensor_proba[i][j].item() == Catch::Approx(vector_proba[i][j]).epsilon(1e-10)); - } - } + // Try to predict before fitting + std::vector> X = { {0}, {1} }; + REQUIRE_THROWS_WITH(ada.predict(X), ContainsSubstring("not been fitted")); + REQUIRE_THROWS_WITH(ada.predict_proba(X), ContainsSubstring("not been fitted")); } } @@ -548,6 +376,7 @@ TEST_CASE("AdaBoost Debug - Simple Dataset Analysis", "[AdaBoost][debug]") SECTION("Debug training process") { AdaBoost ada(5, 3); // Few estimators for debugging + ada.setDebug(DEBUG); // This should work perfectly on this simple dataset REQUIRE_NOTHROW(ada.fit(X, y, features, className, states, Smoothing_t::NONE)); @@ -603,7 +432,14 @@ TEST_CASE("AdaBoost Debug - Simple Dataset Analysis", "[AdaBoost][debug]") // Predicted class should match highest probability int pred_class = predictions[i]; - REQUIRE(pred_class == (p[0] > p[1] ? 0 : 1)); + + // Handle ties + if (std::abs(p[0] - p[1]) < 1e-10) { + INFO("Tie detected - probabilities are equal"); + REQUIRE((pred_class == 0 || pred_class == 1)); + } else { + REQUIRE(pred_class == (p[0] > p[1] ? 0 : 1)); + } } } @@ -621,6 +457,7 @@ TEST_CASE("AdaBoost Debug - Simple Dataset Analysis", "[AdaBoost][debug]") double tree_accuracy = static_cast(tree_correct) / n_samples; AdaBoost ada(5, 3); + ada.setDebug(DEBUG); ada.fit(X, y, features, className, states, Smoothing_t::NONE); auto ada_predictions = ada.predict(X); @@ -639,95 +476,6 @@ TEST_CASE("AdaBoost Debug - Simple Dataset Analysis", "[AdaBoost][debug]") } } -TEST_CASE("AdaBoost SAMME Algorithm Validation", "[AdaBoost]") -{ - auto raw = RawDatasets("iris", true); - - SECTION("Prediction consistency with probabilities") - { - AdaBoost ada(15, 3); - ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); - - auto predictions = ada.predict(raw.Xt); - auto probabilities = ada.predict_proba(raw.Xt); - - REQUIRE(predictions.size(0) == probabilities.size(0)); - REQUIRE(probabilities.size(1) == 3); // Three classes in Iris - - // For each sample, predicted class should correspond to highest probability - for (int i = 0; i < predictions.size(0); i++) { - int predicted_class = predictions[i].item(); - auto probs = probabilities[i]; - - // Find class with highest probability - auto max_prob_idx = torch::argmax(probs).item(); - - // Predicted class should match class with highest probability - REQUIRE(predicted_class == max_prob_idx); - - // Probabilities should sum to 1 - double sum_probs = torch::sum(probs).item(); - REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6)); - - // All probabilities should be non-negative - for (int j = 0; j < 3; j++) { - REQUIRE(probs[j].item() >= 0.0); - REQUIRE(probs[j].item() <= 1.0); - } - } - } - - SECTION("Weighted voting verification") - { - // Simple dataset where we can verify the weighted voting - std::vector> X = { {0,0,1,1}, {0,1,0,1} }; - std::vector y = { 0, 1, 1, 0 }; - std::vector features = { "f1", "f2" }; - std::string className = "class"; - std::map> states; - states["f1"] = { 0, 1 }; - states["f2"] = { 0, 1 }; - states["class"] = { 0, 1 }; - - AdaBoost ada(5, 2); - ada.fit(X, y, features, className, states, Smoothing_t::NONE); - - auto predictions = ada.predict(X); - auto probabilities = ada.predict_proba(X); - auto alphas = ada.getEstimatorWeights(); - - REQUIRE(predictions.size() == 4); - REQUIRE(probabilities.size() == 4); - REQUIRE(probabilities[0].size() == 2); // Two classes - REQUIRE(alphas.size() > 0); - - // Verify that estimator weights are reasonable - for (double alpha : alphas) { - REQUIRE(alpha >= 0.0); // Alphas should be non-negative - } - - // Verify prediction-probability consistency - for (size_t i = 0; i < predictions.size(); i++) { - int pred = predictions[i]; - auto probs = probabilities[i]; - INFO("Sample " << i << ": predicted=" << pred - << ", probabilities=[" << probs[0] << ", " << probs[1] << "]"); - - REQUIRE(pred == (probs[0] > probs[1] ? 0 : 1)); - REQUIRE(probs[0] + probs[1] == Catch::Approx(1.0).epsilon(1e-6)); - } - } - - SECTION("Empty models edge case") - { - AdaBoost ada(1, 1); - - // Try to predict before fitting - std::vector> X = { {0}, {1} }; - REQUIRE_THROWS_WITH(ada.predict(X), ContainsSubstring("not been fitted")); - REQUIRE_THROWS_WITH(ada.predict_proba(X), ContainsSubstring("not been fitted")); - } -} TEST_CASE("AdaBoost Predict-Proba Consistency Fix", "[AdaBoost][consistency]") { // Simple binary classification dataset @@ -743,20 +491,31 @@ TEST_CASE("AdaBoost Predict-Proba Consistency Fix", "[AdaBoost][consistency]") SECTION("Binary classification consistency") { AdaBoost ada(3, 2); - ada.setDebug(true); // Enable debug output + ada.setDebug(DEBUG); // Enable debug output ada.fit(X, y, features, className, states, Smoothing_t::NONE); + INFO("=== Debugging predict vs predict_proba consistency ==="); + + // Get training info + auto alphas = ada.getEstimatorWeights(); + auto errors = ada.getTrainingErrors(); + + INFO("Training completed:"); + INFO(" Number of models: " << alphas.size()); + for (size_t i = 0; i < alphas.size(); i++) { + INFO(" Model " << i << ": alpha=" << alphas[i] << ", error=" << errors[i]); + } + auto predictions = ada.predict(X); auto probabilities = ada.predict_proba(X); - INFO("=== Debugging predict vs predict_proba consistency ==="); - // Verify consistency for each sample for (size_t i = 0; i < predictions.size(); i++) { int predicted_class = predictions[i]; auto probs = probabilities[i]; INFO("Sample " << i << ":"); + INFO(" Features: [" << X[0][i] << ", " << X[1][i] << "]"); INFO(" True class: " << y[i]); INFO(" Predicted class: " << predicted_class); INFO(" Probabilities: [" << probs[0] << ", " << probs[1] << "]"); @@ -765,7 +524,14 @@ TEST_CASE("AdaBoost Predict-Proba Consistency Fix", "[AdaBoost][consistency]") int max_prob_class = (probs[0] > probs[1]) ? 0 : 1; INFO(" Max prob class: " << max_prob_class); - REQUIRE(predicted_class == max_prob_class); + // Handle tie case (when probabilities are equal) + if (std::abs(probs[0] - probs[1]) < 1e-10) { + INFO(" Tie detected - probabilities are equal"); + // In case of tie, either prediction is valid + REQUIRE((predicted_class == 0 || predicted_class == 1)); + } else { + REQUIRE(predicted_class == max_prob_class); + } // Probabilities should sum to 1 double sum_probs = probs[0] + probs[1]; @@ -778,37 +544,4 @@ TEST_CASE("AdaBoost Predict-Proba Consistency Fix", "[AdaBoost][consistency]") REQUIRE(probs[1] <= 1.0); } } - - SECTION("Multi-class consistency") - { - auto raw = RawDatasets("iris", true); - - AdaBoost ada(5, 2); - ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); - - auto predictions = ada.predict(raw.Xt); - auto probabilities = ada.predict_proba(raw.Xt); - - // Check consistency for first 10 samples - for (int i = 0; i < std::min(static_cast(10), predictions.size(0)); i++) { - int predicted_class = predictions[i].item(); - auto probs = probabilities[i]; - - // Find class with maximum probability - auto max_prob_idx = torch::argmax(probs).item(); - - INFO("Sample " << i << ":"); - INFO(" Predicted class: " << predicted_class); - INFO(" Max prob class: " << max_prob_idx); - INFO(" Probabilities: [" << probs[0].item() << ", " - << probs[1].item() << ", " << probs[2].item() << "]"); - - // They must match - REQUIRE(predicted_class == max_prob_idx); - - // Probabilities should sum to 1 - double sum_probs = torch::sum(probs).item(); - REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6)); - } - } } \ No newline at end of file From a1a6d3d612404b6f9421d2bfeffa332718e01567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 18 Jun 2025 18:15:19 +0200 Subject: [PATCH 25/27] Optimize AdaBoost buildModel --- src/experimental_clfs/AdaBoost.cpp | 183 +++++++++++------------------ 1 file changed, 68 insertions(+), 115 deletions(-) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index 4f21a77..adabefd 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -13,6 +13,14 @@ #include #include "TensorUtils.hpp" +// Conditional debug macro for performance-critical sections +#define DEBUG_LOG(condition, ...) \ + do { \ + if (__builtin_expect((condition), 0)) { \ + std::cout << __VA_ARGS__ << std::endl; \ + } \ + } while(0) + namespace bayesnet { AdaBoost::AdaBoost(int n_estimators, int max_depth) @@ -21,6 +29,8 @@ namespace bayesnet { validHyperparameters = { "n_estimators", "base_max_depth" }; } + // Versión optimizada de buildModel - Reemplazar en AdaBoost.cpp: + void AdaBoost::buildModel(const torch::Tensor& weights) { // Initialize variables @@ -38,20 +48,23 @@ namespace bayesnet { // If initial weights are provided, incorporate them if (weights.defined() && weights.numel() > 0) { - sample_weights *= weights; + if (weights.size(0) != n_samples) { + throw std::runtime_error("weights must have the same length as number of samples"); + } + sample_weights = weights.clone(); normalizeWeights(); } - // Debug information - if (debug) { - std::cout << "Starting AdaBoost training with " << n_estimators << " estimators" << std::endl; - std::cout << "Number of classes: " << n_classes << std::endl; - std::cout << "Number of features: " << n << std::endl; - std::cout << "Number of samples: " << n_samples << std::endl; - } + // Conditional debug information (only when debug is enabled) + DEBUG_LOG(debug, "Starting AdaBoost training with " << n_estimators << " estimators\n" + << "Number of classes: " << n_classes << "\n" + << "Number of features: " << n << "\n" + << "Number of samples: " << n_samples); - // Main AdaBoost training loop (SAMME algorithm) - // (Stagewise Additive Modeling using a Multi - class Exponential loss) + // Pre-compute random guess error threshold + const double random_guess_error = 1.0 - (1.0 / static_cast(n_classes)); + + // Main AdaBoost training loop (SAMME algorithm) for (int iter = 0; iter < n_estimators; ++iter) { // Train base estimator with current sample weights auto estimator = trainBaseEstimator(sample_weights); @@ -60,12 +73,9 @@ namespace bayesnet { double weighted_error = calculateWeightedError(estimator.get(), sample_weights); training_errors.push_back(weighted_error); - // Check if error is too high (worse than random guessing) - double random_guess_error = 1.0 - (1.0 / n_classes); - // According to SAMME, we need error < random_guess_error if (weighted_error >= random_guess_error) { - if (debug) std::cout << " Error >= random guess (" << random_guess_error << "), stopping" << std::endl; + DEBUG_LOG(debug, "Error >= random guess (" << random_guess_error << "), stopping"); // If only one estimator and it's worse than random, keep it with zero weight if (models.empty()) { models.push_back(std::move(estimator)); @@ -76,7 +86,7 @@ namespace bayesnet { // Check for perfect classification BEFORE calculating alpha if (weighted_error <= 1e-10) { - if (debug) std::cout << " Perfect classification achieved (error=" << weighted_error << ")" << std::endl; + DEBUG_LOG(debug, "Perfect classification achieved (error=" << weighted_error << ")"); // For perfect classification, use a large but finite alpha double alpha = 10.0 + std::log(static_cast(n_classes - 1)); @@ -85,12 +95,10 @@ namespace bayesnet { models.push_back(std::move(estimator)); alphas.push_back(alpha); - if (debug) { - std::cout << "Iteration " << iter << ":" << std::endl; - std::cout << " Weighted error: " << weighted_error << std::endl; - std::cout << " Alpha (finite): " << alpha << std::endl; - std::cout << " Random guess error: " << random_guess_error << std::endl; - } + DEBUG_LOG(debug, "Iteration " << iter << ":\n" + << " Weighted error: " << weighted_error << "\n" + << " Alpha (finite): " << alpha << "\n" + << " Random guess error: " << random_guess_error); break; // Stop training as we have a perfect classifier } @@ -115,18 +123,15 @@ namespace bayesnet { normalizeWeights(); } - if (debug) { - std::cout << "Iteration " << iter << ":" << std::endl; - std::cout << " Weighted error: " << weighted_error << std::endl; - std::cout << " Alpha: " << alpha << std::endl; - std::cout << " Random guess error: " << random_guess_error << std::endl; - std::cout << " Random guess error: " << random_guess_error << std::endl; - } + DEBUG_LOG(debug, "Iteration " << iter << ":\n" + << " Weighted error: " << weighted_error << "\n" + << " Alpha: " << alpha << "\n" + << " Random guess error: " << random_guess_error); } // Set the number of models actually trained n_models = models.size(); - if (debug) std::cout << "AdaBoost training completed with " << n_models << " models" << std::endl; + DEBUG_LOG(debug, "AdaBoost training completed with " << n_models << " models"); } void AdaBoost::trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) @@ -152,44 +157,60 @@ namespace bayesnet { double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights) { - // Get features and labels from dataset + // Get features and labels from dataset (avoid repeated indexing) auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }); auto y_true = dataset.index({ -1, torch::indexing::Slice() }); // Get predictions from the estimator auto y_pred = estimator->predict(X); - // Calculate weighted error - auto incorrect = (y_pred != y_true).to(torch::kFloat); + // Vectorized error calculation using PyTorch operations + auto incorrect = (y_pred != y_true).to(torch::kDouble); - // Ensure weights are normalized - auto normalized_weights = weights / weights.sum(); + // Direct dot product for weighted error (more efficient than sum) + double weighted_error = torch::dot(incorrect, weights).item(); - // Calculate weighted error - double weighted_error = torch::sum(incorrect * normalized_weights).item(); - - return weighted_error; + // Clamp to valid range in one operation + return std::clamp(weighted_error, 1e-15, 1.0 - 1e-15); } void AdaBoost::updateSampleWeights(Classifier* estimator, double alpha) { - // Get predictions from the estimator + // Get predictions from the estimator (reuse from calculateWeightedError if possible) auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }); auto y_true = dataset.index({ -1, torch::indexing::Slice() }); auto y_pred = estimator->predict(X); - // Update weights according to SAMME algorithm - // w_i = w_i * exp(alpha * I(y_i != y_pred_i)) - auto incorrect = (y_pred != y_true).to(torch::kFloat); + // Vectorized weight update using PyTorch operations + auto incorrect = (y_pred != y_true).to(torch::kDouble); + + // Single vectorized operation instead of element-wise multiplication sample_weights *= torch::exp(alpha * incorrect); + + // Vectorized clamping for numerical stability + sample_weights = torch::clamp(sample_weights, 1e-15, 1e15); } void AdaBoost::normalizeWeights() { - // Normalize weights to sum to 1 + // Single-pass normalization using PyTorch operations double sum_weights = torch::sum(sample_weights).item(); - if (sum_weights > 0) { + + if (__builtin_expect(sum_weights <= 0, 0)) { + // Reset to uniform if all weights are zero/negative (rare case) + sample_weights = torch::ones_like(sample_weights) / sample_weights.size(0); + } else { + // Vectorized normalization sample_weights /= sum_weights; + + // Vectorized minimum weight enforcement + sample_weights = torch::clamp_min(sample_weights, 1e-15); + + // Renormalize after clamping (if any weights were clamped) + double new_sum = torch::sum(sample_weights).item(); + if (new_sum != 1.0) { + sample_weights /= new_sum; + } } } @@ -300,33 +321,6 @@ namespace bayesnet { return predictions; } - // torch::Tensor AdaBoost::predict_proba(torch::Tensor& X) - // { - // if (!fitted) { - // throw std::runtime_error(CLASSIFIER_NOT_FITTED); - // } - - // if (models.empty()) { - // throw std::runtime_error("No models have been trained"); - // } - - // // X should be (n_features, n_samples) - // if (X.size(0) != n) { - // throw std::runtime_error("Input has wrong number of features. Expected " + - // std::to_string(n) + " but got " + std::to_string(X.size(0))); - // } - - // int n_samples = X.size(1); - // torch::Tensor probabilities = torch::zeros({ n_samples, n_classes }); - - // for (int i = 0; i < n_samples; i++) { - // auto sample = X.index({ torch::indexing::Slice(), i }); - // probabilities[i] = predictProbaSample(sample); - // } - - // return probabilities; - // } - std::vector AdaBoost::predict(std::vector>& X) { // Convert to tensor - X is samples x features, need to transpose @@ -450,47 +444,6 @@ namespace bayesnet { return probabilities; } - // int AdaBoost::predictSample(const torch::Tensor& x) const - // { - // if (!fitted) { - // throw std::runtime_error(CLASSIFIER_NOT_FITTED); - // } - - // if (models.empty()) { - // throw std::runtime_error("No models have been trained"); - // } - - // // x should be a 1D tensor with n features - // if (x.size(0) != n) { - // throw std::runtime_error("Input sample has wrong number of features. Expected " + - // std::to_string(n) + " but got " + std::to_string(x.size(0))); - // } - - // // Initialize class votes - // std::vector class_votes(n_classes, 0.0); - - // // Accumulate weighted votes from all estimators - // for (size_t i = 0; i < models.size(); i++) { - // if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight - // try { - // // Get prediction from this estimator - // int predicted_class = static_cast(models[i].get())->predictSample(x); - - // // Add weighted vote for this class - // if (predicted_class >= 0 && predicted_class < n_classes) { - // class_votes[predicted_class] += alphas[i]; - // } - // } - // catch (const std::exception& e) { - // std::cerr << "Error in estimator " << i << ": " << e.what() << std::endl; - // continue; - // } - // } - - // // Return class with highest weighted vote - // return std::distance(class_votes.begin(), - // std::max_element(class_votes.begin(), class_votes.end())); - // } int AdaBoost::predictSample(const torch::Tensor& x) const { if (!fitted) { @@ -640,12 +593,12 @@ namespace bayesnet { } // Convert votes to probabilities - torch::Tensor class_probs = torch::zeros({ n_classes }, torch::kFloat); + torch::Tensor class_probs = torch::zeros({ n_classes }, torch::kDouble); if (total_votes > 0) { // Simple division to get probabilities for (int j = 0; j < n_classes; j++) { - class_probs[j] = static_cast(class_votes[j] / total_votes); + class_probs[j] = static_cast(class_votes[j] / total_votes); } } else { // If no valid votes, uniform distribution @@ -656,7 +609,7 @@ namespace bayesnet { if (debug) { std::cout << "Final probabilities: ["; for (int j = 0; j < n_classes; j++) { - std::cout << class_probs[j].item(); + std::cout << class_probs[j].item(); if (j < n_classes - 1) std::cout << ", "; } std::cout << "]" << std::endl; From 24cef7496d2725153b0374fe42e453cb7175e9a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 18 Jun 2025 18:28:54 +0200 Subject: [PATCH 26/27] Optimize AdaBoostPredict and default 100 estimators --- src/experimental_clfs/AdaBoost.cpp | 474 +++++++++++------------------ src/experimental_clfs/AdaBoost.h | 2 +- 2 files changed, 173 insertions(+), 303 deletions(-) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index adabefd..563fe34 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -294,28 +294,185 @@ namespace bayesnet { Ensemble::setHyperparameters(hyperparameters); } - torch::Tensor AdaBoost::predict(torch::Tensor& X) + int AdaBoost::predictSample(const torch::Tensor& x) const { - if (!fitted) { + // Early validation (keep essential checks only) + if (!fitted || models.empty()) { throw std::runtime_error(CLASSIFIER_NOT_FITTED); } - if (models.empty()) { - throw std::runtime_error("No models have been trained"); + // Pre-allocate and reuse memory + static thread_local std::vector class_votes_cache; + if (class_votes_cache.size() != static_cast(n_classes)) { + class_votes_cache.resize(n_classes); + } + std::fill(class_votes_cache.begin(), class_votes_cache.end(), 0.0); + + // Optimized voting loop - avoid exception handling in hot path + for (size_t i = 0; i < models.size(); ++i) { + double alpha = alphas[i]; + if (alpha <= 0 || !std::isfinite(alpha)) continue; + + // Direct cast and call - avoid virtual dispatch overhead + int predicted_class = static_cast(models[i].get())->predictSample(x); + + // Bounds check with branch prediction hint + if (__builtin_expect(predicted_class >= 0 && predicted_class < n_classes, 1)) { + class_votes_cache[predicted_class] += alpha; + } } - // X should be (n_features, n_samples) + // Fast argmax using iterators + return std::distance(class_votes_cache.begin(), + std::max_element(class_votes_cache.begin(), class_votes_cache.end())); + } + + torch::Tensor AdaBoost::predictProbaSample(const torch::Tensor& x) const + { + // Early validation + if (!fitted || models.empty()) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + // Use stack allocation for small arrays (typical case: n_classes <= 32) + constexpr int STACK_THRESHOLD = 32; + double stack_votes[STACK_THRESHOLD]; + std::vector heap_votes; + double* class_votes; + + if (n_classes <= STACK_THRESHOLD) { + class_votes = stack_votes; + std::fill_n(class_votes, n_classes, 0.0); + } else { + heap_votes.resize(n_classes, 0.0); + class_votes = heap_votes.data(); + } + + double total_votes = 0.0; + + // Optimized voting loop + for (size_t i = 0; i < models.size(); ++i) { + double alpha = alphas[i]; + if (alpha <= 0 || !std::isfinite(alpha)) continue; + + int predicted_class = static_cast(models[i].get())->predictSample(x); + + if (__builtin_expect(predicted_class >= 0 && predicted_class < n_classes, 1)) { + class_votes[predicted_class] += alpha; + total_votes += alpha; + } + } + + // Direct tensor creation with pre-computed size + torch::Tensor class_probs = torch::empty({ n_classes }, torch::TensorOptions().dtype(torch::kFloat32)); + auto probs_accessor = class_probs.accessor(); + + if (__builtin_expect(total_votes > 0.0, 1)) { + // Vectorized probability calculation + const double inv_total = 1.0 / total_votes; + for (int j = 0; j < n_classes; ++j) { + probs_accessor[j] = static_cast(class_votes[j] * inv_total); + } + } else { + // Uniform distribution fallback + const float uniform_prob = 1.0f / n_classes; + for (int j = 0; j < n_classes; ++j) { + probs_accessor[j] = uniform_prob; + } + } + + return class_probs; + } + + torch::Tensor AdaBoost::predict_proba(torch::Tensor& X) + { + if (!fitted || models.empty()) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + // Input validation if (X.size(0) != n) { throw std::runtime_error("Input has wrong number of features. Expected " + std::to_string(n) + " but got " + std::to_string(X.size(0))); } - int n_samples = X.size(1); - torch::Tensor predictions = torch::zeros({ n_samples }, torch::kInt32); + const int n_samples = X.size(1); - for (int i = 0; i < n_samples; i++) { - auto sample = X.index({ torch::indexing::Slice(), i }); - predictions[i] = predictSample(sample); + // Pre-allocate output tensor with correct layout + torch::Tensor probabilities = torch::empty({ n_samples, n_classes }, + torch::TensorOptions().dtype(torch::kFloat32)); + + // Convert to contiguous memory if needed (optimization for memory access) + if (!X.is_contiguous()) { + X = X.contiguous(); + } + + // Batch processing with memory-efficient sample extraction + for (int i = 0; i < n_samples; ++i) { + // Extract sample without unnecessary copies + auto sample = X.select(1, i); + + // Direct assignment to pre-allocated tensor + probabilities[i] = predictProbaSample(sample); + } + + return probabilities; + } + + std::vector> AdaBoost::predict_proba(std::vector>& X) + { + const size_t n_samples = X[0].size(); + + // Pre-allocate result with exact size + std::vector> result; + result.reserve(n_samples); + + // Avoid repeated allocations + for (size_t i = 0; i < n_samples; ++i) { + result.emplace_back(n_classes, 0.0); + } + + // Convert to tensor only once (batch conversion is more efficient) + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); + torch::Tensor proba_tensor = predict_proba(X_tensor); + + // Optimized tensor-to-vector conversion + auto proba_accessor = proba_tensor.accessor(); + for (size_t i = 0; i < n_samples; ++i) { + for (int j = 0; j < n_classes; ++j) { + result[i][j] = static_cast(proba_accessor[i][j]); + } + } + + return result; + } + + torch::Tensor AdaBoost::predict(torch::Tensor& X) + { + if (!fitted || models.empty()) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + if (X.size(0) != n) { + throw std::runtime_error("Input has wrong number of features. Expected " + + std::to_string(n) + " but got " + std::to_string(X.size(0))); + } + + const int n_samples = X.size(1); + + // Pre-allocate with correct dtype + torch::Tensor predictions = torch::empty({ n_samples }, torch::TensorOptions().dtype(torch::kInt32)); + auto pred_accessor = predictions.accessor(); + + // Ensure contiguous memory layout + if (!X.is_contiguous()) { + X = X.contiguous(); + } + + // Optimized prediction loop + for (int i = 0; i < n_samples; ++i) { + auto sample = X.select(1, i); + pred_accessor[i] = predictSample(sample); } return predictions; @@ -323,300 +480,13 @@ namespace bayesnet { std::vector AdaBoost::predict(std::vector>& X) { - // Convert to tensor - X is samples x features, need to transpose + // Single tensor conversion for batch processing torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); - auto predictions = predict(X_tensor); - std::vector result = platform::TensorUtils::to_vector(predictions); + torch::Tensor predictions_tensor = predict(X_tensor); + + // Optimized tensor-to-vector conversion + std::vector result = platform::TensorUtils::to_vector(predictions_tensor); return result; } - std::vector> AdaBoost::predict_proba(std::vector>& X) - { - auto n_samples = X[0].size(); - - if (debug) { - std::cout << "=== predict_proba vector method debug ===" << std::endl; - std::cout << "Input X dimensions: " << X.size() << " features x " << n_samples << " samples" << std::endl; - std::cout << "Input data:" << std::endl; - for (size_t i = 0; i < X.size(); i++) { - std::cout << " Feature " << i << ": ["; - for (size_t j = 0; j < X[i].size(); j++) { - std::cout << X[i][j]; - if (j < X[i].size() - 1) std::cout << ", "; - } - std::cout << "]" << std::endl; - } - } - - // Convert to tensor - X is features x samples, need to transpose for tensor format - torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); - - if (debug) { - std::cout << "Converted tensor shape: " << X_tensor.sizes() << std::endl; - std::cout << "Tensor data: " << X_tensor << std::endl; - } - - auto proba_tensor = predict_proba(X_tensor); // Call tensor method - - if (debug) { - std::cout << "Proba tensor shape: " << proba_tensor.sizes() << std::endl; - std::cout << "Proba tensor data: " << proba_tensor << std::endl; - } - - std::vector> result(n_samples, std::vector(n_classes, 0.0)); - - for (size_t i = 0; i < n_samples; i++) { - for (int j = 0; j < n_classes; j++) { - result[i][j] = proba_tensor[i][j].item(); - } - - if (debug) { - std::cout << "Sample " << i << " converted: ["; - for (int j = 0; j < n_classes; j++) { - std::cout << result[i][j]; - if (j < n_classes - 1) std::cout << ", "; - } - std::cout << "]" << std::endl; - } - } - - if (debug) { - std::cout << "=== End predict_proba vector method debug ===" << std::endl; - } - - return result; - } - - // También agregar debug al método tensor predict_proba: - - torch::Tensor AdaBoost::predict_proba(torch::Tensor& X) - { - if (!fitted) { - throw std::runtime_error(CLASSIFIER_NOT_FITTED); - } - - if (models.empty()) { - throw std::runtime_error("No models have been trained"); - } - - // X should be (n_features, n_samples) - if (X.size(0) != n) { - throw std::runtime_error("Input has wrong number of features. Expected " + - std::to_string(n) + " but got " + std::to_string(X.size(0))); - } - - int n_samples = X.size(1); - - if (debug) { - std::cout << "=== predict_proba tensor method debug ===" << std::endl; - std::cout << "Input tensor shape: " << X.sizes() << std::endl; - std::cout << "Number of samples: " << n_samples << std::endl; - std::cout << "Number of classes: " << n_classes << std::endl; - } - - torch::Tensor probabilities = torch::zeros({ n_samples, n_classes }); - - for (int i = 0; i < n_samples; i++) { - auto sample = X.index({ torch::indexing::Slice(), i }); - - if (debug) { - std::cout << "Processing sample " << i << ": " << sample << std::endl; - } - - auto sample_probs = predictProbaSample(sample); - - if (debug) { - std::cout << "Sample " << i << " probabilities from predictProbaSample: " << sample_probs << std::endl; - } - - probabilities[i] = sample_probs; - - if (debug) { - std::cout << "Assigned to probabilities[" << i << "]: " << probabilities[i] << std::endl; - } - } - - if (debug) { - std::cout << "Final probabilities tensor: " << probabilities << std::endl; - std::cout << "=== End predict_proba tensor method debug ===" << std::endl; - } - - return probabilities; - } - - int AdaBoost::predictSample(const torch::Tensor& x) const - { - if (!fitted) { - throw std::runtime_error(CLASSIFIER_NOT_FITTED); - } - - if (models.empty()) { - throw std::runtime_error("No models have been trained"); - } - - // x should be a 1D tensor with n features - if (x.size(0) != n) { - throw std::runtime_error("Input sample has wrong number of features. Expected " + - std::to_string(n) + " but got " + std::to_string(x.size(0))); - } - - // Initialize class votes with zeros - std::vector class_votes(n_classes, 0.0); - - if (debug) { - std::cout << "=== predictSample Debug ===" << std::endl; - std::cout << "Number of models: " << models.size() << std::endl; - } - - // Accumulate votes from all estimators (same logic as predictProbaSample) - for (size_t i = 0; i < models.size(); i++) { - double alpha = alphas[i]; - - // Skip invalid estimators - if (alpha <= 0 || !std::isfinite(alpha)) { - if (debug) std::cout << "Skipping model " << i << " (alpha=" << alpha << ")" << std::endl; - continue; - } - - try { - // Get class prediction from this estimator - int predicted_class = static_cast(models[i].get())->predictSample(x); - - if (debug) { - std::cout << "Model " << i << ": predicts class " << predicted_class - << " with alpha " << alpha << std::endl; - } - - // Add weighted vote for this class - if (predicted_class >= 0 && predicted_class < n_classes) { - class_votes[predicted_class] += alpha; - } - } - catch (const std::exception& e) { - if (debug) std::cout << "Error in model " << i << ": " << e.what() << std::endl; - continue; - } - } - - // Find class with maximum votes - int best_class = 0; - double max_votes = class_votes[0]; - - for (int j = 1; j < n_classes; j++) { - if (class_votes[j] > max_votes) { - max_votes = class_votes[j]; - best_class = j; - } - } - - if (debug) { - std::cout << "Class votes: ["; - for (int j = 0; j < n_classes; j++) { - std::cout << class_votes[j]; - if (j < n_classes - 1) std::cout << ", "; - } - std::cout << "]" << std::endl; - std::cout << "Best class: " << best_class << " with " << max_votes << " votes" << std::endl; - std::cout << "=== End predictSample Debug ===" << std::endl; - } - - return best_class; - } - - torch::Tensor AdaBoost::predictProbaSample(const torch::Tensor& x) const - { - if (!fitted) { - throw std::runtime_error(CLASSIFIER_NOT_FITTED); - } - - if (models.empty()) { - throw std::runtime_error("No models have been trained"); - } - - // x should be a 1D tensor with n features - if (x.size(0) != n) { - throw std::runtime_error("Input sample has wrong number of features. Expected " + - std::to_string(n) + " but got " + std::to_string(x.size(0))); - } - - // Initialize class votes with zeros - std::vector class_votes(n_classes, 0.0); - double total_votes = 0.0; - - if (debug) { - std::cout << "=== predictProbaSample Debug ===" << std::endl; - std::cout << "Number of models: " << models.size() << std::endl; - std::cout << "Number of classes: " << n_classes << std::endl; - } - - // Accumulate votes from all estimators - for (size_t i = 0; i < models.size(); i++) { - double alpha = alphas[i]; - - // Skip invalid estimators - if (alpha <= 0 || !std::isfinite(alpha)) { - if (debug) std::cout << "Skipping model " << i << " (alpha=" << alpha << ")" << std::endl; - continue; - } - - try { - // Get class prediction from this estimator - int predicted_class = static_cast(models[i].get())->predictSample(x); - - if (debug) { - std::cout << "Model " << i << ": predicts class " << predicted_class - << " with alpha " << alpha << std::endl; - } - - // Add weighted vote for this class - if (predicted_class >= 0 && predicted_class < n_classes) { - class_votes[predicted_class] += alpha; - total_votes += alpha; - } else { - if (debug) std::cout << "Invalid class prediction: " << predicted_class << std::endl; - } - } - catch (const std::exception& e) { - if (debug) std::cout << "Error in model " << i << ": " << e.what() << std::endl; - continue; - } - } - - if (debug) { - std::cout << "Total votes: " << total_votes << std::endl; - std::cout << "Class votes: ["; - for (int j = 0; j < n_classes; j++) { - std::cout << class_votes[j]; - if (j < n_classes - 1) std::cout << ", "; - } - std::cout << "]" << std::endl; - } - - // Convert votes to probabilities - torch::Tensor class_probs = torch::zeros({ n_classes }, torch::kDouble); - - if (total_votes > 0) { - // Simple division to get probabilities - for (int j = 0; j < n_classes; j++) { - class_probs[j] = static_cast(class_votes[j] / total_votes); - } - } else { - // If no valid votes, uniform distribution - if (debug) std::cout << "No valid votes, using uniform distribution" << std::endl; - class_probs.fill_(1.0f / n_classes); - } - - if (debug) { - std::cout << "Final probabilities: ["; - for (int j = 0; j < n_classes; j++) { - std::cout << class_probs[j].item(); - if (j < n_classes - 1) std::cout << ", "; - } - std::cout << "]" << std::endl; - std::cout << "=== End predictProbaSample Debug ===" << std::endl; - } - - return class_probs; - } - } // namespace bayesnet \ No newline at end of file diff --git a/src/experimental_clfs/AdaBoost.h b/src/experimental_clfs/AdaBoost.h index 0c7e08b..1d5c729 100644 --- a/src/experimental_clfs/AdaBoost.h +++ b/src/experimental_clfs/AdaBoost.h @@ -14,7 +14,7 @@ namespace bayesnet { class AdaBoost : public Ensemble { public: - explicit AdaBoost(int n_estimators = 50, int max_depth = 1); + explicit AdaBoost(int n_estimators = 100, int max_depth = 1); virtual ~AdaBoost() = default; // Override base class methods From 9448a971e83d46f1454997b6813f4d60268ff48d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 27 Jun 2025 20:25:41 +0200 Subject: [PATCH 27/27] fix vcpkg.json --- vcpkg.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vcpkg.json b/vcpkg.json index 06584bb..e9a8c61 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -28,7 +28,7 @@ "version": "1.1.1" }, { - "name": "argpase", + "name": "argparse", "version": "3.2" }, { @@ -40,4 +40,4 @@ "version": "3.11.3" } ] - } + } \ No newline at end of file