diff --git a/.gitignore b/.gitignore index 42fbee0..f09d914 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ puml/** diagrams/html/** diagrams/latex/** .cache +vcpkg_installed diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 4b52f53..0000000 --- a/.gitmodules +++ /dev/null @@ -1,21 +0,0 @@ -[submodule "lib/catch2"] - path = lib/catch2 - url = https://github.com/catchorg/Catch2.git -[submodule "lib/argparse"] - path = lib/argparse - url = https://github.com/p-ranav/argparse -[submodule "lib/json"] - path = lib/json - url = https://github.com/nlohmann/json -[submodule "lib/libxlsxwriter"] - path = lib/libxlsxwriter - url = https://github.com/jmcnamara/libxlsxwriter.git -[submodule "lib/folding"] - path = lib/folding - url = https://github.com/rmontanana/folding -[submodule "lib/Files"] - path = lib/Files - url = https://github.com/rmontanana/ArffFiles -[submodule "lib/mdlp"] - path = lib/mdlp - url = https://github.com/rmontanana/mdlp diff --git a/CMakeLists.txt b/CMakeLists.txt index 178e84b..75b6900 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,12 +7,6 @@ project(Platform LANGUAGES CXX ) -find_package(Torch REQUIRED) - -if (POLICY CMP0135) - cmake_policy(SET CMP0135 NEW) -endif () - # Global CMake variables # ---------------------- set(CMAKE_CXX_STANDARD 20) @@ -26,62 +20,77 @@ set(CMAKE_CXX_FLAGS_DEBUG " ${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage -O # Options # ------- -option(ENABLE_CLANG_TIDY "Enable to add clang tidy." OFF) option(ENABLE_TESTING "Unit testing build" OFF) option(CODE_COVERAGE "Collect coverage from test library" OFF) +# CMakes modules +# -------------- +set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules ${CMAKE_MODULE_PATH}) + # MPI find_package(MPI REQUIRED) message("MPI_CXX_LIBRARIES=${MPI_CXX_LIBRARIES}") message("MPI_CXX_INCLUDE_DIRS=${MPI_CXX_INCLUDE_DIRS}") # Boost Library +cmake_policy(SET CMP0135 NEW) +cmake_policy(SET CMP0167 NEW) # For FindBoost set(Boost_USE_STATIC_LIBS OFF) set(Boost_USE_MULTITHREADED ON) set(Boost_USE_STATIC_RUNTIME OFF) + find_package(Boost 1.66.0 REQUIRED COMPONENTS python3 numpy3) + +# # Python +find_package(Python3 REQUIRED COMPONENTS Development) +# # target_include_directories(MyTarget SYSTEM PRIVATE ${Python3_INCLUDE_DIRS}) +# message("Python_LIBRARIES=${Python_LIBRARIES}") + +# # Boost Python +# find_package(boost_python${Python3_VERSION_MAJOR}${Python3_VERSION_MINOR} CONFIG REQUIRED COMPONENTS python${Python3_VERSION_MAJOR}${Python3_VERSION_MINOR}) +# # target_link_libraries(MyTarget PRIVATE Boost::python${Python3_VERSION_MAJOR}${Python3_VERSION_MINOR}) + + if(Boost_FOUND) message("Boost_INCLUDE_DIRS=${Boost_INCLUDE_DIRS}") + message("Boost_LIBRARIES=${Boost_LIBRARIES}") + message("Boost_VERSION=${Boost_VERSION}") include_directories(${Boost_INCLUDE_DIRS}) endif() -# Python -find_package(Python3 3.11 COMPONENTS Interpreter Development REQUIRED) -message("Python3_LIBRARIES=${Python3_LIBRARIES}") -# CMakes modules -# -------------- -set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules ${CMAKE_MODULE_PATH}) -include(AddGitSubmodule) - -if (CODE_COVERAGE) - enable_testing() - include(CodeCoverage) - MESSAGE("Code coverage enabled") - SET(GCC_COVERAGE_LINK_FLAGS " ${GCC_COVERAGE_LINK_FLAGS} -lgcov --coverage") -endif (CODE_COVERAGE) - -if (ENABLE_CLANG_TIDY) - include(StaticAnalyzers) # clang-tidy -endif (ENABLE_CLANG_TIDY) # External libraries - dependencies of Platform # --------------------------------------------- -add_git_submodule("lib/argparse") -add_git_submodule("lib/mdlp") find_library(XLSXWRITER_LIB NAMES libxlsxwriter.dylib libxlsxwriter.so PATHS ${Platform_SOURCE_DIR}/lib/libxlsxwriter/lib) -message("XLSXWRITER_LIB=${XLSXWRITER_LIB}") +# find_path(XLSXWRITER_INCLUDE_DIR xlsxwriter.h) +# find_library(XLSXWRITER_LIBRARY xlsxwriter) +# message("XLSXWRITER_INCLUDE_DIR=${XLSXWRITER_INCLUDE_DIR}") +# message("XLSXWRITER_LIBRARY=${XLSXWRITER_LIBRARY}") +find_package(Torch CONFIG REQUIRED) +find_package(fimdlp CONFIG REQUIRED) +find_package(folding CONFIG REQUIRED) +find_package(argparse CONFIG REQUIRED) +find_package(nlohmann_json CONFIG REQUIRED) +find_package(Boost REQUIRED COMPONENTS python) +find_package(arff-files CONFIG REQUIRED) +# BayesNet +find_library(bayesnet NAMES libbayesnet bayesnet libbayesnet.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED) +find_path(Bayesnet_INCLUDE_DIRS REQUIRED NAMES bayesnet PATHS ${Platform_SOURCE_DIR}/../lib/include) +add_library(bayesnet::bayesnet UNKNOWN IMPORTED) +set_target_properties(bayesnet::bayesnet PROPERTIES + IMPORTED_LOCATION ${bayesnet} + INTERFACE_INCLUDE_DIRECTORIES ${Bayesnet_INCLUDE_DIRS}) +message(STATUS "BayesNet=${bayesnet}") +message(STATUS "BayesNet_INCLUDE_DIRS=${Bayesnet_INCLUDE_DIRS}") + +# PyClassifiers find_library(PyClassifiers NAMES libPyClassifiers PyClassifiers libPyClassifiers.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED) find_path(PyClassifiers_INCLUDE_DIRS REQUIRED NAMES pyclassifiers PATHS ${Platform_SOURCE_DIR}/../lib/include) -find_library(BayesNet NAMES libBayesNet BayesNet libBayesNet.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED) -find_path(Bayesnet_INCLUDE_DIRS REQUIRED NAMES bayesnet PATHS ${Platform_SOURCE_DIR}/../lib/include) - message(STATUS "PyClassifiers=${PyClassifiers}") message(STATUS "PyClassifiers_INCLUDE_DIRS=${PyClassifiers_INCLUDE_DIRS}") -message(STATUS "BayesNet=${BayesNet}") -message(STATUS "Bayesnet_INCLUDE_DIRS=${Bayesnet_INCLUDE_DIRS}") # Subdirectories # -------------- @@ -90,16 +99,20 @@ cmake_path(SET TEST_DATA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/tests/data") configure_file(src/common/SourceData.h.in "${CMAKE_BINARY_DIR}/configured_files/include/SourceData.h") add_subdirectory(config) add_subdirectory(src) -add_subdirectory(sample) +# add_subdirectory(sample) file(GLOB Platform_SOURCES CONFIGURE_DEPENDS ${Platform_SOURCE_DIR}/src/*.cpp) # Testing # ------- if (ENABLE_TESTING) + enable_testing() MESSAGE("Testing enabled") - if (NOT TARGET Catch2::Catch2) - add_git_submodule("lib/catch2") - endif (NOT TARGET Catch2::Catch2) + find_package(Catch2 CONFIG REQUIRED) include(CTest) add_subdirectory(tests) endif (ENABLE_TESTING) +if (CODE_COVERAGE) + include(CodeCoverage) + MESSAGE("Code coverage enabled") + SET(GCC_COVERAGE_LINK_FLAGS " ${GCC_COVERAGE_LINK_FLAGS} -lgcov --coverage") +endif (CODE_COVERAGE) diff --git a/Makefile b/Makefile index 59c603b..aff3116 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ SHELL := /bin/bash .DEFAULT_GOAL := help -.PHONY: coverage setup help build test clean debug release submodules buildr buildd install dependency testp testb clang-uml +.PHONY: init clean coverage setup help build test clean debug release buildr buildd install dependency testp testb clang-uml example -f_release = build_release -f_debug = build_debug +f_release = build_Release +f_debug = build_Debug app_targets = b_best b_list b_main b_manage b_grid b_results test_targets = unit_tests_platform @@ -20,14 +20,22 @@ define ClearTests fi ; endef +init: ## Initialize the project installing dependencies + @echo ">>> Installing dependencies" + @vcpkg install + @echo ">>> Done"; -sub-init: ## Initialize submodules - @git submodule update --init --recursive - -sub-update: ## Initialize submodules - @git submodule update --remote --merge - @git submodule foreach git pull origin master - +clean: ## Clean the project + @echo ">>> Cleaning the project..." + @if test -f CMakeCache.txt ; then echo "- Deleting CMakeCache.txt"; rm -f CMakeCache.txt; fi + @for folder in $(f_release) $(f_debug) vpcpkg_installed install_test ; do \ + if test -d "$$folder" ; then \ + echo "- Deleting $$folder folder" ; \ + rm -rf "$$folder"; \ + fi; \ + done + $(call ClearTests) + @echo ">>> Done"; setup: ## Install dependencies for tests and coverage @if [ "$(shell uname)" = "Darwin" ]; then \ brew install gcovr; \ @@ -51,7 +59,9 @@ install: ## Copy binary files to bin folder @echo "*******************************************" @for item in $(app_targets); do \ echo ">>> Copying $$item" ; \ - cp $(f_release)/src/$$item $(dest) ; \ + cp $(f_release)/src/$$item $(dest) || { \ + echo "*** Error copying $$item" ; \ + } ; \ done dependency: ## Create a dependency graph diagram of the project (build/dependency.png) @@ -60,37 +70,33 @@ dependency: ## Create a dependency graph diagram of the project (build/dependenc cd $(f_debug) && cmake .. --graphviz=dependency.dot && dot -Tpng dependency.dot -o dependency.png buildd: ## Build the debug targets - @cmake --build $(f_debug) -t $(app_targets) PlatformSample --parallel + @cmake --build $(f_debug) -t $(app_targets) PlatformSample --parallel buildr: ## Build the release targets @cmake --build $(f_release) -t $(app_targets) --parallel -clean: ## Clean the tests info - @echo ">>> Cleaning Debug Platform tests..."; - $(call ClearTests) - @echo ">>> Done"; - clang-uml: ## Create uml class and sequence diagrams clang-uml -p --add-compile-flag -I /usr/lib/gcc/x86_64-redhat-linux/8/include/ -debug: ## Build a debug version of the project +debug: ## Build a debug version of the project with BayesNet from vcpkg @echo ">>> Building Debug Platform..."; @if [ -d ./$(f_debug) ]; then rm -rf ./$(f_debug); fi @mkdir $(f_debug); - @cmake -S . -B $(f_debug) -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON + @cmake -S . -B $(f_debug) -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake @echo ">>> Done"; -release: ## Build a Release version of the project +release: ## Build a Release version of the project with BayesNet from vcpkg @echo ">>> Building Release Platform..."; @if [ -d ./$(f_release) ]; then rm -rf ./$(f_release); fi @mkdir $(f_release); - @cmake -S . -B $(f_release) -D CMAKE_BUILD_TYPE=Release - @echo ">>> Done"; + @cmake -S . -B $(f_release) -D CMAKE_BUILD_TYPE=Release -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake + @echo ">>> Done"; opt = "" test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximum Spanning Tree'") to run only that section @echo ">>> Running Platform tests..."; @$(MAKE) clean + @$(MAKE) debug @cmake --build $(f_debug) -t $(test_targets) --parallel @for t in $(test_targets); do \ if [ -f $(f_debug)/tests/$$t ]; then \ diff --git a/README.md b/README.md index 5e5172d..05d5749 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ ![C++](https://img.shields.io/badge/c++-%2300599C.svg?style=flat&logo=c%2B%2B&logoColor=white) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)]() +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/rmontanana/Platform) ![Gitea Last Commit](https://img.shields.io/gitea/last-commit/rmontanana/platform?gitea_url=https://gitea.rmontanana.es&logo=gitea) Platform to run Bayesian Networks and Machine Learning Classifiers experiments. diff --git a/gitmodules b/gitmodules deleted file mode 100644 index 394a867..0000000 --- a/gitmodules +++ /dev/null @@ -1,23 +0,0 @@ -[submodule "lib/catch2"] - path = lib/catch2 - main = v2.x - update = merge - url = https://github.com/catchorg/Catch2.git -[submodule "lib/argparse"] - path = lib/argparse - url = https://github.com/p-ranav/argparse - master = master - update = merge -[submodule "lib/json"] - path = lib/json - url = https://github.com/nlohmann/json.git - master = master - update = merge -[submodule "lib/libxlsxwriter"] - path = lib/libxlsxwriter - url = https://github.com/jmcnamara/libxlsxwriter.git - main = main - update = merge -[submodule "lib/folding"] - path = lib/folding - url = https://github.com/rmontanana/Folding diff --git a/lib/Files b/lib/Files deleted file mode 160000 index 18c79f6..0000000 --- a/lib/Files +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 18c79f6d4894d6b7a6cbfad0239bf9bfd68d3bb4 diff --git a/lib/argparse b/lib/argparse deleted file mode 160000 index cbd9fd8..0000000 --- a/lib/argparse +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cbd9fd8ed675ed6a2ac1bd7142d318c6ad5d3462 diff --git a/lib/catch2 b/lib/catch2 deleted file mode 160000 index 914aeec..0000000 --- a/lib/catch2 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 914aeecfe23b1e16af6ea675a4fb5dbd5a5b8d0a diff --git a/lib/folding b/lib/folding deleted file mode 160000 index 9652853..0000000 --- a/lib/folding +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9652853d692ed3b8a38d89f70559209ffb988020 diff --git a/lib/json b/lib/json deleted file mode 160000 index 48e7b4c..0000000 --- a/lib/json +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 48e7b4c23b089c088c11e51c824d78d0f0949b40 diff --git a/lib/libxlsxwriter b/lib/libxlsxwriter deleted file mode 160000 index 14f1351..0000000 --- a/lib/libxlsxwriter +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 14f13513cb140092a913a91fce719ff7dc36e332 diff --git a/lib/log/loguru.cpp b/lib/log/loguru.cpp deleted file mode 100644 index a95cfbf..0000000 --- a/lib/log/loguru.cpp +++ /dev/null @@ -1,2009 +0,0 @@ -#if defined(__GNUC__) || defined(__clang__) -// Disable all warnings from gcc/clang: -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wpragmas" - -#pragma GCC diagnostic ignored "-Wc++98-compat" -#pragma GCC diagnostic ignored "-Wc++98-compat-pedantic" -#pragma GCC diagnostic ignored "-Wexit-time-destructors" -#pragma GCC diagnostic ignored "-Wformat-nonliteral" -#pragma GCC diagnostic ignored "-Wglobal-constructors" -#pragma GCC diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" -#pragma GCC diagnostic ignored "-Wmissing-prototypes" -#pragma GCC diagnostic ignored "-Wpadded" -#pragma GCC diagnostic ignored "-Wsign-conversion" -#pragma GCC diagnostic ignored "-Wunknown-pragmas" -#pragma GCC diagnostic ignored "-Wunused-macros" -#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant" -#elif defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable:4365) // conversion from 'X' to 'Y', signed/unsigned mismatch -#endif - -#include "loguru.hpp" - -#ifndef LOGURU_HAS_BEEN_IMPLEMENTED -#define LOGURU_HAS_BEEN_IMPLEMENTED - -#define LOGURU_PREAMBLE_WIDTH (53 + LOGURU_THREADNAME_WIDTH + LOGURU_FILENAME_WIDTH) - -#undef min -#undef max - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if LOGURU_SYSLOG -#include -#else -#define LOG_USER 0 -#endif - -#ifdef _WIN32 -#include - -#define localtime_r(a, b) localtime_s(b, a) // No localtime_r with MSVC, but arguments are swapped for localtime_s -#else -#include -#include // mkdir -#include // STDERR_FILENO -#endif - -#ifdef __linux__ -#include // PATH_MAX -#elif !defined(_WIN32) -#include // PATH_MAX -#endif - -#ifndef PATH_MAX -#define PATH_MAX 1024 -#endif - -#ifdef __APPLE__ -#include "TargetConditionals.h" -#endif - -// TODO: use defined(_POSIX_VERSION) for some of these things? - -#if defined(_WIN32) || defined(__CYGWIN__) -#define LOGURU_PTHREADS 0 -#define LOGURU_WINTHREADS 1 -#ifndef LOGURU_STACKTRACES -#define LOGURU_STACKTRACES 0 -#endif -#else -#define LOGURU_PTHREADS 1 -#define LOGURU_WINTHREADS 0 -#ifdef __GLIBC__ -#ifndef LOGURU_STACKTRACES -#define LOGURU_STACKTRACES 1 -#endif -#else -#ifndef LOGURU_STACKTRACES -#define LOGURU_STACKTRACES 0 -#endif -#endif -#endif - -#if LOGURU_STACKTRACES -#include // for __cxa_demangle -#include // for dladdr -#include // for backtrace -#endif // LOGURU_STACKTRACES - -#if LOGURU_PTHREADS -#include -#if defined(__FreeBSD__) -#include -#include -#elif defined(__OpenBSD__) -#include -#endif - -#ifdef __linux__ - /* On Linux, the default thread name is the same as the name of the binary. - Additionally, all new threads inherit the name of the thread it got forked from. - For this reason, Loguru use the pthread Thread Local Storage - for storing thread names on Linux. */ -#ifndef LOGURU_PTLS_NAMES -#define LOGURU_PTLS_NAMES 1 -#endif -#endif -#endif - -#if LOGURU_WINTHREADS -#ifndef _WIN32_WINNT -#define _WIN32_WINNT 0x0502 -#endif -#define WIN32_LEAN_AND_MEAN -#define NOMINMAX -#include -#endif - -#ifndef LOGURU_PTLS_NAMES -#define LOGURU_PTLS_NAMES 0 -#endif - -LOGURU_ANONYMOUS_NAMESPACE_BEGIN - -namespace loguru { - using namespace std::chrono; - -#if LOGURU_WITH_FILEABS - struct FileAbs { - char path[PATH_MAX]; - char mode_str[4]; - Verbosity verbosity; - struct stat st; - FILE* fp; - bool is_reopening = false; // to prevent recursive call in file_reopen. - decltype(steady_clock::now()) last_check_time = steady_clock::now(); - }; -#else - typedef FILE* FileAbs; -#endif - - struct Callback { - std::string id; - log_handler_t callback; - void* user_data; - Verbosity verbosity; // Does not change! - close_handler_t close; - flush_handler_t flush; - unsigned indentation; - }; - - using CallbackVec = std::vector; - - using StringPair = std::pair; - using StringPairList = std::vector; - - const auto s_start_time = steady_clock::now(); - - Verbosity g_stderr_verbosity = Verbosity_0; - bool g_colorlogtostderr = true; - unsigned g_flush_interval_ms = 0; - bool g_preamble_header = true; - bool g_preamble = true; - - Verbosity g_internal_verbosity = Verbosity_0; - - // Preamble details - bool g_preamble_date = true; - bool g_preamble_time = true; - bool g_preamble_uptime = true; - bool g_preamble_thread = true; - bool g_preamble_file = true; - bool g_preamble_verbose = true; - bool g_preamble_pipe = true; - - static std::recursive_mutex s_mutex; - static Verbosity s_max_out_verbosity = Verbosity_OFF; - static std::string s_argv0_filename; - static std::string s_arguments; - static char s_current_dir[PATH_MAX]; - static CallbackVec s_callbacks; - static fatal_handler_t s_fatal_handler = nullptr; - static verbosity_to_name_t s_verbosity_to_name_callback = nullptr; - static name_to_verbosity_t s_name_to_verbosity_callback = nullptr; - static StringPairList s_user_stack_cleanups; - static bool s_strip_file_path = true; - static std::atomic s_stderr_indentation{ 0 }; - - // For periodic flushing: - static std::thread* s_flush_thread = nullptr; - static bool s_needs_flushing = false; - - static SignalOptions s_signal_options = SignalOptions::none(); - - static const bool s_terminal_has_color = []() { -#ifdef _WIN32 -#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING -#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004 -#endif - - HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE); - if (hOut != INVALID_HANDLE_VALUE) { - DWORD dwMode = 0; - GetConsoleMode(hOut, &dwMode); - dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; - return SetConsoleMode(hOut, dwMode) != 0; - } - return false; -#else - if (!isatty(STDERR_FILENO)) { - return false; - } - if (const char* term = getenv("TERM")) { - return 0 == strcmp(term, "cygwin") - || 0 == strcmp(term, "linux") - || 0 == strcmp(term, "rxvt-unicode-256color") - || 0 == strcmp(term, "screen") - || 0 == strcmp(term, "screen-256color") - || 0 == strcmp(term, "screen.xterm-256color") - || 0 == strcmp(term, "tmux-256color") - || 0 == strcmp(term, "xterm") - || 0 == strcmp(term, "xterm-256color") - || 0 == strcmp(term, "xterm-termite") - || 0 == strcmp(term, "xterm-color"); - } else { - return false; - } -#endif - }(); - - static void print_preamble_header(char* out_buff, size_t out_buff_size); - - // ------------------------------------------------------------------------------ - // Colors - - bool terminal_has_color() { return s_terminal_has_color; } - - // Colors - -#ifdef _WIN32 -#define VTSEQ(ID) ("\x1b[1;" #ID "m") -#else -#define VTSEQ(ID) ("\x1b[" #ID "m") -#endif - - const char* terminal_black() { return s_terminal_has_color ? VTSEQ(30) : ""; } - const char* terminal_red() { return s_terminal_has_color ? VTSEQ(31) : ""; } - const char* terminal_green() { return s_terminal_has_color ? VTSEQ(32) : ""; } - const char* terminal_yellow() { return s_terminal_has_color ? VTSEQ(33) : ""; } - const char* terminal_blue() { return s_terminal_has_color ? VTSEQ(34) : ""; } - const char* terminal_purple() { return s_terminal_has_color ? VTSEQ(35) : ""; } - const char* terminal_cyan() { return s_terminal_has_color ? VTSEQ(36) : ""; } - const char* terminal_light_gray() { return s_terminal_has_color ? VTSEQ(37) : ""; } - const char* terminal_white() { return s_terminal_has_color ? VTSEQ(37) : ""; } - const char* terminal_light_red() { return s_terminal_has_color ? VTSEQ(91) : ""; } - const char* terminal_dim() { return s_terminal_has_color ? VTSEQ(2) : ""; } - - // Formating - const char* terminal_bold() { return s_terminal_has_color ? VTSEQ(1) : ""; } - const char* terminal_underline() { return s_terminal_has_color ? VTSEQ(4) : ""; } - - // You should end each line with this! - const char* terminal_reset() { return s_terminal_has_color ? VTSEQ(0) : ""; } - - // ------------------------------------------------------------------------------ -#if LOGURU_WITH_FILEABS - void file_reopen(void* user_data); - inline FILE* to_file(void* user_data) { return reinterpret_cast(user_data)->fp; } -#else - inline FILE* to_file(void* user_data) { return reinterpret_cast(user_data); } -#endif - - void file_log(void* user_data, const Message& message) - { -#if LOGURU_WITH_FILEABS - FileAbs* file_abs = reinterpret_cast(user_data); - if (file_abs->is_reopening) { - return; - } - // It is better checking file change every minute/hour/day, - // instead of doing this every time we log. - // Here check_interval is set to zero to enable checking every time; - const auto check_interval = seconds(0); - if (duration_cast(steady_clock::now() - file_abs->last_check_time) > check_interval) { - file_abs->last_check_time = steady_clock::now(); - file_reopen(user_data); - } - FILE* file = to_file(user_data); - if (!file) { - return; - } -#else - FILE* file = to_file(user_data); -#endif - fprintf(file, "%s%s%s%s\n", - message.preamble, message.indentation, message.prefix, message.message); - if (g_flush_interval_ms == 0) { - fflush(file); - } - } - - void file_close(void* user_data) - { - FILE* file = to_file(user_data); - if (file) { - fclose(file); - } -#if LOGURU_WITH_FILEABS - delete reinterpret_cast(user_data); -#endif - } - - void file_flush(void* user_data) - { - FILE* file = to_file(user_data); - fflush(file); - } - -#if LOGURU_WITH_FILEABS - void file_reopen(void* user_data) - { - FileAbs* file_abs = reinterpret_cast(user_data); - struct stat st; - int ret; - if (!file_abs->fp || (ret = stat(file_abs->path, &st)) == -1 || (st.st_ino != file_abs->st.st_ino)) { - file_abs->is_reopening = true; - if (file_abs->fp) { - fclose(file_abs->fp); - } - if (!file_abs->fp) { - VLOG_F(g_internal_verbosity, "Reopening file '" LOGURU_FMT(s) "' due to previous error", file_abs->path); - } else if (ret < 0) { - const auto why = errno_as_text(); - VLOG_F(g_internal_verbosity, "Reopening file '" LOGURU_FMT(s) "' due to '" LOGURU_FMT(s) "'", file_abs->path, why.c_str()); - } else { - VLOG_F(g_internal_verbosity, "Reopening file '" LOGURU_FMT(s) "' due to file changed", file_abs->path); - } - // try reopen current file. - if (!create_directories(file_abs->path)) { - LOG_F(ERROR, "Failed to create directories to '" LOGURU_FMT(s) "'", file_abs->path); - } - file_abs->fp = fopen(file_abs->path, file_abs->mode_str); - if (!file_abs->fp) { - LOG_F(ERROR, "Failed to open '" LOGURU_FMT(s) "'", file_abs->path); - } else { - stat(file_abs->path, &file_abs->st); - } - file_abs->is_reopening = false; - } - } -#endif - // ------------------------------------------------------------------------------ - // ------------------------------------------------------------------------------ -#if LOGURU_SYSLOG - void syslog_log(void* /*user_data*/, const Message& message) - { - /* - Level 0: Is reserved for kernel panic type situations. - Level 1: Is for Major resource failure. - Level 2->7 Application level failures - */ - int level; - if (message.verbosity < Verbosity_FATAL) { - level = 1; // System Alert - } else { - switch (message.verbosity) { - case Verbosity_FATAL: level = 2; break; // System Critical - case Verbosity_ERROR: level = 3; break; // System Error - case Verbosity_WARNING: level = 4; break; // System Warning - case Verbosity_INFO: level = 5; break; // System Notice - case Verbosity_1: level = 6; break; // System Info - default: level = 7; break; // System Debug - } - } - - // Note: We don't add the time info. - // This is done automatically by the syslog deamon. - // Otherwise log all information that the file log does. - syslog(level, "%s%s%s", message.indentation, message.prefix, message.message); - } - - void syslog_close(void* /*user_data*/) - { - closelog(); - } - - void syslog_flush(void* /*user_data*/) - { - } -#endif - // ------------------------------------------------------------------------------ - // Helpers: - - Text::~Text() { free(_str); } - -#if LOGURU_USE_FMTLIB - Text vtextprintf(const char* format, fmt::format_args args) - { - return Text(STRDUP(fmt::vformat(format, args).c_str())); - } -#else - LOGURU_PRINTF_LIKE(1, 0) - static Text vtextprintf(const char* format, va_list vlist) - { -#ifdef _WIN32 - int bytes_needed = _vscprintf(format, vlist); - CHECK_F(bytes_needed >= 0, "Bad string format: '%s'", format); - char* buff = (char*)malloc(bytes_needed + 1); - vsnprintf(buff, bytes_needed + 1, format, vlist); - return Text(buff); -#else - char* buff = nullptr; - int result = vasprintf(&buff, format, vlist); - CHECK_F(result >= 0, "Bad string format: '" LOGURU_FMT(s) "'", format); - return Text(buff); -#endif - } - - Text textprintf(const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - auto result = vtextprintf(format, vlist); - va_end(vlist); - return result; - } -#endif - - // Overloaded for variadic template matching. - Text textprintf() - { - return Text(static_cast(calloc(1, 1))); - } - - static const char* indentation(unsigned depth) - { - static const char buff[] = - ". . . . . . . . . . " ". . . . . . . . . . " - ". . . . . . . . . . " ". . . . . . . . . . " - ". . . . . . . . . . " ". . . . . . . . . . " - ". . . . . . . . . . " ". . . . . . . . . . " - ". . . . . . . . . . " ". . . . . . . . . . "; - static const size_t INDENTATION_WIDTH = 4; - static const size_t NUM_INDENTATIONS = (sizeof(buff) - 1) / INDENTATION_WIDTH; - depth = std::min(depth, NUM_INDENTATIONS); - return buff + INDENTATION_WIDTH * (NUM_INDENTATIONS - depth); - } - - static void parse_args(int& argc, char* argv[], const char* verbosity_flag) - { - int arg_dest = 1; - int out_argc = argc; - - for (int arg_it = 1; arg_it < argc; ++arg_it) { - auto cmd = argv[arg_it]; - auto arg_len = strlen(verbosity_flag); - - bool last_is_alpha = false; -#if LOGURU_USE_LOCALE - try { // locale variant of isalpha will throw on error - last_is_alpha = std::isalpha(cmd[arg_len], std::locale("")); - } - catch (...) { - last_is_alpha = std::isalpha(static_cast(cmd[arg_len])); - } -#else - last_is_alpha = std::isalpha(static_cast(cmd[arg_len])); -#endif - - if (strncmp(cmd, verbosity_flag, arg_len) == 0 && !last_is_alpha) { - out_argc -= 1; - auto value_str = cmd + arg_len; - if (value_str[0] == '\0') { - // Value in separate argument - arg_it += 1; - CHECK_LT_F(arg_it, argc, "Missing verbosiy level after " LOGURU_FMT(s) "", verbosity_flag); - value_str = argv[arg_it]; - out_argc -= 1; - } - if (*value_str == '=') { value_str += 1; } - - auto req_verbosity = get_verbosity_from_name(value_str); - if (req_verbosity != Verbosity_INVALID) { - g_stderr_verbosity = req_verbosity; - } else { - char* end = 0; - g_stderr_verbosity = static_cast(strtol(value_str, &end, 10)); - CHECK_F(end && *end == '\0', - "Invalid verbosity. Expected integer, INFO, WARNING, ERROR or OFF, got '" LOGURU_FMT(s) "'", value_str); - } - } else { - argv[arg_dest++] = argv[arg_it]; - } - } - - argc = out_argc; - argv[argc] = nullptr; - } - - static long long now_ns() - { - return duration_cast(high_resolution_clock::now().time_since_epoch()).count(); - } - - // Returns the part of the path after the last / or \ (if any). - const char* filename(const char* path) - { - for (auto ptr = path; *ptr; ++ptr) { - if (*ptr == '/' || *ptr == '\\') { - path = ptr + 1; - } - } - return path; - } - - // ------------------------------------------------------------------------------ - - static void on_atexit() - { - VLOG_F(g_internal_verbosity, "atexit"); - flush(); - } - - static void install_signal_handlers(const SignalOptions& signal_options); - - static void write_hex_digit(std::string& out, unsigned num) - { - DCHECK_LT_F(num, 16u); - if (num < 10u) { out.push_back(char('0' + num)); } else { out.push_back(char('A' + num - 10)); } - } - - static void write_hex_byte(std::string& out, uint8_t n) - { - write_hex_digit(out, n >> 4u); - write_hex_digit(out, n & 0x0f); - } - - static void escape(std::string& out, const std::string& str) - { - for (char c : str) { - /**/ if (c == '\a') { out += "\\a"; } else if (c == '\b') { out += "\\b"; } else if (c == '\f') { out += "\\f"; } else if (c == '\n') { out += "\\n"; } else if (c == '\r') { out += "\\r"; } else if (c == '\t') { out += "\\t"; } else if (c == '\v') { out += "\\v"; } else if (c == '\\') { out += "\\\\"; } else if (c == '\'') { out += "\\\'"; } else if (c == '\"') { out += "\\\""; } else if (c == ' ') { out += "\\ "; } else if (0 <= c && c < 0x20) { // ASCI control character: - // else if (c < 0x20 || c != (c & 127)) { // ASCII control character or UTF-8: - out += "\\x"; - write_hex_byte(out, static_cast(c)); - } else { out += c; } - } - } - - Text errno_as_text() - { - char buff[256]; -#if defined(__GLIBC__) && defined(_GNU_SOURCE) - // GNU Version - return Text(STRDUP(strerror_r(errno, buff, sizeof(buff)))); -#elif defined(__APPLE__) || _POSIX_C_SOURCE >= 200112L - // XSI Version - strerror_r(errno, buff, sizeof(buff)); - return Text(strdup(buff)); -#elif defined(_WIN32) - strerror_s(buff, sizeof(buff), errno); - return Text(STRDUP(buff)); -#else - // Not thread-safe. - return Text(STRDUP(strerror(errno))); -#endif - } - - void init(int& argc, char* argv[], const Options& options) - { - CHECK_GT_F(argc, 0, "Expected proper argc/argv"); - CHECK_EQ_F(argv[argc], nullptr, "Expected proper argc/argv"); - - s_argv0_filename = filename(argv[0]); - -#ifdef _WIN32 -#define getcwd _getcwd -#endif - - if (!getcwd(s_current_dir, sizeof(s_current_dir))) { - const auto error_text = errno_as_text(); - LOG_F(WARNING, "Failed to get current working directory: " LOGURU_FMT(s) "", error_text.c_str()); - } - - s_arguments = ""; - for (int i = 0; i < argc; ++i) { - escape(s_arguments, argv[i]); - if (i + 1 < argc) { - s_arguments += " "; - } - } - - if (options.verbosity_flag) { - parse_args(argc, argv, options.verbosity_flag); - } - - if (const auto main_thread_name = options.main_thread_name) { -#if LOGURU_PTLS_NAMES || LOGURU_WINTHREADS - set_thread_name(main_thread_name); -#elif LOGURU_PTHREADS - char old_thread_name[16] = { 0 }; - auto this_thread = pthread_self(); -#if defined(__APPLE__) || defined(__linux__) || defined(__sun) - pthread_getname_np(this_thread, old_thread_name, sizeof(old_thread_name)); -#endif - if (old_thread_name[0] == 0) { -#ifdef __APPLE__ - pthread_setname_np(main_thread_name); -#elif defined(__FreeBSD__) || defined(__OpenBSD__) - pthread_set_name_np(this_thread, main_thread_name); -#elif defined(__linux__) || defined(__sun) - pthread_setname_np(this_thread, main_thread_name); -#endif - } -#endif // LOGURU_PTHREADS - } - - if (g_stderr_verbosity >= Verbosity_INFO) { - if (g_preamble_header) { - char preamble_explain[LOGURU_PREAMBLE_WIDTH]; - print_preamble_header(preamble_explain, sizeof(preamble_explain)); - if (g_colorlogtostderr && s_terminal_has_color) { - fprintf(stderr, "%s%s%s\n", terminal_reset(), terminal_dim(), preamble_explain); - } else { - fprintf(stderr, "%s\n", preamble_explain); - } - } - fflush(stderr); - } - VLOG_F(g_internal_verbosity, "arguments: " LOGURU_FMT(s) "", s_arguments.c_str()); - if (strlen(s_current_dir) != 0) { - VLOG_F(g_internal_verbosity, "Current dir: " LOGURU_FMT(s) "", s_current_dir); - } - VLOG_F(g_internal_verbosity, "stderr verbosity: " LOGURU_FMT(d) "", g_stderr_verbosity); - VLOG_F(g_internal_verbosity, "-----------------------------------"); - - install_signal_handlers(options.signal_options); - - atexit(on_atexit); - } - - void shutdown() - { - VLOG_F(g_internal_verbosity, "loguru::shutdown()"); - remove_all_callbacks(); - set_fatal_handler(nullptr); - set_verbosity_to_name_callback(nullptr); - set_name_to_verbosity_callback(nullptr); - } - - void write_date_time(char* buff, unsigned long long buff_size) - { - auto now = system_clock::now(); - long long ms_since_epoch = duration_cast(now.time_since_epoch()).count(); - time_t sec_since_epoch = time_t(ms_since_epoch / 1000); - tm time_info; - localtime_r(&sec_since_epoch, &time_info); - snprintf(buff, buff_size, "%04d%02d%02d_%02d%02d%02d.%03lld", - 1900 + time_info.tm_year, 1 + time_info.tm_mon, time_info.tm_mday, - time_info.tm_hour, time_info.tm_min, time_info.tm_sec, ms_since_epoch % 1000); - } - - const char* argv0_filename() - { - return s_argv0_filename.c_str(); - } - - const char* arguments() - { - return s_arguments.c_str(); - } - - const char* current_dir() - { - return s_current_dir; - } - - const char* home_dir() - { -#ifdef __MINGW32__ - auto home = getenv("USERPROFILE"); - CHECK_F(home != nullptr, "Missing USERPROFILE"); - return home; -#elif defined(_WIN32) - char* user_profile; - size_t len; - errno_t err = _dupenv_s(&user_profile, &len, "USERPROFILE"); - CHECK_F(err == 0, "Missing USERPROFILE"); - return user_profile; -#else // _WIN32 - auto home = getenv("HOME"); - CHECK_F(home != nullptr, "Missing HOME"); - return home; -#endif // _WIN32 - } - - void suggest_log_path(const char* prefix, char* buff, unsigned long long buff_size) - { - if (prefix[0] == '~') { - snprintf(buff, buff_size - 1, "%s%s", home_dir(), prefix + 1); - } else { - snprintf(buff, buff_size - 1, "%s", prefix); - } - - // Check for terminating / - size_t n = strlen(buff); - if (n != 0) { - if (buff[n - 1] != '/') { - CHECK_F(n + 2 < buff_size, "Filename buffer too small"); - buff[n] = '/'; - buff[n + 1] = '\0'; - } - } - -#ifdef _WIN32 - strncat_s(buff, buff_size - strlen(buff) - 1, s_argv0_filename.c_str(), buff_size - strlen(buff) - 1); - strncat_s(buff, buff_size - strlen(buff) - 1, "/", buff_size - strlen(buff) - 1); - write_date_time(buff + strlen(buff), buff_size - strlen(buff)); - strncat_s(buff, buff_size - strlen(buff) - 1, ".log", buff_size - strlen(buff) - 1); -#else - strncat(buff, s_argv0_filename.c_str(), buff_size - strlen(buff) - 1); - strncat(buff, "/", buff_size - strlen(buff) - 1); - write_date_time(buff + strlen(buff), buff_size - strlen(buff)); - strncat(buff, ".log", buff_size - strlen(buff) - 1); -#endif - } - - bool create_directories(const char* file_path_const) - { - CHECK_F(file_path_const && *file_path_const); - char* file_path = STRDUP(file_path_const); - for (char* p = strchr(file_path + 1, '/'); p; p = strchr(p + 1, '/')) { - *p = '\0'; - -#ifdef _WIN32 - if (_mkdir(file_path) == -1) { -#else - if (mkdir(file_path, 0755) == -1) { -#endif - if (errno != EEXIST) { - LOG_F(ERROR, "Failed to create directory '" LOGURU_FMT(s) "'", file_path); - LOG_IF_F(ERROR, errno == EACCES, "EACCES"); - LOG_IF_F(ERROR, errno == ENAMETOOLONG, "ENAMETOOLONG"); - LOG_IF_F(ERROR, errno == ENOENT, "ENOENT"); - LOG_IF_F(ERROR, errno == ENOTDIR, "ENOTDIR"); - LOG_IF_F(ERROR, errno == ELOOP, "ELOOP"); - - *p = '/'; - free(file_path); - return false; - } - } - *p = '/'; - } - free(file_path); - return true; - } - bool add_file(const char* path_in, FileMode mode, Verbosity verbosity) - { - char path[PATH_MAX]; - if (path_in[0] == '~') { - snprintf(path, sizeof(path) - 1, "%s%s", home_dir(), path_in + 1); - } else { - snprintf(path, sizeof(path) - 1, "%s", path_in); - } - - if (!create_directories(path)) { - LOG_F(ERROR, "Failed to create directories to '" LOGURU_FMT(s) "'", path); - } - - const char* mode_str = (mode == FileMode::Truncate ? "w" : "a"); - FILE* file; -#ifdef _WIN32 - file = _fsopen(path, mode_str, _SH_DENYNO); -#else - file = fopen(path, mode_str); -#endif - if (!file) { - LOG_F(ERROR, "Failed to open '" LOGURU_FMT(s) "'", path); - return false; - } -#if LOGURU_WITH_FILEABS - FileAbs* file_abs = new FileAbs(); // this is deleted in file_close; - snprintf(file_abs->path, sizeof(file_abs->path) - 1, "%s", path); - snprintf(file_abs->mode_str, sizeof(file_abs->mode_str) - 1, "%s", mode_str); - stat(file_abs->path, &file_abs->st); - file_abs->fp = file; - file_abs->verbosity = verbosity; - add_callback(path_in, file_log, file_abs, verbosity, file_close, file_flush); -#else - add_callback(path_in, file_log, file, verbosity, file_close, file_flush); -#endif - - if (mode == FileMode::Append) { - fprintf(file, "\n\n\n\n\n"); - } - if (!s_arguments.empty()) { - fprintf(file, "arguments: %s\n", s_arguments.c_str()); - } - if (strlen(s_current_dir) != 0) { - fprintf(file, "Current dir: %s\n", s_current_dir); - } - fprintf(file, "File verbosity level: %d\n", verbosity); - if (g_preamble_header) { - char preamble_explain[LOGURU_PREAMBLE_WIDTH]; - print_preamble_header(preamble_explain, sizeof(preamble_explain)); - fprintf(file, "%s\n", preamble_explain); - } - fflush(file); - - VLOG_F(g_internal_verbosity, "Logging to '" LOGURU_FMT(s) "', mode: '" LOGURU_FMT(s) "', verbosity: " LOGURU_FMT(d) "", path, mode_str, verbosity); - return true; - } - - /* - Will add syslog as a standard sink for log messages - Any logging message with a verbosity lower or equal to - the given verbosity will be included. - - This works for Unix like systems (i.e. Linux/Mac) - There is no current implementation for Windows (as I don't know the - equivalent calls or have a way to test them). If you know please - add and send a pull request. - - The code should still compile under windows but will only generate - a warning message that syslog is unavailable. - - Search for LOGURU_SYSLOG to find and fix. - */ - bool add_syslog(const char* app_name, Verbosity verbosity) - { - return add_syslog(app_name, verbosity, LOG_USER); - } - bool add_syslog(const char* app_name, Verbosity verbosity, int facility) - { -#if LOGURU_SYSLOG - if (app_name == nullptr) { - app_name = argv0_filename(); - } - openlog(app_name, 0, facility); - add_callback("'syslog'", syslog_log, nullptr, verbosity, syslog_close, syslog_flush); - - VLOG_F(g_internal_verbosity, "Logging to 'syslog' , verbosity: " LOGURU_FMT(d) "", verbosity); - return true; -#else - (void)app_name; - (void)verbosity; - (void)facility; - VLOG_F(g_internal_verbosity, "syslog not implemented on this system. Request to install syslog logging ignored."); - return false; -#endif - } - // Will be called right before abort(). - void set_fatal_handler(fatal_handler_t handler) - { - s_fatal_handler = handler; - } - - fatal_handler_t get_fatal_handler() - { - return s_fatal_handler; - } - - void set_verbosity_to_name_callback(verbosity_to_name_t callback) - { - s_verbosity_to_name_callback = callback; - } - - void set_name_to_verbosity_callback(name_to_verbosity_t callback) - { - s_name_to_verbosity_callback = callback; - } - - void add_stack_cleanup(const char* find_this, const char* replace_with_this) - { - if (strlen(find_this) <= strlen(replace_with_this)) { - LOG_F(WARNING, "add_stack_cleanup: the replacement should be shorter than the pattern!"); - return; - } - - s_user_stack_cleanups.push_back(StringPair(find_this, replace_with_this)); - } - - static void on_callback_change() - { - s_max_out_verbosity = Verbosity_OFF; - for (const auto& callback : s_callbacks) { - s_max_out_verbosity = std::max(s_max_out_verbosity, callback.verbosity); - } - } - - void add_callback( - const char* id, - log_handler_t callback, - void* user_data, - Verbosity verbosity, - close_handler_t on_close, - flush_handler_t on_flush) - { - std::lock_guard lock(s_mutex); - s_callbacks.push_back(Callback{ id, callback, user_data, verbosity, on_close, on_flush, 0 }); - on_callback_change(); - } - - // Returns a custom verbosity name if one is available, or nullptr. - // See also set_verbosity_to_name_callback. - const char* get_verbosity_name(Verbosity verbosity) - { - auto name = s_verbosity_to_name_callback - ? (*s_verbosity_to_name_callback)(verbosity) - : nullptr; - - // Use standard replacements if callback fails: - if (!name) { - if (verbosity <= Verbosity_FATAL) { - name = "FATL"; - } else if (verbosity == Verbosity_ERROR) { - name = "ERR"; - } else if (verbosity == Verbosity_WARNING) { - name = "WARN"; - } else if (verbosity == Verbosity_INFO) { - name = "INFO"; - } - } - - return name; - } - - // Returns Verbosity_INVALID if the name is not found. - // See also set_name_to_verbosity_callback. - Verbosity get_verbosity_from_name(const char* name) - { - auto verbosity = s_name_to_verbosity_callback - ? (*s_name_to_verbosity_callback)(name) - : Verbosity_INVALID; - - // Use standard replacements if callback fails: - if (verbosity == Verbosity_INVALID) { - if (strcmp(name, "OFF") == 0) { - verbosity = Verbosity_OFF; - } else if (strcmp(name, "INFO") == 0) { - verbosity = Verbosity_INFO; - } else if (strcmp(name, "WARNING") == 0) { - verbosity = Verbosity_WARNING; - } else if (strcmp(name, "ERROR") == 0) { - verbosity = Verbosity_ERROR; - } else if (strcmp(name, "FATAL") == 0) { - verbosity = Verbosity_FATAL; - } - } - - return verbosity; - } - - bool remove_callback(const char* id) - { - std::lock_guard lock(s_mutex); - auto it = std::find_if(begin(s_callbacks), end(s_callbacks), [&](const Callback& c) { return c.id == id; }); - if (it != s_callbacks.end()) { - if (it->close) { it->close(it->user_data); } - s_callbacks.erase(it); - on_callback_change(); - return true; - } else { - LOG_F(ERROR, "Failed to locate callback with id '" LOGURU_FMT(s) "'", id); - return false; - } - } - - void remove_all_callbacks() - { - std::lock_guard lock(s_mutex); - for (auto& callback : s_callbacks) { - if (callback.close) { - callback.close(callback.user_data); - } - } - s_callbacks.clear(); - on_callback_change(); - } - - // Returns the maximum of g_stderr_verbosity and all file/custom outputs. - Verbosity current_verbosity_cutoff() - { - return g_stderr_verbosity > s_max_out_verbosity ? - g_stderr_verbosity : s_max_out_verbosity; - } - - // ------------------------------------------------------------------------ - // Threads names - -#if LOGURU_PTLS_NAMES - static pthread_once_t s_pthread_key_once = PTHREAD_ONCE_INIT; - static pthread_key_t s_pthread_key_name; - - void make_pthread_key_name() - { - (void)pthread_key_create(&s_pthread_key_name, free); - } -#endif - -#if LOGURU_WINTHREADS - // Where we store the custom thread name set by `set_thread_name` - char* thread_name_buffer() - { - __declspec(thread) static char thread_name[LOGURU_THREADNAME_WIDTH + 1] = { 0 }; - return &thread_name[0]; - } -#endif // LOGURU_WINTHREADS - - void set_thread_name(const char* name) - { -#if LOGURU_PTLS_NAMES - // Store thread name in thread-local storage at `s_pthread_key_name` - (void)pthread_once(&s_pthread_key_once, make_pthread_key_name); - (void)pthread_setspecific(s_pthread_key_name, STRDUP(name)); -#elif LOGURU_PTHREADS - // Tell the OS the thread name -#ifdef __APPLE__ - pthread_setname_np(name); -#elif defined(__FreeBSD__) || defined(__OpenBSD__) - pthread_set_name_np(pthread_self(), name); -#elif defined(__linux__) || defined(__sun) - pthread_setname_np(pthread_self(), name); -#endif -#elif LOGURU_WINTHREADS - // Store thread name in a thread-local storage: - strncpy_s(thread_name_buffer(), LOGURU_THREADNAME_WIDTH + 1, name, _TRUNCATE); -#else // LOGURU_PTHREADS - // TODO: on these weird platforms we should also store the thread name - // in a generic thread-local storage. - (void)name; -#endif // LOGURU_PTHREADS - } - - void get_thread_name(char* buffer, unsigned long long length, bool right_align_hex_id) - { - CHECK_NE_F(length, 0u, "Zero length buffer in get_thread_name"); - CHECK_NOTNULL_F(buffer, "nullptr in get_thread_name"); - -#if LOGURU_PTLS_NAMES - (void)pthread_once(&s_pthread_key_once, make_pthread_key_name); - if (const char* name = static_cast(pthread_getspecific(s_pthread_key_name))) { - snprintf(buffer, static_cast(length), "%s", name); - } else { - buffer[0] = 0; - } -#elif LOGURU_PTHREADS - // Ask the OS about the thread name. - // This is what we *want* to do on all platforms, but - // only some platforms support it (currently). - pthread_getname_np(pthread_self(), buffer, length); -#elif LOGURU_WINTHREADS - snprintf(buffer, static_cast(length), "%s", thread_name_buffer()); -#else - // Thread names unsupported - buffer[0] = 0; -#endif - - if (buffer[0] == 0) { - // We failed to get a readable thread name. - // Write a HEX thread ID instead. - // We try to get an ID that is the same as the ID you could - // read in your debugger, system monitor etc. - -#ifdef __APPLE__ - uint64_t thread_id; - pthread_threadid_np(pthread_self(), &thread_id); -#elif defined(__FreeBSD__) - long thread_id; - (void)thr_self(&thread_id); -#elif LOGURU_PTHREADS - uint64_t thread_id = pthread_self(); -#else - // This ID does not correllate to anything we can get from the OS, - // so this is the worst way to get the ID. - const auto thread_id = std::hash{}(std::this_thread::get_id()); -#endif - - if (right_align_hex_id) { - snprintf(buffer, static_cast(length), "%*X", static_cast(length - 1), static_cast(thread_id)); - } else { - snprintf(buffer, static_cast(length), "%X", static_cast(thread_id)); - } - } - } - - // ------------------------------------------------------------------------ - // Stack traces - -#if LOGURU_STACKTRACES - Text demangle(const char* name) - { - int status = -1; - char* demangled = abi::__cxa_demangle(name, 0, 0, &status); - Text result{ status == 0 ? demangled : STRDUP(name) }; - return result; - } - -#if LOGURU_RTTI - template - std::string type_name() - { - auto demangled = demangle(typeid(T).name()); - return demangled.c_str(); - } -#endif // LOGURU_RTTI - - static const StringPairList REPLACE_LIST = { - #if LOGURU_RTTI - { type_name(), "std::string" }, - { type_name(), "std::wstring" }, - { type_name(), "std::u16string" }, - { type_name(), "std::u32string" }, - #endif // LOGURU_RTTI - { "std::__1::", "std::" }, - { "__thiscall ", "" }, - { "__cdecl ", "" }, - }; - - void do_replacements(const StringPairList & replacements, std::string & str) - { - for (auto&& p : replacements) { - if (p.first.size() <= p.second.size()) { - // On gcc, "type_name()" is "std::string" - continue; - } - - size_t it; - while ((it = str.find(p.first)) != std::string::npos) { - str.replace(it, p.first.size(), p.second); - } - } - } - - std::string prettify_stacktrace(const std::string & input) - { - std::string output = input; - - do_replacements(s_user_stack_cleanups, output); - do_replacements(REPLACE_LIST, output); - - try { - std::regex std_allocator_re(R"(,\s*std::allocator<[^<>]+>)"); - output = std::regex_replace(output, std_allocator_re, std::string("")); - - std::regex template_spaces_re(R"(<\s*([^<> ]+)\s*>)"); - output = std::regex_replace(output, template_spaces_re, std::string("<$1>")); - } - catch (std::regex_error&) { - // Probably old GCC. - } - - return output; - } - - std::string stacktrace_as_stdstring(int skip) - { - // From https://gist.github.com/fmela/591333 - void* callstack[128]; - const auto max_frames = sizeof(callstack) / sizeof(callstack[0]); - int num_frames = backtrace(callstack, max_frames); - char** symbols = backtrace_symbols(callstack, num_frames); - - std::string result; - // Print stack traces so the most relevant ones are written last - // Rationale: http://yellerapp.com/posts/2015-01-22-upside-down-stacktraces.html - for (int i = num_frames - 1; i >= skip; --i) { - char buf[1024]; - Dl_info info; - if (dladdr(callstack[i], &info) && info.dli_sname) { - char* demangled = NULL; - int status = -1; - if (info.dli_sname[0] == '_') { - demangled = abi::__cxa_demangle(info.dli_sname, 0, 0, &status); - } - snprintf(buf, sizeof(buf), "%-3d %*p %s + %zd\n", - i - skip, int(2 + sizeof(void*) * 2), callstack[i], - status == 0 ? demangled : - info.dli_sname == 0 ? symbols[i] : info.dli_sname, - static_cast(callstack[i]) - static_cast(info.dli_saddr)); - free(demangled); - } else { - snprintf(buf, sizeof(buf), "%-3d %*p %s\n", - i - skip, int(2 + sizeof(void*) * 2), callstack[i], symbols[i]); - } - result += buf; - } - free(symbols); - - if (num_frames == max_frames) { - result = "[truncated]\n" + result; - } - - if (!result.empty() && result[result.size() - 1] == '\n') { - result.resize(result.size() - 1); - } - - return prettify_stacktrace(result); - } - -#else // LOGURU_STACKTRACES - Text demangle(const char* name) - { - return Text(STRDUP(name)); - } - - std::string stacktrace_as_stdstring(int) - { - // No stacktraces available on this platform" - return ""; - } - -#endif // LOGURU_STACKTRACES - - Text stacktrace(int skip) - { - auto str = stacktrace_as_stdstring(skip + 1); - return Text(STRDUP(str.c_str())); - } - - // ------------------------------------------------------------------------ - - static void print_preamble_header(char* out_buff, size_t out_buff_size) - { - if (out_buff_size == 0) { return; } - out_buff[0] = '\0'; - size_t pos = 0; - if (g_preamble_date && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "date "); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_time && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "time "); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_uptime && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "( uptime ) "); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_thread && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "[%-*s]", LOGURU_THREADNAME_WIDTH, " thread name/id"); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_file && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%*s:line ", LOGURU_FILENAME_WIDTH, "file"); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_verbose && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, " v"); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_pipe && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "| "); - if (bytes > 0) { - pos += bytes; - } - } - } - - static void print_preamble(char* out_buff, size_t out_buff_size, Verbosity verbosity, const char* file, unsigned line) - { - if (out_buff_size == 0) { return; } - out_buff[0] = '\0'; - if (!g_preamble) { return; } - long long ms_since_epoch = duration_cast(system_clock::now().time_since_epoch()).count(); - time_t sec_since_epoch = time_t(ms_since_epoch / 1000); - tm time_info; - localtime_r(&sec_since_epoch, &time_info); - - auto uptime_ms = duration_cast(steady_clock::now() - s_start_time).count(); - auto uptime_sec = static_cast (uptime_ms) / 1000.0; - - char thread_name[LOGURU_THREADNAME_WIDTH + 1] = { 0 }; - get_thread_name(thread_name, LOGURU_THREADNAME_WIDTH + 1, true); - - if (s_strip_file_path) { - file = filename(file); - } - - char level_buff[6]; - const char* custom_level_name = get_verbosity_name(verbosity); - if (custom_level_name) { - snprintf(level_buff, sizeof(level_buff) - 1, "%s", custom_level_name); - } else { - snprintf(level_buff, sizeof(level_buff) - 1, "% 4d", static_cast(verbosity)); - } - - size_t pos = 0; - - if (g_preamble_date && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%04d-%02d-%02d ", - 1900 + time_info.tm_year, 1 + time_info.tm_mon, time_info.tm_mday); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_time && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%02d:%02d:%02d.%03lld ", - time_info.tm_hour, time_info.tm_min, time_info.tm_sec, ms_since_epoch % 1000); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_uptime && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "(%8.3fs) ", - uptime_sec); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_thread && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "[%-*s]", - LOGURU_THREADNAME_WIDTH, thread_name); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_file && pos < out_buff_size) { - char shortened_filename[LOGURU_FILENAME_WIDTH + 1]; - snprintf(shortened_filename, LOGURU_FILENAME_WIDTH + 1, "%s", file); - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%*s:%-5u ", - LOGURU_FILENAME_WIDTH, shortened_filename, line); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_verbose && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "%4s", - level_buff); - if (bytes > 0) { - pos += bytes; - } - } - if (g_preamble_pipe && pos < out_buff_size) { - int bytes = snprintf(out_buff + pos, out_buff_size - pos, "| "); - if (bytes > 0) { - pos += bytes; - } - } - } - - // stack_trace_skip is just if verbosity == FATAL. - static void log_message(int stack_trace_skip, Message & message, bool with_indentation, bool abort_if_fatal) - { - const auto verbosity = message.verbosity; - std::lock_guard lock(s_mutex); - - if (message.verbosity == Verbosity_FATAL) { - auto st = loguru::stacktrace(stack_trace_skip + 2); - if (!st.empty()) { - RAW_LOG_F(ERROR, "Stack trace:\n" LOGURU_FMT(s) "", st.c_str()); - } - - auto ec = loguru::get_error_context(); - if (!ec.empty()) { - RAW_LOG_F(ERROR, "" LOGURU_FMT(s) "", ec.c_str()); - } - } - - if (with_indentation) { - message.indentation = indentation(s_stderr_indentation); - } - - if (verbosity <= g_stderr_verbosity) { - if (g_colorlogtostderr && s_terminal_has_color) { - if (verbosity > Verbosity_WARNING) { - fprintf(stderr, "%s%s%s%s%s%s%s%s\n", - terminal_reset(), - terminal_dim(), - message.preamble, - message.indentation, - verbosity == Verbosity_INFO ? terminal_reset() : "", // un-dim for info - message.prefix, - message.message, - terminal_reset()); - } else { - fprintf(stderr, "%s%s%s%s%s%s%s\n", - terminal_reset(), - verbosity == Verbosity_WARNING ? terminal_yellow() : terminal_red(), - message.preamble, - message.indentation, - message.prefix, - message.message, - terminal_reset()); - } - } else { - fprintf(stderr, "%s%s%s%s\n", - message.preamble, message.indentation, message.prefix, message.message); - } - - if (g_flush_interval_ms == 0) { - fflush(stderr); - } else { - s_needs_flushing = true; - } - } - - for (auto& p : s_callbacks) { - if (verbosity <= p.verbosity) { - if (with_indentation) { - message.indentation = indentation(p.indentation); - } - p.callback(p.user_data, message); - if (g_flush_interval_ms == 0) { - if (p.flush) { p.flush(p.user_data); } - } else { - s_needs_flushing = true; - } - } - } - - if (g_flush_interval_ms > 0 && !s_flush_thread) { - s_flush_thread = new std::thread([]() { - for (;;) { - if (s_needs_flushing) { - flush(); - } - std::this_thread::sleep_for(std::chrono::milliseconds(g_flush_interval_ms)); - } - }); - } - - if (message.verbosity == Verbosity_FATAL) { - flush(); - - if (s_fatal_handler) { - s_fatal_handler(message); - flush(); - } - - if (abort_if_fatal) { -#if !defined(_WIN32) - if (s_signal_options.sigabrt) { - // Make sure we don't catch our own abort: - signal(SIGABRT, SIG_DFL); - } -#endif - abort(); - } - } - } - - // stack_trace_skip is just if verbosity == FATAL. - void log_to_everywhere(int stack_trace_skip, Verbosity verbosity, - const char* file, unsigned line, - const char* prefix, const char* buff) - { - char preamble_buff[LOGURU_PREAMBLE_WIDTH]; - print_preamble(preamble_buff, sizeof(preamble_buff), verbosity, file, line); - auto message = Message{ verbosity, file, line, preamble_buff, "", prefix, buff }; - log_message(stack_trace_skip + 1, message, true, true); - } - -#if LOGURU_USE_FMTLIB - void vlog(Verbosity verbosity, const char* file, unsigned line, const char* format, fmt::format_args args) - { - auto formatted = fmt::vformat(format, args); - log_to_everywhere(1, verbosity, file, line, "", formatted.c_str()); - } - - void raw_vlog(Verbosity verbosity, const char* file, unsigned line, const char* format, fmt::format_args args) - { - auto formatted = fmt::vformat(format, args); - auto message = Message{ verbosity, file, line, "", "", "", formatted.c_str() }; - log_message(1, message, false, true); - } -#else - void log(Verbosity verbosity, const char* file, unsigned line, const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - vlog(verbosity, file, line, format, vlist); - va_end(vlist); - } - - void vlog(Verbosity verbosity, const char* file, unsigned line, const char* format, va_list vlist) - { - auto buff = vtextprintf(format, vlist); - log_to_everywhere(1, verbosity, file, line, "", buff.c_str()); - } - - void raw_log(Verbosity verbosity, const char* file, unsigned line, const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - auto buff = vtextprintf(format, vlist); - auto message = Message{ verbosity, file, line, "", "", "", buff.c_str() }; - log_message(1, message, false, true); - va_end(vlist); - } -#endif - - void flush() - { - std::lock_guard lock(s_mutex); - fflush(stderr); - for (const auto& callback : s_callbacks) { - if (callback.flush) { - callback.flush(callback.user_data); - } - } - s_needs_flushing = false; - } - - LogScopeRAII::LogScopeRAII(Verbosity verbosity, const char* file, unsigned line, const char* format, va_list vlist) : - _verbosity(verbosity), _file(file), _line(line) - { - this->Init(format, vlist); - } - - LogScopeRAII::LogScopeRAII(Verbosity verbosity, const char* file, unsigned line, const char* format, ...) : - _verbosity(verbosity), _file(file), _line(line) - { - va_list vlist; - va_start(vlist, format); - this->Init(format, vlist); - va_end(vlist); - } - - LogScopeRAII::~LogScopeRAII() - { - if (_file) { - std::lock_guard lock(s_mutex); - if (_indent_stderr && s_stderr_indentation > 0) { - --s_stderr_indentation; - } - for (auto& p : s_callbacks) { - // Note: Callback indentation cannot change! - if (_verbosity <= p.verbosity) { - // in unlikely case this callback is new - if (p.indentation > 0) { - --p.indentation; - } - } - } -#if LOGURU_VERBOSE_SCOPE_ENDINGS - auto duration_sec = static_cast(now_ns() - _start_time_ns) / 1e9; -#if LOGURU_USE_FMTLIB - auto buff = textprintf("{:.{}f} s: {:s}", duration_sec, LOGURU_SCOPE_TIME_PRECISION, _name); -#else - auto buff = textprintf("%.*f s: %s", LOGURU_SCOPE_TIME_PRECISION, duration_sec, _name); -#endif - log_to_everywhere(1, _verbosity, _file, _line, "} ", buff.c_str()); -#else - log_to_everywhere(1, _verbosity, _file, _line, "}", ""); -#endif - } - } - - void LogScopeRAII::Init(const char* format, va_list vlist) - { - if (_verbosity <= current_verbosity_cutoff()) { - std::lock_guard lock(s_mutex); - _indent_stderr = (_verbosity <= g_stderr_verbosity); - _start_time_ns = now_ns(); - vsnprintf(_name, sizeof(_name), format, vlist); - log_to_everywhere(1, _verbosity, _file, _line, "{ ", _name); - - if (_indent_stderr) { - ++s_stderr_indentation; - } - - for (auto& p : s_callbacks) { - if (_verbosity <= p.verbosity) { - ++p.indentation; - } - } - } else { - _file = nullptr; - } - } - -#if LOGURU_USE_FMTLIB - void vlog_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, const char* format, fmt::format_args args) - { - auto formatted = fmt::vformat(format, args); - log_to_everywhere(stack_trace_skip + 1, Verbosity_FATAL, file, line, expr, formatted.c_str()); - abort(); // log_to_everywhere already does this, but this makes the analyzer happy. - } -#else - void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - auto buff = vtextprintf(format, vlist); - log_to_everywhere(stack_trace_skip + 1, Verbosity_FATAL, file, line, expr, buff.c_str()); - va_end(vlist); - abort(); // log_to_everywhere already does this, but this makes the analyzer happy. - } -#endif - - void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line) - { - log_and_abort(stack_trace_skip + 1, expr, file, line, " "); - } - - // ---------------------------------------------------------------------------- - // Streams: - -#if LOGURU_USE_FMTLIB - template - std::string vstrprintf(const char* format, const Args&... args) - { - auto text = textprintf(format, args...); - std::string result = text.c_str(); - return result; - } - - template - std::string strprintf(const char* format, const Args&... args) - { - return vstrprintf(format, args...); - } -#else - std::string vstrprintf(const char* format, va_list vlist) - { - auto text = vtextprintf(format, vlist); - std::string result = text.c_str(); - return result; - } - - std::string strprintf(const char* format, ...) - { - va_list vlist; - va_start(vlist, format); - auto result = vstrprintf(format, vlist); - va_end(vlist); - return result; - } -#endif - -#if LOGURU_WITH_STREAMS - - StreamLogger::~StreamLogger() noexcept(false) - { - auto message = _ss.str(); - log(_verbosity, _file, _line, LOGURU_FMT(s), message.c_str()); - } - - AbortLogger::~AbortLogger() noexcept(false) - { - auto message = _ss.str(); - loguru::log_and_abort(1, _expr, _file, _line, LOGURU_FMT(s), message.c_str()); - } - -#endif // LOGURU_WITH_STREAMS - - // ---------------------------------------------------------------------------- - // 888888 88""Yb 88""Yb dP"Yb 88""Yb dP""b8 dP"Yb 88b 88 888888 888888 Yb dP 888888 - // 88__ 88__dP 88__dP dP Yb 88__dP dP `" dP Yb 88Yb88 88 88__ YbdP 88 - // 88"" 88"Yb 88"Yb Yb dP 88"Yb Yb Yb dP 88 Y88 88 88"" dPYb 88 - // 888888 88 Yb 88 Yb YbodP 88 Yb YboodP YbodP 88 Y8 88 888888 dP Yb 88 - // ---------------------------------------------------------------------------- - - struct StringStream { - std::string str; - }; - - // Use this in your EcPrinter implementations. - void stream_print(StringStream & out_string_stream, const char* text) - { - out_string_stream.str += text; - } - - // ---------------------------------------------------------------------------- - - using ECPtr = EcEntryBase*; - -#if defined(_WIN32) || (defined(__APPLE__) && !TARGET_OS_IPHONE) -#ifdef __APPLE__ -#define LOGURU_THREAD_LOCAL __thread -#else -#define LOGURU_THREAD_LOCAL thread_local -#endif - static LOGURU_THREAD_LOCAL ECPtr thread_ec_ptr = nullptr; - - ECPtr& get_thread_ec_head_ref() - { - return thread_ec_ptr; - } -#else // !thread_local - static pthread_once_t s_ec_pthread_once = PTHREAD_ONCE_INIT; - static pthread_key_t s_ec_pthread_key; - - void free_ec_head_ref(void* io_error_context) - { - delete reinterpret_cast(io_error_context); - } - - void ec_make_pthread_key() - { - (void)pthread_key_create(&s_ec_pthread_key, free_ec_head_ref); - } - - ECPtr& get_thread_ec_head_ref() - { - (void)pthread_once(&s_ec_pthread_once, ec_make_pthread_key); - auto ec = reinterpret_cast(pthread_getspecific(s_ec_pthread_key)); - if (ec == nullptr) { - ec = new ECPtr(nullptr); - (void)pthread_setspecific(s_ec_pthread_key, ec); - } - return *ec; - } -#endif // !thread_local - - // ---------------------------------------------------------------------------- - - EcHandle get_thread_ec_handle() - { - return get_thread_ec_head_ref(); - } - - Text get_error_context() - { - return get_error_context_for(get_thread_ec_head_ref()); - } - - Text get_error_context_for(const EcEntryBase * ec_head) - { - std::vector stack; - while (ec_head) { - stack.push_back(ec_head); - ec_head = ec_head->_previous; - } - std::reverse(stack.begin(), stack.end()); - - StringStream result; - if (!stack.empty()) { - result.str += "------------------------------------------------\n"; - for (auto entry : stack) { - const auto description = std::string(entry->_descr) + ":"; -#if LOGURU_USE_FMTLIB - auto prefix = textprintf("[ErrorContext] {.{}s}:{:-5u} {:-20s} ", - filename(entry->_file), LOGURU_FILENAME_WIDTH, entry->_line, description.c_str()); -#else - auto prefix = textprintf("[ErrorContext] %*s:%-5u %-20s ", - LOGURU_FILENAME_WIDTH, filename(entry->_file), entry->_line, description.c_str()); -#endif - result.str += prefix.c_str(); - entry->print_value(result); - result.str += "\n"; - } - result.str += "------------------------------------------------"; - } - return Text(STRDUP(result.str.c_str())); - } - - EcEntryBase::EcEntryBase(const char* file, unsigned line, const char* descr) - : _file(file), _line(line), _descr(descr) - { - EcEntryBase*& ec_head = get_thread_ec_head_ref(); - _previous = ec_head; - ec_head = this; - } - - EcEntryBase::~EcEntryBase() - { - get_thread_ec_head_ref() = _previous; - } - - // ------------------------------------------------------------------------ - - Text ec_to_text(const char* value) - { - // Add quotes around the string to make it obvious where it begin and ends. - // This is great for detecting erroneous leading or trailing spaces in e.g. an identifier. - auto str = "\"" + std::string(value) + "\""; - return Text{ STRDUP(str.c_str()) }; - } - - Text ec_to_text(char c) - { - // Add quotes around the character to make it obvious where it begin and ends. - std::string str = "'"; - - auto write_hex_digit = [&](unsigned num) - { - if (num < 10u) { str += char('0' + num); } else { str += char('a' + num - 10); } - }; - - auto write_hex_16 = [&](uint16_t n) - { - write_hex_digit((n >> 12u) & 0x0f); - write_hex_digit((n >> 8u) & 0x0f); - write_hex_digit((n >> 4u) & 0x0f); - write_hex_digit((n >> 0u) & 0x0f); - }; - - if (c == '\\') { str += "\\\\"; } else if (c == '\"') { str += "\\\""; } else if (c == '\'') { str += "\\\'"; } else if (c == '\0') { str += "\\0"; } else if (c == '\b') { str += "\\b"; } else if (c == '\f') { str += "\\f"; } else if (c == '\n') { str += "\\n"; } else if (c == '\r') { str += "\\r"; } else if (c == '\t') { str += "\\t"; } else if (0 <= c && c < 0x20) { - str += "\\u"; - write_hex_16(static_cast(c)); - } else { str += c; } - - str += "'"; - - return Text{ STRDUP(str.c_str()) }; - } - -#define DEFINE_EC(Type) \ - Text ec_to_text(Type value) \ - { \ - auto str = std::to_string(value); \ - return Text{STRDUP(str.c_str())}; \ - } - - DEFINE_EC(int) - DEFINE_EC(unsigned int) - DEFINE_EC(long) - DEFINE_EC(unsigned long) - DEFINE_EC(long long) - DEFINE_EC(unsigned long long) - DEFINE_EC(float) - DEFINE_EC(double) - DEFINE_EC(long double) - -#undef DEFINE_EC - - Text ec_to_text(EcHandle ec_handle) - { - Text parent_ec = get_error_context_for(ec_handle); - size_t buffer_size = strlen(parent_ec.c_str()) + 2; - char* with_newline = reinterpret_cast(malloc(buffer_size)); - with_newline[0] = '\n'; -#ifdef _WIN32 - strncpy_s(with_newline + 1, buffer_size, parent_ec.c_str(), buffer_size - 2); -#else - strcpy(with_newline + 1, parent_ec.c_str()); -#endif - return Text(with_newline); - } - - // ---------------------------------------------------------------------------- - -} // namespace loguru - -// ---------------------------------------------------------------------------- -// .dP"Y8 88 dP""b8 88b 88 db 88 .dP"Y8 -// `Ybo." 88 dP `" 88Yb88 dPYb 88 `Ybo." -// o.`Y8b 88 Yb "88 88 Y88 dP__Yb 88 .o o.`Y8b -// 8bodP' 88 YboodP 88 Y8 dP""""Yb 88ood8 8bodP' -// ---------------------------------------------------------------------------- - -#ifdef _WIN32 -namespace loguru { - void install_signal_handlers(const SignalOptions& signal_options) - { - (void)signal_options; - // TODO: implement signal handlers on windows - } -} // namespace loguru - -#else // _WIN32 - -namespace loguru { - void write_to_stderr(const char* data, size_t size) - { - auto result = write(STDERR_FILENO, data, size); - (void)result; // Ignore errors. - } - - void write_to_stderr(const char* data) - { - write_to_stderr(data, strlen(data)); - } - - void call_default_signal_handler(int signal_number) - { - struct sigaction sig_action; - memset(&sig_action, 0, sizeof(sig_action)); - sigemptyset(&sig_action.sa_mask); - sig_action.sa_handler = SIG_DFL; - sigaction(signal_number, &sig_action, NULL); - kill(getpid(), signal_number); - } - - void signal_handler(int signal_number, siginfo_t*, void*) - { - const char* signal_name = "UNKNOWN SIGNAL"; - - if (signal_number == SIGABRT) { signal_name = "SIGABRT"; } - if (signal_number == SIGBUS) { signal_name = "SIGBUS"; } - if (signal_number == SIGFPE) { signal_name = "SIGFPE"; } - if (signal_number == SIGILL) { signal_name = "SIGILL"; } - if (signal_number == SIGINT) { signal_name = "SIGINT"; } - if (signal_number == SIGSEGV) { signal_name = "SIGSEGV"; } - if (signal_number == SIGTERM) { signal_name = "SIGTERM"; } - - // -------------------------------------------------------------------- - /* There are few things that are safe to do in a signal handler, - but writing to stderr is one of them. - So we first print out what happened to stderr so we're sure that gets out, - then we do the unsafe things, like logging the stack trace. - */ - - if (g_colorlogtostderr && s_terminal_has_color) { - write_to_stderr(terminal_reset()); - write_to_stderr(terminal_bold()); - write_to_stderr(terminal_light_red()); - } - write_to_stderr("\n"); - write_to_stderr("Loguru caught a signal: "); - write_to_stderr(signal_name); - write_to_stderr("\n"); - if (g_colorlogtostderr && s_terminal_has_color) { - write_to_stderr(terminal_reset()); - } - - // -------------------------------------------------------------------- - - if (s_signal_options.unsafe_signal_handler) { - // -------------------------------------------------------------------- - /* Now we do unsafe things. This can for example lead to deadlocks if - the signal was triggered from the system's memory management functions - and the code below tries to do allocations. - */ - - flush(); - char preamble_buff[LOGURU_PREAMBLE_WIDTH]; - print_preamble(preamble_buff, sizeof(preamble_buff), Verbosity_FATAL, "", 0); - auto message = Message{ Verbosity_FATAL, "", 0, preamble_buff, "", "Signal: ", signal_name }; - try { - log_message(1, message, false, false); - } - catch (...) { - // This can happed due to s_fatal_handler. - write_to_stderr("Exception caught and ignored by Loguru signal handler.\n"); - } - flush(); - - // -------------------------------------------------------------------- - } - - call_default_signal_handler(signal_number); - } - - void install_signal_handlers(const SignalOptions& signal_options) - { - s_signal_options = signal_options; - - struct sigaction sig_action; - memset(&sig_action, 0, sizeof(sig_action)); - sigemptyset(&sig_action.sa_mask); - sig_action.sa_flags |= SA_SIGINFO; - sig_action.sa_sigaction = &signal_handler; - - if (signal_options.sigabrt) { - CHECK_F(sigaction(SIGABRT, &sig_action, NULL) != -1, "Failed to install handler for SIGABRT"); - } - if (signal_options.sigbus) { - CHECK_F(sigaction(SIGBUS, &sig_action, NULL) != -1, "Failed to install handler for SIGBUS"); - } - if (signal_options.sigfpe) { - CHECK_F(sigaction(SIGFPE, &sig_action, NULL) != -1, "Failed to install handler for SIGFPE"); - } - if (signal_options.sigill) { - CHECK_F(sigaction(SIGILL, &sig_action, NULL) != -1, "Failed to install handler for SIGILL"); - } - if (signal_options.sigint) { - CHECK_F(sigaction(SIGINT, &sig_action, NULL) != -1, "Failed to install handler for SIGINT"); - } - if (signal_options.sigsegv) { - CHECK_F(sigaction(SIGSEGV, &sig_action, NULL) != -1, "Failed to install handler for SIGSEGV"); - } - if (signal_options.sigterm) { - CHECK_F(sigaction(SIGTERM, &sig_action, NULL) != -1, "Failed to install handler for SIGTERM"); - } - } -} // namespace loguru - -#endif // _WIN32 - - -#if defined(__GNUC__) || defined(__clang__) -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif - -LOGURU_ANONYMOUS_NAMESPACE_END - -#endif // LOGURU_IMPLEMENTATION diff --git a/lib/log/loguru.hpp b/lib/log/loguru.hpp deleted file mode 100644 index 8917b79..0000000 --- a/lib/log/loguru.hpp +++ /dev/null @@ -1,1475 +0,0 @@ -/* -Loguru logging library for C++, by Emil Ernerfeldt. -www.github.com/emilk/loguru -If you find Loguru useful, please let me know on twitter or in a mail! -Twitter: @ernerfeldt -Mail: emil.ernerfeldt@gmail.com -Website: www.ilikebigbits.com - -# License - This software is in the public domain. Where that dedication is not - recognized, you are granted a perpetual, irrevocable license to - copy, modify and distribute it as you see fit. - -# Inspiration - Much of Loguru was inspired by GLOG, https://code.google.com/p/google-glog/. - The choice of public domain is fully due Sean T. Barrett - and his wonderful stb libraries at https://github.com/nothings/stb. - -# Version history - * Version 0.1.0 - 2015-03-22 - Works great on Mac. - * Version 0.2.0 - 2015-09-17 - Removed the only dependency. - * Version 0.3.0 - 2015-10-02 - Drop-in replacement for most of GLOG - * Version 0.4.0 - 2015-10-07 - Single-file! - * Version 0.5.0 - 2015-10-17 - Improved file logging - * Version 0.6.0 - 2015-10-24 - Add stack traces - * Version 0.7.0 - 2015-10-27 - Signals - * Version 0.8.0 - 2015-10-30 - Color logging. - * Version 0.9.0 - 2015-11-26 - ABORT_S and proper handling of FATAL - * Version 1.0.0 - 2016-02-14 - ERROR_CONTEXT - * Version 1.1.0 - 2016-02-19 - -v OFF, -v INFO etc - * Version 1.1.1 - 2016-02-20 - textprintf vs strprintf - * Version 1.1.2 - 2016-02-22 - Remove g_alsologtostderr - * Version 1.1.3 - 2016-02-29 - ERROR_CONTEXT as linked list - * Version 1.2.0 - 2016-03-19 - Add get_thread_name() - * Version 1.2.1 - 2016-03-20 - Minor fixes - * Version 1.2.2 - 2016-03-29 - Fix issues with set_fatal_handler throwing an exception - * Version 1.2.3 - 2016-05-16 - Log current working directory in loguru::init(). - * Version 1.2.4 - 2016-05-18 - Custom replacement for -v in loguru::init() by bjoernpollex - * Version 1.2.5 - 2016-05-18 - Add ability to print ERROR_CONTEXT of parent thread. - * Version 1.2.6 - 2016-05-19 - Bug fix regarding VLOG verbosity argument lacking (). - * Version 1.2.7 - 2016-05-23 - Fix PATH_MAX problem. - * Version 1.2.8 - 2016-05-26 - Add shutdown() and remove_all_callbacks() - * Version 1.2.9 - 2016-06-09 - Use a monotonic clock for uptime. - * Version 1.3.0 - 2016-07-20 - Fix issues with callback flush/close not being called. - * Version 1.3.1 - 2016-07-20 - Add LOGURU_UNSAFE_SIGNAL_HANDLER to toggle stacktrace on signals. - * Version 1.3.2 - 2016-07-20 - Add loguru::arguments() - * Version 1.4.0 - 2016-09-15 - Semantic versioning + add loguru::create_directories - * Version 1.4.1 - 2016-09-29 - Customize formating with LOGURU_FILENAME_WIDTH - * Version 1.5.0 - 2016-12-22 - LOGURU_USE_FMTLIB by kolis and LOGURU_WITH_FILEABS by scinart - * Version 1.5.1 - 2017-08-08 - Terminal colors on Windows 10 thanks to looki - * Version 1.6.0 - 2018-01-03 - Add LOGURU_RTTI and LOGURU_STACKTRACES settings - * Version 1.7.0 - 2018-01-03 - Add ability to turn off the preamble with loguru::g_preamble - * Version 1.7.1 - 2018-04-05 - Add function get_fatal_handler - * Version 1.7.2 - 2018-04-22 - Fix a bug where large file names could cause stack corruption (thanks @ccamporesi) - * Version 1.8.0 - 2018-04-23 - Shorten long file names to keep preamble fixed width - * Version 1.9.0 - 2018-09-22 - Adjust terminal colors, add LOGURU_VERBOSE_SCOPE_ENDINGS, add LOGURU_SCOPE_TIME_PRECISION, add named log levels - * Version 2.0.0 - 2018-09-22 - Split loguru.hpp into loguru.hpp and loguru.cpp - * Version 2.1.0 - 2019-09-23 - Update fmtlib + add option to loguru::init to NOT set main thread name. - * Version 2.2.0 - 2020-07-31 - Replace LOGURU_CATCH_SIGABRT with struct SignalOptions - -# Compiling - Just include where you want to use Loguru. - Then, in one .cpp file #include - Make sure you compile with -std=c++11 -lstdc++ -lpthread -ldl - -# Usage - For details, please see the official documentation at emilk.github.io/loguru - - #include - - int main(int argc, char* argv[]) { - loguru::init(argc, argv); - - // Put every log message in "everything.log": - loguru::add_file("everything.log", loguru::Append, loguru::Verbosity_MAX); - - LOG_F(INFO, "The magic number is %d", 42); - } - -*/ - -#if defined(LOGURU_IMPLEMENTATION) -#error "You are defining LOGURU_IMPLEMENTATION. This is for older versions of Loguru. You should now instead include loguru.cpp (or build it and link with it)" -#endif - -// Disable all warnings from gcc/clang: -#if defined(__clang__) -#pragma clang system_header -#elif defined(__GNUC__) -#pragma GCC system_header -#endif - -#ifndef LOGURU_HAS_DECLARED_FORMAT_HEADER -#define LOGURU_HAS_DECLARED_FORMAT_HEADER - -// Semantic versioning. Loguru version can be printed with printf("%d.%d.%d", LOGURU_VERSION_MAJOR, LOGURU_VERSION_MINOR, LOGURU_VERSION_PATCH); -#define LOGURU_VERSION_MAJOR 2 -#define LOGURU_VERSION_MINOR 1 -#define LOGURU_VERSION_PATCH 0 - -#if defined(_MSC_VER) -#include // Needed for _In_z_ etc annotations -#endif - -#if defined(__linux__) || defined(__APPLE__) -#define LOGURU_SYSLOG 1 -#else -#define LOGURU_SYSLOG 0 -#endif - -// ---------------------------------------------------------------------------- - -#ifndef LOGURU_EXPORT - // Define to your project's export declaration if needed for use in a shared library. -#define LOGURU_EXPORT -#endif - -#ifndef LOGURU_SCOPE_TEXT_SIZE - // Maximum length of text that can be printed by a LOG_SCOPE. - // This should be long enough to get most things, but short enough not to clutter the stack. -#define LOGURU_SCOPE_TEXT_SIZE 196 -#endif - -#ifndef LOGURU_FILENAME_WIDTH - // Width of the column containing the file name -#define LOGURU_FILENAME_WIDTH 23 -#endif - -#ifndef LOGURU_THREADNAME_WIDTH - // Width of the column containing the thread name -#define LOGURU_THREADNAME_WIDTH 16 -#endif - -#ifndef LOGURU_SCOPE_TIME_PRECISION - // Resolution of scope timers. 3=ms, 6=us, 9=ns -#define LOGURU_SCOPE_TIME_PRECISION 3 -#endif - -#ifdef LOGURU_CATCH_SIGABRT -#error "You are defining LOGURU_CATCH_SIGABRT. This is for older versions of Loguru. You should now instead set the options passed to loguru::init" -#endif - -#ifndef LOGURU_VERBOSE_SCOPE_ENDINGS - // Show milliseconds and scope name at end of scope. -#define LOGURU_VERBOSE_SCOPE_ENDINGS 1 -#endif - -#ifndef LOGURU_REDEFINE_ASSERT -#define LOGURU_REDEFINE_ASSERT 0 -#endif - -#ifndef LOGURU_WITH_STREAMS -#define LOGURU_WITH_STREAMS 0 -#endif - -#ifndef LOGURU_REPLACE_GLOG -#define LOGURU_REPLACE_GLOG 0 -#endif - -#if LOGURU_REPLACE_GLOG -#undef LOGURU_WITH_STREAMS -#define LOGURU_WITH_STREAMS 1 -#endif - -#if defined(LOGURU_UNSAFE_SIGNAL_HANDLER) -#error "You are defining LOGURU_UNSAFE_SIGNAL_HANDLER. This is for older versions of Loguru. You should now instead set the unsafe_signal_handler option when you call loguru::init." -#endif - -#if LOGURU_IMPLEMENTATION -#undef LOGURU_WITH_STREAMS -#define LOGURU_WITH_STREAMS 1 -#endif - -#ifndef LOGURU_USE_FMTLIB -#define LOGURU_USE_FMTLIB 0 -#endif - -#ifndef LOGURU_USE_LOCALE -#define LOGURU_USE_LOCALE 0 -#endif - -#ifndef LOGURU_WITH_FILEABS -#define LOGURU_WITH_FILEABS 0 -#endif - -#ifndef LOGURU_RTTI -#if defined(__clang__) -#if __has_feature(cxx_rtti) -#define LOGURU_RTTI 1 -#endif -#elif defined(__GNUG__) -#if defined(__GXX_RTTI) -#define LOGURU_RTTI 1 -#endif -#elif defined(_MSC_VER) -#if defined(_CPPRTTI) -#define LOGURU_RTTI 1 -#endif -#endif -#endif - -#ifdef LOGURU_USE_ANONYMOUS_NAMESPACE -#define LOGURU_ANONYMOUS_NAMESPACE_BEGIN namespace { -#define LOGURU_ANONYMOUS_NAMESPACE_END } -#else -#define LOGURU_ANONYMOUS_NAMESPACE_BEGIN -#define LOGURU_ANONYMOUS_NAMESPACE_END -#endif - -// -------------------------------------------------------------------- -// Utility macros - -#define LOGURU_CONCATENATE_IMPL(s1, s2) s1 ## s2 -#define LOGURU_CONCATENATE(s1, s2) LOGURU_CONCATENATE_IMPL(s1, s2) - -#ifdef __COUNTER__ -# define LOGURU_ANONYMOUS_VARIABLE(str) LOGURU_CONCATENATE(str, __COUNTER__) -#else -# define LOGURU_ANONYMOUS_VARIABLE(str) LOGURU_CONCATENATE(str, __LINE__) -#endif - -#if defined(__clang__) || defined(__GNUC__) - // Helper macro for declaring functions as having similar signature to printf. - // This allows the compiler to catch format errors at compile-time. -#define LOGURU_PRINTF_LIKE(fmtarg, firstvararg) __attribute__((__format__ (__printf__, fmtarg, firstvararg))) -#define LOGURU_FORMAT_STRING_TYPE const char* -#elif defined(_MSC_VER) -#define LOGURU_PRINTF_LIKE(fmtarg, firstvararg) -#define LOGURU_FORMAT_STRING_TYPE _In_z_ _Printf_format_string_ const char* -#else -#define LOGURU_PRINTF_LIKE(fmtarg, firstvararg) -#define LOGURU_FORMAT_STRING_TYPE const char* -#endif - -// Used to mark log_and_abort for the benefit of the static analyzer and optimizer. -#if defined(_MSC_VER) -#define LOGURU_NORETURN __declspec(noreturn) -#else -#define LOGURU_NORETURN __attribute__((noreturn)) -#endif - -#if defined(_MSC_VER) -#define LOGURU_PREDICT_FALSE(x) (x) -#define LOGURU_PREDICT_TRUE(x) (x) -#else -#define LOGURU_PREDICT_FALSE(x) (__builtin_expect(x, 0)) -#define LOGURU_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) -#endif - -#if LOGURU_USE_FMTLIB -#include -#define LOGURU_FMT(x) "{:" #x "}" -#else -#define LOGURU_FMT(x) "%" #x -#endif - -#ifdef _WIN32 -#define STRDUP(str) _strdup(str) -#else -#define STRDUP(str) strdup(str) -#endif - -#include - -// -------------------------------------------------------------------- -LOGURU_ANONYMOUS_NAMESPACE_BEGIN - -namespace loguru { - // Simple RAII ownership of a char*. - class LOGURU_EXPORT Text { - public: - explicit Text(char* owned_str) : _str(owned_str) {} - ~Text(); - Text(Text&& t) - { - _str = t._str; - t._str = nullptr; - } - Text(Text& t) = delete; - Text& operator=(Text& t) = delete; - void operator=(Text&& t) = delete; - - const char* c_str() const { return _str; } - bool empty() const { return _str == nullptr || *_str == '\0'; } - - char* release() - { - auto result = _str; - _str = nullptr; - return result; - } - - private: - char* _str; - }; - - // Like printf, but returns the formated text. -#if LOGURU_USE_FMTLIB - LOGURU_EXPORT - Text vtextprintf(const char* format, fmt::format_args args); - - template - LOGURU_EXPORT - Text textprintf(LOGURU_FORMAT_STRING_TYPE format, const Args&... args) - { - return vtextprintf(format, fmt::make_format_args(args...)); - } -#else - LOGURU_EXPORT - Text textprintf(LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(1, 2); -#endif - - // Overloaded for variadic template matching. - LOGURU_EXPORT - Text textprintf(); - - using Verbosity = int; - -#undef FATAL -#undef ERROR -#undef WARNING -#undef INFO -#undef MAX - - enum NamedVerbosity : Verbosity { - // Used to mark an invalid verbosity. Do not log to this level. - Verbosity_INVALID = -10, // Never do LOG_F(INVALID) - - // You may use Verbosity_OFF on g_stderr_verbosity, but for nothing else! - Verbosity_OFF = -9, // Never do LOG_F(OFF) - - // Prefer to use ABORT_F or ABORT_S over LOG_F(FATAL) or LOG_S(FATAL). - Verbosity_FATAL = -3, - Verbosity_ERROR = -2, - Verbosity_WARNING = -1, - - // Normal messages. By default written to stderr. - Verbosity_INFO = 0, - - // Same as Verbosity_INFO in every way. - Verbosity_0 = 0, - - // Verbosity levels 1-9 are generally not written to stderr, but are written to file. - Verbosity_1 = +1, - Verbosity_2 = +2, - Verbosity_3 = +3, - Verbosity_4 = +4, - Verbosity_5 = +5, - Verbosity_6 = +6, - Verbosity_7 = +7, - Verbosity_8 = +8, - Verbosity_9 = +9, - - // Do not use higher verbosity levels, as that will make grepping log files harder. - Verbosity_MAX = +9, - }; - - struct Message { - // You would generally print a Message by just concatenating the buffers without spacing. - // Optionally, ignore preamble and indentation. - Verbosity verbosity; // Already part of preamble - const char* filename; // Already part of preamble - unsigned line; // Already part of preamble - const char* preamble; // Date, time, uptime, thread, file:line, verbosity. - const char* indentation; // Just a bunch of spacing. - const char* prefix; // Assertion failure info goes here (or ""). - const char* message; // User message goes here. - }; - - /* Everything with a verbosity equal or greater than g_stderr_verbosity will be - written to stderr. You can set this in code or via the -v argument. - Set to loguru::Verbosity_OFF to write nothing to stderr. - Default is 0, i.e. only log ERROR, WARNING and INFO are written to stderr. - */ - LOGURU_EXPORT extern Verbosity g_stderr_verbosity; - LOGURU_EXPORT extern bool g_colorlogtostderr; // True by default. - LOGURU_EXPORT extern unsigned g_flush_interval_ms; // 0 (unbuffered) by default. - LOGURU_EXPORT extern bool g_preamble_header; // Prepend each log start by a descriptions line with all columns name? True by default. - LOGURU_EXPORT extern bool g_preamble; // Prefix each log line with date, time etc? True by default. - - /* Specify the verbosity used by loguru to log its info messages including the header - logged when logged::init() is called or on exit. Default is 0 (INFO). - */ - LOGURU_EXPORT extern Verbosity g_internal_verbosity; - - // Turn off individual parts of the preamble - LOGURU_EXPORT extern bool g_preamble_date; // The date field - LOGURU_EXPORT extern bool g_preamble_time; // The time of the current day - LOGURU_EXPORT extern bool g_preamble_uptime; // The time since init call - LOGURU_EXPORT extern bool g_preamble_thread; // The logging thread - LOGURU_EXPORT extern bool g_preamble_file; // The file from which the log originates from - LOGURU_EXPORT extern bool g_preamble_verbose; // The verbosity field - LOGURU_EXPORT extern bool g_preamble_pipe; // The pipe symbol right before the message - - // May not throw! - typedef void (*log_handler_t)(void* user_data, const Message& message); - typedef void (*close_handler_t)(void* user_data); - typedef void (*flush_handler_t)(void* user_data); - - // May throw if that's how you'd like to handle your errors. - typedef void (*fatal_handler_t)(const Message& message); - - // Given a verbosity level, return the level's name or nullptr. - typedef const char* (*verbosity_to_name_t)(Verbosity verbosity); - - // Given a verbosity level name, return the verbosity level or - // Verbosity_INVALID if name is not recognized. - typedef Verbosity(*name_to_verbosity_t)(const char* name); - - struct SignalOptions { - /// Make Loguru try to do unsafe but useful things, - /// like printing a stack trace, when catching signals. - /// This may lead to bad things like deadlocks in certain situations. - bool unsafe_signal_handler = true; - - /// Should Loguru catch SIGABRT ? - bool sigabrt = true; - - /// Should Loguru catch SIGBUS ? - bool sigbus = true; - - /// Should Loguru catch SIGFPE ? - bool sigfpe = true; - - /// Should Loguru catch SIGILL ? - bool sigill = true; - - /// Should Loguru catch SIGINT ? - bool sigint = true; - - /// Should Loguru catch SIGSEGV ? - bool sigsegv = true; - - /// Should Loguru catch SIGTERM ? - bool sigterm = true; - - static SignalOptions none() - { - SignalOptions options; - options.unsafe_signal_handler = false; - options.sigabrt = false; - options.sigbus = false; - options.sigfpe = false; - options.sigill = false; - options.sigint = false; - options.sigsegv = false; - options.sigterm = false; - return options; - } - }; - - // Runtime options passed to loguru::init - struct Options { - // This allows you to use something else instead of "-v" via verbosity_flag. - // Set to nullptr if you don't want Loguru to parse verbosity from the args. - const char* verbosity_flag = "-v"; - - // loguru::init will set the name of the calling thread to this. - // If you don't want Loguru to set the name of the main thread, - // set this to nullptr. - // NOTE: on SOME platforms loguru::init will only overwrite the thread name - // if a thread name has not already been set. - // To always set a thread name, use loguru::set_thread_name instead. - const char* main_thread_name = "main thread"; - - SignalOptions signal_options; - }; - - /* Should be called from the main thread. - You don't *need* to call this, but if you do you get: - * Signal handlers installed - * Program arguments logged - * Working dir logged - * Optional -v verbosity flag parsed - * Main thread name set to "main thread" - * Explanation of the preamble (date, thread name, etc) logged - - loguru::init() will look for arguments meant for loguru and remove them. - Arguments meant for loguru are: - -v n Set loguru::g_stderr_verbosity level. Examples: - -v 3 Show verbosity level 3 and lower. - -v 0 Only show INFO, WARNING, ERROR, FATAL (default). - -v INFO Only show INFO, WARNING, ERROR, FATAL (default). - -v WARNING Only show WARNING, ERROR, FATAL. - -v ERROR Only show ERROR, FATAL. - -v FATAL Only show FATAL. - -v OFF Turn off logging to stderr. - - Tip: You can set g_stderr_verbosity before calling loguru::init. - That way you can set the default but have the user override it with the -v flag. - Note that -v does not affect file logging (see loguru::add_file). - - You can you something other than the -v flag by setting the verbosity_flag option. - */ - LOGURU_EXPORT - void init(int& argc, char* argv[], const Options& options = {}); - - // Will call remove_all_callbacks(). After calling this, logging will still go to stderr. - // You generally don't need to call this. - LOGURU_EXPORT - void shutdown(); - - // What ~ will be replaced with, e.g. "/home/your_user_name/" - LOGURU_EXPORT - const char* home_dir(); - - /* Returns the name of the app as given in argv[0] but without leading path. - That is, if argv[0] is "../foo/app" this will return "app". - */ - LOGURU_EXPORT - const char* argv0_filename(); - - // Returns all arguments given to loguru::init(), but escaped with a single space as separator. - LOGURU_EXPORT - const char* arguments(); - - // Returns the path to the current working dir when loguru::init() was called. - LOGURU_EXPORT - const char* current_dir(); - - // Returns the part of the path after the last / or \ (if any). - LOGURU_EXPORT - const char* filename(const char* path); - - // e.g. "foo/bar/baz.ext" will create the directories "foo/" and "foo/bar/" - LOGURU_EXPORT - bool create_directories(const char* file_path_const); - - // Writes date and time with millisecond precision, e.g. "20151017_161503.123" - LOGURU_EXPORT - void write_date_time(char* buff, unsigned long long buff_size); - - // Helper: thread-safe version strerror - LOGURU_EXPORT - Text errno_as_text(); - - /* Given a prefix of e.g. "~/loguru/" this might return - "/home/your_username/loguru/app_name/20151017_161503.123.log" - - where "app_name" is a sanitized version of argv[0]. - */ - LOGURU_EXPORT - void suggest_log_path(const char* prefix, char* buff, unsigned long long buff_size); - - enum FileMode { Truncate, Append }; - - /* Will log to a file at the given path. - Any logging message with a verbosity lower or equal to - the given verbosity will be included. - The function will create all directories in 'path' if needed. - If path starts with a ~, it will be replaced with loguru::home_dir() - To stop the file logging, just call loguru::remove_callback(path) with the same path. - */ - LOGURU_EXPORT - bool add_file(const char* path, FileMode mode, Verbosity verbosity); - - LOGURU_EXPORT - // Send logs to syslog with LOG_USER facility (see next call) - bool add_syslog(const char* app_name, Verbosity verbosity); - LOGURU_EXPORT - // Send logs to syslog with your own choice of facility (LOG_USER, LOG_AUTH, ...) - // see loguru.cpp: syslog_log() for more details. - bool add_syslog(const char* app_name, Verbosity verbosity, int facility); - - /* Will be called right before abort(). - You can for instance use this to print custom error messages, or throw an exception. - Feel free to call LOG:ing function from this, but not FATAL ones! */ - LOGURU_EXPORT - void set_fatal_handler(fatal_handler_t handler); - - // Get the current fatal handler, if any. Default value is nullptr. - LOGURU_EXPORT - fatal_handler_t get_fatal_handler(); - - /* Will be called on each log messages with a verbosity less or equal to the given one. - Useful for displaying messages on-screen in a game, for example. - The given on_close is also expected to flush (if desired). - */ - LOGURU_EXPORT - void add_callback( - const char* id, - log_handler_t callback, - void* user_data, - Verbosity verbosity, - close_handler_t on_close = nullptr, - flush_handler_t on_flush = nullptr); - - /* Set a callback that returns custom verbosity level names. If callback - is nullptr or returns nullptr, default log names will be used. - */ - LOGURU_EXPORT - void set_verbosity_to_name_callback(verbosity_to_name_t callback); - - /* Set a callback that returns the verbosity level matching a name. The - callback should return Verbosity_INVALID if the name is not - recognized. - */ - LOGURU_EXPORT - void set_name_to_verbosity_callback(name_to_verbosity_t callback); - - /* Get a custom name for a specific verbosity, if one exists, or nullptr. */ - LOGURU_EXPORT - const char* get_verbosity_name(Verbosity verbosity); - - /* Get the verbosity enum value from a custom 4-character level name, if one exists. - If the name does not match a custom level name, Verbosity_INVALID is returned. - */ - LOGURU_EXPORT - Verbosity get_verbosity_from_name(const char* name); - - // Returns true iff the callback was found (and removed). - LOGURU_EXPORT - bool remove_callback(const char* id); - - // Shut down all file logging and any other callback hooks installed. - LOGURU_EXPORT - void remove_all_callbacks(); - - // Returns the maximum of g_stderr_verbosity and all file/custom outputs. - LOGURU_EXPORT - Verbosity current_verbosity_cutoff(); - -#if LOGURU_USE_FMTLIB - // Internal functions - LOGURU_EXPORT - void vlog(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, fmt::format_args args); - LOGURU_EXPORT - void raw_vlog(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, fmt::format_args args); - - // Actual logging function. Use the LOG macro instead of calling this directly. - template - LOGURU_EXPORT - void log(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, const Args &... args) - { - vlog(verbosity, file, line, format, fmt::make_format_args(args...)); - } - - // Log without any preamble or indentation. - template - LOGURU_EXPORT - void raw_log(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, const Args &... args) - { - raw_vlog(verbosity, file, line, format, fmt::make_format_args(args...)); - } -#else // LOGURU_USE_FMTLIB? - // Actual logging function. Use the LOG macro instead of calling this directly. - LOGURU_EXPORT - void log(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(4, 5); - - // Actual logging function. - LOGURU_EXPORT - void vlog(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, va_list) LOGURU_PRINTF_LIKE(4, 0); - - // Log without any preamble or indentation. - LOGURU_EXPORT - void raw_log(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(4, 5); -#endif // !LOGURU_USE_FMTLIB - - // Helper class for LOG_SCOPE_F - class LOGURU_EXPORT LogScopeRAII { - public: - LogScopeRAII() : _file(nullptr) {} // No logging - LogScopeRAII(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, va_list vlist) LOGURU_PRINTF_LIKE(5, 0); - LogScopeRAII(Verbosity verbosity, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(5, 6); - ~LogScopeRAII(); - - void Init(LOGURU_FORMAT_STRING_TYPE format, va_list vlist) LOGURU_PRINTF_LIKE(2, 0); - -#if defined(_MSC_VER) && _MSC_VER > 1800 - // older MSVC default move ctors close the scope on move. See - // issue #43 - LogScopeRAII(LogScopeRAII&& other) - : _verbosity(other._verbosity) - , _file(other._file) - , _line(other._line) - , _indent_stderr(other._indent_stderr) - , _start_time_ns(other._start_time_ns) - { - // Make sure the tmp object's destruction doesn't close the scope: - other._file = nullptr; - - for (unsigned int i = 0; i < LOGURU_SCOPE_TEXT_SIZE; ++i) { - _name[i] = other._name[i]; - } - } -#else - LogScopeRAII(LogScopeRAII&&) = default; -#endif - - private: - LogScopeRAII(const LogScopeRAII&) = delete; - LogScopeRAII& operator=(const LogScopeRAII&) = delete; - void operator=(LogScopeRAII&&) = delete; - - Verbosity _verbosity; - const char* _file; // Set to null if we are disabled due to verbosity - unsigned _line; - bool _indent_stderr; // Did we? - long long _start_time_ns; - char _name[LOGURU_SCOPE_TEXT_SIZE]; - }; - - // Marked as 'noreturn' for the benefit of the static analyzer and optimizer. - // stack_trace_skip is the number of extrace stack frames to skip above log_and_abort. -#if LOGURU_USE_FMTLIB - LOGURU_EXPORT - LOGURU_NORETURN void vlog_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, fmt::format_args); - template - LOGURU_EXPORT - LOGURU_NORETURN void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, const Args&... args) - { - vlog_and_abort(stack_trace_skip, expr, file, line, format, fmt::make_format_args(args...)); - } -#else - LOGURU_EXPORT - LOGURU_NORETURN void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line, LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(5, 6); -#endif - LOGURU_EXPORT - LOGURU_NORETURN void log_and_abort(int stack_trace_skip, const char* expr, const char* file, unsigned line); - - // Flush output to stderr and files. - // If g_flush_interval_ms is set to non-zero, this will be called automatically this often. - // If not set, you do not need to call this at all. - LOGURU_EXPORT - void flush(); - - template inline Text format_value(const T&) { return textprintf("N/A"); } - template<> inline Text format_value(const char& v) { return textprintf(LOGURU_FMT(c), v); } - template<> inline Text format_value(const int& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const float& v) { return textprintf(LOGURU_FMT(f), v); } - template<> inline Text format_value(const double& v) { return textprintf(LOGURU_FMT(f), v); } - -#if LOGURU_USE_FMTLIB - template<> inline Text format_value(const unsigned int& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const long& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const unsigned long& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const long long& v) { return textprintf(LOGURU_FMT(d), v); } - template<> inline Text format_value(const unsigned long long& v) { return textprintf(LOGURU_FMT(d), v); } -#else - template<> inline Text format_value(const unsigned int& v) { return textprintf(LOGURU_FMT(u), v); } - template<> inline Text format_value(const long& v) { return textprintf(LOGURU_FMT(lu), v); } - template<> inline Text format_value(const unsigned long& v) { return textprintf(LOGURU_FMT(ld), v); } - template<> inline Text format_value(const long long& v) { return textprintf(LOGURU_FMT(llu), v); } - template<> inline Text format_value(const unsigned long long& v) { return textprintf(LOGURU_FMT(lld), v); } -#endif - - /* Thread names can be set for the benefit of readable logs. - If you do not set the thread name, a hex id will be shown instead. - These thread names may or may not be the same as the system thread names, - depending on the system. - Try to limit the thread name to 15 characters or less. */ - LOGURU_EXPORT - void set_thread_name(const char* name); - - /* Returns the thread name for this thread. - On most *nix systems this will return the system thread name (settable from both within and without Loguru). - On other systems it will return whatever you set in `set_thread_name()`; - If no thread name is set, this will return a hexadecimal thread id. - `length` should be the number of bytes available in the buffer. - 17 is a good number for length. - `right_align_hex_id` means any hexadecimal thread id will be written to the end of buffer. - */ - LOGURU_EXPORT - void get_thread_name(char* buffer, unsigned long long length, bool right_align_hex_id); - - /* Generates a readable stacktrace as a string. - 'skip' specifies how many stack frames to skip. - For instance, the default skip (1) means: - don't include the call to loguru::stacktrace in the stack trace. */ - LOGURU_EXPORT - Text stacktrace(int skip = 1); - - /* Add a string to be replaced with something else in the stack output. - - For instance, instead of having a stack trace look like this: - 0x41f541 some_function(std::basic_ofstream >&) - You can clean it up with: - auto verbose_type_name = loguru::demangle(typeid(std::ofstream).name()); - loguru::add_stack_cleanup(verbose_type_name.c_str(); "std::ofstream"); - So the next time you will instead see: - 0x41f541 some_function(std::ofstream&) - - `replace_with_this` must be shorter than `find_this`. - */ - LOGURU_EXPORT - void add_stack_cleanup(const char* find_this, const char* replace_with_this); - - // Example: demangle(typeid(std::ofstream).name()) -> "std::basic_ofstream >" - LOGURU_EXPORT - Text demangle(const char* name); - - // ------------------------------------------------------------------------ - /* - Not all terminals support colors, but if they do, and g_colorlogtostderr - is set, Loguru will write them to stderr to make errors in red, etc. - - You also have the option to manually use them, via the function below. - - Note, however, that if you do, the color codes could end up in your logfile! - - This means if you intend to use them functions you should either: - * Use them on the stderr/stdout directly (bypass Loguru). - * Don't add file outputs to Loguru. - * Expect some \e[1m things in your logfile. - - Usage: - printf("%sRed%sGreen%sBold green%sClear again\n", - loguru::terminal_red(), loguru::terminal_green(), - loguru::terminal_bold(), loguru::terminal_reset()); - - If the terminal at hand does not support colors the above output - will just not have funky \e[1m things showing. - */ - - // Do the output terminal support colors? - LOGURU_EXPORT - bool terminal_has_color(); - - // Colors - LOGURU_EXPORT const char* terminal_black(); - LOGURU_EXPORT const char* terminal_red(); - LOGURU_EXPORT const char* terminal_green(); - LOGURU_EXPORT const char* terminal_yellow(); - LOGURU_EXPORT const char* terminal_blue(); - LOGURU_EXPORT const char* terminal_purple(); - LOGURU_EXPORT const char* terminal_cyan(); - LOGURU_EXPORT const char* terminal_light_gray(); - LOGURU_EXPORT const char* terminal_light_red(); - LOGURU_EXPORT const char* terminal_white(); - - // Formating - LOGURU_EXPORT const char* terminal_bold(); - LOGURU_EXPORT const char* terminal_underline(); - - // You should end each line with this! - LOGURU_EXPORT const char* terminal_reset(); - - // -------------------------------------------------------------------- - // Error context related: - - struct StringStream; - - // Use this in your EcEntryBase::print_value overload. - LOGURU_EXPORT - void stream_print(StringStream& out_string_stream, const char* text); - - class LOGURU_EXPORT EcEntryBase { - public: - EcEntryBase(const char* file, unsigned line, const char* descr); - ~EcEntryBase(); - EcEntryBase(const EcEntryBase&) = delete; - EcEntryBase(EcEntryBase&&) = delete; - EcEntryBase& operator=(const EcEntryBase&) = delete; - EcEntryBase& operator=(EcEntryBase&&) = delete; - - virtual void print_value(StringStream& out_string_stream) const = 0; - - EcEntryBase* previous() const { return _previous; } - - // private: - const char* _file; - unsigned _line; - const char* _descr; - EcEntryBase* _previous; - }; - - template - class EcEntryData : public EcEntryBase { - public: - using Printer = Text(*)(T data); - - EcEntryData(const char* file, unsigned line, const char* descr, T data, Printer&& printer) - : EcEntryBase(file, line, descr), _data(data), _printer(printer) - { - } - - virtual void print_value(StringStream& out_string_stream) const override - { - const auto str = _printer(_data); - stream_print(out_string_stream, str.c_str()); - } - - private: - T _data; - Printer _printer; - }; - - // template - // class EcEntryLambda : public EcEntryBase - // { - // public: - // EcEntryLambda(const char* file, unsigned line, const char* descr, Printer&& printer) - // : EcEntryBase(file, line, descr), _printer(std::move(printer)) {} - - // virtual void print_value(StringStream& out_string_stream) const override - // { - // const auto str = _printer(); - // stream_print(out_string_stream, str.c_str()); - // } - - // private: - // Printer _printer; - // }; - - // template - // EcEntryLambda make_ec_entry_lambda(const char* file, unsigned line, const char* descr, Printer&& printer) - // { - // return {file, line, descr, std::move(printer)}; - // } - - template - struct decay_char_array { using type = T; }; - - template - struct decay_char_array { using type = const char*; }; - - template - struct make_const_ptr { using type = T; }; - - template - struct make_const_ptr { using type = const T*; }; - - template - struct make_ec_type { using type = typename make_const_ptr::type>::type; }; - - /* A stack trace gives you the names of the function at the point of a crash. - With ERROR_CONTEXT, you can also get the values of select local variables. - Usage: - - void process_customers(const std::string& filename) - { - ERROR_CONTEXT("Processing file", filename.c_str()); - for (int customer_index : ...) - { - ERROR_CONTEXT("Customer index", customer_index); - ... - } - } - - The context is in effect during the scope of the ERROR_CONTEXT. - Use loguru::get_error_context() to get the contents of the active error contexts. - - Example result: - - ------------------------------------------------ - [ErrorContext] main.cpp:416 Processing file: "customers.json" - [ErrorContext] main.cpp:417 Customer index: 42 - ------------------------------------------------ - - Error contexts are printed automatically on crashes, and only on crashes. - This makes them much faster than logging the value of a variable. - */ -#define ERROR_CONTEXT(descr, data) \ - const loguru::EcEntryData::type> \ - LOGURU_ANONYMOUS_VARIABLE(error_context_scope_)( \ - __FILE__, __LINE__, descr, data, \ - static_cast::type>::Printer>(loguru::ec_to_text) ) // For better error messages - - /* - #define ERROR_CONTEXT(descr, data) \ - const auto LOGURU_ANONYMOUS_VARIABLE(error_context_scope_)( \ - loguru::make_ec_entry_lambda(__FILE__, __LINE__, descr, \ - [=](){ return loguru::ec_to_text(data); })) - */ - - using EcHandle = const EcEntryBase*; - - /* - Get a light-weight handle to the error context stack on this thread. - The handle is valid as long as the current thread has no changes to its error context stack. - You can pass the handle to loguru::get_error_context on another thread. - This can be very useful for when you have a parent thread spawning several working threads, - and you want the error context of the parent thread to get printed (too) when there is an - error on the child thread. You can accomplish this thusly: - - void foo(const char* parameter) - { - ERROR_CONTEXT("parameter", parameter) - const auto parent_ec_handle = loguru::get_thread_ec_handle(); - - std::thread([=]{ - loguru::set_thread_name("child thread"); - ERROR_CONTEXT("parent context", parent_ec_handle); - dangerous_code(); - }.join(); - } - - */ - LOGURU_EXPORT - EcHandle get_thread_ec_handle(); - - // Get a string describing the current stack of error context. Empty string if there is none. - LOGURU_EXPORT - Text get_error_context(); - - // Get a string describing the error context of the given thread handle. - LOGURU_EXPORT - Text get_error_context_for(EcHandle ec_handle); - - // ------------------------------------------------------------------------ - - LOGURU_EXPORT Text ec_to_text(const char* data); - LOGURU_EXPORT Text ec_to_text(char data); - LOGURU_EXPORT Text ec_to_text(int data); - LOGURU_EXPORT Text ec_to_text(unsigned int data); - LOGURU_EXPORT Text ec_to_text(long data); - LOGURU_EXPORT Text ec_to_text(unsigned long data); - LOGURU_EXPORT Text ec_to_text(long long data); - LOGURU_EXPORT Text ec_to_text(unsigned long long data); - LOGURU_EXPORT Text ec_to_text(float data); - LOGURU_EXPORT Text ec_to_text(double data); - LOGURU_EXPORT Text ec_to_text(long double data); - LOGURU_EXPORT Text ec_to_text(EcHandle); - - /* - You can add ERROR_CONTEXT support for your own types by overloading ec_to_text. Here's how: - - some.hpp: - namespace loguru { - Text ec_to_text(MySmallType data) - Text ec_to_text(const MyBigType* data) - } // namespace loguru - - some.cpp: - namespace loguru { - Text ec_to_text(MySmallType small_value) - { - // Called only when needed, i.e. on a crash. - std::string str = small_value.as_string(); // Format 'small_value' here somehow. - return Text{STRDUP(str.c_str())}; - } - - Text ec_to_text(const MyBigType* big_value) - { - // Called only when needed, i.e. on a crash. - std::string str = big_value->as_string(); // Format 'big_value' here somehow. - return Text{STRDUP(str.c_str())}; - } - } // namespace loguru - - Any file that include some.hpp: - void foo(MySmallType small, const MyBigType& big) - { - ERROR_CONTEXT("Small", small); // Copy ´small` by value. - ERROR_CONTEXT("Big", &big); // `big` should not change during this scope! - .... - } - */ -} // namespace loguru - -LOGURU_ANONYMOUS_NAMESPACE_END - -// -------------------------------------------------------------------- -// Logging macros - -// LOG_F(2, "Only logged if verbosity is 2 or higher: %d", some_number); -#define VLOG_F(verbosity, ...) \ - ((verbosity) > loguru::current_verbosity_cutoff()) ? (void)0 \ - : loguru::log(verbosity, __FILE__, __LINE__, __VA_ARGS__) - -// LOG_F(INFO, "Foo: %d", some_number); -#define LOG_F(verbosity_name, ...) VLOG_F(loguru::Verbosity_ ## verbosity_name, __VA_ARGS__) - -#define VLOG_IF_F(verbosity, cond, ...) \ - ((verbosity) > loguru::current_verbosity_cutoff() || (cond) == false) \ - ? (void)0 \ - : loguru::log(verbosity, __FILE__, __LINE__, __VA_ARGS__) - -#define LOG_IF_F(verbosity_name, cond, ...) \ - VLOG_IF_F(loguru::Verbosity_ ## verbosity_name, cond, __VA_ARGS__) - -#define VLOG_SCOPE_F(verbosity, ...) \ - loguru::LogScopeRAII LOGURU_ANONYMOUS_VARIABLE(error_context_RAII_) = \ - ((verbosity) > loguru::current_verbosity_cutoff()) ? loguru::LogScopeRAII() : \ - loguru::LogScopeRAII(verbosity, __FILE__, __LINE__, __VA_ARGS__) - -// Raw logging - no preamble, no indentation. Slightly faster than full logging. -#define RAW_VLOG_F(verbosity, ...) \ - ((verbosity) > loguru::current_verbosity_cutoff()) ? (void)0 \ - : loguru::raw_log(verbosity, __FILE__, __LINE__, __VA_ARGS__) - -#define RAW_LOG_F(verbosity_name, ...) RAW_VLOG_F(loguru::Verbosity_ ## verbosity_name, __VA_ARGS__) - -// Use to book-end a scope. Affects logging on all threads. -#define LOG_SCOPE_F(verbosity_name, ...) \ - VLOG_SCOPE_F(loguru::Verbosity_ ## verbosity_name, __VA_ARGS__) - -#define LOG_SCOPE_FUNCTION(verbosity_name) LOG_SCOPE_F(verbosity_name, __func__) - -// ----------------------------------------------- -// ABORT_F macro. Usage: ABORT_F("Cause of error: %s", error_str); - -// Message is optional -#define ABORT_F(...) loguru::log_and_abort(0, "ABORT: ", __FILE__, __LINE__, __VA_ARGS__) - -// -------------------------------------------------------------------- -// CHECK_F macros: - -#define CHECK_WITH_INFO_F(test, info, ...) \ - LOGURU_PREDICT_TRUE((test) == true) ? (void)0 : loguru::log_and_abort(0, "CHECK FAILED: " info " ", __FILE__, \ - __LINE__, ##__VA_ARGS__) - -/* Checked at runtime too. Will print error, then call fatal_handler (if any), then 'abort'. - Note that the test must be boolean. - CHECK_F(ptr); will not compile, but CHECK_F(ptr != nullptr); will. */ -#define CHECK_F(test, ...) CHECK_WITH_INFO_F(test, #test, ##__VA_ARGS__) - -#define CHECK_NOTNULL_F(x, ...) CHECK_WITH_INFO_F((x) != nullptr, #x " != nullptr", ##__VA_ARGS__) - -#define CHECK_OP_F(expr_left, expr_right, op, ...) \ - do \ - { \ - auto val_left = expr_left; \ - auto val_right = expr_right; \ - if (! LOGURU_PREDICT_TRUE(val_left op val_right)) \ - { \ - auto str_left = loguru::format_value(val_left); \ - auto str_right = loguru::format_value(val_right); \ - auto fail_info = loguru::textprintf("CHECK FAILED: " LOGURU_FMT(s) " " LOGURU_FMT(s) " " LOGURU_FMT(s) " (" LOGURU_FMT(s) " " LOGURU_FMT(s) " " LOGURU_FMT(s) ") ", \ - #expr_left, #op, #expr_right, str_left.c_str(), #op, str_right.c_str()); \ - auto user_msg = loguru::textprintf(__VA_ARGS__); \ - loguru::log_and_abort(0, fail_info.c_str(), __FILE__, __LINE__, \ - LOGURU_FMT(s), user_msg.c_str()); \ - } \ - } while (false) - -#ifndef LOGURU_DEBUG_LOGGING -#ifndef NDEBUG -#define LOGURU_DEBUG_LOGGING 1 -#else -#define LOGURU_DEBUG_LOGGING 0 -#endif -#endif - -#if LOGURU_DEBUG_LOGGING - // Debug logging enabled: -#define DLOG_F(verbosity_name, ...) LOG_F(verbosity_name, __VA_ARGS__) -#define DVLOG_F(verbosity, ...) VLOG_F(verbosity, __VA_ARGS__) -#define DLOG_IF_F(verbosity_name, ...) LOG_IF_F(verbosity_name, __VA_ARGS__) -#define DVLOG_IF_F(verbosity, ...) VLOG_IF_F(verbosity, __VA_ARGS__) -#define DRAW_LOG_F(verbosity_name, ...) RAW_LOG_F(verbosity_name, __VA_ARGS__) -#define DRAW_VLOG_F(verbosity, ...) RAW_VLOG_F(verbosity, __VA_ARGS__) -#else - // Debug logging disabled: -#define DLOG_F(verbosity_name, ...) -#define DVLOG_F(verbosity, ...) -#define DLOG_IF_F(verbosity_name, ...) -#define DVLOG_IF_F(verbosity, ...) -#define DRAW_LOG_F(verbosity_name, ...) -#define DRAW_VLOG_F(verbosity, ...) -#endif - -#define CHECK_EQ_F(a, b, ...) CHECK_OP_F(a, b, ==, ##__VA_ARGS__) -#define CHECK_NE_F(a, b, ...) CHECK_OP_F(a, b, !=, ##__VA_ARGS__) -#define CHECK_LT_F(a, b, ...) CHECK_OP_F(a, b, < , ##__VA_ARGS__) -#define CHECK_GT_F(a, b, ...) CHECK_OP_F(a, b, > , ##__VA_ARGS__) -#define CHECK_LE_F(a, b, ...) CHECK_OP_F(a, b, <=, ##__VA_ARGS__) -#define CHECK_GE_F(a, b, ...) CHECK_OP_F(a, b, >=, ##__VA_ARGS__) - -#ifndef LOGURU_DEBUG_CHECKS -#ifndef NDEBUG -#define LOGURU_DEBUG_CHECKS 1 -#else -#define LOGURU_DEBUG_CHECKS 0 -#endif -#endif - -#if LOGURU_DEBUG_CHECKS - // Debug checks enabled: -#define DCHECK_F(test, ...) CHECK_F(test, ##__VA_ARGS__) -#define DCHECK_NOTNULL_F(x, ...) CHECK_NOTNULL_F(x, ##__VA_ARGS__) -#define DCHECK_EQ_F(a, b, ...) CHECK_EQ_F(a, b, ##__VA_ARGS__) -#define DCHECK_NE_F(a, b, ...) CHECK_NE_F(a, b, ##__VA_ARGS__) -#define DCHECK_LT_F(a, b, ...) CHECK_LT_F(a, b, ##__VA_ARGS__) -#define DCHECK_LE_F(a, b, ...) CHECK_LE_F(a, b, ##__VA_ARGS__) -#define DCHECK_GT_F(a, b, ...) CHECK_GT_F(a, b, ##__VA_ARGS__) -#define DCHECK_GE_F(a, b, ...) CHECK_GE_F(a, b, ##__VA_ARGS__) -#else - // Debug checks disabled: -#define DCHECK_F(test, ...) -#define DCHECK_NOTNULL_F(x, ...) -#define DCHECK_EQ_F(a, b, ...) -#define DCHECK_NE_F(a, b, ...) -#define DCHECK_LT_F(a, b, ...) -#define DCHECK_LE_F(a, b, ...) -#define DCHECK_GT_F(a, b, ...) -#define DCHECK_GE_F(a, b, ...) -#endif // NDEBUG - - -#if LOGURU_REDEFINE_ASSERT -#undef assert -#ifndef NDEBUG - // Debug: -#define assert(test) CHECK_WITH_INFO_F(!!(test), #test) // HACK -#else -#define assert(test) -#endif -#endif // LOGURU_REDEFINE_ASSERT - -#endif // LOGURU_HAS_DECLARED_FORMAT_HEADER - -// ---------------------------------------------------------------------------- -// .dP"Y8 888888 88""Yb 888888 db 8b d8 .dP"Y8 -// `Ybo." 88 88__dP 88__ dPYb 88b d88 `Ybo." -// o.`Y8b 88 88"Yb 88"" dP__Yb 88YbdP88 o.`Y8b -// 8bodP' 88 88 Yb 888888 dP""""Yb 88 YY 88 8bodP' - -#if LOGURU_WITH_STREAMS -#ifndef LOGURU_HAS_DECLARED_STREAMS_HEADER -#define LOGURU_HAS_DECLARED_STREAMS_HEADER - -/* This file extends loguru to enable std::stream-style logging, a la Glog. - It's an optional feature behind the LOGURU_WITH_STREAMS settings - because including it everywhere will slow down compilation times. -*/ - -#include -#include // Adds about 38 kLoC on clang. -#include - -LOGURU_ANONYMOUS_NAMESPACE_BEGIN - -namespace loguru { - // Like sprintf, but returns the formated text. - LOGURU_EXPORT - std::string strprintf(LOGURU_FORMAT_STRING_TYPE format, ...) LOGURU_PRINTF_LIKE(1, 2); - - // Like vsprintf, but returns the formated text. - LOGURU_EXPORT - std::string vstrprintf(LOGURU_FORMAT_STRING_TYPE format, va_list) LOGURU_PRINTF_LIKE(1, 0); - - class LOGURU_EXPORT StreamLogger { - public: - StreamLogger(Verbosity verbosity, const char* file, unsigned line) : _verbosity(verbosity), _file(file), _line(line) {} - ~StreamLogger() noexcept(false); - - template - StreamLogger& operator<<(const T& t) - { - _ss << t; - return *this; - } - - // std::endl and other iomanip:s. - StreamLogger& operator<<(std::ostream& (*f)(std::ostream&)) - { - f(_ss); - return *this; - } - - private: - Verbosity _verbosity; - const char* _file; - unsigned _line; - std::ostringstream _ss; - }; - - class LOGURU_EXPORT AbortLogger { - public: - AbortLogger(const char* expr, const char* file, unsigned line) : _expr(expr), _file(file), _line(line) {} - LOGURU_NORETURN ~AbortLogger() noexcept(false); - - template - AbortLogger& operator<<(const T& t) - { - _ss << t; - return *this; - } - - // std::endl and other iomanip:s. - AbortLogger& operator<<(std::ostream& (*f)(std::ostream&)) - { - f(_ss); - return *this; - } - - private: - const char* _expr; - const char* _file; - unsigned _line; - std::ostringstream _ss; - }; - - class LOGURU_EXPORT Voidify { - public: - Voidify() {} - // This has to be an operator with a precedence lower than << but higher than ?: - void operator&(const StreamLogger&) {} - void operator&(const AbortLogger&) {} - }; - - /* Helper functions for CHECK_OP_S macro. - GLOG trick: The (int, int) specialization works around the issue that the compiler - will not instantiate the template version of the function on values of unnamed enum type. */ -#define DEFINE_CHECK_OP_IMPL(name, op) \ - template \ - inline std::string* name(const char* expr, const T1& v1, const char* op_str, const T2& v2) \ - { \ - if (LOGURU_PREDICT_TRUE(v1 op v2)) { return NULL; } \ - std::ostringstream ss; \ - ss << "CHECK FAILED: " << expr << " (" << v1 << " " << op_str << " " << v2 << ") "; \ - return new std::string(ss.str()); \ - } \ - inline std::string* name(const char* expr, int v1, const char* op_str, int v2) \ - { \ - return name(expr, v1, op_str, v2); \ - } - - DEFINE_CHECK_OP_IMPL(check_EQ_impl, == ) - DEFINE_CHECK_OP_IMPL(check_NE_impl, != ) - DEFINE_CHECK_OP_IMPL(check_LE_impl, <= ) - DEFINE_CHECK_OP_IMPL(check_LT_impl, < ) - DEFINE_CHECK_OP_IMPL(check_GE_impl, >= ) - DEFINE_CHECK_OP_IMPL(check_GT_impl, > ) -#undef DEFINE_CHECK_OP_IMPL - - /* GLOG trick: Function is overloaded for integral types to allow static const integrals - declared in classes and not defined to be used as arguments to CHECK* macros. */ - template - inline const T& referenceable_value(const T& t) { return t; } - inline char referenceable_value(char t) { return t; } - inline unsigned char referenceable_value(unsigned char t) { return t; } - inline signed char referenceable_value(signed char t) { return t; } - inline short referenceable_value(short t) { return t; } - inline unsigned short referenceable_value(unsigned short t) { return t; } - inline int referenceable_value(int t) { return t; } - inline unsigned int referenceable_value(unsigned int t) { return t; } - inline long referenceable_value(long t) { return t; } - inline unsigned long referenceable_value(unsigned long t) { return t; } - inline long long referenceable_value(long long t) { return t; } - inline unsigned long long referenceable_value(unsigned long long t) { return t; } -} // namespace loguru - -LOGURU_ANONYMOUS_NAMESPACE_END - -// ----------------------------------------------- -// Logging macros: - -// usage: LOG_STREAM(INFO) << "Foo " << std::setprecision(10) << some_value; -#define VLOG_IF_S(verbosity, cond) \ - ((verbosity) > loguru::current_verbosity_cutoff() || (cond) == false) \ - ? (void)0 \ - : loguru::Voidify() & loguru::StreamLogger(verbosity, __FILE__, __LINE__) -#define LOG_IF_S(verbosity_name, cond) VLOG_IF_S(loguru::Verbosity_ ## verbosity_name, cond) -#define VLOG_S(verbosity) VLOG_IF_S(verbosity, true) -#define LOG_S(verbosity_name) VLOG_S(loguru::Verbosity_ ## verbosity_name) - -// ----------------------------------------------- -// ABORT_S macro. Usage: ABORT_S() << "Causo of error: " << details; - -#define ABORT_S() loguru::Voidify() & loguru::AbortLogger("ABORT: ", __FILE__, __LINE__) - -// ----------------------------------------------- -// CHECK_S macros: - -#define CHECK_WITH_INFO_S(cond, info) \ - LOGURU_PREDICT_TRUE((cond) == true) \ - ? (void)0 \ - : loguru::Voidify() & loguru::AbortLogger("CHECK FAILED: " info " ", __FILE__, __LINE__) - -#define CHECK_S(cond) CHECK_WITH_INFO_S(cond, #cond) -#define CHECK_NOTNULL_S(x) CHECK_WITH_INFO_S((x) != nullptr, #x " != nullptr") - -#define CHECK_OP_S(function_name, expr1, op, expr2) \ - while (auto error_string = loguru::function_name(#expr1 " " #op " " #expr2, \ - loguru::referenceable_value(expr1), #op, \ - loguru::referenceable_value(expr2))) \ - loguru::AbortLogger(error_string->c_str(), __FILE__, __LINE__) - -#define CHECK_EQ_S(expr1, expr2) CHECK_OP_S(check_EQ_impl, expr1, ==, expr2) -#define CHECK_NE_S(expr1, expr2) CHECK_OP_S(check_NE_impl, expr1, !=, expr2) -#define CHECK_LE_S(expr1, expr2) CHECK_OP_S(check_LE_impl, expr1, <=, expr2) -#define CHECK_LT_S(expr1, expr2) CHECK_OP_S(check_LT_impl, expr1, < , expr2) -#define CHECK_GE_S(expr1, expr2) CHECK_OP_S(check_GE_impl, expr1, >=, expr2) -#define CHECK_GT_S(expr1, expr2) CHECK_OP_S(check_GT_impl, expr1, > , expr2) - -#if LOGURU_DEBUG_LOGGING - // Debug logging enabled: -#define DVLOG_IF_S(verbosity, cond) VLOG_IF_S(verbosity, cond) -#define DLOG_IF_S(verbosity_name, cond) LOG_IF_S(verbosity_name, cond) -#define DVLOG_S(verbosity) VLOG_S(verbosity) -#define DLOG_S(verbosity_name) LOG_S(verbosity_name) -#else - // Debug logging disabled: -#define DVLOG_IF_S(verbosity, cond) \ - (true || (verbosity) > loguru::current_verbosity_cutoff() || (cond) == false) \ - ? (void)0 \ - : loguru::Voidify() & loguru::StreamLogger(verbosity, __FILE__, __LINE__) - -#define DLOG_IF_S(verbosity_name, cond) DVLOG_IF_S(loguru::Verbosity_ ## verbosity_name, cond) -#define DVLOG_S(verbosity) DVLOG_IF_S(verbosity, true) -#define DLOG_S(verbosity_name) DVLOG_S(loguru::Verbosity_ ## verbosity_name) -#endif - -#if LOGURU_DEBUG_CHECKS - // Debug checks enabled: -#define DCHECK_S(cond) CHECK_S(cond) -#define DCHECK_NOTNULL_S(x) CHECK_NOTNULL_S(x) -#define DCHECK_EQ_S(a, b) CHECK_EQ_S(a, b) -#define DCHECK_NE_S(a, b) CHECK_NE_S(a, b) -#define DCHECK_LT_S(a, b) CHECK_LT_S(a, b) -#define DCHECK_LE_S(a, b) CHECK_LE_S(a, b) -#define DCHECK_GT_S(a, b) CHECK_GT_S(a, b) -#define DCHECK_GE_S(a, b) CHECK_GE_S(a, b) -#else -// Debug checks disabled: -#define DCHECK_S(cond) CHECK_S(true || (cond)) -#define DCHECK_NOTNULL_S(x) CHECK_S(true || (x) != nullptr) -#define DCHECK_EQ_S(a, b) CHECK_S(true || (a) == (b)) -#define DCHECK_NE_S(a, b) CHECK_S(true || (a) != (b)) -#define DCHECK_LT_S(a, b) CHECK_S(true || (a) < (b)) -#define DCHECK_LE_S(a, b) CHECK_S(true || (a) <= (b)) -#define DCHECK_GT_S(a, b) CHECK_S(true || (a) > (b)) -#define DCHECK_GE_S(a, b) CHECK_S(true || (a) >= (b)) -#endif - -#if LOGURU_REPLACE_GLOG -#undef LOG -#undef VLOG -#undef LOG_IF -#undef VLOG_IF -#undef CHECK -#undef CHECK_NOTNULL -#undef CHECK_EQ -#undef CHECK_NE -#undef CHECK_LT -#undef CHECK_LE -#undef CHECK_GT -#undef CHECK_GE -#undef DLOG -#undef DVLOG -#undef DLOG_IF -#undef DVLOG_IF -#undef DCHECK -#undef DCHECK_NOTNULL -#undef DCHECK_EQ -#undef DCHECK_NE -#undef DCHECK_LT -#undef DCHECK_LE -#undef DCHECK_GT -#undef DCHECK_GE -#undef VLOG_IS_ON - -#define LOG LOG_S -#define VLOG VLOG_S -#define LOG_IF LOG_IF_S -#define VLOG_IF VLOG_IF_S -#define CHECK(cond) CHECK_S(!!(cond)) -#define CHECK_NOTNULL CHECK_NOTNULL_S -#define CHECK_EQ CHECK_EQ_S -#define CHECK_NE CHECK_NE_S -#define CHECK_LT CHECK_LT_S -#define CHECK_LE CHECK_LE_S -#define CHECK_GT CHECK_GT_S -#define CHECK_GE CHECK_GE_S -#define DLOG DLOG_S -#define DVLOG DVLOG_S -#define DLOG_IF DLOG_IF_S -#define DVLOG_IF DVLOG_IF_S -#define DCHECK DCHECK_S -#define DCHECK_NOTNULL DCHECK_NOTNULL_S -#define DCHECK_EQ DCHECK_EQ_S -#define DCHECK_NE DCHECK_NE_S -#define DCHECK_LT DCHECK_LT_S -#define DCHECK_LE DCHECK_LE_S -#define DCHECK_GT DCHECK_GT_S -#define DCHECK_GE DCHECK_GE_S -#define VLOG_IS_ON(verbosity) ((verbosity) <= loguru::current_verbosity_cutoff()) - -#endif // LOGURU_REPLACE_GLOG - -#endif // LOGURU_WITH_STREAMS - -#endif // LOGURU_HAS_DECLARED_STREAMS_HEADER diff --git a/lib/mdlp b/lib/mdlp deleted file mode 160000 index cfb993f..0000000 --- a/lib/mdlp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cfb993f5ec1aabed527f524fdd4db06c6d839868 diff --git a/remove_submodules.sh b/remove_submodules.sh new file mode 100644 index 0000000..e5b4ce6 --- /dev/null +++ b/remove_submodules.sh @@ -0,0 +1,14 @@ +git config --file .gitmodules --get-regexp path | awk '{ print $2 }' | while read line; do + echo "Removing $line" + # Deinit the submodule + git submodule deinit -f "$line" + + # Remove the submodule from the working tree + git rm -f "$line" + + # Remove the submodule from .git/modules + rm -rf ".git/modules/$line" +done + +# Remove the .gitmodules file +git rm -f .gitmodules diff --git a/sample/CMakeLists.txt b/sample/CMakeLists.txt index 9cda435..0db5797 100644 --- a/sample/CMakeLists.txt +++ b/sample/CMakeLists.txt @@ -1,15 +1,11 @@ include_directories( + ${TORCH_INCLUDE_DIRS} ${Platform_SOURCE_DIR}/src/common ${Platform_SOURCE_DIR}/src/main ${Python3_INCLUDE_DIRS} - ${Platform_SOURCE_DIR}/lib/Files - ${Platform_SOURCE_DIR}/lib/mdlp/src - ${Platform_SOURCE_DIR}/lib/argparse/include - ${Platform_SOURCE_DIR}/lib/folding - ${Platform_SOURCE_DIR}/lib/json/include ${CMAKE_BINARY_DIR}/configured_files/include ${PyClassifiers_INCLUDE_DIRS} - ${Bayesnet_INCLUDE_DIRS} + ${bayesnet_INCLUDE_DIRS} ) add_executable(PlatformSample sample.cpp ${Platform_SOURCE_DIR}/src/main/Models.cpp) -target_link_libraries(PlatformSample "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) \ No newline at end of file +target_link_libraries(PlatformSample "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} ${Boost_LIBRARIES}) \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7825077..66bcbdd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,18 +1,10 @@ include_directories( ## Libs - ${Platform_SOURCE_DIR}/lib/log - ${Platform_SOURCE_DIR}/lib/Files - ${Platform_SOURCE_DIR}/lib/folding - ${Platform_SOURCE_DIR}/lib/mdlp/src - ${Platform_SOURCE_DIR}/lib/argparse/include - ${Platform_SOURCE_DIR}/lib/json/include - ${Platform_SOURCE_DIR}/lib/libxlsxwriter/include ${Python3_INCLUDE_DIRS} ${MPI_CXX_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} ${CMAKE_BINARY_DIR}/configured_files/include ${PyClassifiers_INCLUDE_DIRS} - ${Bayesnet_INCLUDE_DIRS} ## Platform ${Platform_SOURCE_DIR}/src ${Platform_SOURCE_DIR}/results @@ -28,8 +20,10 @@ add_executable( results/Result.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp + experimental_clfs/DecisionTree.cpp + experimental_clfs/AdaBoost.cpp ) -target_link_libraries(b_best Boost::boost "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}") +target_link_libraries(b_best Boost::boost "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}") # b_grid set(grid_sources GridSearch.cpp GridData.cpp GridExperiment.cpp GridBase.cpp ) @@ -41,8 +35,10 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources} results/Result.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp + experimental_clfs/DecisionTree.cpp + experimental_clfs/AdaBoost.cpp ) -target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) +target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) # b_list add_executable(b_list commands/b_list.cpp @@ -52,8 +48,10 @@ add_executable(b_list commands/b_list.cpp results/Result.cpp results/ResultsDatasetExcel.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp + experimental_clfs/DecisionTree.cpp + experimental_clfs/AdaBoost.cpp ) -target_link_libraries(b_list "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}") +target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}") # b_main set(main_sources Experiment.cpp Models.cpp HyperParameters.cpp Scores.cpp ArgumentsExperiment.cpp) @@ -64,8 +62,11 @@ add_executable(b_main commands/b_main.cpp ${main_sources} results/Result.cpp experimental_clfs/XA1DE.cpp experimental_clfs/ExpClf.cpp + experimental_clfs/ExpClf.cpp + experimental_clfs/DecisionTree.cpp + experimental_clfs/AdaBoost.cpp ) -target_link_libraries(b_main "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy) +target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy) # b_manage set(manage_sources ManageScreen.cpp OptionsMenu.cpp ResultsManager.cpp) @@ -77,7 +78,7 @@ add_executable( results/Result.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp main/Scores.cpp ) -target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" fimdlp "${BayesNet}") +target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" fimdlp bayesnet::bayesnet) # b_results add_executable(b_results commands/b_results.cpp) diff --git a/src/best/BestResults.cpp b/src/best/BestResults.cpp index 09a2cf3..bd7f82a 100644 --- a/src/best/BestResults.cpp +++ b/src/best/BestResults.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "common/Colors.h" #include "common/CLocale.h" #include "common/Paths.h" @@ -123,16 +124,24 @@ namespace platform { } result = std::vector(models.begin(), models.end()); maxModelName = (*max_element(result.begin(), result.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); - maxModelName = std::max(12, maxModelName); + maxModelName = std::max(minLength, maxModelName); return result; } + std::string toLower(std::string data) + { + std::transform(data.begin(), data.end(), data.begin(), + [](unsigned char c) { return std::tolower(c); }); + return data; + } std::vector BestResults::getDatasets(json table) { std::vector datasets; for (const auto& dataset_ : table.items()) { datasets.push_back(dataset_.key()); } - std::stable_sort(datasets.begin(), datasets.end()); + std::stable_sort(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { + return toLower(a) < toLower(b); + }); maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); maxDatasetName = std::max(7, maxDatasetName); return datasets; @@ -222,7 +231,7 @@ namespace platform { std::cout << oss.str(); std::cout << std::string(oss.str().size() - 8, '-') << std::endl; std::cout << Colors::GREEN() << " # " << std::setw(maxDatasetName + 1) << std::left << std::string("Dataset"); - auto bestResultsTex = BestResultsTex(); + auto bestResultsTex = BestResultsTex(score); auto bestResultsMd = BestResultsMd(); if (tex) { bestResultsTex.results_header(models, table.at("dateTable").get(), index); @@ -266,12 +275,14 @@ namespace platform { // Print the row with red colors on max values for (const auto& model : models) { std::string efectiveColor = color; - double value; + double value, std; try { value = table[model].at(dataset_).at(0).get(); + std = table[model].at(dataset_).at(3).get(); } catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) { value = -1.0; + std = -1.0; } if (value == maxValue) { efectiveColor = Colors::RED(); @@ -280,7 +291,8 @@ namespace platform { std::cout << Colors::YELLOW() << std::setw(maxModelName) << std::right << "N/A" << " "; } else { totals[model].push_back(value); - std::cout << efectiveColor << std::setw(maxModelName) << std::setprecision(maxModelName - 2) << std::fixed << value << " "; + std::cout << efectiveColor << std::setw(maxModelName - 6) << std::setprecision(maxModelName - 8) << std::fixed << value; + std::cout << efectiveColor << "±" << std::setw(5) << std::setprecision(3) << std::fixed << std << " "; } } std::cout << std::endl; @@ -307,9 +319,9 @@ namespace platform { for (const auto& model : models) { std::string efectiveColor = model == best_model ? Colors::RED() : Colors::GREEN(); double value = std::reduce(totals[model].begin(), totals[model].end()) / nDatasets; - double std_value = compute_std(totals[model], value); - std::cout << efectiveColor << std::right << std::setw(maxModelName) << std::setprecision(maxModelName - 4) << std::fixed << value << " "; - + double std = compute_std(totals[model], value); + std::cout << efectiveColor << std::right << std::setw(maxModelName - 6) << std::setprecision(maxModelName - 8) << std::fixed << value; + std::cout << efectiveColor << "±" << std::setw(5) << std::setprecision(3) << std::fixed << std << " "; } std::cout << std::endl; } @@ -321,9 +333,10 @@ namespace platform { // Build the table of results json table = buildTableResults(models); std::vector datasets = getDatasets(table.begin().value()); - BestResultsExcel excel_report(score, datasets); + BestResultsExcel excel_report(path, score, datasets); excel_report.reportSingle(model, path + Paths::bestResultsFile(score, model)); messageOutputFile("Excel", excel_report.getFileName()); + excelFileName = excel_report.getFileName(); } } void BestResults::reportAll(bool excel, bool tex, bool index) @@ -337,9 +350,10 @@ namespace platform { // Compute the Friedman test std::map> ranksModels; if (friedman) { - Statistics stats(models, datasets, table, significance); + Statistics stats(score, models, datasets, table, significance); auto result = stats.friedmanTest(); - stats.postHocHolmTest(result, tex); + stats.postHocTest(); + stats.postHocTestReport(result, tex); ranksModels = stats.getRanks(); } if (tex) { @@ -351,33 +365,21 @@ namespace platform { } } if (excel) { - BestResultsExcel excel(score, datasets); + BestResultsExcel excel(path, score, datasets); excel.reportAll(models, table, ranksModels, friedman, significance); if (friedman) { - int idx = -1; - double min = 2000; - // Find out the control model - auto totals = std::vector(models.size(), 0.0); - for (const auto& dataset_ : datasets) { - for (int i = 0; i < models.size(); ++i) { - totals[i] += ranksModels[dataset_][models[i]]; - } - } - for (int i = 0; i < models.size(); ++i) { - if (totals[i] < min) { - min = totals[i]; - idx = i; - } - } + Statistics stats(score, models, datasets, table, significance); + int idx = stats.getControlIdx(); model = models.at(idx); excel.reportSingle(model, path + Paths::bestResultsFile(score, model)); } messageOutputFile("Excel", excel.getFileName()); + excelFileName = excel.getFileName(); } } void BestResults::messageOutputFile(const std::string& title, const std::string& fileName) { - std::cout << Colors::YELLOW() << "** " << std::setw(5) << std::left << title + std::cout << Colors::YELLOW() << "** " << std::setw(8) << std::left << title << " file generated: " << fileName << Colors::RESET() << std::endl; } } \ No newline at end of file diff --git a/src/best/BestResults.h b/src/best/BestResults.h index fde6a74..6d1af54 100644 --- a/src/best/BestResults.h +++ b/src/best/BestResults.h @@ -15,6 +15,7 @@ namespace platform { void reportSingle(bool excel); void reportAll(bool excel, bool tex, bool index); void buildAll(); + std::string getExcelFileName() const { return excelFileName; } private: std::vector getModels(); std::vector getDatasets(json table); @@ -32,6 +33,8 @@ namespace platform { double significance; int maxModelName = 0; int maxDatasetName = 0; + int minLength = 13; // Minimum length for scores + std::string excelFileName; }; } #endif \ No newline at end of file diff --git a/src/best/BestResultsExcel.cpp b/src/best/BestResultsExcel.cpp index 0bc961e..71e6cc3 100644 --- a/src/best/BestResultsExcel.cpp +++ b/src/best/BestResultsExcel.cpp @@ -30,7 +30,7 @@ namespace platform { } return columnName; } - BestResultsExcel::BestResultsExcel(const std::string& score, const std::vector& datasets) : score(score), datasets(datasets) + BestResultsExcel::BestResultsExcel(const std::string& path, const std::string& score, const std::vector& datasets) : path(path), score(score), datasets(datasets) { file_name = Paths::bestResultsExcel(score); workbook = workbook_new(getFileName().c_str()); @@ -92,7 +92,7 @@ namespace platform { catch (const std::out_of_range& oor) { auto tabName = "table_" + std::to_string(i); auto worksheetNew = workbook_add_worksheet(workbook, tabName.c_str()); - json data = loadResultData(Paths::results() + fileName); + json data = loadResultData(path + fileName); auto report = ReportExcel(data, false, workbook, worksheetNew); report.show(); hyperlink = "#table_" + std::to_string(i); @@ -164,13 +164,15 @@ namespace platform { addConditionalFormat("max"); footer(false); if (friedman) { - // Create Sheet with ranks - worksheet = workbook_add_worksheet(workbook, "Ranks"); - formatColumns(); - header(true); - body(true); - addConditionalFormat("min"); - footer(true); + if (score == "accuracy") { + // Create Sheet with ranks + worksheet = workbook_add_worksheet(workbook, "Ranks"); + formatColumns(); + header(true); + body(true); + addConditionalFormat("min"); + footer(true); + } // Create Sheet with Friedman Test doFriedman(); } @@ -241,11 +243,12 @@ namespace platform { } worksheet_merge_range(worksheet, 0, 0, 0, 7, "Friedman Test", styles["headerFirst"]); row = 2; - Statistics stats(models, datasets, table, significance, false); + Statistics stats(score, models, datasets, table, significance, false); // No output auto result = stats.friedmanTest(); - stats.postHocHolmTest(result); + stats.postHocTest(); + stats.postHocTestReport(result, false); // No tex output auto friedmanResult = stats.getFriedmanResult(); - auto holmResult = stats.getHolmResult(); + auto postHocResults = stats.getPostHocResults(); worksheet_merge_range(worksheet, row, 0, row, 7, "Null hypothesis: H0 'There is no significant differences between all the classifiers.'", styles["headerSmall"]); row += 2; writeString(row, 1, "Friedman Q", "bodyHeader"); @@ -264,7 +267,7 @@ namespace platform { row += 2; worksheet_merge_range(worksheet, row, 0, row, 7, "Null hypothesis: H0 'There is no significant differences between the control model and the other models.'", styles["headerSmall"]); row += 2; - std::string controlModel = "Control Model: " + holmResult.model; + std::string controlModel = "Control Model: " + postHocResults.at(0).model; worksheet_merge_range(worksheet, row, 1, row, 7, controlModel.c_str(), styles["bodyHeader_odd"]); row++; writeString(row, 1, "Model", "bodyHeader"); @@ -276,7 +279,7 @@ namespace platform { writeString(row, 7, "Reject H0", "bodyHeader"); row++; bool first = true; - for (const auto& item : holmResult.holmLines) { + for (const auto& item : postHocResults) { writeString(row, 1, item.model, "text"); if (first) { // Control model info diff --git a/src/best/BestResultsExcel.h b/src/best/BestResultsExcel.h index 6c70a49..bd8bf94 100644 --- a/src/best/BestResultsExcel.h +++ b/src/best/BestResultsExcel.h @@ -10,7 +10,7 @@ namespace platform { using json = nlohmann::ordered_json; class BestResultsExcel : public ExcelFile { public: - BestResultsExcel(const std::string& score, const std::vector& datasets); + BestResultsExcel(const std::string& path, const std::string& score, const std::vector& datasets); ~BestResultsExcel(); void reportAll(const std::vector& models, const json& table, const std::map>& ranks, bool friedman, double significance); void reportSingle(const std::string& model, const std::string& fileName); @@ -22,6 +22,7 @@ namespace platform { void formatColumns(); void doFriedman(); void addConditionalFormat(std::string formula); + std::string path; std::string score; std::vector models; std::vector datasets; diff --git a/src/best/BestResultsMd.cpp b/src/best/BestResultsMd.cpp index bfa0a9b..195d3f6 100644 --- a/src/best/BestResultsMd.cpp +++ b/src/best/BestResultsMd.cpp @@ -75,7 +75,7 @@ namespace platform { handler.close(); } - void BestResultsMd::holm_test(struct HolmResult& holmResult, const std::string& date) + void BestResultsMd::postHoc_test(std::vector& postHocResults, const std::string& kind, const std::string& date) { auto file_name = Paths::tex() + Paths::md_post_hoc(); openMdFile(file_name); @@ -84,13 +84,15 @@ namespace platform { handler << std::endl; handler << " Post-hoc handler test" << std::endl; handler << "-->" << std::endl; - handler << "Post-hoc Holm test: H0: There is no significant differences between the control model and the other models." << std::endl << std::endl; + handler << "Post-hoc " << kind << " test: H0: There is no significant differences between the control model and the other models." << std::endl << std::endl; handler << "| classifier | pvalue | rank | win | tie | loss | H0 |" << std::endl; handler << "| :-- | --: | --: | --:| --: | --: | :--: |" << std::endl; - for (auto const& line : holmResult.holmLines) { + bool first = true; + for (auto const& line : postHocResults) { auto textStatus = !line.reject ? "**" : " "; - if (line.model == holmResult.model) { + if (first) { handler << "| " << line.model << " | - | " << std::fixed << std::setprecision(2) << line.rank << " | - | - | - |" << std::endl; + first = false; } else { handler << "| " << line.model << " | " << textStatus << std::scientific << std::setprecision(4) << line.pvalue << textStatus << " |"; handler << std::fixed << std::setprecision(2) << line.rank << " | " << line.wtl.win << " | " << line.wtl.tie << " | " << line.wtl.loss << " |"; diff --git a/src/best/BestResultsMd.h b/src/best/BestResultsMd.h index 253a54a..8ff2612 100644 --- a/src/best/BestResultsMd.h +++ b/src/best/BestResultsMd.h @@ -14,7 +14,7 @@ namespace platform { void results_header(const std::vector& models, const std::string& date); void results_body(const std::vector& datasets, json& table); void results_footer(const std::map>& totals, const std::string& best_model); - void holm_test(struct HolmResult& holmResult, const std::string& date); + void postHoc_test(std::vector& postHocResults, const std::string& kind, const std::string& date); private: void openMdFile(const std::string& name); std::ofstream handler; diff --git a/src/best/BestResultsTex.cpp b/src/best/BestResultsTex.cpp index bf74c88..afe19ad 100644 --- a/src/best/BestResultsTex.cpp +++ b/src/best/BestResultsTex.cpp @@ -27,8 +27,10 @@ namespace platform { handler << "\\tiny " << std::endl; handler << "\\renewcommand{\\arraystretch }{1.2} " << std::endl; handler << "\\renewcommand{\\tabcolsep }{0.07cm} " << std::endl; - handler << "\\caption{Accuracy results(mean $\\pm$ std) for all the algorithms and datasets} " << std::endl; - handler << "\\label{tab:results_accuracy}" << std::endl; + auto umetric = score; + umetric[0] = toupper(umetric[0]); + handler << "\\caption{" << umetric << " results(mean $\\pm$ std) for all the algorithms and datasets} " << std::endl; + handler << "\\label{tab:results_" << score << "}" << std::endl; std::string header_dataset_name = index ? "r" : "l"; handler << "\\begin{tabular} {{" << header_dataset_name << std::string(models.size(), 'c').c_str() << "}}" << std::endl; handler << "\\hline " << std::endl; @@ -87,26 +89,28 @@ namespace platform { handler << "\\end{table}" << std::endl; handler.close(); } - void BestResultsTex::holm_test(struct HolmResult& holmResult, const std::string& date) + void BestResultsTex::postHoc_test(std::vector& postHocResults, const std::string& kind, const std::string& date) { auto file_name = Paths::tex() + Paths::tex_post_hoc(); openTexFile(file_name); handler << "%% This file has been generated by the platform program" << std::endl; handler << "%% Date: " << date.c_str() << std::endl; handler << "%%" << std::endl; - handler << "%% Post-hoc handler test" << std::endl; + handler << "%% Post-hoc " << kind << " test" << std::endl; handler << "%%" << std::endl; handler << "\\begin{table}[htbp]" << std::endl; handler << "\\centering" << std::endl; - handler << "\\caption{Results of the post-hoc test for the mean accuracy of the algorithms.}\\label{tab:tests}" << std::endl; + handler << "\\caption{Results of the post-hoc " << kind << " test for the mean " << score << " of the algorithms.}\\label{ tab:tests }" << std::endl; handler << "\\begin{tabular}{lrrrrr}" << std::endl; handler << "\\hline" << std::endl; handler << "classifier & pvalue & rank & win & tie & loss\\\\" << std::endl; handler << "\\hline" << std::endl; - for (auto const& line : holmResult.holmLines) { + bool first = true; + for (auto const& line : postHocResults) { auto textStatus = !line.reject ? "\\bf " : " "; - if (line.model == holmResult.model) { + if (first) { handler << line.model << " & - & " << std::fixed << std::setprecision(2) << line.rank << " & - & - & - \\\\" << std::endl; + first = false; } else { handler << line.model << " & " << textStatus << std::scientific << std::setprecision(4) << line.pvalue << " & "; handler << std::fixed << std::setprecision(2) << line.rank << " & " << line.wtl.win << " & " << line.wtl.tie << " & " << line.wtl.loss << "\\\\" << std::endl; diff --git a/src/best/BestResultsTex.h b/src/best/BestResultsTex.h index ae88c6d..7392d7c 100644 --- a/src/best/BestResultsTex.h +++ b/src/best/BestResultsTex.h @@ -9,13 +9,14 @@ namespace platform { using json = nlohmann::ordered_json; class BestResultsTex { public: - BestResultsTex(bool dataset_name = true) : dataset_name(dataset_name) {}; + BestResultsTex(const std::string score, bool dataset_name = true) : score{ score }, dataset_name{ dataset_name } {}; ~BestResultsTex() = default; void results_header(const std::vector& models, const std::string& date, bool index); void results_body(const std::vector& datasets, json& table, bool index); void results_footer(const std::map>& totals, const std::string& best_model); - void holm_test(struct HolmResult& holmResult, const std::string& date); + void postHoc_test(std::vector& postHocResults, const std::string& kind, const std::string& date); private: + std::string score; bool dataset_name; void openTexFile(const std::string& name); std::ofstream handler; diff --git a/src/best/Statistics.cpp b/src/best/Statistics.cpp index 73f1edb..1fa16ad 100644 --- a/src/best/Statistics.cpp +++ b/src/best/Statistics.cpp @@ -7,18 +7,25 @@ #include "BestResultsTex.h" #include "BestResultsMd.h" #include "Statistics.h" +#include "WilcoxonTest.hpp" namespace platform { - Statistics::Statistics(const std::vector& models, const std::vector& datasets, const json& data, double significance, bool output) : - models(models), datasets(datasets), data(data), significance(significance), output(output) + Statistics::Statistics(const std::string& score, const std::vector& models, const std::vector& datasets, const json& data, double significance, bool output) : + score(score), models(models), datasets(datasets), data(data), significance(significance), output(output) { + if (score == "accuracy") { + postHocType = "Holm"; + hlen = 85; + } else { + postHocType = "Wilcoxon"; + hlen = 88; + } nModels = models.size(); nDatasets = datasets.size(); auto temp = ConfigLocale(); } - void Statistics::fit() { if (nModels < 3 || nDatasets < 3) { @@ -27,9 +34,11 @@ namespace platform { throw std::runtime_error("Can't make the Friedman test with less than 3 models and/or less than 3 datasets."); } ranksModels.clear(); - computeRanks(); + computeRanks(); // compute greaterAverage and ranks // Set the control model as the one with the lowest average rank - controlIdx = distance(ranks.begin(), min_element(ranks.begin(), ranks.end(), [](const auto& l, const auto& r) { return l.second < r.second; })); + controlIdx = score == "accuracy" ? + distance(ranks.begin(), min_element(ranks.begin(), ranks.end(), [](const auto& l, const auto& r) { return l.second < r.second; })) + : greaterAverage; // The model with the greater average score computeWTL(); maxModelName = (*std::max_element(models.begin(), models.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); maxDatasetName = (*std::max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); @@ -66,11 +75,16 @@ namespace platform { void Statistics::computeRanks() { std::map ranksLine; + std::map averages; + for (const auto& model : models) { + averages[model] = 0; + } for (const auto& dataset : datasets) { std::vector> ranksOrder; for (const auto& model : models) { double value = data[model].at(dataset).at(0).get(); ranksOrder.push_back({ model, value }); + averages[model] += value; } // Assign the ranks ranksLine = assignRanks(ranksOrder); @@ -88,10 +102,17 @@ namespace platform { for (const auto& rank : ranks) { ranks[rank.first] /= nDatasets; } + // Average the scores + for (const auto& average : averages) { + averages[average.first] /= nDatasets; + } + // Get the model with the greater average score + greaterAverage = distance(averages.begin(), max_element(averages.begin(), averages.end(), [](const auto& l, const auto& r) { return l.second < r.second; })); } void Statistics::computeWTL() { - // Compute the WTL matrix + const double practical_threshold = 0.0005; + // Compute the WTL matrix (Win Tie Loss) for (int i = 0; i < nModels; ++i) { wtl[i] = { 0, 0, 0 }; } @@ -104,23 +125,85 @@ namespace platform { continue; } double value = data[models[i]].at(item.key()).at(0).get(); - if (value < controlValue) { - wtl[i].win++; - } else if (value == controlValue) { + double diff = controlValue - value; // control − comparison + if (std::fabs(diff) <= practical_threshold) { wtl[i].tie++; + } else if (diff < 0) { + wtl[i].win++; } else { wtl[i].loss++; } } } } - - void Statistics::postHocHolmTest(bool friedmanResult, bool tex) + int Statistics::getControlIdx() + { + if (!fitted) { + fit(); + } + return controlIdx; + } + void Statistics::postHocTest() + { + if (score == "accuracy") { + postHocHolmTest(); + } else { + postHocWilcoxonTest(); + } + } + void Statistics::postHocWilcoxonTest() + { + if (!fitted) { + fit(); + } + // Reference: Wilcoxon, F. (1945). “Individual Comparisons by Ranking Methods”. Biometrics Bulletin, 1(6), 80-83. + auto wilcoxon = WilcoxonTest(models, datasets, data, significance); + controlIdx = wilcoxon.getControlIdx(); + postHocResults = wilcoxon.getPostHocResults(); + setResultsOrder(); + // Fill the ranks info + for (const auto& item : postHocResults) { + ranks[item.model] = item.rank; + } + Holm_Bonferroni(); + restoreResultsOrder(); + } + void Statistics::Holm_Bonferroni() + { + // The algorithm need the p-values sorted from the lowest to the highest + // Sort the models by p-value + std::sort(postHocResults.begin(), postHocResults.end(), [](const PostHocLine& a, const PostHocLine& b) { + return a.pvalue < b.pvalue; + }); + // Holm adjustment + for (int i = 0; i < postHocResults.size(); ++i) { + auto item = postHocResults.at(i); + double before = i == 0 ? 0.0 : postHocResults.at(i - 1).pvalue; + double p_value = std::min((long double)1.0, item.pvalue * (nModels - i)); + p_value = std::max(before, p_value); + postHocResults[i].pvalue = p_value; + } + } + void Statistics::setResultsOrder() + { + int c = 0; + for (auto& item : postHocResults) { + item.idx = c++; + } + + } + void Statistics::restoreResultsOrder() + { + // Restore the order of the results + std::sort(postHocResults.begin(), postHocResults.end(), [](const PostHocLine& a, const PostHocLine& b) { + return a.idx < b.idx; + }); + } + void Statistics::postHocHolmTest() { if (!fitted) { fit(); } - std::stringstream oss; // Reference https://link.springer.com/article/10.1007/s44196-022-00083-8 // Post-hoc Holm test // Calculate the p-value for the models paired with the control model @@ -128,80 +211,66 @@ namespace platform { boost::math::normal dist(0.0, 1.0); double diff = sqrt(nModels * (nModels + 1) / (6.0 * nDatasets)); for (int i = 0; i < nModels; i++) { + PostHocLine line; + line.model = models[i]; + line.rank = ranks.at(models[i]); + line.wtl = wtl.at(i); + line.reject = false; if (i == controlIdx) { - stats[i] = 0.0; + postHocResults.push_back(line); continue; } double z = std::abs(ranks.at(models[controlIdx]) - ranks.at(models[i])) / diff; - double p_value = (long double)2 * (1 - cdf(dist, z)); - stats[i] = p_value; + line.pvalue = (long double)2 * (1 - cdf(dist, z)); + line.reject = (line.pvalue < significance); + postHocResults.push_back(line); } - // Sort the models by p-value - std::vector> statsOrder; - for (const auto& stat : stats) { - statsOrder.push_back({ stat.first, stat.second }); - } - std::sort(statsOrder.begin(), statsOrder.end(), [](const std::pair& a, const std::pair& b) { - return a.second < b.second; + std::sort(postHocResults.begin(), postHocResults.end(), [](const PostHocLine& a, const PostHocLine& b) { + return a.rank < b.rank; }); + setResultsOrder(); + Holm_Bonferroni(); + restoreResultsOrder(); + } - // Holm adjustment - for (int i = 0; i < statsOrder.size(); ++i) { - auto item = statsOrder.at(i); - double before = i == 0 ? 0.0 : statsOrder.at(i - 1).second; - double p_value = std::min((double)1.0, item.second * (nModels - i)); - p_value = std::max(before, p_value); - statsOrder[i] = { item.first, p_value }; - } - holmResult.model = models.at(controlIdx); + void Statistics::postHocTestReport(bool friedmanResult, bool tex) + { + + std::stringstream oss; auto color = friedmanResult ? Colors::CYAN() : Colors::YELLOW(); oss << color; - oss << " *************************************************************************************************************" << std::endl; - oss << " Post-hoc Holm test: H0: 'There is no significant differences between the control model and the other models.'" << std::endl; + oss << " " << std::string(hlen + 25, '*') << std::endl; + oss << " Post-hoc " << postHocType << " test: H0: 'There is no significant differences between the control model and the other models.'" << std::endl; oss << " Control model: " << models.at(controlIdx) << std::endl; oss << " " << std::left << std::setw(maxModelName) << std::string("Model") << " p-value rank win tie loss Status" << std::endl; oss << " " << std::string(maxModelName, '=') << " ============ ========= === === ==== =============" << std::endl; - // sort ranks from lowest to highest - std::vector> ranksOrder; - for (const auto& rank : ranks) { - ranksOrder.push_back({ rank.first, rank.second }); - } - std::sort(ranksOrder.begin(), ranksOrder.end(), [](const std::pair& a, const std::pair& b) { - return a.second < b.second; - }); - // Show the control model info. - oss << " " << Colors::BLUE() << std::left << std::setw(maxModelName) << ranksOrder.at(0).first << " "; - oss << std::setw(12) << " " << std::setprecision(7) << std::fixed << " " << ranksOrder.at(0).second << std::endl; - for (const auto& item : ranksOrder) { - auto idx = distance(models.begin(), find(models.begin(), models.end(), item.first)); - double pvalue = 0.0; - for (const auto& stat : statsOrder) { - if (stat.first == idx) { - pvalue = stat.second; - } - } - holmResult.holmLines.push_back({ item.first, pvalue, item.second, wtl.at(idx), pvalue < significance }); - if (item.first == models.at(controlIdx)) { + bool first = true; + for (const auto& item : postHocResults) { + if (first) { + oss << " " << Colors::BLUE() << std::left << std::setw(maxModelName) << item.model << " "; + oss << std::setw(12) << " " << std::setprecision(7) << std::fixed << " " << item.rank << std::endl; + first = false; continue; } + auto pvalue = item.pvalue; auto colorStatus = pvalue > significance ? Colors::GREEN() : Colors::MAGENTA(); auto status = pvalue > significance ? Symbols::check_mark : Symbols::cross; auto textStatus = pvalue > significance ? " accepted H0" : " rejected H0"; - oss << " " << colorStatus << std::left << std::setw(maxModelName) << item.first << " "; - oss << std::setprecision(6) << std::scientific << pvalue << std::setprecision(7) << std::fixed << " " << item.second; - oss << " " << std::right << std::setw(3) << wtl.at(idx).win << " " << std::setw(3) << wtl.at(idx).tie << " " << std::setw(4) << wtl.at(idx).loss; + oss << " " << colorStatus << std::left << std::setw(maxModelName) << item.model << " "; + oss << std::setprecision(6) << std::scientific << pvalue << std::setprecision(7) << std::fixed << " " << item.rank; + oss << " " << std::right << std::setw(3) << item.wtl.win << " " << std::setw(3) << item.wtl.tie << " " << std::setw(4) << item.wtl.loss; oss << " " << status << textStatus << std::endl; } - oss << color << " *************************************************************************************************************" << std::endl; + oss << color << " " << std::string(hlen + 25, '*') << std::endl; oss << Colors::RESET(); if (output) { std::cout << oss.str(); } if (tex) { - BestResultsTex bestResultsTex; + BestResultsTex bestResultsTex(score); BestResultsMd bestResultsMd; - bestResultsTex.holm_test(holmResult, get_date() + " " + get_time()); - bestResultsMd.holm_test(holmResult, get_date() + " " + get_time()); + bestResultsTex.postHoc_test(postHocResults, postHocType, get_date() + " " + get_time()); + bestResultsMd.postHoc_test(postHocResults, postHocType, get_date() + " " + get_time()); } } bool Statistics::friedmanTest() @@ -213,7 +282,7 @@ namespace platform { // Friedman test // Calculate the Friedman statistic oss << Colors::BLUE() << std::endl; - oss << "***************************************************************************************************************" << std::endl; + oss << std::string(hlen, '*') << std::endl; oss << Colors::GREEN() << "Friedman test: H0: 'There is no significant differences between all the classifiers.'" << Colors::BLUE() << std::endl; double degreesOfFreedom = nModels - 1.0; double sumSquared = 0; @@ -238,23 +307,11 @@ namespace platform { oss << Colors::YELLOW() << "The null hypothesis H0 is accepted. Computed p-values will not be significant." << std::endl; result = false; } - oss << Colors::BLUE() << "***************************************************************************************************************" << Colors::RESET() << std::endl; + oss << Colors::BLUE() << std::string(hlen, '*') << Colors::RESET() << std::endl; if (output) { std::cout << oss.str(); } friedmanResult = { friedmanQ, criticalValue, p_value, result }; return result; } - FriedmanResult& Statistics::getFriedmanResult() - { - return friedmanResult; - } - HolmResult& Statistics::getHolmResult() - { - return holmResult; - } - std::map>& Statistics::getRanks() - { - return ranksModels; - } } // namespace platform diff --git a/src/best/Statistics.h b/src/best/Statistics.h index ee98c96..a6b5c4a 100644 --- a/src/best/Statistics.h +++ b/src/best/Statistics.h @@ -9,9 +9,9 @@ namespace platform { using json = nlohmann::ordered_json; struct WTL { - int win; - int tie; - int loss; + uint win; + uint tie; + uint loss; }; struct FriedmanResult { double statistic; @@ -19,29 +19,36 @@ namespace platform { long double pvalue; bool reject; }; - struct HolmLine { + struct PostHocLine { + uint idx; //index of the main order std::string model; long double pvalue; double rank; WTL wtl; bool reject; }; - struct HolmResult { - std::string model; - std::vector holmLines; - }; + class Statistics { public: - Statistics(const std::vector& models, const std::vector& datasets, const json& data, double significance = 0.05, bool output = true); + Statistics(const std::string& score, const std::vector& models, const std::vector& datasets, const json& data, double significance = 0.05, bool output = true); bool friedmanTest(); - void postHocHolmTest(bool friedmanResult, bool tex=false); - FriedmanResult& getFriedmanResult(); - HolmResult& getHolmResult(); - std::map>& getRanks(); + void postHocTest(); + void postHocTestReport(bool friedmanResult, bool tex); + int getControlIdx(); + FriedmanResult& getFriedmanResult() { return friedmanResult; } + std::vector& getPostHocResults() { return postHocResults; } + std::map>& getRanks() { return ranksModels; } // ranks of the models per dataset private: void fit(); + void postHocHolmTest(); + void postHocWilcoxonTest(); void computeRanks(); void computeWTL(); + void Holm_Bonferroni(); + void setResultsOrder(); // Set the order of the results based on the statistic analysis needed + void restoreResultsOrder(); // Restore the order of the results after the Holm-Bonferroni adjustment + const std::string& score; + std::string postHocType; const std::vector& models; const std::vector& datasets; const json& data; @@ -51,12 +58,14 @@ namespace platform { int nModels = 0; int nDatasets = 0; int controlIdx = 0; + int greaterAverage = -1; // The model with the greater average score std::map wtl; std::map ranks; int maxModelName = 0; int maxDatasetName = 0; + int hlen; // length of the line FriedmanResult friedmanResult; - HolmResult holmResult; + std::vector postHocResults; std::map> ranksModels; }; } diff --git a/src/best/WilcoxonTest.hpp b/src/best/WilcoxonTest.hpp new file mode 100644 index 0000000..dbf1c0c --- /dev/null +++ b/src/best/WilcoxonTest.hpp @@ -0,0 +1,245 @@ +#ifndef BEST_WILCOXON_TEST_HPP +#define BEST_WILCOXON_TEST_HPP +// WilcoxonTest.hpp +// Stand‑alone class for paired Wilcoxon signed‑rank post‑hoc analysis +// ------------------------------------------------------------------ +// * Constructor takes the *already‑loaded* nlohmann::json object plus the +// vectors of model and dataset names. +// * Internally selects a control model (highest average AUC) and builds all +// statistics (ranks, W/T/L counts, Wilcoxon p‑values). +// * Public API: +// int getControlIdx() const; +// PostHocResult getPostHocResult() const; +// +#include +#include +#include +#include +#include +#include +#include +#include "Statistics.h" + +namespace platform { + class WilcoxonTest { + public: + WilcoxonTest(const std::vector& models, const std::vector& datasets, + const json& data, double alpha = 0.05) : models_(models), datasets_(datasets), data_(data), alpha_(alpha) + { + buildAUCTable(); // extracts all AUCs into a dense matrix + computeAverageAUCs(); // per‑model mean (→ control selection) + computeAverageRanks(); // Friedman‑style ranks per model + selectControlModel(); // sets control_idx_ + buildPostHocResult(); // fills postHocResult_ + } + + int getControlIdx() const noexcept { return control_idx_; } + const std::vector& getPostHocResults() const noexcept { return postHocResults_; } + + private: + //-------------------------------------------------- helper structs ---- + // When a value is missing we keep NaN so that ordinary arithmetic still + // works (NaN simply propagates and we can test with std::isnan). + using Matrix = std::vector>; // [model][dataset] + + //------------------------------------------------- implementation ---- + void buildAUCTable() + { + const std::size_t M = models_.size(); + const std::size_t D = datasets_.size(); + auc_.assign(M, std::vector(D, std::numeric_limits::quiet_NaN())); + + for (std::size_t i = 0; i < M; ++i) { + const auto& model = models_[i]; + for (std::size_t j = 0; j < D; ++j) { + const auto& ds = datasets_[j]; + try { + auc_[i][j] = data_.at(model).at(ds).at(0).get(); + } + catch (...) { + // leave as NaN when value missing + } + } + } + } + + void computeAverageAUCs() + { + const std::size_t M = models_.size(); + avg_auc_.resize(M, std::numeric_limits::quiet_NaN()); + + for (std::size_t i = 0; i < M; ++i) { + double sum = 0.0; + std::size_t cnt = 0; + for (double v : auc_[i]) { + if (!std::isnan(v)) { sum += v; ++cnt; } + } + avg_auc_[i] = cnt ? sum / cnt : std::numeric_limits::quiet_NaN(); + } + } + + // Average rank across datasets (1 = best). + void computeAverageRanks() + { + const std::size_t M = models_.size(); + const std::size_t D = datasets_.size(); + rank_sum_.assign(M, 0.0); + rank_cnt_.assign(M, 0); + + const double EPS = 1e-10; + + for (std::size_t j = 0; j < D; ++j) { + // Collect present values for this dataset + std::vector> vals; // (auc, model_idx) + vals.reserve(M); + for (std::size_t i = 0; i < M; ++i) { + if (!std::isnan(auc_[i][j])) + vals.emplace_back(auc_[i][j], i); + } + if (vals.empty()) continue; // no info for this dataset + + // Sort descending (higher AUC better) + std::sort(vals.begin(), vals.end(), [](auto a, auto b) { + return a.first > b.first; + }); + + // Assign ranks with average for ties + std::size_t k = 0; + while (k < vals.size()) { + std::size_t l = k + 1; + while (l < vals.size() && std::fabs(vals[l].first - vals[k].first) < EPS) ++l; + const double avg_rank = (k + 1 + l) * 0.5; // average of ranks (1‑based) + for (std::size_t m = k; m < l; ++m) { + const auto idx = vals[m].second; + rank_sum_[idx] += avg_rank; + ++rank_cnt_[idx]; + } + k = l; + } + } + + // Final average + avg_rank_.resize(M, std::numeric_limits::quiet_NaN()); + for (std::size_t i = 0; i < M; ++i) { + avg_rank_[i] = rank_cnt_[i] ? rank_sum_[i] / rank_cnt_[i] + : std::numeric_limits::quiet_NaN(); + } + } + + void selectControlModel() + { + // pick model with highest average AUC (ties → first) + control_idx_ = 0; + for (std::size_t i = 1; i < avg_auc_.size(); ++i) { + if (avg_auc_[i] > avg_auc_[control_idx_]) control_idx_ = static_cast(i); + } + } + + void buildPostHocResult() + { + const std::size_t M = models_.size(); + const std::size_t D = datasets_.size(); + const std::string& control_name = models_[control_idx_]; + + const double practical_threshold = 0.0005; // same heuristic as original code + + for (std::size_t i = 0; i < M; ++i) { + PostHocLine line; + line.model = models_[i]; + line.rank = avg_auc_[i]; + + WTL wtl = { 0, 0, 0 }; // win, tie, loss + std::vector differences; + differences.reserve(D); + + for (std::size_t j = 0; j < D; ++j) { + double auc_control = auc_[control_idx_][j]; + double auc_other = auc_[i][j]; + if (std::isnan(auc_control) || std::isnan(auc_other)) continue; + + double diff = auc_control - auc_other; // control − comparison + if (std::fabs(diff) <= practical_threshold) { + ++wtl.tie; + } else if (diff < 0) { + ++wtl.win; // comparison wins + } else { + ++wtl.loss; // control wins + } + differences.push_back(diff); + } + + line.wtl = wtl; + line.pvalue = differences.empty() ? 1.0L : static_cast(wilcoxonSignedRankTest(differences)); + line.reject = (line.pvalue < alpha_); + + postHocResults_.push_back(std::move(line)); + } + // Sort results by rank (descending) + std::sort(postHocResults_.begin(), postHocResults_.end(), [](const PostHocLine& a, const PostHocLine& b) { + return a.rank > b.rank; + }); + } + + // ------------------------------------------------ Wilcoxon (private) -- + static double wilcoxonSignedRankTest(const std::vector& diffs) + { + if (diffs.empty()) return 1.0; + + // Build |diff| + sign vector (exclude zeros) + struct Node { double absval; int sign; }; + std::vector v; + v.reserve(diffs.size()); + for (double d : diffs) { + if (d != 0.0) v.push_back({ std::fabs(d), d > 0 ? 1 : -1 }); + } + if (v.empty()) return 1.0; + + // Sort by absolute value + std::sort(v.begin(), v.end(), [](const Node& a, const Node& b) { return a.absval < b.absval; }); + + const double EPS = 1e-10; + const std::size_t n = v.size(); + std::vector ranks(n, 0.0); + + std::size_t i = 0; + while (i < n) { + std::size_t j = i + 1; + while (j < n && std::fabs(v[j].absval - v[i].absval) < EPS) ++j; + double avg_rank = (i + 1 + j) * 0.5; // 1‑based ranks + for (std::size_t k = i; k < j; ++k) ranks[k] = avg_rank; + i = j; + } + + double w_plus = 0.0, w_minus = 0.0; + for (std::size_t k = 0; k < n; ++k) { + if (v[k].sign > 0) w_plus += ranks[k]; + else w_minus += ranks[k]; + } + double w = std::min(w_plus, w_minus); + double mean_w = n * (n + 1) / 4.0; + double sd_w = std::sqrt(n * (n + 1) * (2 * n + 1) / 24.0); + if (sd_w == 0.0) return 1.0; // degenerate (all diffs identical) + + double z = (w - mean_w) / sd_w; + double p_two = std::erfc(std::fabs(z) / std::sqrt(2.0)); // 2‑sided tail + return p_two; + } + + //-------------------------------------------------------- data ---- + std::vector models_; + std::vector datasets_; + json data_; + double alpha_; + + Matrix auc_; // [model][dataset] + std::vector avg_auc_; // mean AUC per model + std::vector avg_rank_; // mean rank per model + std::vector rank_sum_; // helper for ranks + std::vector rank_cnt_; // datasets counted per model + + int control_idx_ = -1; + std::vector postHocResults_; + }; + +} // namespace platform +#endif // BEST_WILCOXON_TEST_HPP \ No newline at end of file diff --git a/src/commands/b_best.cpp b/src/commands/b_best.cpp index 39ec19a..8c8b89e 100644 --- a/src/commands/b_best.cpp +++ b/src/commands/b_best.cpp @@ -4,16 +4,18 @@ #include "main/modelRegister.h" #include "common/Paths.h" #include "common/Colors.h" +#include "common/Utils.h" #include "best/BestResults.h" +#include "common/DotEnv.h" #include "config_platform.h" void manageArguments(argparse::ArgumentParser& program) { - program.add_argument("-m", "--model") - .help("Model to use or any") - .default_value("any"); + auto env = platform::DotEnv(); + program.add_argument("-m", "--model").help("Model to use or any").default_value("any"); + program.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results()); program.add_argument("-d", "--dataset").default_value("any").help("Filter results of the selected model) (any for all datasets)"); - program.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied"); + program.add_argument("-s", "--score").default_value(env.get("score")).help("Filter results of the score name supplied"); program.add_argument("--friedman").help("Friedman test").default_value(false).implicit_value(true); program.add_argument("--excel").help("Output to excel").default_value(false).implicit_value(true); program.add_argument("--tex").help("Output results to TeX & Markdown files").default_value(false).implicit_value(true); @@ -38,12 +40,16 @@ int main(int argc, char** argv) { argparse::ArgumentParser program("b_best", { platform_project_version.begin(), platform_project_version.end() }); manageArguments(program); - std::string model, dataset, score; + std::string model, dataset, score, folder; bool build, report, friedman, excel, tex, index; double level; try { program.parse_args(argc, argv); model = program.get("model"); + folder = program.get("folder"); + if (folder.back() != '/') { + folder += '/'; + } dataset = program.get("dataset"); score = program.get("score"); friedman = program.get("friedman"); @@ -66,7 +72,7 @@ int main(int argc, char** argv) exit(1); } // Generate report - auto results = platform::BestResults(platform::Paths::results(), score, model, dataset, friedman, level); + auto results = platform::BestResults(folder, score, model, dataset, friedman, level); if (model == "any") { results.buildAll(); results.reportAll(excel, tex, index); @@ -75,6 +81,11 @@ int main(int argc, char** argv) std::cout << Colors::GREEN() << fileName << " created!" << Colors::RESET() << std::endl; results.reportSingle(excel); } + if (excel) { + auto fileName = results.getExcelFileName(); + std::cout << "Opening " << fileName << std::endl; + platform::openFile(fileName); + } std::cout << Colors::RESET(); return 0; } diff --git a/src/commands/b_grid.cpp b/src/commands/b_grid.cpp index b6efd56..b1c6244 100644 --- a/src/commands/b_grid.cpp +++ b/src/commands/b_grid.cpp @@ -232,6 +232,7 @@ void experiment(argparse::ArgumentParser& program) struct platform::ConfigGrid config; auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::GRID); arguments.parse(); + auto path_results = arguments.getPathResults(); auto grid_experiment = platform::GridExperiment(arguments, config); platform::Timer timer; timer.start(); @@ -250,7 +251,7 @@ void experiment(argparse::ArgumentParser& program) auto duration = timer.getDuration(); experiment.setDuration(duration); if (grid_experiment.haveToSaveResults()) { - experiment.saveResult(); + experiment.saveResult(path_results); } experiment.report(); std::cout << "Process took " << duration << std::endl; diff --git a/src/commands/b_list.cpp b/src/commands/b_list.cpp index 5309101..96950b0 100644 --- a/src/commands/b_list.cpp +++ b/src/commands/b_list.cpp @@ -8,6 +8,7 @@ #include "common/Paths.h" #include "common/Colors.h" #include "common/Datasets.h" +#include "common/Utils.h" #include "reports/DatasetsExcel.h" #include "reports/DatasetsConsole.h" #include "results/ResultsDatasetConsole.h" @@ -24,9 +25,13 @@ void list_datasets(argparse::ArgumentParser& program) std::cout << report.getOutput(); if (excel) { auto data = report.getData(); - auto report = platform::DatasetsExcel(); - report.report(data); - std::cout << std::endl << Colors::GREEN() << "Output saved in " << report.getFileName() << std::endl; + auto ereport = new platform::DatasetsExcel(); + ereport->report(data); + std::cout << std::endl << Colors::GREEN() << "Output saved in " << ereport->getFileName() << std::endl; + auto fileName = ereport->getExcelFileName(); + delete ereport; + std::cout << "Opening " << fileName << std::endl; + platform::openFile(fileName); } } @@ -42,9 +47,13 @@ void list_results(argparse::ArgumentParser& program) std::cout << report.getOutput(); if (excel) { auto data = report.getData(); - auto report = platform::ResultsDatasetExcel(); - report.report(data); - std::cout << std::endl << Colors::GREEN() << "Output saved in " << report.getFileName() << std::endl; + auto ereport = new platform::ResultsDatasetExcel(); + ereport->report(data); + std::cout << std::endl << Colors::GREEN() << "Output saved in " << ereport->getFileName() << std::endl; + auto fileName = ereport->getExcelFileName(); + delete ereport; + std::cout << "Opening " << fileName << std::endl; + platform::openFile(fileName); } } diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index f04a79f..03002d5 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -18,6 +18,7 @@ int main(int argc, char** argv) */ // Initialize the experiment class with the command line arguments auto experiment = arguments.initializedExperiment(); + auto path_results = arguments.getPathResults(); platform::Timer timer; timer.start(); experiment.go(); @@ -27,7 +28,7 @@ int main(int argc, char** argv) experiment.report(); } if (arguments.haveToSaveResults()) { - experiment.saveResult(); + experiment.saveResult(path_results); } if (arguments.doGraph()) { experiment.saveGraph(); diff --git a/src/commands/b_manage.cpp b/src/commands/b_manage.cpp index 0dda157..8a0deb4 100644 --- a/src/commands/b_manage.cpp +++ b/src/commands/b_manage.cpp @@ -1,7 +1,8 @@ + +#include #include #include -#include -#include +#include "common/Paths.h" #include #include "manage/ManageScreen.h" #include @@ -13,6 +14,7 @@ void manageArguments(argparse::ArgumentParser& program, int argc, char** argv) { program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)"); program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied"); + program.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results()); program.add_argument("--platform").default_value("any").help("Filter results of the selected platform"); program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true); program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true); @@ -51,71 +53,17 @@ void handleResize(int sig) manager->updateSize(rows, cols); } -void openFile(const std::string& fileName) -{ - // #ifdef __APPLE__ - // // macOS uses the "open" command - // std::string command = "open"; - // #elif defined(__linux__) - // // Linux typically uses "xdg-open" - // std::string command = "xdg-open"; - // #else - // // For other OSes, do nothing or handle differently - // std::cerr << "Unsupported platform." << std::endl; - // return; - // #endif - // execlp(command.c_str(), command.c_str(), fileName.c_str(), NULL); -#ifdef __APPLE__ - const char* tool = "/usr/bin/open"; -#elif defined(__linux__) - const char* tool = "/usr/bin/xdg-open"; -#else - std::cerr << "Unsupported platform." << std::endl; - return; -#endif - // We'll build an argv array for execve: - std::vector argv; - argv.push_back(const_cast(tool)); // argv[0] - argv.push_back(const_cast(fileName.c_str())); // argv[1] - argv.push_back(nullptr); - - // Make a new environment array, skipping BASH_FUNC_ variables - std::vector filteredEnv; - for (char** env = environ; *env != nullptr; ++env) { - // *env is a string like "NAME=VALUE" - // We want to skip those starting with "BASH_FUNC_" - if (strncmp(*env, "BASH_FUNC_", 10) == 0) { - // skip it - continue; - } - filteredEnv.push_back(*env); - } - - // Convert filteredEnv into a char* array - std::vector envp; - for (auto& var : filteredEnv) { - envp.push_back(const_cast(var.c_str())); - } - envp.push_back(nullptr); - - // Now call execve with the cleaned environment - // NOTE: You may need a full path to the tool if it's not in PATH, or use which() logic - // For now, let's assume "open" or "xdg-open" is found in the default PATH: - execve(tool, argv.data(), envp.data()); - - // If we reach here, execve failed - perror("execve failed"); - // This would terminate your current process if it's not in a child - // Usually you'd do something like: - _exit(EXIT_FAILURE); -} int main(int argc, char** argv) { auto program = argparse::ArgumentParser("b_manage", { platform_project_version.begin(), platform_project_version.end() }); manageArguments(program, argc, argv); std::string model = program.get("model"); + std::string path = program.get("folder"); + if (path.back() != '/') { + path += '/'; + } std::string score = program.get("score"); std::string platform = program.get("platform"); bool complete = program.get("complete"); @@ -125,13 +73,13 @@ int main(int argc, char** argv) partial = false; signal(SIGWINCH, handleResize); auto [rows, cols] = numRowsCols(); - manager = new platform::ManageScreen(rows, cols, model, score, platform, complete, partial, compare); + manager = new platform::ManageScreen(path, rows, cols, model, score, platform, complete, partial, compare); manager->doMenu(); auto fileName = manager->getExcelFileName(); delete manager; if (!fileName.empty()) { std::cout << "Opening " << fileName << std::endl; - openFile(fileName); + platform::openFile(fileName); } return 0; } diff --git a/src/common/Dataset.cpp b/src/common/Dataset.cpp index 26c0882..c635e0f 100644 --- a/src/common/Dataset.cpp +++ b/src/common/Dataset.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include "Dataset.h" namespace platform { diff --git a/src/common/Paths.h b/src/common/Paths.h index 15a42d1..6861457 100644 --- a/src/common/Paths.h +++ b/src/common/Paths.h @@ -49,6 +49,7 @@ namespace platform { return "BestResults_" + score + ".xlsx"; } static std::string excelResults() { return "some_results.xlsx"; } + static std::string excelDatasets() { return "datasets.xlsx"; } static std::string grid_input(const std::string& model) { return grid() + "grid_" + model + "_input.json"; @@ -73,6 +74,7 @@ namespace platform { { return "post_hoc.md"; } + }; } #endif \ No newline at end of file diff --git a/src/common/Utils.h b/src/common/Utils.h index 92e5a4a..e371d89 100644 --- a/src/common/Utils.h +++ b/src/common/Utils.h @@ -1,5 +1,7 @@ #ifndef UTILS_H #define UTILS_H + +#include #include #include #include @@ -66,5 +68,64 @@ namespace platform { oss << std::put_time(timeinfo, "%H:%M:%S"); return oss.str(); } + static void openFile(const std::string& fileName) + { + // #ifdef __APPLE__ + // // macOS uses the "open" command + // std::string command = "open"; + // #elif defined(__linux__) + // // Linux typically uses "xdg-open" + // std::string command = "xdg-open"; + // #else + // // For other OSes, do nothing or handle differently + // std::cerr << "Unsupported platform." << std::endl; + // return; + // #endif + // execlp(command.c_str(), command.c_str(), fileName.c_str(), NULL); +#ifdef __APPLE__ + const char* tool = "/usr/bin/open"; +#elif defined(__linux__) + const char* tool = "/usr/bin/xdg-open"; +#else + std::cerr << "Unsupported platform." << std::endl; + return; +#endif + + // We'll build an argv array for execve: + std::vector argv; + argv.push_back(const_cast(tool)); // argv[0] + argv.push_back(const_cast(fileName.c_str())); // argv[1] + argv.push_back(nullptr); + + // Make a new environment array, skipping BASH_FUNC_ variables + std::vector filteredEnv; + for (char** env = environ; *env != nullptr; ++env) { + // *env is a string like "NAME=VALUE" + // We want to skip those starting with "BASH_FUNC_" + if (strncmp(*env, "BASH_FUNC_", 10) == 0) { + // skip it + continue; + } + filteredEnv.push_back(*env); + } + + // Convert filteredEnv into a char* array + std::vector envp; + for (auto& var : filteredEnv) { + envp.push_back(const_cast(var.c_str())); + } + envp.push_back(nullptr); + + // Now call execve with the cleaned environment + // NOTE: You may need a full path to the tool if it's not in PATH, or use which() logic + // For now, let's assume "open" or "xdg-open" is found in the default PATH: + execve(tool, argv.data(), envp.data()); + + // If we reach here, execve failed + perror("execve failed"); + // This would terminate your current process if it's not in a child + // Usually you'd do something like: + _exit(EXIT_FAILURE); + } } #endif \ No newline at end of file diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp new file mode 100644 index 0000000..563fe34 --- /dev/null +++ b/src/experimental_clfs/AdaBoost.cpp @@ -0,0 +1,492 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include "AdaBoost.h" +#include "DecisionTree.h" +#include +#include +#include +#include +#include +#include "TensorUtils.hpp" + +// Conditional debug macro for performance-critical sections +#define DEBUG_LOG(condition, ...) \ + do { \ + if (__builtin_expect((condition), 0)) { \ + std::cout << __VA_ARGS__ << std::endl; \ + } \ + } while(0) + +namespace bayesnet { + + AdaBoost::AdaBoost(int n_estimators, int max_depth) + : Ensemble(true), n_estimators(n_estimators), base_max_depth(max_depth), n(0), n_classes(0) + { + validHyperparameters = { "n_estimators", "base_max_depth" }; + } + + // Versión optimizada de buildModel - Reemplazar en AdaBoost.cpp: + + void AdaBoost::buildModel(const torch::Tensor& weights) + { + // Initialize variables + models.clear(); + alphas.clear(); + training_errors.clear(); + + // Initialize n (number of features) and n_classes + n = dataset.size(0) - 1; // Exclude the label row + n_classes = states[className].size(); + + // Initialize sample weights uniformly + int n_samples = dataset.size(1); + sample_weights = torch::ones({ n_samples }) / n_samples; + + // If initial weights are provided, incorporate them + if (weights.defined() && weights.numel() > 0) { + if (weights.size(0) != n_samples) { + throw std::runtime_error("weights must have the same length as number of samples"); + } + sample_weights = weights.clone(); + normalizeWeights(); + } + + // Conditional debug information (only when debug is enabled) + DEBUG_LOG(debug, "Starting AdaBoost training with " << n_estimators << " estimators\n" + << "Number of classes: " << n_classes << "\n" + << "Number of features: " << n << "\n" + << "Number of samples: " << n_samples); + + // Pre-compute random guess error threshold + const double random_guess_error = 1.0 - (1.0 / static_cast(n_classes)); + + // Main AdaBoost training loop (SAMME algorithm) + for (int iter = 0; iter < n_estimators; ++iter) { + // Train base estimator with current sample weights + auto estimator = trainBaseEstimator(sample_weights); + + // Calculate weighted error + double weighted_error = calculateWeightedError(estimator.get(), sample_weights); + training_errors.push_back(weighted_error); + + // According to SAMME, we need error < random_guess_error + if (weighted_error >= random_guess_error) { + DEBUG_LOG(debug, "Error >= random guess (" << random_guess_error << "), stopping"); + // If only one estimator and it's worse than random, keep it with zero weight + if (models.empty()) { + models.push_back(std::move(estimator)); + alphas.push_back(0.0); + } + break; // Stop boosting + } + + // Check for perfect classification BEFORE calculating alpha + if (weighted_error <= 1e-10) { + DEBUG_LOG(debug, "Perfect classification achieved (error=" << weighted_error << ")"); + + // For perfect classification, use a large but finite alpha + double alpha = 10.0 + std::log(static_cast(n_classes - 1)); + + // Store the estimator and its weight + models.push_back(std::move(estimator)); + alphas.push_back(alpha); + + DEBUG_LOG(debug, "Iteration " << iter << ":\n" + << " Weighted error: " << weighted_error << "\n" + << " Alpha (finite): " << alpha << "\n" + << " Random guess error: " << random_guess_error); + + break; // Stop training as we have a perfect classifier + } + + // Calculate alpha (estimator weight) using SAMME formula + // alpha = log((1 - err) / err) + log(K - 1) + // Clamp weighted_error to avoid division by zero and infinite alpha + double clamped_error = std::max(1e-15, std::min(1.0 - 1e-15, weighted_error)); + double alpha = std::log((1.0 - clamped_error) / clamped_error) + + std::log(static_cast(n_classes - 1)); + + // Clamp alpha to reasonable bounds to avoid numerical issues + alpha = std::max(-10.0, std::min(10.0, alpha)); + + // Store the estimator and its weight + models.push_back(std::move(estimator)); + alphas.push_back(alpha); + + // Update sample weights (only if this is not the last iteration) + if (iter < n_estimators - 1) { + updateSampleWeights(models.back().get(), alpha); + normalizeWeights(); + } + + DEBUG_LOG(debug, "Iteration " << iter << ":\n" + << " Weighted error: " << weighted_error << "\n" + << " Alpha: " << alpha << "\n" + << " Random guess error: " << random_guess_error); + } + + // Set the number of models actually trained + n_models = models.size(); + DEBUG_LOG(debug, "AdaBoost training completed with " << n_models << " models"); + } + + void AdaBoost::trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) + { + // Call buildModel which does the actual training + buildModel(weights); + fitted = true; + } + + std::unique_ptr AdaBoost::trainBaseEstimator(const torch::Tensor& weights) + { + // Create a decision tree with specified max depth + auto tree = std::make_unique(base_max_depth); + + // Ensure weights are properly normalized + auto normalized_weights = weights / weights.sum(); + + // Fit the tree with the current sample weights + tree->fit(dataset, features, className, states, normalized_weights, Smoothing_t::NONE); + + return tree; + } + + double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights) + { + // Get features and labels from dataset (avoid repeated indexing) + auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }); + auto y_true = dataset.index({ -1, torch::indexing::Slice() }); + + // Get predictions from the estimator + auto y_pred = estimator->predict(X); + + // Vectorized error calculation using PyTorch operations + auto incorrect = (y_pred != y_true).to(torch::kDouble); + + // Direct dot product for weighted error (more efficient than sum) + double weighted_error = torch::dot(incorrect, weights).item(); + + // Clamp to valid range in one operation + return std::clamp(weighted_error, 1e-15, 1.0 - 1e-15); + } + + void AdaBoost::updateSampleWeights(Classifier* estimator, double alpha) + { + // Get predictions from the estimator (reuse from calculateWeightedError if possible) + auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }); + auto y_true = dataset.index({ -1, torch::indexing::Slice() }); + auto y_pred = estimator->predict(X); + + // Vectorized weight update using PyTorch operations + auto incorrect = (y_pred != y_true).to(torch::kDouble); + + // Single vectorized operation instead of element-wise multiplication + sample_weights *= torch::exp(alpha * incorrect); + + // Vectorized clamping for numerical stability + sample_weights = torch::clamp(sample_weights, 1e-15, 1e15); + } + + void AdaBoost::normalizeWeights() + { + // Single-pass normalization using PyTorch operations + double sum_weights = torch::sum(sample_weights).item(); + + if (__builtin_expect(sum_weights <= 0, 0)) { + // Reset to uniform if all weights are zero/negative (rare case) + sample_weights = torch::ones_like(sample_weights) / sample_weights.size(0); + } else { + // Vectorized normalization + sample_weights /= sum_weights; + + // Vectorized minimum weight enforcement + sample_weights = torch::clamp_min(sample_weights, 1e-15); + + // Renormalize after clamping (if any weights were clamped) + double new_sum = torch::sum(sample_weights).item(); + if (new_sum != 1.0) { + sample_weights /= new_sum; + } + } + } + + std::vector AdaBoost::graph(const std::string& title) const + { + // Create a graph representation of the AdaBoost ensemble + std::vector graph_lines; + + // Header + graph_lines.push_back("digraph AdaBoost {"); + graph_lines.push_back(" rankdir=TB;"); + graph_lines.push_back(" node [shape=box];"); + + if (!title.empty()) { + graph_lines.push_back(" label=\"" + title + "\";"); + graph_lines.push_back(" labelloc=t;"); + } + + // Add input node + graph_lines.push_back(" Input [shape=ellipse, label=\"Input Features\"];"); + + // Add base estimators + for (size_t i = 0; i < models.size(); ++i) { + std::stringstream ss; + ss << " Estimator" << i << " [label=\"Base Estimator " << i + 1 + << "\\nα = " << std::fixed << std::setprecision(3) << alphas[i] << "\"];"; + graph_lines.push_back(ss.str()); + + // Connect input to estimator + ss.str(""); + ss << " Input -> Estimator" << i << ";"; + graph_lines.push_back(ss.str()); + } + + // Add combination node + graph_lines.push_back(" Combination [shape=diamond, label=\"Weighted Vote\"];"); + + // Connect estimators to combination + for (size_t i = 0; i < models.size(); ++i) { + std::stringstream ss; + ss << " Estimator" << i << " -> Combination;"; + graph_lines.push_back(ss.str()); + } + + // Add output node + graph_lines.push_back(" Output [shape=ellipse, label=\"Final Prediction\"];"); + graph_lines.push_back(" Combination -> Output;"); + + // Close graph + graph_lines.push_back("}"); + + return graph_lines; + } + + void AdaBoost::checkValues() const + { + if (n_estimators <= 0) { + throw std::invalid_argument("n_estimators must be positive"); + } + if (base_max_depth <= 0) { + throw std::invalid_argument("base_max_depth must be positive"); + } + } + + void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters_) + { + auto hyperparameters = hyperparameters_; + // Set hyperparameters from JSON + auto it = hyperparameters.find("n_estimators"); + if (it != hyperparameters.end()) { + n_estimators = it->get(); + hyperparameters.erase("n_estimators"); + } + + it = hyperparameters.find("base_max_depth"); + if (it != hyperparameters.end()) { + base_max_depth = it->get(); + hyperparameters.erase("base_max_depth"); + } + checkValues(); + Ensemble::setHyperparameters(hyperparameters); + } + + int AdaBoost::predictSample(const torch::Tensor& x) const + { + // Early validation (keep essential checks only) + if (!fitted || models.empty()) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + // Pre-allocate and reuse memory + static thread_local std::vector class_votes_cache; + if (class_votes_cache.size() != static_cast(n_classes)) { + class_votes_cache.resize(n_classes); + } + std::fill(class_votes_cache.begin(), class_votes_cache.end(), 0.0); + + // Optimized voting loop - avoid exception handling in hot path + for (size_t i = 0; i < models.size(); ++i) { + double alpha = alphas[i]; + if (alpha <= 0 || !std::isfinite(alpha)) continue; + + // Direct cast and call - avoid virtual dispatch overhead + int predicted_class = static_cast(models[i].get())->predictSample(x); + + // Bounds check with branch prediction hint + if (__builtin_expect(predicted_class >= 0 && predicted_class < n_classes, 1)) { + class_votes_cache[predicted_class] += alpha; + } + } + + // Fast argmax using iterators + return std::distance(class_votes_cache.begin(), + std::max_element(class_votes_cache.begin(), class_votes_cache.end())); + } + + torch::Tensor AdaBoost::predictProbaSample(const torch::Tensor& x) const + { + // Early validation + if (!fitted || models.empty()) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + // Use stack allocation for small arrays (typical case: n_classes <= 32) + constexpr int STACK_THRESHOLD = 32; + double stack_votes[STACK_THRESHOLD]; + std::vector heap_votes; + double* class_votes; + + if (n_classes <= STACK_THRESHOLD) { + class_votes = stack_votes; + std::fill_n(class_votes, n_classes, 0.0); + } else { + heap_votes.resize(n_classes, 0.0); + class_votes = heap_votes.data(); + } + + double total_votes = 0.0; + + // Optimized voting loop + for (size_t i = 0; i < models.size(); ++i) { + double alpha = alphas[i]; + if (alpha <= 0 || !std::isfinite(alpha)) continue; + + int predicted_class = static_cast(models[i].get())->predictSample(x); + + if (__builtin_expect(predicted_class >= 0 && predicted_class < n_classes, 1)) { + class_votes[predicted_class] += alpha; + total_votes += alpha; + } + } + + // Direct tensor creation with pre-computed size + torch::Tensor class_probs = torch::empty({ n_classes }, torch::TensorOptions().dtype(torch::kFloat32)); + auto probs_accessor = class_probs.accessor(); + + if (__builtin_expect(total_votes > 0.0, 1)) { + // Vectorized probability calculation + const double inv_total = 1.0 / total_votes; + for (int j = 0; j < n_classes; ++j) { + probs_accessor[j] = static_cast(class_votes[j] * inv_total); + } + } else { + // Uniform distribution fallback + const float uniform_prob = 1.0f / n_classes; + for (int j = 0; j < n_classes; ++j) { + probs_accessor[j] = uniform_prob; + } + } + + return class_probs; + } + + torch::Tensor AdaBoost::predict_proba(torch::Tensor& X) + { + if (!fitted || models.empty()) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + // Input validation + if (X.size(0) != n) { + throw std::runtime_error("Input has wrong number of features. Expected " + + std::to_string(n) + " but got " + std::to_string(X.size(0))); + } + + const int n_samples = X.size(1); + + // Pre-allocate output tensor with correct layout + torch::Tensor probabilities = torch::empty({ n_samples, n_classes }, + torch::TensorOptions().dtype(torch::kFloat32)); + + // Convert to contiguous memory if needed (optimization for memory access) + if (!X.is_contiguous()) { + X = X.contiguous(); + } + + // Batch processing with memory-efficient sample extraction + for (int i = 0; i < n_samples; ++i) { + // Extract sample without unnecessary copies + auto sample = X.select(1, i); + + // Direct assignment to pre-allocated tensor + probabilities[i] = predictProbaSample(sample); + } + + return probabilities; + } + + std::vector> AdaBoost::predict_proba(std::vector>& X) + { + const size_t n_samples = X[0].size(); + + // Pre-allocate result with exact size + std::vector> result; + result.reserve(n_samples); + + // Avoid repeated allocations + for (size_t i = 0; i < n_samples; ++i) { + result.emplace_back(n_classes, 0.0); + } + + // Convert to tensor only once (batch conversion is more efficient) + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); + torch::Tensor proba_tensor = predict_proba(X_tensor); + + // Optimized tensor-to-vector conversion + auto proba_accessor = proba_tensor.accessor(); + for (size_t i = 0; i < n_samples; ++i) { + for (int j = 0; j < n_classes; ++j) { + result[i][j] = static_cast(proba_accessor[i][j]); + } + } + + return result; + } + + torch::Tensor AdaBoost::predict(torch::Tensor& X) + { + if (!fitted || models.empty()) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + if (X.size(0) != n) { + throw std::runtime_error("Input has wrong number of features. Expected " + + std::to_string(n) + " but got " + std::to_string(X.size(0))); + } + + const int n_samples = X.size(1); + + // Pre-allocate with correct dtype + torch::Tensor predictions = torch::empty({ n_samples }, torch::TensorOptions().dtype(torch::kInt32)); + auto pred_accessor = predictions.accessor(); + + // Ensure contiguous memory layout + if (!X.is_contiguous()) { + X = X.contiguous(); + } + + // Optimized prediction loop + for (int i = 0; i < n_samples; ++i) { + auto sample = X.select(1, i); + pred_accessor[i] = predictSample(sample); + } + + return predictions; + } + + std::vector AdaBoost::predict(std::vector>& X) + { + // Single tensor conversion for batch processing + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); + torch::Tensor predictions_tensor = predict(X_tensor); + + // Optimized tensor-to-vector conversion + std::vector result = platform::TensorUtils::to_vector(predictions_tensor); + return result; + } + +} // namespace bayesnet \ No newline at end of file diff --git a/src/experimental_clfs/AdaBoost.h b/src/experimental_clfs/AdaBoost.h new file mode 100644 index 0000000..1d5c729 --- /dev/null +++ b/src/experimental_clfs/AdaBoost.h @@ -0,0 +1,81 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#ifndef ADABOOST_H +#define ADABOOST_H + +#include +#include +#include "bayesnet/ensembles/Ensemble.h" + +namespace bayesnet { + class AdaBoost : public Ensemble { + public: + explicit AdaBoost(int n_estimators = 100, int max_depth = 1); + virtual ~AdaBoost() = default; + + // Override base class methods + std::vector graph(const std::string& title = "") const override; + + // AdaBoost specific methods + void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; checkValues(); } + int getNEstimators() const { return n_estimators; } + void setBaseMaxDepth(int depth) { this->base_max_depth = depth; checkValues(); } + int getBaseMaxDepth() const { return base_max_depth; } + + // Get the weight of each base estimator + std::vector getEstimatorWeights() const { return alphas; } + + // Get training errors for each iteration + std::vector getTrainingErrors() const { return training_errors; } + + // Override setHyperparameters from BaseClassifier + void setHyperparameters(const nlohmann::json& hyperparameters) override; + + torch::Tensor predict(torch::Tensor& X) override; + std::vector predict(std::vector>& X) override; + torch::Tensor predict_proba(torch::Tensor& X) override; + std::vector> predict_proba(std::vector>& X); + void setDebug(bool debug) { this->debug = debug; } + + protected: + void buildModel(const torch::Tensor& weights) override; + void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override; + + private: + int n_estimators; + int base_max_depth; // Max depth for base decision trees + std::vector alphas; // Weight of each base estimator + std::vector training_errors; // Training error at each iteration + torch::Tensor sample_weights; // Current sample weights + int n_classes; // Number of classes in the target variable + int n; // Number of features + + // Train a single base estimator + std::unique_ptr trainBaseEstimator(const torch::Tensor& weights); + + // Calculate weighted error + double calculateWeightedError(Classifier* estimator, const torch::Tensor& weights); + + // Update sample weights based on predictions + void updateSampleWeights(Classifier* estimator, double alpha); + + // Normalize weights to sum to 1 + void normalizeWeights(); + + // Check if hyperparameters values are valid + void checkValues() const; + + // Make predictions for a single sample + int predictSample(const torch::Tensor& x) const; + + // Make probabilistic predictions for a single sample + torch::Tensor predictProbaSample(const torch::Tensor& x) const; + bool debug = false; // Enable debug mode for debug output + }; +} + +#endif // ADABOOST_H \ No newline at end of file diff --git a/src/experimental_clfs/DecisionTree.cpp b/src/experimental_clfs/DecisionTree.cpp new file mode 100644 index 0000000..307186a --- /dev/null +++ b/src/experimental_clfs/DecisionTree.cpp @@ -0,0 +1,495 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include "DecisionTree.h" +#include +#include +#include +#include +#include +#include "TensorUtils.hpp" + +namespace bayesnet { + + DecisionTree::DecisionTree(int max_depth, int min_samples_split, int min_samples_leaf) + : Classifier(Network()), max_depth(max_depth), + min_samples_split(min_samples_split), min_samples_leaf(min_samples_leaf) + { + validHyperparameters = { "max_depth", "min_samples_split", "min_samples_leaf" }; + } + + void DecisionTree::setHyperparameters(const nlohmann::json& hyperparameters_) + { + auto hyperparameters = hyperparameters_; + // Set hyperparameters from JSON + auto it = hyperparameters.find("max_depth"); + if (it != hyperparameters.end()) { + max_depth = it->get(); + hyperparameters.erase("max_depth"); // Remove 'order' if present + } + + it = hyperparameters.find("min_samples_split"); + if (it != hyperparameters.end()) { + min_samples_split = it->get(); + hyperparameters.erase("min_samples_split"); // Remove 'min_samples_split' if present + } + + it = hyperparameters.find("min_samples_leaf"); + if (it != hyperparameters.end()) { + min_samples_leaf = it->get(); + hyperparameters.erase("min_samples_leaf"); // Remove 'min_samples_leaf' if present + } + Classifier::setHyperparameters(hyperparameters); + checkValues(); + } + void DecisionTree::checkValues() + { + if (max_depth <= 0) { + throw std::invalid_argument("max_depth must be positive"); + } + if (min_samples_leaf <= 0) { + throw std::invalid_argument("min_samples_leaf must be positive"); + } + if (min_samples_split <= 0) { + throw std::invalid_argument("min_samples_split must be positive"); + } + } + void DecisionTree::buildModel(const torch::Tensor& weights) + { + // Extract features (X) and labels (y) from dataset + auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }).t(); + auto y = dataset.index({ -1, torch::indexing::Slice() }); + + if (X.size(0) != y.size(0)) { + throw std::runtime_error("X and y must have the same number of samples"); + } + + n_classes = states[className].size(); + + // Use provided weights or uniform weights + torch::Tensor sample_weights; + if (weights.defined() && weights.numel() > 0) { + if (weights.size(0) != X.size(0)) { + throw std::runtime_error("weights must have the same length as number of samples"); + } + sample_weights = weights; + } else { + sample_weights = torch::ones({ X.size(0) }) / X.size(0); + } + + // Normalize weights + sample_weights = sample_weights / sample_weights.sum(); + + // Build the tree + root = buildTree(X, y, sample_weights, 0); + + // Mark as fitted + fitted = true; + } + bool DecisionTree::validateTensors(const torch::Tensor& X, const torch::Tensor& y, + const torch::Tensor& sample_weights) const + { + if (X.size(0) != y.size(0) || X.size(0) != sample_weights.size(0)) { + return false; + } + if (X.size(0) == 0) { + return false; + } + return true; + } + + std::unique_ptr DecisionTree::buildTree( + const torch::Tensor& X, + const torch::Tensor& y, + const torch::Tensor& sample_weights, + int current_depth) + { + auto node = std::make_unique(); + int n_samples = y.size(0); + + // Check stopping criteria + auto unique = at::_unique(y); + bool should_stop = (current_depth >= max_depth) || + (n_samples < min_samples_split) || + (std::get<0>(unique).size(0) == 1); // All samples same class + + if (should_stop || n_samples <= min_samples_leaf) { + // Create leaf node + node->is_leaf = true; + + // Calculate class probabilities + node->class_probabilities = torch::zeros({ n_classes }); + + for (int i = 0; i < n_samples; i++) { + int class_idx = y[i].item(); + node->class_probabilities[class_idx] += sample_weights[i].item(); + } + + // Normalize probabilities + node->class_probabilities /= node->class_probabilities.sum(); + + // Set predicted class as the one with highest probability + node->predicted_class = torch::argmax(node->class_probabilities).item(); + + return node; + } + + // Find best split + SplitInfo best_split = findBestSplit(X, y, sample_weights); + + // If no valid split found, create leaf + if (best_split.feature_index == -1 || best_split.impurity_decrease <= 0) { + node->is_leaf = true; + + // Calculate class probabilities + node->class_probabilities = torch::zeros({ n_classes }); + + for (int i = 0; i < n_samples; i++) { + int class_idx = y[i].item(); + node->class_probabilities[class_idx] += sample_weights[i].item(); + } + + node->class_probabilities /= node->class_probabilities.sum(); + node->predicted_class = torch::argmax(node->class_probabilities).item(); + + return node; + } + + // Create internal node + node->is_leaf = false; + node->split_feature = best_split.feature_index; + node->split_value = best_split.split_value; + + // Split data + auto left_X = X.index({ best_split.left_mask }); + auto left_y = y.index({ best_split.left_mask }); + auto left_weights = sample_weights.index({ best_split.left_mask }); + + auto right_X = X.index({ best_split.right_mask }); + auto right_y = y.index({ best_split.right_mask }); + auto right_weights = sample_weights.index({ best_split.right_mask }); + + // Recursively build subtrees + if (left_X.size(0) >= min_samples_leaf) { + node->left = buildTree(left_X, left_y, left_weights, current_depth + 1); + } else { + // Force leaf if not enough samples + node->left = std::make_unique(); + node->left->is_leaf = true; + auto mode = std::get<0>(torch::mode(left_y)); + node->left->predicted_class = mode.item(); + node->left->class_probabilities = torch::zeros({ n_classes }); + node->left->class_probabilities[node->left->predicted_class] = 1.0; + } + + if (right_X.size(0) >= min_samples_leaf) { + node->right = buildTree(right_X, right_y, right_weights, current_depth + 1); + } else { + // Force leaf if not enough samples + node->right = std::make_unique(); + node->right->is_leaf = true; + auto mode = std::get<0>(torch::mode(right_y)); + node->right->predicted_class = mode.item(); + node->right->class_probabilities = torch::zeros({ n_classes }); + node->right->class_probabilities[node->right->predicted_class] = 1.0; + } + + return node; + } + + DecisionTree::SplitInfo DecisionTree::findBestSplit( + const torch::Tensor& X, + const torch::Tensor& y, + const torch::Tensor& sample_weights) + { + + SplitInfo best_split; + best_split.feature_index = -1; + best_split.split_value = -1; + best_split.impurity_decrease = -std::numeric_limits::infinity(); + + int n_features = X.size(1); + int n_samples = X.size(0); + + // Calculate impurity of current node + double current_impurity = calculateGiniImpurity(y, sample_weights); + double total_weight = sample_weights.sum().item(); + + // Try each feature + for (int feat_idx = 0; feat_idx < n_features; feat_idx++) { + auto feature_values = X.index({ torch::indexing::Slice(), feat_idx }); + auto unique_values = std::get<0>(torch::unique_consecutive(std::get<0>(torch::sort(feature_values)))); + + // Try each unique value as split point + for (int i = 0; i < unique_values.size(0); i++) { + int split_val = unique_values[i].item(); + + // Create masks for left and right splits + auto left_mask = feature_values == split_val; + auto right_mask = ~left_mask; + + int left_count = left_mask.sum().item(); + int right_count = right_mask.sum().item(); + + // Skip if split doesn't satisfy minimum samples requirement + if (left_count < min_samples_leaf || right_count < min_samples_leaf) { + continue; + } + + // Calculate weighted impurities + auto left_y = y.index({ left_mask }); + auto left_weights = sample_weights.index({ left_mask }); + double left_weight = left_weights.sum().item(); + double left_impurity = calculateGiniImpurity(left_y, left_weights); + + auto right_y = y.index({ right_mask }); + auto right_weights = sample_weights.index({ right_mask }); + double right_weight = right_weights.sum().item(); + double right_impurity = calculateGiniImpurity(right_y, right_weights); + + // Calculate impurity decrease + double impurity_decrease = current_impurity - + (left_weight / total_weight * left_impurity + + right_weight / total_weight * right_impurity); + + // Update best split if this is better + if (impurity_decrease > best_split.impurity_decrease) { + best_split.feature_index = feat_idx; + best_split.split_value = split_val; + best_split.impurity_decrease = impurity_decrease; + best_split.left_mask = left_mask; + best_split.right_mask = right_mask; + } + } + } + + return best_split; + } + + double DecisionTree::calculateGiniImpurity( + const torch::Tensor& y, + const torch::Tensor& sample_weights) + { + if (y.size(0) == 0 || sample_weights.size(0) == 0) { + return 0.0; + } + + if (y.size(0) != sample_weights.size(0)) { + throw std::runtime_error("y and sample_weights must have same size"); + } + + torch::Tensor class_weights = torch::zeros({ n_classes }); + + // Calculate weighted class counts + for (int i = 0; i < y.size(0); i++) { + int class_idx = y[i].item(); + + if (class_idx < 0 || class_idx >= n_classes) { + throw std::runtime_error("Invalid class index: " + std::to_string(class_idx)); + } + + class_weights[class_idx] += sample_weights[i].item(); + } + + // Normalize + double total_weight = class_weights.sum().item(); + if (total_weight == 0) return 0.0; + + class_weights /= total_weight; + + // Calculate Gini impurity: 1 - sum(p_i^2) + double gini = 1.0; + for (int i = 0; i < n_classes; i++) { + double p = class_weights[i].item(); + gini -= p * p; + } + + return gini; + } + + + torch::Tensor DecisionTree::predict(torch::Tensor& X) + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + int n_samples = X.size(1); + torch::Tensor predictions = torch::zeros({ n_samples }, torch::kInt32); + + for (int i = 0; i < n_samples; i++) { + auto sample = X.index({ torch::indexing::Slice(), i }).ravel(); + predictions[i] = predictSample(sample); + } + + return predictions; + } + + std::vector DecisionTree::predict(std::vector>& X) + { + // Convert to tensor + long n = X.size(); + long m = X.at(0).size(); + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); + auto predictions = predict(X_tensor); + std::vector result = platform::TensorUtils::to_vector(predictions); + return result; + } + + torch::Tensor DecisionTree::predict_proba(torch::Tensor& X) + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + int n_samples = X.size(1); + torch::Tensor probabilities = torch::zeros({ n_samples, n_classes }); + + for (int i = 0; i < n_samples; i++) { + auto sample = X.index({ torch::indexing::Slice(), i }).ravel(); + probabilities[i] = predictProbaSample(sample); + } + + return probabilities; + } + + std::vector> DecisionTree::predict_proba(std::vector>& X) + { + auto n_samples = X.at(0).size(); + // Convert to tensor + torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X); + auto proba_tensor = predict_proba(X_tensor); + std::vector> result(n_samples, std::vector(n_classes, 0.0)); + + for (int i = 0; i < n_samples; i++) { + for (int j = 0; j < n_classes; j++) { + result[i][j] = proba_tensor[i][j].item(); + } + } + + return result; + } + + int DecisionTree::predictSample(const torch::Tensor& x) const + { + if (!fitted) { + throw std::runtime_error(CLASSIFIER_NOT_FITTED); + } + + if (x.size(0) != n) { // n debería ser el número de características + throw std::runtime_error("Input sample has wrong number of features"); + } + + const TreeNode* leaf = traverseTree(x, root.get()); + return leaf->predicted_class; + } + torch::Tensor DecisionTree::predictProbaSample(const torch::Tensor& x) const + { + const TreeNode* leaf = traverseTree(x, root.get()); + return leaf->class_probabilities.clone(); + } + + + const TreeNode* DecisionTree::traverseTree(const torch::Tensor& x, const TreeNode* node) const + { + if (!node) { + throw std::runtime_error("Null node encountered during tree traversal"); + } + + if (node->is_leaf) { + return node; + } + + if (node->split_feature < 0 || node->split_feature >= x.size(0)) { + throw std::runtime_error("Invalid split_feature index: " + std::to_string(node->split_feature)); + } + + int feature_value = x[node->split_feature].item(); + + if (feature_value == node->split_value) { + if (!node->left) { + throw std::runtime_error("Missing left child in tree"); + } + return traverseTree(x, node->left.get()); + } else { + if (!node->right) { + throw std::runtime_error("Missing right child in tree"); + } + return traverseTree(x, node->right.get()); + } + } + + std::vector DecisionTree::graph(const std::string& title) const + { + std::vector lines; + lines.push_back("digraph DecisionTree {"); + lines.push_back(" rankdir=TB;"); + lines.push_back(" node [shape=box, style=\"filled, rounded\", fontname=\"helvetica\"];"); + lines.push_back(" edge [fontname=\"helvetica\"];"); + + if (!title.empty()) { + lines.push_back(" label=\"" + title + "\";"); + lines.push_back(" labelloc=t;"); + } + + if (root) { + int node_id = 0; + treeToGraph(root.get(), lines, node_id); + } + + lines.push_back("}"); + return lines; + } + + void DecisionTree::treeToGraph( + const TreeNode* node, + std::vector& lines, + int& node_id, + int parent_id, + const std::string& edge_label) const + { + + int current_id = node_id++; + std::stringstream ss; + + if (node->is_leaf) { + // Leaf node + ss << " node" << current_id << " [label=\"Class: " << node->predicted_class; + ss << "\\nProb: " << std::fixed << std::setprecision(3) + << node->class_probabilities[node->predicted_class].item(); + ss << "\", fillcolor=\"lightblue\"];"; + lines.push_back(ss.str()); + } else { + // Internal node + ss << " node" << current_id << " [label=\"" << features[node->split_feature]; + ss << " = " << node->split_value << "?\", fillcolor=\"lightgreen\"];"; + lines.push_back(ss.str()); + } + + // Add edge from parent + if (parent_id >= 0) { + ss.str(""); + ss << " node" << parent_id << " -> node" << current_id; + if (!edge_label.empty()) { + ss << " [label=\"" << edge_label << "\"];"; + } else { + ss << ";"; + } + lines.push_back(ss.str()); + } + + // Recurse on children + if (!node->is_leaf) { + if (node->left) { + treeToGraph(node->left.get(), lines, node_id, current_id, "Yes"); + } + if (node->right) { + treeToGraph(node->right.get(), lines, node_id, current_id, "No"); + } + } + } + +} // namespace bayesnet \ No newline at end of file diff --git a/src/experimental_clfs/DecisionTree.h b/src/experimental_clfs/DecisionTree.h new file mode 100644 index 0000000..8a1c337 --- /dev/null +++ b/src/experimental_clfs/DecisionTree.h @@ -0,0 +1,134 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#ifndef DECISION_TREE_H +#define DECISION_TREE_H + +#include +#include +#include +#include +#include "bayesnet/classifiers/Classifier.h" + +namespace bayesnet { + + // Forward declaration + struct TreeNode; + + class DecisionTree : public Classifier { + public: + explicit DecisionTree(int max_depth = 3, int min_samples_split = 2, int min_samples_leaf = 1); + virtual ~DecisionTree() = default; + + // Override graph method to show tree structure + std::vector graph(const std::string& title = "") const override; + + // Setters for hyperparameters + void setMaxDepth(int depth) { max_depth = depth; checkValues(); } + void setMinSamplesSplit(int samples) { min_samples_split = samples; checkValues(); } + void setMinSamplesLeaf(int samples) { min_samples_leaf = samples; checkValues(); } + int getMaxDepth() const { return max_depth; } + int getMinSamplesSplit() const { return min_samples_split; } + int getMinSamplesLeaf() const { return min_samples_leaf; } + + // Override setHyperparameters + void setHyperparameters(const nlohmann::json& hyperparameters) override; + + torch::Tensor predict(torch::Tensor& X) override; + std::vector predict(std::vector>& X) override; + torch::Tensor predict_proba(torch::Tensor& X) override; + std::vector> predict_proba(std::vector>& X); + + // Make predictions for a single sample + int predictSample(const torch::Tensor& x) const; + + // Make probabilistic predictions for a single sample + torch::Tensor predictProbaSample(const torch::Tensor& x) const; + + protected: + void buildModel(const torch::Tensor& weights) override; + void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override + { + // Decision trees do not require training in the traditional sense + // as they are built from the data directly. + // This method can be used to set weights or other parameters if needed. + } + private: + void checkValues(); + bool validateTensors(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& sample_weights) const; + // Tree hyperparameters + int max_depth; + int min_samples_split; + int min_samples_leaf; + int n_classes; // Number of classes in the target variable + + // Root of the decision tree + std::unique_ptr root; + + // Build tree recursively + std::unique_ptr buildTree( + const torch::Tensor& X, + const torch::Tensor& y, + const torch::Tensor& sample_weights, + int current_depth + ); + + // Find best split for a node + struct SplitInfo { + int feature_index; + int split_value; + double impurity_decrease; + torch::Tensor left_mask; + torch::Tensor right_mask; + }; + + SplitInfo findBestSplit( + const torch::Tensor& X, + const torch::Tensor& y, + const torch::Tensor& sample_weights + ); + + // Calculate weighted Gini impurity for multi-class + double calculateGiniImpurity( + const torch::Tensor& y, + const torch::Tensor& sample_weights + ); + + + + // Traverse tree to find leaf node + const TreeNode* traverseTree(const torch::Tensor& x, const TreeNode* node) const; + + // Convert tree to graph representation + void treeToGraph( + const TreeNode* node, + std::vector& lines, + int& node_id, + int parent_id = -1, + const std::string& edge_label = "" + ) const; + }; + + // Tree node structure + struct TreeNode { + bool is_leaf; + + // For internal nodes + int split_feature; + int split_value; + std::unique_ptr left; + std::unique_ptr right; + + // For leaf nodes + int predicted_class; + torch::Tensor class_probabilities; // Probability for each class + + TreeNode() : is_leaf(false), split_feature(-1), split_value(-1), predicted_class(-1) {} + }; + +} // namespace bayesnet + +#endif // DECISION_TREE_H \ No newline at end of file diff --git a/src/experimental_clfs/ExpClf.h b/src/experimental_clfs/ExpClf.h index fc6d3ec..dbb2140 100644 --- a/src/experimental_clfs/ExpClf.h +++ b/src/experimental_clfs/ExpClf.h @@ -43,6 +43,7 @@ namespace platform { void add_active_parents(const std::vector& active_parents); void add_active_parent(int parent); void remove_last_parent(); + void setHyperparameters(const nlohmann::json& hyperparameters_) override {}; protected: bool debug = false; Xaode aode_; diff --git a/src/experimental_clfs/README.md b/src/experimental_clfs/README.md new file mode 100644 index 0000000..129c608 --- /dev/null +++ b/src/experimental_clfs/README.md @@ -0,0 +1,142 @@ +# AdaBoost and DecisionTree Classifier Implementation + +This implementation provides both a Decision Tree classifier and a multi-class AdaBoost classifier based on the SAMME (Stagewise Additive Modeling using a Multi-class Exponential loss) algorithm described in the paper "Multi-class AdaBoost" by Zhu et al. Implemented in C++ using + +## Components + +### 1. DecisionTree Classifier + +A classic decision tree implementation that: + +- Supports multi-class classification +- Handles weighted samples (essential for boosting) +- Uses Gini impurity as the splitting criterion +- Works with discrete/categorical features +- Provides both class predictions and probability estimates + +#### Key Features + +- **Max Depth Control**: Limit tree depth to create weak learners +- **Minimum Samples**: Control minimum samples for splitting and leaf nodes +- **Weighted Training**: Properly handles sample weights for boosting +- **Visualization**: Generates DOT format graphs of the tree structure + +#### Hyperparameters + +- `max_depth`: Maximum depth of the tree (default: 3) +- `min_samples_split`: Minimum samples required to split a node (default: 2) +- `min_samples_leaf`: Minimum samples required in a leaf node (default: 1) + +### 2. AdaBoost Classifier + +A multi-class AdaBoost implementation using DecisionTree as base estimators: + +- **SAMME Algorithm**: Implements the multi-class extension of AdaBoost +- **Automatic Stumps**: Uses decision stumps (max_depth=1) by default +- **Early Stopping**: Stops if base classifier performs worse than random +- **Ensemble Visualization**: Shows the weighted combination of base estimators + +#### Key Features + +- **Multi-class Support**: Natural extension to K classes +- **Base Estimator Control**: Configure depth of base decision trees +- **Training Monitoring**: Track training errors and estimator weights +- **Probability Estimates**: Provides class probability predictions + +#### Hyperparameters + +- `n_estimators`: Number of base estimators to train (default: 50) +- `base_max_depth`: Maximum depth for base decision trees (default: 1) + +## Algorithm Details + +The SAMME algorithm differs from binary AdaBoost in the calculation of the estimator weight (alpha): + +``` +α = log((1 - err) / err) + log(K - 1) +``` + +where `K` is the number of classes. This formula ensures that: + +- When K = 2, it reduces to standard AdaBoost +- For K > 2, base classifiers only need to be better than random guessing (1/K) rather than 50% + +## Usage Example + +```cpp +// Create AdaBoost with decision stumps +AdaBoost ada(100, 1); // 100 estimators, max_depth=1 + +// Train +ada.fit(X_train, y_train, features, className, states, Smoothing_t::NONE); + +// Predict +auto predictions = ada.predict(X_test); +auto probabilities = ada.predict_proba(X_test); + +// Evaluate +float accuracy = ada.score(X_test, y_test); + +// Get ensemble information +auto weights = ada.getEstimatorWeights(); +auto errors = ada.getTrainingErrors(); +``` + +## Implementation Structure + +``` +AdaBoost (inherits from Ensemble) + └── Uses multiple DecisionTree instances as base estimators + └── DecisionTree (inherits from Classifier) + └── Implements weighted Gini impurity splitting +``` + +## Visualization + +Both classifiers support graph visualization: + +- **DecisionTree**: Shows the tree structure with split conditions +- **AdaBoost**: Shows the ensemble of weighted base estimators + +Generate visualizations using: + +```cpp +auto graph = classifier.graph("Title"); +``` + +## Data Format + +Both classifiers expect discrete/categorical data: + +- **Features**: Integer values representing categories (stored in `torch::Tensor` or `std::vector>`) +- **Labels**: Integer values representing class indices (0, 1, ..., K-1) +- **States**: Map defining possible values for each feature and the class variable +- **Sample Weights**: Optional weights for each training sample (important for boosting) + +Example data setup: + +```cpp +// Features matrix (n_features x n_samples) +torch::Tensor X = torch::tensor({{0, 1, 2}, {1, 0, 1}}); // 2 features, 3 samples + +// Labels vector +torch::Tensor y = torch::tensor({0, 1, 0}); // 3 samples + +// States definition +std::map> states; +states["feature1"] = {0, 1, 2}; // Feature 1 can take values 0, 1, or 2 +states["feature2"] = {0, 1}; // Feature 2 can take values 0 or 1 +states["class"] = {0, 1}; // Binary classification +``` + +## Notes + +- The implementation handles discrete/categorical features as indicated by the int-based data structures +- Sample weights are properly propagated through the tree building process +- The DecisionTree implementation uses equality testing for splits (suitable for categorical data) +- Both classifiers support the standard fit/predict interface from the base framework + +## References + +- Zhu, J., Zou, H., Rosset, S., & Hastie, T. (2009). Multi-class AdaBoost. Statistics and its interface, 2(3), 349-360. +- Breiman, L., Friedman, J., Olshen, R., & Stone, C. (1984). Classification and Regression Trees. Wadsworth, Belmont, CA. diff --git a/src/experimental_clfs/TensorUtils.hpp b/src/experimental_clfs/TensorUtils.hpp index 6f09859..2efdf7d 100644 --- a/src/experimental_clfs/TensorUtils.hpp +++ b/src/experimental_clfs/TensorUtils.hpp @@ -45,7 +45,53 @@ namespace platform { return data; } + static torch::Tensor to_matrix(const std::vector>& data) + { + if (data.empty()) return torch::empty({ 0, 0 }, torch::kInt64); + size_t rows = data.size(); + size_t cols = data[0].size(); + torch::Tensor tensor = torch::empty({ static_cast(rows), static_cast(cols) }, torch::kInt64); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + tensor.index_put_({ static_cast(i), static_cast(j) }, data[i][j]); + } + } + return tensor; + } }; + static void dumpVector(const std::vector>& vec, const std::string& name) + { + std::cout << name << ": " << std::endl; + for (const auto& row : vec) { + std::cout << "["; + for (const auto& val : row) { + std::cout << val << " "; + } + std::cout << "]" << std::endl; + } + std::cout << std::endl; + } + static void dumpTensor(const torch::Tensor& tensor, const std::string& name) + { + std::cout << name << ": " << std::endl; + for (auto i = 0; i < tensor.size(0); i++) { + std::cout << "["; + for (auto j = 0; j < tensor.size(1); j++) { + std::cout << tensor[i][j].item() << " "; + } + std::cout << "]" << std::endl; + } + std::cout << std::endl; + } + static void dumpTensorV(const torch::Tensor& tensor, const std::string& name) + { + std::cout << name << ": " << std::endl; + std::cout << "["; + for (int i = 0; i < tensor.size(0); i++) { + std::cout << tensor[i].item() << " "; + } + std::cout << "]" << std::endl; + } } #endif // TENSORUTILS_HPP \ No newline at end of file diff --git a/src/main/ArgumentsExperiment.cpp b/src/main/ArgumentsExperiment.cpp index aa8199e..d27f9a3 100644 --- a/src/main/ArgumentsExperiment.cpp +++ b/src/main/ArgumentsExperiment.cpp @@ -13,6 +13,7 @@ namespace platform { auto env = platform::DotEnv(); auto datasets = platform::Datasets(false, platform::Paths::datasets()); auto& group = arguments.add_mutually_exclusive_group(true); + group.add_argument("-d", "--dataset") .help("Dataset file name: " + datasets.toString()) .default_value("all") @@ -43,6 +44,7 @@ namespace platform { } ); arguments.add_argument("--title").default_value("").help("Experiment title"); + arguments.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results()); arguments.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); auto valid_choices = env.valid_tokens("discretize_algo"); auto& disc_arg = arguments.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo")); @@ -103,6 +105,10 @@ namespace platform { file_name = arguments.get("dataset"); file_names = arguments.get>("datasets"); datasets_file = arguments.get("datasets-file"); + path_results = arguments.get("folder"); + if (path_results.back() != '/') { + path_results += '/'; + } model_name = arguments.get("model"); discretize_dataset = arguments.get("discretize"); discretize_algo = arguments.get("discretize-algo"); @@ -119,7 +125,7 @@ namespace platform { hyper_best = arguments.get("hyper-best"); if (hyper_best) { // Build the best results file_name - hyperparameters_file = platform::Paths::results() + platform::Paths::bestResultsFile(score, model_name); + hyperparameters_file = path_results + platform::Paths::bestResultsFile(score, model_name); // ignore this parameter hyperparameters = "{}"; } else { @@ -209,10 +215,36 @@ namespace platform { test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_json); } } + std::string getGppVersion() + { + std::string result; + std::array buffer; + + // Run g++ --version and capture the output + using pclose_t = int(*)(FILE*); + std::unique_ptr pipe(popen("g++ --version", "r"), pclose); + + if (!pipe) { + return "Error executing g++ --version command"; + } + + // Read the first line of output (which contains the version info) + if (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { + result = buffer.data(); + // Remove trailing newline if present + if (!result.empty() && result[result.length() - 1] == '\n') { + result.erase(result.length() - 1); + } + } else { + return "No output from g++ --version command"; + } + + return result; + } Experiment& ArgumentsExperiment::initializedExperiment() { auto env = platform::DotEnv(); - experiment.setTitle(title).setLanguage("c++").setLanguageVersion("gcc 14.1.1"); + experiment.setTitle(title).setLanguage("c++").setLanguageVersion(getGppVersion()); experiment.setDiscretizationAlgorithm(discretize_algo).setSmoothSrategy(smooth_strat); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform")); experiment.setStratified(stratified).setNFolds(n_folds).setScoreName(score); diff --git a/src/main/ArgumentsExperiment.h b/src/main/ArgumentsExperiment.h index c4528b9..5f2fbc2 100644 --- a/src/main/ArgumentsExperiment.h +++ b/src/main/ArgumentsExperiment.h @@ -22,11 +22,13 @@ namespace platform { bool isQuiet() const { return quiet; } bool haveToSaveResults() const { return saveResults; } bool doGraph() const { return graph; } + std::string getPathResults() const { return path_results; } private: Experiment experiment; experiment_t type; argparse::ArgumentParser& arguments; - std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat, score; + std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat; + std::string score, path_results; json hyperparameters_json; bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph, hyper_best; std::vector seeds; diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 06cbc41..e240e7c 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -7,12 +7,12 @@ namespace platform { using json = nlohmann::ordered_json; - void Experiment::saveResult() + void Experiment::saveResult(const std::string& path) { result.setSchemaVersion("1.0"); result.check(); - result.save(); - std::cout << "Result saved in " << Paths::results() << result.getFilename() << std::endl; + result.save(path); + std::cout << "Result saved in " << path << result.getFilename() << std::endl; } void Experiment::report() { @@ -245,8 +245,6 @@ namespace platform { // Train model // clf->fit(X_train, y_train, features, className, states, smooth_type); - if (!quiet) - showProgress(nfold + 1, getColor(clf->getStatus()), "b"); auto clf_notes = clf->getNotes(); std::transform(clf_notes.begin(), clf_notes.end(), std::back_inserter(notes), [nfold](const std::string& note) { return "Fold " + std::to_string(nfold) + ": " + note; }); @@ -259,10 +257,13 @@ namespace platform { // Score train // if (!no_train_score) { + if (!quiet) + showProgress(nfold + 1, getColor(clf->getStatus()), "b"); auto y_proba_train = clf->predict_proba(X_train); Scores scores(y_train, y_proba_train, num_classes, labels); score_train_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc(); - confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true)); + if (discretized) + confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true)); } // // Test model @@ -277,7 +278,8 @@ namespace platform { test_time[item] = test_timer.getDuration(); score_train[item] = score_train_value; score_test[item] = score_test_value; - confusion_matrices.push_back(scores.get_confusion_matrix_json(true)); + if (discretized) + confusion_matrices.push_back(scores.get_confusion_matrix_json(true)); if (!quiet) std::cout << "\b\b\b, " << flush; // diff --git a/src/main/Experiment.h b/src/main/Experiment.h index bfad61c..52c08fe 100644 --- a/src/main/Experiment.h +++ b/src/main/Experiment.h @@ -45,7 +45,7 @@ namespace platform { std::vector getRandomSeeds() const { return randomSeeds; } void cross_validation(const std::string& fileName); void go(); - void saveResult(); + void saveResult(const std::string& path); void show(); void saveGraph(); void report(); diff --git a/src/main/Models.h b/src/main/Models.h index 565a96d..3640a01 100644 --- a/src/main/Models.h +++ b/src/main/Models.h @@ -23,8 +23,11 @@ #include #include #include +#include #include #include "../experimental_clfs/XA1DE.h" +#include "../experimental_clfs/AdaBoost.h" +#include "../experimental_clfs/DecisionTree.h" namespace platform { class Models { diff --git a/src/main/RocAuc.cpp b/src/main/RocAuc.cpp index 03a1c31..c41d986 100644 --- a/src/main/RocAuc.cpp +++ b/src/main/RocAuc.cpp @@ -4,7 +4,7 @@ #include #include "RocAuc.h" namespace platform { - + double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels) { size_t nClasses = y_proba.size(1); @@ -48,6 +48,7 @@ namespace platform { double tp = 0, fp = 0; double totalPos = std::count(y_test.begin(), y_test.end(), classIdx); double totalNeg = nSamples - totalPos; + if (totalPos == 0 || totalNeg == 0) return 0.5; // neutral AUC for (const auto& [score, label] : scoresAndLabels) { if (label == 1) { diff --git a/src/main/modelRegister.h b/src/main/modelRegister.h index 0dbd269..5f44728 100644 --- a/src/main/modelRegister.h +++ b/src/main/modelRegister.h @@ -35,6 +35,12 @@ namespace platform { [](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();}); static Registrar registrarXGB("XGBoost", [](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();}); + static Registrar registrarAdaPy("AdaBoostPy", + [](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoostPy();}); + static Registrar registrarAda("AdaBoost", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AdaBoost();}); + static Registrar registrarDT("DecisionTree", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::DecisionTree();}); static Registrar registrarXSPODE("XSPODE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);}); static Registrar registrarXSP2DE("XSP2DE", @@ -44,6 +50,6 @@ namespace platform { static Registrar registrarXBA2DE("XBA2DE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::XBA2DE();}); static Registrar registrarXA1DE("XA1DE", - [](void) -> bayesnet::BaseClassifier* { return new XA1DE();}); + [](void) -> bayesnet::BaseClassifier* { return new XA1DE();}); } #endif diff --git a/src/manage/ManageScreen.cpp b/src/manage/ManageScreen.cpp index 6648d94..d9f65fc 100644 --- a/src/manage/ManageScreen.cpp +++ b/src/manage/ManageScreen.cpp @@ -18,8 +18,8 @@ namespace platform { const std::string STATUS_OK = "Ok."; const std::string STATUS_COLOR = Colors::GREEN(); - ManageScreen::ManageScreen(int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare) : - rows{ rows }, cols{ cols }, complete{ complete }, partial{ partial }, compare{ compare }, didExcel(false), results(ResultsManager(model, score, platform, complete, partial)) + ManageScreen::ManageScreen(const std::string path_, int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare) : + path{ path_ }, rows{ rows }, cols{ cols }, complete{ complete }, partial{ partial }, compare{ compare }, didExcel(false), results(ResultsManager(path_, model, score, platform, complete, partial)) { results.load(); openExcel = false; @@ -329,11 +329,11 @@ namespace platform { return; } // Remove the old result file - std::string oldFile = Paths::results() + results.at(index).getFilename(); + std::string oldFile = path + results.at(index).getFilename(); std::filesystem::remove(oldFile); // Actually change the model results.at(index).setModel(newModel); - results.at(index).save(); + results.at(index).save(path); int newModelSize = static_cast(newModel.size()); if (newModelSize > maxModel) { maxModel = newModelSize; @@ -583,7 +583,7 @@ namespace platform { getline(std::cin, newTitle); if (!newTitle.empty()) { results.at(index).setTitle(newTitle); - results.at(index).save(); + results.at(index).save(path); list("Title changed to " + newTitle, Colors::GREEN()); break; } diff --git a/src/manage/ManageScreen.h b/src/manage/ManageScreen.h index 7e41896..46a02c4 100644 --- a/src/manage/ManageScreen.h +++ b/src/manage/ManageScreen.h @@ -15,7 +15,7 @@ namespace platform { }; class ManageScreen { public: - ManageScreen(int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare); + ManageScreen(const std::string path, int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare); ~ManageScreen() = default; void doMenu(); void updateSize(int rows, int cols); @@ -59,7 +59,7 @@ namespace platform { std::vector paginator; ResultsManager results; lxw_workbook* workbook; - std::string excelFileName; + std::string path, excelFileName; }; } #endif \ No newline at end of file diff --git a/src/manage/ResultsManager.cpp b/src/manage/ResultsManager.cpp index f4b722a..fa4e895 100644 --- a/src/manage/ResultsManager.cpp +++ b/src/manage/ResultsManager.cpp @@ -1,10 +1,9 @@ #include -#include "common/Paths.h" #include "ResultsManager.h" namespace platform { - ResultsManager::ResultsManager(const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial) : - path(Paths::results()), model(model), scoreName(score), platform(platform), complete(complete), partial(partial), maxModel(0), maxTitle(0) + ResultsManager::ResultsManager(const std::string& path_, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial) : + path(path_), model(model), scoreName(score), platform(platform), complete(complete), partial(partial), maxModel(0), maxTitle(0) { } void ResultsManager::load() diff --git a/src/manage/ResultsManager.h b/src/manage/ResultsManager.h index cabf909..f891c44 100644 --- a/src/manage/ResultsManager.h +++ b/src/manage/ResultsManager.h @@ -18,7 +18,7 @@ namespace platform { }; class ResultsManager { public: - ResultsManager(const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial); + ResultsManager(const std::string& path_, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial); void load(); // Loads the list of results void sortResults(SortField field, SortType type); // Sorts the list of results void sortDate(SortType type); diff --git a/src/reports/DatasetsConsole.cpp b/src/reports/DatasetsConsole.cpp index d402a25..82dab4f 100644 --- a/src/reports/DatasetsConsole.cpp +++ b/src/reports/DatasetsConsole.cpp @@ -26,6 +26,7 @@ namespace platform { auto datasets = platform::Datasets(false, platform::Paths::datasets()); std::stringstream sheader; auto datasets_names = datasets.getNames(); + std::cout << Colors::GREEN() << "Datasets available in the platform: " << datasets_names.size() << std::endl; int maxName = std::max(size_t(7), (*max_element(datasets_names.begin(), datasets_names.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size()); std::vector header_labels = { " #", "Dataset", "Sampl.", "Feat.", "#Num.", "Cls", "Balance" }; std::vector header_lengths = { 3, maxName, 6, 6, 6, 3, DatasetsConsole::BALANCE_LENGTH }; @@ -61,9 +62,13 @@ namespace platform { line << setw(header_lengths[5]) << right << nClasses << " "; std::string sep = ""; oss.str(""); - for (auto number : dataset.getClassesCounts()) { - oss << sep << std::setprecision(2) << fixed << (float)number / nSamples * 100.0 << "% (" << number << ")"; - sep = " / "; + if (nSamples == 0) { + oss << "No samples"; + } else { + for (auto number : dataset.getClassesCounts()) { + oss << sep << std::setprecision(2) << fixed << (float)number / nSamples * 100.0 << "% (" << number << ")"; + sep = " / "; + } } split_lines(maxName, line.str(), oss.str()); // Store data for Excel report diff --git a/src/reports/DatasetsExcel.cpp b/src/reports/DatasetsExcel.cpp index a24def6..cb1dd37 100644 --- a/src/reports/DatasetsExcel.cpp +++ b/src/reports/DatasetsExcel.cpp @@ -1,8 +1,9 @@ +#include "common/Paths.h" #include "DatasetsExcel.h" namespace platform { DatasetsExcel::DatasetsExcel() { - file_name = "datasets.xlsx"; + file_name = Paths::excelDatasets(); workbook = workbook_new(getFileName().c_str()); createFormats(); setProperties("Datasets"); diff --git a/src/reports/DatasetsExcel.h b/src/reports/DatasetsExcel.h index cd543cd..4b528b3 100644 --- a/src/reports/DatasetsExcel.h +++ b/src/reports/DatasetsExcel.h @@ -11,6 +11,7 @@ namespace platform { DatasetsExcel(); ~DatasetsExcel(); void report(json& data); + std::string getExcelFileName() { return getFileName(); } }; } #endif \ No newline at end of file diff --git a/src/results/Result.cpp b/src/results/Result.cpp index c143874..f37ca61 100644 --- a/src/results/Result.cpp +++ b/src/results/Result.cpp @@ -69,9 +69,9 @@ namespace platform { platform::JsonValidator validator(platform::SchemaV1_0::schema); return validator.validate(data); } - void Result::save() + void Result::save(const std::string& path) { - std::ofstream file(Paths::results() + getFilename()); + std::ofstream file(path + getFilename()); file << data; file.close(); } diff --git a/src/results/Result.h b/src/results/Result.h index 3f45c70..a16748d 100644 --- a/src/results/Result.h +++ b/src/results/Result.h @@ -15,7 +15,7 @@ namespace platform { public: Result(); Result& load(const std::string& path, const std::string& filename); - void save(); + void save(const std::string& path); std::vector check(); // Getters json getJson(); diff --git a/src/results/ResultsDatasetExcel.cpp b/src/results/ResultsDatasetExcel.cpp index b4f1225..df1604e 100644 --- a/src/results/ResultsDatasetExcel.cpp +++ b/src/results/ResultsDatasetExcel.cpp @@ -1,8 +1,9 @@ +#include "common/Paths.h" #include "ResultsDatasetExcel.h" namespace platform { ResultsDatasetExcel::ResultsDatasetExcel() { - file_name = "some_results.xlsx"; + file_name = Paths::excelResults(); workbook = workbook_new(getFileName().c_str()); createFormats(); setProperties("Results"); diff --git a/src/results/ResultsDatasetExcel.h b/src/results/ResultsDatasetExcel.h index 83226cc..3f9b968 100644 --- a/src/results/ResultsDatasetExcel.h +++ b/src/results/ResultsDatasetExcel.h @@ -12,6 +12,7 @@ namespace platform { ResultsDatasetExcel(); ~ResultsDatasetExcel(); void report(json& data); + std::string getExcelFileName() { return getFileName(); } }; } #endif \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fce68a8..18317bb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -12,11 +12,13 @@ if(ENABLE_TESTING) ${Bayesnet_INCLUDE_DIRS} ) set(TEST_SOURCES_PLATFORM - TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp + TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp TestDecisionTree.cpp TestAdaBoost.cpp ${Platform_SOURCE_DIR}/src/common/Datasets.cpp ${Platform_SOURCE_DIR}/src/common/Dataset.cpp ${Platform_SOURCE_DIR}/src/common/Discretization.cpp - ${Platform_SOURCE_DIR}/src/main/Scores.cpp + ${Platform_SOURCE_DIR}/src/main/Scores.cpp + ${Platform_SOURCE_DIR}/src/experimental_clfs/DecisionTree.cpp + ${Platform_SOURCE_DIR}/src/experimental_clfs/AdaBoost.cpp ) add_executable(${TEST_PLATFORM} ${TEST_SOURCES_PLATFORM}) - target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" mdlp Catch2::Catch2WithMain BayesNet) + target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" fimdlp Catch2::Catch2WithMain bayesnet) add_test(NAME ${TEST_PLATFORM} COMMAND ${TEST_PLATFORM}) endif(ENABLE_TESTING) diff --git a/tests/TestAdaBoost.cpp b/tests/TestAdaBoost.cpp new file mode 100644 index 0000000..81d5673 --- /dev/null +++ b/tests/TestAdaBoost.cpp @@ -0,0 +1,547 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include +#include +#include +#include +#include +#include +#include +#include "experimental_clfs/AdaBoost.h" +#include "experimental_clfs/DecisionTree.h" +#include "experimental_clfs/TensorUtils.hpp" +#include "TestUtils.h" + +using namespace bayesnet; +using namespace Catch::Matchers; + +static const bool DEBUG = false; + +TEST_CASE("AdaBoost Construction", "[AdaBoost]") +{ + SECTION("Default constructor") + { + REQUIRE_NOTHROW(AdaBoost()); + } + + SECTION("Constructor with parameters") + { + REQUIRE_NOTHROW(AdaBoost(100, 2)); + } + + SECTION("Constructor parameter access") + { + AdaBoost ada(75, 3); + REQUIRE(ada.getNEstimators() == 75); + REQUIRE(ada.getBaseMaxDepth() == 3); + } +} + +TEST_CASE("AdaBoost Hyperparameter Setting", "[AdaBoost]") +{ + AdaBoost ada; + + SECTION("Set individual hyperparameters") + { + REQUIRE_NOTHROW(ada.setNEstimators(100)); + REQUIRE_NOTHROW(ada.setBaseMaxDepth(5)); + + REQUIRE(ada.getNEstimators() == 100); + REQUIRE(ada.getBaseMaxDepth() == 5); + } + + SECTION("Set hyperparameters via JSON") + { + nlohmann::json params; + params["n_estimators"] = 80; + params["base_max_depth"] = 4; + + REQUIRE_NOTHROW(ada.setHyperparameters(params)); + } + + SECTION("Invalid hyperparameters should throw") + { + nlohmann::json params; + + // Negative n_estimators + params["n_estimators"] = -1; + REQUIRE_THROWS_AS(ada.setHyperparameters(params), std::invalid_argument); + + // Zero n_estimators + params["n_estimators"] = 0; + REQUIRE_THROWS_AS(ada.setHyperparameters(params), std::invalid_argument); + + // Negative base_max_depth + params["n_estimators"] = 50; + params["base_max_depth"] = -1; + REQUIRE_THROWS_AS(ada.setHyperparameters(params), std::invalid_argument); + + // Zero base_max_depth + params["base_max_depth"] = 0; + REQUIRE_THROWS_AS(ada.setHyperparameters(params), std::invalid_argument); + } +} + +TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") +{ + // Create a simple dataset + int n_samples = 20; + int n_features = 2; + + std::vector> X(n_features, std::vector(n_samples)); + std::vector y(n_samples); + + // Simple pattern: class depends on first feature + for (int i = 0; i < n_samples; i++) { + X[0][i] = i < 10 ? 0 : 1; + X[1][i] = i % 2; + y[i] = X[0][i]; // Class equals first feature + } + + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + SECTION("Training with vector interface") + { + AdaBoost ada(10, 3); // 10 estimators, max_depth = 3 + REQUIRE_NOTHROW(ada.fit(X, y, features, className, states, Smoothing_t::NONE)); + + // Check that we have the expected number of models + auto weights = ada.getEstimatorWeights(); + REQUIRE(weights.size() <= 10); // Should be <= n_estimators + REQUIRE(weights.size() > 0); // Should have at least one model + + // Check training errors + auto errors = ada.getTrainingErrors(); + REQUIRE(errors.size() == weights.size()); + + // All training errors should be less than 0.5 for this simple dataset + for (double error : errors) { + REQUIRE(error < 0.5); + REQUIRE(error >= 0.0); + } + } + + SECTION("Prediction before fitting") + { + AdaBoost ada; + REQUIRE_THROWS_WITH(ada.predict(X), + ContainsSubstring("not been fitted")); + REQUIRE_THROWS_WITH(ada.predict_proba(X), + ContainsSubstring("not been fitted")); + } + + SECTION("Prediction with vector interface") + { + AdaBoost ada(10, 3); + ada.setDebug(DEBUG); // Enable debug to investigate + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto predictions = ada.predict(X); + REQUIRE(predictions.size() == static_cast(n_samples)); + // Check accuracy + int correct = 0; + for (size_t i = 0; i < predictions.size(); i++) { + if (predictions[i] == y[i]) correct++; + } + double accuracy = static_cast(correct) / n_samples; + REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset + auto accuracy_computed = ada.score(X, y); + REQUIRE(accuracy_computed == Catch::Approx(accuracy).epsilon(1e-6)); + } + + SECTION("Probability predictions with vector interface") + { + AdaBoost ada(10, 3); + ada.setDebug(DEBUG); // ENABLE DEBUG HERE TOO + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto proba = ada.predict_proba(X); + REQUIRE(proba.size() == static_cast(n_samples)); + REQUIRE(proba[0].size() == 2); // Two classes + + // Check probabilities sum to 1 and are valid + auto predictions = ada.predict(X); + int correct = 0; + for (size_t i = 0; i < proba.size(); i++) { + auto p = proba[i]; + auto pred = predictions[i]; + REQUIRE(p.size() == 2); + REQUIRE(p[0] >= 0.0); + REQUIRE(p[1] >= 0.0); + double sum = p[0] + p[1]; + REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6)); + // compute the predicted class based on probabilities + auto predicted_class = (p[0] > p[1]) ? 0 : 1; + // compute accuracy based on predictions + if (predicted_class == y[i]) { + correct++; + } + + INFO("Probability test - Sample " << i << ": pred=" << pred << ", probs=[" << p[0] << "," << p[1] << "], expected_from_probs=" << predicted_class); + + // Handle ties + if (std::abs(p[0] - p[1]) < 1e-10) { + INFO("Tie detected in probabilities"); + // Either prediction is valid in case of tie + } else { + // Check that predict_proba matches the expected predict value + REQUIRE(pred == predicted_class); + } + } + double accuracy = static_cast(correct) / n_samples; + REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset + } +} + +TEST_CASE("AdaBoost Tensor Interface", "[AdaBoost]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Training with tensor format") + { + AdaBoost ada(20, 3); + + INFO("Dataset shape: " << raw.dataset.sizes()); + INFO("Features: " << raw.featurest.size()); + INFO("Samples: " << raw.nSamples); + + // AdaBoost expects dataset in format: features x samples, with labels as last row + REQUIRE_NOTHROW(ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE)); + + // Test prediction with tensor + auto predictions = ada.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + // Calculate accuracy + auto correct = torch::sum(predictions == raw.yt).item(); + double accuracy = static_cast(correct) / raw.yt.size(0); + auto accuracy_computed = ada.score(raw.Xt, raw.yt); + REQUIRE(accuracy_computed == Catch::Approx(accuracy).epsilon(1e-6)); + REQUIRE(accuracy > 0.97); // Should achieve good accuracy on Iris + + // Test probability predictions with tensor + auto proba = ada.predict_proba(raw.Xt); + REQUIRE(proba.size(0) == raw.yt.size(0)); + REQUIRE(proba.size(1) == 3); // Three classes in Iris + + // Check probabilities sum to 1 + auto prob_sums = torch::sum(proba, 1); + for (int i = 0; i < prob_sums.size(0); i++) { + REQUIRE(prob_sums[i].item() == Catch::Approx(1.0).epsilon(1e-6)); + } + } +} + +TEST_CASE("AdaBoost SAMME Algorithm Validation", "[AdaBoost]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Prediction consistency with probabilities") + { + AdaBoost ada(15, 3); + ada.setDebug(DEBUG); // Enable debug for ALL instances + ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = ada.predict(raw.Xt); + auto probabilities = ada.predict_proba(raw.Xt); + + REQUIRE(predictions.size(0) == probabilities.size(0)); + REQUIRE(probabilities.size(1) == 3); // Three classes in Iris + + // For each sample, predicted class should correspond to highest probability + for (int i = 0; i < predictions.size(0); i++) { + int predicted_class = predictions[i].item(); + auto probs = probabilities[i]; + + // Find class with highest probability + auto max_prob_idx = torch::argmax(probs).item(); + + // Predicted class should match class with highest probability + REQUIRE(predicted_class == max_prob_idx); + + // Probabilities should sum to 1 + double sum_probs = torch::sum(probs).item(); + REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6)); + + // All probabilities should be non-negative + for (int j = 0; j < 3; j++) { + REQUIRE(probs[j].item() >= 0.0); + REQUIRE(probs[j].item() <= 1.0); + } + } + } + + SECTION("Weighted voting verification") + { + // Simple dataset where we can verify the weighted voting + std::vector> X = { {0,0,1,1}, {0,1,0,1} }; + std::vector y = { 0, 1, 1, 0 }; + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + AdaBoost ada(5, 2); + ada.setDebug(DEBUG); // Enable debug for detailed logging + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + INFO("=== Final test verification ==="); + auto predictions = ada.predict(X); + auto probabilities = ada.predict_proba(X); + auto alphas = ada.getEstimatorWeights(); + + INFO("Training info:"); + for (size_t i = 0; i < alphas.size(); i++) { + INFO(" Model " << i << ": alpha=" << alphas[i]); + } + + REQUIRE(predictions.size() == 4); + REQUIRE(probabilities.size() == 4); + REQUIRE(probabilities[0].size() == 2); // Two classes + REQUIRE(alphas.size() > 0); + + // Verify that estimator weights are reasonable + for (double alpha : alphas) { + REQUIRE(alpha >= 0.0); // Alphas should be non-negative + } + + // Verify prediction-probability consistency with detailed logging + for (size_t i = 0; i < predictions.size(); i++) { + int pred = predictions[i]; + auto probs = probabilities[i]; + + INFO("Final check - Sample " << i << ": predicted=" << pred << ", probabilities=[" << probs[0] << "," << probs[1] << "]"); + + // Handle the case where probabilities are exactly equal (tie) + if (std::abs(probs[0] - probs[1]) < 1e-10) { + INFO("Tie detected in probabilities - either prediction is valid"); + REQUIRE((pred == 0 || pred == 1)); + } else { + // Normal case - prediction should match max probability + int expected_pred = (probs[0] > probs[1]) ? 0 : 1; + INFO("Expected prediction based on probs: " << expected_pred); + REQUIRE(pred == expected_pred); + } + + REQUIRE(probs[0] + probs[1] == Catch::Approx(1.0).epsilon(1e-6)); + } + } + + SECTION("Empty models edge case") + { + AdaBoost ada(1, 1); + ada.setDebug(DEBUG); // Enable debug for ALL instances + + // Try to predict before fitting + std::vector> X = { {0}, {1} }; + REQUIRE_THROWS_WITH(ada.predict(X), ContainsSubstring("not been fitted")); + REQUIRE_THROWS_WITH(ada.predict_proba(X), ContainsSubstring("not been fitted")); + } +} + +TEST_CASE("AdaBoost Debug - Simple Dataset Analysis", "[AdaBoost][debug]") +{ + // Create the exact same simple dataset that was failing + int n_samples = 20; + int n_features = 2; + + std::vector> X(n_features, std::vector(n_samples)); + std::vector y(n_samples); + + // Simple pattern: class depends on first feature + for (int i = 0; i < n_samples; i++) { + X[0][i] = i < 10 ? 0 : 1; + X[1][i] = i % 2; + y[i] = X[0][i]; // Class equals first feature + } + + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + SECTION("Debug training process") + { + AdaBoost ada(5, 3); // Few estimators for debugging + ada.setDebug(DEBUG); + + // This should work perfectly on this simple dataset + REQUIRE_NOTHROW(ada.fit(X, y, features, className, states, Smoothing_t::NONE)); + + // Get training details + auto weights = ada.getEstimatorWeights(); + auto errors = ada.getTrainingErrors(); + + INFO("Number of models trained: " << weights.size()); + INFO("Training errors: "); + for (size_t i = 0; i < errors.size(); i++) { + INFO(" Model " << i << ": error=" << errors[i] << ", weight=" << weights[i]); + } + + // Should have at least one model + REQUIRE(weights.size() > 0); + REQUIRE(errors.size() == weights.size()); + + // All training errors should be reasonable for this simple dataset + for (double error : errors) { + REQUIRE(error >= 0.0); + REQUIRE(error < 0.5); // Should be better than random + } + + // Test predictions + auto predictions = ada.predict(X); + REQUIRE(predictions.size() == static_cast(n_samples)); + + // Calculate accuracy + int correct = 0; + for (size_t i = 0; i < predictions.size(); i++) { + if (predictions[i] == y[i]) correct++; + INFO("Sample " << i << ": predicted=" << predictions[i] << ", actual=" << y[i]); + } + double accuracy = static_cast(correct) / n_samples; + INFO("Accuracy: " << accuracy); + + // Should achieve high accuracy on this perfectly separable dataset + REQUIRE(accuracy >= 0.9); // Lower threshold for debugging + + // Test probability predictions + auto proba = ada.predict_proba(X); + REQUIRE(proba.size() == static_cast(n_samples)); + + // Verify probabilities are valid + for (size_t i = 0; i < proba.size(); i++) { + auto p = proba[i]; + REQUIRE(p.size() == 2); + REQUIRE(p[0] >= 0.0); + REQUIRE(p[1] >= 0.0); + double sum = p[0] + p[1]; + REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6)); + + // Predicted class should match highest probability + int pred_class = predictions[i]; + + // Handle ties + if (std::abs(p[0] - p[1]) < 1e-10) { + INFO("Tie detected - probabilities are equal"); + REQUIRE((pred_class == 0 || pred_class == 1)); + } else { + REQUIRE(pred_class == (p[0] > p[1] ? 0 : 1)); + } + } + } + + SECTION("Compare with single DecisionTree") + { + // Test that AdaBoost performs at least as well as a single tree + DecisionTree single_tree(3, 2, 1); + single_tree.fit(X, y, features, className, states, Smoothing_t::NONE); + auto tree_predictions = single_tree.predict(X); + + int tree_correct = 0; + for (size_t i = 0; i < tree_predictions.size(); i++) { + if (tree_predictions[i] == y[i]) tree_correct++; + } + double tree_accuracy = static_cast(tree_correct) / n_samples; + + AdaBoost ada(5, 3); + ada.setDebug(DEBUG); + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + auto ada_predictions = ada.predict(X); + + int ada_correct = 0; + for (size_t i = 0; i < ada_predictions.size(); i++) { + if (ada_predictions[i] == y[i]) ada_correct++; + } + double ada_accuracy = static_cast(ada_correct) / n_samples; + + INFO("DecisionTree accuracy: " << tree_accuracy); + INFO("AdaBoost accuracy: " << ada_accuracy); + + // AdaBoost should perform at least as well as single tree + // (allowing small tolerance for numerical differences) + REQUIRE(ada_accuracy >= tree_accuracy - 0.1); + } +} + +TEST_CASE("AdaBoost Predict-Proba Consistency Fix", "[AdaBoost][consistency]") +{ + // Simple binary classification dataset + std::vector> X = { {0,0,1,1}, {0,1,0,1} }; + std::vector y = { 0, 0, 1, 1 }; + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + SECTION("Binary classification consistency") + { + AdaBoost ada(3, 2); + ada.setDebug(DEBUG); // Enable debug output + ada.fit(X, y, features, className, states, Smoothing_t::NONE); + + INFO("=== Debugging predict vs predict_proba consistency ==="); + + // Get training info + auto alphas = ada.getEstimatorWeights(); + auto errors = ada.getTrainingErrors(); + + INFO("Training completed:"); + INFO(" Number of models: " << alphas.size()); + for (size_t i = 0; i < alphas.size(); i++) { + INFO(" Model " << i << ": alpha=" << alphas[i] << ", error=" << errors[i]); + } + + auto predictions = ada.predict(X); + auto probabilities = ada.predict_proba(X); + + // Verify consistency for each sample + for (size_t i = 0; i < predictions.size(); i++) { + int predicted_class = predictions[i]; + auto probs = probabilities[i]; + + INFO("Sample " << i << ":"); + INFO(" Features: [" << X[0][i] << ", " << X[1][i] << "]"); + INFO(" True class: " << y[i]); + INFO(" Predicted class: " << predicted_class); + INFO(" Probabilities: [" << probs[0] << ", " << probs[1] << "]"); + + // The predicted class should be the one with highest probability + int max_prob_class = (probs[0] > probs[1]) ? 0 : 1; + INFO(" Max prob class: " << max_prob_class); + + // Handle tie case (when probabilities are equal) + if (std::abs(probs[0] - probs[1]) < 1e-10) { + INFO(" Tie detected - probabilities are equal"); + // In case of tie, either prediction is valid + REQUIRE((predicted_class == 0 || predicted_class == 1)); + } else { + REQUIRE(predicted_class == max_prob_class); + } + + // Probabilities should sum to 1 + double sum_probs = probs[0] + probs[1]; + REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6)); + + // All probabilities should be valid + REQUIRE(probs[0] >= 0.0); + REQUIRE(probs[1] >= 0.0); + REQUIRE(probs[0] <= 1.0); + REQUIRE(probs[1] <= 1.0); + } + } +} \ No newline at end of file diff --git a/tests/TestDecisionTree.cpp b/tests/TestDecisionTree.cpp new file mode 100644 index 0000000..7b5ef76 --- /dev/null +++ b/tests/TestDecisionTree.cpp @@ -0,0 +1,311 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include +#include +#include +#include +#include +#include +#include +#include "experimental_clfs/DecisionTree.h" +#include "TestUtils.h" + +using namespace bayesnet; +using namespace Catch::Matchers; + +TEST_CASE("DecisionTree Construction", "[DecisionTree]") +{ + SECTION("Default constructor") + { + REQUIRE_NOTHROW(DecisionTree()); + } + + SECTION("Constructor with parameters") + { + REQUIRE_NOTHROW(DecisionTree(5, 10, 3)); + } +} + +TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]") +{ + DecisionTree dt; + + SECTION("Set individual hyperparameters") + { + REQUIRE_NOTHROW(dt.setMaxDepth(10)); + REQUIRE_NOTHROW(dt.setMinSamplesSplit(5)); + REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2)); + REQUIRE(dt.getMaxDepth() == 10); + REQUIRE(dt.getMinSamplesSplit() == 5); + REQUIRE(dt.getMinSamplesLeaf() == 2); + } + + SECTION("Set hyperparameters via JSON") + { + nlohmann::json params; + params["max_depth"] = 7; + params["min_samples_split"] = 4; + params["min_samples_leaf"] = 2; + + REQUIRE_NOTHROW(dt.setHyperparameters(params)); + REQUIRE(dt.getMaxDepth() == 7); + REQUIRE(dt.getMinSamplesSplit() == 4); + REQUIRE(dt.getMinSamplesLeaf() == 2); + } + + SECTION("Invalid hyperparameters should throw") + { + nlohmann::json params; + + // Negative max_depth + params["max_depth"] = -1; + REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument); + + // Zero min_samples_split + params["max_depth"] = 5; + params["min_samples_split"] = 0; + REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument); + + // Negative min_samples_leaf + params["min_samples_split"] = 2; + params["min_samples_leaf"] = -5; + REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument); + } +} + +TEST_CASE("DecisionTree Basic Functionality", "[DecisionTree]") +{ + // Create a simple dataset + int n_samples = 20; + int n_features = 2; + + std::vector> X(n_features, std::vector(n_samples)); + std::vector y(n_samples); + + // Simple pattern: class depends on first feature + for (int i = 0; i < n_samples; i++) { + X[0][i] = i < 10 ? 0 : 1; + X[1][i] = i % 2; + y[i] = X[0][i]; // Class equals first feature + } + + std::vector features = { "f1", "f2" }; + std::string className = "class"; + std::map> states; + states["f1"] = { 0, 1 }; + states["f2"] = { 0, 1 }; + states["class"] = { 0, 1 }; + + SECTION("Training with vector interface") + { + DecisionTree dt(3, 2, 1); + REQUIRE_NOTHROW(dt.fit(X, y, features, className, states, Smoothing_t::NONE)); + + auto predictions = dt.predict(X); + REQUIRE(predictions.size() == static_cast(n_samples)); + + // Should achieve perfect accuracy on this simple dataset + int correct = 0; + for (size_t i = 0; i < predictions.size(); i++) { + if (predictions[i] == y[i]) correct++; + } + REQUIRE(correct == n_samples); + } + + SECTION("Prediction before fitting") + { + DecisionTree dt; + REQUIRE_THROWS_WITH(dt.predict(X), + ContainsSubstring("Classifier has not been fitted")); + } + + SECTION("Probability predictions") + { + DecisionTree dt(3, 2, 1); + dt.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto proba = dt.predict_proba(X); + REQUIRE(proba.size() == static_cast(n_samples)); + REQUIRE(proba[0].size() == 2); // Two classes + + // Check probabilities sum to 1 and probabilities are valid + auto predictions = dt.predict(X); + for (size_t i = 0; i < proba.size(); i++) { + auto p = proba[i]; + auto pred = predictions[i]; + REQUIRE(p.size() == 2); + REQUIRE(p[0] >= 0.0); + REQUIRE(p[1] >= 0.0); + double sum = p[0] + p[1]; + //Check that prodict_proba matches the expected predict value + REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); + REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6)); + } + } +} + +TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Training with dataset format") + { + DecisionTree dt(5, 2, 1); + + INFO("Dataset shape: " << raw.dataset.sizes()); + INFO("Features: " << raw.featurest.size()); + INFO("Samples: " << raw.nSamples); + + // DecisionTree expects dataset in format: features x samples, with labels as last row + REQUIRE_NOTHROW(dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE)); + + // Test prediction + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + // Calculate accuracy + auto correct = torch::sum(predictions == raw.yt).item(); + double accuracy = static_cast(correct) / raw.yt.size(0); + double acurracy_computed = dt.score(raw.Xt, raw.yt); + REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris + REQUIRE(acurracy_computed == Catch::Approx(accuracy).epsilon(1e-6)); + } + + SECTION("Training with vector interface") + { + DecisionTree dt(5, 2, 1); + + REQUIRE_NOTHROW(dt.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv, Smoothing_t::NONE)); + + // std::cout << "Tree structure:\n"; + // auto graph_lines = dt.graph("Iris Decision Tree"); + // for (const auto& line : graph_lines) { + // std::cout << line << "\n"; + // } + auto predictions = dt.predict(raw.Xv); + REQUIRE(predictions.size() == raw.yv.size()); + } + + SECTION("Different tree depths") + { + std::vector depths = { 1, 3, 5 }; + + for (int depth : depths) { + DecisionTree dt(depth, 2, 1); + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + } + } +} + +TEST_CASE("DecisionTree Edge Cases", "[DecisionTree]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Very shallow tree") + { + DecisionTree dt(1, 2, 1); // depth = 1 + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + + // With depth 1, should have at most 2 unique predictions + auto unique_vals = at::_unique(predictions); + REQUIRE(std::get<0>(unique_vals).size(0) <= 2); + } + + SECTION("High min_samples_split") + { + DecisionTree dt(10, 50, 1); + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + } +} + +TEST_CASE("DecisionTree Graph Visualization", "[DecisionTree]") +{ + // Simple dataset + std::vector> X = { {0,0,0,1}, {0,1,1,1} }; // XOR pattern + std::vector y = { 0, 1, 1, 0 }; // XOR pattern + std::vector features = { "x1", "x2" }; + std::string className = "xor"; + std::map> states; + states["x1"] = { 0, 1 }; + states["x2"] = { 0, 1 }; + states["xor"] = { 0, 1 }; + + SECTION("Graph generation") + { + DecisionTree dt(2, 1, 1); + dt.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto graph_lines = dt.graph(); + + REQUIRE(graph_lines.size() > 2); + REQUIRE(graph_lines.front() == "digraph DecisionTree {"); + REQUIRE(graph_lines.back() == "}"); + + // Should contain node definitions + bool has_nodes = false; + for (const auto& line : graph_lines) { + if (line.find("node") != std::string::npos) { + has_nodes = true; + break; + } + } + REQUIRE(has_nodes); + } + + SECTION("Graph with title") + { + DecisionTree dt(2, 1, 1); + dt.fit(X, y, features, className, states, Smoothing_t::NONE); + + auto graph_lines = dt.graph("XOR Tree"); + + bool has_title = false; + for (const auto& line : graph_lines) { + if (line.find("label=\"XOR Tree\"") != std::string::npos) { + has_title = true; + break; + } + } + REQUIRE(has_title); + } +} + +TEST_CASE("DecisionTree with Weights", "[DecisionTree]") +{ + auto raw = RawDatasets("iris", true); + + SECTION("Uniform weights") + { + DecisionTree dt(5, 2, 1); + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, raw.weights, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + } + + SECTION("Non-uniform weights") + { + auto weights = torch::ones({ raw.nSamples }); + weights.index({ torch::indexing::Slice(0, 50) }) *= 2.0; // Emphasize first class + weights = weights / weights.sum(); + + DecisionTree dt(5, 2, 1); + dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, weights, Smoothing_t::NONE); + + auto predictions = dt.predict(raw.Xt); + REQUIRE(predictions.size(0) == raw.yt.size(0)); + } +} \ No newline at end of file diff --git a/tests/TestPlatform.cpp b/tests/TestPlatform.cpp index 108d8ec..6eb2bd0 100644 --- a/tests/TestPlatform.cpp +++ b/tests/TestPlatform.cpp @@ -7,7 +7,7 @@ #include #include "TestUtils.h" #include "folding.hpp" -#include +#include #include #include "config_platform.h" @@ -20,17 +20,17 @@ TEST_CASE("Test Platform version", "[Platform]") TEST_CASE("Test Folding library version", "[Folding]") { std::string version = folding::KFold(5, 100).version(); - REQUIRE(version == "1.1.0"); + REQUIRE(version == "1.1.1"); } TEST_CASE("Test BayesNet version", "[BayesNet]") { std::string version = bayesnet::TAN().getVersion(); - REQUIRE(version == "1.0.6"); + REQUIRE(version == "1.1.2"); } TEST_CASE("Test mdlp version", "[mdlp]") { std::string version = mdlp::CPPFImdlp::version(); - REQUIRE(version == "2.0.0"); + REQUIRE(version == "2.0.1"); } TEST_CASE("Test Arff version", "[Arff]") { diff --git a/tests/TestScores.cpp b/tests/TestScores.cpp index 111ae7e..b414c04 100644 --- a/tests/TestScores.cpp +++ b/tests/TestScores.cpp @@ -14,38 +14,40 @@ using json = nlohmann::ordered_json; auto epsilon = 1e-4; -void make_test_bin(int TP, int TN, int FP, int FN, std::vector& y_test, std::vector& y_pred) +void make_test_bin(int TP, int TN, int FP, int FN, std::vector& y_test, torch::Tensor& y_pred) { - // TP + std::vector> probs; + // TP: true positive (label 1, predicted 1) for (int i = 0; i < TP; i++) { y_test.push_back(1); - y_pred.push_back(1); + probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1 } - // TN + // TN: true negative (label 0, predicted 0) for (int i = 0; i < TN; i++) { y_test.push_back(0); - y_pred.push_back(0); + probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0 } - // FP + // FP: false positive (label 0, predicted 1) for (int i = 0; i < FP; i++) { y_test.push_back(0); - y_pred.push_back(1); + probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1 } - // FN + // FN: false negative (label 1, predicted 0) for (int i = 0; i < FN; i++) { y_test.push_back(1); - y_pred.push_back(0); + probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0 } + // Convert to torch::Tensor of double, shape [N,2] + y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 2 }, torch::kFloat64).clone(); } TEST_CASE("Scores binary", "[Scores]") { std::vector y_test; - std::vector y_pred; + torch::Tensor y_pred; make_test_bin(197, 210, 52, 41, y_test, y_pred); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores(y_test_tensor, y_pred_tensor, 2); + platform::Scores scores(y_test_tensor, y_pred, 2); REQUIRE(scores.accuracy() == Catch::Approx(0.814).epsilon(epsilon)); REQUIRE(scores.f1_score(0) == Catch::Approx(0.818713)); REQUIRE(scores.f1_score(1) == Catch::Approx(0.809035)); @@ -64,10 +66,23 @@ TEST_CASE("Scores binary", "[Scores]") TEST_CASE("Scores multiclass", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + // Refactor y_pred to a tensor of shape [10, 3] with probabilities + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); + // Convert y_test to a tensor auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores(y_test_tensor, y_pred_tensor, 3); + platform::Scores scores(y_test_tensor, y_pred, 3); REQUIRE(scores.accuracy() == Catch::Approx(0.6).epsilon(epsilon)); REQUIRE(scores.f1_score(0) == Catch::Approx(0.666667)); REQUIRE(scores.f1_score(1) == Catch::Approx(0.4)); @@ -84,10 +99,21 @@ TEST_CASE("Scores multiclass", "[Scores]") TEST_CASE("Test Confusion Matrix Values", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores(y_test_tensor, y_pred_tensor, 3); + platform::Scores scores(y_test_tensor, y_pred, 3); auto confusion_matrix = scores.get_confusion_matrix(); REQUIRE(confusion_matrix[0][0].item() == 2); REQUIRE(confusion_matrix[0][1].item() == 1); @@ -102,11 +128,22 @@ TEST_CASE("Test Confusion Matrix Values", "[Scores]") TEST_CASE("Confusion Matrix JSON", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); std::vector labels = { "Aeroplane", "Boat", "Car" }; - platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); + platform::Scores scores(y_test_tensor, y_pred, 3, labels); auto res_json_int = scores.get_confusion_matrix_json(); REQUIRE(res_json_int[0][0] == 2); REQUIRE(res_json_int[0][1] == 1); @@ -131,11 +168,22 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]") TEST_CASE("Classification Report", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); std::vector labels = { "Aeroplane", "Boat", "Car" }; - platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); + platform::Scores scores(y_test_tensor, y_pred, 3, labels); auto report = scores.classification_report(Colors::BLUE(), "train"); auto json_matrix = scores.get_confusion_matrix_json(true); platform::Scores scores2(json_matrix); @@ -144,11 +192,22 @@ TEST_CASE("Classification Report", "[Scores]") TEST_CASE("JSON constructor", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); std::vector labels = { "Car", "Boat", "Aeroplane" }; - platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); + platform::Scores scores(y_test_tensor, y_pred, 3, labels); auto res_json_int = scores.get_confusion_matrix_json(); platform::Scores scores2(res_json_int); REQUIRE(scores.accuracy() == scores2.accuracy()); @@ -173,17 +232,14 @@ TEST_CASE("JSON constructor", "[Scores]") TEST_CASE("Aggregate", "[Scores]") { std::vector y_test; - std::vector y_pred; + torch::Tensor y_pred; make_test_bin(197, 210, 52, 41, y_test, y_pred); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores(y_test_tensor, y_pred_tensor, 2); + platform::Scores scores(y_test_tensor, y_pred, 2); y_test.clear(); - y_pred.clear(); make_test_bin(227, 187, 39, 47, y_test, y_pred); auto y_test_tensor2 = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor2 = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores2(y_test_tensor2, y_pred_tensor2, 2); + platform::Scores scores2(y_test_tensor2, y_pred, 2); scores.aggregate(scores2); REQUIRE(scores.accuracy() == Catch::Approx(0.821).epsilon(epsilon)); REQUIRE(scores.f1_score(0) == Catch::Approx(0.8160329)); @@ -195,11 +251,9 @@ TEST_CASE("Aggregate", "[Scores]") REQUIRE(scores.f1_weighted() == Catch::Approx(0.8209856)); REQUIRE(scores.f1_macro() == Catch::Approx(0.8208694)); y_test.clear(); - y_pred.clear(); make_test_bin(197 + 227, 210 + 187, 52 + 39, 41 + 47, y_test, y_pred); y_test_tensor = torch::tensor(y_test, torch::kInt32); - y_pred_tensor = torch::tensor(y_pred, torch::kInt32); - platform::Scores scores3(y_test_tensor, y_pred_tensor, 2); + platform::Scores scores3(y_test_tensor, y_pred, 2); for (int i = 0; i < 2; ++i) { REQUIRE(scores3.f1_score(i) == scores.f1_score(i)); REQUIRE(scores3.precision(i) == scores.precision(i)); @@ -212,11 +266,22 @@ TEST_CASE("Aggregate", "[Scores]") TEST_CASE("Order of keys", "[Scores]") { std::vector y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; - std::vector y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; + std::vector> probs = { + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0 + { 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1 + }; + torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone(); auto y_test_tensor = torch::tensor(y_test, torch::kInt32); - auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); std::vector labels = { "Car", "Boat", "Aeroplane" }; - platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); + platform::Scores scores(y_test_tensor, y_pred, 3, labels); auto res_json_int = scores.get_confusion_matrix_json(true); // Make a temp file and store the json std::string filename = "temp.json"; diff --git a/tests/TestUtils.h b/tests/TestUtils.h index aea25d6..4409d6f 100644 --- a/tests/TestUtils.h +++ b/tests/TestUtils.h @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include bool file_exists(const std::string& name); diff --git a/vcpkg-configuration.json b/vcpkg-configuration.json new file mode 100644 index 0000000..b241242 --- /dev/null +++ b/vcpkg-configuration.json @@ -0,0 +1,21 @@ +{ + "default-registry": { + "kind": "git", + "baseline": "760bfd0c8d7c89ec640aec4df89418b7c2745605", + "repository": "https://github.com/microsoft/vcpkg" + }, + "registries": [ + { + "kind": "git", + "repository": "https://github.com/rmontanana/vcpkg-stash", + "baseline": "1ea69243c0e8b0de77c9d1dd6e1d7593ae7f3627", + "packages": [ + "arff-files", + "bayesnet", + "fimdlp", + "folding", + "libtorch-bin" + ] + } + ] +} \ No newline at end of file diff --git a/vcpkg.json b/vcpkg.json new file mode 100644 index 0000000..e9a8c61 --- /dev/null +++ b/vcpkg.json @@ -0,0 +1,43 @@ + { + "name": "platform", + "version-string": "1.1.0", + "dependencies": [ + "arff-files", + "nlohmann-json", + "fimdlp", + "libtorch-bin", + "folding", + "catch2", + "argparse" + ], + "overrides": [ + { + "name": "arff-files", + "version": "1.1.0" + }, + { + "name": "fimdlp", + "version": "2.0.1" + }, + { + "name": "libtorch-bin", + "version": "2.7.0" + }, + { + "name": "folding", + "version": "1.1.1" + }, + { + "name": "argparse", + "version": "3.2" + }, + { + "name": "catch2", + "version": "3.8.1" + }, + { + "name": "nlohmann-json", + "version": "3.11.3" + } + ] + } \ No newline at end of file