Compare commits
167 Commits
67f1feb71f
...
alphablock
Author | SHA1 | Date | |
---|---|---|---|
ba455bb934
|
|||
a65955248a
|
|||
84930b0537
|
|||
10c65f44a0
|
|||
6d112f01e7
|
|||
401296293b
|
|||
9566ae4cf6
|
|||
55187ee521
|
|||
68ea06d129
|
|||
6c1d1d0d32
|
|||
b0853d169b
|
|||
26f8e07774
|
|||
315dfb104f
|
|||
381f226d53
|
|||
ea13835701
|
|||
d75468cf78
|
|||
c58bd9d60d
|
|||
148a3b831a
|
|||
69063badbb
|
|||
6ae2b2182a
|
|||
4dbd76df55
|
|||
4545f76667
|
|||
8372987dae
|
|||
d72943c749
|
|||
800246acd2
|
|||
0ea967dd9d
|
|||
97abec8b69
|
|||
17c9522e77
|
|||
45af550cf9
|
|||
5d5f49777e
|
|||
540a8ea06d
|
|||
1924c4392b
|
|||
f2556a30af
|
|||
2f2ed00ca1
|
|||
28f6a0d7a7
|
|||
028522f180
|
|||
84adf13a79
|
|||
26dfe6d056
|
|||
3acc34e4c6
|
|||
8f92b74260
|
|||
3d900f8c81
|
|||
e628d80f4c
|
|||
0f06f8971e
|
|||
f800772149
|
|||
b8a8ddaf8c
|
|||
90555489ff
|
|||
080f3cee34
|
|||
643633e6dd
|
|||
361c51d864
|
|||
5dd3deca1a
|
|||
2202a81782
|
|||
c4f4e332f6
|
|||
a7ec930fa0
|
|||
6858b3d89a
|
|||
5fb176d78a
|
|||
f5d5c35002
|
|||
b34af13eea
|
|||
e3a06264a9
|
|||
df82f82e88
|
|||
886dde7a06
|
|||
88468434e7
|
|||
ad5c3319bd
|
|||
594adb0534
|
|||
b9e0c92334
|
|||
25bd7a42c6
|
|||
c165a4bdda
|
|||
49a36904dc
|
|||
577351eda5
|
|||
a3c4bde460
|
|||
696c0564a7
|
|||
30a6d5e60d
|
|||
f8f3ca28dc
|
|||
5c190d7c66
|
|||
99c9c6731f
|
|||
8d20545fd2
|
|||
2b480cdcb7 | |||
ebaddf1a6c
|
|||
07a2efb298
|
|||
f88b223c46
|
|||
69b9609154
|
|||
6d4117d188
|
|||
ec0268c514
|
|||
dd94fd51f7
|
|||
009ed037b8
|
|||
6d1b78ada7
|
|||
3882ebd6e4
|
|||
423242d280
|
|||
b9381aa453
|
|||
33cfb78554
|
|||
1caa39c071
|
|||
018c94bfe6
|
|||
a54d6b8716
|
|||
6cde09d81e
|
|||
7be95d889d
|
|||
42d61c6fc4
|
|||
e5e947779f
|
|||
ad168d13ba
|
|||
78b8a8ae66
|
|||
7ed9073d15
|
|||
ee93789ca3
|
|||
375ed437ed
|
|||
5ec7fe8d00
|
|||
72ea62f783
|
|||
4b91f2bde0
|
|||
3bc51cb7b0
|
|||
cf83d1f8f4
|
|||
0dd10bcbe4
|
|||
622b36b2c7
|
|||
ea29a96ca1
|
|||
673a41fc4d
|
|||
634ea36169
|
|||
20fef5b6b3
|
|||
7cf864c3f3
|
|||
4a0fa33917
|
|||
d47da27571
|
|||
faccb09c43
|
|||
fa4f47ff35
|
|||
106a36109e
|
|||
37eba57765
|
|||
67487ffce1
|
|||
9c11dee019
|
|||
58ae2c7690
|
|||
fa366a4c22
|
|||
b9af086c29
|
|||
6a285b149b
|
|||
ad402ac21e
|
|||
38978aa7b7
|
|||
3691363b8e
|
|||
fe24aa0b3e
|
|||
175e0eb591
|
|||
1912d17498
|
|||
54249e5304
|
|||
d7f92c9682
|
|||
00bb7f4680
|
|||
bf5dabb169
|
|||
cdf339856a
|
|||
3ceea5677c
|
|||
260fd122eb
|
|||
eff0be1c1c
|
|||
0ade72a37a
|
|||
72cda3784a
|
|||
52d689666a
|
|||
26e87c9cb1 | |||
03cd6e5a51
|
|||
cd9ff89b52
|
|||
05d05e25c2
|
|||
5cd6e3d1a5
|
|||
d9e9356d92
|
|||
0010c840d1
|
|||
51f32113c0
|
|||
b3b3d9f1b9
|
|||
4c847fc3f6
|
|||
7e4ee0a9a9
|
|||
b7398db9b1
|
|||
0a9bd0d9c4
|
|||
7a3adaf4a9
|
|||
5c4efa08db
|
|||
576016bbd9 | |||
e26b3c0970
|
|||
183cf12300
|
|||
4eb08cd281
|
|||
4f5f629124
|
|||
df011f7e6b
|
|||
42648f3125
|
|||
d2832ed2b3
|
|||
ec323d86ab
|
|||
e4a6575722
|
@@ -4,8 +4,8 @@ diagrams:
|
|||||||
Platform:
|
Platform:
|
||||||
type: class
|
type: class
|
||||||
glob:
|
glob:
|
||||||
- src/*.cc
|
- src/*.cpp
|
||||||
- src/modules/*.cc
|
- src/modules/*.cpp
|
||||||
using_namespace: platform
|
using_namespace: platform
|
||||||
include:
|
include:
|
||||||
namespaces:
|
namespaces:
|
||||||
@@ -17,7 +17,7 @@ diagrams:
|
|||||||
sequence:
|
sequence:
|
||||||
type: sequence
|
type: sequence
|
||||||
glob:
|
glob:
|
||||||
- src/b_main.cc
|
- src/b_main.cpp
|
||||||
combine_free_functions_into_file_participants: true
|
combine_free_functions_into_file_participants: true
|
||||||
using_namespace:
|
using_namespace:
|
||||||
- std
|
- std
|
||||||
|
10
.gitmodules
vendored
10
.gitmodules
vendored
@@ -10,10 +10,12 @@
|
|||||||
[submodule "lib/libxlsxwriter"]
|
[submodule "lib/libxlsxwriter"]
|
||||||
path = lib/libxlsxwriter
|
path = lib/libxlsxwriter
|
||||||
url = https://github.com/jmcnamara/libxlsxwriter.git
|
url = https://github.com/jmcnamara/libxlsxwriter.git
|
||||||
|
[submodule "lib/folding"]
|
||||||
|
path = lib/folding
|
||||||
|
url = https://github.com/rmontanana/folding
|
||||||
|
[submodule "lib/Files"]
|
||||||
|
path = lib/Files
|
||||||
|
url = https://github.com/rmontanana/ArffFiles
|
||||||
[submodule "lib/mdlp"]
|
[submodule "lib/mdlp"]
|
||||||
path = lib/mdlp
|
path = lib/mdlp
|
||||||
url = https://github.com/rmontanana/mdlp
|
url = https://github.com/rmontanana/mdlp
|
||||||
update = merge
|
|
||||||
[submodule "lib/PyClassifiers"]
|
|
||||||
path = lib/PyClassifiers
|
|
||||||
url = git@github.com:rmontanana/PyClassifiers
|
|
||||||
|
13
.vscode/c_cpp_properties.json
vendored
13
.vscode/c_cpp_properties.json
vendored
@@ -11,7 +11,18 @@
|
|||||||
],
|
],
|
||||||
"cStandard": "c17",
|
"cStandard": "c17",
|
||||||
"cppStandard": "c++17",
|
"cppStandard": "c++17",
|
||||||
"compileCommands": "${workspaceFolder}/cmake-build-release/compile_commands.json"
|
"compileCommands": "${workspaceFolder}/cmake-build-release/compile_commands.json",
|
||||||
|
"configurationProvider": "ms-vscode.cmake-tools"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Linux",
|
||||||
|
"includePath": [
|
||||||
|
"${workspaceFolder}/**"
|
||||||
|
],
|
||||||
|
"defines": [],
|
||||||
|
"cStandard": "c17",
|
||||||
|
"cppStandard": "c++17",
|
||||||
|
"configurationProvider": "ms-vscode.cmake-tools"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"version": 4
|
"version": 4
|
||||||
|
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
@@ -62,9 +62,9 @@
|
|||||||
"--stratified",
|
"--stratified",
|
||||||
"--discretize",
|
"--discretize",
|
||||||
"-d",
|
"-d",
|
||||||
"iris",
|
"glass",
|
||||||
"--hyperparameters",
|
"--hyperparameters",
|
||||||
"{\"repeatSparent\": true, \"maxModels\": 12}"
|
"{\"block_update\": true}"
|
||||||
],
|
],
|
||||||
"cwd": "/home/rmontanana/Code/discretizbench",
|
"cwd": "/home/rmontanana/Code/discretizbench",
|
||||||
},
|
},
|
||||||
@@ -99,7 +99,9 @@
|
|||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${workspaceFolder}/build_debug/src/b_list",
|
"program": "${workspaceFolder}/build_debug/src/b_list",
|
||||||
"args": [
|
"args": [
|
||||||
"--excel"
|
"results",
|
||||||
|
"-d",
|
||||||
|
"mfeat-morphological"
|
||||||
],
|
],
|
||||||
//"cwd": "/Users/rmontanana/Code/discretizbench",
|
//"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||||
"cwd": "${workspaceFolder}/../discretizbench",
|
"cwd": "${workspaceFolder}/../discretizbench",
|
||||||
@@ -108,12 +110,13 @@
|
|||||||
"name": "test",
|
"name": "test",
|
||||||
"type": "lldb",
|
"type": "lldb",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${workspaceFolder}/build_debug/tests/unit_tests",
|
"program": "${workspaceFolder}/build_debug/tests/unit_tests_platform",
|
||||||
"args": [
|
"args": [
|
||||||
"-c=\"Metrics Test\"",
|
"[Scores]",
|
||||||
|
// "-c=\"Metrics Test\"",
|
||||||
// "-s",
|
// "-s",
|
||||||
],
|
],
|
||||||
"cwd": "${workspaceFolder}/build/tests",
|
"cwd": "${workspaceFolder}/build_debug/tests",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Build & debug active file",
|
"name": "Build & debug active file",
|
||||||
|
@@ -1,16 +1,12 @@
|
|||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
project(Platform
|
project(Platform
|
||||||
VERSION 1.0.2
|
VERSION 1.1.0
|
||||||
DESCRIPTION "Platform to run Experiments with classifiers."
|
DESCRIPTION "Platform to run Experiments with classifiers."
|
||||||
HOMEPAGE_URL "https://github.com/rmontanana/platform"
|
HOMEPAGE_URL "https://github.com/rmontanana/platform"
|
||||||
LANGUAGES CXX
|
LANGUAGES CXX
|
||||||
)
|
)
|
||||||
|
|
||||||
if (CODE_COVERAGE AND NOT ENABLE_TESTING)
|
|
||||||
MESSAGE(FATAL_ERROR "Code coverage requires testing enabled")
|
|
||||||
endif (CODE_COVERAGE AND NOT ENABLE_TESTING)
|
|
||||||
|
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
if (POLICY CMP0135)
|
if (POLICY CMP0135)
|
||||||
@@ -25,6 +21,8 @@ set(CMAKE_CXX_EXTENSIONS OFF)
|
|||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast")
|
||||||
|
set(CMAKE_CXX_FLAGS_DEBUG " ${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage -O0 -g")
|
||||||
|
|
||||||
# Options
|
# Options
|
||||||
# -------
|
# -------
|
||||||
@@ -48,7 +46,7 @@ if(Boost_FOUND)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Python
|
# Python
|
||||||
find_package(Python3 3.11...3.11.9 COMPONENTS Interpreter Development REQUIRED)
|
find_package(Python3 3.11 COMPONENTS Interpreter Development REQUIRED)
|
||||||
message("Python3_LIBRARIES=${Python3_LIBRARIES}")
|
message("Python3_LIBRARIES=${Python3_LIBRARIES}")
|
||||||
|
|
||||||
# CMakes modules
|
# CMakes modules
|
||||||
@@ -60,7 +58,6 @@ if (CODE_COVERAGE)
|
|||||||
enable_testing()
|
enable_testing()
|
||||||
include(CodeCoverage)
|
include(CodeCoverage)
|
||||||
MESSAGE("Code coverage enabled")
|
MESSAGE("Code coverage enabled")
|
||||||
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage -O0 -g")
|
|
||||||
SET(GCC_COVERAGE_LINK_FLAGS " ${GCC_COVERAGE_LINK_FLAGS} -lgcov --coverage")
|
SET(GCC_COVERAGE_LINK_FLAGS " ${GCC_COVERAGE_LINK_FLAGS} -lgcov --coverage")
|
||||||
endif (CODE_COVERAGE)
|
endif (CODE_COVERAGE)
|
||||||
|
|
||||||
@@ -70,18 +67,31 @@ endif (ENABLE_CLANG_TIDY)
|
|||||||
|
|
||||||
# External libraries - dependencies of Platform
|
# External libraries - dependencies of Platform
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
add_git_submodule("lib/PyClassifiers")
|
|
||||||
add_git_submodule("lib/argparse")
|
add_git_submodule("lib/argparse")
|
||||||
|
add_git_submodule("lib/mdlp")
|
||||||
|
|
||||||
find_library(XLSXWRITER_LIB NAMES libxlsxwriter.dylib libxlsxwriter.so PATHS ${Platform_SOURCE_DIR}/lib/libxlsxwriter/lib)
|
find_library(XLSXWRITER_LIB NAMES libxlsxwriter.dylib libxlsxwriter.so PATHS ${Platform_SOURCE_DIR}/lib/libxlsxwriter/lib)
|
||||||
message("XLSXWRITER_LIB=${XLSXWRITER_LIB}")
|
message("XLSXWRITER_LIB=${XLSXWRITER_LIB}")
|
||||||
|
|
||||||
|
find_library(PyClassifiers NAMES libPyClassifiers PyClassifiers libPyClassifiers.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED)
|
||||||
|
find_path(PyClassifiers_INCLUDE_DIRS REQUIRED NAMES pyclassifiers PATHS ${Platform_SOURCE_DIR}/../lib/include)
|
||||||
|
find_library(BayesNet NAMES libBayesNet BayesNet libBayesNet.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED)
|
||||||
|
find_path(Bayesnet_INCLUDE_DIRS REQUIRED NAMES bayesnet PATHS ${Platform_SOURCE_DIR}/../lib/include)
|
||||||
|
|
||||||
|
message(STATUS "PyClassifiers=${PyClassifiers}")
|
||||||
|
message(STATUS "PyClassifiers_INCLUDE_DIRS=${PyClassifiers_INCLUDE_DIRS}")
|
||||||
|
message(STATUS "BayesNet=${BayesNet}")
|
||||||
|
message(STATUS "Bayesnet_INCLUDE_DIRS=${Bayesnet_INCLUDE_DIRS}")
|
||||||
|
|
||||||
# Subdirectories
|
# Subdirectories
|
||||||
# --------------
|
# --------------
|
||||||
|
## Configure test data path
|
||||||
|
cmake_path(SET TEST_DATA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/tests/data")
|
||||||
|
configure_file(src/common/SourceData.h.in "${CMAKE_BINARY_DIR}/configured_files/include/SourceData.h")
|
||||||
add_subdirectory(config)
|
add_subdirectory(config)
|
||||||
add_subdirectory(src)
|
add_subdirectory(src)
|
||||||
add_subdirectory(sample)
|
# add_subdirectory(sample)
|
||||||
file(GLOB Platform_SOURCES CONFIGURE_DEPENDS ${Platform_SOURCE_DIR}/src/*.cc)
|
file(GLOB Platform_SOURCES CONFIGURE_DEPENDS ${Platform_SOURCE_DIR}/src/*.cpp)
|
||||||
|
|
||||||
# Testing
|
# Testing
|
||||||
# -------
|
# -------
|
||||||
|
4
Doxyfile
4
Doxyfile
@@ -976,7 +976,7 @@ INPUT_FILE_ENCODING =
|
|||||||
# Note the list of default checked file patterns might differ from the list of
|
# Note the list of default checked file patterns might differ from the list of
|
||||||
# default file extension mappings.
|
# default file extension mappings.
|
||||||
#
|
#
|
||||||
# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp,
|
# If left blank the following patterns are tested:*.c, *.cpp, *.cxx, *.cpp,
|
||||||
# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h,
|
# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h,
|
||||||
# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml,
|
# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml,
|
||||||
# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C
|
# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C
|
||||||
@@ -984,7 +984,7 @@ INPUT_FILE_ENCODING =
|
|||||||
# *.vhdl, *.ucf, *.qsf and *.ice.
|
# *.vhdl, *.ucf, *.qsf and *.ice.
|
||||||
|
|
||||||
FILE_PATTERNS = *.c \
|
FILE_PATTERNS = *.c \
|
||||||
*.cc \
|
*.cpp \
|
||||||
*.cxx \
|
*.cxx \
|
||||||
*.cpp \
|
*.cpp \
|
||||||
*.c++ \
|
*.c++ \
|
||||||
|
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2024 rmontanana
|
Copyright (c) 2024 Ricardo Montañana Gómez
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
21
Makefile
21
Makefile
@@ -5,8 +5,7 @@ SHELL := /bin/bash
|
|||||||
f_release = build_release
|
f_release = build_release
|
||||||
f_debug = build_debug
|
f_debug = build_debug
|
||||||
app_targets = b_best b_list b_main b_manage b_grid
|
app_targets = b_best b_list b_main b_manage b_grid
|
||||||
test_targets = unit_tests_bayesnet unit_tests_platform
|
test_targets = unit_tests_platform
|
||||||
n_procs = -j 16
|
|
||||||
|
|
||||||
define ClearTests
|
define ClearTests
|
||||||
@for t in $(test_targets); do \
|
@for t in $(test_targets); do \
|
||||||
@@ -41,7 +40,7 @@ setup: ## Install dependencies for tests and coverage
|
|||||||
dest ?= ${HOME}/bin
|
dest ?= ${HOME}/bin
|
||||||
install: ## Copy binary files to bin folder
|
install: ## Copy binary files to bin folder
|
||||||
@echo "Destination folder: $(dest)"
|
@echo "Destination folder: $(dest)"
|
||||||
make buildr
|
@make buildr
|
||||||
@echo "*******************************************"
|
@echo "*******************************************"
|
||||||
@echo ">>> Copying files to $(dest)"
|
@echo ">>> Copying files to $(dest)"
|
||||||
@echo "*******************************************"
|
@echo "*******************************************"
|
||||||
@@ -56,10 +55,10 @@ dependency: ## Create a dependency graph diagram of the project (build/dependenc
|
|||||||
cd $(f_debug) && cmake .. --graphviz=dependency.dot && dot -Tpng dependency.dot -o dependency.png
|
cd $(f_debug) && cmake .. --graphviz=dependency.dot && dot -Tpng dependency.dot -o dependency.png
|
||||||
|
|
||||||
buildd: ## Build the debug targets
|
buildd: ## Build the debug targets
|
||||||
cmake --build $(f_debug) -t $(app_targets) PlatformSample $(n_procs)
|
cmake --build $(f_debug) -t $(app_targets) PlatformSample --parallel
|
||||||
|
|
||||||
buildr: ## Build the release targets
|
buildr: ## Build the release targets
|
||||||
cmake --build $(f_release) -t $(app_targets) $(n_procs)
|
cmake --build $(f_release) -t $(app_targets) --parallel
|
||||||
|
|
||||||
clean: ## Clean the tests info
|
clean: ## Clean the tests info
|
||||||
@echo ">>> Cleaning Debug Platform tests...";
|
@echo ">>> Cleaning Debug Platform tests...";
|
||||||
@@ -87,7 +86,7 @@ opt = ""
|
|||||||
test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximum Spanning Tree'") to run only that section
|
test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximum Spanning Tree'") to run only that section
|
||||||
@echo ">>> Running Platform tests...";
|
@echo ">>> Running Platform tests...";
|
||||||
@$(MAKE) clean
|
@$(MAKE) clean
|
||||||
@cmake --build $(f_debug) -t $(test_targets) $(n_procs)
|
@cmake --build $(f_debug) -t $(test_targets) --parallel
|
||||||
@for t in $(test_targets); do \
|
@for t in $(test_targets); do \
|
||||||
if [ -f $(f_debug)/tests/$$t ]; then \
|
if [ -f $(f_debug)/tests/$$t ]; then \
|
||||||
cd $(f_debug)/tests ; \
|
cd $(f_debug)/tests ; \
|
||||||
@@ -96,6 +95,14 @@ test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximu
|
|||||||
done
|
done
|
||||||
@echo ">>> Done";
|
@echo ">>> Done";
|
||||||
|
|
||||||
|
fname = iris
|
||||||
|
example: ## Build sample
|
||||||
|
@echo ">>> Building Sample...";
|
||||||
|
@cmake --build build_debug -t sample
|
||||||
|
build_debug/sample/PlatformSample --model BoostAODE --dataset $(fname) --discretize --stratified
|
||||||
|
@echo ">>> Done";
|
||||||
|
|
||||||
|
|
||||||
coverage: ## Run tests and generate coverage report (build/index.html)
|
coverage: ## Run tests and generate coverage report (build/index.html)
|
||||||
@echo ">>> Building tests with coverage..."
|
@echo ">>> Building tests with coverage..."
|
||||||
@$(MAKE) test
|
@$(MAKE) test
|
||||||
@@ -105,7 +112,7 @@ coverage: ## Run tests and generate coverage report (build/index.html)
|
|||||||
|
|
||||||
help: ## Show help message
|
help: ## Show help message
|
||||||
@IFS=$$'\n' ; \
|
@IFS=$$'\n' ; \
|
||||||
help_lines=(`fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##/:/'`); \
|
help_lines=(`grep -Fh "##" $(MAKEFILE_LIST) | grep -Fv fgrep | sed -e 's/\\$$//' | sed -e 's/##/:/'`); \
|
||||||
printf "%s\n\n" "Usage: make [task]"; \
|
printf "%s\n\n" "Usage: make [task]"; \
|
||||||
printf "%-20s %s\n" "task" "help" ; \
|
printf "%-20s %s\n" "task" "help" ; \
|
||||||
printf "%-20s %s\n" "------" "----" ; \
|
printf "%-20s %s\n" "------" "----" ; \
|
||||||
|
83
README.md
83
README.md
@@ -1,10 +1,8 @@
|
|||||||
# Platform
|
# <img src="logo.png" alt="logo" width="50"/> Platform
|
||||||
|
|
||||||
Platform to run Bayesian Networks and Machine Learning Classifiers experiments.
|

|
||||||
|
[](<https://opensource.org/licenses/MIT>)
|
||||||
# Platform
|

|
||||||
|
|
||||||
[](https://opensource.org/licenses/MIT)
|
|
||||||
|
|
||||||
Platform to run Bayesian Networks and Machine Learning Classifiers experiments.
|
Platform to run Bayesian Networks and Machine Learning Classifiers experiments.
|
||||||
|
|
||||||
@@ -22,11 +20,18 @@ In Linux sometimes the library libstdc++ is mistaken from the miniconda installa
|
|||||||
libstdc++.so.6: version `GLIBCXX_3.4.32' not found (required by b_xxxx)
|
libstdc++.so.6: version `GLIBCXX_3.4.32' not found (required by b_xxxx)
|
||||||
```
|
```
|
||||||
|
|
||||||
The solution is to erase the libstdc++ library from the miniconda installation:
|
The solution is to erase the libstdc++ library from the miniconda installation and no further compilation is needed.
|
||||||
|
|
||||||
### MPI
|
### MPI
|
||||||
|
|
||||||
In Linux just install openmpi & openmpi-devel packages. Only if cmake can't find openmpi installation (like in Oracle Linux) set the following variable:
|
In Linux just install openmpi & openmpi-devel packages.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source /etc/profile.d/modules.sh
|
||||||
|
module load mpi/openmpi-x86_64
|
||||||
|
```
|
||||||
|
|
||||||
|
If cmake can't find openmpi installation (like in Oracle Linux) set the following variable:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export MPI_HOME="/usr/lib64/openmpi"
|
export MPI_HOME="/usr/lib64/openmpi"
|
||||||
@@ -86,4 +91,64 @@ make release
|
|||||||
make debug
|
make debug
|
||||||
```
|
```
|
||||||
|
|
||||||
## 1. Introduction
|
### Configuration
|
||||||
|
|
||||||
|
The configuration file is named .env and it should be located in the folder where the experiments should be run. In the root folder of the project there is a file named .env.example that can be used as a template.
|
||||||
|
|
||||||
|
## 1. Commands
|
||||||
|
|
||||||
|
### b_list
|
||||||
|
|
||||||
|
List all the datasets and its properties. The datasets are located in the _datasets_ folder under the experiments root folder. A special file called all.txt with the names of the datasets has to be created. This all file is built wih lines of the form:
|
||||||
|
<name>,<class_name>,<real_features>
|
||||||
|
|
||||||
|
where <real_features> can be either the word _all_ or a list of numbers separated by commas, i.e. [0,3,6,7]
|
||||||
|
|
||||||
|
### b_grid
|
||||||
|
|
||||||
|
Run a grid search over the parameters of the classifiers. The parameters are defined in the file _grid.txt_ located in the grid folder of the experiments. The file has to be created with the following format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"all": [
|
||||||
|
<set of hyperparams>, ...
|
||||||
|
],
|
||||||
|
"<dataset_name>": [
|
||||||
|
<specific set of hyperparams for <dataset_name>>, ...
|
||||||
|
],
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The file has to be named _grid_<model_name>_input.json_
|
||||||
|
|
||||||
|
As a result it builds a file named _grid_<model_name>_output.json_ with the results of the grid search.
|
||||||
|
|
||||||
|
The computation is done in parallel using MPI.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### b_main
|
||||||
|
|
||||||
|
Run the main experiment. There are several hyperparameters that can set in command line:
|
||||||
|
|
||||||
|
- -d, -\-dataset <dataset_name> : Name of the dataset to run the experiment with. If no dataset is specificied the experiment will run with all the datasets in the all.txt file.
|
||||||
|
- -m, -\-model <classifier_name> : Name of the classifier to run the experiment with (i.e. BoostAODE, TAN, Odte, etc.).
|
||||||
|
- -\-discretize: Discretize the dataset before running the experiment.
|
||||||
|
- -\-stratified: Use stratified cross validation.
|
||||||
|
- -\-folds <folds>: Number of folds for cross validation (optional, default value is in .env file).
|
||||||
|
- -s, -\-seeds <seed>: Seeds for the random number generator (optional, default values are in .env file).
|
||||||
|
- -\-no-train-score: Do not calculate the train score (optional), this is useful when the dataset is big and the training score is not needed.
|
||||||
|
- -\-hyperparameters <hyperparameters>: Hyperparameters for the experiment in json format.
|
||||||
|
- -\-hyper-file <hyperparameters_file>: File with the hyperparameters for the experiment in json format. This file uses the output format of the b_grid command.
|
||||||
|
- -\-title <title_text>: Title of the experiment (optional if only one dataset is specificied).
|
||||||
|
- -\-quiet: Don't display detailed progress and result of the experiment.
|
||||||
|
|
||||||
|
### b_manage
|
||||||
|
|
||||||
|
Manage the results of the experiments.
|
||||||
|
|
||||||
|
### b_best
|
||||||
|
|
||||||
|
Get and optionally compare the best results of the experiments. The results can be stored in an MS Excel file.
|
||||||
|
|
||||||
|

|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
configure_file(
|
configure_file(
|
||||||
"config.h.in"
|
"config.h.in"
|
||||||
"${CMAKE_BINARY_DIR}/configured_files/include/config.h" ESCAPE_QUOTES
|
"${CMAKE_BINARY_DIR}/configured_files/include/config_platform.h" ESCAPE_QUOTES
|
||||||
)
|
)
|
||||||
|
@@ -1,14 +1,11 @@
|
|||||||
#pragma once
|
#ifndef PLATFORM_H
|
||||||
|
#define PLATFORM_H
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
|
||||||
#define PROJECT_VERSION_MAJOR @PROJECT_VERSION_MAJOR @
|
static constexpr std::string_view platform_project_name = "@PROJECT_NAME@";
|
||||||
#define PROJECT_VERSION_MINOR @PROJECT_VERSION_MINOR @
|
static constexpr std::string_view platform_project_version = "@PROJECT_VERSION@";
|
||||||
#define PROJECT_VERSION_PATCH @PROJECT_VERSION_PATCH @
|
static constexpr std::string_view platform_project_description = "@PROJECT_DESCRIPTION@";
|
||||||
|
static constexpr std::string_view platform_git_sha = "@GIT_SHA@";
|
||||||
static constexpr std::string_view project_name = "@PROJECT_NAME@";
|
static constexpr std::string_view platform_data_path = "@Platform_SOURCE_DIR@/tests/data/";
|
||||||
static constexpr std::string_view project_version = "@PROJECT_VERSION@";
|
#endif
|
||||||
static constexpr std::string_view project_description = "@PROJECT_DESCRIPTION@";
|
|
||||||
static constexpr std::string_view git_sha = "@GIT_SHA@";
|
|
||||||
static constexpr std::string_view data_path = "@Platform_SOURCE_DIR@/tests/data/";
|
|
@@ -1,4 +1,4 @@
|
|||||||
filter = src/
|
filter = src/
|
||||||
exclude-directories = build/lib/
|
exclude-directories = build_debug/lib/
|
||||||
print-summary = yes
|
print-summary = yes
|
||||||
sort-percentage = yes
|
sort-percentage = yes
|
||||||
|
@@ -1,8 +1,3 @@
|
|||||||
[submodule "lib/mdlp"]
|
|
||||||
path = lib/mdlp
|
|
||||||
url = https://github.com/rmontanana/mdlp
|
|
||||||
main = main
|
|
||||||
update = merge
|
|
||||||
[submodule "lib/catch2"]
|
[submodule "lib/catch2"]
|
||||||
path = lib/catch2
|
path = lib/catch2
|
||||||
main = v2.x
|
main = v2.x
|
||||||
@@ -23,9 +18,6 @@
|
|||||||
url = https://github.com/jmcnamara/libxlsxwriter.git
|
url = https://github.com/jmcnamara/libxlsxwriter.git
|
||||||
main = main
|
main = main
|
||||||
update = merge
|
update = merge
|
||||||
[submodule "lib/PyClassifiers"]
|
|
||||||
path = lib/PyClassifiers
|
|
||||||
url = https://github.com/rmontanana/PyClassifiers
|
|
||||||
[submodule "lib/folding"]
|
[submodule "lib/folding"]
|
||||||
path = lib/folding
|
path = lib/folding
|
||||||
url = https://github.com/rmontanana/Folding
|
url = https://github.com/rmontanana/Folding
|
||||||
|
BIN
img/bbest.gif
Normal file
BIN
img/bbest.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 MiB |
BIN
img/bgrid.gif
Normal file
BIN
img/bgrid.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 349 KiB |
BIN
img/blist.gif
Normal file
BIN
img/blist.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.7 MiB |
BIN
img/bmain.gif
Normal file
BIN
img/bmain.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.3 MiB |
BIN
img/bmanage.gif
Normal file
BIN
img/bmanage.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 8.7 MiB |
1
lib/Files
Submodule
1
lib/Files
Submodule
Submodule lib/Files added at a4329f5f9d
@@ -1,168 +0,0 @@
|
|||||||
#include "ArffFiles.h"
|
|
||||||
#include <fstream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <map>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
ArffFiles::ArffFiles() = default;
|
|
||||||
|
|
||||||
std::vector<std::string> ArffFiles::getLines() const
|
|
||||||
{
|
|
||||||
return lines;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned long int ArffFiles::getSize() const
|
|
||||||
{
|
|
||||||
return lines.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::pair<std::string, std::string>> ArffFiles::getAttributes() const
|
|
||||||
{
|
|
||||||
return attributes;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ArffFiles::getClassName() const
|
|
||||||
{
|
|
||||||
return className;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ArffFiles::getClassType() const
|
|
||||||
{
|
|
||||||
return classType;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>>& ArffFiles::getX()
|
|
||||||
{
|
|
||||||
return X;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int>& ArffFiles::getY()
|
|
||||||
{
|
|
||||||
return y;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArffFiles::loadCommon(std::string fileName)
|
|
||||||
{
|
|
||||||
std::ifstream file(fileName);
|
|
||||||
if (!file.is_open()) {
|
|
||||||
throw std::invalid_argument("Unable to open file");
|
|
||||||
}
|
|
||||||
std::string line;
|
|
||||||
std::string keyword;
|
|
||||||
std::string attribute;
|
|
||||||
std::string type;
|
|
||||||
std::string type_w;
|
|
||||||
while (getline(file, line)) {
|
|
||||||
if (line.empty() || line[0] == '%' || line == "\r" || line == " ") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (line.find("@attribute") != std::string::npos || line.find("@ATTRIBUTE") != std::string::npos) {
|
|
||||||
std::stringstream ss(line);
|
|
||||||
ss >> keyword >> attribute;
|
|
||||||
type = "";
|
|
||||||
while (ss >> type_w)
|
|
||||||
type += type_w + " ";
|
|
||||||
attributes.emplace_back(trim(attribute), trim(type));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (line[0] == '@') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
lines.push_back(line);
|
|
||||||
}
|
|
||||||
file.close();
|
|
||||||
if (attributes.empty())
|
|
||||||
throw std::invalid_argument("No attributes found");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArffFiles::load(const std::string& fileName, bool classLast)
|
|
||||||
{
|
|
||||||
int labelIndex;
|
|
||||||
loadCommon(fileName);
|
|
||||||
if (classLast) {
|
|
||||||
className = std::get<0>(attributes.back());
|
|
||||||
classType = std::get<1>(attributes.back());
|
|
||||||
attributes.pop_back();
|
|
||||||
labelIndex = static_cast<int>(attributes.size());
|
|
||||||
} else {
|
|
||||||
className = std::get<0>(attributes.front());
|
|
||||||
classType = std::get<1>(attributes.front());
|
|
||||||
attributes.erase(attributes.begin());
|
|
||||||
labelIndex = 0;
|
|
||||||
}
|
|
||||||
generateDataset(labelIndex);
|
|
||||||
}
|
|
||||||
void ArffFiles::load(const std::string& fileName, const std::string& name)
|
|
||||||
{
|
|
||||||
int labelIndex;
|
|
||||||
loadCommon(fileName);
|
|
||||||
bool found = false;
|
|
||||||
for (int i = 0; i < attributes.size(); ++i) {
|
|
||||||
if (attributes[i].first == name) {
|
|
||||||
className = std::get<0>(attributes[i]);
|
|
||||||
classType = std::get<1>(attributes[i]);
|
|
||||||
attributes.erase(attributes.begin() + i);
|
|
||||||
labelIndex = i;
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!found) {
|
|
||||||
throw std::invalid_argument("Class name not found");
|
|
||||||
}
|
|
||||||
generateDataset(labelIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArffFiles::generateDataset(int labelIndex)
|
|
||||||
{
|
|
||||||
X = std::vector<std::vector<float>>(attributes.size(), std::vector<float>(lines.size()));
|
|
||||||
auto yy = std::vector<std::string>(lines.size(), "");
|
|
||||||
auto removeLines = std::vector<int>(); // Lines with missing values
|
|
||||||
for (size_t i = 0; i < lines.size(); i++) {
|
|
||||||
std::stringstream ss(lines[i]);
|
|
||||||
std::string value;
|
|
||||||
int pos = 0;
|
|
||||||
int xIndex = 0;
|
|
||||||
while (getline(ss, value, ',')) {
|
|
||||||
if (pos++ == labelIndex) {
|
|
||||||
yy[i] = value;
|
|
||||||
} else {
|
|
||||||
if (value == "?") {
|
|
||||||
X[xIndex++][i] = -1;
|
|
||||||
removeLines.push_back(i);
|
|
||||||
} else
|
|
||||||
X[xIndex++][i] = stof(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto i : removeLines) {
|
|
||||||
yy.erase(yy.begin() + i);
|
|
||||||
for (auto& x : X) {
|
|
||||||
x.erase(x.begin() + i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
y = factorize(yy);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ArffFiles::trim(const std::string& source)
|
|
||||||
{
|
|
||||||
std::string s(source);
|
|
||||||
s.erase(0, s.find_first_not_of(" '\n\r\t"));
|
|
||||||
s.erase(s.find_last_not_of(" '\n\r\t") + 1);
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> ArffFiles::factorize(const std::vector<std::string>& labels_t)
|
|
||||||
{
|
|
||||||
std::vector<int> yy;
|
|
||||||
yy.reserve(labels_t.size());
|
|
||||||
std::map<std::string, int> labelMap;
|
|
||||||
int i = 0;
|
|
||||||
for (const std::string& label : labels_t) {
|
|
||||||
if (labelMap.find(label) == labelMap.end()) {
|
|
||||||
labelMap[label] = i++;
|
|
||||||
}
|
|
||||||
yy.push_back(labelMap[label]);
|
|
||||||
}
|
|
||||||
return yy;
|
|
||||||
}
|
|
@@ -1,32 +0,0 @@
|
|||||||
#ifndef ARFFFILES_H
|
|
||||||
#define ARFFFILES_H
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
class ArffFiles {
|
|
||||||
private:
|
|
||||||
std::vector<std::string> lines;
|
|
||||||
std::vector<std::pair<std::string, std::string>> attributes;
|
|
||||||
std::string className;
|
|
||||||
std::string classType;
|
|
||||||
std::vector<std::vector<float>> X;
|
|
||||||
std::vector<int> y;
|
|
||||||
void generateDataset(int);
|
|
||||||
void loadCommon(std::string);
|
|
||||||
public:
|
|
||||||
ArffFiles();
|
|
||||||
void load(const std::string&, bool = true);
|
|
||||||
void load(const std::string&, const std::string&);
|
|
||||||
std::vector<std::string> getLines() const;
|
|
||||||
unsigned long int getSize() const;
|
|
||||||
std::string getClassName() const;
|
|
||||||
std::string getClassType() const;
|
|
||||||
static std::string trim(const std::string&);
|
|
||||||
std::vector<std::vector<float>>& getX();
|
|
||||||
std::vector<int>& getY();
|
|
||||||
std::vector<std::pair<std::string, std::string>> getAttributes() const;
|
|
||||||
static std::vector<int> factorize(const std::vector<std::string>& labels_t);
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif
|
|
@@ -1 +0,0 @@
|
|||||||
add_library(ArffFiles ArffFiles.cc)
|
|
Submodule lib/PyClassifiers deleted from 0608c0a52a
Submodule lib/argparse updated: 1b3abd9b92...cbd9fd8ed6
Submodule lib/catch2 updated: ed6ac8a629...0321d2fce3
1
lib/folding
Submodule
1
lib/folding
Submodule
Submodule lib/folding added at 2ac43e32ac
2
lib/json
2
lib/json
Submodule lib/json updated: 0457de21cf...620034ecec
Submodule lib/libxlsxwriter updated: b0c76b3396...8206bda64a
2
lib/mdlp
2
lib/mdlp
Submodule lib/mdlp updated: 5708dc3de9...cfb993f5ec
@@ -1,15 +1,15 @@
|
|||||||
include_directories(
|
include_directories(
|
||||||
${Platform_SOURCE_DIR}/src/common
|
${Platform_SOURCE_DIR}/src/common
|
||||||
${Platform_SOURCE_DIR}/src/main
|
${Platform_SOURCE_DIR}/src/main
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/src
|
|
||||||
${Python3_INCLUDE_DIRS}
|
${Python3_INCLUDE_DIRS}
|
||||||
${Platform_SOURCE_DIR}/lib/Files
|
${Platform_SOURCE_DIR}/lib/Files
|
||||||
|
${Platform_SOURCE_DIR}/lib/mdlp/src
|
||||||
${Platform_SOURCE_DIR}/lib/argparse/include
|
${Platform_SOURCE_DIR}/lib/argparse/include
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/lib/BayesNet/src
|
${Platform_SOURCE_DIR}/lib/folding
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/lib/BayesNet/lib/folding
|
${Platform_SOURCE_DIR}/lib/json/include
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/lib/BayesNet/lib/mdlp
|
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/lib/BayesNet/lib/json/include
|
|
||||||
${CMAKE_BINARY_DIR}/configured_files/include
|
${CMAKE_BINARY_DIR}/configured_files/include
|
||||||
|
${PyClassifiers_INCLUDE_DIRS}
|
||||||
|
${Bayesnet_INCLUDE_DIRS}
|
||||||
)
|
)
|
||||||
add_executable(PlatformSample sample.cc ${Platform_SOURCE_DIR}/src/main/Models.cc)
|
add_executable(PlatformSample sample.cpp ${Platform_SOURCE_DIR}/src/main/Models.cpp)
|
||||||
target_link_libraries(PlatformSample PyClassifiers ArffFiles mdlp "${TORCH_LIBRARIES}")
|
target_link_libraries(PlatformSample "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy)
|
240
sample/sample.cc
240
sample/sample.cc
@@ -1,240 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <torch/torch.h>
|
|
||||||
#include <string>
|
|
||||||
#include <map>
|
|
||||||
#include <argparse/argparse.hpp>
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
#include "ArffFiles.h"
|
|
||||||
#include "BayesMetrics.h"
|
|
||||||
#include "CPPFImdlp.h"
|
|
||||||
#include "folding.hpp"
|
|
||||||
#include "Models.h"
|
|
||||||
#include "modelRegister.h"
|
|
||||||
#include <fstream>
|
|
||||||
#include "config.h"
|
|
||||||
|
|
||||||
const std::string PATH = { data_path.begin(), data_path.end() };
|
|
||||||
|
|
||||||
pair<std::vector<mdlp::labels_t>, map<std::string, int>> discretize(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y, std::vector<std::string> features)
|
|
||||||
{
|
|
||||||
std::vector<mdlp::labels_t>Xd;
|
|
||||||
map<std::string, int> maxes;
|
|
||||||
|
|
||||||
auto fimdlp = mdlp::CPPFImdlp();
|
|
||||||
for (int i = 0; i < X.size(); i++) {
|
|
||||||
fimdlp.fit(X[i], y);
|
|
||||||
mdlp::labels_t& xd = fimdlp.transform(X[i]);
|
|
||||||
maxes[features[i]] = *max_element(xd.begin(), xd.end()) + 1;
|
|
||||||
Xd.push_back(xd);
|
|
||||||
}
|
|
||||||
return { Xd, maxes };
|
|
||||||
}
|
|
||||||
|
|
||||||
bool file_exists(const std::string& name)
|
|
||||||
{
|
|
||||||
if (FILE* file = fopen(name.c_str(), "r")) {
|
|
||||||
fclose(file);
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pair<std::vector<std::vector<int>>, std::vector<int>> extract_indices(std::vector<int> indices, std::vector<std::vector<int>> X, std::vector<int> y)
|
|
||||||
{
|
|
||||||
std::vector<std::vector<int>> Xr; // nxm
|
|
||||||
std::vector<int> yr;
|
|
||||||
for (int col = 0; col < X.size(); ++col) {
|
|
||||||
Xr.push_back(std::vector<int>());
|
|
||||||
}
|
|
||||||
for (auto index : indices) {
|
|
||||||
for (int col = 0; col < X.size(); ++col) {
|
|
||||||
Xr[col].push_back(X[col][index]);
|
|
||||||
}
|
|
||||||
yr.push_back(y[index]);
|
|
||||||
}
|
|
||||||
return { Xr, yr };
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
|
||||||
{
|
|
||||||
map<std::string, bool> datasets = {
|
|
||||||
{"diabetes", true},
|
|
||||||
{"ecoli", true},
|
|
||||||
{"glass", true},
|
|
||||||
{"iris", true},
|
|
||||||
{"kdd_JapaneseVowels", false},
|
|
||||||
{"letter", true},
|
|
||||||
{"liver-disorders", true},
|
|
||||||
{"mfeat-factors", true},
|
|
||||||
};
|
|
||||||
auto valid_datasets = std::vector<std::string>();
|
|
||||||
transform(datasets.begin(), datasets.end(), back_inserter(valid_datasets),
|
|
||||||
[](const pair<std::string, bool>& pair) { return pair.first; });
|
|
||||||
argparse::ArgumentParser program("PlatformSample");
|
|
||||||
program.add_argument("-d", "--dataset")
|
|
||||||
.help("Dataset file name")
|
|
||||||
.action([valid_datasets](const std::string& value) {
|
|
||||||
if (find(valid_datasets.begin(), valid_datasets.end(), value) != valid_datasets.end()) {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
throw runtime_error("file must be one of {diabetes, ecoli, glass, iris, kdd_JapaneseVowels, letter, liver-disorders, mfeat-factors}");
|
|
||||||
}
|
|
||||||
);
|
|
||||||
program.add_argument("-p", "--path")
|
|
||||||
.help(" folder where the data files are located, default")
|
|
||||||
.default_value(std::string{ PATH }
|
|
||||||
);
|
|
||||||
program.add_argument("-m", "--model")
|
|
||||||
.help("Model to use " + platform::Models::instance()->tostring())
|
|
||||||
.action([](const std::string& value) {
|
|
||||||
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
|
|
||||||
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
throw runtime_error("Model must be one of " + platform::Models::instance()->tostring());
|
|
||||||
}
|
|
||||||
);
|
|
||||||
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);
|
|
||||||
program.add_argument("--dumpcpt").help("Dump CPT Tables").default_value(false).implicit_value(true);
|
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value(false).implicit_value(true);
|
|
||||||
program.add_argument("--tensors").help("Use tensors to store samples").default_value(false).implicit_value(true);
|
|
||||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(5).scan<'i', int>().action([](const std::string& value) {
|
|
||||||
try {
|
|
||||||
auto k = stoi(value);
|
|
||||||
if (k < 2) {
|
|
||||||
throw runtime_error("Number of folds must be greater than 1");
|
|
||||||
}
|
|
||||||
return k;
|
|
||||||
}
|
|
||||||
catch (const runtime_error& err) {
|
|
||||||
throw runtime_error(err.what());
|
|
||||||
}
|
|
||||||
catch (...) {
|
|
||||||
throw runtime_error("Number of folds must be an integer");
|
|
||||||
}});
|
|
||||||
program.add_argument("-s", "--seed").help("Random seed").default_value(-1).scan<'i', int>();
|
|
||||||
bool class_last, stratified, tensors, dump_cpt;
|
|
||||||
std::string model_name, file_name, path, complete_file_name;
|
|
||||||
int nFolds, seed;
|
|
||||||
try {
|
|
||||||
program.parse_args(argc, argv);
|
|
||||||
file_name = program.get<std::string>("dataset");
|
|
||||||
path = program.get<std::string>("path");
|
|
||||||
model_name = program.get<std::string>("model");
|
|
||||||
complete_file_name = path + file_name + ".arff";
|
|
||||||
stratified = program.get<bool>("stratified");
|
|
||||||
tensors = program.get<bool>("tensors");
|
|
||||||
nFolds = program.get<int>("folds");
|
|
||||||
seed = program.get<int>("seed");
|
|
||||||
dump_cpt = program.get<bool>("dumpcpt");
|
|
||||||
class_last = datasets[file_name];
|
|
||||||
if (!file_exists(complete_file_name)) {
|
|
||||||
throw runtime_error("Data File " + path + file_name + ".arff" + " does not exist");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const exception& err) {
|
|
||||||
cerr << err.what() << std::endl;
|
|
||||||
cerr << program;
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Begin Processing
|
|
||||||
*/
|
|
||||||
auto handler = ArffFiles();
|
|
||||||
handler.load(complete_file_name, class_last);
|
|
||||||
// Get Dataset X, y
|
|
||||||
std::vector<mdlp::samples_t>& X = handler.getX();
|
|
||||||
mdlp::labels_t& y = handler.getY();
|
|
||||||
// Get className & Features
|
|
||||||
auto className = handler.getClassName();
|
|
||||||
std::vector<std::string> features;
|
|
||||||
auto attributes = handler.getAttributes();
|
|
||||||
transform(attributes.begin(), attributes.end(), back_inserter(features),
|
|
||||||
[](const pair<std::string, std::string>& item) { return item.first; });
|
|
||||||
// Discretize Dataset
|
|
||||||
auto [Xd, maxes] = discretize(X, y, features);
|
|
||||||
maxes[className] = *max_element(y.begin(), y.end()) + 1;
|
|
||||||
map<std::string, std::vector<int>> states;
|
|
||||||
for (auto feature : features) {
|
|
||||||
states[feature] = std::vector<int>(maxes[feature]);
|
|
||||||
}
|
|
||||||
states[className] = std::vector<int>(maxes[className]);
|
|
||||||
auto clf = platform::Models::instance()->create(model_name);
|
|
||||||
clf->fit(Xd, y, features, className, states);
|
|
||||||
if (dump_cpt) {
|
|
||||||
std::cout << "--- CPT Tables ---" << std::endl;
|
|
||||||
clf->dump_cpt();
|
|
||||||
}
|
|
||||||
auto lines = clf->show();
|
|
||||||
for (auto line : lines) {
|
|
||||||
std::cout << line << std::endl;
|
|
||||||
}
|
|
||||||
std::cout << "--- Topological Order ---" << std::endl;
|
|
||||||
auto order = clf->topological_order();
|
|
||||||
for (auto name : order) {
|
|
||||||
std::cout << name << ", ";
|
|
||||||
}
|
|
||||||
std::cout << "end." << std::endl;
|
|
||||||
auto score = clf->score(Xd, y);
|
|
||||||
std::cout << "Score: " << score << std::endl;
|
|
||||||
auto graph = clf->graph();
|
|
||||||
auto dot_file = model_name + "_" + file_name;
|
|
||||||
ofstream file(dot_file + ".dot");
|
|
||||||
file << graph;
|
|
||||||
file.close();
|
|
||||||
std::cout << "Graph saved in " << model_name << "_" << file_name << ".dot" << std::endl;
|
|
||||||
std::cout << "dot -Tpng -o " + dot_file + ".png " + dot_file + ".dot " << std::endl;
|
|
||||||
std::string stratified_string = stratified ? " Stratified" : "";
|
|
||||||
std::cout << nFolds << " Folds" << stratified_string << " Cross validation" << std::endl;
|
|
||||||
std::cout << "==========================================" << std::endl;
|
|
||||||
torch::Tensor Xt = torch::zeros({ static_cast<int>(Xd.size()), static_cast<int>(Xd[0].size()) }, torch::kInt32);
|
|
||||||
torch::Tensor yt = torch::tensor(y, torch::kInt32);
|
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
|
||||||
Xt.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32));
|
|
||||||
}
|
|
||||||
float total_score = 0, total_score_train = 0, score_train, score_test;
|
|
||||||
folding::Fold* fold;
|
|
||||||
double nodes = 0.0;
|
|
||||||
if (stratified)
|
|
||||||
fold = new folding::StratifiedKFold(nFolds, y, seed);
|
|
||||||
else
|
|
||||||
fold = new folding::KFold(nFolds, y.size(), seed);
|
|
||||||
for (auto i = 0; i < nFolds; ++i) {
|
|
||||||
auto [train, test] = fold->getFold(i);
|
|
||||||
std::cout << "Fold: " << i + 1 << std::endl;
|
|
||||||
if (tensors) {
|
|
||||||
auto ttrain = torch::tensor(train, torch::kInt64);
|
|
||||||
auto ttest = torch::tensor(test, torch::kInt64);
|
|
||||||
torch::Tensor Xtraint = torch::index_select(Xt, 1, ttrain);
|
|
||||||
torch::Tensor ytraint = yt.index({ ttrain });
|
|
||||||
torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
|
|
||||||
torch::Tensor ytestt = yt.index({ ttest });
|
|
||||||
clf->fit(Xtraint, ytraint, features, className, states);
|
|
||||||
auto temp = clf->predict(Xtraint);
|
|
||||||
score_train = clf->score(Xtraint, ytraint);
|
|
||||||
score_test = clf->score(Xtestt, ytestt);
|
|
||||||
} else {
|
|
||||||
auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
|
|
||||||
auto [Xtest, ytest] = extract_indices(test, Xd, y);
|
|
||||||
clf->fit(Xtrain, ytrain, features, className, states);
|
|
||||||
std::cout << "Nodes: " << clf->getNumberOfNodes() << std::endl;
|
|
||||||
nodes += clf->getNumberOfNodes();
|
|
||||||
score_train = clf->score(Xtrain, ytrain);
|
|
||||||
score_test = clf->score(Xtest, ytest);
|
|
||||||
}
|
|
||||||
if (dump_cpt) {
|
|
||||||
std::cout << "--- CPT Tables ---" << std::endl;
|
|
||||||
clf->dump_cpt();
|
|
||||||
}
|
|
||||||
total_score_train += score_train;
|
|
||||||
total_score += score_test;
|
|
||||||
std::cout << "Score Train: " << score_train << std::endl;
|
|
||||||
std::cout << "Score Test : " << score_test << std::endl;
|
|
||||||
std::cout << "-------------------------------------------------------------------------------" << std::endl;
|
|
||||||
}
|
|
||||||
std::cout << "Nodes: " << nodes / nFolds << std::endl;
|
|
||||||
std::cout << "**********************************************************************************" << std::endl;
|
|
||||||
std::cout << "Average Score Train: " << total_score_train / nFolds << std::endl;
|
|
||||||
std::cout << "Average Score Test : " << total_score / nFolds << std::endl;return 0;
|
|
||||||
}
|
|
241
sample/sample.cpp
Normal file
241
sample/sample.cpp
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <fstream>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <argparse/argparse.hpp>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include <ArffFiles.hpp>
|
||||||
|
#include <fimdlp/CPPFImdlp.h>
|
||||||
|
#include <folding.hpp>
|
||||||
|
#include <bayesnet/utils/BayesMetrics.h>
|
||||||
|
#include "Models.h"
|
||||||
|
#include "modelRegister.h"
|
||||||
|
#include "config_platform.h"
|
||||||
|
|
||||||
|
const std::string PATH = { platform_data_path.begin(), platform_data_path.end() };
|
||||||
|
|
||||||
|
pair<std::vector<mdlp::labels_t>, map<std::string, int>> discretize(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y, std::vector<std::string> features)
|
||||||
|
{
|
||||||
|
std::vector<mdlp::labels_t>Xd;
|
||||||
|
map<std::string, int> maxes;
|
||||||
|
|
||||||
|
auto fimdlp = mdlp::CPPFImdlp();
|
||||||
|
for (int i = 0; i < X.size(); i++) {
|
||||||
|
fimdlp.fit(X[i], y);
|
||||||
|
mdlp::labels_t& xd = fimdlp.transform(X[i]);
|
||||||
|
maxes[features[i]] = *max_element(xd.begin(), xd.end()) + 1;
|
||||||
|
Xd.push_back(xd);
|
||||||
|
}
|
||||||
|
return { Xd, maxes };
|
||||||
|
}
|
||||||
|
|
||||||
|
bool file_exists(const std::string& name)
|
||||||
|
{
|
||||||
|
if (FILE* file = fopen(name.c_str(), "r")) {
|
||||||
|
fclose(file);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pair<std::vector<std::vector<int>>, std::vector<int>> extract_indices(std::vector<int> indices, std::vector<std::vector<int>> X, std::vector<int> y)
|
||||||
|
{
|
||||||
|
std::vector<std::vector<int>> Xr; // nxm
|
||||||
|
std::vector<int> yr;
|
||||||
|
for (int col = 0; col < X.size(); ++col) {
|
||||||
|
Xr.push_back(std::vector<int>());
|
||||||
|
}
|
||||||
|
for (auto index : indices) {
|
||||||
|
for (int col = 0; col < X.size(); ++col) {
|
||||||
|
Xr[col].push_back(X[col][index]);
|
||||||
|
}
|
||||||
|
yr.push_back(y[index]);
|
||||||
|
}
|
||||||
|
return { Xr, yr };
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
map<std::string, bool> datasets = {
|
||||||
|
{"diabetes", true},
|
||||||
|
{"ecoli", true},
|
||||||
|
{"glass", true},
|
||||||
|
{"iris", true},
|
||||||
|
{"kdd_JapaneseVowels", false},
|
||||||
|
{"letter", true},
|
||||||
|
{"liver-disorders", true},
|
||||||
|
{"mfeat-factors", true},
|
||||||
|
};
|
||||||
|
auto valid_datasets = std::vector<std::string>();
|
||||||
|
transform(datasets.begin(), datasets.end(), back_inserter(valid_datasets),
|
||||||
|
[](const pair<std::string, bool>& pair) { return pair.first; });
|
||||||
|
argparse::ArgumentParser program("PlatformSample");
|
||||||
|
program.add_argument("-d", "--dataset")
|
||||||
|
.help("Dataset file name")
|
||||||
|
.action([valid_datasets](const std::string& value) {
|
||||||
|
if (find(valid_datasets.begin(), valid_datasets.end(), value) != valid_datasets.end()) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
throw runtime_error("file must be one of {diabetes, ecoli, glass, iris, kdd_JapaneseVowels, letter, liver-disorders, mfeat-factors}");
|
||||||
|
}
|
||||||
|
);
|
||||||
|
program.add_argument("-p", "--path")
|
||||||
|
.help(" folder where the data files are located, default")
|
||||||
|
.default_value(std::string{ PATH }
|
||||||
|
);
|
||||||
|
program.add_argument("-m", "--model")
|
||||||
|
.help("Model to use " + platform::Models::instance()->toString())
|
||||||
|
.action([](const std::string& value) {
|
||||||
|
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
|
||||||
|
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
|
||||||
|
}
|
||||||
|
);
|
||||||
|
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--dumpcpt").help("Dump CPT Tables").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--tensors").help("Use tensors to store samples").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("-f", "--folds").help("Number of folds").default_value(5).scan<'i', int>().action([](const std::string& value) {
|
||||||
|
try {
|
||||||
|
auto k = stoi(value);
|
||||||
|
if (k < 2) {
|
||||||
|
throw runtime_error("Number of folds must be greater than 1");
|
||||||
|
}
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
catch (const runtime_error& err) {
|
||||||
|
throw runtime_error(err.what());
|
||||||
|
}
|
||||||
|
catch (...) {
|
||||||
|
throw runtime_error("Number of folds must be an integer");
|
||||||
|
}});
|
||||||
|
program.add_argument("-s", "--seed").help("Random seed").default_value(-1).scan<'i', int>();
|
||||||
|
bool class_last, stratified, tensors, dump_cpt;
|
||||||
|
std::string model_name, file_name, path, complete_file_name;
|
||||||
|
int nFolds, seed;
|
||||||
|
try {
|
||||||
|
program.parse_args(argc, argv);
|
||||||
|
file_name = program.get<std::string>("dataset");
|
||||||
|
path = program.get<std::string>("path");
|
||||||
|
model_name = program.get<std::string>("model");
|
||||||
|
complete_file_name = path + file_name + ".arff";
|
||||||
|
stratified = program.get<bool>("stratified");
|
||||||
|
tensors = program.get<bool>("tensors");
|
||||||
|
nFolds = program.get<int>("folds");
|
||||||
|
seed = program.get<int>("seed");
|
||||||
|
dump_cpt = program.get<bool>("dumpcpt");
|
||||||
|
class_last = datasets[file_name];
|
||||||
|
if (!file_exists(complete_file_name)) {
|
||||||
|
throw runtime_error("Data File " + path + file_name + ".arff" + " does not exist");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (const exception& err) {
|
||||||
|
cerr << err.what() << std::endl;
|
||||||
|
cerr << program;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Begin Processing
|
||||||
|
*/
|
||||||
|
auto handler = ArffFiles();
|
||||||
|
handler.load(complete_file_name, class_last);
|
||||||
|
// Get Dataset X, y
|
||||||
|
std::vector<mdlp::samples_t>& X = handler.getX();
|
||||||
|
mdlp::labels_t& y = handler.getY();
|
||||||
|
// Get className & Features
|
||||||
|
auto className = handler.getClassName();
|
||||||
|
std::vector<std::string> features;
|
||||||
|
auto attributes = handler.getAttributes();
|
||||||
|
transform(attributes.begin(), attributes.end(), back_inserter(features),
|
||||||
|
[](const pair<std::string, std::string>& item) { return item.first; });
|
||||||
|
// Discretize Dataset
|
||||||
|
auto [Xd, maxes] = discretize(X, y, features);
|
||||||
|
maxes[className] = *max_element(y.begin(), y.end()) + 1;
|
||||||
|
map<std::string, std::vector<int>> states;
|
||||||
|
for (auto feature : features) {
|
||||||
|
states[feature] = std::vector<int>(maxes[feature]);
|
||||||
|
}
|
||||||
|
states[className] = std::vector<int>(maxes[className]);
|
||||||
|
auto clf = platform::Models::instance()->create(model_name);
|
||||||
|
bayesnet::Smoothing_t smoothing = bayesnet::Smoothing_t::ORIGINAL;
|
||||||
|
clf->fit(Xd, y, features, className, states, smoothing);
|
||||||
|
if (dump_cpt) {
|
||||||
|
std::cout << "--- CPT Tables ---" << std::endl;
|
||||||
|
clf->dump_cpt();
|
||||||
|
}
|
||||||
|
auto lines = clf->show();
|
||||||
|
for (auto line : lines) {
|
||||||
|
std::cout << line << std::endl;
|
||||||
|
}
|
||||||
|
std::cout << "--- Topological Order ---" << std::endl;
|
||||||
|
auto order = clf->topological_order();
|
||||||
|
for (auto name : order) {
|
||||||
|
std::cout << name << ", ";
|
||||||
|
}
|
||||||
|
std::cout << "end." << std::endl;
|
||||||
|
auto score = clf->score(Xd, y);
|
||||||
|
std::cout << "Score: " << score << std::endl;
|
||||||
|
auto graph = clf->graph();
|
||||||
|
auto dot_file = model_name + "_" + file_name;
|
||||||
|
ofstream file(dot_file + ".dot");
|
||||||
|
file << graph;
|
||||||
|
file.close();
|
||||||
|
std::cout << "Graph saved in " << model_name << "_" << file_name << ".dot" << std::endl;
|
||||||
|
std::cout << "dot -Tpng -o " + dot_file + ".png " + dot_file + ".dot " << std::endl;
|
||||||
|
std::string stratified_string = stratified ? " Stratified" : "";
|
||||||
|
std::cout << nFolds << " Folds" << stratified_string << " Cross validation" << std::endl;
|
||||||
|
std::cout << "==========================================" << std::endl;
|
||||||
|
torch::Tensor Xt = torch::zeros({ static_cast<int>(Xd.size()), static_cast<int>(Xd[0].size()) }, torch::kInt32);
|
||||||
|
torch::Tensor yt = torch::tensor(y, torch::kInt32);
|
||||||
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
|
Xt.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32));
|
||||||
|
}
|
||||||
|
float total_score = 0, total_score_train = 0, score_train, score_test;
|
||||||
|
folding::Fold* fold;
|
||||||
|
double nodes = 0.0;
|
||||||
|
if (stratified)
|
||||||
|
fold = new folding::StratifiedKFold(nFolds, y, seed);
|
||||||
|
else
|
||||||
|
fold = new folding::KFold(nFolds, y.size(), seed);
|
||||||
|
for (auto i = 0; i < nFolds; ++i) {
|
||||||
|
auto [train, test] = fold->getFold(i);
|
||||||
|
std::cout << "Fold: " << i + 1 << std::endl;
|
||||||
|
if (tensors) {
|
||||||
|
auto ttrain = torch::tensor(train, torch::kInt64);
|
||||||
|
auto ttest = torch::tensor(test, torch::kInt64);
|
||||||
|
torch::Tensor Xtraint = torch::index_select(Xt, 1, ttrain);
|
||||||
|
torch::Tensor ytraint = yt.index({ ttrain });
|
||||||
|
torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
|
||||||
|
torch::Tensor ytestt = yt.index({ ttest });
|
||||||
|
clf->fit(Xtraint, ytraint, features, className, states, smoothing);
|
||||||
|
auto temp = clf->predict(Xtraint);
|
||||||
|
score_train = clf->score(Xtraint, ytraint);
|
||||||
|
score_test = clf->score(Xtestt, ytestt);
|
||||||
|
} else {
|
||||||
|
auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
|
||||||
|
auto [Xtest, ytest] = extract_indices(test, Xd, y);
|
||||||
|
clf->fit(Xtrain, ytrain, features, className, states, smoothing);
|
||||||
|
std::cout << "Nodes: " << clf->getNumberOfNodes() << std::endl;
|
||||||
|
nodes += clf->getNumberOfNodes();
|
||||||
|
score_train = clf->score(Xtrain, ytrain);
|
||||||
|
score_test = clf->score(Xtest, ytest);
|
||||||
|
}
|
||||||
|
if (dump_cpt) {
|
||||||
|
std::cout << "--- CPT Tables ---" << std::endl;
|
||||||
|
clf->dump_cpt();
|
||||||
|
}
|
||||||
|
total_score_train += score_train;
|
||||||
|
total_score += score_test;
|
||||||
|
std::cout << "Score Train: " << score_train << std::endl;
|
||||||
|
std::cout << "Score Test : " << score_test << std::endl;
|
||||||
|
std::cout << "-------------------------------------------------------------------------------" << std::endl;
|
||||||
|
}
|
||||||
|
std::cout << "Nodes: " << nodes / nFolds << std::endl;
|
||||||
|
std::cout << "**********************************************************************************" << std::endl;
|
||||||
|
std::cout << "Average Score Train: " << total_score_train / nFolds << std::endl;
|
||||||
|
std::cout << "Average Score Test : " << total_score / nFolds << std::endl;return 0;
|
||||||
|
}
|
@@ -1,53 +1,69 @@
|
|||||||
include_directories(
|
include_directories(
|
||||||
## Libs
|
## Libs
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/lib/BayesNet/src
|
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/lib/BayesNet/lib/folding
|
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/lib/BayesNet/lib/mdlp
|
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/lib/BayesNet/lib/json/include
|
|
||||||
${Platform_SOURCE_DIR}/lib/PyClassifiers/src
|
|
||||||
${Platform_SOURCE_DIR}/lib/Files
|
${Platform_SOURCE_DIR}/lib/Files
|
||||||
${Platform_SOURCE_DIR}/lib/mdlp
|
${Platform_SOURCE_DIR}/lib/folding
|
||||||
|
${Platform_SOURCE_DIR}/lib/mdlp/src
|
||||||
${Platform_SOURCE_DIR}/lib/argparse/include
|
${Platform_SOURCE_DIR}/lib/argparse/include
|
||||||
${Platform_SOURCE_DIR}/lib/json/include
|
${Platform_SOURCE_DIR}/lib/json/include
|
||||||
${Platform_SOURCE_DIR}/lib/libxlsxwriter/include
|
${Platform_SOURCE_DIR}/lib/libxlsxwriter/include
|
||||||
${Python3_INCLUDE_DIRS}
|
${Python3_INCLUDE_DIRS}
|
||||||
${MPI_CXX_INCLUDE_DIRS}
|
${MPI_CXX_INCLUDE_DIRS}
|
||||||
|
${TORCH_INCLUDE_DIRS}
|
||||||
${CMAKE_BINARY_DIR}/configured_files/include
|
${CMAKE_BINARY_DIR}/configured_files/include
|
||||||
|
${PyClassifiers_INCLUDE_DIRS}
|
||||||
|
${Bayesnet_INCLUDE_DIRS}
|
||||||
## Platform
|
## Platform
|
||||||
${Platform_SOURCE_DIR}/src/common
|
${Platform_SOURCE_DIR}/src
|
||||||
${Platform_SOURCE_DIR}/src/best
|
${Platform_SOURCE_DIR}/results
|
||||||
${Platform_SOURCE_DIR}/src/grid
|
|
||||||
${Platform_SOURCE_DIR}/src/main
|
|
||||||
${Platform_SOURCE_DIR}/src/manage
|
|
||||||
${Platform_SOURCE_DIR}/src/reports
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# b_best
|
# b_best
|
||||||
set(best_sources b_best.cc BestResults.cc Statistics.cc BestResultsExcel.cc)
|
add_executable(
|
||||||
list(TRANSFORM best_sources PREPEND best/)
|
b_best commands/b_best.cpp best/Statistics.cpp
|
||||||
add_executable(b_best ${best_sources} main/Result.cc reports/ReportExcel.cc reports/ReportBase.cc reports/ExcelFile.cc common/Datasets.cc common/Dataset.cc)
|
best/BestResultsExcel.cpp best/BestResultsTex.cpp best/BestResultsMd.cpp best/BestResults.cpp
|
||||||
target_link_libraries(b_best Boost::boost "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
|
common/Datasets.cpp common/Dataset.cpp common/Discretization.cpp
|
||||||
|
main/Models.cpp main/Scores.cpp
|
||||||
|
reports/ReportExcel.cpp reports/ReportBase.cpp reports/ExcelFile.cpp
|
||||||
|
results/Result.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(b_best Boost::boost "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
||||||
|
|
||||||
# b_grid
|
# b_grid
|
||||||
set(grid_sources b_grid.cc GridSearch.cc GridData.cc)
|
set(grid_sources GridSearch.cpp GridData.cpp)
|
||||||
list(TRANSFORM grid_sources PREPEND grid/)
|
list(TRANSFORM grid_sources PREPEND grid/)
|
||||||
add_executable(b_grid ${grid_sources} main/HyperParameters.cc main/Models.cc common/Datasets.cc common/Dataset.cc)
|
add_executable(b_grid commands/b_grid.cpp ${grid_sources}
|
||||||
target_link_libraries(b_grid PyClassifiers ${MPI_CXX_LIBRARIES} ArffFiles)
|
common/Datasets.cpp common/Dataset.cpp common/Discretization.cpp
|
||||||
|
main/HyperParameters.cpp main/Models.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy)
|
||||||
|
|
||||||
# b_list
|
# b_list
|
||||||
set(list_sources b_list.cc DatasetsExcel.cc)
|
add_executable(b_list commands/b_list.cpp
|
||||||
list(TRANSFORM list_sources PREPEND list/)
|
common/Datasets.cpp common/Dataset.cpp common/Discretization.cpp
|
||||||
add_executable(b_list ${list_sources} common/Datasets.cc common/Dataset.cc reports/ReportExcel.cc reports/ExcelFile.cc reports/ReportBase.cc)
|
main/Models.cpp main/Scores.cpp
|
||||||
target_link_libraries(b_list "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
|
reports/ReportExcel.cpp reports/ExcelFile.cpp reports/ReportBase.cpp reports/DatasetsExcel.cpp reports/DatasetsConsole.cpp reports/ReportsPaged.cpp
|
||||||
|
results/Result.cpp results/ResultsDatasetExcel.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(b_list "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
||||||
|
|
||||||
# b_main
|
# b_main
|
||||||
set(main_sources b_main.cc Experiment.cc Models.cc HyperParameters.cc)
|
set(main_sources Experiment.cpp Models.cpp HyperParameters.cpp Scores.cpp)
|
||||||
list(TRANSFORM main_sources PREPEND main/)
|
list(TRANSFORM main_sources PREPEND main/)
|
||||||
add_executable(b_main ${main_sources} common/Datasets.cc common/Dataset.cc reports/ReportConsole.cc reports/ReportBase.cc main/Result.cc)
|
add_executable(b_main commands/b_main.cpp ${main_sources}
|
||||||
target_link_libraries(b_main PyClassifiers BayesNet ArffFiles mdlp)
|
common/Datasets.cpp common/Dataset.cpp common/Discretization.cpp
|
||||||
|
reports/ReportConsole.cpp reports/ReportBase.cpp
|
||||||
|
results/Result.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(b_main "${PyClassifiers}" "${BayesNet}" fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy)
|
||||||
|
|
||||||
# b_manage
|
# b_manage
|
||||||
set(manage_sources b_manage.cc ManageResults.cc CommandParser.cc Results.cc)
|
set(manage_sources ManageScreen.cpp OptionsMenu.cpp ResultsManager.cpp)
|
||||||
list(TRANSFORM manage_sources PREPEND manage/)
|
list(TRANSFORM manage_sources PREPEND manage/)
|
||||||
add_executable(b_manage ${manage_sources} main/Result.cc reports/ReportConsole.cc reports/ReportExcel.cc reports/ReportBase.cc reports/ExcelFile.cc common/Datasets.cc common/Dataset.cc)
|
add_executable(
|
||||||
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
|
b_manage commands/b_manage.cpp ${manage_sources}
|
||||||
|
common/Datasets.cpp common/Dataset.cpp common/Discretization.cpp
|
||||||
|
reports/ReportConsole.cpp reports/ReportExcel.cpp reports/ReportExcelCompared.cpp reports/ReportBase.cpp reports/ExcelFile.cpp reports/DatasetsConsole.cpp reports/ReportsPaged.cpp
|
||||||
|
results/Result.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp
|
||||||
|
main/Scores.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" fimdlp "${BayesNet}")
|
||||||
|
@@ -4,12 +4,16 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "BestResults.h"
|
#include "common/Colors.h"
|
||||||
#include "Result.h"
|
#include "common/CLocale.h"
|
||||||
#include "Colors.h"
|
#include "common/Paths.h"
|
||||||
#include "Statistics.h"
|
#include "common/Utils.h" // compute_std
|
||||||
|
#include "results/Result.h"
|
||||||
#include "BestResultsExcel.h"
|
#include "BestResultsExcel.h"
|
||||||
#include "CLocale.h"
|
#include "BestResultsTex.h"
|
||||||
|
#include "BestResultsMd.h"
|
||||||
|
#include "best/Statistics.h"
|
||||||
|
#include "BestResults.h"
|
||||||
|
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
namespace fs = std::filesystem;
|
||||||
@@ -42,26 +46,29 @@ namespace platform {
|
|||||||
for (auto const& item : data.at("results")) {
|
for (auto const& item : data.at("results")) {
|
||||||
bool update = true;
|
bool update = true;
|
||||||
auto datasetName = item.at("dataset").get<std::string>();
|
auto datasetName = item.at("dataset").get<std::string>();
|
||||||
|
if (dataset != "any" && dataset != datasetName) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (bests.contains(datasetName)) {
|
if (bests.contains(datasetName)) {
|
||||||
if (item.at("score").get<double>() < bests[datasetName].at(0).get<double>()) {
|
if (item.at("score").get<double>() < bests[datasetName].at(0).get<double>()) {
|
||||||
update = false;
|
update = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (update) {
|
if (update) {
|
||||||
bests[datasetName] = { item.at("score").get<double>(), item.at("hyperparameters"), file };
|
bests[datasetName] = { item.at("score").get<double>(), item.at("hyperparameters"), file, item.at("score_std").get<double>() };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::string bestFileName = path + bestResultFile();
|
if (bests.empty()) {
|
||||||
|
std::cerr << Colors::MAGENTA() << "No results found for model " << model << " and score " << score << Colors::RESET() << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
std::string bestFileName = path + Paths::bestResultsFile(score, model);
|
||||||
std::ofstream file(bestFileName);
|
std::ofstream file(bestFileName);
|
||||||
file << bests;
|
file << bests;
|
||||||
file.close();
|
file.close();
|
||||||
return bestFileName;
|
return bestFileName;
|
||||||
}
|
}
|
||||||
std::string BestResults::bestResultFile()
|
|
||||||
{
|
|
||||||
return "best_results_" + score + "_" + model + ".json";
|
|
||||||
}
|
|
||||||
std::pair<std::string, std::string> getModelScore(std::string name)
|
std::pair<std::string, std::string> getModelScore(std::string name)
|
||||||
{
|
{
|
||||||
// results_accuracy_BoostAODE_MacBookpro16_2023-09-06_12:27:00_1.json
|
// results_accuracy_BoostAODE_MacBookpro16_2023-09-06_12:27:00_1.json
|
||||||
@@ -122,8 +129,8 @@ namespace platform {
|
|||||||
std::vector<std::string> BestResults::getDatasets(json table)
|
std::vector<std::string> BestResults::getDatasets(json table)
|
||||||
{
|
{
|
||||||
std::vector<std::string> datasets;
|
std::vector<std::string> datasets;
|
||||||
for (const auto& dataset : table.items()) {
|
for (const auto& dataset_ : table.items()) {
|
||||||
datasets.push_back(dataset.key());
|
datasets.push_back(dataset_.key());
|
||||||
}
|
}
|
||||||
maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
|
maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
|
||||||
maxDatasetName = std::max(7, maxDatasetName);
|
maxDatasetName = std::max(7, maxDatasetName);
|
||||||
@@ -143,7 +150,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
void BestResults::listFile()
|
void BestResults::listFile()
|
||||||
{
|
{
|
||||||
std::string bestFileName = path + bestResultFile();
|
std::string bestFileName = path + Paths::bestResultsFile(score, model);
|
||||||
if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) {
|
if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) {
|
||||||
fclose(fileTest);
|
fclose(fileTest);
|
||||||
} else {
|
} else {
|
||||||
@@ -167,10 +174,9 @@ namespace platform {
|
|||||||
std::cout << Colors::GREEN() << " # " << std::setw(maxDatasetName + 1) << std::left << "Dataset" << "Score " << std::setw(maxFileName) << "File" << " Hyperparameters" << std::endl;
|
std::cout << Colors::GREEN() << " # " << std::setw(maxDatasetName + 1) << std::left << "Dataset" << "Score " << std::setw(maxFileName) << "File" << " Hyperparameters" << std::endl;
|
||||||
std::cout << "=== " << std::string(maxDatasetName, '=') << " =========== " << std::string(maxFileName, '=') << " " << std::string(maxHyper, '=') << std::endl;
|
std::cout << "=== " << std::string(maxDatasetName, '=') << " =========== " << std::string(maxFileName, '=') << " " << std::string(maxHyper, '=') << std::endl;
|
||||||
auto i = 0;
|
auto i = 0;
|
||||||
bool odd = true;
|
|
||||||
double total = 0;
|
double total = 0;
|
||||||
for (auto const& item : data.items()) {
|
for (auto const& item : data.items()) {
|
||||||
auto color = odd ? Colors::BLUE() : Colors::CYAN();
|
auto color = (i % 2) ? Colors::BLUE() : Colors::CYAN();
|
||||||
double value = item.value().at(0).get<double>();
|
double value = item.value().at(0).get<double>();
|
||||||
std::cout << color << std::setw(3) << std::fixed << std::right << i++ << " ";
|
std::cout << color << std::setw(3) << std::fixed << std::right << i++ << " ";
|
||||||
std::cout << std::setw(maxDatasetName) << std::left << item.key() << " ";
|
std::cout << std::setw(maxDatasetName) << std::left << item.key() << " ";
|
||||||
@@ -179,7 +185,6 @@ namespace platform {
|
|||||||
std::cout << item.value().at(1) << " ";
|
std::cout << item.value().at(1) << " ";
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
total += value;
|
total += value;
|
||||||
odd = !odd;
|
|
||||||
}
|
}
|
||||||
std::cout << Colors::GREEN() << "=== " << std::string(maxDatasetName, '=') << " ===========" << std::endl;
|
std::cout << Colors::GREEN() << "=== " << std::string(maxDatasetName, '=') << " ===========" << std::endl;
|
||||||
std::cout << Colors::GREEN() << " Total" << std::string(maxDatasetName - 5, '.') << " " << std::setw(11) << std::setprecision(8) << std::fixed << total << std::endl;
|
std::cout << Colors::GREEN() << " Total" << std::string(maxDatasetName - 5, '.') << " " << std::setw(11) << std::setprecision(8) << std::fixed << total << std::endl;
|
||||||
@@ -191,7 +196,7 @@ namespace platform {
|
|||||||
auto maxDate = std::filesystem::file_time_type::max();
|
auto maxDate = std::filesystem::file_time_type::max();
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
this->model = model;
|
this->model = model;
|
||||||
std::string bestFileName = path + bestResultFile();
|
std::string bestFileName = path + Paths::bestResultsFile(score, model);
|
||||||
if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) {
|
if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) {
|
||||||
fclose(fileTest);
|
fclose(fileTest);
|
||||||
} else {
|
} else {
|
||||||
@@ -208,13 +213,20 @@ namespace platform {
|
|||||||
table["dateTable"] = ftime_to_string(maxDate);
|
table["dateTable"] = ftime_to_string(maxDate);
|
||||||
return table;
|
return table;
|
||||||
}
|
}
|
||||||
void BestResults::printTableResults(std::vector<std::string> models, json table)
|
|
||||||
|
void BestResults::printTableResults(std::vector<std::string> models, json table, bool tex)
|
||||||
{
|
{
|
||||||
std::stringstream oss;
|
std::stringstream oss;
|
||||||
oss << Colors::GREEN() << "Best results for " << score << " as of " << table.at("dateTable").get<std::string>() << std::endl;
|
oss << Colors::GREEN() << "Best results for " << score << " as of " << table.at("dateTable").get<std::string>() << std::endl;
|
||||||
std::cout << oss.str();
|
std::cout << oss.str();
|
||||||
std::cout << std::string(oss.str().size() - 8, '-') << std::endl;
|
std::cout << std::string(oss.str().size() - 8, '-') << std::endl;
|
||||||
std::cout << Colors::GREEN() << " # " << std::setw(maxDatasetName + 1) << std::left << std::string("Dataset");
|
std::cout << Colors::GREEN() << " # " << std::setw(maxDatasetName + 1) << std::left << std::string("Dataset");
|
||||||
|
auto bestResultsTex = BestResultsTex();
|
||||||
|
auto bestResultsMd = BestResultsMd();
|
||||||
|
if (tex) {
|
||||||
|
bestResultsTex.results_header(models, table.at("dateTable").get<std::string>());
|
||||||
|
bestResultsMd.results_header(models, table.at("dateTable").get<std::string>());
|
||||||
|
}
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
std::cout << std::setw(maxModelName) << std::left << model << " ";
|
std::cout << std::setw(maxModelName) << std::left << model << " ";
|
||||||
}
|
}
|
||||||
@@ -225,23 +237,23 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
auto i = 0;
|
auto i = 0;
|
||||||
bool odd = true;
|
std::map<std::string, std::vector<double>> totals;
|
||||||
std::map<std::string, double> totals;
|
|
||||||
int nDatasets = table.begin().value().size();
|
int nDatasets = table.begin().value().size();
|
||||||
for (const auto& model : models) {
|
|
||||||
totals[model] = 0.0;
|
|
||||||
}
|
|
||||||
auto datasets = getDatasets(table.begin().value());
|
auto datasets = getDatasets(table.begin().value());
|
||||||
for (auto const& dataset : datasets) {
|
if (tex) {
|
||||||
auto color = odd ? Colors::BLUE() : Colors::CYAN();
|
bestResultsTex.results_body(datasets, table);
|
||||||
|
bestResultsMd.results_body(datasets, table);
|
||||||
|
}
|
||||||
|
for (auto const& dataset_ : datasets) {
|
||||||
|
auto color = (i % 2) ? Colors::BLUE() : Colors::CYAN();
|
||||||
std::cout << color << std::setw(3) << std::fixed << std::right << i++ << " ";
|
std::cout << color << std::setw(3) << std::fixed << std::right << i++ << " ";
|
||||||
std::cout << std::setw(maxDatasetName) << std::left << dataset << " ";
|
std::cout << std::setw(maxDatasetName) << std::left << dataset_ << " ";
|
||||||
double maxValue = 0;
|
double maxValue = 0;
|
||||||
// Find out the max value for this dataset
|
// Find out the max value for this dataset
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
double value;
|
double value;
|
||||||
try {
|
try {
|
||||||
value = table[model].at(dataset).at(0).get<double>();
|
value = table[model].at(dataset_).at(0).get<double>();
|
||||||
}
|
}
|
||||||
catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) {
|
catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) {
|
||||||
value = -1.0;
|
value = -1.0;
|
||||||
@@ -255,7 +267,7 @@ namespace platform {
|
|||||||
std::string efectiveColor = color;
|
std::string efectiveColor = color;
|
||||||
double value;
|
double value;
|
||||||
try {
|
try {
|
||||||
value = table[model].at(dataset).at(0).get<double>();
|
value = table[model].at(dataset_).at(0).get<double>();
|
||||||
}
|
}
|
||||||
catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) {
|
catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) {
|
||||||
value = -1.0;
|
value = -1.0;
|
||||||
@@ -266,31 +278,37 @@ namespace platform {
|
|||||||
if (value == -1) {
|
if (value == -1) {
|
||||||
std::cout << Colors::YELLOW() << std::setw(maxModelName) << std::right << "N/A" << " ";
|
std::cout << Colors::YELLOW() << std::setw(maxModelName) << std::right << "N/A" << " ";
|
||||||
} else {
|
} else {
|
||||||
totals[model] += value;
|
totals[model].push_back(value);
|
||||||
std::cout << efectiveColor << std::setw(maxModelName) << std::setprecision(maxModelName - 2) << std::fixed << value << " ";
|
std::cout << efectiveColor << std::setw(maxModelName) << std::setprecision(maxModelName - 2) << std::fixed << value << " ";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
odd = !odd;
|
|
||||||
}
|
}
|
||||||
std::cout << Colors::GREEN() << "=== " << std::string(maxDatasetName, '=') << " ";
|
std::cout << Colors::GREEN() << "=== " << std::string(maxDatasetName, '=') << " ";
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
std::cout << std::string(maxModelName, '=') << " ";
|
std::cout << std::string(maxModelName, '=') << " ";
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
std::cout << Colors::GREEN() << " Totals" << std::string(maxDatasetName - 6, '.') << " ";
|
std::cout << Colors::GREEN() << " Average" << std::string(maxDatasetName - 7, '.') << " ";
|
||||||
double max_value = 0.0;
|
double max_value = 0.0;
|
||||||
|
std::string best_model = "";
|
||||||
for (const auto& total : totals) {
|
for (const auto& total : totals) {
|
||||||
if (total.second > max_value) {
|
auto actual = std::reduce(total.second.begin(), total.second.end());
|
||||||
max_value = total.second;
|
if (actual > max_value) {
|
||||||
|
max_value = actual;
|
||||||
|
best_model = total.first;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (tex) {
|
||||||
|
bestResultsTex.results_footer(totals, best_model);
|
||||||
|
bestResultsMd.results_footer(totals, best_model);
|
||||||
|
}
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
std::string efectiveColor = Colors::GREEN();
|
std::string efectiveColor = model == best_model ? Colors::RED() : Colors::GREEN();
|
||||||
if (totals[model] == max_value) {
|
double value = std::reduce(totals[model].begin(), totals[model].end()) / nDatasets;
|
||||||
efectiveColor = Colors::RED();
|
double std_value = compute_std(totals[model], value);
|
||||||
}
|
std::cout << efectiveColor << std::right << std::setw(maxModelName) << std::setprecision(maxModelName - 4) << std::fixed << value << " ";
|
||||||
std::cout << efectiveColor << std::right << std::setw(maxModelName) << std::setprecision(maxModelName - 4) << std::fixed << totals[model] << " ";
|
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@@ -303,26 +321,34 @@ namespace platform {
|
|||||||
json table = buildTableResults(models);
|
json table = buildTableResults(models);
|
||||||
std::vector<std::string> datasets = getDatasets(table.begin().value());
|
std::vector<std::string> datasets = getDatasets(table.begin().value());
|
||||||
BestResultsExcel excel_report(score, datasets);
|
BestResultsExcel excel_report(score, datasets);
|
||||||
excel_report.reportSingle(model, path + bestResultFile());
|
excel_report.reportSingle(model, path + Paths::bestResultsFile(score, model));
|
||||||
messageExcelFile(excel_report.getFileName());
|
messageOutputFile("Excel", excel_report.getFileName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void BestResults::reportAll(bool excel)
|
void BestResults::reportAll(bool excel, bool tex)
|
||||||
{
|
{
|
||||||
auto models = getModels();
|
auto models = getModels();
|
||||||
// Build the table of results
|
// Build the table of results
|
||||||
json table = buildTableResults(models);
|
json table = buildTableResults(models);
|
||||||
std::vector<std::string> datasets = getDatasets(table.begin().value());
|
std::vector<std::string> datasets = getDatasets(table.begin().value());
|
||||||
// Print the table of results
|
// Print the table of results
|
||||||
printTableResults(models, table);
|
printTableResults(models, table, tex);
|
||||||
// Compute the Friedman test
|
// Compute the Friedman test
|
||||||
std::map<std::string, std::map<std::string, float>> ranksModels;
|
std::map<std::string, std::map<std::string, float>> ranksModels;
|
||||||
if (friedman) {
|
if (friedman) {
|
||||||
Statistics stats(models, datasets, table, significance);
|
Statistics stats(models, datasets, table, significance);
|
||||||
auto result = stats.friedmanTest();
|
auto result = stats.friedmanTest();
|
||||||
stats.postHocHolmTest(result);
|
stats.postHocHolmTest(result, tex);
|
||||||
ranksModels = stats.getRanks();
|
ranksModels = stats.getRanks();
|
||||||
}
|
}
|
||||||
|
if (tex) {
|
||||||
|
messageOutputFile("TeX", Paths::tex() + Paths::tex_output());
|
||||||
|
messageOutputFile("MarkDown", Paths::tex() + Paths::md_output());
|
||||||
|
if (friedman) {
|
||||||
|
messageOutputFile("TeX", Paths::tex() + Paths::tex_post_hoc());
|
||||||
|
messageOutputFile("MarkDown", Paths::tex() + Paths::md_post_hoc());
|
||||||
|
}
|
||||||
|
}
|
||||||
if (excel) {
|
if (excel) {
|
||||||
BestResultsExcel excel(score, datasets);
|
BestResultsExcel excel(score, datasets);
|
||||||
excel.reportAll(models, table, ranksModels, friedman, significance);
|
excel.reportAll(models, table, ranksModels, friedman, significance);
|
||||||
@@ -331,9 +357,9 @@ namespace platform {
|
|||||||
double min = 2000;
|
double min = 2000;
|
||||||
// Find out the control model
|
// Find out the control model
|
||||||
auto totals = std::vector<double>(models.size(), 0.0);
|
auto totals = std::vector<double>(models.size(), 0.0);
|
||||||
for (const auto& dataset : datasets) {
|
for (const auto& dataset_ : datasets) {
|
||||||
for (int i = 0; i < models.size(); ++i) {
|
for (int i = 0; i < models.size(); ++i) {
|
||||||
totals[i] += ranksModels[dataset][models[i]];
|
totals[i] += ranksModels[dataset_][models[i]];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int i = 0; i < models.size(); ++i) {
|
for (int i = 0; i < models.size(); ++i) {
|
||||||
@@ -343,13 +369,14 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
model = models.at(idx);
|
model = models.at(idx);
|
||||||
excel.reportSingle(model, path + bestResultFile());
|
excel.reportSingle(model, path + Paths::bestResultsFile(score, model));
|
||||||
}
|
}
|
||||||
messageExcelFile(excel.getFileName());
|
messageOutputFile("Excel", excel.getFileName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void BestResults::messageExcelFile(const std::string& fileName)
|
void BestResults::messageOutputFile(const std::string& title, const std::string& fileName)
|
||||||
{
|
{
|
||||||
std::cout << Colors::YELLOW() << "** Excel file generated: " << fileName << Colors::RESET() << std::endl;
|
std::cout << Colors::YELLOW() << "** " << std::setw(5) << std::left << title
|
||||||
|
<< " file generated: " << fileName << Colors::RESET() << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -2,35 +2,36 @@
|
|||||||
#define BESTRESULTS_H
|
#define BESTRESULTS_H
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
using json = nlohmann::json;
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
class BestResults {
|
class BestResults {
|
||||||
public:
|
public:
|
||||||
explicit BestResults(const std::string& path, const std::string& score, const std::string& model, bool friedman, double significance = 0.05)
|
explicit BestResults(const std::string& path, const std::string& score, const std::string& model, const std::string& dataset, bool friedman, double significance = 0.05)
|
||||||
: path(path), score(score), model(model), friedman(friedman), significance(significance)
|
: path(path), score(score), model(model), dataset(dataset), friedman(friedman), significance(significance)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
std::string build();
|
std::string build();
|
||||||
void reportSingle(bool excel);
|
void reportSingle(bool excel);
|
||||||
void reportAll(bool excel);
|
void reportAll(bool excel, bool tex);
|
||||||
void buildAll();
|
void buildAll();
|
||||||
private:
|
private:
|
||||||
std::vector<std::string> getModels();
|
std::vector<std::string> getModels();
|
||||||
std::vector<std::string> getDatasets(json table);
|
std::vector<std::string> getDatasets(json table);
|
||||||
std::vector<std::string> loadResultFiles();
|
std::vector<std::string> loadResultFiles();
|
||||||
void messageExcelFile(const std::string& fileName);
|
void messageOutputFile(const std::string& title, const std::string& fileName);
|
||||||
json buildTableResults(std::vector<std::string> models);
|
json buildTableResults(std::vector<std::string> models);
|
||||||
void printTableResults(std::vector<std::string> models, json table);
|
void printTableResults(std::vector<std::string> models, json table, bool tex);
|
||||||
std::string bestResultFile();
|
|
||||||
json loadFile(const std::string& fileName);
|
json loadFile(const std::string& fileName);
|
||||||
void listFile();
|
void listFile();
|
||||||
std::string path;
|
std::string path;
|
||||||
std::string score;
|
std::string score;
|
||||||
std::string model;
|
std::string model;
|
||||||
|
std::string dataset;
|
||||||
bool friedman;
|
bool friedman;
|
||||||
double significance;
|
double significance;
|
||||||
int maxModelName = 0;
|
int maxModelName = 0;
|
||||||
int maxDatasetName = 0;
|
int maxDatasetName = 0;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif //BESTRESULTS_H
|
#endif
|
@@ -1,10 +1,10 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include "BestResultsExcel.h"
|
|
||||||
#include "Paths.h"
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include "Statistics.h"
|
#include "common/Paths.h"
|
||||||
#include "ReportExcel.h"
|
#include "reports/ReportExcel.h"
|
||||||
|
#include "best/Statistics.h"
|
||||||
|
#include "BestResultsExcel.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
json loadResultData(const std::string& fileName)
|
json loadResultData(const std::string& fileName)
|
||||||
@@ -32,7 +32,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
BestResultsExcel::BestResultsExcel(const std::string& score, const std::vector<std::string>& datasets) : score(score), datasets(datasets)
|
BestResultsExcel::BestResultsExcel(const std::string& score, const std::vector<std::string>& datasets) : score(score), datasets(datasets)
|
||||||
{
|
{
|
||||||
file_name = "BestResults.xlsx";
|
file_name = Paths::bestResultsExcel(score);
|
||||||
workbook = workbook_new(getFileName().c_str());
|
workbook = workbook_new(getFileName().c_str());
|
||||||
setProperties("Best Results");
|
setProperties("Best Results");
|
||||||
int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
|
int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
|
||||||
@@ -64,19 +64,21 @@ namespace platform {
|
|||||||
json data = loadResultData(fileName);
|
json data = loadResultData(fileName);
|
||||||
|
|
||||||
std::string title = "Best results for " + model;
|
std::string title = "Best results for " + model;
|
||||||
worksheet_merge_range(worksheet, 0, 0, 0, 4, title.c_str(), styles["headerFirst"]);
|
worksheet_merge_range(worksheet, 0, 0, 0, 5, title.c_str(), styles["headerFirst"]);
|
||||||
// Body header
|
// Body header
|
||||||
row = 3;
|
row = 3;
|
||||||
int col = 1;
|
int col = 1;
|
||||||
writeString(row, 0, "Nº", "bodyHeader");
|
writeString(row, 0, "#", "bodyHeader");
|
||||||
writeString(row, 1, "Dataset", "bodyHeader");
|
writeString(row, 1, "Dataset", "bodyHeader");
|
||||||
writeString(row, 2, "Score", "bodyHeader");
|
writeString(row, 2, "Score", "bodyHeader");
|
||||||
writeString(row, 3, "File", "bodyHeader");
|
writeString(row, 3, "File", "bodyHeader");
|
||||||
writeString(row, 4, "Hyperparameters", "bodyHeader");
|
writeString(row, 4, "Hyperparameters", "bodyHeader");
|
||||||
|
writeString(row, 5, "F", "bodyHeader");
|
||||||
auto i = 0;
|
auto i = 0;
|
||||||
std::string hyperparameters;
|
std::string hyperparameters;
|
||||||
int hypSize = 22;
|
int hypSize = 22;
|
||||||
std::map<std::string, std::string> files; // map of files imported and their tabs
|
std::map<std::string, std::string> files; // map of files imported and their tabs
|
||||||
|
int numLines = data.size();
|
||||||
for (auto const& item : data.items()) {
|
for (auto const& item : data.items()) {
|
||||||
row++;
|
row++;
|
||||||
writeInt(row, 0, i++, "ints");
|
writeInt(row, 0, i++, "ints");
|
||||||
@@ -104,6 +106,8 @@ namespace platform {
|
|||||||
hypSize = hyperparameters.size();
|
hypSize = hyperparameters.size();
|
||||||
}
|
}
|
||||||
writeString(row, 4, hyperparameters, "text");
|
writeString(row, 4, hyperparameters, "text");
|
||||||
|
std::string countHyperparameters = "=COUNTIF(e5:e" + std::to_string(numLines + 4) + ", e" + std::to_string(row + 1) + ")";
|
||||||
|
worksheet_write_formula(worksheet, row, 5, countHyperparameters.c_str(), efectiveStyle("ints"));
|
||||||
}
|
}
|
||||||
row++;
|
row++;
|
||||||
// Set Totals
|
// Set Totals
|
||||||
@@ -180,7 +184,7 @@ namespace platform {
|
|||||||
// Body header
|
// Body header
|
||||||
row = 3;
|
row = 3;
|
||||||
int col = 1;
|
int col = 1;
|
||||||
writeString(row, 0, "Nº", "bodyHeader");
|
writeString(row, 0, "#", "bodyHeader");
|
||||||
writeString(row, 1, "Dataset", "bodyHeader");
|
writeString(row, 1, "Dataset", "bodyHeader");
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
writeString(row, ++col, model.c_str(), "bodyHeader");
|
writeString(row, ++col, model.c_str(), "bodyHeader");
|
@@ -1,14 +1,13 @@
|
|||||||
#ifndef BESTRESULTS_EXCEL_H
|
#ifndef BESTRESULTSEXCEL_H
|
||||||
#define BESTRESULTS_EXCEL_H
|
#define BESTRESULTSEXCEL_H
|
||||||
#include "ExcelFile.h"
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "reports/ExcelFile.h"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
class BestResultsExcel : public ExcelFile {
|
class BestResultsExcel : public ExcelFile {
|
||||||
public:
|
public:
|
||||||
BestResultsExcel(const std::string& score, const std::vector<std::string>& datasets);
|
BestResultsExcel(const std::string& score, const std::vector<std::string>& datasets);
|
||||||
@@ -34,4 +33,4 @@ namespace platform {
|
|||||||
int datasetNameSize = 25; // Min size of the column
|
int datasetNameSize = 25; // Min size of the column
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif //BESTRESULTS_EXCEL_H
|
#endif
|
103
src/best/BestResultsMd.cpp
Normal file
103
src/best/BestResultsMd.cpp
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include "BestResultsMd.h"
|
||||||
|
#include "common/Utils.h" // compute_std
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
void BestResultsMd::openMdFile(const std::string& name)
|
||||||
|
{
|
||||||
|
handler.open(name);
|
||||||
|
if (!handler.is_open()) {
|
||||||
|
std::cerr << "Error opening file " << name << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void BestResultsMd::results_header(const std::vector<std::string>& models, const std::string& date)
|
||||||
|
{
|
||||||
|
this->models = models;
|
||||||
|
auto file_name = Paths::tex() + Paths::md_output();
|
||||||
|
openMdFile(file_name);
|
||||||
|
handler << "<!-- This file has been generated by the platform program" << std::endl;
|
||||||
|
handler << " Date: " << date.c_str() << std::endl;
|
||||||
|
handler << "" << std::endl;
|
||||||
|
handler << " Table of results" << std::endl;
|
||||||
|
handler << "-->" << std::endl;
|
||||||
|
handler << "| # | Dataset |";
|
||||||
|
for (const auto& model : models) {
|
||||||
|
handler << " " << model.c_str() << " |";
|
||||||
|
}
|
||||||
|
handler << std::endl;
|
||||||
|
handler << "|--: | :--- |";
|
||||||
|
for (const auto& model : models) {
|
||||||
|
handler << " :---: |";
|
||||||
|
}
|
||||||
|
handler << std::endl;
|
||||||
|
}
|
||||||
|
void BestResultsMd::results_body(const std::vector<std::string>& datasets, json& table)
|
||||||
|
{
|
||||||
|
int i = 0;
|
||||||
|
for (auto const& dataset : datasets) {
|
||||||
|
// Find out max value for this dataset
|
||||||
|
double max_value = 0;
|
||||||
|
// Find out the max value for this dataset
|
||||||
|
for (const auto& model : models) {
|
||||||
|
double value;
|
||||||
|
try {
|
||||||
|
value = table[model].at(dataset).at(0).get<double>();
|
||||||
|
}
|
||||||
|
catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) {
|
||||||
|
value = -1.0;
|
||||||
|
}
|
||||||
|
if (value > max_value) {
|
||||||
|
max_value = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handler << "| " << ++i << " | " << dataset.c_str() << " | ";
|
||||||
|
for (const auto& model : models) {
|
||||||
|
double value = table[model].at(dataset).at(0).get<double>();
|
||||||
|
double std_value = table[model].at(dataset).at(3).get<double>();
|
||||||
|
const char* bold = value == max_value ? "**" : "";
|
||||||
|
handler << bold << std::setprecision(4) << std::fixed << value << "±" << std::setprecision(3) << std_value << bold << " | ";
|
||||||
|
}
|
||||||
|
handler << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void BestResultsMd::results_footer(const std::map<std::string, std::vector<double>>& totals, const std::string& best_model)
|
||||||
|
{
|
||||||
|
handler << "| | **Average Score** | ";
|
||||||
|
int nDatasets = totals.begin()->second.size();
|
||||||
|
for (const auto& model : models) {
|
||||||
|
double value = std::reduce(totals.at(model).begin(), totals.at(model).end()) / nDatasets;
|
||||||
|
double std_value = compute_std(totals.at(model), value);
|
||||||
|
const char* bold = model == best_model ? "**" : "";
|
||||||
|
handler << bold << std::setprecision(4) << std::fixed << value << "±" << std::setprecision(3) << std::fixed << std_value << bold << " | ";
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.close();
|
||||||
|
}
|
||||||
|
void BestResultsMd::holm_test(struct HolmResult& holmResult, const std::string& date)
|
||||||
|
{
|
||||||
|
auto file_name = Paths::tex() + Paths::md_post_hoc();
|
||||||
|
openMdFile(file_name);
|
||||||
|
handler << "<!-- This file has been generated by the platform program" << std::endl;
|
||||||
|
handler << " Date: " << date.c_str() << std::endl;
|
||||||
|
handler << std::endl;
|
||||||
|
handler << " Post-hoc handler test" << std::endl;
|
||||||
|
handler << "-->" << std::endl;
|
||||||
|
handler << "Post-hoc Holm test: H<sub>0</sub>: There is no significant differences between the control model and the other models." << std::endl << std::endl;
|
||||||
|
handler << "| classifier | pvalue | rank | win | tie | loss | H<sub>0</sub> |" << std::endl;
|
||||||
|
handler << "| :-- | --: | --: | --:| --: | --: | :--: |" << std::endl;
|
||||||
|
for (auto const& line : holmResult.holmLines) {
|
||||||
|
auto textStatus = !line.reject ? "**" : " ";
|
||||||
|
if (line.model == holmResult.model) {
|
||||||
|
handler << "| " << line.model << " | - | " << std::fixed << std::setprecision(2) << line.rank << " | - | - | - |" << std::endl;
|
||||||
|
} else {
|
||||||
|
handler << "| " << line.model << " | " << textStatus << std::scientific << std::setprecision(4) << line.pvalue << textStatus << " |";
|
||||||
|
handler << std::fixed << std::setprecision(2) << line.rank << " | " << line.wtl.win << " | " << line.wtl.tie << " | " << line.wtl.loss << " |";
|
||||||
|
handler << (line.reject ? "rejected" : "**accepted**") << " |" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handler << std::endl;
|
||||||
|
handler.close();
|
||||||
|
}
|
||||||
|
}
|
24
src/best/BestResultsMd.h
Normal file
24
src/best/BestResultsMd.h
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
#ifndef BEST_RESULTS_MD_H
|
||||||
|
#define BEST_RESULTS_MD_H
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "common/Paths.h"
|
||||||
|
#include "Statistics.h"
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
class BestResultsMd {
|
||||||
|
public:
|
||||||
|
BestResultsMd() = default;
|
||||||
|
~BestResultsMd() = default;
|
||||||
|
void results_header(const std::vector<std::string>& models, const std::string& date);
|
||||||
|
void results_body(const std::vector<std::string>& datasets, json& table);
|
||||||
|
void results_footer(const std::map<std::string, std::vector<double>>& totals, const std::string& best_model);
|
||||||
|
void holm_test(struct HolmResult& holmResult, const std::string& date);
|
||||||
|
private:
|
||||||
|
void openMdFile(const std::string& name);
|
||||||
|
std::ofstream handler;
|
||||||
|
std::vector<std::string> models;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
117
src/best/BestResultsTex.cpp
Normal file
117
src/best/BestResultsTex.cpp
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include "BestResultsTex.h"
|
||||||
|
#include "common/Utils.h" // compute_std
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
void BestResultsTex::openTexFile(const std::string& name)
|
||||||
|
{
|
||||||
|
handler.open(name);
|
||||||
|
if (!handler.is_open()) {
|
||||||
|
std::cerr << "Error opening file " << name << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void BestResultsTex::results_header(const std::vector<std::string>& models, const std::string& date)
|
||||||
|
{
|
||||||
|
this->models = models;
|
||||||
|
auto file_name = Paths::tex() + Paths::tex_output();
|
||||||
|
openTexFile(file_name);
|
||||||
|
handler << "%% This file has been generated by the platform program" << std::endl;
|
||||||
|
handler << "%% Date: " << date.c_str() << std::endl;
|
||||||
|
handler << "%%" << std::endl;
|
||||||
|
handler << "%% Table of results" << std::endl;
|
||||||
|
handler << "%%" << std::endl;
|
||||||
|
handler << "\\begin{table}[htbp] " << std::endl;
|
||||||
|
handler << "\\centering " << std::endl;
|
||||||
|
handler << "\\tiny " << std::endl;
|
||||||
|
handler << "\\renewcommand{\\arraystretch }{1.2} " << std::endl;
|
||||||
|
handler << "\\renewcommand{\\tabcolsep }{0.07cm} " << std::endl;
|
||||||
|
handler << "\\caption{Accuracy results(mean $\\pm$ std) for all the algorithms and datasets} " << std::endl;
|
||||||
|
handler << "\\label{tab:results_accuracy}" << std::endl;
|
||||||
|
handler << "\\begin{tabular} {{r" << std::string(models.size(), 'c').c_str() << "}}" << std::endl;
|
||||||
|
handler << "\\hline " << std::endl;
|
||||||
|
handler << "" << std::endl;
|
||||||
|
for (const auto& model : models) {
|
||||||
|
handler << "& " << model.c_str();
|
||||||
|
}
|
||||||
|
handler << "\\\\" << std::endl;
|
||||||
|
handler << "\\hline" << std::endl;
|
||||||
|
}
|
||||||
|
void BestResultsTex::results_body(const std::vector<std::string>& datasets, json& table)
|
||||||
|
{
|
||||||
|
int i = 0;
|
||||||
|
for (auto const& dataset : datasets) {
|
||||||
|
// Find out max value for this dataset
|
||||||
|
double max_value = 0;
|
||||||
|
// Find out the max value for this dataset
|
||||||
|
for (const auto& model : models) {
|
||||||
|
double value;
|
||||||
|
try {
|
||||||
|
value = table[model].at(dataset).at(0).get<double>();
|
||||||
|
}
|
||||||
|
catch (nlohmann::json_abi_v3_11_3::detail::out_of_range err) {
|
||||||
|
value = -1.0;
|
||||||
|
}
|
||||||
|
if (value > max_value) {
|
||||||
|
max_value = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handler << ++i << " ";
|
||||||
|
for (const auto& model : models) {
|
||||||
|
double value = table[model].at(dataset).at(0).get<double>();
|
||||||
|
double std_value = table[model].at(dataset).at(3).get<double>();
|
||||||
|
const char* bold = value == max_value ? "\\bfseries" : "";
|
||||||
|
handler << "& " << bold << std::setprecision(4) << std::fixed << value << "$\\pm$" << std::setprecision(3) << std_value;
|
||||||
|
}
|
||||||
|
handler << "\\\\" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void BestResultsTex::results_footer(const std::map<std::string, std::vector<double>>& totals, const std::string& best_model)
|
||||||
|
{
|
||||||
|
handler << "\\hline" << std::endl;
|
||||||
|
handler << "Average ";
|
||||||
|
int nDatasets = totals.begin()->second.size();
|
||||||
|
for (const auto& model : models) {
|
||||||
|
double value = std::reduce(totals.at(model).begin(), totals.at(model).end()) / nDatasets;
|
||||||
|
double std_value = compute_std(totals.at(model), value);
|
||||||
|
const char* bold = model == best_model ? "\\bfseries" : "";
|
||||||
|
handler << "& " << bold << std::setprecision(4) << std::fixed << value << "$\\pm$" << std::setprecision(3) << std::fixed << std_value;
|
||||||
|
}
|
||||||
|
handler << "\\\\" << std::endl;
|
||||||
|
handler << "\\hline " << std::endl;
|
||||||
|
handler << "\\end{tabular}" << std::endl;
|
||||||
|
handler << "\\end{table}" << std::endl;
|
||||||
|
handler.close();
|
||||||
|
}
|
||||||
|
void BestResultsTex::holm_test(struct HolmResult& holmResult, const std::string& date)
|
||||||
|
{
|
||||||
|
auto file_name = Paths::tex() + Paths::tex_post_hoc();
|
||||||
|
openTexFile(file_name);
|
||||||
|
handler << "%% This file has been generated by the platform program" << std::endl;
|
||||||
|
handler << "%% Date: " << date.c_str() << std::endl;
|
||||||
|
handler << "%%" << std::endl;
|
||||||
|
handler << "%% Post-hoc handler test" << std::endl;
|
||||||
|
handler << "%%" << std::endl;
|
||||||
|
handler << "\\begin{table}[htbp]" << std::endl;
|
||||||
|
handler << "\\centering" << std::endl;
|
||||||
|
handler << "\\caption{Results of the post-hoc test for the mean accuracy of the algorithms.}\\label{tab:tests}" << std::endl;
|
||||||
|
handler << "\\begin{tabular}{lrrrrr}" << std::endl;
|
||||||
|
handler << "\\hline" << std::endl;
|
||||||
|
handler << "classifier & pvalue & rank & win & tie & loss\\\\" << std::endl;
|
||||||
|
handler << "\\hline" << std::endl;
|
||||||
|
for (auto const& line : holmResult.holmLines) {
|
||||||
|
auto textStatus = !line.reject ? "\\bf " : " ";
|
||||||
|
if (line.model == holmResult.model) {
|
||||||
|
handler << line.model << " & - & " << std::fixed << std::setprecision(2) << line.rank << " & - & - & - \\\\" << std::endl;
|
||||||
|
} else {
|
||||||
|
handler << line.model << " & " << textStatus << std::scientific << std::setprecision(4) << line.pvalue << " & ";
|
||||||
|
handler << std::fixed << std::setprecision(2) << line.rank << " & " << line.wtl.win << " & " << line.wtl.tie << " & " << line.wtl.loss << "\\\\" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handler << "\\hline " << std::endl;
|
||||||
|
handler << "\\end{tabular}" << std::endl;
|
||||||
|
handler << "\\end{table}" << std::endl;
|
||||||
|
handler.close();
|
||||||
|
}
|
||||||
|
}
|
24
src/best/BestResultsTex.h
Normal file
24
src/best/BestResultsTex.h
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
#ifndef BEST_RESULTS_TEX_H
|
||||||
|
#define BEST_RESULTS_TEX_H
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "common/Paths.h"
|
||||||
|
#include "Statistics.h"
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
class BestResultsTex {
|
||||||
|
public:
|
||||||
|
BestResultsTex() = default;
|
||||||
|
~BestResultsTex() = default;
|
||||||
|
void results_header(const std::vector<std::string>& models, const std::string& date);
|
||||||
|
void results_body(const std::vector<std::string>& datasets, json& table);
|
||||||
|
void results_footer(const std::map<std::string, std::vector<double>>& totals, const std::string& best_model);
|
||||||
|
void holm_test(struct HolmResult& holmResult, const std::string& date);
|
||||||
|
private:
|
||||||
|
void openTexFile(const std::string& name);
|
||||||
|
std::ofstream handler;
|
||||||
|
std::vector<std::string> models;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
@@ -3,7 +3,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "DotEnv.h"
|
#include "common/DotEnv.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class BestScore {
|
class BestScore {
|
||||||
public:
|
public:
|
||||||
@@ -24,5 +24,4 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
@@ -1,10 +1,12 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include "Statistics.h"
|
|
||||||
#include "Colors.h"
|
|
||||||
#include "Symbols.h"
|
|
||||||
#include <boost/math/distributions/chi_squared.hpp>
|
#include <boost/math/distributions/chi_squared.hpp>
|
||||||
#include <boost/math/distributions/normal.hpp>
|
#include <boost/math/distributions/normal.hpp>
|
||||||
#include "CLocale.h"
|
#include "common/Colors.h"
|
||||||
|
#include "common/Symbols.h"
|
||||||
|
#include "common/CLocale.h"
|
||||||
|
#include "BestResultsTex.h"
|
||||||
|
#include "BestResultsMd.h"
|
||||||
|
#include "Statistics.h"
|
||||||
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
@@ -113,7 +115,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Statistics::postHocHolmTest(bool friedmanResult)
|
void Statistics::postHocHolmTest(bool friedmanResult, bool tex)
|
||||||
{
|
{
|
||||||
if (!fitted) {
|
if (!fitted) {
|
||||||
fit();
|
fit();
|
||||||
@@ -130,7 +132,7 @@ namespace platform {
|
|||||||
stats[i] = 0.0;
|
stats[i] = 0.0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
double z = abs(ranks.at(models[controlIdx]) - ranks.at(models[i])) / diff;
|
double z = std::abs(ranks.at(models[controlIdx]) - ranks.at(models[i])) / diff;
|
||||||
double p_value = (long double)2 * (1 - cdf(dist, z));
|
double p_value = (long double)2 * (1 - cdf(dist, z));
|
||||||
stats[i] = p_value;
|
stats[i] = p_value;
|
||||||
}
|
}
|
||||||
@@ -195,6 +197,12 @@ namespace platform {
|
|||||||
if (output) {
|
if (output) {
|
||||||
std::cout << oss.str();
|
std::cout << oss.str();
|
||||||
}
|
}
|
||||||
|
if (tex) {
|
||||||
|
BestResultsTex bestResultsTex;
|
||||||
|
BestResultsMd bestResultsMd;
|
||||||
|
bestResultsTex.holm_test(holmResult, get_date() + " " + get_time());
|
||||||
|
bestResultsMd.holm_test(holmResult, get_date() + " " + get_time());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
bool Statistics::friedmanTest()
|
bool Statistics::friedmanTest()
|
||||||
{
|
{
|
@@ -5,9 +5,9 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
using json = nlohmann::json;
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
struct WTL {
|
struct WTL {
|
||||||
int win;
|
int win;
|
||||||
int tie;
|
int tie;
|
||||||
@@ -34,7 +34,7 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
Statistics(const std::vector<std::string>& models, const std::vector<std::string>& datasets, const json& data, double significance = 0.05, bool output = true);
|
Statistics(const std::vector<std::string>& models, const std::vector<std::string>& datasets, const json& data, double significance = 0.05, bool output = true);
|
||||||
bool friedmanTest();
|
bool friedmanTest();
|
||||||
void postHocHolmTest(bool friedmanResult);
|
void postHocHolmTest(bool friedmanResult, bool tex=false);
|
||||||
FriedmanResult& getFriedmanResult();
|
FriedmanResult& getFriedmanResult();
|
||||||
HolmResult& getHolmResult();
|
HolmResult& getHolmResult();
|
||||||
std::map<std::string, std::map<std::string, float>>& getRanks();
|
std::map<std::string, std::map<std::string, float>>& getRanks();
|
||||||
@@ -60,4 +60,4 @@ namespace platform {
|
|||||||
std::map<std::string, std::map<std::string, float>> ranksModels;
|
std::map<std::string, std::map<std::string, float>> ranksModels;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif // !STATISTICS_H
|
#endif
|
@@ -1,16 +1,22 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <argparse/argparse.hpp>
|
#include <argparse/argparse.hpp>
|
||||||
#include "Paths.h"
|
#include "main/Models.h"
|
||||||
#include "BestResults.h"
|
#include "main/modelRegister.h"
|
||||||
#include "Colors.h"
|
#include "common/Paths.h"
|
||||||
#include "config.h"
|
#include "common/Colors.h"
|
||||||
|
#include "best/BestResults.h"
|
||||||
|
#include "config_platform.h"
|
||||||
|
|
||||||
void manageArguments(argparse::ArgumentParser& program)
|
void manageArguments(argparse::ArgumentParser& program)
|
||||||
{
|
{
|
||||||
program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)");
|
program.add_argument("-m", "--model")
|
||||||
|
.help("Model to use or any")
|
||||||
|
.default_value("any");
|
||||||
|
program.add_argument("-d", "--dataset").default_value("any").help("Filter results of the selected model) (any for all datasets)");
|
||||||
program.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied");
|
program.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied");
|
||||||
program.add_argument("--friedman").help("Friedman test").default_value(false).implicit_value(true);
|
program.add_argument("--friedman").help("Friedman test").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--excel").help("Output to excel").default_value(false).implicit_value(true);
|
program.add_argument("--excel").help("Output to excel").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--tex").help("Output result table to TeX file").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--level").help("significance level").default_value(0.05).scan<'g', double>().action([](const std::string& value) {
|
program.add_argument("--level").help("significance level").default_value(0.05).scan<'g', double>().action([](const std::string& value) {
|
||||||
try {
|
try {
|
||||||
auto k = std::stod(value);
|
auto k = std::stod(value);
|
||||||
@@ -29,23 +35,25 @@ void manageArguments(argparse::ArgumentParser& program)
|
|||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
argparse::ArgumentParser program("b_best", { project_version.begin(), project_version.end() });
|
argparse::ArgumentParser program("b_best", { platform_project_version.begin(), platform_project_version.end() });
|
||||||
manageArguments(program);
|
manageArguments(program);
|
||||||
std::string model, score;
|
std::string model, dataset, score;
|
||||||
bool build, report, friedman, excel;
|
bool build, report, friedman, excel, tex;
|
||||||
double level;
|
double level;
|
||||||
try {
|
try {
|
||||||
program.parse_args(argc, argv);
|
program.parse_args(argc, argv);
|
||||||
model = program.get<std::string>("model");
|
model = program.get<std::string>("model");
|
||||||
|
dataset = program.get<std::string>("dataset");
|
||||||
score = program.get<std::string>("score");
|
score = program.get<std::string>("score");
|
||||||
friedman = program.get<bool>("friedman");
|
friedman = program.get<bool>("friedman");
|
||||||
excel = program.get<bool>("excel");
|
excel = program.get<bool>("excel");
|
||||||
|
tex = program.get<bool>("tex");
|
||||||
level = program.get<double>("level");
|
level = program.get<double>("level");
|
||||||
if (model == "" || score == "") {
|
if (model == "" || score == "") {
|
||||||
throw std::runtime_error("Model and score name must be supplied");
|
throw std::runtime_error("Model and score name must be supplied");
|
||||||
}
|
}
|
||||||
if (friedman && model != "any") {
|
if (friedman && (model != "any" || dataset != "any")) {
|
||||||
std::cerr << "Friedman test can only be used with all models" << std::endl;
|
std::cerr << "Friedman test can only be used with all models and all the datasets" << std::endl;
|
||||||
std::cerr << program;
|
std::cerr << program;
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
@@ -56,10 +64,10 @@ int main(int argc, char** argv)
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
// Generate report
|
// Generate report
|
||||||
auto results = platform::BestResults(platform::Paths::results(), score, model, friedman, level);
|
auto results = platform::BestResults(platform::Paths::results(), score, model, dataset, friedman, level);
|
||||||
if (model == "any") {
|
if (model == "any") {
|
||||||
results.buildAll();
|
results.buildAll();
|
||||||
results.reportAll(excel);
|
results.reportAll(excel, tex);
|
||||||
} else {
|
} else {
|
||||||
std::string fileName = results.build();
|
std::string fileName = results.build();
|
||||||
std::cout << Colors::GREEN() << fileName << " created!" << Colors::RESET() << std::endl;
|
std::cout << Colors::GREEN() << fileName << " created!" << Colors::RESET() << std::endl;
|
@@ -4,30 +4,30 @@
|
|||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
#include "DotEnv.h"
|
#include "main/Models.h"
|
||||||
#include "Models.h"
|
#include "main/modelRegister.h"
|
||||||
#include "modelRegister.h"
|
#include "common/Paths.h"
|
||||||
#include "GridSearch.h"
|
#include "common/Timer.h"
|
||||||
#include "Paths.h"
|
#include "common/Colors.h"
|
||||||
#include "Timer.h"
|
#include "common/DotEnv.h"
|
||||||
#include "Colors.h"
|
#include "grid/GridSearch.h"
|
||||||
#include "config.h"
|
#include "config_platform.h"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::ordered_json;
|
||||||
const int MAXL = 133;
|
const int MAXL = 133;
|
||||||
|
|
||||||
void assignModel(argparse::ArgumentParser& parser)
|
void assignModel(argparse::ArgumentParser& parser)
|
||||||
{
|
{
|
||||||
auto models = platform::Models::instance();
|
auto models = platform::Models::instance();
|
||||||
parser.add_argument("-m", "--model")
|
parser.add_argument("-m", "--model")
|
||||||
.help("Model to use " + models->tostring())
|
.help("Model to use " + models->toString())
|
||||||
.required()
|
.required()
|
||||||
.action([models](const std::string& value) {
|
.action([models](const std::string& value) {
|
||||||
static const std::vector<std::string> choices = models->getNames();
|
static const std::vector<std::string> choices = models->getNames();
|
||||||
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
throw std::runtime_error("Model must be one of " + models->tostring());
|
throw std::runtime_error("Model must be one of " + models->toString());
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -93,21 +93,27 @@ void list_dump(std::string& model)
|
|||||||
if (item.first.size() > max_dataset) {
|
if (item.first.size() > max_dataset) {
|
||||||
max_dataset = item.first.size();
|
max_dataset = item.first.size();
|
||||||
}
|
}
|
||||||
if (item.second.dump().size() > max_hyper) {
|
for (auto const& [key, value] : item.second.items()) {
|
||||||
max_hyper = item.second.dump().size();
|
if (value.dump().size() > max_hyper) {
|
||||||
|
max_hyper = value.dump().size();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::cout << Colors::GREEN() << left << " # " << left << setw(max_dataset) << "Dataset" << " #Com. "
|
std::cout << Colors::GREEN() << left << " # " << left << setw(max_dataset) << "Dataset" << " #Com. "
|
||||||
<< setw(max_hyper) << "Hyperparameters" << std::endl;
|
<< setw(max_hyper) << "Hyperparameters" << std::endl;
|
||||||
std::cout << "=== " << string(max_dataset, '=') << " ===== " << string(max_hyper, '=') << std::endl;
|
std::cout << "=== " << string(max_dataset, '=') << " ===== " << string(max_hyper, '=') << std::endl;
|
||||||
bool odd = true;
|
int i = 0;
|
||||||
for (auto const& item : combinations) {
|
for (auto const& item : combinations) {
|
||||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
auto color = (i++ % 2) ? Colors::CYAN() : Colors::BLUE();
|
||||||
std::cout << color;
|
std::cout << color;
|
||||||
auto num_combinations = data.getNumCombinations(item.first);
|
auto num_combinations = data.getNumCombinations(item.first);
|
||||||
std::cout << setw(3) << fixed << right << ++index << left << " " << setw(max_dataset) << item.first
|
std::cout << setw(3) << fixed << right << ++index << left << " " << setw(max_dataset) << item.first
|
||||||
<< " " << setw(5) << right << num_combinations << " " << setw(max_hyper) << left << item.second.dump() << std::endl;
|
<< " " << setw(5) << right << num_combinations << " ";
|
||||||
odd = !odd;
|
std::string prefix = "";
|
||||||
|
for (auto const& [key, value] : item.second.items()) {
|
||||||
|
std::cout << prefix << setw(max_hyper) << std::left << value.dump() << std::endl;
|
||||||
|
prefix = string(11 + max_dataset, ' ');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::cout << Colors::RESET() << std::endl;
|
std::cout << Colors::RESET() << std::endl;
|
||||||
}
|
}
|
||||||
@@ -141,17 +147,15 @@ void list_results(json& results, std::string& model)
|
|||||||
<< "Duration " << setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
<< "Duration " << setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
||||||
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
|
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
|
||||||
<< string(8, '=') << " " << string(hyperparameters_spaces, '=') << std::endl;
|
<< string(8, '=') << " " << string(hyperparameters_spaces, '=') << std::endl;
|
||||||
bool odd = true;
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (const auto& item : results["results"].items()) {
|
for (const auto& item : results["results"].items()) {
|
||||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
auto color = (index % 2) ? Colors::CYAN() : Colors::BLUE();
|
||||||
auto value = item.value();
|
auto value = item.value();
|
||||||
std::cout << color;
|
std::cout << color;
|
||||||
std::cout << std::setw(3) << std::right << index++ << " ";
|
std::cout << std::setw(3) << std::right << index++ << " ";
|
||||||
std::cout << left << setw(spaces) << item.key() << " " << value["date"].get<string>()
|
std::cout << left << setw(spaces) << item.key() << " " << value["date"].get<string>()
|
||||||
<< " " << setw(8) << right << value["duration"].get<string>() << " " << setw(8) << setprecision(6)
|
<< " " << setw(8) << right << value["duration"].get<string>() << " " << setw(8) << setprecision(6)
|
||||||
<< fixed << right << value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
<< fixed << right << value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
||||||
odd = !odd;
|
|
||||||
}
|
}
|
||||||
std::cout << Colors::RESET() << std::endl;
|
std::cout << Colors::RESET() << std::endl;
|
||||||
}
|
}
|
||||||
@@ -223,7 +227,7 @@ int main(int argc, char** argv)
|
|||||||
//
|
//
|
||||||
// Manage arguments
|
// Manage arguments
|
||||||
//
|
//
|
||||||
argparse::ArgumentParser program("b_grid", { project_version.begin(), project_version.end() });
|
argparse::ArgumentParser program("b_grid", { platform_project_version.begin(), platform_project_version.end() });
|
||||||
// grid dump subparser
|
// grid dump subparser
|
||||||
argparse::ArgumentParser dump_command("dump");
|
argparse::ArgumentParser dump_command("dump");
|
||||||
dump_command.add_description("Dump the combinations of hyperparameters of a model.");
|
dump_command.add_description("Dump the combinations of hyperparameters of a model.");
|
||||||
@@ -259,7 +263,7 @@ int main(int argc, char** argv)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!found) {
|
if (!found) {
|
||||||
throw std::runtime_error("You must specify one of the following commands: dump, report, compute, export\n");
|
throw std::runtime_error("You must specify one of the following commands: dump, report, compute\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (const exception& err) {
|
catch (const exception& err) {
|
110
src/commands/b_list.cpp
Normal file
110
src/commands/b_list.cpp
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <locale>
|
||||||
|
#include <map>
|
||||||
|
#include <argparse/argparse.hpp>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "main/Models.h"
|
||||||
|
#include "main/modelRegister.h"
|
||||||
|
#include "common/Paths.h"
|
||||||
|
#include "common/Colors.h"
|
||||||
|
#include "common/Datasets.h"
|
||||||
|
#include "reports/DatasetsExcel.h"
|
||||||
|
#include "reports/DatasetsConsole.h"
|
||||||
|
#include "results/ResultsDatasetConsole.h"
|
||||||
|
#include "results/ResultsDataset.h"
|
||||||
|
#include "results/ResultsDatasetExcel.h"
|
||||||
|
#include "config_platform.h"
|
||||||
|
|
||||||
|
|
||||||
|
void list_datasets(argparse::ArgumentParser& program)
|
||||||
|
{
|
||||||
|
auto excel = program.get<bool>("excel");
|
||||||
|
auto report = platform::DatasetsConsole();
|
||||||
|
report.report();
|
||||||
|
std::cout << report.getOutput();
|
||||||
|
if (excel) {
|
||||||
|
auto data = report.getData();
|
||||||
|
auto report = platform::DatasetsExcel();
|
||||||
|
report.report(data);
|
||||||
|
std::cout << std::endl << Colors::GREEN() << "Output saved in " << report.getFileName() << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void list_results(argparse::ArgumentParser& program)
|
||||||
|
{
|
||||||
|
auto dataset = program.get<string>("dataset");
|
||||||
|
auto score = program.get<string>("score");
|
||||||
|
auto model = program.get<string>("model");
|
||||||
|
auto excel = program.get<bool>("excel");
|
||||||
|
auto report = platform::ResultsDatasetsConsole();
|
||||||
|
if (!report.report(dataset, score, model))
|
||||||
|
return;
|
||||||
|
std::cout << report.getOutput();
|
||||||
|
if (excel) {
|
||||||
|
auto data = report.getData();
|
||||||
|
auto report = platform::ResultsDatasetExcel();
|
||||||
|
report.report(data);
|
||||||
|
std::cout << std::endl << Colors::GREEN() << "Output saved in " << report.getFileName() << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
argparse::ArgumentParser program("b_list", { platform_project_version.begin(), platform_project_version.end() });
|
||||||
|
//
|
||||||
|
// datasets subparser
|
||||||
|
//
|
||||||
|
argparse::ArgumentParser datasets_command("datasets");
|
||||||
|
datasets_command.add_description("List datasets available in the platform.");
|
||||||
|
datasets_command.add_argument("--excel").help("Output in Excel format").default_value(false).implicit_value(true);
|
||||||
|
//
|
||||||
|
// results subparser
|
||||||
|
//
|
||||||
|
argparse::ArgumentParser results_command("results");
|
||||||
|
results_command.add_description("List the results of a given dataset.");
|
||||||
|
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
||||||
|
results_command.add_argument("-d", "--dataset")
|
||||||
|
.help("Dataset to use " + datasets.toString())
|
||||||
|
.required()
|
||||||
|
.action([](const std::string& value) {
|
||||||
|
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
||||||
|
static const std::vector<std::string> choices = datasets.getNames();
|
||||||
|
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Dataset must be one of " + datasets.toString());
|
||||||
|
}
|
||||||
|
);
|
||||||
|
results_command.add_argument("-m", "--model")
|
||||||
|
.help("Model to use or any")
|
||||||
|
.default_value("any");
|
||||||
|
results_command.add_argument("--excel").help("Output in Excel format").default_value(false).implicit_value(true);
|
||||||
|
results_command.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied");
|
||||||
|
|
||||||
|
// Add subparsers
|
||||||
|
program.add_subparser(datasets_command);
|
||||||
|
program.add_subparser(results_command);
|
||||||
|
// Parse command line and execute
|
||||||
|
try {
|
||||||
|
program.parse_args(argc, argv);
|
||||||
|
bool found = false;
|
||||||
|
map<std::string, void(*)(argparse::ArgumentParser&)> commands = { {"datasets", &list_datasets}, {"results", &list_results} };
|
||||||
|
for (const auto& command : commands) {
|
||||||
|
if (program.is_subcommand_used(command.first)) {
|
||||||
|
std::invoke(command.second, program.at<argparse::ArgumentParser>(command.first));
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
throw std::runtime_error("You must specify one of the following commands: {datasets, results}\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (const exception& err) {
|
||||||
|
cerr << err.what() << std::endl;
|
||||||
|
cerr << program;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
std::cout << Colors::RESET() << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
234
src/commands/b_main.cpp
Normal file
234
src/commands/b_main.cpp
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <argparse/argparse.hpp>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "main/Experiment.h"
|
||||||
|
#include "common/Datasets.h"
|
||||||
|
#include "common/DotEnv.h"
|
||||||
|
#include "common/Paths.h"
|
||||||
|
#include "main/Models.h"
|
||||||
|
#include "main/modelRegister.h"
|
||||||
|
#include "config_platform.h"
|
||||||
|
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
|
||||||
|
void manageArguments(argparse::ArgumentParser& program)
|
||||||
|
{
|
||||||
|
auto env = platform::DotEnv();
|
||||||
|
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
||||||
|
auto& group = program.add_mutually_exclusive_group(true);
|
||||||
|
group.add_argument("-d", "--dataset")
|
||||||
|
.help("Dataset file name: " + datasets.toString())
|
||||||
|
.default_value("all")
|
||||||
|
.action([](const std::string& value) {
|
||||||
|
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
||||||
|
static std::vector<std::string> choices_datasets(datasets.getNames());
|
||||||
|
choices_datasets.push_back("all");
|
||||||
|
if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Dataset must be one of: " + datasets.toString());
|
||||||
|
}
|
||||||
|
);
|
||||||
|
group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector<std::string>());
|
||||||
|
group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test.");
|
||||||
|
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
||||||
|
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
||||||
|
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
|
||||||
|
program.add_argument("--hyper-best").default_value(false).help("Use best results of the model as source of hyperparameters").implicit_value(true);
|
||||||
|
program.add_argument("-m", "--model")
|
||||||
|
.help("Model to use: " + platform::Models::instance()->toString())
|
||||||
|
.action([](const std::string& value) {
|
||||||
|
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
|
||||||
|
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString());
|
||||||
|
}
|
||||||
|
);
|
||||||
|
program.add_argument("--title").default_value("").help("Experiment title");
|
||||||
|
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
|
auto valid_choices = env.valid_tokens("discretize_algo");
|
||||||
|
auto& disc_arg = program.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo"));
|
||||||
|
for (auto choice : valid_choices) {
|
||||||
|
disc_arg.choices(choice);
|
||||||
|
}
|
||||||
|
valid_choices = env.valid_tokens("smooth_strat");
|
||||||
|
auto& smooth_arg = program.add_argument("--smooth-strat").help("Smooth strategy used in Bayes Network node initialization. Valid values: " + env.valid_values("smooth_strat")).default_value(env.get("smooth_strat"));
|
||||||
|
for (auto choice : valid_choices) {
|
||||||
|
smooth_arg.choices(choice);
|
||||||
|
}
|
||||||
|
auto& score_arg = program.add_argument("-s", "--score").help("Score to use. Valid values: " + env.valid_values("score")).default_value(env.get("score"));
|
||||||
|
valid_choices = env.valid_tokens("score");
|
||||||
|
for (auto choice : valid_choices) {
|
||||||
|
score_arg.choices(choice);
|
||||||
|
}
|
||||||
|
program.add_argument("--generate-fold-files").help("generate fold information in datasets_experiment folder").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--graph").help("generate graphviz dot files with the model").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
|
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
||||||
|
try {
|
||||||
|
auto k = stoi(value);
|
||||||
|
if (k < 2) {
|
||||||
|
throw std::runtime_error("Number of folds must be greater than 1");
|
||||||
|
}
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
catch (const runtime_error& err) {
|
||||||
|
throw std::runtime_error(err.what());
|
||||||
|
}
|
||||||
|
catch (...) {
|
||||||
|
throw std::runtime_error("Number of folds must be an integer");
|
||||||
|
}});
|
||||||
|
auto seed_values = env.getSeeds();
|
||||||
|
program.add_argument("--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
argparse::ArgumentParser program("b_main", { platform_project_version.begin(), platform_project_version.end() });
|
||||||
|
manageArguments(program);
|
||||||
|
std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat, score;
|
||||||
|
json hyperparameters_json;
|
||||||
|
bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph, hyper_best;
|
||||||
|
std::vector<int> seeds;
|
||||||
|
std::vector<std::string> file_names;
|
||||||
|
std::vector<std::string> filesToTest;
|
||||||
|
int n_folds;
|
||||||
|
try {
|
||||||
|
program.parse_args(argc, argv);
|
||||||
|
file_name = program.get<std::string>("dataset");
|
||||||
|
file_names = program.get<std::vector<std::string>>("datasets");
|
||||||
|
datasets_file = program.get<std::string>("datasets-file");
|
||||||
|
model_name = program.get<std::string>("model");
|
||||||
|
discretize_dataset = program.get<bool>("discretize");
|
||||||
|
discretize_algo = program.get<std::string>("discretize-algo");
|
||||||
|
smooth_strat = program.get<std::string>("smooth-strat");
|
||||||
|
stratified = program.get<bool>("stratified");
|
||||||
|
quiet = program.get<bool>("quiet");
|
||||||
|
graph = program.get<bool>("graph");
|
||||||
|
n_folds = program.get<int>("folds");
|
||||||
|
score = program.get<std::string>("score");
|
||||||
|
seeds = program.get<std::vector<int>>("seeds");
|
||||||
|
auto hyperparameters = program.get<std::string>("hyperparameters");
|
||||||
|
hyperparameters_json = json::parse(hyperparameters);
|
||||||
|
hyperparameters_file = program.get<std::string>("hyper-file");
|
||||||
|
no_train_score = program.get<bool>("no-train-score");
|
||||||
|
hyper_best = program.get<bool>("hyper-best");
|
||||||
|
generate_fold_files = program.get<bool>("generate-fold-files");
|
||||||
|
if (hyper_best) {
|
||||||
|
// Build the best results file_name
|
||||||
|
hyperparameters_file = platform::Paths::results() + platform::Paths::bestResultsFile(score, model_name);
|
||||||
|
// ignore this parameter
|
||||||
|
hyperparameters = "{}";
|
||||||
|
} else {
|
||||||
|
if (hyperparameters_file != "" && hyperparameters != "{}") {
|
||||||
|
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
title = program.get<std::string>("title");
|
||||||
|
if (title == "" && file_name == "all") {
|
||||||
|
throw runtime_error("title is mandatory if all datasets are to be tested");
|
||||||
|
}
|
||||||
|
saveResults = program.get<bool>("save");
|
||||||
|
}
|
||||||
|
catch (const exception& err) {
|
||||||
|
cerr << err.what() << std::endl;
|
||||||
|
cerr << program;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
||||||
|
if (datasets_file != "") {
|
||||||
|
ifstream catalog(datasets_file);
|
||||||
|
if (catalog.is_open()) {
|
||||||
|
std::string line;
|
||||||
|
while (getline(catalog, line)) {
|
||||||
|
if (line.empty() || line[0] == '#') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!datasets.isDataset(line)) {
|
||||||
|
cerr << "Dataset " << line << " not found" << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
filesToTest.push_back(line);
|
||||||
|
}
|
||||||
|
catalog.close();
|
||||||
|
saveResults = true;
|
||||||
|
if (title == "") {
|
||||||
|
title = "Test " + to_string(filesToTest.size()) + " datasets (" + datasets_file + ") "\
|
||||||
|
+ model_name + " " + to_string(n_folds) + " folds";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("Unable to open catalog file. [" + datasets_file + "]");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (file_names.size() > 0) {
|
||||||
|
for (auto file : file_names) {
|
||||||
|
if (!datasets.isDataset(file)) {
|
||||||
|
cerr << "Dataset " << file << " not found" << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filesToTest = file_names;
|
||||||
|
saveResults = true;
|
||||||
|
if (title == "") {
|
||||||
|
title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (file_name != "all") {
|
||||||
|
if (!datasets.isDataset(file_name)) {
|
||||||
|
cerr << "Dataset " << file_name << " not found" << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
if (title == "") {
|
||||||
|
title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds";
|
||||||
|
}
|
||||||
|
filesToTest.push_back(file_name);
|
||||||
|
} else {
|
||||||
|
filesToTest = datasets.getNames();
|
||||||
|
saveResults = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
platform::HyperParameters test_hyperparams;
|
||||||
|
if (hyperparameters_file != "") {
|
||||||
|
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file, hyper_best);
|
||||||
|
} else {
|
||||||
|
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_json);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Begin Processing
|
||||||
|
*/
|
||||||
|
auto env = platform::DotEnv();
|
||||||
|
auto experiment = platform::Experiment();
|
||||||
|
experiment.setTitle(title).setLanguage("c++").setLanguageVersion("gcc 14.1.1");
|
||||||
|
experiment.setDiscretizationAlgorithm(discretize_algo).setSmoothSrategy(smooth_strat);
|
||||||
|
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
||||||
|
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName(score);
|
||||||
|
experiment.setHyperparameters(test_hyperparams);
|
||||||
|
for (auto seed : seeds) {
|
||||||
|
experiment.addRandomSeed(seed);
|
||||||
|
}
|
||||||
|
platform::Timer timer;
|
||||||
|
timer.start();
|
||||||
|
experiment.go(filesToTest, quiet, no_train_score, generate_fold_files, graph);
|
||||||
|
experiment.setDuration(timer.getDuration());
|
||||||
|
if (!quiet) {
|
||||||
|
// Classification report if only one dataset is tested
|
||||||
|
experiment.report(filesToTest.size() == 1);
|
||||||
|
}
|
||||||
|
if (saveResults) {
|
||||||
|
experiment.saveResult();
|
||||||
|
}
|
||||||
|
if (graph) {
|
||||||
|
experiment.saveGraph();
|
||||||
|
}
|
||||||
|
std::cout << "Done!" << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
@@ -1,23 +1,25 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <sys/ioctl.h>
|
||||||
|
#include <utility>
|
||||||
|
#include <unistd.h>
|
||||||
#include <argparse/argparse.hpp>
|
#include <argparse/argparse.hpp>
|
||||||
#include "ManageResults.h"
|
#include "manage/ManageScreen.h"
|
||||||
#include "config.h"
|
#include <signal.h>
|
||||||
|
#include "config_platform.h"
|
||||||
|
|
||||||
|
platform::ManageScreen* manager = nullptr;
|
||||||
|
|
||||||
void manageArguments(argparse::ArgumentParser& program, int argc, char** argv)
|
void manageArguments(argparse::ArgumentParser& program, int argc, char** argv)
|
||||||
{
|
{
|
||||||
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
|
|
||||||
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
|
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
|
||||||
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
||||||
|
program.add_argument("--platform").default_value("any").help("Filter results of the selected platform");
|
||||||
program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true);
|
program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true);
|
program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--compare").help("Compare with best results").default_value(false).implicit_value(true);
|
program.add_argument("--compare").help("Compare with best results").default_value(false).implicit_value(true);
|
||||||
try {
|
try {
|
||||||
program.parse_args(argc, argv);
|
program.parse_args(argc, argv);
|
||||||
auto number = program.get<int>("number");
|
auto platform = program.get<std::string>("platform");
|
||||||
if (number < 0) {
|
|
||||||
throw std::runtime_error("Number of results must be greater than or equal to 0");
|
|
||||||
}
|
|
||||||
auto model = program.get<std::string>("model");
|
auto model = program.get<std::string>("model");
|
||||||
auto score = program.get<std::string>("score");
|
auto score = program.get<std::string>("score");
|
||||||
auto complete = program.get<bool>("complete");
|
auto complete = program.get<bool>("complete");
|
||||||
@@ -31,19 +33,40 @@ void manageArguments(argparse::ArgumentParser& program, int argc, char** argv)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<int, int> numRowsCols()
|
||||||
|
{
|
||||||
|
#ifdef TIOCGSIZE
|
||||||
|
struct ttysize ts;
|
||||||
|
ioctl(STDIN_FILENO, TIOCGSIZE, &ts);
|
||||||
|
return { ts.ts_lines, ts.ts_cols };
|
||||||
|
#elif defined(TIOCGWINSZ)
|
||||||
|
struct winsize ts;
|
||||||
|
ioctl(STDIN_FILENO, TIOCGWINSZ, &ts);
|
||||||
|
return { ts.ws_row, ts.ws_col };
|
||||||
|
#endif /* TIOCGSIZE */
|
||||||
|
}
|
||||||
|
void handleResize(int sig)
|
||||||
|
{
|
||||||
|
auto [rows, cols] = numRowsCols();
|
||||||
|
manager->updateSize(rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
auto program = argparse::ArgumentParser("b_manage", { project_version.begin(), project_version.end() });
|
auto program = argparse::ArgumentParser("b_manage", { platform_project_version.begin(), platform_project_version.end() });
|
||||||
manageArguments(program, argc, argv);
|
manageArguments(program, argc, argv);
|
||||||
int number = program.get<int>("number");
|
|
||||||
std::string model = program.get<std::string>("model");
|
std::string model = program.get<std::string>("model");
|
||||||
std::string score = program.get<std::string>("score");
|
std::string score = program.get<std::string>("score");
|
||||||
auto complete = program.get<bool>("complete");
|
std::string platform = program.get<std::string>("platform");
|
||||||
auto partial = program.get<bool>("partial");
|
bool complete = program.get<bool>("complete");
|
||||||
auto compare = program.get<bool>("compare");
|
bool partial = program.get<bool>("partial");
|
||||||
|
bool compare = program.get<bool>("compare");
|
||||||
if (complete)
|
if (complete)
|
||||||
partial = false;
|
partial = false;
|
||||||
auto manager = platform::ManageResults(number, model, score, complete, partial, compare);
|
signal(SIGWINCH, handleResize);
|
||||||
manager.doMenu();
|
auto [rows, cols] = numRowsCols();
|
||||||
|
manager = new platform::ManageScreen(rows, cols, model, score, platform, complete, partial, compare);
|
||||||
|
manager->doMenu();
|
||||||
|
delete manager;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
@@ -1,5 +1,5 @@
|
|||||||
#ifndef LOCALE_H
|
#ifndef CLOCALE_H
|
||||||
#define LOCALE_H
|
#define CLOCALE_H
|
||||||
#include <locale>
|
#include <locale>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@@ -1,15 +1,30 @@
|
|||||||
#ifndef COLORS_H
|
#ifndef COLORS_H
|
||||||
#define COLORS_H
|
#define COLORS_H
|
||||||
|
#include <string>
|
||||||
class Colors {
|
class Colors {
|
||||||
public:
|
public:
|
||||||
static std::string MAGENTA() { return "\033[1;35m"; }
|
static std::string BLACK() { return "\033[1;30m"; }
|
||||||
|
static std::string IBLACK() { return "\033[0;90m"; }
|
||||||
static std::string BLUE() { return "\033[1;34m"; }
|
static std::string BLUE() { return "\033[1;34m"; }
|
||||||
static std::string CYAN() { return "\033[1;36m"; }
|
|
||||||
static std::string GREEN() { return "\033[1;32m"; }
|
|
||||||
static std::string YELLOW() { return "\033[1;33m"; }
|
|
||||||
static std::string RED() { return "\033[1;31m"; }
|
|
||||||
static std::string WHITE() { return "\033[1;37m"; }
|
|
||||||
static std::string IBLUE() { return "\033[0;94m"; }
|
static std::string IBLUE() { return "\033[0;94m"; }
|
||||||
|
static std::string CYAN() { return "\033[1;36m"; }
|
||||||
|
static std::string ICYAN() { return "\033[0;96m"; }
|
||||||
|
static std::string GREEN() { return "\033[1;32m"; }
|
||||||
|
static std::string IGREEN() { return "\033[0;92m"; }
|
||||||
|
static std::string MAGENTA() { return "\033[1;35m"; }
|
||||||
|
static std::string IMAGENTA() { return "\033[0;95m"; }
|
||||||
|
static std::string RED() { return "\033[1;31m"; }
|
||||||
|
static std::string IRED() { return "\033[0;91m"; }
|
||||||
|
static std::string YELLOW() { return "\033[1;33m"; }
|
||||||
|
static std::string IYELLOW() { return "\033[0;93m"; }
|
||||||
|
static std::string WHITE() { return "\033[1;37m"; }
|
||||||
|
static std::string IWHITE() { return "\033[0;97m"; }
|
||||||
static std::string RESET() { return "\033[0m"; }
|
static std::string RESET() { return "\033[0m"; }
|
||||||
|
static std::string BOLD() { return "\033[1m"; }
|
||||||
|
static std::string UNDERLINE() { return "\033[4m"; }
|
||||||
|
static std::string BLINK() { return "\033[5m"; }
|
||||||
|
static std::string REVERSE() { return "\033[7m"; }
|
||||||
|
static std::string CONCEALED() { return "\033[8m"; }
|
||||||
|
static std::string CLRSCR() { return "\033[2J\033[1;1H"; }
|
||||||
};
|
};
|
||||||
#endif // COLORS_H
|
#endif
|
@@ -1,215 +0,0 @@
|
|||||||
#include "Dataset.h"
|
|
||||||
#include "ArffFiles.h"
|
|
||||||
#include <fstream>
|
|
||||||
namespace platform {
|
|
||||||
Dataset::Dataset(const Dataset& dataset) : path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples), n_features(dataset.n_features), features(dataset.features), states(dataset.states), loaded(dataset.loaded), discretize(dataset.discretize), X(dataset.X), y(dataset.y), Xv(dataset.Xv), Xd(dataset.Xd), yv(dataset.yv), fileType(dataset.fileType)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
std::string Dataset::getName() const
|
|
||||||
{
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
std::string Dataset::getClassName() const
|
|
||||||
{
|
|
||||||
return className;
|
|
||||||
}
|
|
||||||
std::vector<std::string> Dataset::getFeatures() const
|
|
||||||
{
|
|
||||||
if (loaded) {
|
|
||||||
return features;
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int Dataset::getNFeatures() const
|
|
||||||
{
|
|
||||||
if (loaded) {
|
|
||||||
return n_features;
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int Dataset::getNSamples() const
|
|
||||||
{
|
|
||||||
if (loaded) {
|
|
||||||
return n_samples;
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::map<std::string, std::vector<int>> Dataset::getStates() const
|
|
||||||
{
|
|
||||||
if (loaded) {
|
|
||||||
return states;
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pair<std::vector<std::vector<float>>&, std::vector<int>&> Dataset::getVectors()
|
|
||||||
{
|
|
||||||
if (loaded) {
|
|
||||||
return { Xv, yv };
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pair<std::vector<std::vector<int>>&, std::vector<int>&> Dataset::getVectorsDiscretized()
|
|
||||||
{
|
|
||||||
if (loaded) {
|
|
||||||
return { Xd, yv };
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pair<torch::Tensor&, torch::Tensor&> Dataset::getTensors()
|
|
||||||
{
|
|
||||||
if (loaded) {
|
|
||||||
buildTensors();
|
|
||||||
return { X, y };
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void Dataset::load_csv()
|
|
||||||
{
|
|
||||||
ifstream file(path + "/" + name + ".csv");
|
|
||||||
if (file.is_open()) {
|
|
||||||
std::string line;
|
|
||||||
getline(file, line);
|
|
||||||
std::vector<std::string> tokens = split(line, ',');
|
|
||||||
features = std::vector<std::string>(tokens.begin(), tokens.end() - 1);
|
|
||||||
if (className == "-1") {
|
|
||||||
className = tokens.back();
|
|
||||||
}
|
|
||||||
for (auto i = 0; i < features.size(); ++i) {
|
|
||||||
Xv.push_back(std::vector<float>());
|
|
||||||
}
|
|
||||||
while (getline(file, line)) {
|
|
||||||
tokens = split(line, ',');
|
|
||||||
for (auto i = 0; i < features.size(); ++i) {
|
|
||||||
Xv[i].push_back(stof(tokens[i]));
|
|
||||||
}
|
|
||||||
yv.push_back(stoi(tokens.back()));
|
|
||||||
}
|
|
||||||
file.close();
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Unable to open dataset file.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void Dataset::computeStates()
|
|
||||||
{
|
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
|
||||||
states[features[i]] = std::vector<int>(*max_element(Xd[i].begin(), Xd[i].end()) + 1);
|
|
||||||
auto item = states.at(features[i]);
|
|
||||||
iota(begin(item), end(item), 0);
|
|
||||||
}
|
|
||||||
states[className] = std::vector<int>(*max_element(yv.begin(), yv.end()) + 1);
|
|
||||||
iota(begin(states.at(className)), end(states.at(className)), 0);
|
|
||||||
}
|
|
||||||
void Dataset::load_arff()
|
|
||||||
{
|
|
||||||
auto arff = ArffFiles();
|
|
||||||
arff.load(path + "/" + name + ".arff", className);
|
|
||||||
// Get Dataset X, y
|
|
||||||
Xv = arff.getX();
|
|
||||||
yv = arff.getY();
|
|
||||||
// Get className & Features
|
|
||||||
className = arff.getClassName();
|
|
||||||
auto attributes = arff.getAttributes();
|
|
||||||
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; });
|
|
||||||
}
|
|
||||||
std::vector<std::string> tokenize(std::string line)
|
|
||||||
{
|
|
||||||
std::vector<std::string> tokens;
|
|
||||||
for (auto i = 0; i < line.size(); ++i) {
|
|
||||||
if (line[i] == ' ' || line[i] == '\t' || line[i] == '\n') {
|
|
||||||
std::string token = line.substr(0, i);
|
|
||||||
tokens.push_back(token);
|
|
||||||
line.erase(line.begin(), line.begin() + i + 1);
|
|
||||||
i = 0;
|
|
||||||
while (line[i] == ' ' || line[i] == '\t' || line[i] == '\n')
|
|
||||||
line.erase(line.begin(), line.begin() + i + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (line.size() > 0) {
|
|
||||||
tokens.push_back(line);
|
|
||||||
}
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
void Dataset::load_rdata()
|
|
||||||
{
|
|
||||||
ifstream file(path + "/" + name + "_R.dat");
|
|
||||||
if (file.is_open()) {
|
|
||||||
std::string line;
|
|
||||||
getline(file, line);
|
|
||||||
line = ArffFiles::trim(line);
|
|
||||||
std::vector<std::string> tokens = tokenize(line);
|
|
||||||
transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); });
|
|
||||||
if (className == "-1") {
|
|
||||||
className = ArffFiles::trim(tokens.back());
|
|
||||||
}
|
|
||||||
for (auto i = 0; i < features.size(); ++i) {
|
|
||||||
Xv.push_back(std::vector<float>());
|
|
||||||
}
|
|
||||||
while (getline(file, line)) {
|
|
||||||
tokens = tokenize(line);
|
|
||||||
// We have to skip the first token, which is the instance number.
|
|
||||||
for (auto i = 1; i < features.size() + 1; ++i) {
|
|
||||||
const float value = stof(tokens[i]);
|
|
||||||
Xv[i - 1].push_back(value);
|
|
||||||
}
|
|
||||||
yv.push_back(stoi(tokens.back()));
|
|
||||||
}
|
|
||||||
file.close();
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Unable to open dataset file.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void Dataset::load()
|
|
||||||
{
|
|
||||||
if (loaded) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (fileType == CSV) {
|
|
||||||
load_csv();
|
|
||||||
} else if (fileType == ARFF) {
|
|
||||||
load_arff();
|
|
||||||
} else if (fileType == RDATA) {
|
|
||||||
load_rdata();
|
|
||||||
}
|
|
||||||
if (discretize) {
|
|
||||||
Xd = discretizeDataset(Xv, yv);
|
|
||||||
computeStates();
|
|
||||||
}
|
|
||||||
n_samples = Xv[0].size();
|
|
||||||
n_features = Xv.size();
|
|
||||||
loaded = true;
|
|
||||||
}
|
|
||||||
void Dataset::buildTensors()
|
|
||||||
{
|
|
||||||
if (discretize) {
|
|
||||||
X = torch::zeros({ static_cast<int>(n_features), static_cast<int>(n_samples) }, torch::kInt32);
|
|
||||||
} else {
|
|
||||||
X = torch::zeros({ static_cast<int>(n_features), static_cast<int>(n_samples) }, torch::kFloat32);
|
|
||||||
}
|
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
|
||||||
if (discretize) {
|
|
||||||
X.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32));
|
|
||||||
} else {
|
|
||||||
X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
y = torch::tensor(yv, torch::kInt32);
|
|
||||||
}
|
|
||||||
std::vector<mdlp::labels_t> Dataset::discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y)
|
|
||||||
{
|
|
||||||
std::vector<mdlp::labels_t> Xd;
|
|
||||||
auto fimdlp = mdlp::CPPFImdlp();
|
|
||||||
for (int i = 0; i < X.size(); i++) {
|
|
||||||
fimdlp.fit(X[i], y);
|
|
||||||
mdlp::labels_t& xd = fimdlp.transform(X[i]);
|
|
||||||
Xd.push_back(xd);
|
|
||||||
}
|
|
||||||
return Xd;
|
|
||||||
}
|
|
||||||
}
|
|
278
src/common/Dataset.cpp
Normal file
278
src/common/Dataset.cpp
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
#include <ArffFiles.hpp>
|
||||||
|
#include <fstream>
|
||||||
|
#include "Dataset.h"
|
||||||
|
namespace platform {
|
||||||
|
const std::string message_dataset_not_loaded = "Dataset not loaded.";
|
||||||
|
Dataset::Dataset(const Dataset& dataset) :
|
||||||
|
path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples),
|
||||||
|
n_features(dataset.n_features), numericFeatures(dataset.numericFeatures), features(dataset.features),
|
||||||
|
states(dataset.states), loaded(dataset.loaded), discretize(dataset.discretize), X(dataset.X), y(dataset.y),
|
||||||
|
X_train(dataset.X_train), X_test(dataset.X_test), Xv(dataset.Xv), yv(dataset.yv),
|
||||||
|
fileType(dataset.fileType)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
std::string Dataset::getName() const
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
std::vector<std::string> Dataset::getFeatures() const
|
||||||
|
{
|
||||||
|
if (loaded) {
|
||||||
|
return features;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int Dataset::getNFeatures() const
|
||||||
|
{
|
||||||
|
if (loaded) {
|
||||||
|
return n_features;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int Dataset::getNSamples() const
|
||||||
|
{
|
||||||
|
if (loaded) {
|
||||||
|
return n_samples;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string Dataset::getClassName() const
|
||||||
|
{
|
||||||
|
return className;
|
||||||
|
}
|
||||||
|
int Dataset::getNClasses() const
|
||||||
|
{
|
||||||
|
if (loaded) {
|
||||||
|
return *std::max_element(yv.begin(), yv.end()) + 1;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<std::string> Dataset::getLabels() const
|
||||||
|
{
|
||||||
|
// Return the labels factorization result
|
||||||
|
if (loaded) {
|
||||||
|
return labels;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<int> Dataset::getClassesCounts() const
|
||||||
|
{
|
||||||
|
if (loaded) {
|
||||||
|
std::vector<int> counts(*std::max_element(yv.begin(), yv.end()) + 1);
|
||||||
|
for (auto y : yv) {
|
||||||
|
counts[y]++;
|
||||||
|
}
|
||||||
|
return counts;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::map<std::string, std::vector<int>> Dataset::getStates() const
|
||||||
|
{
|
||||||
|
if (loaded) {
|
||||||
|
return states;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pair<std::vector<std::vector<float>>&, std::vector<int>&> Dataset::getVectors()
|
||||||
|
{
|
||||||
|
if (loaded) {
|
||||||
|
return { Xv, yv };
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pair<torch::Tensor&, torch::Tensor&> Dataset::getTensors()
|
||||||
|
{
|
||||||
|
if (loaded) {
|
||||||
|
return { X, y };
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Dataset::load_csv()
|
||||||
|
{
|
||||||
|
ifstream file(path + "/" + name + ".csv");
|
||||||
|
if (!file.is_open()) {
|
||||||
|
throw std::invalid_argument("Unable to open dataset file.");
|
||||||
|
}
|
||||||
|
labels.clear();
|
||||||
|
std::string line;
|
||||||
|
getline(file, line);
|
||||||
|
std::vector<std::string> tokens = split(line, ',');
|
||||||
|
features = std::vector<std::string>(tokens.begin(), tokens.end() - 1);
|
||||||
|
if (className == "-1") {
|
||||||
|
className = tokens.back();
|
||||||
|
}
|
||||||
|
for (auto i = 0; i < features.size(); ++i) {
|
||||||
|
Xv.push_back(std::vector<float>());
|
||||||
|
}
|
||||||
|
while (getline(file, line)) {
|
||||||
|
tokens = split(line, ',');
|
||||||
|
for (auto i = 0; i < features.size(); ++i) {
|
||||||
|
Xv[i].push_back(stof(tokens[i]));
|
||||||
|
}
|
||||||
|
auto label = trim(tokens.back());
|
||||||
|
if (find(labels.begin(), labels.end(), label) == labels.end()) {
|
||||||
|
labels.push_back(label);
|
||||||
|
}
|
||||||
|
yv.push_back(stoi(label));
|
||||||
|
}
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
void Dataset::computeStates()
|
||||||
|
{
|
||||||
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
|
auto [max_value, idx] = torch::max(X_train.index({ i, "..." }), 0);
|
||||||
|
states[features[i]] = std::vector<int>(max_value.item<int>() + 1);
|
||||||
|
iota(begin(states.at(features[i])), end(states.at(features[i])), 0);
|
||||||
|
}
|
||||||
|
auto [max_value, idx] = torch::max(y_train, 0);
|
||||||
|
states[className] = std::vector<int>(max_value.item<int>() + 1);
|
||||||
|
iota(begin(states.at(className)), end(states.at(className)), 0);
|
||||||
|
}
|
||||||
|
void Dataset::load_arff()
|
||||||
|
{
|
||||||
|
auto arff = ArffFiles();
|
||||||
|
arff.load(path + "/" + name + ".arff", className);
|
||||||
|
// Get Dataset X, y
|
||||||
|
Xv = arff.getX();
|
||||||
|
yv = arff.getY();
|
||||||
|
// Get className & Features
|
||||||
|
className = arff.getClassName();
|
||||||
|
auto attributes = arff.getAttributes();
|
||||||
|
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; });
|
||||||
|
labels = arff.getLabels();
|
||||||
|
}
|
||||||
|
std::vector<std::string> tokenize(std::string line)
|
||||||
|
{
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
for (auto i = 0; i < line.size(); ++i) {
|
||||||
|
if (line[i] == ' ' || line[i] == '\t' || line[i] == '\n') {
|
||||||
|
std::string token = line.substr(0, i);
|
||||||
|
tokens.push_back(token);
|
||||||
|
line.erase(line.begin(), line.begin() + i + 1);
|
||||||
|
i = 0;
|
||||||
|
while (line[i] == ' ' || line[i] == '\t' || line[i] == '\n')
|
||||||
|
line.erase(line.begin(), line.begin() + i + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (line.size() > 0) {
|
||||||
|
tokens.push_back(line);
|
||||||
|
}
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
void Dataset::load_rdata()
|
||||||
|
{
|
||||||
|
ifstream file(path + "/" + name + "_R.dat");
|
||||||
|
if (!file.is_open()) {
|
||||||
|
throw std::invalid_argument("Unable to open dataset file.");
|
||||||
|
}
|
||||||
|
std::string line;
|
||||||
|
labels.clear();
|
||||||
|
getline(file, line);
|
||||||
|
line = ArffFiles::trim(line);
|
||||||
|
std::vector<std::string> tokens = tokenize(line);
|
||||||
|
transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); });
|
||||||
|
if (className == "-1") {
|
||||||
|
className = ArffFiles::trim(tokens.back());
|
||||||
|
}
|
||||||
|
for (auto i = 0; i < features.size(); ++i) {
|
||||||
|
Xv.push_back(std::vector<float>());
|
||||||
|
}
|
||||||
|
while (getline(file, line)) {
|
||||||
|
tokens = tokenize(line);
|
||||||
|
// We have to skip the first token, which is the instance number.
|
||||||
|
for (auto i = 1; i < features.size() + 1; ++i) {
|
||||||
|
const float value = stof(tokens[i]);
|
||||||
|
Xv[i - 1].push_back(value);
|
||||||
|
}
|
||||||
|
auto label = trim(tokens.back());
|
||||||
|
if (find(labels.begin(), labels.end(), label) == labels.end()) {
|
||||||
|
labels.push_back(label);
|
||||||
|
}
|
||||||
|
yv.push_back(stoi(label));
|
||||||
|
}
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
void Dataset::load()
|
||||||
|
{
|
||||||
|
if (loaded) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (fileType == CSV) {
|
||||||
|
load_csv();
|
||||||
|
} else if (fileType == ARFF) {
|
||||||
|
load_arff();
|
||||||
|
} else if (fileType == RDATA) {
|
||||||
|
load_rdata();
|
||||||
|
}
|
||||||
|
n_samples = Xv[0].size();
|
||||||
|
n_features = Xv.size();
|
||||||
|
if (numericFeaturesIdx.size() == 0) {
|
||||||
|
numericFeatures = std::vector<bool>(n_features, false);
|
||||||
|
} else {
|
||||||
|
if (numericFeaturesIdx.at(0) == -1) {
|
||||||
|
numericFeatures = std::vector<bool>(n_features, true);
|
||||||
|
} else {
|
||||||
|
numericFeatures = std::vector<bool>(n_features, false);
|
||||||
|
for (auto i : numericFeaturesIdx) {
|
||||||
|
numericFeatures[i] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Build Tensors
|
||||||
|
X = torch::zeros({ n_features, n_samples }, torch::kFloat32);
|
||||||
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
|
X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32));
|
||||||
|
}
|
||||||
|
y = torch::tensor(yv, torch::kInt32);
|
||||||
|
loaded = true;
|
||||||
|
}
|
||||||
|
std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> Dataset::getTrainTestTensors(std::vector<int>& train, std::vector<int>& test)
|
||||||
|
{
|
||||||
|
if (!loaded) {
|
||||||
|
throw std::invalid_argument(message_dataset_not_loaded);
|
||||||
|
}
|
||||||
|
auto train_t = torch::tensor(train);
|
||||||
|
int samples_train = train.size();
|
||||||
|
int samples_test = test.size();
|
||||||
|
auto test_t = torch::tensor(test);
|
||||||
|
X_train = X.index({ "...", train_t });
|
||||||
|
y_train = y.index({ train_t });
|
||||||
|
X_test = X.index({ "...", test_t });
|
||||||
|
y_test = y.index({ test_t });
|
||||||
|
if (discretize) {
|
||||||
|
auto discretizer = Discretization::instance()->create(discretizer_algorithm);
|
||||||
|
auto X_train_d = torch::zeros({ n_features, samples_train }, torch::kInt32);
|
||||||
|
auto X_test_d = torch::zeros({ n_features, samples_test }, torch::kInt32);
|
||||||
|
for (auto feature = 0; feature < n_features; ++feature) {
|
||||||
|
if (numericFeatures[feature]) {
|
||||||
|
auto feature_train = X_train.index({ feature, "..." });
|
||||||
|
auto feature_test = X_test.index({ feature, "..." });
|
||||||
|
auto feature_train_disc = discretizer->fit_transform_t(feature_train, y_train);
|
||||||
|
auto feature_test_disc = discretizer->transform_t(feature_test);
|
||||||
|
X_train_d.index_put_({ feature, "..." }, feature_train_disc);
|
||||||
|
X_test_d.index_put_({ feature, "..." }, feature_test_disc);
|
||||||
|
} else {
|
||||||
|
X_train_d.index_put_({ feature, "..." }, X_train.index({ feature, "..." }).to(torch::kInt32));
|
||||||
|
X_test_d.index_put_({ feature, "..." }, X_test.index({ feature, "..." }).to(torch::kInt32));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
X_train = X_train_d;
|
||||||
|
X_test = X_test_d;
|
||||||
|
assert(X_train.dtype() == torch::kInt32);
|
||||||
|
assert(X_test.dtype() == torch::kInt32);
|
||||||
|
computeStates();
|
||||||
|
}
|
||||||
|
assert(y_train.dtype() == torch::kInt32);
|
||||||
|
assert(y_test.dtype() == torch::kInt32);
|
||||||
|
return { X_train, X_test, y_train, y_test };
|
||||||
|
}
|
||||||
|
}
|
@@ -4,75 +4,57 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "CPPFImdlp.h"
|
#include <tuple>
|
||||||
|
#include <common/DiscretizationRegister.h>
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
|
#include "SourceData.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
enum fileType_t { CSV, ARFF, RDATA };
|
|
||||||
class SourceData {
|
|
||||||
public:
|
|
||||||
SourceData(std::string source)
|
|
||||||
{
|
|
||||||
if (source == "Surcov") {
|
|
||||||
path = "datasets/";
|
|
||||||
fileType = CSV;
|
|
||||||
} else if (source == "Arff") {
|
|
||||||
path = "datasets/";
|
|
||||||
fileType = ARFF;
|
|
||||||
} else if (source == "Tanveer") {
|
|
||||||
path = "data/";
|
|
||||||
fileType = RDATA;
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Unknown source.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::string getPath()
|
|
||||||
{
|
|
||||||
return path;
|
|
||||||
}
|
|
||||||
fileType_t getFileType()
|
|
||||||
{
|
|
||||||
return fileType;
|
|
||||||
}
|
|
||||||
private:
|
|
||||||
std::string path;
|
|
||||||
fileType_t fileType;
|
|
||||||
};
|
|
||||||
class Dataset {
|
class Dataset {
|
||||||
|
public:
|
||||||
|
Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType, std::vector<int> numericFeaturesIdx, std::string discretizer_algo = "none") :
|
||||||
|
path(path), name(name), className(className), discretize(discretize),
|
||||||
|
loaded(false), fileType(fileType), numericFeaturesIdx(numericFeaturesIdx), discretizer_algorithm(discretizer_algo)
|
||||||
|
{
|
||||||
|
};
|
||||||
|
explicit Dataset(const Dataset&);
|
||||||
|
std::string getName() const;
|
||||||
|
std::string getClassName() const;
|
||||||
|
int getNClasses() const;
|
||||||
|
std::vector<std::string> getLabels() const; // return the labels factorization result
|
||||||
|
std::vector<int> getClassesCounts() const;
|
||||||
|
std::vector<string> getFeatures() const;
|
||||||
|
std::map<std::string, std::vector<int>> getStates() const;
|
||||||
|
std::pair<vector<std::vector<float>>&, std::vector<int>&> getVectors();
|
||||||
|
std::pair<torch::Tensor&, torch::Tensor&> getTensors();
|
||||||
|
std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> getTrainTestTensors(std::vector<int>& train, std::vector<int>& test);
|
||||||
|
int getNFeatures() const;
|
||||||
|
int getNSamples() const;
|
||||||
|
std::vector<bool>& getNumericFeatures() { return numericFeatures; }
|
||||||
|
void load();
|
||||||
|
const bool inline isLoaded() const { return loaded; };
|
||||||
private:
|
private:
|
||||||
std::string path;
|
std::string path;
|
||||||
std::string name;
|
std::string name;
|
||||||
fileType_t fileType;
|
fileType_t fileType;
|
||||||
std::string className;
|
std::string className;
|
||||||
int n_samples{ 0 }, n_features{ 0 };
|
int n_samples{ 0 }, n_features{ 0 };
|
||||||
|
std::vector<int> numericFeaturesIdx;
|
||||||
|
std::string discretizer_algorithm;
|
||||||
|
std::vector<bool> numericFeatures; // true if feature is numeric
|
||||||
std::vector<std::string> features;
|
std::vector<std::string> features;
|
||||||
|
std::vector<std::string> labels;
|
||||||
std::map<std::string, std::vector<int>> states;
|
std::map<std::string, std::vector<int>> states;
|
||||||
bool loaded;
|
bool loaded;
|
||||||
bool discretize;
|
bool discretize;
|
||||||
torch::Tensor X, y;
|
torch::Tensor X, y;
|
||||||
|
torch::Tensor X_train, X_test, y_train, y_test;
|
||||||
std::vector<std::vector<float>> Xv;
|
std::vector<std::vector<float>> Xv;
|
||||||
std::vector<std::vector<int>> Xd;
|
|
||||||
std::vector<int> yv;
|
std::vector<int> yv;
|
||||||
void buildTensors();
|
|
||||||
void load_csv();
|
void load_csv();
|
||||||
void load_arff();
|
void load_arff();
|
||||||
void load_rdata();
|
void load_rdata();
|
||||||
void computeStates();
|
void computeStates();
|
||||||
std::vector<mdlp::labels_t> discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y);
|
std::vector<mdlp::labels_t> discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y);
|
||||||
public:
|
|
||||||
Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
|
|
||||||
explicit Dataset(const Dataset&);
|
|
||||||
std::string getName() const;
|
|
||||||
std::string getClassName() const;
|
|
||||||
std::vector<string> getFeatures() const;
|
|
||||||
std::map<std::string, std::vector<int>> getStates() const;
|
|
||||||
std::pair<vector<std::vector<float>>&, std::vector<int>&> getVectors();
|
|
||||||
std::pair<vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized();
|
|
||||||
std::pair<torch::Tensor&, torch::Tensor&> getTensors();
|
|
||||||
int getNFeatures() const;
|
|
||||||
int getNSamples() const;
|
|
||||||
void load();
|
|
||||||
const bool inline isLoaded() const { return loaded; };
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
@@ -1,129 +0,0 @@
|
|||||||
#include "Datasets.h"
|
|
||||||
#include <fstream>
|
|
||||||
namespace platform {
|
|
||||||
void Datasets::load()
|
|
||||||
{
|
|
||||||
auto sd = SourceData(sfileType);
|
|
||||||
fileType = sd.getFileType();
|
|
||||||
path = sd.getPath();
|
|
||||||
ifstream catalog(path + "all.txt");
|
|
||||||
if (catalog.is_open()) {
|
|
||||||
std::string line;
|
|
||||||
while (getline(catalog, line)) {
|
|
||||||
if (line.empty() || line[0] == '#') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::vector<std::string> tokens = split(line, ',');
|
|
||||||
std::string name = tokens[0];
|
|
||||||
std::string className;
|
|
||||||
if (tokens.size() == 1) {
|
|
||||||
className = "-1";
|
|
||||||
} else {
|
|
||||||
className = tokens[1];
|
|
||||||
}
|
|
||||||
datasets[name] = make_unique<Dataset>(path, name, className, discretize, fileType);
|
|
||||||
}
|
|
||||||
catalog.close();
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Unable to open catalog file. [" + path + "all.txt" + "]");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<std::string> Datasets::getNames()
|
|
||||||
{
|
|
||||||
std::vector<std::string> result;
|
|
||||||
transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; });
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
std::vector<std::string> Datasets::getFeatures(const std::string& name) const
|
|
||||||
{
|
|
||||||
if (datasets.at(name)->isLoaded()) {
|
|
||||||
return datasets.at(name)->getFeatures();
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
map<std::string, std::vector<int>> Datasets::getStates(const std::string& name) const
|
|
||||||
{
|
|
||||||
if (datasets.at(name)->isLoaded()) {
|
|
||||||
return datasets.at(name)->getStates();
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void Datasets::loadDataset(const std::string& name) const
|
|
||||||
{
|
|
||||||
if (datasets.at(name)->isLoaded()) {
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
datasets.at(name)->load();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::string Datasets::getClassName(const std::string& name) const
|
|
||||||
{
|
|
||||||
if (datasets.at(name)->isLoaded()) {
|
|
||||||
return datasets.at(name)->getClassName();
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int Datasets::getNSamples(const std::string& name) const
|
|
||||||
{
|
|
||||||
if (datasets.at(name)->isLoaded()) {
|
|
||||||
return datasets.at(name)->getNSamples();
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int Datasets::getNClasses(const std::string& name)
|
|
||||||
{
|
|
||||||
if (datasets.at(name)->isLoaded()) {
|
|
||||||
auto className = datasets.at(name)->getClassName();
|
|
||||||
if (discretize) {
|
|
||||||
auto states = getStates(name);
|
|
||||||
return states.at(className).size();
|
|
||||||
}
|
|
||||||
auto [Xv, yv] = getVectors(name);
|
|
||||||
return *std::max_element(yv.begin(), yv.end()) + 1;
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<int> Datasets::getClassesCounts(const std::string& name) const
|
|
||||||
{
|
|
||||||
if (datasets.at(name)->isLoaded()) {
|
|
||||||
auto [Xv, yv] = datasets.at(name)->getVectors();
|
|
||||||
std::vector<int> counts(*std::max_element(yv.begin(), yv.end()) + 1);
|
|
||||||
for (auto y : yv) {
|
|
||||||
counts[y]++;
|
|
||||||
}
|
|
||||||
return counts;
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Dataset not loaded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pair<std::vector<std::vector<float>>&, std::vector<int>&> Datasets::getVectors(const std::string& name)
|
|
||||||
{
|
|
||||||
if (!datasets[name]->isLoaded()) {
|
|
||||||
datasets[name]->load();
|
|
||||||
}
|
|
||||||
return datasets[name]->getVectors();
|
|
||||||
}
|
|
||||||
pair<std::vector<std::vector<int>>&, std::vector<int>&> Datasets::getVectorsDiscretized(const std::string& name)
|
|
||||||
{
|
|
||||||
if (!datasets[name]->isLoaded()) {
|
|
||||||
datasets[name]->load();
|
|
||||||
}
|
|
||||||
return datasets[name]->getVectorsDiscretized();
|
|
||||||
}
|
|
||||||
pair<torch::Tensor&, torch::Tensor&> Datasets::getTensors(const std::string& name)
|
|
||||||
{
|
|
||||||
if (!datasets[name]->isLoaded()) {
|
|
||||||
datasets[name]->load();
|
|
||||||
}
|
|
||||||
return datasets[name]->getTensors();
|
|
||||||
}
|
|
||||||
bool Datasets::isDataset(const std::string& name) const
|
|
||||||
{
|
|
||||||
return datasets.find(name) != datasets.end();
|
|
||||||
}
|
|
||||||
}
|
|
89
src/common/Datasets.cpp
Normal file
89
src/common/Datasets.cpp
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
#include <fstream>
|
||||||
|
#include "Datasets.h"
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
const std::string message_dataset_not_loaded = "dataset not loaded.";
|
||||||
|
Datasets::Datasets(bool discretize, std::string sfileType, std::string discretizer_algorithm) :
|
||||||
|
discretize(discretize), sfileType(sfileType), discretizer_algorithm(discretizer_algorithm)
|
||||||
|
{
|
||||||
|
if ((discretizer_algorithm == "none" || discretizer_algorithm == "") && discretize) {
|
||||||
|
throw std::runtime_error("Can't discretize without discretization algorithm");
|
||||||
|
}
|
||||||
|
load();
|
||||||
|
}
|
||||||
|
void Datasets::load()
|
||||||
|
{
|
||||||
|
auto sd = SourceData(sfileType);
|
||||||
|
fileType = sd.getFileType();
|
||||||
|
path = sd.getPath();
|
||||||
|
ifstream catalog(path + "all.txt");
|
||||||
|
std::vector<int> numericFeaturesIdx;
|
||||||
|
if (!catalog.is_open()) {
|
||||||
|
throw std::invalid_argument("Unable to open catalog file. [" + path + "all.txt" + "]");
|
||||||
|
}
|
||||||
|
std::string line;
|
||||||
|
while (getline(catalog, line)) {
|
||||||
|
if (line.empty() || line[0] == '#') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<std::string> tokens = split(line, ';');
|
||||||
|
std::string name = tokens[0];
|
||||||
|
std::string className;
|
||||||
|
numericFeaturesIdx.clear();
|
||||||
|
int size = tokens.size();
|
||||||
|
switch (size) {
|
||||||
|
case 1:
|
||||||
|
className = "-1";
|
||||||
|
numericFeaturesIdx.push_back(-1);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
className = tokens[1];
|
||||||
|
numericFeaturesIdx.push_back(-1);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
{
|
||||||
|
className = tokens[1];
|
||||||
|
auto numericFeatures = tokens[2];
|
||||||
|
if (numericFeatures == "all") {
|
||||||
|
numericFeaturesIdx.push_back(-1);
|
||||||
|
} else {
|
||||||
|
if (numericFeatures != "none") {
|
||||||
|
auto features = json::parse(numericFeatures);
|
||||||
|
for (auto& f : features) {
|
||||||
|
numericFeaturesIdx.push_back(f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid catalog file format.");
|
||||||
|
|
||||||
|
}
|
||||||
|
datasets[name] = make_unique<Dataset>(path, name, className, discretize, fileType, numericFeaturesIdx, discretizer_algorithm);
|
||||||
|
}
|
||||||
|
catalog.close();
|
||||||
|
}
|
||||||
|
std::vector<std::string> Datasets::getNames()
|
||||||
|
{
|
||||||
|
std::vector<std::string> result;
|
||||||
|
transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
bool Datasets::isDataset(const std::string& name) const
|
||||||
|
{
|
||||||
|
return datasets.find(name) != datasets.end();
|
||||||
|
}
|
||||||
|
std::string Datasets::toString() const
|
||||||
|
{
|
||||||
|
std::string result;
|
||||||
|
std::string sep = "";
|
||||||
|
for (const auto& d : datasets) {
|
||||||
|
result += sep + d.first;
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
return "{" + result + "}";
|
||||||
|
}
|
||||||
|
}
|
@@ -3,28 +3,20 @@
|
|||||||
#include "Dataset.h"
|
#include "Dataset.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Datasets {
|
class Datasets {
|
||||||
|
public:
|
||||||
|
explicit Datasets(bool discretize, std::string sfileType, std::string discretizer_algorithm = "none");
|
||||||
|
std::vector<std::string> getNames();
|
||||||
|
bool isDataset(const std::string& name) const;
|
||||||
|
Dataset& getDataset(const std::string& name) const { return *datasets.at(name); }
|
||||||
|
std::string toString() const;
|
||||||
private:
|
private:
|
||||||
std::string path;
|
std::string path;
|
||||||
fileType_t fileType;
|
fileType_t fileType;
|
||||||
std::string sfileType;
|
std::string sfileType;
|
||||||
|
std::string discretizer_algorithm;
|
||||||
std::map<std::string, std::unique_ptr<Dataset>> datasets;
|
std::map<std::string, std::unique_ptr<Dataset>> datasets;
|
||||||
bool discretize;
|
bool discretize;
|
||||||
void load(); // Loads the list of datasets
|
void load(); // Loads the list of datasets
|
||||||
public:
|
|
||||||
explicit Datasets(bool discretize, std::string sfileType) : discretize(discretize), sfileType(sfileType) { load(); };
|
|
||||||
std::vector<string> getNames();
|
|
||||||
std::vector<string> getFeatures(const std::string& name) const;
|
|
||||||
int getNSamples(const std::string& name) const;
|
|
||||||
std::string getClassName(const std::string& name) const;
|
|
||||||
int getNClasses(const std::string& name);
|
|
||||||
std::vector<int> getClassesCounts(const std::string& name) const;
|
|
||||||
std::map<std::string, std::vector<int>> getStates(const std::string& name) const;
|
|
||||||
std::pair<std::vector<std::vector<float>>&, std::vector<int>&> getVectors(const std::string& name);
|
|
||||||
std::pair<std::vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized(const std::string& name);
|
|
||||||
std::pair<torch::Tensor&, torch::Tensor&> getTensors(const std::string& name);
|
|
||||||
bool isDataset(const std::string& name) const;
|
|
||||||
void loadDataset(const std::string& name) const;
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
55
src/common/Discretization.cpp
Normal file
55
src/common/Discretization.cpp
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
#include "Discretization.h"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
||||||
|
Discretization* Discretization::factory = nullptr;
|
||||||
|
Discretization* Discretization::instance()
|
||||||
|
{
|
||||||
|
//manages singleton
|
||||||
|
if (factory == nullptr)
|
||||||
|
factory = new Discretization();
|
||||||
|
return factory;
|
||||||
|
}
|
||||||
|
void Discretization::registerFactoryFunction(const std::string& name,
|
||||||
|
function<mdlp::Discretizer* (void)> classFactoryFunction)
|
||||||
|
{
|
||||||
|
// register the class factory function
|
||||||
|
functionRegistry[name] = classFactoryFunction;
|
||||||
|
}
|
||||||
|
std::shared_ptr<mdlp::Discretizer> Discretization::create(const std::string& name)
|
||||||
|
{
|
||||||
|
mdlp::Discretizer* instance = nullptr;
|
||||||
|
|
||||||
|
// find name in the registry and call factory method.
|
||||||
|
auto it = functionRegistry.find(name);
|
||||||
|
if (it != functionRegistry.end())
|
||||||
|
instance = it->second();
|
||||||
|
// wrap instance in a shared ptr and return
|
||||||
|
if (instance != nullptr)
|
||||||
|
return std::unique_ptr<mdlp::Discretizer>(instance);
|
||||||
|
else
|
||||||
|
throw std::runtime_error("Discretizer not found: " + name);
|
||||||
|
}
|
||||||
|
std::vector<std::string> Discretization::getNames()
|
||||||
|
{
|
||||||
|
std::vector<std::string> names;
|
||||||
|
transform(functionRegistry.begin(), functionRegistry.end(), back_inserter(names),
|
||||||
|
[](const pair<std::string, function<mdlp::Discretizer* (void)>>& pair) { return pair.first; });
|
||||||
|
return names;
|
||||||
|
}
|
||||||
|
std::string Discretization::toString()
|
||||||
|
{
|
||||||
|
std::string result = "";
|
||||||
|
std::string sep = "";
|
||||||
|
for (const auto& pair : functionRegistry) {
|
||||||
|
result += sep + pair.first;
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
return "{" + result + "}";
|
||||||
|
}
|
||||||
|
RegistrarDiscretization::RegistrarDiscretization(const std::string& name, function<mdlp::Discretizer* (void)> classFactoryFunction)
|
||||||
|
{
|
||||||
|
// register the class factory function
|
||||||
|
Discretization::instance()->registerFactoryFunction(name, classFactoryFunction);
|
||||||
|
}
|
||||||
|
}
|
33
src/common/Discretization.h
Normal file
33
src/common/Discretization.h
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#ifndef DISCRETIZATION_H
|
||||||
|
#define DISCRETIZATION_H
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <functional>
|
||||||
|
#include <vector>
|
||||||
|
#include <fimdlp/Discretizer.h>
|
||||||
|
#include <fimdlp/BinDisc.h>
|
||||||
|
#include <fimdlp/CPPFImdlp.h>
|
||||||
|
namespace platform {
|
||||||
|
class Discretization {
|
||||||
|
public:
|
||||||
|
Discretization(Discretization&) = delete;
|
||||||
|
void operator=(const Discretization&) = delete;
|
||||||
|
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
|
||||||
|
static Discretization* instance();
|
||||||
|
std::shared_ptr<mdlp::Discretizer> create(const std::string& name);
|
||||||
|
void registerFactoryFunction(const std::string& name,
|
||||||
|
function<mdlp::Discretizer* (void)> classFactoryFunction);
|
||||||
|
std::vector<string> getNames();
|
||||||
|
std::string toString();
|
||||||
|
private:
|
||||||
|
map<std::string, function<mdlp::Discretizer* (void)>> functionRegistry;
|
||||||
|
static Discretization* factory; //singleton
|
||||||
|
Discretization() {};
|
||||||
|
};
|
||||||
|
class RegistrarDiscretization {
|
||||||
|
public:
|
||||||
|
RegistrarDiscretization(const std::string& className, function<mdlp::Discretizer* (void)> classFactoryFunction);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
38
src/common/DiscretizationRegister.h
Normal file
38
src/common/DiscretizationRegister.h
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#ifndef DISCRETIZATIONREGISTER_H
|
||||||
|
#define DISCRETIZATIONREGISTER_H
|
||||||
|
#include <common/Discretization.h>
|
||||||
|
static platform::RegistrarDiscretization registrarM("mdlp",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::CPPFImdlp();});
|
||||||
|
static platform::RegistrarDiscretization registrarBU3("bin3u",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(3, mdlp::strategy_t::UNIFORM);});
|
||||||
|
static platform::RegistrarDiscretization registrarBQ3("bin3q",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(3, mdlp::strategy_t::QUANTILE);});
|
||||||
|
static platform::RegistrarDiscretization registrarBU4("bin4u",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(4, mdlp::strategy_t::UNIFORM);});
|
||||||
|
static platform::RegistrarDiscretization registrarBQ4("bin4q",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(4, mdlp::strategy_t::QUANTILE);});
|
||||||
|
static platform::RegistrarDiscretization registrarBU5("bin5u",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(5, mdlp::strategy_t::UNIFORM);});
|
||||||
|
static platform::RegistrarDiscretization registrarBQ5("bin5q",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(5, mdlp::strategy_t::QUANTILE);});
|
||||||
|
static platform::RegistrarDiscretization registrarBU6("bin6u",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(6, mdlp::strategy_t::UNIFORM);});
|
||||||
|
static platform::RegistrarDiscretization registrarBQ6("bin6q",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(6, mdlp::strategy_t::QUANTILE);});
|
||||||
|
static platform::RegistrarDiscretization registrarBU7("bin7u",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(7, mdlp::strategy_t::UNIFORM);});
|
||||||
|
static platform::RegistrarDiscretization registrarBQ7("bin7q",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(7, mdlp::strategy_t::QUANTILE);});
|
||||||
|
static platform::RegistrarDiscretization registrarBU8("bin8u",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(8, mdlp::strategy_t::UNIFORM);});
|
||||||
|
static platform::RegistrarDiscretization registrarBQ8("bin8q",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(8, mdlp::strategy_t::QUANTILE);});
|
||||||
|
static platform::RegistrarDiscretization registrarBU9("bin9u",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(9, mdlp::strategy_t::UNIFORM);});
|
||||||
|
static platform::RegistrarDiscretization registrarBQ9("bin9q",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(9, mdlp::strategy_t::QUANTILE);});
|
||||||
|
static platform::RegistrarDiscretization registrarBU10("bin10u",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(10, mdlp::strategy_t::UNIFORM);});
|
||||||
|
static platform::RegistrarDiscretization registrarBQ10("bin10q",
|
||||||
|
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(10, mdlp::strategy_t::QUANTILE);});
|
||||||
|
#endif
|
@@ -13,9 +13,55 @@ namespace platform {
|
|||||||
class DotEnv {
|
class DotEnv {
|
||||||
private:
|
private:
|
||||||
std::map<std::string, std::string> env;
|
std::map<std::string, std::string> env;
|
||||||
|
std::map<std::string, std::vector<std::string>> valid;
|
||||||
public:
|
public:
|
||||||
DotEnv()
|
DotEnv(bool create = false)
|
||||||
{
|
{
|
||||||
|
valid =
|
||||||
|
{
|
||||||
|
{"depth", {"any"}},
|
||||||
|
{"discretize", {"0", "1"}},
|
||||||
|
{"discretize_algo", {"mdlp", "bin3u", "bin3q", "bin4u", "bin4q", "bin5q", "bin5u", "bin6q", "bin6u", "bin7q", "bin7u", "bin8q", "bin8u", "bin9q", "bin9u", "bin10q", "bin10u"}},
|
||||||
|
{"experiment", {"discretiz", "odte", "covid", "Test"}},
|
||||||
|
{"fit_features", {"0", "1"}},
|
||||||
|
{"framework", {"bulma", "bootstrap"}},
|
||||||
|
{"ignore_nan", {"0", "1"}},
|
||||||
|
{"leaves", {"any"}},
|
||||||
|
{"margin", {"0.1", "0.2", "0.3"}},
|
||||||
|
{"model", {"any"}},
|
||||||
|
{"n_folds", {"5", "10"}},
|
||||||
|
{"nodes", {"any"}},
|
||||||
|
{"platform", {"any"}},
|
||||||
|
{"stratified", {"0", "1"}},
|
||||||
|
{"score", {"accuracy", "roc-auc-ovr"}},
|
||||||
|
{"seeds", {"any"}},
|
||||||
|
{"smooth_strat", {"ORIGINAL", "LAPLACE", "CESTNIK"}},
|
||||||
|
{"source_data", {"Arff", "Tanveer", "Surcov", "Test"}},
|
||||||
|
};
|
||||||
|
if (create) {
|
||||||
|
// For testing purposes
|
||||||
|
std::ofstream file(".env");
|
||||||
|
file << "experiment=Test" << std::endl;
|
||||||
|
file << "source_data=Test" << std::endl;
|
||||||
|
file << "margin=0.1" << std::endl;
|
||||||
|
file << "score=accuracy" << std::endl;
|
||||||
|
file << "platform=um790Linux" << std::endl;
|
||||||
|
file << "n_folds=5" << std::endl;
|
||||||
|
file << "discretize_algo=mdlp" << std::endl;
|
||||||
|
file << "smooth_strat=ORIGINAL" << std::endl;
|
||||||
|
file << "stratified=0" << std::endl;
|
||||||
|
file << "model=TAN" << std::endl;
|
||||||
|
file << "seeds=[271]" << std::endl;
|
||||||
|
file << "discretize=0" << std::endl;
|
||||||
|
file << "ignore_nan=0" << std::endl;
|
||||||
|
file << "nodes=Nodes" << std::endl;
|
||||||
|
file << "leaves=Edges" << std::endl;
|
||||||
|
file << "depth=States" << std::endl;
|
||||||
|
file << "fit_features=0" << std::endl;
|
||||||
|
file << "framework=bulma" << std::endl;
|
||||||
|
file << "margin=0.1" << std::endl;
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
std::ifstream file(".env");
|
std::ifstream file(".env");
|
||||||
if (!file.is_open()) {
|
if (!file.is_open()) {
|
||||||
std::cerr << "File .env not found" << std::endl;
|
std::cerr << "File .env not found" << std::endl;
|
||||||
@@ -30,12 +76,62 @@ namespace platform {
|
|||||||
std::istringstream iss(line);
|
std::istringstream iss(line);
|
||||||
std::string key, value;
|
std::string key, value;
|
||||||
if (std::getline(iss, key, '=') && std::getline(iss, value)) {
|
if (std::getline(iss, key, '=') && std::getline(iss, value)) {
|
||||||
|
key = trim(key);
|
||||||
|
value = trim(value);
|
||||||
|
parse(key, value);
|
||||||
env[key] = value;
|
env[key] = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
parseEnv();
|
||||||
|
}
|
||||||
|
void parse(const std::string& key, const std::string& value)
|
||||||
|
{
|
||||||
|
if (valid.find(key) == valid.end()) {
|
||||||
|
std::cerr << "Invalid key in .env: " << key << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
if (valid[key].front() == "any") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (std::find(valid[key].begin(), valid[key].end(), value) == valid[key].end()) {
|
||||||
|
std::cerr << "Invalid value in .env: " << key << " = " << value << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<std::string> valid_tokens(const std::string& key)
|
||||||
|
{
|
||||||
|
if (valid.find(key) == valid.end()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return valid.at(key);
|
||||||
|
}
|
||||||
|
std::string valid_values(const std::string& key)
|
||||||
|
{
|
||||||
|
std::string valid_values = "{", sep = "";
|
||||||
|
if (valid.find(key) == valid.end()) {
|
||||||
|
return "{}";
|
||||||
|
}
|
||||||
|
for (const auto& value : valid.at(key)) {
|
||||||
|
valid_values += sep + value;
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
return valid_values + "}";
|
||||||
|
}
|
||||||
|
void parseEnv()
|
||||||
|
{
|
||||||
|
for (auto& [key, values] : valid) {
|
||||||
|
if (env.find(key) == env.end()) {
|
||||||
|
std::cerr << "Key not found in .env: " << key << ", valid values: " << valid_values(key) << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::string get(const std::string& key)
|
std::string get(const std::string& key)
|
||||||
{
|
{
|
||||||
|
if (env.find(key) == env.end()) {
|
||||||
|
std::cerr << "Key not found in .env: " << key << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
return env.at(key);
|
return env.at(key);
|
||||||
}
|
}
|
||||||
std::vector<int> getSeeds()
|
std::vector<int> getSeeds()
|
||||||
|
@@ -6,15 +6,30 @@
|
|||||||
namespace platform {
|
namespace platform {
|
||||||
class Paths {
|
class Paths {
|
||||||
public:
|
public:
|
||||||
static std::string results() { return "results/"; }
|
static std::string createIfNotExists(const std::string& folder)
|
||||||
static std::string hiddenResults() { return "hidden_results/"; }
|
{
|
||||||
static std::string excel() { return "excel/"; }
|
if (!std::filesystem::exists(folder)) {
|
||||||
static std::string grid() { return "grid/"; }
|
std::filesystem::create_directory(folder);
|
||||||
|
}
|
||||||
|
return folder;
|
||||||
|
}
|
||||||
|
static std::string results() { return createIfNotExists("results/"); }
|
||||||
|
static std::string hiddenResults() { return createIfNotExists("hidden_results/"); }
|
||||||
|
static std::string excel() { return createIfNotExists("excel/"); }
|
||||||
|
static std::string grid() { return createIfNotExists("grid/"); }
|
||||||
|
static std::string graphs() { return createIfNotExists("graphs/"); }
|
||||||
|
static std::string tex() { return createIfNotExists("tex/"); }
|
||||||
static std::string datasets()
|
static std::string datasets()
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
return env.get("source_data");
|
return env.get("source_data");
|
||||||
}
|
}
|
||||||
|
static std::string experiment_file(const std::string& fileName, bool discretize, bool stratified, int seed, int nfold)
|
||||||
|
{
|
||||||
|
std::string disc = discretize ? "_disc_" : "_ndisc_";
|
||||||
|
std::string strat = stratified ? "strat_" : "nstrat_";
|
||||||
|
return "datasets_experiment/" + fileName + disc + strat + std::to_string(seed) + "_" + std::to_string(nfold) + ".json";
|
||||||
|
}
|
||||||
static void createPath(const std::string& path)
|
static void createPath(const std::string& path)
|
||||||
{
|
{
|
||||||
// Create directory if it does not exist
|
// Create directory if it does not exist
|
||||||
@@ -25,6 +40,14 @@ namespace platform {
|
|||||||
throw std::runtime_error("Could not create directory " + path);
|
throw std::runtime_error("Could not create directory " + path);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
static std::string bestResultsFile(const std::string& score, const std::string& model)
|
||||||
|
{
|
||||||
|
return "best_results_" + score + "_" + model + ".json";
|
||||||
|
}
|
||||||
|
static std::string bestResultsExcel(const std::string& score)
|
||||||
|
{
|
||||||
|
return "BestResults_" + score + ".xlsx";
|
||||||
|
}
|
||||||
static std::string excelResults() { return "some_results.xlsx"; }
|
static std::string excelResults() { return "some_results.xlsx"; }
|
||||||
static std::string grid_input(const std::string& model)
|
static std::string grid_input(const std::string& model)
|
||||||
{
|
{
|
||||||
@@ -34,6 +57,22 @@ namespace platform {
|
|||||||
{
|
{
|
||||||
return grid() + "grid_" + model + "_output.json";
|
return grid() + "grid_" + model + "_output.json";
|
||||||
}
|
}
|
||||||
|
static std::string tex_output()
|
||||||
|
{
|
||||||
|
return "results.tex";
|
||||||
|
}
|
||||||
|
static std::string md_output()
|
||||||
|
{
|
||||||
|
return "results.md";
|
||||||
|
}
|
||||||
|
static std::string tex_post_hoc()
|
||||||
|
{
|
||||||
|
return "post_hoc.tex";
|
||||||
|
}
|
||||||
|
static std::string md_post_hoc()
|
||||||
|
{
|
||||||
|
return "post_hoc.md";
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
38
src/common/SourceData.h.in
Normal file
38
src/common/SourceData.h.in
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#ifndef SOURCEDATA_H
|
||||||
|
#define SOURCEDATA_H
|
||||||
|
namespace platform {
|
||||||
|
enum fileType_t { CSV, ARFF, RDATA };
|
||||||
|
class SourceData {
|
||||||
|
public:
|
||||||
|
SourceData(std::string source)
|
||||||
|
{
|
||||||
|
if (source == "Surcov") {
|
||||||
|
path = "datasets/";
|
||||||
|
fileType = CSV;
|
||||||
|
} else if (source == "Arff") {
|
||||||
|
path = "datasets/";
|
||||||
|
fileType = ARFF;
|
||||||
|
} else if (source == "Tanveer") {
|
||||||
|
path = "data/";
|
||||||
|
fileType = RDATA;
|
||||||
|
} else if (source == "Test") {
|
||||||
|
path = "@TEST_DATA_PATH@/";
|
||||||
|
fileType = ARFF;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("Unknown source.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string getPath()
|
||||||
|
{
|
||||||
|
return path;
|
||||||
|
}
|
||||||
|
fileType_t getFileType()
|
||||||
|
{
|
||||||
|
return fileType;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
std::string path;
|
||||||
|
fileType_t fileType;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
@@ -9,10 +9,13 @@ namespace platform {
|
|||||||
inline static const std::string black_star{ "\u2605" };
|
inline static const std::string black_star{ "\u2605" };
|
||||||
inline static const std::string cross{ "\u2717" };
|
inline static const std::string cross{ "\u2717" };
|
||||||
inline static const std::string upward_arrow{ "\u27B6" };
|
inline static const std::string upward_arrow{ "\u27B6" };
|
||||||
inline static const std::string down_arrow{ "\u27B4" };
|
inline static const std::string downward_arrow{ "\u27B4" };
|
||||||
|
inline static const std::string up_arrow{ "\u2B06" };
|
||||||
|
inline static const std::string down_arrow{ "\u2B07" };
|
||||||
|
inline static const std::string ellipsis{ "\u2026" };
|
||||||
inline static const std::string equal_best{ check_mark };
|
inline static const std::string equal_best{ check_mark };
|
||||||
inline static const std::string better_best{ black_star };
|
inline static const std::string better_best{ black_star };
|
||||||
inline static const std::string notebook{ "\U0001F5C8" };
|
inline static const std::string notebook{ "\U0001F5C8" };
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif // !SYMBOLS_H
|
#endif
|
@@ -40,4 +40,4 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
#endif /* TIMER_H */
|
#endif
|
@@ -3,16 +3,16 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <torch/torch.h>
|
||||||
namespace platform {
|
namespace platform {
|
||||||
//static std::vector<std::string> split(const std::string& text, char delimiter);
|
template <typename T>
|
||||||
static std::vector<std::string> split(const std::string& text, char delimiter)
|
std::vector<T> tensorToVector(const torch::Tensor& tensor)
|
||||||
{
|
{
|
||||||
std::vector<std::string> result;
|
torch::Tensor contig_tensor = tensor.contiguous();
|
||||||
std::stringstream ss(text);
|
auto num_elements = contig_tensor.numel();
|
||||||
std::string token;
|
const T* tensor_data = contig_tensor.data_ptr<T>();
|
||||||
while (std::getline(ss, token, delimiter)) {
|
std::vector<T> result(tensor_data, tensor_data + num_elements);
|
||||||
result.push_back(token);
|
|
||||||
}
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
static std::string trim(const std::string& str)
|
static std::string trim(const std::string& str)
|
||||||
@@ -26,5 +26,45 @@ namespace platform {
|
|||||||
}).base(), result.end());
|
}).base(), result.end());
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
static std::vector<std::string> split(const std::string& text, char delimiter)
|
||||||
|
{
|
||||||
|
std::vector<std::string> result;
|
||||||
|
std::stringstream ss(text);
|
||||||
|
std::string token;
|
||||||
|
while (std::getline(ss, token, delimiter)) {
|
||||||
|
result.push_back(trim(token));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
inline double compute_std(std::vector<double> values, double mean)
|
||||||
|
{
|
||||||
|
// Compute standard devation of the values
|
||||||
|
double sum = 0.0;
|
||||||
|
for (const auto& value : values) {
|
||||||
|
sum += std::pow(value - mean, 2);
|
||||||
|
}
|
||||||
|
double variance = sum / values.size();
|
||||||
|
return std::sqrt(variance);
|
||||||
|
}
|
||||||
|
inline std::string get_date()
|
||||||
|
{
|
||||||
|
time_t rawtime;
|
||||||
|
tm* timeinfo;
|
||||||
|
time(&rawtime);
|
||||||
|
timeinfo = std::localtime(&rawtime);
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << std::put_time(timeinfo, "%Y-%m-%d");
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
inline std::string get_time()
|
||||||
|
{
|
||||||
|
time_t rawtime;
|
||||||
|
tm* timeinfo;
|
||||||
|
time(&rawtime);
|
||||||
|
timeinfo = std::localtime(&rawtime);
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << std::put_time(timeinfo, "%H:%M:%S");
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -1,5 +1,5 @@
|
|||||||
#include "GridData.h"
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include "GridData.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
GridData::GridData(const std::string& fileName)
|
GridData::GridData(const std::string& fileName)
|
@@ -6,7 +6,7 @@
|
|||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::ordered_json;
|
||||||
const std::string ALL_DATASETS = "all";
|
const std::string ALL_DATASETS = "all";
|
||||||
class GridData {
|
class GridData {
|
||||||
public:
|
public:
|
||||||
@@ -23,4 +23,4 @@ namespace platform {
|
|||||||
std::map<std::string, json> grid;
|
std::map<std::string, json> grid;
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
#endif /* GRIDDATA_H */
|
#endif
|
@@ -1,33 +1,15 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
#include <folding.hpp>
|
||||||
|
#include "main/Models.h"
|
||||||
|
#include "common/Paths.h"
|
||||||
|
#include "common/Colors.h"
|
||||||
|
#include "common/Utils.h"
|
||||||
#include "GridSearch.h"
|
#include "GridSearch.h"
|
||||||
#include "Models.h"
|
|
||||||
#include "Paths.h"
|
|
||||||
#include "folding.hpp"
|
|
||||||
#include "Colors.h"
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
std::string get_date()
|
|
||||||
{
|
|
||||||
time_t rawtime;
|
|
||||||
tm* timeinfo;
|
|
||||||
time(&rawtime);
|
|
||||||
timeinfo = std::localtime(&rawtime);
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << std::put_time(timeinfo, "%Y-%m-%d");
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
std::string get_time()
|
|
||||||
{
|
|
||||||
time_t rawtime;
|
|
||||||
tm* timeinfo;
|
|
||||||
time(&rawtime);
|
|
||||||
timeinfo = std::localtime(&rawtime);
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << std::put_time(timeinfo, "%H:%M:%S");
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
std::string get_color_rank(int rank)
|
std::string get_color_rank(int rank)
|
||||||
{
|
{
|
||||||
auto colors = { Colors::WHITE(), Colors::RED(), Colors::GREEN(), Colors::BLUE(), Colors::MAGENTA(), Colors::CYAN() };
|
auto colors = { Colors::WHITE(), Colors::RED(), Colors::GREEN(), Colors::BLUE(), Colors::MAGENTA(), Colors::CYAN() };
|
||||||
@@ -103,11 +85,11 @@ namespace platform {
|
|||||||
std::mt19937 g{ 271 }; // Use fixed seed to obtain the same shuffle
|
std::mt19937 g{ 271 }; // Use fixed seed to obtain the same shuffle
|
||||||
std::shuffle(tasks.begin(), tasks.end(), g);
|
std::shuffle(tasks.begin(), tasks.end(), g);
|
||||||
std::cout << get_color_rank(rank) << "* Number of tasks: " << tasks.size() << std::endl;
|
std::cout << get_color_rank(rank) << "* Number of tasks: " << tasks.size() << std::endl;
|
||||||
std::cout << "|";
|
std::cout << separator;
|
||||||
for (int i = 0; i < tasks.size(); ++i) {
|
for (int i = 0; i < tasks.size(); ++i) {
|
||||||
std::cout << (i + 1) % 10;
|
std::cout << (i + 1) % 10;
|
||||||
}
|
}
|
||||||
std::cout << "|" << std::endl << "|" << std::flush;
|
std::cout << separator << std::endl << separator << std::flush;
|
||||||
return tasks;
|
return tasks;
|
||||||
}
|
}
|
||||||
void process_task_mpi_consumer(struct ConfigGrid& config, struct ConfigMPI& config_mpi, json& tasks, int n_task, Datasets& datasets, Task_Result* result)
|
void process_task_mpi_consumer(struct ConfigGrid& config, struct ConfigMPI& config_mpi, json& tasks, int n_task, Datasets& datasets, Task_Result* result)
|
||||||
@@ -118,17 +100,18 @@ namespace platform {
|
|||||||
json task = tasks[n_task];
|
json task = tasks[n_task];
|
||||||
auto model = config.model;
|
auto model = config.model;
|
||||||
auto grid = GridData(Paths::grid_input(model));
|
auto grid = GridData(Paths::grid_input(model));
|
||||||
auto dataset = task["dataset"].get<std::string>();
|
auto dataset_name = task["dataset"].get<std::string>();
|
||||||
auto idx_dataset = task["idx_dataset"].get<int>();
|
auto idx_dataset = task["idx_dataset"].get<int>();
|
||||||
auto seed = task["seed"].get<int>();
|
auto seed = task["seed"].get<int>();
|
||||||
auto n_fold = task["fold"].get<int>();
|
auto n_fold = task["fold"].get<int>();
|
||||||
bool stratified = config.stratified;
|
bool stratified = config.stratified;
|
||||||
// Generate the hyperparamters combinations
|
// Generate the hyperparamters combinations
|
||||||
auto combinations = grid.getGrid(dataset);
|
auto& dataset = datasets.getDataset(dataset_name);
|
||||||
auto [X, y] = datasets.getTensors(dataset);
|
auto combinations = grid.getGrid(dataset_name);
|
||||||
auto states = datasets.getStates(dataset);
|
dataset.load();
|
||||||
auto features = datasets.getFeatures(dataset);
|
auto [X, y] = dataset.getTensors();
|
||||||
auto className = datasets.getClassName(dataset);
|
auto features = dataset.getFeatures();
|
||||||
|
auto className = dataset.getClassName();
|
||||||
//
|
//
|
||||||
// Start working on task
|
// Start working on task
|
||||||
//
|
//
|
||||||
@@ -138,14 +121,11 @@ namespace platform {
|
|||||||
else
|
else
|
||||||
fold = new folding::KFold(config.n_folds, y.size(0), seed);
|
fold = new folding::KFold(config.n_folds, y.size(0), seed);
|
||||||
auto [train, test] = fold->getFold(n_fold);
|
auto [train, test] = fold->getFold(n_fold);
|
||||||
auto train_t = torch::tensor(train);
|
auto [X_train, X_test, y_train, y_test] = dataset.getTrainTestTensors(train, test);
|
||||||
auto test_t = torch::tensor(test);
|
auto states = dataset.getStates(); // Get the states of the features Once they are discretized
|
||||||
auto X_train = X.index({ "...", train_t });
|
|
||||||
auto y_train = y.index({ train_t });
|
|
||||||
auto X_test = X.index({ "...", test_t });
|
|
||||||
auto y_test = y.index({ test_t });
|
|
||||||
double best_fold_score = 0.0;
|
double best_fold_score = 0.0;
|
||||||
int best_idx_combination = -1;
|
int best_idx_combination = -1;
|
||||||
|
bayesnet::Smoothing_t smoothing = bayesnet::Smoothing_t::NONE;
|
||||||
json best_fold_hyper;
|
json best_fold_hyper;
|
||||||
for (int idx_combination = 0; idx_combination < combinations.size(); ++idx_combination) {
|
for (int idx_combination = 0; idx_combination < combinations.size(); ++idx_combination) {
|
||||||
auto hyperparam_line = combinations[idx_combination];
|
auto hyperparam_line = combinations[idx_combination];
|
||||||
@@ -168,10 +148,10 @@ namespace platform {
|
|||||||
// Build Classifier with selected hyperparameters
|
// Build Classifier with selected hyperparameters
|
||||||
auto clf = Models::instance()->create(config.model);
|
auto clf = Models::instance()->create(config.model);
|
||||||
auto valid = clf->getValidHyperparameters();
|
auto valid = clf->getValidHyperparameters();
|
||||||
hyperparameters.check(valid, dataset);
|
hyperparameters.check(valid, dataset_name);
|
||||||
clf->setHyperparameters(hyperparameters.get(dataset));
|
clf->setHyperparameters(hyperparameters.get(dataset_name));
|
||||||
// Train model
|
// Train model
|
||||||
clf->fit(X_nested_train, y_nested_train, features, className, states);
|
clf->fit(X_nested_train, y_nested_train, features, className, states, smoothing);
|
||||||
// Test model
|
// Test model
|
||||||
score += clf->score(X_nested_test, y_nested_test);
|
score += clf->score(X_nested_test, y_nested_test);
|
||||||
}
|
}
|
||||||
@@ -188,9 +168,9 @@ namespace platform {
|
|||||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), best_fold_hyper);
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), best_fold_hyper);
|
||||||
auto clf = Models::instance()->create(config.model);
|
auto clf = Models::instance()->create(config.model);
|
||||||
auto valid = clf->getValidHyperparameters();
|
auto valid = clf->getValidHyperparameters();
|
||||||
hyperparameters.check(valid, dataset);
|
hyperparameters.check(valid, dataset_name);
|
||||||
clf->setHyperparameters(best_fold_hyper);
|
clf->setHyperparameters(best_fold_hyper);
|
||||||
clf->fit(X_train, y_train, features, className, states);
|
clf->fit(X_train, y_train, features, className, states, smoothing);
|
||||||
best_fold_score = clf->score(X_test, y_test);
|
best_fold_score = clf->score(X_test, y_test);
|
||||||
// Return the result
|
// Return the result
|
||||||
result->idx_dataset = task["idx_dataset"].get<int>();
|
result->idx_dataset = task["idx_dataset"].get<int>();
|
||||||
@@ -373,14 +353,16 @@ namespace platform {
|
|||||||
MPI_Bcast(msg, tasks_size + 1, MPI_CHAR, config_mpi.manager, MPI_COMM_WORLD);
|
MPI_Bcast(msg, tasks_size + 1, MPI_CHAR, config_mpi.manager, MPI_COMM_WORLD);
|
||||||
tasks = json::parse(msg);
|
tasks = json::parse(msg);
|
||||||
delete[] msg;
|
delete[] msg;
|
||||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
auto env = platform::DotEnv();
|
||||||
|
auto datasets = Datasets(config.discretize, Paths::datasets(), env.get("discretize_algo"));
|
||||||
|
|
||||||
if (config_mpi.rank == config_mpi.manager) {
|
if (config_mpi.rank == config_mpi.manager) {
|
||||||
//
|
//
|
||||||
// 2a. Producer delivers the tasks to the consumers
|
// 2a. Producer delivers the tasks to the consumers
|
||||||
//
|
//
|
||||||
auto datasets_names = filterDatasets(datasets);
|
auto datasets_names = filterDatasets(datasets);
|
||||||
json all_results = producer(datasets_names, tasks, config_mpi, MPI_Result);
|
json all_results = producer(datasets_names, tasks, config_mpi, MPI_Result);
|
||||||
std::cout << get_color_rank(config_mpi.rank) << "|" << std::endl;
|
std::cout << get_color_rank(config_mpi.rank) << separator << std::endl;
|
||||||
//
|
//
|
||||||
// 3. Manager select the bests sccores for each dataset
|
// 3. Manager select the bests sccores for each dataset
|
||||||
//
|
//
|
@@ -4,13 +4,13 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include "Datasets.h"
|
#include "common/Datasets.h"
|
||||||
#include "HyperParameters.h"
|
#include "common/Timer.h"
|
||||||
|
#include "main/HyperParameters.h"
|
||||||
#include "GridData.h"
|
#include "GridData.h"
|
||||||
#include "Timer.h"
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::ordered_json;
|
||||||
struct ConfigGrid {
|
struct ConfigGrid {
|
||||||
std::string model;
|
std::string model;
|
||||||
std::string score;
|
std::string score;
|
||||||
@@ -55,6 +55,7 @@ namespace platform {
|
|||||||
struct ConfigGrid config;
|
struct ConfigGrid config;
|
||||||
json build_tasks_mpi(int rank);
|
json build_tasks_mpi(int rank);
|
||||||
Timer timer; // used to measure the time of the whole process
|
Timer timer; // used to measure the time of the whole process
|
||||||
|
const std::string separator = "|";
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
#endif /* GRIDSEARCH_H */
|
#endif
|
@@ -1,23 +0,0 @@
|
|||||||
#ifndef DATASETS_EXCEL_H
|
|
||||||
#define DATASETS_EXCEL_H
|
|
||||||
#include "ExcelFile.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
|
|
||||||
using json = nlohmann::json;
|
|
||||||
|
|
||||||
namespace platform {
|
|
||||||
|
|
||||||
class DatasetsExcel : public ExcelFile {
|
|
||||||
public:
|
|
||||||
explicit DatasetsExcel(json& data);
|
|
||||||
~DatasetsExcel();
|
|
||||||
void report();
|
|
||||||
private:
|
|
||||||
void formatColumns(int dataset, int balance);
|
|
||||||
json data;
|
|
||||||
|
|
||||||
};
|
|
||||||
}
|
|
||||||
#endif //DATASETS_EXCEL_H
|
|
@@ -1,80 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <locale>
|
|
||||||
#include <argparse/argparse.hpp>
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
#include "Paths.h"
|
|
||||||
#include "Colors.h"
|
|
||||||
#include "Datasets.h"
|
|
||||||
#include "DatasetsExcel.h"
|
|
||||||
#include "config.h"
|
|
||||||
|
|
||||||
const int BALANCE_LENGTH = 75;
|
|
||||||
|
|
||||||
struct separated : numpunct<char> {
|
|
||||||
char do_decimal_point() const { return ','; }
|
|
||||||
char do_thousands_sep() const { return '.'; }
|
|
||||||
std::string do_grouping() const { return "\03"; }
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string outputBalance(const std::string& balance)
|
|
||||||
{
|
|
||||||
auto temp = std::string(balance);
|
|
||||||
while (temp.size() > BALANCE_LENGTH - 1) {
|
|
||||||
auto part = temp.substr(0, BALANCE_LENGTH);
|
|
||||||
std::cout << part << std::endl;
|
|
||||||
std::cout << setw(52) << " ";
|
|
||||||
temp = temp.substr(BALANCE_LENGTH);
|
|
||||||
}
|
|
||||||
return temp;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
|
||||||
{
|
|
||||||
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
|
||||||
argparse::ArgumentParser program("b_list", { project_version.begin(), project_version.end() });
|
|
||||||
program.add_argument("--excel")
|
|
||||||
.help("Output in Excel format")
|
|
||||||
.default_value(false)
|
|
||||||
.implicit_value(true);
|
|
||||||
program.parse_args(argc, argv);
|
|
||||||
auto excel = program.get<bool>("--excel");
|
|
||||||
locale mylocale(std::cout.getloc(), new separated);
|
|
||||||
locale::global(mylocale);
|
|
||||||
std::cout.imbue(mylocale);
|
|
||||||
std::cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Balance" << std::endl;
|
|
||||||
std::string balanceBars = std::string(BALANCE_LENGTH, '=');
|
|
||||||
std::cout << "=== ============================== ====== ===== === " << balanceBars << std::endl;
|
|
||||||
int num = 0;
|
|
||||||
json data;
|
|
||||||
for (const auto& dataset : datasets.getNames()) {
|
|
||||||
auto color = num % 2 ? Colors::CYAN() : Colors::BLUE();
|
|
||||||
std::cout << color << setw(3) << right << num++ << " ";
|
|
||||||
std::cout << setw(30) << left << dataset << " ";
|
|
||||||
datasets.loadDataset(dataset);
|
|
||||||
auto nSamples = datasets.getNSamples(dataset);
|
|
||||||
std::cout << setw(6) << right << nSamples << " ";
|
|
||||||
std::cout << setw(5) << right << datasets.getFeatures(dataset).size() << " ";
|
|
||||||
std::cout << setw(3) << right << datasets.getNClasses(dataset) << " ";
|
|
||||||
std::stringstream oss;
|
|
||||||
std::string sep = "";
|
|
||||||
for (auto number : datasets.getClassesCounts(dataset)) {
|
|
||||||
oss << sep << std::setprecision(2) << fixed << (float)number / nSamples * 100.0 << "% (" << number << ")";
|
|
||||||
sep = " / ";
|
|
||||||
}
|
|
||||||
auto balance = outputBalance(oss.str());
|
|
||||||
std::cout << balance << std::endl;
|
|
||||||
// Store data for Excel report
|
|
||||||
data[dataset] = json::object();
|
|
||||||
data[dataset]["samples"] = nSamples;
|
|
||||||
data[dataset]["features"] = datasets.getFeatures(dataset).size();
|
|
||||||
data[dataset]["classes"] = datasets.getNClasses(dataset);
|
|
||||||
data[dataset]["balance"] = oss.str();
|
|
||||||
}
|
|
||||||
std::cout << Colors::RESET() << std::endl;
|
|
||||||
if (excel) {
|
|
||||||
auto report = platform::DatasetsExcel(data);
|
|
||||||
report.report();
|
|
||||||
std::cout << "Output saved in " << report.getFileName() << std::endl;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
@@ -1,175 +0,0 @@
|
|||||||
#include "Experiment.h"
|
|
||||||
#include "Datasets.h"
|
|
||||||
#include "Models.h"
|
|
||||||
#include "ReportConsole.h"
|
|
||||||
#include "Paths.h"
|
|
||||||
namespace platform {
|
|
||||||
using json = nlohmann::json;
|
|
||||||
|
|
||||||
void Experiment::saveResult()
|
|
||||||
{
|
|
||||||
result.save();
|
|
||||||
}
|
|
||||||
void Experiment::report()
|
|
||||||
{
|
|
||||||
ReportConsole report(result.getJson());
|
|
||||||
report.show();
|
|
||||||
}
|
|
||||||
void Experiment::show()
|
|
||||||
{
|
|
||||||
std::cout << result.getJson().dump(4) << std::endl;
|
|
||||||
}
|
|
||||||
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score)
|
|
||||||
{
|
|
||||||
for (auto fileName : filesToProcess) {
|
|
||||||
if (fileName.size() > max_name)
|
|
||||||
max_name = fileName.size();
|
|
||||||
}
|
|
||||||
std::cout << Colors::MAGENTA() << "*** Starting experiment: " << result.getTitle() << " ***" << Colors::RESET() << std::endl << std::endl;
|
|
||||||
if (!quiet) {
|
|
||||||
std::cout << Colors::GREEN() << " Status Meaning" << std::endl;
|
|
||||||
std::cout << " ------ --------------------------------" << Colors::RESET() << std::endl;
|
|
||||||
std::cout << " ( " << Colors::GREEN() << "a" << Colors::RESET() << " ) Fitting model with train dataset" << std::endl;
|
|
||||||
std::cout << " ( " << Colors::GREEN() << "b" << Colors::RESET() << " ) Scoring train dataset" << std::endl;
|
|
||||||
std::cout << " ( " << Colors::GREEN() << "c" << Colors::RESET() << " ) Scoring test dataset" << std::endl << std::endl;
|
|
||||||
std::cout << Colors::YELLOW() << "Note: fold number in this color means fitting had issues such as not using all features in BoostAODE classifier" << std::endl << std::endl;
|
|
||||||
std::cout << Colors::GREEN() << left << " # " << setw(max_name) << "Dataset" << " #Samp #Feat Seed Status" << std::endl;
|
|
||||||
std::cout << " --- " << string(max_name, '-') << " ----- ----- ---- " << string(4 + 3 * nfolds, '-') << Colors::RESET() << std::endl;
|
|
||||||
}
|
|
||||||
int num = 0;
|
|
||||||
for (auto fileName : filesToProcess) {
|
|
||||||
if (!quiet)
|
|
||||||
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
|
|
||||||
cross_validation(fileName, quiet, no_train_score);
|
|
||||||
if (!quiet)
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
|
||||||
if (!quiet)
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
|
||||||
std::string getColor(bayesnet::status_t status)
|
|
||||||
{
|
|
||||||
switch (status) {
|
|
||||||
case bayesnet::NORMAL:
|
|
||||||
return Colors::GREEN();
|
|
||||||
case bayesnet::WARNING:
|
|
||||||
return Colors::YELLOW();
|
|
||||||
case bayesnet::ERROR:
|
|
||||||
return Colors::RED();
|
|
||||||
default:
|
|
||||||
return Colors::RESET();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void showProgress(int fold, const std::string& color, const std::string& phase)
|
|
||||||
{
|
|
||||||
std::string prefix = phase == "a" ? "" : "\b\b\b\b";
|
|
||||||
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
|
||||||
|
|
||||||
}
|
|
||||||
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score)
|
|
||||||
{
|
|
||||||
auto datasets = Datasets(discretized, Paths::datasets());
|
|
||||||
// Get dataset
|
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
|
||||||
auto states = datasets.getStates(fileName);
|
|
||||||
auto features = datasets.getFeatures(fileName);
|
|
||||||
auto samples = datasets.getNSamples(fileName);
|
|
||||||
auto className = datasets.getClassName(fileName);
|
|
||||||
if (!quiet) {
|
|
||||||
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
|
|
||||||
}
|
|
||||||
// Prepare Result
|
|
||||||
auto partial_result = PartialResult();
|
|
||||||
auto [values, counts] = at::_unique(y);
|
|
||||||
partial_result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
|
||||||
partial_result.setHyperparameters(hyperparameters.get(fileName));
|
|
||||||
// Initialize results std::vectors
|
|
||||||
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
|
||||||
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
|
||||||
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
|
||||||
auto train_time = torch::zeros({ nResults }, torch::kFloat64);
|
|
||||||
auto test_time = torch::zeros({ nResults }, torch::kFloat64);
|
|
||||||
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
|
|
||||||
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
|
||||||
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
|
||||||
std::vector<std::string> notes;
|
|
||||||
Timer train_timer, test_timer;
|
|
||||||
int item = 0;
|
|
||||||
bool first_seed = true;
|
|
||||||
for (auto seed : randomSeeds) {
|
|
||||||
if (!quiet) {
|
|
||||||
string prefix = " ";
|
|
||||||
if (!first_seed) {
|
|
||||||
prefix = "\n" + string(18 + max_name, ' ');
|
|
||||||
}
|
|
||||||
std::cout << prefix << setw(4) << right << seed << " " << flush;
|
|
||||||
first_seed = false;
|
|
||||||
}
|
|
||||||
folding::Fold* fold;
|
|
||||||
if (stratified)
|
|
||||||
fold = new folding::StratifiedKFold(nfolds, y, seed);
|
|
||||||
else
|
|
||||||
fold = new folding::KFold(nfolds, y.size(0), seed);
|
|
||||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
|
||||||
auto clf = Models::instance()->create(result.getModel());
|
|
||||||
setModelVersion(clf->getVersion());
|
|
||||||
auto valid = clf->getValidHyperparameters();
|
|
||||||
hyperparameters.check(valid, fileName);
|
|
||||||
clf->setHyperparameters(hyperparameters.get(fileName));
|
|
||||||
// Split train - test dataset
|
|
||||||
train_timer.start();
|
|
||||||
auto [train, test] = fold->getFold(nfold);
|
|
||||||
auto train_t = torch::tensor(train);
|
|
||||||
auto test_t = torch::tensor(test);
|
|
||||||
auto X_train = X.index({ "...", train_t });
|
|
||||||
auto y_train = y.index({ train_t });
|
|
||||||
auto X_test = X.index({ "...", test_t });
|
|
||||||
auto y_test = y.index({ test_t });
|
|
||||||
if (!quiet)
|
|
||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
|
|
||||||
// Train model
|
|
||||||
clf->fit(X_train, y_train, features, className, states);
|
|
||||||
if (!quiet)
|
|
||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
|
|
||||||
auto clf_notes = clf->getNotes();
|
|
||||||
std::transform(clf_notes.begin(), clf_notes.end(), std::back_inserter(notes), [nfold](const std::string& note)
|
|
||||||
{ return "Fold " + std::to_string(nfold) + ": " + note; });
|
|
||||||
nodes[item] = clf->getNumberOfNodes();
|
|
||||||
edges[item] = clf->getNumberOfEdges();
|
|
||||||
num_states[item] = clf->getNumberOfStates();
|
|
||||||
train_time[item] = train_timer.getDuration();
|
|
||||||
double accuracy_train_value = 0.0;
|
|
||||||
// Score train
|
|
||||||
if (!no_train_score)
|
|
||||||
accuracy_train_value = clf->score(X_train, y_train);
|
|
||||||
// Test model
|
|
||||||
if (!quiet)
|
|
||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
|
||||||
test_timer.start();
|
|
||||||
auto accuracy_test_value = clf->score(X_test, y_test);
|
|
||||||
test_time[item] = test_timer.getDuration();
|
|
||||||
accuracy_train[item] = accuracy_train_value;
|
|
||||||
accuracy_test[item] = accuracy_test_value;
|
|
||||||
if (!quiet)
|
|
||||||
std::cout << "\b\b\b, " << flush;
|
|
||||||
// Store results and times in std::vector
|
|
||||||
partial_result.addScoreTrain(accuracy_train_value);
|
|
||||||
partial_result.addScoreTest(accuracy_test_value);
|
|
||||||
partial_result.addTimeTrain(train_time[item].item<double>());
|
|
||||||
partial_result.addTimeTest(test_time[item].item<double>());
|
|
||||||
item++;
|
|
||||||
}
|
|
||||||
if (!quiet)
|
|
||||||
std::cout << "end. " << flush;
|
|
||||||
delete fold;
|
|
||||||
}
|
|
||||||
partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
|
|
||||||
partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
|
|
||||||
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
|
||||||
partial_result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
|
|
||||||
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
|
||||||
partial_result.setDataset(fileName).setNotes(notes);
|
|
||||||
addResult(partial_result);
|
|
||||||
}
|
|
||||||
}
|
|
299
src/main/Experiment.cpp
Normal file
299
src/main/Experiment.cpp
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
#include "common/Datasets.h"
|
||||||
|
#include "reports/ReportConsole.h"
|
||||||
|
#include "common/Paths.h"
|
||||||
|
#include "Models.h"
|
||||||
|
#include "Scores.h"
|
||||||
|
#include "Experiment.h"
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
void Experiment::saveResult()
|
||||||
|
{
|
||||||
|
result.save();
|
||||||
|
std::cout << "Result saved in " << Paths::results() << result.getFilename() << std::endl;
|
||||||
|
}
|
||||||
|
void Experiment::report(bool classification_report)
|
||||||
|
{
|
||||||
|
ReportConsole report(result.getJson());
|
||||||
|
report.show();
|
||||||
|
if (classification_report) {
|
||||||
|
std::cout << report.showClassificationReport(Colors::BLUE());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Experiment::show()
|
||||||
|
{
|
||||||
|
std::cout << result.getJson().dump(4) << std::endl;
|
||||||
|
}
|
||||||
|
void Experiment::saveGraph()
|
||||||
|
{
|
||||||
|
std::cout << "Saving graphs..." << std::endl;
|
||||||
|
auto data = result.getJson();
|
||||||
|
for (const auto& item : data["results"]) {
|
||||||
|
auto graphs = item["graph"];
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& graph : graphs) {
|
||||||
|
i++;
|
||||||
|
auto fileName = Paths::graphs() + result.getFilename() + "_graph_" + item["dataset"].get<std::string>() + "_" + std::to_string(i) + ".dot";
|
||||||
|
auto file = std::ofstream(fileName);
|
||||||
|
file << graph.get<std::string>();
|
||||||
|
file.close();
|
||||||
|
std::cout << "Graph saved in " << fileName << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score, bool generate_fold_files, bool graph)
|
||||||
|
{
|
||||||
|
for (auto fileName : filesToProcess) {
|
||||||
|
if (fileName.size() > max_name)
|
||||||
|
max_name = fileName.size();
|
||||||
|
}
|
||||||
|
std::cout << Colors::MAGENTA() << "*** Starting experiment: " << result.getTitle() << " ***" << Colors::RESET() << std::endl << std::endl;
|
||||||
|
auto clf = Models::instance()->create(result.getModel());
|
||||||
|
auto version = clf->getVersion();
|
||||||
|
std::cout << Colors::BLUE() << " Using " << result.getModel() << " ver. " << version << std::endl << std::endl;
|
||||||
|
if (!quiet) {
|
||||||
|
std::cout << Colors::GREEN() << " Status Meaning" << std::endl;
|
||||||
|
std::cout << " ------ --------------------------------" << Colors::RESET() << std::endl;
|
||||||
|
std::cout << " ( " << Colors::GREEN() << "a" << Colors::RESET() << " ) Fitting model with train dataset" << std::endl;
|
||||||
|
std::cout << " ( " << Colors::GREEN() << "b" << Colors::RESET() << " ) Scoring train dataset" << std::endl;
|
||||||
|
std::cout << " ( " << Colors::GREEN() << "c" << Colors::RESET() << " ) Scoring test dataset" << std::endl << std::endl;
|
||||||
|
std::cout << Colors::YELLOW() << "Note: fold number in this color means fitting had issues such as not using all features in BoostAODE classifier" << std::endl << std::endl;
|
||||||
|
std::cout << Colors::GREEN() << left << " # " << setw(max_name) << "Dataset" << " #Samp #Feat Seed Status" << string(3 * nfolds - 2, ' ') << " Time" << std::endl;
|
||||||
|
std::cout << " --- " << string(max_name, '-') << " ----- ----- ---- " << string(4 + 3 * nfolds, '-') << " ----------" << Colors::RESET() << std::endl;
|
||||||
|
}
|
||||||
|
int num = 0;
|
||||||
|
for (auto fileName : filesToProcess) {
|
||||||
|
if (!quiet)
|
||||||
|
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
|
||||||
|
cross_validation(fileName, quiet, no_train_score, generate_fold_files, graph);
|
||||||
|
if (!quiet)
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
if (!quiet)
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
std::string getColor(bayesnet::status_t status)
|
||||||
|
{
|
||||||
|
switch (status) {
|
||||||
|
case bayesnet::NORMAL:
|
||||||
|
return Colors::GREEN();
|
||||||
|
case bayesnet::WARNING:
|
||||||
|
return Colors::YELLOW();
|
||||||
|
case bayesnet::ERROR:
|
||||||
|
return Colors::RED();
|
||||||
|
default:
|
||||||
|
return Colors::RESET();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
score_t Experiment::parse_score() const
|
||||||
|
{
|
||||||
|
if (result.getScoreName() == "accuracy")
|
||||||
|
return score_t::ACCURACY;
|
||||||
|
if (result.getScoreName() == "roc-auc-ovr")
|
||||||
|
return score_t::ROC_AUC_OVR;
|
||||||
|
throw std::runtime_error("Unknown score: " + result.getScoreName());
|
||||||
|
}
|
||||||
|
void showProgress(int fold, const std::string& color, const std::string& phase)
|
||||||
|
{
|
||||||
|
std::string prefix = phase == "-" ? "" : "\b\b\b\b";
|
||||||
|
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
||||||
|
|
||||||
|
}
|
||||||
|
void generate_files(const std::string& fileName, bool discretize, bool stratified, int seed, int nfold, torch::Tensor X_train, torch::Tensor y_train, torch::Tensor X_test, torch::Tensor y_test, std::vector<int>& train, std::vector<int>& test)
|
||||||
|
{
|
||||||
|
std::string file_name = Paths::experiment_file(fileName, discretize, stratified, seed, nfold);
|
||||||
|
auto file = std::ofstream(file_name);
|
||||||
|
json output;
|
||||||
|
output["seed"] = seed;
|
||||||
|
output["nfold"] = nfold;
|
||||||
|
output["X_train"] = json::array();
|
||||||
|
auto n = X_train.size(1);
|
||||||
|
for (int i = 0; i < X_train.size(0); i++) {
|
||||||
|
if (X_train.dtype() == torch::kFloat32) {
|
||||||
|
auto xvf_ptr = X_train.index({ i }).data_ptr<float>();
|
||||||
|
auto feature = std::vector<float>(xvf_ptr, xvf_ptr + n);
|
||||||
|
output["X_train"].push_back(feature);
|
||||||
|
} else {
|
||||||
|
auto feature = std::vector<int>(X_train.index({ i }).data_ptr<int>(), X_train.index({ i }).data_ptr<int>() + n);
|
||||||
|
output["X_train"].push_back(feature);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output["y_train"] = std::vector<int>(y_train.data_ptr<int>(), y_train.data_ptr<int>() + n);
|
||||||
|
output["X_test"] = json::array();
|
||||||
|
n = X_test.size(1);
|
||||||
|
for (int i = 0; i < X_test.size(0); i++) {
|
||||||
|
if (X_train.dtype() == torch::kFloat32) {
|
||||||
|
auto xvf_ptr = X_test.index({ i }).data_ptr<float>();
|
||||||
|
auto feature = std::vector<float>(xvf_ptr, xvf_ptr + n);
|
||||||
|
output["X_test"].push_back(feature);
|
||||||
|
} else {
|
||||||
|
auto feature = std::vector<int>(X_test.index({ i }).data_ptr<int>(), X_test.index({ i }).data_ptr<int>() + n);
|
||||||
|
output["X_test"].push_back(feature);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output["y_test"] = std::vector<int>(y_test.data_ptr<int>(), y_test.data_ptr<int>() + n);
|
||||||
|
output["train"] = train;
|
||||||
|
output["test"] = test;
|
||||||
|
file << output.dump(4);
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files, bool graph)
|
||||||
|
{
|
||||||
|
//
|
||||||
|
// Load dataset and prepare data
|
||||||
|
//
|
||||||
|
auto datasets = Datasets(discretized, Paths::datasets(), discretization_algo);
|
||||||
|
auto& dataset = datasets.getDataset(fileName);
|
||||||
|
dataset.load();
|
||||||
|
auto [X, y] = dataset.getTensors(); // Only need y for folding
|
||||||
|
auto features = dataset.getFeatures();
|
||||||
|
auto n_features = dataset.getNFeatures();
|
||||||
|
auto n_samples = dataset.getNSamples();
|
||||||
|
auto className = dataset.getClassName();
|
||||||
|
auto labels = dataset.getLabels();
|
||||||
|
int num_classes = dataset.getNClasses();
|
||||||
|
if (!quiet) {
|
||||||
|
std::cout << " " << setw(5) << n_samples << " " << setw(5) << n_features << flush;
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// Prepare Result
|
||||||
|
//
|
||||||
|
auto partial_result = PartialResult();
|
||||||
|
partial_result.setSamples(n_samples).setFeatures(n_features).setClasses(num_classes);
|
||||||
|
partial_result.setHyperparameters(hyperparameters.get(fileName));
|
||||||
|
//
|
||||||
|
// Initialize results std::vectors
|
||||||
|
//
|
||||||
|
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
||||||
|
auto score_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
|
auto score_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
|
auto train_time = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
|
auto test_time = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
|
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
|
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
|
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
|
json confusion_matrices = json::array();
|
||||||
|
json confusion_matrices_train = json::array();
|
||||||
|
std::vector<std::string> notes;
|
||||||
|
std::vector<std::string> graphs;
|
||||||
|
Timer train_timer, test_timer, seed_timer;
|
||||||
|
int item = 0;
|
||||||
|
bool first_seed = true;
|
||||||
|
//
|
||||||
|
// Loop over random seeds
|
||||||
|
//
|
||||||
|
auto score = parse_score();
|
||||||
|
for (auto seed : randomSeeds) {
|
||||||
|
seed_timer.start();
|
||||||
|
if (!quiet) {
|
||||||
|
string prefix = " ";
|
||||||
|
if (!first_seed) {
|
||||||
|
prefix = "\n" + string(18 + max_name, ' ');
|
||||||
|
}
|
||||||
|
std::cout << prefix << setw(4) << right << seed << " " << flush;
|
||||||
|
first_seed = false;
|
||||||
|
}
|
||||||
|
folding::Fold* fold;
|
||||||
|
if (stratified)
|
||||||
|
fold = new folding::StratifiedKFold(nfolds, y, seed);
|
||||||
|
else
|
||||||
|
fold = new folding::KFold(nfolds, n_samples, seed);
|
||||||
|
//
|
||||||
|
// Loop over folds
|
||||||
|
//
|
||||||
|
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||||
|
auto clf = Models::instance()->create(result.getModel());
|
||||||
|
if (!quiet)
|
||||||
|
showProgress(nfold + 1, getColor(clf->getStatus()), "-");
|
||||||
|
setModelVersion(clf->getVersion());
|
||||||
|
auto valid = clf->getValidHyperparameters();
|
||||||
|
hyperparameters.check(valid, fileName);
|
||||||
|
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||||
|
//
|
||||||
|
// Split train - test dataset
|
||||||
|
//
|
||||||
|
train_timer.start();
|
||||||
|
auto [train, test] = fold->getFold(nfold);
|
||||||
|
auto [X_train, X_test, y_train, y_test] = dataset.getTrainTestTensors(train, test);
|
||||||
|
auto states = dataset.getStates(); // Get the states of the features Once they are discretized
|
||||||
|
if (generate_fold_files)
|
||||||
|
generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test);
|
||||||
|
if (!quiet)
|
||||||
|
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
|
||||||
|
//
|
||||||
|
// Train model
|
||||||
|
//
|
||||||
|
clf->fit(X_train, y_train, features, className, states, smooth_type);
|
||||||
|
if (!quiet)
|
||||||
|
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
|
||||||
|
auto clf_notes = clf->getNotes();
|
||||||
|
std::transform(clf_notes.begin(), clf_notes.end(), std::back_inserter(notes), [nfold](const std::string& note)
|
||||||
|
{ return "Fold " + std::to_string(nfold) + ": " + note; });
|
||||||
|
nodes[item] = clf->getNumberOfNodes();
|
||||||
|
edges[item] = clf->getNumberOfEdges();
|
||||||
|
num_states[item] = clf->getNumberOfStates();
|
||||||
|
train_time[item] = train_timer.getDuration();
|
||||||
|
double score_train_value = 0.0;
|
||||||
|
//
|
||||||
|
// Score train
|
||||||
|
//
|
||||||
|
if (!no_train_score) {
|
||||||
|
auto y_proba_train = clf->predict_proba(X_train);
|
||||||
|
Scores scores(y_train, y_proba_train, num_classes, labels);
|
||||||
|
score_train_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc();
|
||||||
|
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// Test model
|
||||||
|
//
|
||||||
|
if (!quiet)
|
||||||
|
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
||||||
|
test_timer.start();
|
||||||
|
// auto y_predict = clf->predict(X_test);
|
||||||
|
auto y_proba_test = clf->predict_proba(X_test);
|
||||||
|
Scores scores(y_test, y_proba_test, num_classes, labels);
|
||||||
|
auto score_test_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc();
|
||||||
|
test_time[item] = test_timer.getDuration();
|
||||||
|
score_train[item] = score_train_value;
|
||||||
|
score_test[item] = score_test_value;
|
||||||
|
confusion_matrices.push_back(scores.get_confusion_matrix_json(true));
|
||||||
|
if (!quiet)
|
||||||
|
std::cout << "\b\b\b, " << flush;
|
||||||
|
//
|
||||||
|
// Store results and times in std::vector
|
||||||
|
//
|
||||||
|
partial_result.addScoreTrain(score_train_value);
|
||||||
|
partial_result.addScoreTest(score_test_value);
|
||||||
|
partial_result.addTimeTrain(train_time[item].item<double>());
|
||||||
|
partial_result.addTimeTest(test_time[item].item<double>());
|
||||||
|
item++;
|
||||||
|
if (graph) {
|
||||||
|
std::string result = "";
|
||||||
|
for (const auto& line : clf->graph()) {
|
||||||
|
result += line + "\n";
|
||||||
|
}
|
||||||
|
graphs.push_back(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!quiet) {
|
||||||
|
seed_timer.stop();
|
||||||
|
std::cout << "end. [" << seed_timer.getDurationString() << "]" << std::endl;
|
||||||
|
}
|
||||||
|
delete fold;
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// Store result totals in Result
|
||||||
|
//
|
||||||
|
partial_result.setGraph(graphs);
|
||||||
|
partial_result.setScoreTest(torch::mean(score_test).item<double>()).setScoreTrain(torch::mean(score_train).item<double>());
|
||||||
|
partial_result.setScoreTestStd(torch::std(score_test).item<double>()).setScoreTrainStd(torch::std(score_train).item<double>());
|
||||||
|
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
||||||
|
partial_result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
|
||||||
|
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
||||||
|
partial_result.setDataset(fileName).setNotes(notes);
|
||||||
|
partial_result.setConfusionMatrices(confusion_matrices);
|
||||||
|
if (!no_train_score)
|
||||||
|
partial_result.setConfusionMatricesTrain(confusion_matrices_train);
|
||||||
|
addResult(partial_result);
|
||||||
|
}
|
||||||
|
}
|
@@ -3,14 +3,15 @@
|
|||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "folding.hpp"
|
#include <folding.hpp>
|
||||||
#include "BaseClassifier.h"
|
#include "bayesnet/BaseClassifier.h"
|
||||||
#include "HyperParameters.h"
|
#include "HyperParameters.h"
|
||||||
#include "Result.h"
|
#include "results/Result.h"
|
||||||
|
#include "bayesnet/network/Network.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::ordered_json;
|
||||||
|
enum class score_t { NONE, ACCURACY, ROC_AUC_OVR };
|
||||||
class Experiment {
|
class Experiment {
|
||||||
public:
|
public:
|
||||||
Experiment() = default;
|
Experiment() = default;
|
||||||
@@ -20,6 +21,25 @@ namespace platform {
|
|||||||
Experiment& setModelVersion(const std::string& model_version) { this->result.setModelVersion(model_version); return *this; }
|
Experiment& setModelVersion(const std::string& model_version) { this->result.setModelVersion(model_version); return *this; }
|
||||||
Experiment& setModel(const std::string& model) { this->result.setModel(model); return *this; }
|
Experiment& setModel(const std::string& model) { this->result.setModel(model); return *this; }
|
||||||
Experiment& setLanguage(const std::string& language) { this->result.setLanguage(language); return *this; }
|
Experiment& setLanguage(const std::string& language) { this->result.setLanguage(language); return *this; }
|
||||||
|
Experiment& setDiscretizationAlgorithm(const std::string& discretization_algo)
|
||||||
|
{
|
||||||
|
this->discretization_algo = discretization_algo; this->result.setDiscretizationAlgorithm(discretization_algo); return *this;
|
||||||
|
}
|
||||||
|
Experiment& setSmoothSrategy(const std::string& smooth_strategy)
|
||||||
|
{
|
||||||
|
this->smooth_strategy = smooth_strategy; this->result.setSmoothStrategy(smooth_strategy);
|
||||||
|
if (smooth_strategy == "ORIGINAL")
|
||||||
|
smooth_type = bayesnet::Smoothing_t::ORIGINAL;
|
||||||
|
else if (smooth_strategy == "LAPLACE")
|
||||||
|
smooth_type = bayesnet::Smoothing_t::LAPLACE;
|
||||||
|
else if (smooth_strategy == "CESTNIK")
|
||||||
|
smooth_type = bayesnet::Smoothing_t::CESTNIK;
|
||||||
|
else {
|
||||||
|
std::cerr << "Experiment: Unknown smoothing strategy: " << smooth_strategy << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
Experiment& setLanguageVersion(const std::string& language_version) { this->result.setLanguageVersion(language_version); return *this; }
|
Experiment& setLanguageVersion(const std::string& language_version) { this->result.setLanguageVersion(language_version); return *this; }
|
||||||
Experiment& setDiscretized(bool discretized) { this->discretized = discretized; result.setDiscretized(discretized); return *this; }
|
Experiment& setDiscretized(bool discretized) { this->discretized = discretized; result.setDiscretized(discretized); return *this; }
|
||||||
Experiment& setStratified(bool stratified) { this->stratified = stratified; result.setStratified(stratified); return *this; }
|
Experiment& setStratified(bool stratified) { this->stratified = stratified; result.setStratified(stratified); return *this; }
|
||||||
@@ -28,16 +48,21 @@ namespace platform {
|
|||||||
Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); result.addSeed(randomSeed); return *this; }
|
Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); result.addSeed(randomSeed); return *this; }
|
||||||
Experiment& setDuration(float duration) { this->result.setDuration(duration); return *this; }
|
Experiment& setDuration(float duration) { this->result.setDuration(duration); return *this; }
|
||||||
Experiment& setHyperparameters(const HyperParameters& hyperparameters_) { this->hyperparameters = hyperparameters_; return *this; }
|
Experiment& setHyperparameters(const HyperParameters& hyperparameters_) { this->hyperparameters = hyperparameters_; return *this; }
|
||||||
void cross_validation(const std::string& fileName, bool quiet, bool no_train_score);
|
void cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files, bool graph);
|
||||||
void go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score);
|
void go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score, bool generate_fold_files, bool graph);
|
||||||
void saveResult();
|
void saveResult();
|
||||||
void show();
|
void show();
|
||||||
void report();
|
void saveGraph();
|
||||||
|
void report(bool classification_report = false);
|
||||||
private:
|
private:
|
||||||
|
score_t parse_score() const;
|
||||||
Result result;
|
Result result;
|
||||||
bool discretized{ false }, stratified{ false };
|
bool discretized{ false }, stratified{ false };
|
||||||
std::vector<PartialResult> results;
|
std::vector<PartialResult> results;
|
||||||
std::vector<int> randomSeeds;
|
std::vector<int> randomSeeds;
|
||||||
|
std::string discretization_algo;
|
||||||
|
std::string smooth_strategy;
|
||||||
|
bayesnet::Smoothing_t smooth_type{ bayesnet::Smoothing_t::NONE };
|
||||||
HyperParameters hyperparameters;
|
HyperParameters hyperparameters;
|
||||||
int nfolds{ 0 };
|
int nfolds{ 0 };
|
||||||
int max_name{ 7 }; // max length of dataset name for formatting (default 7)
|
int max_name{ 7 }; // max length of dataset name for formatting (default 7)
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
#include "HyperParameters.h"
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include "HyperParameters.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const json& hyperparameters_)
|
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const json& hyperparameters_)
|
||||||
@@ -10,16 +10,9 @@ namespace platform {
|
|||||||
for (const auto& item : datasets) {
|
for (const auto& item : datasets) {
|
||||||
hyperparameters[item] = hyperparameters_;
|
hyperparameters[item] = hyperparameters_;
|
||||||
}
|
}
|
||||||
|
normalize_nested(datasets);
|
||||||
}
|
}
|
||||||
// https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/
|
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file, bool best)
|
||||||
std::string join(std::vector<std::string> const& strings, std::string delim)
|
|
||||||
{
|
|
||||||
std::stringstream ss;
|
|
||||||
std::copy(strings.begin(), strings.end(),
|
|
||||||
std::ostream_iterator<std::string>(ss, delim.c_str()));
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file)
|
|
||||||
{
|
{
|
||||||
// Check if file exists
|
// Check if file exists
|
||||||
std::ifstream file(hyperparameters_file);
|
std::ifstream file(hyperparameters_file);
|
||||||
@@ -28,7 +21,14 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
// Check if file is a json
|
// Check if file is a json
|
||||||
json file_hyperparameters = json::parse(file);
|
json file_hyperparameters = json::parse(file);
|
||||||
auto input_hyperparameters = file_hyperparameters["results"];
|
json input_hyperparameters;
|
||||||
|
if (best) {
|
||||||
|
for (const auto& [key, value] : file_hyperparameters.items()) {
|
||||||
|
input_hyperparameters[key]["hyperparameters"] = value[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
input_hyperparameters = file_hyperparameters["results"];
|
||||||
|
}
|
||||||
// Check if hyperparameters are valid
|
// Check if hyperparameters are valid
|
||||||
for (const auto& dataset : datasets) {
|
for (const auto& dataset : datasets) {
|
||||||
if (!input_hyperparameters.contains(dataset)) {
|
if (!input_hyperparameters.contains(dataset)) {
|
||||||
@@ -38,6 +38,24 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
hyperparameters[dataset] = input_hyperparameters[dataset]["hyperparameters"].get<json>();
|
hyperparameters[dataset] = input_hyperparameters[dataset]["hyperparameters"].get<json>();
|
||||||
}
|
}
|
||||||
|
normalize_nested(datasets);
|
||||||
|
}
|
||||||
|
void HyperParameters::normalize_nested(const std::vector<std::string>& datasets)
|
||||||
|
{
|
||||||
|
// for (const auto& dataset : datasets) {
|
||||||
|
// if (hyperparameters[dataset].contains("be_hyperparams")) {
|
||||||
|
// // Odte has base estimator hyperparameters set this way
|
||||||
|
// hyperparameters[dataset]["be_hyperparams"] = hyperparameters[dataset]["be_hyperparams"].dump();
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
// https://www.techiedelight.com/implode-a-vector-of-strings-into-a-comma-separated-string-in-cpp/
|
||||||
|
std::string join(std::vector<std::string> const& strings, std::string delim)
|
||||||
|
{
|
||||||
|
std::stringstream ss;
|
||||||
|
std::copy(strings.begin(), strings.end(),
|
||||||
|
std::ostream_iterator<std::string>(ss, delim.c_str()));
|
||||||
|
return ss.str();
|
||||||
}
|
}
|
||||||
void HyperParameters::check(const std::vector<std::string>& valid, const std::string& fileName)
|
void HyperParameters::check(const std::vector<std::string>& valid, const std::string& fileName)
|
||||||
{
|
{
|
@@ -6,18 +6,22 @@
|
|||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::ordered_json;
|
||||||
class HyperParameters {
|
class HyperParameters {
|
||||||
public:
|
public:
|
||||||
HyperParameters() = default;
|
HyperParameters() = default;
|
||||||
|
// Constructor to use command line hyperparameters
|
||||||
explicit HyperParameters(const std::vector<std::string>& datasets, const json& hyperparameters_);
|
explicit HyperParameters(const std::vector<std::string>& datasets, const json& hyperparameters_);
|
||||||
explicit HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file);
|
// Constructor to use hyperparameters file generated by grid or by best results
|
||||||
|
explicit HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file, bool best = false);
|
||||||
~HyperParameters() = default;
|
~HyperParameters() = default;
|
||||||
bool notEmpty(const std::string& key) const { return !hyperparameters.at(key).empty(); }
|
bool notEmpty(const std::string& key) const { return !hyperparameters.at(key).empty(); }
|
||||||
void check(const std::vector<std::string>& valid, const std::string& fileName);
|
void check(const std::vector<std::string>& valid, const std::string& fileName);
|
||||||
json get(const std::string& fileName);
|
json get(const std::string& fileName);
|
||||||
private:
|
private:
|
||||||
|
void normalize_nested(const std::vector<std::string>& datasets);
|
||||||
std::map<std::string, json> hyperparameters;
|
std::map<std::string, json> hyperparameters;
|
||||||
|
bool best = false; // Used to separate grid/best hyperparameters as the format of those files are different
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
#endif /* HYPERPARAMETERS_H */
|
#endif
|
@@ -36,13 +36,15 @@ namespace platform {
|
|||||||
[](const pair<std::string, function<bayesnet::BaseClassifier* (void)>>& pair) { return pair.first; });
|
[](const pair<std::string, function<bayesnet::BaseClassifier* (void)>>& pair) { return pair.first; });
|
||||||
return names;
|
return names;
|
||||||
}
|
}
|
||||||
std::string Models::tostring()
|
std::string Models::toString()
|
||||||
{
|
{
|
||||||
std::string result = "";
|
std::string result = "";
|
||||||
|
std::string sep = "";
|
||||||
for (const auto& pair : functionRegistry) {
|
for (const auto& pair : functionRegistry) {
|
||||||
result += pair.first + ", ";
|
result += sep + pair.first;
|
||||||
|
sep = ", ";
|
||||||
}
|
}
|
||||||
return "{" + result.substr(0, result.size() - 2) + "}";
|
return "{" + result + "}";
|
||||||
}
|
}
|
||||||
Registrar::Registrar(const std::string& name, function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
|
Registrar::Registrar(const std::string& name, function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
|
||||||
{
|
{
|
@@ -1,27 +1,27 @@
|
|||||||
#ifndef MODELS_H
|
#ifndef MODELS_H
|
||||||
#define MODELS_H
|
#define MODELS_H
|
||||||
#include <map>
|
#include <map>
|
||||||
#include "BaseClassifier.h"
|
#include <bayesnet/BaseClassifier.h>
|
||||||
#include "AODE.h"
|
#include <bayesnet/ensembles/AODE.h>
|
||||||
#include "TAN.h"
|
#include <bayesnet/ensembles/A2DE.h>
|
||||||
#include "KDB.h"
|
#include <bayesnet/ensembles/AODELd.h>
|
||||||
#include "SPODE.h"
|
#include <bayesnet/ensembles/BoostAODE.h>
|
||||||
#include "TANLd.h"
|
#include <bayesnet/ensembles/BoostA2DE.h>
|
||||||
#include "KDBLd.h"
|
#include <bayesnet/classifiers/TAN.h>
|
||||||
#include "SPODELd.h"
|
#include <bayesnet/classifiers/KDB.h>
|
||||||
#include "AODELd.h"
|
#include <bayesnet/classifiers/SPODE.h>
|
||||||
#include "BoostAODE.h"
|
#include <bayesnet/classifiers/SPnDE.h>
|
||||||
#include "STree.h"
|
#include <bayesnet/classifiers/TANLd.h>
|
||||||
#include "ODTE.h"
|
#include <bayesnet/classifiers/KDBLd.h>
|
||||||
#include "SVC.h"
|
#include <bayesnet/classifiers/SPODELd.h>
|
||||||
#include "XGBoost.h"
|
#include <bayesnet/classifiers/SPODELd.h>
|
||||||
#include "RandomForest.h"
|
#include <pyclassifiers/STree.h>
|
||||||
|
#include <pyclassifiers/ODTE.h>
|
||||||
|
#include <pyclassifiers/SVC.h>
|
||||||
|
#include <pyclassifiers/XGBoost.h>
|
||||||
|
#include <pyclassifiers/RandomForest.h>
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Models {
|
class Models {
|
||||||
private:
|
|
||||||
map<std::string, function<bayesnet::BaseClassifier* (void)>> functionRegistry;
|
|
||||||
static Models* factory; //singleton
|
|
||||||
Models() {};
|
|
||||||
public:
|
public:
|
||||||
Models(Models&) = delete;
|
Models(Models&) = delete;
|
||||||
void operator=(const Models&) = delete;
|
void operator=(const Models&) = delete;
|
||||||
@@ -31,8 +31,11 @@ namespace platform {
|
|||||||
void registerFactoryFunction(const std::string& name,
|
void registerFactoryFunction(const std::string& name,
|
||||||
function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
|
function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
|
||||||
std::vector<string> getNames();
|
std::vector<string> getNames();
|
||||||
std::string tostring();
|
std::string toString();
|
||||||
|
private:
|
||||||
|
map<std::string, function<bayesnet::BaseClassifier* (void)>> functionRegistry;
|
||||||
|
static Models* factory; //singleton
|
||||||
|
Models() {};
|
||||||
};
|
};
|
||||||
class Registrar {
|
class Registrar {
|
||||||
public:
|
public:
|
||||||
|
@@ -1,10 +1,10 @@
|
|||||||
#pragma once
|
#ifndef PARTIAL_RESULT_H
|
||||||
|
#define PARTIAL_RESULT_H
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::ordered_json;
|
||||||
class PartialResult {
|
class PartialResult {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -15,6 +15,7 @@ namespace platform {
|
|||||||
data["times_train"] = json::array();
|
data["times_train"] = json::array();
|
||||||
data["times_test"] = json::array();
|
data["times_test"] = json::array();
|
||||||
data["notes"] = json::array();
|
data["notes"] = json::array();
|
||||||
|
data["graph"] = json::array();
|
||||||
data["train_time"] = 0.0;
|
data["train_time"] = 0.0;
|
||||||
data["train_time_std"] = 0.0;
|
data["train_time_std"] = 0.0;
|
||||||
data["test_time"] = 0.0;
|
data["test_time"] = 0.0;
|
||||||
@@ -27,6 +28,14 @@ namespace platform {
|
|||||||
data["notes"].insert(data["notes"].end(), notes_.begin(), notes_.end());
|
data["notes"].insert(data["notes"].end(), notes_.begin(), notes_.end());
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
PartialResult& setGraph(const std::vector<std::string>& graph)
|
||||||
|
{
|
||||||
|
json graph_ = graph;
|
||||||
|
data["graph"].insert(data["graph"].end(), graph_.begin(), graph_.end());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
PartialResult& setConfusionMatrices(const json& confusion_matrices) { data["confusion_matrices"] = confusion_matrices; return *this; }
|
||||||
|
PartialResult& setConfusionMatricesTrain(const json& confusion_matrices) { data["confusion_matrices_train"] = confusion_matrices; return *this; }
|
||||||
PartialResult& setHyperparameters(const json& hyperparameters) { data["hyperparameters"] = hyperparameters; return *this; }
|
PartialResult& setHyperparameters(const json& hyperparameters) { data["hyperparameters"] = hyperparameters; return *this; }
|
||||||
PartialResult& setSamples(int samples) { data["samples"] = samples; return *this; }
|
PartialResult& setSamples(int samples) { data["samples"] = samples; return *this; }
|
||||||
PartialResult& setFeatures(int features) { data["features"] = features; return *this; }
|
PartialResult& setFeatures(int features) { data["features"] = features; return *this; }
|
||||||
@@ -71,3 +80,4 @@ namespace platform {
|
|||||||
json data;
|
json data;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
#endif
|
67
src/main/RocAuc.cpp
Normal file
67
src/main/RocAuc.cpp
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#include <sstream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
#include <utility>
|
||||||
|
#include "RocAuc.h"
|
||||||
|
namespace platform {
|
||||||
|
|
||||||
|
double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels)
|
||||||
|
{
|
||||||
|
size_t nClasses = y_proba.size(1);
|
||||||
|
// In binary classification problem there's no need to calculate the average of the AUCs
|
||||||
|
if (nClasses == 2)
|
||||||
|
nClasses = 1;
|
||||||
|
size_t nSamples = y_proba.size(0);
|
||||||
|
y_test = tensorToVector(labels);
|
||||||
|
std::vector<double> aucScores(nClasses, 0.0);
|
||||||
|
for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {
|
||||||
|
scoresAndLabels.clear();
|
||||||
|
for (size_t i = 0; i < nSamples; ++i) {
|
||||||
|
scoresAndLabels.emplace_back(y_proba[i][classIdx].item<float>(), y_test[i] == classIdx ? 1 : 0);
|
||||||
|
}
|
||||||
|
aucScores[classIdx] = compute_common(nSamples, classIdx);
|
||||||
|
}
|
||||||
|
return std::accumulate(aucScores.begin(), aucScores.end(), 0.0) / nClasses;
|
||||||
|
}
|
||||||
|
double RocAuc::compute(const std::vector<std::vector<double>>& y_proba, const std::vector<int>& labels)
|
||||||
|
{
|
||||||
|
y_test = labels;
|
||||||
|
size_t nClasses = y_proba[0].size();
|
||||||
|
// In binary classification problem there's no need to calculate the average of the AUCs
|
||||||
|
if (nClasses == 2)
|
||||||
|
nClasses = 1;
|
||||||
|
size_t nSamples = y_proba.size();
|
||||||
|
std::vector<double> aucScores(nClasses, 0.0);
|
||||||
|
for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {
|
||||||
|
scoresAndLabels.clear();
|
||||||
|
for (size_t i = 0; i < nSamples; ++i) {
|
||||||
|
scoresAndLabels.emplace_back(y_proba[i][classIdx], labels[i] == classIdx ? 1 : 0);
|
||||||
|
}
|
||||||
|
aucScores[classIdx] = compute_common(nSamples, classIdx);
|
||||||
|
}
|
||||||
|
return std::accumulate(aucScores.begin(), aucScores.end(), 0.0) / nClasses;
|
||||||
|
}
|
||||||
|
double RocAuc::compute_common(size_t nSamples, size_t classIdx)
|
||||||
|
{
|
||||||
|
std::sort(scoresAndLabels.begin(), scoresAndLabels.end(), std::greater<>());
|
||||||
|
std::vector<double> tpr, fpr;
|
||||||
|
double tp = 0, fp = 0;
|
||||||
|
double totalPos = std::count(y_test.begin(), y_test.end(), classIdx);
|
||||||
|
double totalNeg = nSamples - totalPos;
|
||||||
|
|
||||||
|
for (const auto& [score, label] : scoresAndLabels) {
|
||||||
|
if (label == 1) {
|
||||||
|
tp += 1;
|
||||||
|
} else {
|
||||||
|
fp += 1;
|
||||||
|
}
|
||||||
|
tpr.push_back(tp / totalPos);
|
||||||
|
fpr.push_back(fp / totalNeg);
|
||||||
|
}
|
||||||
|
double auc = 0.0;
|
||||||
|
for (size_t i = 1; i < tpr.size(); ++i) {
|
||||||
|
auc += 0.5 * (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]);
|
||||||
|
}
|
||||||
|
return auc;
|
||||||
|
}
|
||||||
|
}
|
21
src/main/RocAuc.h
Normal file
21
src/main/RocAuc.h
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
#ifndef ROCAUC_H
|
||||||
|
#define ROCAUC_H
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
class RocAuc {
|
||||||
|
public:
|
||||||
|
RocAuc() = default;
|
||||||
|
double compute(const std::vector<std::vector<double>>& y_proba, const std::vector<int>& y_test);
|
||||||
|
double compute(const torch::Tensor& y_proba, const torch::Tensor& y_test);
|
||||||
|
private:
|
||||||
|
double compute_common(size_t nSamples, size_t classIdx);
|
||||||
|
std::vector<std::pair<double, int>> scoresAndLabels;
|
||||||
|
std::vector<int> y_test;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
270
src/main/Scores.cpp
Normal file
270
src/main/Scores.cpp
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
#include <sstream>
|
||||||
|
#include "Scores.h"
|
||||||
|
#include "common/Utils.h" // tensorToVector
|
||||||
|
#include "common/Colors.h"
|
||||||
|
namespace platform {
|
||||||
|
Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_proba, int num_classes, std::vector<std::string> labels) : num_classes(num_classes), labels(labels), y_test(y_test), y_proba(y_proba)
|
||||||
|
{
|
||||||
|
if (labels.size() == 0) {
|
||||||
|
init_default_labels();
|
||||||
|
}
|
||||||
|
total = y_test.size(0);
|
||||||
|
auto y_pred = y_proba.argmax(1);
|
||||||
|
accuracy_value = (y_pred == y_test).sum().item<float>() / total;
|
||||||
|
init_confusion_matrix();
|
||||||
|
for (int i = 0; i < total; i++) {
|
||||||
|
int actual = y_test[i].item<int>();
|
||||||
|
int predicted = y_pred[i].item<int>();
|
||||||
|
confusion_matrix[actual][predicted] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Scores::Scores(const json& confusion_matrix_)
|
||||||
|
{
|
||||||
|
json values;
|
||||||
|
total = 0;
|
||||||
|
num_classes = confusion_matrix_.size();
|
||||||
|
init_confusion_matrix();
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& item : confusion_matrix_.items()) {
|
||||||
|
values = item.value();
|
||||||
|
json key = item.key();
|
||||||
|
if (key.is_number_integer()) {
|
||||||
|
labels.push_back("Class " + std::to_string(key.get<int>()));
|
||||||
|
} else {
|
||||||
|
labels.push_back(key.get<std::string>());
|
||||||
|
}
|
||||||
|
for (int j = 0; j < num_classes; ++j) {
|
||||||
|
int value_int = values[j].get<int>();
|
||||||
|
confusion_matrix[i][j] = value_int;
|
||||||
|
total += value_int;
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
compute_accuracy_value();
|
||||||
|
}
|
||||||
|
float Scores::auc()
|
||||||
|
{
|
||||||
|
size_t nSamples = y_test.numel();
|
||||||
|
if (nSamples == 0) return 0;
|
||||||
|
// In binary classification problem there's no need to calculate the average of the AUCs
|
||||||
|
auto nClasses = num_classes;
|
||||||
|
if (num_classes == 2)
|
||||||
|
nClasses = 1;
|
||||||
|
auto y_testv = tensorToVector<int>(y_test);
|
||||||
|
std::vector<double> aucScores(nClasses, 0.0);
|
||||||
|
std::vector<std::pair<double, int>> scoresAndLabels;
|
||||||
|
for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {
|
||||||
|
if (classIdx >= y_proba.size(1)) {
|
||||||
|
std::cerr << "AUC warning - class index out of range" << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
scoresAndLabels.clear();
|
||||||
|
for (size_t i = 0; i < nSamples; ++i) {
|
||||||
|
scoresAndLabels.emplace_back(y_proba[i][classIdx].item<float>(), y_testv[i] == classIdx ? 1 : 0);
|
||||||
|
}
|
||||||
|
std::sort(scoresAndLabels.begin(), scoresAndLabels.end(), std::greater<>());
|
||||||
|
std::vector<double> tpr, fpr;
|
||||||
|
double tp = 0, fp = 0;
|
||||||
|
double totalPos = std::count(y_testv.begin(), y_testv.end(), classIdx);
|
||||||
|
double totalNeg = nSamples - totalPos;
|
||||||
|
for (const auto& [score, label] : scoresAndLabels) {
|
||||||
|
if (label == 1) {
|
||||||
|
tp += 1;
|
||||||
|
} else {
|
||||||
|
fp += 1;
|
||||||
|
}
|
||||||
|
tpr.push_back(tp / totalPos);
|
||||||
|
fpr.push_back(fp / totalNeg);
|
||||||
|
}
|
||||||
|
double auc = 0.0;
|
||||||
|
for (size_t i = 1; i < tpr.size(); ++i) {
|
||||||
|
auc += 0.5 * (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]);
|
||||||
|
}
|
||||||
|
aucScores[classIdx] = auc;
|
||||||
|
}
|
||||||
|
return std::accumulate(aucScores.begin(), aucScores.end(), 0.0) / nClasses;
|
||||||
|
}
|
||||||
|
Scores Scores::create_aggregate(const json& data, const std::string key)
|
||||||
|
{
|
||||||
|
auto scores = Scores(data[key][0]);
|
||||||
|
for (int i = 1; i < data[key].size(); i++) {
|
||||||
|
auto score = Scores(data[key][i]);
|
||||||
|
scores.aggregate(score);
|
||||||
|
}
|
||||||
|
return scores;
|
||||||
|
}
|
||||||
|
void Scores::compute_accuracy_value()
|
||||||
|
{
|
||||||
|
accuracy_value = 0;
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
accuracy_value += confusion_matrix[i][i].item<int>();
|
||||||
|
}
|
||||||
|
accuracy_value /= total;
|
||||||
|
accuracy_value = std::min(accuracy_value, 1.0f);
|
||||||
|
}
|
||||||
|
void Scores::init_confusion_matrix()
|
||||||
|
{
|
||||||
|
confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32);
|
||||||
|
}
|
||||||
|
void Scores::init_default_labels()
|
||||||
|
{
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
labels.push_back("Class " + std::to_string(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Scores::aggregate(const Scores& a)
|
||||||
|
{
|
||||||
|
if (a.num_classes != num_classes)
|
||||||
|
throw std::invalid_argument("The number of classes must be the same");
|
||||||
|
confusion_matrix += a.confusion_matrix;
|
||||||
|
total += a.total;
|
||||||
|
compute_accuracy_value();
|
||||||
|
}
|
||||||
|
float Scores::accuracy()
|
||||||
|
{
|
||||||
|
return accuracy_value;
|
||||||
|
}
|
||||||
|
float Scores::f1_score(int num_class)
|
||||||
|
{
|
||||||
|
// Compute f1_score in a one vs rest fashion
|
||||||
|
auto precision_value = precision(num_class);
|
||||||
|
auto recall_value = recall(num_class);
|
||||||
|
if (precision_value + recall_value == 0) return 0; // Avoid division by zero (0/0 = 0)
|
||||||
|
return 2 * precision_value * recall_value / (precision_value + recall_value);
|
||||||
|
}
|
||||||
|
float Scores::f1_weighted()
|
||||||
|
{
|
||||||
|
float f1_weighted = 0;
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
f1_weighted += confusion_matrix[i].sum().item<int>() * f1_score(i);
|
||||||
|
}
|
||||||
|
return f1_weighted / total;
|
||||||
|
}
|
||||||
|
float Scores::f1_macro()
|
||||||
|
{
|
||||||
|
float f1_macro = 0;
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
f1_macro += f1_score(i);
|
||||||
|
}
|
||||||
|
return f1_macro / num_classes;
|
||||||
|
}
|
||||||
|
float Scores::precision(int num_class)
|
||||||
|
{
|
||||||
|
int tp = confusion_matrix[num_class][num_class].item<int>();
|
||||||
|
int fp = confusion_matrix.index({ "...", num_class }).sum().item<int>() - tp;
|
||||||
|
int fn = confusion_matrix[num_class].sum().item<int>() - tp;
|
||||||
|
if (tp + fp == 0) return 0; // Avoid division by zero (0/0 = 0
|
||||||
|
return float(tp) / (tp + fp);
|
||||||
|
}
|
||||||
|
float Scores::recall(int num_class)
|
||||||
|
{
|
||||||
|
int tp = confusion_matrix[num_class][num_class].item<int>();
|
||||||
|
int fp = confusion_matrix.index({ "...", num_class }).sum().item<int>() - tp;
|
||||||
|
int fn = confusion_matrix[num_class].sum().item<int>() - tp;
|
||||||
|
if (tp + fn == 0) return 0; // Avoid division by zero (0/0 = 0
|
||||||
|
return float(tp) / (tp + fn);
|
||||||
|
}
|
||||||
|
std::string Scores::classification_report_line(std::string label, float precision, float recall, float f1_score, int support)
|
||||||
|
{
|
||||||
|
std::stringstream oss;
|
||||||
|
oss << std::right << std::setw(label_len) << label << " ";
|
||||||
|
if (precision == 0) {
|
||||||
|
oss << std::string(dlen, ' ') << " ";
|
||||||
|
} else {
|
||||||
|
oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << precision << " ";
|
||||||
|
}
|
||||||
|
if (recall == 0) {
|
||||||
|
oss << std::string(dlen, ' ') << " ";
|
||||||
|
} else {
|
||||||
|
oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << recall << " ";
|
||||||
|
}
|
||||||
|
oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << f1_score << " "
|
||||||
|
<< std::setw(dlen) << std::right << support;
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
std::tuple<float, float, float, float> Scores::compute_averages()
|
||||||
|
{
|
||||||
|
float precision_avg = 0;
|
||||||
|
float recall_avg = 0;
|
||||||
|
float precision_wavg = 0;
|
||||||
|
float recall_wavg = 0;
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
int support = confusion_matrix[i].sum().item<int>();
|
||||||
|
precision_avg += precision(i);
|
||||||
|
precision_wavg += precision(i) * support;
|
||||||
|
recall_avg += recall(i);
|
||||||
|
recall_wavg += recall(i) * support;
|
||||||
|
}
|
||||||
|
precision_wavg /= total;
|
||||||
|
recall_wavg /= total;
|
||||||
|
precision_avg /= num_classes;
|
||||||
|
recall_avg /= num_classes;
|
||||||
|
return { precision_avg, recall_avg, precision_wavg, recall_wavg };
|
||||||
|
}
|
||||||
|
std::vector<std::string> Scores::classification_report(std::string color, std::string title)
|
||||||
|
{
|
||||||
|
std::stringstream oss;
|
||||||
|
std::vector<std::string> report;
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
label_len = std::max(label_len, (int)labels[i].size());
|
||||||
|
}
|
||||||
|
report.push_back("Classification Report using " + title + " dataset");
|
||||||
|
report.push_back("=========================================");
|
||||||
|
oss << std::string(label_len, ' ') << " precision recall f1-score support";
|
||||||
|
report.push_back(oss.str()); oss.str("");
|
||||||
|
oss << std::string(label_len, ' ') << " ========= ========= ========= =========";
|
||||||
|
report.push_back(oss.str()); oss.str("");
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
report.push_back(classification_report_line(labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item<int>()));
|
||||||
|
}
|
||||||
|
report.push_back(" ");
|
||||||
|
oss << classification_report_line("accuracy", 0, 0, accuracy(), total);
|
||||||
|
report.push_back(oss.str()); oss.str("");
|
||||||
|
auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages();
|
||||||
|
report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total));
|
||||||
|
report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total));
|
||||||
|
report.push_back("");
|
||||||
|
report.push_back("Confusion Matrix");
|
||||||
|
report.push_back("================");
|
||||||
|
auto number = total > 1000 ? 4 : 3;
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
oss << std::right << std::setw(label_len) << labels[i] << " ";
|
||||||
|
for (int j = 0; j < num_classes; j++) {
|
||||||
|
if (i == j) oss << Colors::GREEN();
|
||||||
|
oss << std::setw(number) << confusion_matrix[i][j].item<int>() << " ";
|
||||||
|
if (i == j) oss << color;
|
||||||
|
}
|
||||||
|
report.push_back(oss.str()); oss.str("");
|
||||||
|
}
|
||||||
|
return report;
|
||||||
|
}
|
||||||
|
json Scores::classification_report_json(std::string title)
|
||||||
|
{
|
||||||
|
json output;
|
||||||
|
output["title"] = "Classification Report using " + title + " dataset";
|
||||||
|
output["headers"] = { " ", "precision", "recall", "f1-score", "support" };
|
||||||
|
output["body"] = {};
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
output["body"].push_back({ labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item<int>() });
|
||||||
|
}
|
||||||
|
output["accuracy"] = { "accuracy", 0, 0, accuracy(), total };
|
||||||
|
auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages();
|
||||||
|
output["averages"] = { "macro avg", precision_avg, recall_avg, f1_macro(), total };
|
||||||
|
output["weighted"] = { "weighted avg", precision_wavg, recall_wavg, f1_weighted(), total };
|
||||||
|
output["confusion_matrix"] = get_confusion_matrix_json();
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
json Scores::get_confusion_matrix_json(bool labels_as_keys)
|
||||||
|
{
|
||||||
|
json output;
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
auto r_ptr = confusion_matrix[i].data_ptr<int>();
|
||||||
|
if (labels_as_keys) {
|
||||||
|
output[labels[i]] = std::vector<int>(r_ptr, r_ptr + num_classes);
|
||||||
|
} else {
|
||||||
|
output[i] = std::vector<int>(r_ptr, r_ptr + num_classes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
}
|
46
src/main/Scores.h
Normal file
46
src/main/Scores.h
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
#ifndef SCORES_H
|
||||||
|
#define SCORES_H
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
class Scores {
|
||||||
|
public:
|
||||||
|
Scores(torch::Tensor& y_test, torch::Tensor& y_proba, int num_classes, std::vector<std::string> labels = {});
|
||||||
|
explicit Scores(const json& confusion_matrix_);
|
||||||
|
static Scores create_aggregate(const json& data, const std::string key);
|
||||||
|
float accuracy();
|
||||||
|
float auc();
|
||||||
|
float f1_score(int num_class);
|
||||||
|
float f1_weighted();
|
||||||
|
float f1_macro();
|
||||||
|
float precision(int num_class);
|
||||||
|
float recall(int num_class);
|
||||||
|
torch::Tensor get_confusion_matrix() { return confusion_matrix; }
|
||||||
|
std::vector<std::string> classification_report(std::string color = "", std::string title = "");
|
||||||
|
json classification_report_json(std::string title = "");
|
||||||
|
json get_confusion_matrix_json(bool labels_as_keys = false);
|
||||||
|
void aggregate(const Scores& a);
|
||||||
|
private:
|
||||||
|
std::string classification_report_line(std::string label, float precision, float recall, float f1_score, int support);
|
||||||
|
void init_confusion_matrix();
|
||||||
|
void init_default_labels();
|
||||||
|
void compute_accuracy_value();
|
||||||
|
std::tuple<float, float, float, float> compute_averages();
|
||||||
|
int num_classes;
|
||||||
|
float accuracy_value;
|
||||||
|
int total;
|
||||||
|
std::vector<std::string> labels;
|
||||||
|
torch::Tensor confusion_matrix; // Rows ar actual, columns are predicted
|
||||||
|
torch::Tensor null_t; // Covenient null tensor needed when confusion_matrix constructor is used
|
||||||
|
torch::Tensor& y_test = null_t; // for ROC AUC
|
||||||
|
torch::Tensor& y_proba = null_t; // for ROC AUC
|
||||||
|
int label_len = 16;
|
||||||
|
int dlen = 9;
|
||||||
|
int ndec = 7;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
@@ -1,137 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <argparse/argparse.hpp>
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
#include "Experiment.h"
|
|
||||||
#include "Datasets.h"
|
|
||||||
#include "DotEnv.h"
|
|
||||||
#include "Models.h"
|
|
||||||
#include "modelRegister.h"
|
|
||||||
#include "Paths.h"
|
|
||||||
#include "config.h"
|
|
||||||
|
|
||||||
|
|
||||||
using json = nlohmann::json;
|
|
||||||
|
|
||||||
void manageArguments(argparse::ArgumentParser& program)
|
|
||||||
{
|
|
||||||
auto env = platform::DotEnv();
|
|
||||||
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
|
||||||
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
|
||||||
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
|
||||||
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
|
|
||||||
program.add_argument("-m", "--model")
|
|
||||||
.help("Model to use " + platform::Models::instance()->tostring())
|
|
||||||
.action([](const std::string& value) {
|
|
||||||
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
|
|
||||||
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
|
|
||||||
}
|
|
||||||
);
|
|
||||||
program.add_argument("--title").default_value("").help("Experiment title");
|
|
||||||
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
|
||||||
program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true);
|
|
||||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
|
||||||
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
|
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
|
||||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
|
||||||
try {
|
|
||||||
auto k = stoi(value);
|
|
||||||
if (k < 2) {
|
|
||||||
throw std::runtime_error("Number of folds must be greater than 1");
|
|
||||||
}
|
|
||||||
return k;
|
|
||||||
}
|
|
||||||
catch (const runtime_error& err) {
|
|
||||||
throw std::runtime_error(err.what());
|
|
||||||
}
|
|
||||||
catch (...) {
|
|
||||||
throw std::runtime_error("Number of folds must be an integer");
|
|
||||||
}});
|
|
||||||
auto seed_values = env.getSeeds();
|
|
||||||
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
|
||||||
{
|
|
||||||
argparse::ArgumentParser program("b_main", { project_version.begin(), project_version.end() });
|
|
||||||
manageArguments(program);
|
|
||||||
std::string file_name, model_name, title, hyperparameters_file;
|
|
||||||
json hyperparameters_json;
|
|
||||||
bool discretize_dataset, stratified, saveResults, quiet, no_train_score;
|
|
||||||
std::vector<int> seeds;
|
|
||||||
std::vector<std::string> filesToTest;
|
|
||||||
int n_folds;
|
|
||||||
try {
|
|
||||||
program.parse_args(argc, argv);
|
|
||||||
file_name = program.get<std::string>("dataset");
|
|
||||||
model_name = program.get<std::string>("model");
|
|
||||||
discretize_dataset = program.get<bool>("discretize");
|
|
||||||
stratified = program.get<bool>("stratified");
|
|
||||||
quiet = program.get<bool>("quiet");
|
|
||||||
n_folds = program.get<int>("folds");
|
|
||||||
seeds = program.get<std::vector<int>>("seeds");
|
|
||||||
auto hyperparameters = program.get<std::string>("hyperparameters");
|
|
||||||
hyperparameters_json = json::parse(hyperparameters);
|
|
||||||
hyperparameters_file = program.get<std::string>("hyper-file");
|
|
||||||
no_train_score = program.get<bool>("no-train-score");
|
|
||||||
if (hyperparameters_file != "" && hyperparameters != "{}") {
|
|
||||||
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
|
|
||||||
}
|
|
||||||
title = program.get<std::string>("title");
|
|
||||||
if (title == "" && file_name == "") {
|
|
||||||
throw runtime_error("title is mandatory if dataset is not provided");
|
|
||||||
}
|
|
||||||
saveResults = program.get<bool>("save");
|
|
||||||
}
|
|
||||||
catch (const exception& err) {
|
|
||||||
cerr << err.what() << std::endl;
|
|
||||||
cerr << program;
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets());
|
|
||||||
if (file_name != "") {
|
|
||||||
if (!datasets.isDataset(file_name)) {
|
|
||||||
cerr << "Dataset " << file_name << " not found" << std::endl;
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
if (title == "") {
|
|
||||||
title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds";
|
|
||||||
}
|
|
||||||
filesToTest.push_back(file_name);
|
|
||||||
} else {
|
|
||||||
filesToTest = datasets.getNames();
|
|
||||||
saveResults = true;
|
|
||||||
}
|
|
||||||
platform::HyperParameters test_hyperparams;
|
|
||||||
if (hyperparameters_file != "") {
|
|
||||||
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file);
|
|
||||||
} else {
|
|
||||||
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_json);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Begin Processing
|
|
||||||
*/
|
|
||||||
auto env = platform::DotEnv();
|
|
||||||
auto experiment = platform::Experiment();
|
|
||||||
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
|
|
||||||
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
|
||||||
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy");
|
|
||||||
experiment.setHyperparameters(test_hyperparams);
|
|
||||||
for (auto seed : seeds) {
|
|
||||||
experiment.addRandomSeed(seed);
|
|
||||||
}
|
|
||||||
platform::Timer timer;
|
|
||||||
timer.start();
|
|
||||||
experiment.go(filesToTest, quiet, no_train_score);
|
|
||||||
experiment.setDuration(timer.getDuration());
|
|
||||||
if (saveResults) {
|
|
||||||
experiment.saveResult();
|
|
||||||
}
|
|
||||||
if (!quiet)
|
|
||||||
experiment.report();
|
|
||||||
std::cout << "Done!" << std::endl;
|
|
||||||
return 0;
|
|
||||||
}
|
|
@@ -1,11 +1,14 @@
|
|||||||
#ifndef MODEL_REGISTER_H
|
#ifndef MODELREGISTER_H
|
||||||
#define MODEL_REGISTER_H
|
#define MODELREGISTER_H
|
||||||
|
|
||||||
static platform::Registrar registrarT("TAN",
|
static platform::Registrar registrarT("TAN",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
|
||||||
static platform::Registrar registrarTLD("TANLd",
|
static platform::Registrar registrarTLD("TANLd",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TANLd();});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TANLd();});
|
||||||
static platform::Registrar registrarS("SPODE",
|
static platform::Registrar registrarS("SPODE",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
||||||
|
static platform::Registrar registrarSn("SPnDE",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPnDE({ 0, 1 });});
|
||||||
static platform::Registrar registrarSLD("SPODELd",
|
static platform::Registrar registrarSLD("SPODELd",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODELd(2);});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODELd(2);});
|
||||||
static platform::Registrar registrarK("KDB",
|
static platform::Registrar registrarK("KDB",
|
||||||
@@ -14,10 +17,14 @@ static platform::Registrar registrarKLD("KDBLd",
|
|||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDBLd(2);});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDBLd(2);});
|
||||||
static platform::Registrar registrarA("AODE",
|
static platform::Registrar registrarA("AODE",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
|
||||||
|
static platform::Registrar registrarA2("A2DE",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::A2DE();});
|
||||||
static platform::Registrar registrarALD("AODELd",
|
static platform::Registrar registrarALD("AODELd",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();});
|
||||||
static platform::Registrar registrarBA("BoostAODE",
|
static platform::Registrar registrarBA("BoostAODE",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();});
|
||||||
|
static platform::Registrar registrarBA2("BoostA2DE",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostA2DE();});
|
||||||
static platform::Registrar registrarSt("STree",
|
static platform::Registrar registrarSt("STree",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();});
|
[](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();});
|
||||||
static platform::Registrar registrarOdte("Odte",
|
static platform::Registrar registrarOdte("Odte",
|
||||||
@@ -28,4 +35,5 @@ static platform::Registrar registrarRaF("RandomForest",
|
|||||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();});
|
[](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();});
|
||||||
static platform::Registrar registrarXGB("XGBoost",
|
static platform::Registrar registrarXGB("XGBoost",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();});
|
[](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();});
|
||||||
|
|
||||||
#endif
|
#endif
|
@@ -1,87 +0,0 @@
|
|||||||
#include "CommandParser.h"
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <algorithm>
|
|
||||||
#include "Colors.h"
|
|
||||||
#include "Utils.h"
|
|
||||||
|
|
||||||
namespace platform {
|
|
||||||
void CommandParser::messageError(const std::string& message)
|
|
||||||
{
|
|
||||||
std::cout << Colors::RED() << message << Colors::RESET() << std::endl;
|
|
||||||
}
|
|
||||||
std::pair<char, int> CommandParser::parse(const std::string& color, const std::vector<std::tuple<std::string, char, bool>>& options, const char defaultCommand, const int maxIndex)
|
|
||||||
{
|
|
||||||
bool finished = false;
|
|
||||||
while (!finished) {
|
|
||||||
std::stringstream oss;
|
|
||||||
std::string line;
|
|
||||||
oss << color << "Choose option (";
|
|
||||||
bool first = true;
|
|
||||||
for (auto& option : options) {
|
|
||||||
if (first) {
|
|
||||||
first = false;
|
|
||||||
} else {
|
|
||||||
oss << ", ";
|
|
||||||
}
|
|
||||||
oss << std::get<char>(option) << "=" << std::get<std::string>(option);
|
|
||||||
}
|
|
||||||
oss << "): ";
|
|
||||||
std::cout << oss.str();
|
|
||||||
getline(std::cin, line);
|
|
||||||
std::cout << Colors::RESET();
|
|
||||||
line = trim(line);
|
|
||||||
if (line.size() == 0)
|
|
||||||
continue;
|
|
||||||
if (all_of(line.begin(), line.end(), ::isdigit)) {
|
|
||||||
command = defaultCommand;
|
|
||||||
index = stoi(line);
|
|
||||||
if (index > maxIndex || index < 0) {
|
|
||||||
messageError("Index out of range");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
finished = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
bool found = false;
|
|
||||||
for (auto& option : options) {
|
|
||||||
if (line[0] == std::get<char>(option)) {
|
|
||||||
found = true;
|
|
||||||
// it's a match
|
|
||||||
line.erase(line.begin());
|
|
||||||
line = trim(line);
|
|
||||||
if (std::get<bool>(option)) {
|
|
||||||
// The option requires a value
|
|
||||||
if (line.size() == 0) {
|
|
||||||
messageError("Option " + std::get<std::string>(option) + " requires a value");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
index = stoi(line);
|
|
||||||
if (index > maxIndex || index < 0) {
|
|
||||||
messageError("Index out of range");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::invalid_argument& ia) {
|
|
||||||
messageError("Invalid value: " + line);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (line.size() > 0) {
|
|
||||||
messageError("option " + std::get<std::string>(option) + " doesn't accept values");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
command = std::get<char>(option);
|
|
||||||
finished = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!found) {
|
|
||||||
messageError("I don't know " + line);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { command, index };
|
|
||||||
}
|
|
||||||
} /* namespace platform */
|
|
@@ -1,20 +0,0 @@
|
|||||||
#ifndef COMMAND_PARSER_H
|
|
||||||
#define COMMAND_PARSER_H
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <tuple>
|
|
||||||
|
|
||||||
namespace platform {
|
|
||||||
class CommandParser {
|
|
||||||
public:
|
|
||||||
CommandParser() = default;
|
|
||||||
std::pair<char, int> parse(const std::string& color, const std::vector<std::tuple<std::string, char, bool>>& options, const char defaultCommand, const int maxIndex);
|
|
||||||
char getCommand() const { return command; };
|
|
||||||
int getIndex() const { return index; };
|
|
||||||
private:
|
|
||||||
void messageError(const std::string& message);
|
|
||||||
char command;
|
|
||||||
int index;
|
|
||||||
};
|
|
||||||
} /* namespace platform */
|
|
||||||
#endif /* COMMAND_PARSER_H */
|
|
@@ -1,225 +0,0 @@
|
|||||||
#include "ManageResults.h"
|
|
||||||
#include "CommandParser.h"
|
|
||||||
#include <filesystem>
|
|
||||||
#include <tuple>
|
|
||||||
#include "Colors.h"
|
|
||||||
#include "CLocale.h"
|
|
||||||
#include "Paths.h"
|
|
||||||
#include "ReportConsole.h"
|
|
||||||
#include "ReportExcel.h"
|
|
||||||
|
|
||||||
namespace platform {
|
|
||||||
|
|
||||||
ManageResults::ManageResults(int numFiles, const std::string& model, const std::string& score, bool complete, bool partial, bool compare) :
|
|
||||||
numFiles{ numFiles }, complete{ complete }, partial{ partial }, compare{ compare }, results(Results(Paths::results(), model, score, complete, partial))
|
|
||||||
{
|
|
||||||
indexList = true;
|
|
||||||
openExcel = false;
|
|
||||||
workbook = NULL;
|
|
||||||
if (numFiles == 0) {
|
|
||||||
this->numFiles = results.size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void ManageResults::doMenu()
|
|
||||||
{
|
|
||||||
if (results.empty()) {
|
|
||||||
std::cout << Colors::MAGENTA() << "No results found!" << Colors::RESET() << std::endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
results.sortDate();
|
|
||||||
list();
|
|
||||||
menu();
|
|
||||||
if (openExcel) {
|
|
||||||
workbook_close(workbook);
|
|
||||||
}
|
|
||||||
std::cout << Colors::RESET() << "Done!" << std::endl;
|
|
||||||
}
|
|
||||||
void ManageResults::list()
|
|
||||||
{
|
|
||||||
auto temp = ConfigLocale();
|
|
||||||
std::string suffix = numFiles != results.size() ? " of " + std::to_string(results.size()) : "";
|
|
||||||
std::stringstream oss;
|
|
||||||
oss << "Results on screen: " << numFiles << suffix;
|
|
||||||
std::cout << Colors::GREEN() << oss.str() << std::endl;
|
|
||||||
std::cout << std::string(oss.str().size(), '-') << std::endl;
|
|
||||||
if (complete) {
|
|
||||||
std::cout << Colors::MAGENTA() << "Only listing complete results" << std::endl;
|
|
||||||
}
|
|
||||||
if (partial) {
|
|
||||||
std::cout << Colors::MAGENTA() << "Only listing partial results" << std::endl;
|
|
||||||
}
|
|
||||||
auto i = 0;
|
|
||||||
int maxModel = results.maxModelSize();
|
|
||||||
std::cout << Colors::GREEN() << " # Date " << std::setw(maxModel) << std::left << "Model" << " Score Name Score C/P Duration Title" << std::endl;
|
|
||||||
std::cout << "=== ========== " << std::string(maxModel, '=') << " =========== =========== === ========= =============================================================" << std::endl;
|
|
||||||
bool odd = true;
|
|
||||||
for (auto& result : results) {
|
|
||||||
auto color = odd ? Colors::BLUE() : Colors::CYAN();
|
|
||||||
std::cout << color << std::setw(3) << std::fixed << std::right << i++ << " ";
|
|
||||||
std::cout << result.to_string(maxModel) << std::endl;
|
|
||||||
if (i == numFiles) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
odd = !odd;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bool ManageResults::confirmAction(const std::string& intent, const std::string& fileName) const
|
|
||||||
{
|
|
||||||
std::string color;
|
|
||||||
if (intent == "delete") {
|
|
||||||
color = Colors::RED();
|
|
||||||
} else {
|
|
||||||
color = Colors::YELLOW();
|
|
||||||
}
|
|
||||||
std::string line;
|
|
||||||
bool finished = false;
|
|
||||||
while (!finished) {
|
|
||||||
std::cout << color << "Really want to " << intent << " " << fileName << "? (y/n): ";
|
|
||||||
getline(std::cin, line);
|
|
||||||
finished = line.size() == 1 && (tolower(line[0]) == 'y' || tolower(line[0] == 'n'));
|
|
||||||
}
|
|
||||||
if (tolower(line[0]) == 'y') {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
std::cout << "Not done!" << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
void ManageResults::report(const int index, const bool excelReport)
|
|
||||||
{
|
|
||||||
std::cout << Colors::YELLOW() << "Reporting " << results.at(index).getFilename() << std::endl;
|
|
||||||
auto data = results.at(index).getJson();
|
|
||||||
if (excelReport) {
|
|
||||||
ReportExcel reporter(data, compare, workbook);
|
|
||||||
reporter.show();
|
|
||||||
openExcel = true;
|
|
||||||
workbook = reporter.getWorkbook();
|
|
||||||
std::cout << "Adding sheet to " << Paths::excel() + Paths::excelResults() << std::endl;
|
|
||||||
} else {
|
|
||||||
ReportConsole reporter(data, compare);
|
|
||||||
reporter.show();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void ManageResults::showIndex(const int index, const int idx)
|
|
||||||
{
|
|
||||||
// Show a dataset result inside a report
|
|
||||||
auto data = results.at(index).getJson();
|
|
||||||
std::cout << Colors::YELLOW() << "Showing " << results.at(index).getFilename() << std::endl;
|
|
||||||
ReportConsole reporter(data, compare, idx);
|
|
||||||
reporter.show();
|
|
||||||
}
|
|
||||||
void ManageResults::sortList()
|
|
||||||
{
|
|
||||||
std::cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', duration='u', model='m'): ";
|
|
||||||
std::string line;
|
|
||||||
char option;
|
|
||||||
getline(std::cin, line);
|
|
||||||
if (line.size() == 0)
|
|
||||||
return;
|
|
||||||
if (line.size() > 1) {
|
|
||||||
std::cout << "Invalid option" << std::endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
option = line[0];
|
|
||||||
switch (option) {
|
|
||||||
case 'd':
|
|
||||||
results.sortDate();
|
|
||||||
break;
|
|
||||||
case 's':
|
|
||||||
results.sortScore();
|
|
||||||
break;
|
|
||||||
case 'u':
|
|
||||||
results.sortDuration();
|
|
||||||
break;
|
|
||||||
case 'm':
|
|
||||||
results.sortModel();
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
std::cout << "Invalid option" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void ManageResults::menu()
|
|
||||||
{
|
|
||||||
char option;
|
|
||||||
int index, subIndex;
|
|
||||||
bool finished = false;
|
|
||||||
std::string filename;
|
|
||||||
// tuple<Option, digit, requires value>
|
|
||||||
std::vector<std::tuple<std::string, char, bool>> mainOptions = {
|
|
||||||
{"quit", 'q', false},
|
|
||||||
{"list", 'l', false},
|
|
||||||
{"delete", 'd', true},
|
|
||||||
{"hide", 'h', true},
|
|
||||||
{"sort", 's', false},
|
|
||||||
{"report", 'r', true},
|
|
||||||
{"excel", 'e', true},
|
|
||||||
{"title", 't', true}
|
|
||||||
};
|
|
||||||
std::vector<std::tuple<std::string, char, bool>> listOptions = {
|
|
||||||
{"report", 'r', true},
|
|
||||||
{"list", 'l', false},
|
|
||||||
{"quit", 'q', false}
|
|
||||||
};
|
|
||||||
auto parser = CommandParser();
|
|
||||||
while (!finished) {
|
|
||||||
if (indexList) {
|
|
||||||
std::tie(option, index) = parser.parse(Colors::GREEN(), mainOptions, 'r', numFiles - 1);
|
|
||||||
} else {
|
|
||||||
std::tie(option, subIndex) = parser.parse(Colors::CYAN(), listOptions, 'r', results.at(index).getJson()["results"].size() - 1);
|
|
||||||
}
|
|
||||||
switch (option) {
|
|
||||||
case 'q':
|
|
||||||
finished = true;
|
|
||||||
break;
|
|
||||||
case 'l':
|
|
||||||
list();
|
|
||||||
indexList = true;
|
|
||||||
break;
|
|
||||||
case 'd':
|
|
||||||
filename = results.at(index).getFilename();
|
|
||||||
if (!confirmAction("delete", filename))
|
|
||||||
break;
|
|
||||||
std::cout << "Deleting " << filename << std::endl;
|
|
||||||
results.deleteResult(index);
|
|
||||||
std::cout << "File: " + filename + " deleted!" << std::endl;
|
|
||||||
list();
|
|
||||||
break;
|
|
||||||
case 'h':
|
|
||||||
filename = results.at(index).getFilename();
|
|
||||||
if (!confirmAction("hide", filename))
|
|
||||||
break;
|
|
||||||
filename = results.at(index).getFilename();
|
|
||||||
std::cout << "Hiding " << filename << std::endl;
|
|
||||||
results.hideResult(index, Paths::hiddenResults());
|
|
||||||
std::cout << "File: " + filename + " hidden! (moved to " << Paths::hiddenResults() << ")" << std::endl;
|
|
||||||
list();
|
|
||||||
break;
|
|
||||||
case 's':
|
|
||||||
sortList();
|
|
||||||
list();
|
|
||||||
break;
|
|
||||||
case 'r':
|
|
||||||
if (indexList) {
|
|
||||||
report(index, false);
|
|
||||||
indexList = false;
|
|
||||||
} else {
|
|
||||||
showIndex(index, subIndex);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case 'e':
|
|
||||||
report(index, true);
|
|
||||||
break;
|
|
||||||
case 't':
|
|
||||||
std::cout << "Title: " << results.at(index).getTitle() << std::endl;
|
|
||||||
std::cout << "New title: ";
|
|
||||||
std::string newTitle;
|
|
||||||
getline(std::cin, newTitle);
|
|
||||||
if (!newTitle.empty()) {
|
|
||||||
results.at(index).setTitle(newTitle);
|
|
||||||
results.at(index).save();
|
|
||||||
std::cout << "Title changed to " << newTitle << std::endl;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} /* namespace platform */
|
|
@@ -1,31 +0,0 @@
|
|||||||
#ifndef MANAGE_RESULTS_H
|
|
||||||
#define MANAGE_RESULTS_H
|
|
||||||
#include "Results.h"
|
|
||||||
#include "xlsxwriter.h"
|
|
||||||
|
|
||||||
namespace platform {
|
|
||||||
class ManageResults {
|
|
||||||
public:
|
|
||||||
ManageResults(int numFiles, const std::string& model, const std::string& score, bool complete, bool partial, bool compare);
|
|
||||||
~ManageResults() = default;
|
|
||||||
void doMenu();
|
|
||||||
private:
|
|
||||||
void list();
|
|
||||||
bool confirmAction(const std::string& intent, const std::string& fileName) const;
|
|
||||||
void report(const int index, const bool excelReport);
|
|
||||||
void showIndex(const int index, const int idx);
|
|
||||||
void sortList();
|
|
||||||
void menu();
|
|
||||||
int numFiles;
|
|
||||||
bool indexList;
|
|
||||||
bool openExcel;
|
|
||||||
bool complete;
|
|
||||||
bool partial;
|
|
||||||
bool compare;
|
|
||||||
Results results;
|
|
||||||
lxw_workbook* workbook;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif /* MANAGE_RESULTS_H */
|
|
564
src/manage/ManageScreen.cpp
Normal file
564
src/manage/ManageScreen.cpp
Normal file
@@ -0,0 +1,564 @@
|
|||||||
|
#include <filesystem>
|
||||||
|
#include <tuple>
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "folding.hpp"
|
||||||
|
#include "common/CLocale.h"
|
||||||
|
#include "common/Paths.h"
|
||||||
|
#include "OptionsMenu.h"
|
||||||
|
#include "ManageScreen.h"
|
||||||
|
#include "reports/DatasetsConsole.h"
|
||||||
|
#include "reports/ReportConsole.h"
|
||||||
|
#include "reports/ReportExcel.h"
|
||||||
|
#include "reports/ReportExcelCompared.h"
|
||||||
|
#include <bayesnet/classifiers/TAN.h>
|
||||||
|
#include <fimdlp/CPPFImdlp.h>
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
const std::string STATUS_OK = "Ok.";
|
||||||
|
const std::string STATUS_COLOR = Colors::GREEN();
|
||||||
|
|
||||||
|
ManageScreen::ManageScreen(int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare) :
|
||||||
|
rows{ rows }, cols{ cols }, complete{ complete }, partial{ partial }, compare{ compare }, didExcel(false), results(ResultsManager(model, score, platform, complete, partial))
|
||||||
|
{
|
||||||
|
results.load();
|
||||||
|
openExcel = false;
|
||||||
|
workbook = NULL;
|
||||||
|
maxModel = results.maxModelSize();
|
||||||
|
maxTitle = results.maxTitleSize();
|
||||||
|
header_lengths = { 3, 10, maxModel, 11, 10, 12, 2, 3, 7, maxTitle };
|
||||||
|
header_labels = { " #", "Date", "Model", "Score Name", "Score", "Platform", "SD", "C/P", "Time", "Title" };
|
||||||
|
sort_fields = { "Date", "Model", "Score", "Time" };
|
||||||
|
updateSize(rows, cols);
|
||||||
|
// Initializes the paginator for each output type (experiments, datasets, result)
|
||||||
|
for (int i = 0; i < static_cast<int>(OutputType::Count); i++) {
|
||||||
|
paginator.push_back(Paginator(this->rows, results.size()));
|
||||||
|
}
|
||||||
|
index_A = -1;
|
||||||
|
index_B = -1;
|
||||||
|
index = -1;
|
||||||
|
subIndex = -1;
|
||||||
|
output_type = OutputType::EXPERIMENTS;
|
||||||
|
}
|
||||||
|
void ManageScreen::computeSizes()
|
||||||
|
{
|
||||||
|
int minTitle = 10;
|
||||||
|
// set 10 chars as minimum for Title
|
||||||
|
auto header_title = header_lengths[header_lengths.size() - 1];
|
||||||
|
min_columns = std::accumulate(header_lengths.begin(), header_lengths.end(), 0) + header_lengths.size() - header_title + minTitle;
|
||||||
|
maxTitle = minTitle + cols - min_columns;
|
||||||
|
header_lengths[header_lengths.size() - 1] = maxTitle;
|
||||||
|
cols = std::min(cols, min_columns + maxTitle);
|
||||||
|
for (auto& paginator_ : paginator) {
|
||||||
|
paginator_.setPageSize(rows);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool ManageScreen::checkWrongColumns()
|
||||||
|
{
|
||||||
|
if (min_columns > cols) {
|
||||||
|
std::cerr << Colors::MAGENTA() << "Make screen bigger to fit the results! " + std::to_string(min_columns - cols) + " columns needed! " << std::endl;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void ManageScreen::updateSize(int rows_, int cols_)
|
||||||
|
{
|
||||||
|
rows = std::max(6, rows_ - 6); // 6 is the number of lines used by the menu & header
|
||||||
|
cols = cols_;
|
||||||
|
computeSizes();
|
||||||
|
}
|
||||||
|
void ManageScreen::doMenu()
|
||||||
|
{
|
||||||
|
if (results.empty()) {
|
||||||
|
std::cerr << Colors::MAGENTA() << "No results found!" << Colors::RESET() << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (checkWrongColumns())
|
||||||
|
return;
|
||||||
|
results.sortResults(sort_field, sort_type);
|
||||||
|
list(STATUS_OK, STATUS_COLOR);
|
||||||
|
menu();
|
||||||
|
if (openExcel) {
|
||||||
|
workbook_close(workbook);
|
||||||
|
}
|
||||||
|
if (didExcel) {
|
||||||
|
std::cout << Colors::MAGENTA() << "Excel file created: " << Paths::excel() + Paths::excelResults() << std::endl;
|
||||||
|
}
|
||||||
|
std::cout << Colors::RESET() << "Done!" << std::endl;
|
||||||
|
}
|
||||||
|
std::string ManageScreen::getVersions()
|
||||||
|
{
|
||||||
|
std::string kfold_version = folding::KFold(5, 100).version();
|
||||||
|
std::string bayesnet_version = bayesnet::TAN().getVersion();
|
||||||
|
std::string mdlp_version = mdlp::CPPFImdlp::version();
|
||||||
|
return " BayesNet: " + bayesnet_version + " Folding: " + kfold_version + " MDLP: " + mdlp_version + " ";
|
||||||
|
}
|
||||||
|
void ManageScreen::header()
|
||||||
|
{
|
||||||
|
auto [index_from, index_to] = paginator[static_cast<int>(output_type)].getOffset();
|
||||||
|
std::string suffix = "";
|
||||||
|
if (complete) {
|
||||||
|
suffix = " Only listing complete results ";
|
||||||
|
}
|
||||||
|
if (partial) {
|
||||||
|
suffix = " Only listing partial results ";
|
||||||
|
}
|
||||||
|
auto page = paginator[static_cast<int>(output_type)].getPage();
|
||||||
|
auto pages = paginator[static_cast<int>(output_type)].getPages();
|
||||||
|
auto lines = paginator[static_cast<int>(output_type)].getLines();
|
||||||
|
auto total = paginator[static_cast<int>(output_type)].getTotal();
|
||||||
|
std::string header = " Lines " + std::to_string(lines) + " of "
|
||||||
|
+ std::to_string(total) + " - Page " + std::to_string(page) + " of "
|
||||||
|
+ std::to_string(pages) + " ";
|
||||||
|
std::string versions = getVersions();
|
||||||
|
int filler = std::max(cols - versions.size() - suffix.size() - header.size(), size_t(0));
|
||||||
|
std::string prefix = std::string(filler, ' ');
|
||||||
|
std::cout << Colors::CLRSCR() << Colors::REVERSE() << Colors::WHITE() << header
|
||||||
|
<< prefix << Colors::GREEN() << versions << Colors::MAGENTA() << suffix << Colors::RESET() << std::endl;
|
||||||
|
}
|
||||||
|
void ManageScreen::footer(const std::string& status, const std::string& status_color)
|
||||||
|
{
|
||||||
|
std::stringstream oss;
|
||||||
|
oss << " A: " << (index_A == -1 ? "<notset>" : std::to_string(index_A)) <<
|
||||||
|
" B: " << (index_B == -1 ? "<notset>" : std::to_string(index_B)) << " ";
|
||||||
|
int status_length = std::max(oss.str().size(), cols - oss.str().size());
|
||||||
|
auto status_message = status.substr(0, status_length - 1);
|
||||||
|
std::string status_line = status_message + std::string(std::max(size_t(0), status_length - status_message.size() - 1), ' ');
|
||||||
|
auto color = (index_A != -1 && index_B != -1) ? Colors::IGREEN() : Colors::IYELLOW();
|
||||||
|
std::cout << color << Colors::REVERSE() << oss.str() << Colors::RESET() << Colors::WHITE()
|
||||||
|
<< Colors::REVERSE() << status_color << " " << status_line << Colors::IWHITE()
|
||||||
|
<< Colors::RESET() << std::endl;
|
||||||
|
}
|
||||||
|
void ManageScreen::list(const std::string& status_message, const std::string& status_color)
|
||||||
|
{
|
||||||
|
switch (static_cast<int>(output_type)) {
|
||||||
|
case static_cast<int>(OutputType::RESULT):
|
||||||
|
list_result(status_message, status_color);
|
||||||
|
break;
|
||||||
|
case static_cast<int>(OutputType::DETAIL):
|
||||||
|
list_detail(status_message, status_color);
|
||||||
|
break;
|
||||||
|
case static_cast<int>(OutputType::DATASETS):
|
||||||
|
list_datasets(status_message, status_color);
|
||||||
|
break;
|
||||||
|
case static_cast<int>(OutputType::EXPERIMENTS):
|
||||||
|
list_experiments(status_message, status_color);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void ManageScreen::list_result(const std::string& status_message, const std::string& status_color)
|
||||||
|
{
|
||||||
|
auto data = results.at(index).getJson();
|
||||||
|
ReportConsole report(data, compare);
|
||||||
|
auto header_text = report.getHeader();
|
||||||
|
auto body = report.getBody();
|
||||||
|
paginator[static_cast<int>(output_type)].setTotal(body.size());
|
||||||
|
// We need to subtract 8 from the page size to make room for the extra header in report
|
||||||
|
auto page_size = paginator[static_cast<int>(OutputType::EXPERIMENTS)].getPageSize();
|
||||||
|
paginator[static_cast<int>(output_type)].setPageSize(page_size - 8);
|
||||||
|
//
|
||||||
|
// header
|
||||||
|
//
|
||||||
|
header();
|
||||||
|
//
|
||||||
|
// Results
|
||||||
|
//
|
||||||
|
std::cout << header_text;
|
||||||
|
auto [index_from, index_to] = paginator[static_cast<int>(output_type)].getOffset();
|
||||||
|
for (int i = index_from; i <= index_to; i++) {
|
||||||
|
std::cout << body[i];
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// Status Area
|
||||||
|
//
|
||||||
|
footer(status_message, status_color);
|
||||||
|
}
|
||||||
|
void ManageScreen::list_detail(const std::string& status_message, const std::string& status_color)
|
||||||
|
{
|
||||||
|
auto data = results.at(index).getJson();
|
||||||
|
ReportConsole report(data, compare, subIndex);
|
||||||
|
auto header_text = report.getHeader();
|
||||||
|
auto body = report.getBody();
|
||||||
|
paginator[static_cast<int>(output_type)].setTotal(body.size());
|
||||||
|
// We need to subtract 8 from the page size to make room for the extra header in report
|
||||||
|
auto page_size = paginator[static_cast<int>(OutputType::EXPERIMENTS)].getPageSize();
|
||||||
|
paginator[static_cast<int>(output_type)].setPageSize(page_size - 8);
|
||||||
|
//
|
||||||
|
// header
|
||||||
|
//
|
||||||
|
header();
|
||||||
|
//
|
||||||
|
// Results
|
||||||
|
//
|
||||||
|
std::cout << header_text;
|
||||||
|
auto [index_from, index_to] = paginator[static_cast<int>(output_type)].getOffset();
|
||||||
|
for (int i = index_from; i <= index_to; i++) {
|
||||||
|
std::cout << body[i];
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// Status Area
|
||||||
|
//
|
||||||
|
footer(status_message, status_color);
|
||||||
|
}
|
||||||
|
void ManageScreen::list_datasets(const std::string& status_message, const std::string& status_color)
|
||||||
|
{
|
||||||
|
auto report = DatasetsConsole();
|
||||||
|
report.report();
|
||||||
|
paginator[static_cast<int>(output_type)].setTotal(report.getNumLines());
|
||||||
|
//
|
||||||
|
// header
|
||||||
|
//
|
||||||
|
header();
|
||||||
|
//
|
||||||
|
// Results
|
||||||
|
//
|
||||||
|
auto body = report.getBody();
|
||||||
|
std::cout << report.getHeader();
|
||||||
|
auto [index_from, index_to] = paginator[static_cast<int>(output_type)].getOffset();
|
||||||
|
for (int i = index_from; i <= index_to; i++) {
|
||||||
|
std::cout << body[i];
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// Status Area
|
||||||
|
//
|
||||||
|
footer(status_message, status_color);
|
||||||
|
}
|
||||||
|
void ManageScreen::list_experiments(const std::string& status_message, const std::string& status_color)
|
||||||
|
{
|
||||||
|
//
|
||||||
|
// header
|
||||||
|
//
|
||||||
|
header();
|
||||||
|
std::cout << Colors::RESET();
|
||||||
|
std::string arrow_dn = Symbols::down_arrow + " ";
|
||||||
|
std::string arrow_up = Symbols::up_arrow + " ";
|
||||||
|
for (int i = 0; i < header_labels.size(); i++) {
|
||||||
|
std::string suffix = "", color = Colors::GREEN();
|
||||||
|
int diff = 0;
|
||||||
|
if (header_labels[i] == sort_fields[static_cast<int>(sort_field)]) {
|
||||||
|
color = Colors::YELLOW();
|
||||||
|
diff = 2;
|
||||||
|
suffix = sort_type == SortType::ASC ? arrow_up : arrow_dn;
|
||||||
|
}
|
||||||
|
std::cout << color << std::setw(header_lengths[i] + diff) << std::left << std::string(header_labels[i] + suffix) << " ";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
for (int i = 0; i < header_labels.size(); i++) {
|
||||||
|
std::cout << std::string(header_lengths[i], '=') << " ";
|
||||||
|
}
|
||||||
|
std::cout << Colors::RESET() << std::endl;
|
||||||
|
//
|
||||||
|
// Results
|
||||||
|
//
|
||||||
|
if (results.empty()) {
|
||||||
|
std::cout << "No results found!" << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [index_from, index_to] = paginator[static_cast<int>(output_type)].getOffset();
|
||||||
|
for (int i = index_from; i <= index_to; i++) {
|
||||||
|
auto color = (i % 2) ? Colors::BLUE() : Colors::CYAN();
|
||||||
|
std::cout << color << std::setw(3) << std::fixed << std::right << i << " ";
|
||||||
|
std::cout << results.at(i).to_string(maxModel, maxTitle) << std::endl;
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// Status Area
|
||||||
|
//
|
||||||
|
footer(status_message, status_color);
|
||||||
|
}
|
||||||
|
bool ManageScreen::confirmAction(const std::string& intent, const std::string& fileName) const
|
||||||
|
{
|
||||||
|
std::string color;
|
||||||
|
if (intent == "delete") {
|
||||||
|
color = Colors::RED();
|
||||||
|
} else {
|
||||||
|
color = Colors::YELLOW();
|
||||||
|
}
|
||||||
|
std::string line;
|
||||||
|
bool finished = false;
|
||||||
|
while (!finished) {
|
||||||
|
std::cout << color << "Really want to " << intent << " " << fileName << "? (y/n): ";
|
||||||
|
getline(std::cin, line);
|
||||||
|
finished = line.size() == 1 && (tolower(line[0]) == 'y' || tolower(line[0]) == 'n');
|
||||||
|
}
|
||||||
|
if (tolower(line[0]) == 'y') {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::cout << "Not done!" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string ManageScreen::report_compared()
|
||||||
|
{
|
||||||
|
auto data_A = results.at(index_A).getJson();
|
||||||
|
auto data_B = results.at(index_B).getJson();
|
||||||
|
ReportExcelCompared reporter(data_A, data_B);
|
||||||
|
reporter.report();
|
||||||
|
didExcel = true;
|
||||||
|
return results.at(index_A).getFilename() + " Vs " + results.at(index_B).getFilename();
|
||||||
|
}
|
||||||
|
std::string ManageScreen::report(const int index, const bool excelReport)
|
||||||
|
{
|
||||||
|
auto data = results.at(index).getJson();
|
||||||
|
if (excelReport) {
|
||||||
|
didExcel = true;
|
||||||
|
ReportExcel reporter(data, compare, workbook);
|
||||||
|
reporter.show();
|
||||||
|
openExcel = true;
|
||||||
|
workbook = reporter.getWorkbook();
|
||||||
|
return results.at(index).getFilename() + "->" + Paths::excel() + Paths::excelResults();
|
||||||
|
} else {
|
||||||
|
ReportConsole reporter(data, compare);
|
||||||
|
std::cout << Colors::CLRSCR() << reporter.fileReport();
|
||||||
|
return "Reporting " + results.at(index).getFilename();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::pair<std::string, std::string> ManageScreen::sortList()
|
||||||
|
{
|
||||||
|
std::vector<std::tuple<std::string, char, bool>> sortOptions = {
|
||||||
|
{"date", 'd', false},
|
||||||
|
{"score", 's', false},
|
||||||
|
{"time", 't', false},
|
||||||
|
{"model", 'm', false},
|
||||||
|
{"ascending+", '+', false},
|
||||||
|
{"descending-", '-', false}
|
||||||
|
};
|
||||||
|
auto sortMenu = OptionsMenu(sortOptions, Colors::YELLOW(), Colors::RED(), cols);
|
||||||
|
std::string invalid_option = "Invalid sorting option";
|
||||||
|
char option;
|
||||||
|
bool parserError = true; // force the first iteration
|
||||||
|
while (parserError) {
|
||||||
|
if (checkWrongColumns())
|
||||||
|
return { Colors::RED(), "Invalid column size" };
|
||||||
|
auto [min_index, max_index] = paginator[static_cast<int>(output_type)].getOffset();
|
||||||
|
std::tie(option, index, parserError) = sortMenu.parse(' ', 0, 0);
|
||||||
|
sortMenu.updateColumns(cols);
|
||||||
|
if (parserError) {
|
||||||
|
return { Colors::RED(), invalid_option };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch (option) {
|
||||||
|
case 'd':
|
||||||
|
sort_field = SortField::DATE;
|
||||||
|
break;
|
||||||
|
case 's':
|
||||||
|
sort_field = SortField::SCORE;
|
||||||
|
break;
|
||||||
|
case 't':
|
||||||
|
sort_field = SortField::DURATION;
|
||||||
|
break;
|
||||||
|
case 'm':
|
||||||
|
sort_field = SortField::MODEL;
|
||||||
|
break;
|
||||||
|
case '+':
|
||||||
|
sort_type = SortType::ASC;
|
||||||
|
break;
|
||||||
|
case '-':
|
||||||
|
sort_type = SortType::DESC;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return { Colors::RED(), invalid_option };
|
||||||
|
}
|
||||||
|
results.sortResults(sort_field, sort_type);
|
||||||
|
return { Colors::GREEN(), "Sorted by " + sort_fields[static_cast<int>(sort_field)] + " " + (sort_type == SortType::ASC ? "ascending" : "descending") };
|
||||||
|
}
|
||||||
|
void ManageScreen::menu()
|
||||||
|
{
|
||||||
|
char option;
|
||||||
|
bool finished = false;
|
||||||
|
std::string filename;
|
||||||
|
// tuple<Option, digit, requires value>
|
||||||
|
std::vector<std::tuple<std::string, char, bool>> mainOptions = {
|
||||||
|
{"quit", 'q', false},
|
||||||
|
{"list", 'l', false},
|
||||||
|
{"Delete", 'D', true},
|
||||||
|
{"datasets", 'd', false},
|
||||||
|
{"hide", 'h', true},
|
||||||
|
{"sort", 's', false},
|
||||||
|
{"report", 'r', true},
|
||||||
|
{"excel", 'e', true},
|
||||||
|
{"title", 't', true},
|
||||||
|
{"set A", 'A', true},
|
||||||
|
{"set B", 'B', true},
|
||||||
|
{"compare A~B", 'c', false},
|
||||||
|
{"page", 'p', true},
|
||||||
|
{"Page+", '+', false },
|
||||||
|
{"Page-", '-', false}
|
||||||
|
};
|
||||||
|
// tuple<Option, digit, requires value>
|
||||||
|
std::vector<std::tuple<std::string, char, bool>> listOptions = {
|
||||||
|
{"quit", 'q', false},
|
||||||
|
{"report", 'r', true},
|
||||||
|
{"list", 'l', false},
|
||||||
|
{"excel", 'e', true},
|
||||||
|
{"back", 'b', false},
|
||||||
|
{"page", 'p', true},
|
||||||
|
{"Page+", '+', false},
|
||||||
|
{"Page-", '-', false}
|
||||||
|
};
|
||||||
|
while (!finished) {
|
||||||
|
auto main_menu = OptionsMenu(mainOptions, Colors::IGREEN(), Colors::YELLOW(), cols);
|
||||||
|
auto list_menu = OptionsMenu(listOptions, Colors::IBLUE(), Colors::YELLOW(), cols);
|
||||||
|
OptionsMenu& menu = output_type == OutputType::EXPERIMENTS ? main_menu : list_menu;
|
||||||
|
bool parserError = true; // force the first iteration
|
||||||
|
while (parserError) {
|
||||||
|
int index_menu;
|
||||||
|
auto [min_index, max_index] = paginator[static_cast<int>(output_type)].getOffset();
|
||||||
|
std::tie(option, index_menu, parserError) = menu.parse('r', min_index, max_index);
|
||||||
|
if (output_type == OutputType::EXPERIMENTS) {
|
||||||
|
index = index_menu;
|
||||||
|
} else {
|
||||||
|
subIndex = index_menu;
|
||||||
|
}
|
||||||
|
if (min_columns > cols) {
|
||||||
|
std::cerr << "Make screen bigger to fit the results! " + std::to_string(min_columns - cols) + " columns needed! " << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
menu.updateColumns(cols);
|
||||||
|
if (parserError) {
|
||||||
|
list(menu.getErrorMessage(), Colors::RED());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch (option) {
|
||||||
|
case 'd':
|
||||||
|
output_type = OutputType::DATASETS;
|
||||||
|
list_datasets(STATUS_OK, STATUS_COLOR);
|
||||||
|
break;
|
||||||
|
case 'p':
|
||||||
|
{
|
||||||
|
auto page = output_type == OutputType::EXPERIMENTS ? index : subIndex;
|
||||||
|
if (paginator[static_cast<int>(output_type)].setPage(page)) {
|
||||||
|
list(STATUS_OK, STATUS_COLOR);
|
||||||
|
} else {
|
||||||
|
list("Invalid page! (" + std::to_string(page) + ")", Colors::RED());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case '+':
|
||||||
|
if (paginator[static_cast<int>(output_type)].addPage()) {
|
||||||
|
list(STATUS_OK, STATUS_COLOR);
|
||||||
|
} else {
|
||||||
|
list("No more pages!", Colors::RED());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case '-':
|
||||||
|
if (paginator[static_cast<int>(output_type)].subPage()) {
|
||||||
|
list(STATUS_OK, STATUS_COLOR);
|
||||||
|
} else {
|
||||||
|
list("First page already!", Colors::RED());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 'q':
|
||||||
|
finished = true;
|
||||||
|
break;
|
||||||
|
case 'A':
|
||||||
|
if (index == index_B) {
|
||||||
|
list("A and B cannot be the same!", Colors::RED());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
index_A = index;
|
||||||
|
list("A set to " + std::to_string(index), Colors::GREEN());
|
||||||
|
break;
|
||||||
|
case 'B': // set_b or back to list
|
||||||
|
if (output_type == OutputType::EXPERIMENTS) {
|
||||||
|
if (index == index_A) {
|
||||||
|
list("A and B cannot be the same!", Colors::RED());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
index_B = index;
|
||||||
|
list("B set to " + std::to_string(index), Colors::GREEN());
|
||||||
|
} else {
|
||||||
|
// back to show the report
|
||||||
|
output_type = OutputType::RESULT;
|
||||||
|
paginator[static_cast<int>(OutputType::DETAIL)].setPage(1);
|
||||||
|
list(STATUS_OK, STATUS_COLOR);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 'c':
|
||||||
|
if (index_A == -1 || index_B == -1) {
|
||||||
|
list("Need to set A and B first!", Colors::RED());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
list(report_compared(), Colors::GREEN());
|
||||||
|
break;
|
||||||
|
case 'l':
|
||||||
|
output_type = OutputType::EXPERIMENTS;
|
||||||
|
paginator[static_cast<int>(OutputType::DATASETS)].setPage(1);
|
||||||
|
paginator[static_cast<int>(OutputType::RESULT)].setPage(1);
|
||||||
|
paginator[static_cast<int>(OutputType::DETAIL)].setPage(1);
|
||||||
|
list(STATUS_OK, STATUS_COLOR);
|
||||||
|
break;
|
||||||
|
case 'D':
|
||||||
|
filename = results.at(index).getFilename();
|
||||||
|
if (!confirmAction("delete", filename)) {
|
||||||
|
list(filename + " not deleted!", Colors::YELLOW());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::cout << "Deleting " << filename << std::endl;
|
||||||
|
results.deleteResult(index);
|
||||||
|
paginator[static_cast<int>(OutputType::EXPERIMENTS)].setTotal(results.size());
|
||||||
|
list(filename + " deleted!", Colors::RED());
|
||||||
|
break;
|
||||||
|
case 'h':
|
||||||
|
{
|
||||||
|
std::string status_message;
|
||||||
|
filename = results.at(index).getFilename();
|
||||||
|
if (!confirmAction("hide", filename)) {
|
||||||
|
list(filename + " not hidden!", Colors::YELLOW());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
filename = results.at(index).getFilename();
|
||||||
|
std::cout << "Hiding " << filename << std::endl;
|
||||||
|
results.hideResult(index, Paths::hiddenResults());
|
||||||
|
status_message = filename + " hidden! (moved to " + Paths::hiddenResults() + ")";
|
||||||
|
paginator[static_cast<int>(OutputType::EXPERIMENTS)].setTotal(results.size());
|
||||||
|
list(status_message, Colors::YELLOW());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 's':
|
||||||
|
{
|
||||||
|
std::string status_message, status_color;
|
||||||
|
tie(status_color, status_message) = sortList();
|
||||||
|
list(status_message, status_color);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 'r':
|
||||||
|
if (output_type == OutputType::DATASETS) {
|
||||||
|
list(STATUS_OK, STATUS_COLOR);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (output_type == OutputType::EXPERIMENTS) {
|
||||||
|
output_type = OutputType::RESULT;
|
||||||
|
paginator[static_cast<int>(OutputType::DETAIL)].setPage(1);
|
||||||
|
list(STATUS_OK, STATUS_COLOR);
|
||||||
|
} else {
|
||||||
|
output_type = OutputType::DETAIL;
|
||||||
|
list(STATUS_OK, STATUS_COLOR);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 'e':
|
||||||
|
if (output_type == OutputType::EXPERIMENTS) {
|
||||||
|
list(report(index, true), Colors::GREEN());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
list(report(subIndex, true), Colors::GREEN());
|
||||||
|
break;
|
||||||
|
case 't':
|
||||||
|
{
|
||||||
|
std::string status_message;
|
||||||
|
std::cout << "Title: " << results.at(index).getTitle() << std::endl;
|
||||||
|
std::cout << "New title: ";
|
||||||
|
std::string newTitle;
|
||||||
|
getline(std::cin, newTitle);
|
||||||
|
if (!newTitle.empty()) {
|
||||||
|
results.at(index).setTitle(newTitle);
|
||||||
|
results.at(index).save();
|
||||||
|
status_message = "Title changed to " + newTitle;
|
||||||
|
list(status_message, Colors::GREEN());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
list("No title change!", Colors::YELLOW());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} /* namespace platform */
|
62
src/manage/ManageScreen.h
Normal file
62
src/manage/ManageScreen.h
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
#ifndef MANAGE_SCREEN_H
|
||||||
|
#define MANAGE_SCREEN_H
|
||||||
|
#include <xlsxwriter.h>
|
||||||
|
#include "ResultsManager.h"
|
||||||
|
#include "common/Colors.h"
|
||||||
|
#include "Paginator.hpp"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
enum class OutputType {
|
||||||
|
EXPERIMENTS = 0,
|
||||||
|
DATASETS = 1,
|
||||||
|
RESULT = 2,
|
||||||
|
DETAIL = 3,
|
||||||
|
Count
|
||||||
|
};
|
||||||
|
class ManageScreen {
|
||||||
|
public:
|
||||||
|
ManageScreen(int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare);
|
||||||
|
~ManageScreen() = default;
|
||||||
|
void doMenu();
|
||||||
|
void updateSize(int rows, int cols);
|
||||||
|
private:
|
||||||
|
void list(const std::string& status, const std::string& color);
|
||||||
|
void list_experiments(const std::string& status, const std::string& color);
|
||||||
|
void list_result(const std::string& status, const std::string& color);
|
||||||
|
void list_detail(const std::string& status, const std::string& color);
|
||||||
|
void list_datasets(const std::string& status, const std::string& color);
|
||||||
|
bool confirmAction(const std::string& intent, const std::string& fileName) const;
|
||||||
|
std::string report(const int index, const bool excelReport);
|
||||||
|
std::string report_compared();
|
||||||
|
std::pair<std::string, std::string> sortList();
|
||||||
|
std::string getVersions();
|
||||||
|
void computeSizes();
|
||||||
|
bool checkWrongColumns();
|
||||||
|
void menu();
|
||||||
|
void header();
|
||||||
|
void footer(const std::string& status, const std::string& color);
|
||||||
|
OutputType output_type;
|
||||||
|
int rows;
|
||||||
|
int cols;
|
||||||
|
int min_columns;
|
||||||
|
int index;
|
||||||
|
int subIndex;
|
||||||
|
int index_A, index_B; // used for comparison of experiments
|
||||||
|
bool indexList;
|
||||||
|
bool openExcel;
|
||||||
|
bool didExcel;
|
||||||
|
bool complete;
|
||||||
|
bool partial;
|
||||||
|
bool compare;
|
||||||
|
int maxModel, maxTitle;
|
||||||
|
std::vector<std::string> header_labels;
|
||||||
|
std::vector<int> header_lengths;
|
||||||
|
std::vector<std::string> sort_fields;
|
||||||
|
SortField sort_field = SortField::DATE;
|
||||||
|
SortType sort_type = SortType::DESC;
|
||||||
|
std::vector<Paginator> paginator;
|
||||||
|
ResultsManager results;
|
||||||
|
lxw_workbook* workbook;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
102
src/manage/OptionsMenu.cpp
Normal file
102
src/manage/OptionsMenu.cpp
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
#include "OptionsMenu.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "common/Utils.h"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
std::string OptionsMenu::to_string()
|
||||||
|
{
|
||||||
|
bool first = true;
|
||||||
|
std::string result = color_normal + "Options: (";
|
||||||
|
size_t size = 10; // Size of "Options: ("
|
||||||
|
for (auto& option : options) {
|
||||||
|
if (!first) {
|
||||||
|
result += ", ";
|
||||||
|
size += 2;
|
||||||
|
}
|
||||||
|
std::string title = std::get<0>(option);
|
||||||
|
auto pos = title.find(std::get<1>(option));
|
||||||
|
result += color_normal + title.substr(0, pos) + color_bold + title.substr(pos, 1) + color_normal + title.substr(pos + 1);
|
||||||
|
size += title.size();
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
if (size + 3 > cols) { // 3 is the size of the "): " at the end
|
||||||
|
result = "";
|
||||||
|
first = true;
|
||||||
|
for (auto& option : options) {
|
||||||
|
if (!first) {
|
||||||
|
result += color_normal + ", ";
|
||||||
|
}
|
||||||
|
result += color_bold + std::get<1>(option);
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result += "): ";
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
std::tuple<char, int, bool> OptionsMenu::parse(char defaultCommand, int minIndex, int maxIndex)
|
||||||
|
{
|
||||||
|
bool finished = false;
|
||||||
|
while (!finished) {
|
||||||
|
std::cout << to_string();
|
||||||
|
std::string line;
|
||||||
|
getline(std::cin, line);
|
||||||
|
line = trim(line);
|
||||||
|
if (line.size() == 0) {
|
||||||
|
errorMessage = "No command";
|
||||||
|
return { defaultCommand, 0, true };
|
||||||
|
}
|
||||||
|
if (all_of(line.begin(), line.end(), ::isdigit)) {
|
||||||
|
command = defaultCommand;
|
||||||
|
index = stoi(line);
|
||||||
|
if (index > maxIndex || index < minIndex) {
|
||||||
|
errorMessage = "Index out of range";
|
||||||
|
return { ' ', -1, true };
|
||||||
|
}
|
||||||
|
finished = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
bool found = false;
|
||||||
|
for (auto& option : options) {
|
||||||
|
if (line[0] == std::get<char>(option)) {
|
||||||
|
found = true;
|
||||||
|
// it's a match
|
||||||
|
line.erase(line.begin());
|
||||||
|
line = trim(line);
|
||||||
|
if (std::get<bool>(option)) {
|
||||||
|
// The option requires a value
|
||||||
|
if (line.size() == 0) {
|
||||||
|
errorMessage = "Option " + std::get<std::string>(option) + " requires a value";
|
||||||
|
return { command, index, true };
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
index = stoi(line);
|
||||||
|
if (index > maxIndex || index < 0) {
|
||||||
|
errorMessage = "Index out of range";
|
||||||
|
return { command, index, true };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (const std::invalid_argument& ia) {
|
||||||
|
errorMessage = "Invalid value: " + line;
|
||||||
|
return { command, index, true };
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (line.size() > 0) {
|
||||||
|
errorMessage = "option " + std::get<std::string>(option) + " doesn't accept values";
|
||||||
|
return { command, index, true };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
command = std::get<char>(option);
|
||||||
|
finished = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
errorMessage = "I don't know " + line;
|
||||||
|
return { command, index, true };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { command, index, false };
|
||||||
|
}
|
||||||
|
} /* namespace platform */
|
26
src/manage/OptionsMenu.h
Normal file
26
src/manage/OptionsMenu.h
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
#ifndef OPTIONS_MENU_H
|
||||||
|
#define OPTIONS_MENU_H
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
class OptionsMenu {
|
||||||
|
public:
|
||||||
|
OptionsMenu(std::vector<std::tuple<std::string, char, bool>>& options, std::string color_normal, std::string color_bold, int cols) : options(options), color_normal(color_normal), color_bold(color_bold), cols(cols) {}
|
||||||
|
std::string to_string();
|
||||||
|
std::tuple<char, int, bool> parse(char defaultCommand, int minIndex, int maxIndex);
|
||||||
|
char getCommand() const { return command; };
|
||||||
|
int getIndex() const { return index; };
|
||||||
|
std::string getErrorMessage() const { return errorMessage; };
|
||||||
|
void updateColumns(int cols) { this->cols = cols; }
|
||||||
|
private:
|
||||||
|
std::vector<std::tuple<std::string, char, bool>>& options;
|
||||||
|
std::string color_normal, color_bold;
|
||||||
|
int cols;
|
||||||
|
std::string errorMessage;
|
||||||
|
char command;
|
||||||
|
int index;
|
||||||
|
};
|
||||||
|
} /* namespace platform */
|
||||||
|
#endif
|
57
src/manage/Paginator.hpp
Normal file
57
src/manage/Paginator.hpp
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
#ifndef PAGINATOR_HPP
|
||||||
|
#define PAGINATOR_HPP
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
class Paginator {
|
||||||
|
public:
|
||||||
|
Paginator() = default;
|
||||||
|
Paginator(int pageSize, int total, int page = 1) : pageSize(pageSize), total(total), page(page)
|
||||||
|
{
|
||||||
|
computePages();
|
||||||
|
};
|
||||||
|
~Paginator() = default;
|
||||||
|
// Getters
|
||||||
|
int getPageSize() const { return pageSize; }
|
||||||
|
int getLines() const
|
||||||
|
{
|
||||||
|
auto [start, end] = getOffset();
|
||||||
|
return std::min(pageSize, end - start + 1);
|
||||||
|
}
|
||||||
|
int getPage() const { return page; }
|
||||||
|
int getTotal() const { return total; }
|
||||||
|
int getPages() const { return numPages; }
|
||||||
|
std::pair<int, int> getOffset() const
|
||||||
|
{
|
||||||
|
return { (page - 1) * pageSize, std::min(total - 1, page * pageSize - 1) };
|
||||||
|
}
|
||||||
|
// Setters
|
||||||
|
void setTotal(int total) { this->total = total; computePages(); }
|
||||||
|
void setPageSize(int page) { this->pageSize = page; computePages(); }
|
||||||
|
bool setPage(int page) { return valid(page) ? this->page = page, true : false; }
|
||||||
|
// Utils
|
||||||
|
bool valid(int page) const { return page > 0 && page <= numPages; }
|
||||||
|
bool hasPrev(int page) const { return page > 1; }
|
||||||
|
bool hasNext(int page) const { return page < getPages(); }
|
||||||
|
bool addPage() { return page < numPages ? ++page, true : false; }
|
||||||
|
bool subPage() { return page > 1 ? --page, true : false; }
|
||||||
|
std::string to_string() const
|
||||||
|
{
|
||||||
|
auto offset = getOffset();
|
||||||
|
return "Paginator: { pageSize: " + std::to_string(pageSize) + ", total: " + std::to_string(total)
|
||||||
|
+ ", page: " + std::to_string(page) + ", numPages: " + std::to_string(numPages)
|
||||||
|
+ " Offset [" + std::to_string(offset.first) + ", " + std::to_string(offset.second) + "]}";
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
void computePages()
|
||||||
|
{
|
||||||
|
numPages = pageSize > 0 ? (total + pageSize - 1) / pageSize : 0;
|
||||||
|
if (page > numPages) {
|
||||||
|
page = numPages;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int pageSize;
|
||||||
|
int total;
|
||||||
|
int page;
|
||||||
|
int numPages;
|
||||||
|
};
|
||||||
|
#endif
|
@@ -1,75 +0,0 @@
|
|||||||
#include "Results.h"
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
namespace platform {
|
|
||||||
Results::Results(const std::string& path, const std::string& model, const std::string& score, bool complete, bool partial) :
|
|
||||||
path(path), model(model), scoreName(score), complete(complete), partial(partial)
|
|
||||||
{
|
|
||||||
load();
|
|
||||||
if (!files.empty()) {
|
|
||||||
maxModel = (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size();
|
|
||||||
} else {
|
|
||||||
maxModel = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void Results::load()
|
|
||||||
{
|
|
||||||
using std::filesystem::directory_iterator;
|
|
||||||
for (const auto& file : directory_iterator(path)) {
|
|
||||||
auto filename = file.path().filename().string();
|
|
||||||
if (filename.find(".json") != std::string::npos && filename.find("results_") == 0) {
|
|
||||||
auto result = Result();
|
|
||||||
result.load(path, filename);
|
|
||||||
bool addResult = true;
|
|
||||||
if (model != "any" && result.getModel() != model || scoreName != "any" && scoreName != result.getScoreName() || complete && !result.isComplete() || partial && result.isComplete())
|
|
||||||
addResult = false;
|
|
||||||
if (addResult)
|
|
||||||
files.push_back(result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void Results::hideResult(int index, const std::string& pathHidden)
|
|
||||||
{
|
|
||||||
auto filename = files.at(index).getFilename();
|
|
||||||
rename((path + "/" + filename).c_str(), (pathHidden + "/" + filename).c_str());
|
|
||||||
files.erase(files.begin() + index);
|
|
||||||
}
|
|
||||||
void Results::deleteResult(int index)
|
|
||||||
{
|
|
||||||
auto filename = files.at(index).getFilename();
|
|
||||||
remove((path + "/" + filename).c_str());
|
|
||||||
files.erase(files.begin() + index);
|
|
||||||
}
|
|
||||||
int Results::size() const
|
|
||||||
{
|
|
||||||
return files.size();
|
|
||||||
}
|
|
||||||
void Results::sortDate()
|
|
||||||
{
|
|
||||||
sort(files.begin(), files.end(), [](const Result& a, const Result& b) {
|
|
||||||
return a.getDate() > b.getDate();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
void Results::sortModel()
|
|
||||||
{
|
|
||||||
sort(files.begin(), files.end(), [](const Result& a, const Result& b) {
|
|
||||||
return a.getModel() > b.getModel();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
void Results::sortDuration()
|
|
||||||
{
|
|
||||||
sort(files.begin(), files.end(), [](const Result& a, const Result& b) {
|
|
||||||
return a.getDuration() > b.getDuration();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
void Results::sortScore()
|
|
||||||
{
|
|
||||||
sort(files.begin(), files.end(), [](const Result& a, const Result& b) {
|
|
||||||
return a.getScore() > b.getScore();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
bool Results::empty() const
|
|
||||||
{
|
|
||||||
return files.empty();
|
|
||||||
}
|
|
||||||
}
|
|
130
src/manage/ResultsManager.cpp
Normal file
130
src/manage/ResultsManager.cpp
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include "common/Paths.h"
|
||||||
|
#include "ResultsManager.h"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
ResultsManager::ResultsManager(const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial) :
|
||||||
|
path(Paths::results()), model(model), scoreName(score), platform(platform), complete(complete), partial(partial), maxModel(0), maxTitle(0)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
void ResultsManager::load()
|
||||||
|
{
|
||||||
|
using std::filesystem::directory_iterator;
|
||||||
|
bool found = false;
|
||||||
|
for (const auto& file : directory_iterator(path)) {
|
||||||
|
auto filename = file.path().filename().string();
|
||||||
|
if (filename.find(".json") != std::string::npos && filename.find("results_") == 0) {
|
||||||
|
auto result = Result();
|
||||||
|
result.load(path, filename);
|
||||||
|
bool addResult = true;
|
||||||
|
if (platform != "any" && result.getPlatform() != platform
|
||||||
|
|| model != "any" && result.getModel() != model
|
||||||
|
|| scoreName != "any" && scoreName != result.getScoreName()
|
||||||
|
|| complete && !result.isComplete()
|
||||||
|
|| partial && result.isComplete())
|
||||||
|
addResult = false;
|
||||||
|
if (addResult) {
|
||||||
|
files.push_back(result);
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (found) {
|
||||||
|
maxModel = std::max(size_t(5), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size());
|
||||||
|
maxTitle = std::max(size_t(5), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getTitle().size() < b.getTitle().size(); })).getTitle().size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void ResultsManager::hideResult(int index, const std::string& pathHidden)
|
||||||
|
{
|
||||||
|
auto filename = files.at(index).getFilename();
|
||||||
|
rename((path + "/" + filename).c_str(), (pathHidden + "/" + filename).c_str());
|
||||||
|
files.erase(files.begin() + index);
|
||||||
|
}
|
||||||
|
void ResultsManager::deleteResult(int index)
|
||||||
|
{
|
||||||
|
auto filename = files.at(index).getFilename();
|
||||||
|
remove((path + "/" + filename).c_str());
|
||||||
|
files.erase(files.begin() + index);
|
||||||
|
}
|
||||||
|
int ResultsManager::size() const
|
||||||
|
{
|
||||||
|
return files.size();
|
||||||
|
}
|
||||||
|
void ResultsManager::sortDate(SortType type)
|
||||||
|
{
|
||||||
|
if (empty())
|
||||||
|
return;
|
||||||
|
sort(files.begin(), files.end(), [type](const Result& a, const Result& b) {
|
||||||
|
if (a.getDate() == b.getDate()) {
|
||||||
|
if (type == SortType::ASC)
|
||||||
|
return a.getModel() < b.getModel();
|
||||||
|
return a.getModel() > b.getModel();
|
||||||
|
}
|
||||||
|
if (type == SortType::ASC)
|
||||||
|
return a.getDate() < b.getDate();
|
||||||
|
return a.getDate() > b.getDate();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
void ResultsManager::sortModel(SortType type)
|
||||||
|
{
|
||||||
|
if (empty())
|
||||||
|
return;
|
||||||
|
sort(files.begin(), files.end(), [type](const Result& a, const Result& b) {
|
||||||
|
if (a.getModel() == b.getModel()) {
|
||||||
|
if (type == SortType::ASC)
|
||||||
|
return a.getDate() < b.getDate();
|
||||||
|
return a.getDate() > b.getDate();
|
||||||
|
}
|
||||||
|
if (type == SortType::ASC)
|
||||||
|
return a.getModel() < b.getModel();
|
||||||
|
return a.getModel() > b.getModel();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
void ResultsManager::sortDuration(SortType type)
|
||||||
|
{
|
||||||
|
if (empty())
|
||||||
|
return;
|
||||||
|
sort(files.begin(), files.end(), [type](const Result& a, const Result& b) {
|
||||||
|
if (type == SortType::ASC)
|
||||||
|
return a.getDuration() < b.getDuration();
|
||||||
|
return a.getDuration() > b.getDuration();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
void ResultsManager::sortScore(SortType type)
|
||||||
|
{
|
||||||
|
if (empty())
|
||||||
|
return;
|
||||||
|
sort(files.begin(), files.end(), [type](const Result& a, const Result& b) {
|
||||||
|
if (a.getScore() == b.getScore()) {
|
||||||
|
if (type == SortType::ASC)
|
||||||
|
return a.getDate() < b.getDate();
|
||||||
|
return a.getDate() > b.getDate();
|
||||||
|
}
|
||||||
|
if (type == SortType::ASC)
|
||||||
|
return a.getScore() < b.getScore();
|
||||||
|
return a.getScore() > b.getScore();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResultsManager::sortResults(SortField field, SortType type)
|
||||||
|
{
|
||||||
|
switch (field) {
|
||||||
|
case SortField::DATE:
|
||||||
|
sortDate(type);
|
||||||
|
break;
|
||||||
|
case SortField::MODEL:
|
||||||
|
sortModel(type);
|
||||||
|
break;
|
||||||
|
case SortField::SCORE:
|
||||||
|
sortScore(type);
|
||||||
|
break;
|
||||||
|
case SortField::DURATION:
|
||||||
|
sortDuration(type);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool ResultsManager::empty() const
|
||||||
|
{
|
||||||
|
return files.empty();
|
||||||
|
}
|
||||||
|
}
|
49
src/manage/ResultsManager.h
Normal file
49
src/manage/ResultsManager.h
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
#ifndef RESULTSMANAGER_H
|
||||||
|
#define RESULTSMANAGER_H
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "results/Result.h"
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
enum class SortType {
|
||||||
|
ASC = 0,
|
||||||
|
DESC = 1,
|
||||||
|
};
|
||||||
|
enum class SortField {
|
||||||
|
DATE = 0,
|
||||||
|
MODEL = 1,
|
||||||
|
SCORE = 2,
|
||||||
|
DURATION = 3,
|
||||||
|
};
|
||||||
|
class ResultsManager {
|
||||||
|
public:
|
||||||
|
ResultsManager(const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial);
|
||||||
|
void load(); // Loads the list of results
|
||||||
|
void sortResults(SortField field, SortType type); // Sorts the list of results
|
||||||
|
void sortDate(SortType type);
|
||||||
|
void sortScore(SortType type);
|
||||||
|
void sortModel(SortType type);
|
||||||
|
void sortDuration(SortType type);
|
||||||
|
int maxModelSize() const { return maxModel; };
|
||||||
|
int maxTitleSize() const { return maxTitle; };
|
||||||
|
void hideResult(int index, const std::string& pathHidden);
|
||||||
|
void deleteResult(int index);
|
||||||
|
int size() const;
|
||||||
|
bool empty() const;
|
||||||
|
std::vector<Result>::iterator begin() { return files.begin(); };
|
||||||
|
std::vector<Result>::iterator end() { return files.end(); };
|
||||||
|
Result& at(int index) { return files.at(index); };
|
||||||
|
private:
|
||||||
|
std::string path;
|
||||||
|
std::string model;
|
||||||
|
std::string scoreName;
|
||||||
|
std::string platform;
|
||||||
|
bool complete;
|
||||||
|
bool partial;
|
||||||
|
int maxModel;
|
||||||
|
int maxTitle;
|
||||||
|
std::vector<Result> files;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
#endif
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user