Compare commits
34 Commits
676637fb1b
...
main
Author | SHA1 | Date | |
---|---|---|---|
89142f8997
|
|||
17ee6a909a | |||
56d85b1a43
|
|||
481c702302
|
|||
3e0b790cfe
|
|||
e2a0c5f4a5
|
|||
aa77745e55
|
|||
e5227c5f4b
|
|||
ed380b1494
|
|||
2c7352ac38
|
|||
0ce7f664b4
|
|||
62fa85a1b3
|
|||
97894cc49c
|
|||
090172c6c5
|
|||
3048244a27
|
|||
c142ff2c4a
|
|||
a5841000d3
|
|||
e7e80cfa9c
|
|||
1d58cea276
|
|||
189d314990
|
|||
dfa74056f5
|
|||
839be5335d
|
|||
28be43db02 | |||
55a24fbaf0
|
|||
3b170324f4 | |||
8ccc7e263c
|
|||
b1e25a7d05
|
|||
3cb454d4aa
|
|||
3178bcbda9
|
|||
32d231cdaf
|
|||
526d036d75
|
|||
7a9d4178d9
|
|||
3e94d400e2 | |||
31fa9cd498
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -46,3 +46,4 @@ docs/man
|
||||
docs/Doxyfile
|
||||
.cache
|
||||
vcpkg_installed
|
||||
CMakeUserPresets.json
|
||||
|
16
CHANGELOG.md
16
CHANGELOG.md
@@ -5,9 +5,17 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [1.2.1] - 2025-07-19
|
||||
|
||||
## [1.2.0] - 2025-06-30
|
||||
### Internal
|
||||
|
||||
- Update Libtorch to version 2.7.1
|
||||
- Update libraries versions:
|
||||
- mdlp: 2.1.1
|
||||
- Folding: 1.1.2
|
||||
- ArffFiles: 1.2.1
|
||||
|
||||
## [1.2.0] - 2025-07-08
|
||||
|
||||
### Internal
|
||||
|
||||
@@ -17,6 +25,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- *ld_proposed_cuts*: number of cut points to return.
|
||||
- *mdlp_min_length*: minimum length of a partition in MDLP algorithm to be evaluated for partition.
|
||||
- *mdlp_max_depth*: maximum level of recursion in MDLP algorithm.
|
||||
- *max_iterations*: maximum number of iterations of discretization-build model loop.
|
||||
- *verbose_convergence*: display status messages during the convergence process.
|
||||
- Remove vcpkg as a dependency manager, now the library is built with Conan package manager and CMake.
|
||||
- Add `build_type` option to the sample target in the Makefile to allow building in *Debug* or *Release* mode. Default is *Debug*.
|
||||
|
||||
## [1.1.1] - 2025-05-20
|
||||
|
||||
|
95
CLAUDE.md
95
CLAUDE.md
@@ -9,11 +9,26 @@ BayesNet is a C++ library implementing Bayesian Network Classifiers. It provides
|
||||
## Build System & Dependencies
|
||||
|
||||
### Dependency Management
|
||||
- Uses **vcpkg** for package management with private registry at https://github.com/rmontanana/vcpkg-stash
|
||||
|
||||
The project supports **two package managers**:
|
||||
|
||||
#### vcpkg (Default)
|
||||
|
||||
- Uses vcpkg with private registry at <https://github.com/rmontanana/vcpkg-stash>
|
||||
- Core dependencies: libtorch, nlohmann-json, folding, fimdlp, arff-files, catch2
|
||||
- All dependencies defined in `vcpkg.json` with version overrides
|
||||
|
||||
#### Conan (Alternative)
|
||||
|
||||
- Modern C++ package manager with better dependency resolution
|
||||
- Configured via `conanfile.py` for packaging and distribution
|
||||
- Supports subset of dependencies (libtorch, nlohmann-json, catch2)
|
||||
- Custom dependencies (folding, fimdlp, arff-files) need custom Conan recipes
|
||||
|
||||
### Build Commands
|
||||
|
||||
#### Using vcpkg (Default)
|
||||
|
||||
```bash
|
||||
# Initialize dependencies
|
||||
make init
|
||||
@@ -37,11 +52,40 @@ make viewcoverage
|
||||
make clean
|
||||
```
|
||||
|
||||
#### Using Conan
|
||||
|
||||
```bash
|
||||
# Install Conan first: pip install conan
|
||||
|
||||
# Initialize dependencies
|
||||
make conan-init
|
||||
|
||||
# Build debug version (with tests and coverage)
|
||||
make conan-debug
|
||||
make buildd
|
||||
|
||||
# Build release version
|
||||
make conan-release
|
||||
make buildr
|
||||
|
||||
# Create and test Conan package
|
||||
make conan-create
|
||||
|
||||
# Upload to Conan remote
|
||||
make conan-upload remote=myremote
|
||||
|
||||
# Clean Conan cache and builds
|
||||
make conan-clean
|
||||
```
|
||||
|
||||
### CMake Configuration
|
||||
|
||||
- Uses CMake 3.27+ with C++17 standard
|
||||
- Debug builds automatically enable testing and coverage
|
||||
- Release builds optimize with `-Ofast`
|
||||
- Supports both static library and vcpkg package installation
|
||||
- **Automatic package manager detection**: CMake detects whether Conan or vcpkg is being used
|
||||
- Supports both static library and package manager installation
|
||||
- Conditional dependency linking based on availability
|
||||
|
||||
## Testing Framework
|
||||
|
||||
@@ -51,6 +95,7 @@ make clean
|
||||
- Coverage reporting with lcov/genhtml
|
||||
|
||||
### Test Categories
|
||||
|
||||
- A2DE, BoostA2DE, BoostAODE, XSPODE, XSPnDE, XBAODE, XBA2DE
|
||||
- Classifier, Ensemble, FeatureSelection, Metrics, Models
|
||||
- Network, Node, MST, Modules
|
||||
@@ -58,6 +103,7 @@ make clean
|
||||
## Code Architecture
|
||||
|
||||
### Core Structure
|
||||
|
||||
```
|
||||
bayesnet/
|
||||
├── BaseClassifier.h # Abstract base for all classifiers
|
||||
@@ -69,12 +115,14 @@ bayesnet/
|
||||
```
|
||||
|
||||
### Key Design Patterns
|
||||
|
||||
- **BaseClassifier** abstract interface for all algorithms
|
||||
- Template-based design with both std::vector and torch::Tensor support
|
||||
- Network/Node abstraction for Bayesian network representation
|
||||
- Feature selection as separate, composable modules
|
||||
|
||||
### Data Handling
|
||||
|
||||
- Supports both discrete integer data and continuous data with discretization
|
||||
- ARFF file format support through arff-files library
|
||||
- Tensor operations via PyTorch C++ (libtorch)
|
||||
@@ -90,13 +138,54 @@ bayesnet/
|
||||
## Sample Applications
|
||||
|
||||
Sample code in `sample/` directory demonstrates library usage:
|
||||
|
||||
```bash
|
||||
make sample fname=tests/data/iris.arff model=TANLd
|
||||
```
|
||||
|
||||
## Package Distribution
|
||||
|
||||
### Creating Conan Packages
|
||||
|
||||
```bash
|
||||
# Create package locally
|
||||
make conan-create
|
||||
|
||||
# Test package installation
|
||||
cd test_package
|
||||
conan create ..
|
||||
|
||||
# Upload to remote repository
|
||||
make conan-upload remote=myremote profile=myprofile
|
||||
```
|
||||
|
||||
### Using the Library
|
||||
|
||||
With Conan:
|
||||
|
||||
```python
|
||||
# conanfile.txt or conanfile.py
|
||||
[requires]
|
||||
bayesnet/1.1.2@user/channel
|
||||
|
||||
[generators]
|
||||
cmake
|
||||
```
|
||||
|
||||
With vcpkg:
|
||||
|
||||
```json
|
||||
{
|
||||
"dependencies": ["bayesnet"]
|
||||
}
|
||||
```
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
- **Add new classifier**: Extend BaseClassifier, implement in appropriate subdirectory
|
||||
- **Add new test**: Update `tests/CMakeLists.txt` and create test in `tests/`
|
||||
- **Modify build**: Edit main `CMakeLists.txt` or use Makefile targets
|
||||
- **Update dependencies**: Modify `vcpkg.json` and run `make init`
|
||||
- **Update dependencies**:
|
||||
- vcpkg: Modify `vcpkg.json` and run `make init`
|
||||
- Conan: Modify `conanfile.py` and run `make conan-init`
|
||||
- **Package for distribution**: Use `make conan-create` for Conan packaging
|
||||
|
@@ -1,7 +1,7 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(bayesnet
|
||||
VERSION 1.1.2
|
||||
VERSION 1.2.1
|
||||
DESCRIPTION "Bayesian Network and basic classifiers Library."
|
||||
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
|
||||
LANGUAGES CXX
|
||||
@@ -10,11 +10,6 @@ project(bayesnet
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
find_package(Torch CONFIG REQUIRED)
|
||||
find_package(fimdlp CONFIG REQUIRED)
|
||||
find_package(nlohmann_json CONFIG REQUIRED)
|
||||
find_package(folding CONFIG REQUIRED)
|
||||
|
||||
# Global CMake variables
|
||||
# ----------------------
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -23,25 +18,34 @@ set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fprofile-arcs -ftest-coverage -fno-elide-constructors")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast")
|
||||
if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-default-inline")
|
||||
endif()
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
|
||||
|
||||
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
MESSAGE("Debug mode")
|
||||
else(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
MESSAGE("Release mode")
|
||||
endif (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
|
||||
# Options
|
||||
# -------
|
||||
option(ENABLE_CLANG_TIDY "Enable to add clang tidy" OFF)
|
||||
option(ENABLE_TESTING "Unit testing build" OFF)
|
||||
option(CODE_COVERAGE "Collect coverage from test library" OFF)
|
||||
option(INSTALL_GTEST "Enable installation of googletest" OFF)
|
||||
|
||||
find_package(Torch CONFIG REQUIRED)
|
||||
if(NOT TARGET torch::torch)
|
||||
add_library(torch::torch INTERFACE IMPORTED GLOBAL)
|
||||
# expose include paths and libraries that the find-module discovered
|
||||
set_target_properties(torch::torch PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}"
|
||||
INTERFACE_LINK_LIBRARIES "${TORCH_LIBRARIES}")
|
||||
endif()
|
||||
|
||||
find_package(fimdlp CONFIG REQUIRED)
|
||||
find_package(folding CONFIG REQUIRED)
|
||||
find_package(nlohmann_json REQUIRED)
|
||||
|
||||
add_subdirectory(config)
|
||||
|
||||
if (ENABLE_CLANG_TIDY)
|
||||
include(StaticAnalyzers) # clang-tidy
|
||||
endif (ENABLE_CLANG_TIDY)
|
||||
|
||||
# Add the library
|
||||
# ---------------
|
||||
include_directories(
|
||||
@@ -52,24 +56,30 @@ include_directories(
|
||||
file(GLOB_RECURSE Sources "bayesnet/*.cc")
|
||||
|
||||
add_library(bayesnet ${Sources})
|
||||
target_link_libraries(bayesnet fimdlp::fimdlp folding::folding "${TORCH_LIBRARIES}")
|
||||
|
||||
target_link_libraries(bayesnet
|
||||
nlohmann_json::nlohmann_json
|
||||
folding::folding
|
||||
fimdlp::fimdlp
|
||||
torch::torch
|
||||
arff-files::arff-files
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Testing
|
||||
# -------
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
MESSAGE("Debug mode")
|
||||
set(ENABLE_TESTING ON)
|
||||
set(CODE_COVERAGE ON)
|
||||
endif (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
if (ENABLE_TESTING)
|
||||
MESSAGE(STATUS "Testing enabled")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fprofile-arcs -ftest-coverage -fno-elide-constructors")
|
||||
if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-default-inline")
|
||||
endif()
|
||||
find_package(Catch2 CONFIG REQUIRED)
|
||||
find_package(arff-files CONFIG REQUIRED)
|
||||
enable_testing()
|
||||
include(CTest)
|
||||
add_subdirectory(tests)
|
||||
else(ENABLE_TESTING)
|
||||
message("Release mode")
|
||||
endif (ENABLE_TESTING)
|
||||
|
||||
# Installation
|
||||
@@ -89,17 +99,14 @@ configure_package_config_file(
|
||||
install(TARGETS bayesnet
|
||||
EXPORT bayesnetTargets
|
||||
ARCHIVE DESTINATION lib
|
||||
LIBRARY DESTINATION lib
|
||||
CONFIGURATIONS Release)
|
||||
LIBRARY DESTINATION lib)
|
||||
|
||||
install(DIRECTORY bayesnet/
|
||||
DESTINATION include/bayesnet
|
||||
FILES_MATCHING
|
||||
CONFIGURATIONS Release
|
||||
PATTERN "*.h")
|
||||
install(FILES ${CMAKE_BINARY_DIR}/configured_files/include/bayesnet/config.h
|
||||
DESTINATION include/bayesnet
|
||||
CONFIGURATIONS Release)
|
||||
DESTINATION include/bayesnet)
|
||||
|
||||
install(EXPORT bayesnetTargets
|
||||
FILE bayesnetTargets.cmake
|
||||
|
86
CONAN_README.md
Normal file
86
CONAN_README.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# Using BayesNet with Conan
|
||||
|
||||
This document explains how to use Conan as an alternative package manager for BayesNet.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
```bash
|
||||
pip install conan
|
||||
conan remote add Cimmeria https://conan.rmontanana.es/artifactory/api/conan/Cimmeria
|
||||
conan profile new default --detect
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### As a Consumer
|
||||
|
||||
1. Create a `conanfile.txt` in your project:
|
||||
|
||||
```ini
|
||||
[requires]
|
||||
libtorch/2.7.0
|
||||
bayesnet/1.2.0
|
||||
|
||||
[generators]
|
||||
CMakeDeps
|
||||
CMakeToolchain
|
||||
```
|
||||
|
||||
1. Install dependencies:
|
||||
|
||||
```bash
|
||||
conan install . --build=missing
|
||||
```
|
||||
|
||||
1. In your CMakeLists.txt:
|
||||
|
||||
```cmake
|
||||
find_package(bayesnet REQUIRED)
|
||||
target_link_libraries(your_target bayesnet::bayesnet)
|
||||
```
|
||||
|
||||
### Building BayesNet with Conan
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
make conan-init
|
||||
|
||||
# Build debug version
|
||||
make debug
|
||||
make buildd
|
||||
|
||||
# Build release version
|
||||
make release
|
||||
make buildr
|
||||
|
||||
# Create package
|
||||
make conan-create
|
||||
```
|
||||
|
||||
## Current Limitations
|
||||
|
||||
- Custom dependencies (folding, fimdlp, arff-files) are not available in ConanCenter
|
||||
- These need to be built as custom Conan packages or replaced with alternatives
|
||||
- The conanfile.py currently comments out these dependencies
|
||||
|
||||
## Creating Custom Dependency Packages
|
||||
|
||||
For the custom dependencies, you'll need to create Conan recipes:
|
||||
|
||||
1. **folding**: Cross-validation library
|
||||
1. **fimdlp**: Discretization library
|
||||
1. **arff-files**: ARFF file format parser
|
||||
|
||||
Contact the maintainer or create custom recipes for these packages.
|
||||
|
||||
## Package Distribution
|
||||
|
||||
Once custom dependencies are resolved:
|
||||
|
||||
```bash
|
||||
# Create and test package
|
||||
make conan-create
|
||||
|
||||
# Upload to your remote
|
||||
conan upload bayesnet/1.2.0 -r myremote
|
||||
```
|
166
Makefile
166
Makefile
@@ -1,6 +1,6 @@
|
||||
SHELL := /bin/bash
|
||||
.DEFAULT_GOAL := help
|
||||
.PHONY: viewcoverage coverage setup help install uninstall diagrams buildr buildd test clean debug release sample updatebadge doc doc-install init clean-test
|
||||
.PHONY: viewcoverage coverage setup help install uninstall diagrams buildr buildd test clean updatebadge doc doc-install init clean-test debug release conan-create conan-upload conan-clean sample
|
||||
|
||||
f_release = build_Release
|
||||
f_debug = build_Debug
|
||||
@@ -17,6 +17,14 @@ mansrcdir = docs/man3
|
||||
mandestdir = /usr/local/share/man
|
||||
sed_command_link = 's/e">LCOV -/e"><a href="https:\/\/rmontanana.github.io\/bayesnet">Back to manual<\/a> LCOV -/g'
|
||||
sed_command_diagram = 's/Diagram"/Diagram" width="100%" height="100%" /g'
|
||||
# Set the number of parallel jobs to the number of available processors minus 7
|
||||
CPUS := $(shell getconf _NPROCESSORS_ONLN 2>/dev/null \
|
||||
|| nproc --all 2>/dev/null \
|
||||
|| sysctl -n hw.ncpu)
|
||||
|
||||
# --- Your desired job count: CPUs – 7, but never less than 1 --------------
|
||||
JOBS := $(shell n=$(CPUS); [ $${n} -gt 7 ] && echo $$((n-7)) || echo 1)
|
||||
|
||||
|
||||
define ClearTests
|
||||
@for t in $(test_targets); do \
|
||||
@@ -31,6 +39,14 @@ define ClearTests
|
||||
fi ;
|
||||
endef
|
||||
|
||||
define setup_target
|
||||
@echo ">>> Setup the project for $(1)..."
|
||||
@if [ -d $(2) ]; then rm -fr $(2); fi
|
||||
@conan install . --build=missing -of $(2) -s build_type=$(1)
|
||||
@cmake -S . -B $(2) -DCMAKE_TOOLCHAIN_FILE=$(2)/build/$(1)/generators/conan_toolchain.cmake -DCMAKE_BUILD_TYPE=$(1) -D$(3)
|
||||
@echo ">>> Will build using $(JOBS) parallel jobs"
|
||||
@echo ">>> Done"
|
||||
endef
|
||||
|
||||
setup: ## Install dependencies for tests and coverage
|
||||
@if [ "$(shell uname)" = "Darwin" ]; then \
|
||||
@@ -43,30 +59,36 @@ setup: ## Install dependencies for tests and coverage
|
||||
fi
|
||||
@echo "* You should install plantuml & graphviz for the diagrams"
|
||||
|
||||
diagrams: ## Create an UML class diagram & dependency of the project (diagrams/BayesNet.png)
|
||||
@which $(plantuml) || (echo ">>> Please install plantuml"; exit 1)
|
||||
@which $(dot) || (echo ">>> Please install graphviz"; exit 1)
|
||||
@which $(clang-uml) || (echo ">>> Please install clang-uml"; exit 1)
|
||||
@export PLANTUML_LIMIT_SIZE=16384
|
||||
@echo ">>> Creating UML class diagram of the project...";
|
||||
@$(clang-uml) -p
|
||||
@cd $(f_diagrams); \
|
||||
$(plantuml) -tsvg BayesNet.puml
|
||||
@echo ">>> Creating dependency graph diagram of the project...";
|
||||
$(MAKE) debug
|
||||
cd $(f_debug) && cmake .. --graphviz=dependency.dot
|
||||
@$(dot) -Tsvg $(f_debug)/dependency.dot.BayesNet -o $(f_diagrams)/dependency.svg
|
||||
clean: ## Clean the project
|
||||
@echo ">>> Cleaning the project..."
|
||||
@if test -f CMakeCache.txt ; then echo "- Deleting CMakeCache.txt"; rm -f CMakeCache.txt; fimake
|
||||
@for folder in $(f_release) $(f_debug) vpcpkg_installed install_test ; do \
|
||||
if test -d "$$folder" ; then \
|
||||
echo "- Deleting $$folder folder" ; \
|
||||
rm -rf "$$folder"; \
|
||||
fi; \
|
||||
done
|
||||
@$(MAKE) clean-test
|
||||
@echo ">>> Done";
|
||||
|
||||
# Build targets
|
||||
# =============
|
||||
|
||||
debug: ## Setup debug version using Conan
|
||||
@$(call setup_target,"Debug","$(f_debug)","ENABLE_TESTING=ON")
|
||||
|
||||
release: ## Setup release version using Conan
|
||||
@$(call setup_target,"Release","$(f_release)","ENABLE_TESTING=OFF")
|
||||
|
||||
buildd: ## Build the debug targets
|
||||
cmake --build $(f_debug) -t $(app_targets) --parallel $(CMAKE_BUILD_PARALLEL_LEVEL)
|
||||
cmake --build $(f_debug) --config Debug -t $(app_targets) --parallel $(JOBS)
|
||||
|
||||
buildr: ## Build the release targets
|
||||
cmake --build $(f_release) -t $(app_targets) --parallel $(CMAKE_BUILD_PARALLEL_LEVEL)
|
||||
cmake --build $(f_release) --config Release -t $(app_targets) --parallel $(JOBS)
|
||||
|
||||
clean-test: ## Clean the tests info
|
||||
@echo ">>> Cleaning Debug BayesNet tests...";
|
||||
$(call ClearTests)
|
||||
@echo ">>> Done";
|
||||
|
||||
# Install targets
|
||||
# ===============
|
||||
|
||||
uninstall: ## Uninstall library
|
||||
@echo ">>> Uninstalling BayesNet...";
|
||||
@@ -79,60 +101,20 @@ install: ## Install library
|
||||
@cmake --install $(f_release) --prefix $(prefix)
|
||||
@echo ">>> Done";
|
||||
|
||||
init: ## Initialize the project installing dependencies
|
||||
@echo ">>> Installing dependencies"
|
||||
@vcpkg install
|
||||
@echo ">>> Done";
|
||||
|
||||
clean: ## Clean the project
|
||||
@echo ">>> Cleaning the project..."
|
||||
@if test -f CMakeCache.txt ; then echo "- Deleting CMakeCache.txt"; rm -f CMakeCache.txt; fi
|
||||
@for folder in $(f_release) $(f_debug) vpcpkg_installed install_test ; do \
|
||||
if test -d "$$folder" ; then \
|
||||
echo "- Deleting $$folder folder" ; \
|
||||
rm -rf "$$folder"; \
|
||||
fi; \
|
||||
done
|
||||
@$(MAKE) clean-test
|
||||
@echo ">>> Done";
|
||||
# Test targets
|
||||
# ============
|
||||
|
||||
debug: ## Build a debug version of the project
|
||||
@echo ">>> Building Debug BayesNet...";
|
||||
@if [ -d ./$(f_debug) ]; then rm -rf ./$(f_debug); fi
|
||||
@mkdir $(f_debug);
|
||||
@cmake -S . -B $(f_debug) -D CMAKE_BUILD_TYPE=Debug -D ENABLE_TESTING=ON -D CODE_COVERAGE=ON -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake
|
||||
@echo ">>> Done";
|
||||
|
||||
release: ## Build a Release version of the project
|
||||
@echo ">>> Building Release BayesNet...";
|
||||
@if [ -d ./$(f_release) ]; then rm -rf ./$(f_release); fi
|
||||
@mkdir $(f_release);
|
||||
@cmake -S . -B $(f_release) -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake
|
||||
@echo ">>> Done";
|
||||
|
||||
fname = "tests/data/iris.arff"
|
||||
model = "TANLd"
|
||||
sample: ## Build sample
|
||||
@echo ">>> Building Sample...";
|
||||
@if [ -d ./sample/build ]; then rm -rf ./sample/build; fi
|
||||
@cd sample && cmake -B build -S . -D CMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake && \
|
||||
cmake --build build -t bayesnet_sample
|
||||
sample/build/bayesnet_sample $(fname) $(model)
|
||||
@echo ">>> Done";
|
||||
|
||||
fname = "tests/data/iris.arff"
|
||||
sample2: ## Build sample2
|
||||
@echo ">>> Building Sample...";
|
||||
@if [ -d ./sample/build ]; then rm -rf ./sample/build; fi
|
||||
@cd sample && cmake -B build -S . -D CMAKE_BUILD_TYPE=Debug && cmake --build build -t bayesnet_sample_xspode
|
||||
sample/build/bayesnet_sample_xspode $(fname)
|
||||
clean-test: ## Clean the tests info
|
||||
@echo ">>> Cleaning Debug BayesNet tests...";
|
||||
$(call ClearTests)
|
||||
@echo ">>> Done";
|
||||
|
||||
opt = ""
|
||||
test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximum Spanning Tree'") to run only that section
|
||||
@echo ">>> Running BayesNet tests...";
|
||||
@$(MAKE) clean-test
|
||||
@cmake --build $(f_debug) -t $(test_targets) --parallel $(CMAKE_BUILD_PARALLEL_LEVEL)
|
||||
@cmake --build $(f_debug) -t $(test_targets) --parallel $(JOBS)
|
||||
@for t in $(test_targets); do \
|
||||
echo ">>> Running $$t...";\
|
||||
if [ -f $(f_debug)/tests/$$t ]; then \
|
||||
@@ -157,6 +139,7 @@ coverage: ## Run tests and generate coverage report (build/index.html)
|
||||
$(lcov) --remove coverage.info 'tests/*' --output-file coverage.info >/dev/null 2>&1; \
|
||||
$(lcov) --remove coverage.info 'bayesnet/utils/loguru.*' --ignore-errors unused --output-file coverage.info >/dev/null 2>&1; \
|
||||
$(lcov) --remove coverage.info '/opt/miniconda/*' --ignore-errors unused --output-file coverage.info >/dev/null 2>&1; \
|
||||
$(lcov) --remove coverage.info '*/.conan2/*' --ignore-errors unused --output-file coverage.info >/dev/null 2>&1; \
|
||||
$(lcov) --summary coverage.info
|
||||
@$(MAKE) updatebadge
|
||||
@echo ">>> Done";
|
||||
@@ -182,6 +165,9 @@ updatebadge: ## Update the coverage badge in README.md
|
||||
@env python update_coverage.py $(f_debug)/tests
|
||||
@echo ">>> Done";
|
||||
|
||||
# Documentation targets
|
||||
# =====================
|
||||
|
||||
doc: ## Generate documentation
|
||||
@echo ">>> Generating documentation..."
|
||||
@cmake --build $(f_release) -t doxygen
|
||||
@@ -196,6 +182,22 @@ doc: ## Generate documentation
|
||||
fi
|
||||
@echo ">>> Done";
|
||||
|
||||
diagrams: ## Create an UML class diagram & dependency of the project (diagrams/BayesNet.png)
|
||||
@echo ">>> Creating diagrams..."
|
||||
@which $(plantuml) || (echo ">>> Please install plantuml"; exit 1)
|
||||
@which $(dot) || (echo ">>> Please install graphviz"; exit 1)
|
||||
@which $(clang-uml) || (echo ">>> Please install clang-uml"; exit 1)
|
||||
@export PLANTUML_LIMIT_SIZE=16384
|
||||
@echo ">>> Creating UML class diagram of the project...";
|
||||
@$(clang-uml) -p
|
||||
@cd $(f_diagrams); \
|
||||
$(plantuml) -tsvg BayesNet.puml
|
||||
@echo ">>> Creating dependency graph diagram of the project...";
|
||||
$(MAKE) debug
|
||||
cd $(f_debug) && cmake .. --graphviz=dependency.dot
|
||||
@$(dot) -Tsvg $(f_debug)/dependency.dot.BayesNet -o $(f_diagrams)/dependency.svg
|
||||
@echo ">>> Done";
|
||||
|
||||
docdir = ""
|
||||
doc-install: ## Install documentation
|
||||
@echo ">>> Installing documentation..."
|
||||
@@ -210,6 +212,38 @@ doc-install: ## Install documentation
|
||||
@sudo cp -rp $(mansrcdir) $(mandestdir)
|
||||
@echo ">>> Done";
|
||||
|
||||
# Conan package manager targets
|
||||
# =============================
|
||||
|
||||
conan-create: ## Create Conan package
|
||||
@echo ">>> Creating Conan package..."
|
||||
@conan create . --build=missing -tf "" -s:a build_type=Release
|
||||
@conan create . --build=missing -tf "" -s:a build_type=Debug -o "&:enable_coverage=False" -o "&:enable_testing=False"
|
||||
@echo ">>> Done"
|
||||
|
||||
conan-clean: ## Clean Conan cache and build folders
|
||||
@echo ">>> Cleaning Conan cache and build folders..."
|
||||
@conan remove "*" --confirm
|
||||
@conan cache clean
|
||||
@if test -d "$(f_release)" ; then rm -rf "$(f_release)"; fi
|
||||
@if test -d "$(f_debug)" ; then rm -rf "$(f_debug)"; fi
|
||||
@echo ">>> Done"
|
||||
|
||||
fname = "tests/data/iris.arff"
|
||||
model = "TANLd"
|
||||
build_type = "Debug"
|
||||
sample: ## Build sample with Conan
|
||||
@echo ">>> Building Sample with Conan...";
|
||||
@if [ -d ./sample/build ]; then rm -rf ./sample/build; fi
|
||||
@cd sample && conan install . --output-folder=build --build=missing -s build_type=$(build_type) -o "&:enable_coverage=False" -o "&:enable_testing=False"
|
||||
@cd sample && cmake -B build -S . -DCMAKE_BUILD_TYPE=$(build_type) -DCMAKE_TOOLCHAIN_FILE=build/conan_toolchain.cmake && \
|
||||
cmake --build build -t bayesnet_sample --parallel $(JOBS)
|
||||
sample/build/bayesnet_sample $(fname) $(model)
|
||||
@echo ">>> Done";
|
||||
|
||||
# Help target
|
||||
# ===========
|
||||
|
||||
help: ## Show help message
|
||||
@IFS=$$'\n' ; \
|
||||
help_lines=(`fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##/:/'`); \
|
||||
|
120
README.md
120
README.md
@@ -8,119 +8,119 @@
|
||||
[](https://sonarcloud.io/summary/new_code?id=rmontanana_BayesNet)
|
||||
[](https://deepwiki.com/Doctorado-ML/BayesNet)
|
||||

|
||||
[](https://gitea.rmontanana.es/rmontanana/BayesNet)
|
||||
[](https://gitea.rmontanana.es/rmontanana/BayesNet)
|
||||
[](https://doi.org/10.5281/zenodo.14210344)
|
||||
|
||||
Bayesian Network Classifiers library
|
||||
|
||||
## Setup
|
||||
## Using the Library
|
||||
|
||||
### Using the vcpkg library
|
||||
### Using Conan Package Manager
|
||||
|
||||
You can use the library with the vcpkg library manager. In your project you have to add the following files:
|
||||
You can use the library with the [Conan](https://conan.io/) package manager. In your project you need to add the following files:
|
||||
|
||||
#### vcpkg.json
|
||||
#### conanfile.txt
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "sample-project",
|
||||
"version-string": "0.1.0",
|
||||
"dependencies": [
|
||||
"bayesnet"
|
||||
]
|
||||
}
|
||||
```
|
||||
```txt
|
||||
[requires]
|
||||
bayesnet/1.1.2
|
||||
|
||||
#### vcpkg-configuration.json
|
||||
|
||||
```json
|
||||
{
|
||||
"registries": [
|
||||
{
|
||||
"kind": "git",
|
||||
"repository": "https://github.com/rmontanana/vcpkg-stash",
|
||||
"baseline": "393efa4e74e053b6f02c4ab03738c8fe796b28e5",
|
||||
"packages": [
|
||||
"folding",
|
||||
"bayesnet",
|
||||
"arff-files",
|
||||
"fimdlp",
|
||||
"libtorch-bin"
|
||||
]
|
||||
}
|
||||
],
|
||||
"default-registry": {
|
||||
"kind": "git",
|
||||
"repository": "https://github.com/microsoft/vcpkg",
|
||||
"baseline": "760bfd0c8d7c89ec640aec4df89418b7c2745605"
|
||||
}
|
||||
}
|
||||
[generators]
|
||||
CMakeDeps
|
||||
CMakeToolchain
|
||||
```
|
||||
|
||||
#### CMakeLists.txt
|
||||
|
||||
You have to include the following lines in your `CMakeLists.txt` file:
|
||||
Include the following lines in your `CMakeLists.txt` file:
|
||||
|
||||
```cmake
|
||||
find_package(bayesnet CONFIG REQUIRED)
|
||||
find_package(bayesnet REQUIRED)
|
||||
|
||||
add_executable(myapp main.cpp)
|
||||
|
||||
target_link_libraries(myapp PRIVATE bayesnet::bayesnet)
|
||||
```
|
||||
|
||||
After that, you can use the `vcpkg` command to install the dependencies:
|
||||
Then install the dependencies and build your project:
|
||||
|
||||
```bash
|
||||
vcpkg install
|
||||
conan install . --output-folder=build --build=missing
|
||||
cmake -B build -S . -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=build/conan_toolchain.cmake
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
**Note: In the `sample` folder you can find a sample application that uses the library. You can use it as a reference to create your own application.**
|
||||
|
||||
## Playing with the library
|
||||
## Building and Testing
|
||||
|
||||
The dependencies are managed with [vcpkg](https://vcpkg.io/) and supported by a private vcpkg repository in [https://github.com/rmontanana/vcpkg-stash](https://github.com/rmontanana/vcpkg-stash).
|
||||
The project uses [Conan](https://conan.io/) for dependency management and provides convenient Makefile targets for common tasks.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- [Conan](https://conan.io/) package manager (`pip install conan`)
|
||||
- CMake 3.27+
|
||||
- C++17 compatible compiler
|
||||
|
||||
### Getting the code
|
||||
|
||||
```bash
|
||||
git clone https://github.com/doctorado-ml/bayesnet
|
||||
cd bayesnet
|
||||
```
|
||||
|
||||
Once you have the code, you can use the `make` command to build the project. The `Makefile` is used to manage the build process and it will automatically download and install the dependencies.
|
||||
### Build Commands
|
||||
|
||||
### Release
|
||||
#### Release Build
|
||||
|
||||
```bash
|
||||
make init # Install dependencies
|
||||
make release # Build the release version
|
||||
make buildr # compile and link the release version
|
||||
make release # Configure release build with Conan
|
||||
make buildr # Build the release version
|
||||
```
|
||||
|
||||
### Debug & Tests
|
||||
#### Debug Build & Tests
|
||||
|
||||
```bash
|
||||
make init # Install dependencies
|
||||
make debug # Build the debug version
|
||||
make test # Run the tests
|
||||
make debug # Configure debug build with Conan
|
||||
make buildd # Build the debug version
|
||||
make test # Run the tests
|
||||
```
|
||||
|
||||
### Coverage
|
||||
#### Coverage Analysis
|
||||
|
||||
```bash
|
||||
make coverage # Run the tests with coverage
|
||||
make viewcoverage # View the coverage report in the browser
|
||||
make coverage # Run tests with coverage analysis
|
||||
make viewcoverage # View coverage report in browser
|
||||
```
|
||||
|
||||
### Sample app
|
||||
#### Sample Application
|
||||
|
||||
After building and installing the release version, you can run the sample app with the following commands:
|
||||
Run the sample application with different datasets and models:
|
||||
|
||||
```bash
|
||||
make sample
|
||||
make sample fname=tests/data/glass.arff
|
||||
make sample # Run with default settings
|
||||
make sample fname=tests/data/glass.arff # Use glass dataset
|
||||
make sample fname=tests/data/iris.arff model=AODE # Use specific model
|
||||
```
|
||||
|
||||
### Available Makefile Targets
|
||||
|
||||
- `debug` - Configure debug build using Conan
|
||||
- `release` - Configure release build using Conan
|
||||
- `buildd` - Build debug targets
|
||||
- `buildr` - Build release targets
|
||||
- `test` - Run all tests (use `opt="-s"` for verbose output)
|
||||
- `coverage` - Generate test coverage report
|
||||
- `viewcoverage` - Open coverage report in browser
|
||||
- `sample` - Build and run sample application
|
||||
- `conan-create` - Create Conan package
|
||||
- `conan-upload` - Upload package to Conan remote
|
||||
- `conan-clean` - Clean Conan cache and build folders
|
||||
- `clean` - Clean all build artifacts
|
||||
- `doc` - Generate documentation
|
||||
- `diagrams` - Generate UML diagrams
|
||||
- `help` - Show all available targets
|
||||
|
||||
## Models
|
||||
|
||||
#### - TAN
|
||||
|
@@ -1,13 +0,0 @@
|
||||
include_directories(
|
||||
${BayesNet_SOURCE_DIR}/lib/log
|
||||
${BayesNet_SOURCE_DIR}/lib/mdlp/src
|
||||
${BayesNet_SOURCE_DIR}/lib/folding
|
||||
${BayesNet_SOURCE_DIR}/lib/json/include
|
||||
${BayesNet_SOURCE_DIR}
|
||||
${CMAKE_BINARY_DIR}/configured_files/include
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE Sources "*.cc")
|
||||
|
||||
add_library(BayesNet ${Sources})
|
||||
target_link_libraries(BayesNet fimdlp "${TORCH_LIBRARIES}")
|
@@ -37,6 +37,7 @@ namespace bayesnet {
|
||||
std::vector<std::string> getNotes() const override { return notes; }
|
||||
std::string dump_cpt() const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
|
||||
Network& getModel() { return model; }
|
||||
protected:
|
||||
bool fitted;
|
||||
unsigned int m, n; // m: number of samples, n: number of features
|
||||
|
@@ -5,40 +5,38 @@
|
||||
// ***************************************************************
|
||||
|
||||
#include "KDBLd.h"
|
||||
#include <memory>
|
||||
|
||||
namespace bayesnet {
|
||||
KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className)
|
||||
KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className, KDB::notes)
|
||||
{
|
||||
validHyperparameters = validHyperparameters_ld;
|
||||
validHyperparameters.push_back("k");
|
||||
validHyperparameters.push_back("theta");
|
||||
}
|
||||
void KDBLd::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("k")) {
|
||||
k = hyperparameters["k"];
|
||||
hyperparameters.erase("k");
|
||||
}
|
||||
if (hyperparameters.contains("theta")) {
|
||||
theta = hyperparameters["theta"];
|
||||
hyperparameters.erase("theta");
|
||||
}
|
||||
Proposal::setHyperparameters(hyperparameters);
|
||||
}
|
||||
KDBLd& KDBLd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing)
|
||||
{
|
||||
checkInput(X_, y_);
|
||||
features = features_;
|
||||
className = className_;
|
||||
Xf = X_;
|
||||
y = y_;
|
||||
// Fills std::vectors Xv & yv with the data from tensors X_ (discretized) & y
|
||||
states = fit_local_discretization(y);
|
||||
// We have discretized the input data
|
||||
// 1st we need to fit the model to build the normal KDB structure, KDB::fit initializes the base Bayesian network
|
||||
return commonFit(features_, className_, states_, smoothing);
|
||||
}
|
||||
KDBLd& KDBLd::fit(torch::Tensor& dataset, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing)
|
||||
{
|
||||
if (!torch::is_floating_point(dataset)) {
|
||||
throw std::runtime_error("Dataset must be a floating point tensor");
|
||||
}
|
||||
Xf = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }).clone();
|
||||
y = dataset.index({ -1, "..." }).clone().to(torch::kInt32);
|
||||
return commonFit(features_, className_, states_, smoothing);
|
||||
}
|
||||
|
||||
KDBLd& KDBLd::commonFit(const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing)
|
||||
{
|
||||
features = features_;
|
||||
className = className_;
|
||||
states = iterativeLocalDiscretization(y, static_cast<KDB*>(this), dataset, features, className, states_, smoothing);
|
||||
KDB::fit(dataset, features, className, states, smoothing);
|
||||
states = localDiscretizationProposal(states, model);
|
||||
return *this;
|
||||
}
|
||||
torch::Tensor KDBLd::predict(torch::Tensor& X)
|
||||
|
@@ -15,8 +15,15 @@ namespace bayesnet {
|
||||
explicit KDBLd(int k);
|
||||
virtual ~KDBLd() = default;
|
||||
KDBLd& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states, const Smoothing_t smoothing) override;
|
||||
KDBLd& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states, const Smoothing_t smoothing) override;
|
||||
KDBLd& commonFit(const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states, const Smoothing_t smoothing);
|
||||
std::vector<std::string> graph(const std::string& name = "KDB") const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
Proposal::setHyperparameters(hyperparameters);
|
||||
KDB::setHyperparameters(hyperparameters);
|
||||
}
|
||||
torch::Tensor predict(torch::Tensor& X) override;
|
||||
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||
static inline std::string version() { return "0.0.1"; };
|
||||
|
@@ -5,14 +5,22 @@
|
||||
// ***************************************************************
|
||||
|
||||
#include "Proposal.h"
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include "Classifier.h"
|
||||
#include "KDB.h"
|
||||
#include "TAN.h"
|
||||
#include "SPODE.h"
|
||||
#include "KDBLd.h"
|
||||
#include "TANLd.h"
|
||||
|
||||
namespace bayesnet {
|
||||
Proposal::Proposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_) : pDataset(dataset_), pFeatures(features_), pClassName(className_)
|
||||
Proposal::Proposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_, std::vector<std::string>& notes_) : pDataset(dataset_), pFeatures(features_), pClassName(className_), notes(notes_)
|
||||
{
|
||||
}
|
||||
void Proposal::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
void Proposal::setHyperparameters(nlohmann::json& hyperparameters)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("ld_proposed_cuts")) {
|
||||
ld_params.proposed_cuts = hyperparameters["ld_proposed_cuts"];
|
||||
hyperparameters.erase("ld_proposed_cuts");
|
||||
@@ -38,8 +46,14 @@ namespace bayesnet {
|
||||
throw std::invalid_argument("Invalid discretization algorithm: " + algorithm.get<std::string>());
|
||||
}
|
||||
}
|
||||
if (!hyperparameters.empty()) {
|
||||
throw std::invalid_argument("Invalid hyperparameters for Proposal: " + hyperparameters.dump());
|
||||
// Convergence parameters
|
||||
if (hyperparameters.contains("max_iterations")) {
|
||||
convergence_params.maxIterations = hyperparameters["max_iterations"];
|
||||
hyperparameters.erase("max_iterations");
|
||||
}
|
||||
if (hyperparameters.contains("verbose_convergence")) {
|
||||
convergence_params.verbose = hyperparameters["verbose_convergence"];
|
||||
hyperparameters.erase("verbose_convergence");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,4 +177,65 @@ namespace bayesnet {
|
||||
}
|
||||
return yy;
|
||||
}
|
||||
|
||||
template<typename Classifier>
|
||||
map<std::string, std::vector<int>> Proposal::iterativeLocalDiscretization(
|
||||
const torch::Tensor& y,
|
||||
Classifier* classifier,
|
||||
torch::Tensor& dataset,
|
||||
const std::vector<std::string>& features,
|
||||
const std::string& className,
|
||||
const map<std::string, std::vector<int>>& initialStates,
|
||||
Smoothing_t smoothing
|
||||
)
|
||||
{
|
||||
// Phase 1: Initial discretization (same as original)
|
||||
auto currentStates = fit_local_discretization(y);
|
||||
auto previousModel = Network();
|
||||
|
||||
if (convergence_params.verbose) {
|
||||
std::cout << "Starting iterative local discretization with "
|
||||
<< convergence_params.maxIterations << " max iterations" << std::endl;
|
||||
}
|
||||
|
||||
const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble);
|
||||
for (int iteration = 0; iteration < convergence_params.maxIterations; ++iteration) {
|
||||
if (convergence_params.verbose) {
|
||||
std::cout << "Iteration " << (iteration + 1) << "/" << convergence_params.maxIterations << std::endl;
|
||||
}
|
||||
|
||||
// Phase 2: Build model with current discretization
|
||||
classifier->fit(dataset, features, className, currentStates, weights, smoothing);
|
||||
|
||||
// Phase 3: Network-aware discretization refinement
|
||||
currentStates = localDiscretizationProposal(currentStates, classifier->getModel());
|
||||
|
||||
// Check convergence
|
||||
if (iteration > 0 && previousModel == classifier->getModel()) {
|
||||
if (convergence_params.verbose) {
|
||||
std::cout << "Converged after " << (iteration + 1) << " iterations" << std::endl;
|
||||
}
|
||||
notes.push_back("Converged after " + std::to_string(iteration + 1) + " of "
|
||||
+ std::to_string(convergence_params.maxIterations) + " iterations");
|
||||
break;
|
||||
}
|
||||
|
||||
// Update for next iteration
|
||||
previousModel = classifier->getModel();
|
||||
}
|
||||
|
||||
return currentStates;
|
||||
}
|
||||
|
||||
// Explicit template instantiation for common classifier types
|
||||
template map<std::string, std::vector<int>> Proposal::iterativeLocalDiscretization<KDB>(
|
||||
const torch::Tensor&, KDB*, torch::Tensor&, const std::vector<std::string>&,
|
||||
const std::string&, const map<std::string, std::vector<int>>&, Smoothing_t);
|
||||
|
||||
template map<std::string, std::vector<int>> Proposal::iterativeLocalDiscretization<TAN>(
|
||||
const torch::Tensor&, TAN*, torch::Tensor&, const std::vector<std::string>&,
|
||||
const std::string&, const map<std::string, std::vector<int>>&, Smoothing_t);
|
||||
template map<std::string, std::vector<int>> Proposal::iterativeLocalDiscretization<SPODE>(
|
||||
const torch::Tensor&, SPODE*, torch::Tensor&, const std::vector<std::string>&,
|
||||
const std::string&, const map<std::string, std::vector<int>>&, Smoothing_t);
|
||||
}
|
||||
|
@@ -18,25 +18,50 @@
|
||||
namespace bayesnet {
|
||||
class Proposal {
|
||||
public:
|
||||
Proposal(torch::Tensor& pDataset, std::vector<std::string>& features_, std::string& className_);
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_);
|
||||
Proposal(torch::Tensor& pDataset, std::vector<std::string>& features_, std::string& className_, std::vector<std::string>& notes);
|
||||
void setHyperparameters(nlohmann::json& hyperparameters_);
|
||||
protected:
|
||||
void checkInput(const torch::Tensor& X, const torch::Tensor& y);
|
||||
torch::Tensor prepareX(torch::Tensor& X);
|
||||
map<std::string, std::vector<int>> localDiscretizationProposal(const map<std::string, std::vector<int>>& states, Network& model);
|
||||
map<std::string, std::vector<int>> fit_local_discretization(const torch::Tensor& y);
|
||||
|
||||
// Iterative discretization method
|
||||
template<typename Classifier>
|
||||
map<std::string, std::vector<int>> iterativeLocalDiscretization(
|
||||
const torch::Tensor& y,
|
||||
Classifier* classifier,
|
||||
torch::Tensor& dataset,
|
||||
const std::vector<std::string>& features,
|
||||
const std::string& className,
|
||||
const map<std::string, std::vector<int>>& initialStates,
|
||||
const Smoothing_t smoothing
|
||||
);
|
||||
|
||||
torch::Tensor Xf; // X continuous nxm tensor
|
||||
torch::Tensor y; // y discrete nx1 tensor
|
||||
map<std::string, std::unique_ptr<mdlp::Discretizer>> discretizers;
|
||||
|
||||
// MDLP parameters
|
||||
struct {
|
||||
size_t min_length = 3; // Minimum length of the interval to consider it in mdlp
|
||||
float proposed_cuts = 0.0; // Proposed cuts for the Discretization algorithm
|
||||
int max_depth = std::numeric_limits<int>::max(); // Maximum depth of the MDLP tree
|
||||
} ld_params;
|
||||
nlohmann::json validHyperparameters_ld = { "ld_algorithm", "ld_proposed_cuts", "mdlp_min_length", "mdlp_max_depth" };
|
||||
|
||||
// Convergence parameters
|
||||
struct {
|
||||
int maxIterations = 10;
|
||||
bool verbose = false;
|
||||
} convergence_params;
|
||||
|
||||
nlohmann::json validHyperparameters_ld = {
|
||||
"ld_algorithm", "ld_proposed_cuts", "mdlp_min_length", "mdlp_max_depth",
|
||||
"max_iterations", "verbose_convergence"
|
||||
};
|
||||
private:
|
||||
std::vector<int> factorize(const std::vector<std::string>& labels_t);
|
||||
std::vector<std::string>& notes; // Notes during fit from BaseClassifier
|
||||
torch::Tensor& pDataset; // (n+1)xm tensor
|
||||
std::vector<std::string>& pFeatures;
|
||||
std::string& pClassName;
|
||||
|
@@ -7,7 +7,7 @@
|
||||
#include "SPODELd.h"
|
||||
|
||||
namespace bayesnet {
|
||||
SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className)
|
||||
SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className, SPODE::notes)
|
||||
{
|
||||
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
|
||||
}
|
||||
@@ -34,12 +34,8 @@ namespace bayesnet {
|
||||
{
|
||||
features = features_;
|
||||
className = className_;
|
||||
// Fills std::vectors Xv & yv with the data from tensors X_ (discretized) & y
|
||||
states = fit_local_discretization(y);
|
||||
// We have discretized the input data
|
||||
// 1st we need to fit the model to build the normal SPODE structure, SPODE::fit initializes the base Bayesian network
|
||||
states = iterativeLocalDiscretization(y, static_cast<SPODE*>(this), dataset, features, className, states_, smoothing);
|
||||
SPODE::fit(dataset, features, className, states, smoothing);
|
||||
states = localDiscretizationProposal(states, model);
|
||||
return *this;
|
||||
}
|
||||
torch::Tensor SPODELd::predict(torch::Tensor& X)
|
||||
|
@@ -18,6 +18,12 @@ namespace bayesnet {
|
||||
SPODELd& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states, const Smoothing_t smoothing) override;
|
||||
SPODELd& commonFit(const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states, const Smoothing_t smoothing);
|
||||
std::vector<std::string> graph(const std::string& name = "SPODELd") const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
Proposal::setHyperparameters(hyperparameters);
|
||||
SPODE::setHyperparameters(hyperparameters);
|
||||
}
|
||||
torch::Tensor predict(torch::Tensor& X) override;
|
||||
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||
static inline std::string version() { return "0.0.1"; };
|
||||
|
@@ -5,24 +5,37 @@
|
||||
// ***************************************************************
|
||||
|
||||
#include "TANLd.h"
|
||||
#include <memory>
|
||||
|
||||
namespace bayesnet {
|
||||
TANLd::TANLd() : TAN(), Proposal(dataset, features, className) {}
|
||||
TANLd::TANLd() : TAN(), Proposal(dataset, features, className, TAN::notes)
|
||||
{
|
||||
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
|
||||
}
|
||||
TANLd& TANLd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing)
|
||||
{
|
||||
checkInput(X_, y_);
|
||||
features = features_;
|
||||
className = className_;
|
||||
Xf = X_;
|
||||
y = y_;
|
||||
// Fills std::vectors Xv & yv with the data from tensors X_ (discretized) & y
|
||||
states = fit_local_discretization(y);
|
||||
// We have discretized the input data
|
||||
// 1st we need to fit the model to build the normal TAN structure, TAN::fit initializes the base Bayesian network
|
||||
TAN::fit(dataset, features, className, states, smoothing);
|
||||
states = localDiscretizationProposal(states, model);
|
||||
return *this;
|
||||
return commonFit(features_, className_, states_, smoothing);
|
||||
}
|
||||
TANLd& TANLd::fit(torch::Tensor& dataset, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing)
|
||||
{
|
||||
if (!torch::is_floating_point(dataset)) {
|
||||
throw std::runtime_error("Dataset must be a floating point tensor");
|
||||
}
|
||||
Xf = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }).clone();
|
||||
y = dataset.index({ -1, "..." }).clone().to(torch::kInt32);
|
||||
return commonFit(features_, className_, states_, smoothing);
|
||||
}
|
||||
|
||||
TANLd& TANLd::commonFit(const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing)
|
||||
{
|
||||
features = features_;
|
||||
className = className_;
|
||||
states = iterativeLocalDiscretization(y, static_cast<TAN*>(this), dataset, features, className, states_, smoothing);
|
||||
TAN::fit(dataset, features, className, states, smoothing);
|
||||
return *this;
|
||||
}
|
||||
torch::Tensor TANLd::predict(torch::Tensor& X)
|
||||
{
|
||||
|
@@ -16,7 +16,15 @@ namespace bayesnet {
|
||||
TANLd();
|
||||
virtual ~TANLd() = default;
|
||||
TANLd& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states, const Smoothing_t smoothing) override;
|
||||
TANLd& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states, const Smoothing_t smoothing) override;
|
||||
TANLd& commonFit(const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states, const Smoothing_t smoothing);
|
||||
std::vector<std::string> graph(const std::string& name = "TANLd") const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
Proposal::setHyperparameters(hyperparameters);
|
||||
TAN::setHyperparameters(hyperparameters);
|
||||
}
|
||||
torch::Tensor predict(torch::Tensor& X) override;
|
||||
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||
};
|
||||
|
@@ -7,7 +7,7 @@
|
||||
#include "AODELd.h"
|
||||
|
||||
namespace bayesnet {
|
||||
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className)
|
||||
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className, Ensemble::notes)
|
||||
{
|
||||
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
|
||||
}
|
||||
|
@@ -17,6 +17,10 @@ namespace bayesnet {
|
||||
virtual ~AODELd() = default;
|
||||
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing) override;
|
||||
std::vector<std::string> graph(const std::string& name = "AODELd") const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override
|
||||
{
|
||||
hyperparameters = hyperparameters_;
|
||||
}
|
||||
protected:
|
||||
void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override;
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
|
@@ -17,14 +17,90 @@ namespace bayesnet {
|
||||
Network::Network() : fitted{ false }, classNumStates{ 0 }
|
||||
{
|
||||
}
|
||||
Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
|
||||
fitted(other.fitted), samples(other.samples)
|
||||
Network::Network(const Network& other)
|
||||
: features(other.features), className(other.className), classNumStates(other.classNumStates),
|
||||
fitted(other.fitted)
|
||||
{
|
||||
if (samples.defined())
|
||||
samples = samples.clone();
|
||||
// Deep copy the samples tensor
|
||||
if (other.samples.defined()) {
|
||||
samples = other.samples.clone();
|
||||
}
|
||||
|
||||
// First, create all nodes (without relationships)
|
||||
for (const auto& node : other.nodes) {
|
||||
nodes[node.first] = std::make_unique<Node>(*node.second);
|
||||
}
|
||||
|
||||
// Second, reconstruct the relationships between nodes
|
||||
for (const auto& node : other.nodes) {
|
||||
const std::string& nodeName = node.first;
|
||||
Node* originalNode = node.second.get();
|
||||
Node* newNode = nodes[nodeName].get();
|
||||
|
||||
// Reconstruct parent relationships
|
||||
for (Node* parent : originalNode->getParents()) {
|
||||
const std::string& parentName = parent->getName();
|
||||
if (nodes.find(parentName) != nodes.end()) {
|
||||
newNode->addParent(nodes[parentName].get());
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct child relationships
|
||||
for (Node* child : originalNode->getChildren()) {
|
||||
const std::string& childName = child->getName();
|
||||
if (nodes.find(childName) != nodes.end()) {
|
||||
newNode->addChild(nodes[childName].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Network& Network::operator=(const Network& other)
|
||||
{
|
||||
if (this != &other) {
|
||||
// Clear existing state
|
||||
nodes.clear();
|
||||
features = other.features;
|
||||
className = other.className;
|
||||
classNumStates = other.classNumStates;
|
||||
fitted = other.fitted;
|
||||
|
||||
// Deep copy the samples tensor
|
||||
if (other.samples.defined()) {
|
||||
samples = other.samples.clone();
|
||||
} else {
|
||||
samples = torch::Tensor();
|
||||
}
|
||||
|
||||
// First, create all nodes (without relationships)
|
||||
for (const auto& node : other.nodes) {
|
||||
nodes[node.first] = std::make_unique<Node>(*node.second);
|
||||
}
|
||||
|
||||
// Second, reconstruct the relationships between nodes
|
||||
for (const auto& node : other.nodes) {
|
||||
const std::string& nodeName = node.first;
|
||||
Node* originalNode = node.second.get();
|
||||
Node* newNode = nodes[nodeName].get();
|
||||
|
||||
// Reconstruct parent relationships
|
||||
for (Node* parent : originalNode->getParents()) {
|
||||
const std::string& parentName = parent->getName();
|
||||
if (nodes.find(parentName) != nodes.end()) {
|
||||
newNode->addParent(nodes[parentName].get());
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct child relationships
|
||||
for (Node* child : originalNode->getChildren()) {
|
||||
const std::string& childName = child->getName();
|
||||
if (nodes.find(childName) != nodes.end()) {
|
||||
newNode->addChild(nodes[childName].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
void Network::initialize()
|
||||
{
|
||||
@@ -503,4 +579,41 @@ namespace bayesnet {
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
bool Network::operator==(const Network& other) const
|
||||
{
|
||||
// Compare number of nodes
|
||||
if (nodes.size() != other.nodes.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compare if all node names exist in both networks
|
||||
for (const auto& node : nodes) {
|
||||
if (other.nodes.find(node.first) == other.nodes.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Compare edges (topology)
|
||||
auto thisEdges = getEdges();
|
||||
auto otherEdges = other.getEdges();
|
||||
|
||||
// Compare number of edges
|
||||
if (thisEdges.size() != otherEdges.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Sort both edge lists for comparison
|
||||
std::sort(thisEdges.begin(), thisEdges.end());
|
||||
std::sort(otherEdges.begin(), otherEdges.end());
|
||||
|
||||
// Compare each edge
|
||||
for (size_t i = 0; i < thisEdges.size(); ++i) {
|
||||
if (thisEdges[i] != otherEdges[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@@ -17,7 +17,8 @@ namespace bayesnet {
|
||||
class Network {
|
||||
public:
|
||||
Network();
|
||||
explicit Network(const Network&);
|
||||
Network(const Network& other);
|
||||
Network& operator=(const Network& other);
|
||||
~Network() = default;
|
||||
torch::Tensor& getSamples();
|
||||
void addNode(const std::string&);
|
||||
@@ -47,6 +48,7 @@ namespace bayesnet {
|
||||
void initialize();
|
||||
std::string dump_cpt() const;
|
||||
inline std::string version() { return { project_version.begin(), project_version.end() }; }
|
||||
bool operator==(const Network& other) const;
|
||||
private:
|
||||
std::map<std::string, std::unique_ptr<Node>> nodes;
|
||||
bool fitted;
|
||||
|
@@ -13,6 +13,41 @@ namespace bayesnet {
|
||||
: name(name)
|
||||
{
|
||||
}
|
||||
|
||||
Node::Node(const Node& other)
|
||||
: name(other.name), numStates(other.numStates), dimensions(other.dimensions)
|
||||
{
|
||||
// Deep copy the CPT tensor
|
||||
if (other.cpTable.defined()) {
|
||||
cpTable = other.cpTable.clone();
|
||||
}
|
||||
// Note: parent and children pointers are NOT copied here
|
||||
// They will be reconstructed by the Network copy constructor
|
||||
// to maintain proper object relationships
|
||||
}
|
||||
|
||||
Node& Node::operator=(const Node& other)
|
||||
{
|
||||
if (this != &other) {
|
||||
name = other.name;
|
||||
numStates = other.numStates;
|
||||
dimensions = other.dimensions;
|
||||
|
||||
// Deep copy the CPT tensor
|
||||
if (other.cpTable.defined()) {
|
||||
cpTable = other.cpTable.clone();
|
||||
} else {
|
||||
cpTable = torch::Tensor();
|
||||
}
|
||||
|
||||
// Clear existing relationships
|
||||
parents.clear();
|
||||
children.clear();
|
||||
// Note: parent and children pointers are NOT copied here
|
||||
// They must be reconstructed to maintain proper object relationships
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
void Node::clear()
|
||||
{
|
||||
parents.clear();
|
||||
|
@@ -14,6 +14,9 @@ namespace bayesnet {
|
||||
class Node {
|
||||
public:
|
||||
explicit Node(const std::string&);
|
||||
Node(const Node& other);
|
||||
Node& operator=(const Node& other);
|
||||
~Node() = default;
|
||||
void clear();
|
||||
void addParent(Node*);
|
||||
void addChild(Node*);
|
||||
|
@@ -1,12 +0,0 @@
|
||||
|
||||
function(add_git_submodule dir)
|
||||
find_package(Git REQUIRED)
|
||||
|
||||
if(NOT EXISTS ${dir}/CMakeLists.txt)
|
||||
message(STATUS "🚨 Adding git submodule => ${dir}")
|
||||
execute_process(COMMAND ${GIT_EXECUTABLE}
|
||||
submodule update --init --recursive -- ${dir}
|
||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
|
||||
endif()
|
||||
add_subdirectory(${dir})
|
||||
endfunction(add_git_submodule)
|
@@ -1,746 +0,0 @@
|
||||
# Copyright (c) 2012 - 2017, Lars Bilke
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software without
|
||||
# specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
# CHANGES:
|
||||
#
|
||||
# 2012-01-31, Lars Bilke
|
||||
# - Enable Code Coverage
|
||||
#
|
||||
# 2013-09-17, Joakim Söderberg
|
||||
# - Added support for Clang.
|
||||
# - Some additional usage instructions.
|
||||
#
|
||||
# 2016-02-03, Lars Bilke
|
||||
# - Refactored functions to use named parameters
|
||||
#
|
||||
# 2017-06-02, Lars Bilke
|
||||
# - Merged with modified version from github.com/ufz/ogs
|
||||
#
|
||||
# 2019-05-06, Anatolii Kurotych
|
||||
# - Remove unnecessary --coverage flag
|
||||
#
|
||||
# 2019-12-13, FeRD (Frank Dana)
|
||||
# - Deprecate COVERAGE_LCOVR_EXCLUDES and COVERAGE_GCOVR_EXCLUDES lists in favor
|
||||
# of tool-agnostic COVERAGE_EXCLUDES variable, or EXCLUDE setup arguments.
|
||||
# - CMake 3.4+: All excludes can be specified relative to BASE_DIRECTORY
|
||||
# - All setup functions: accept BASE_DIRECTORY, EXCLUDE list
|
||||
# - Set lcov basedir with -b argument
|
||||
# - Add automatic --demangle-cpp in lcovr, if 'c++filt' is available (can be
|
||||
# overridden with NO_DEMANGLE option in setup_target_for_coverage_lcovr().)
|
||||
# - Delete output dir, .info file on 'make clean'
|
||||
# - Remove Python detection, since version mismatches will break gcovr
|
||||
# - Minor cleanup (lowercase function names, update examples...)
|
||||
#
|
||||
# 2019-12-19, FeRD (Frank Dana)
|
||||
# - Rename Lcov outputs, make filtered file canonical, fix cleanup for targets
|
||||
#
|
||||
# 2020-01-19, Bob Apthorpe
|
||||
# - Added gfortran support
|
||||
#
|
||||
# 2020-02-17, FeRD (Frank Dana)
|
||||
# - Make all add_custom_target()s VERBATIM to auto-escape wildcard characters
|
||||
# in EXCLUDEs, and remove manual escaping from gcovr targets
|
||||
#
|
||||
# 2021-01-19, Robin Mueller
|
||||
# - Add CODE_COVERAGE_VERBOSE option which will allow to print out commands which are run
|
||||
# - Added the option for users to set the GCOVR_ADDITIONAL_ARGS variable to supply additional
|
||||
# flags to the gcovr command
|
||||
#
|
||||
# 2020-05-04, Mihchael Davis
|
||||
# - Add -fprofile-abs-path to make gcno files contain absolute paths
|
||||
# - Fix BASE_DIRECTORY not working when defined
|
||||
# - Change BYPRODUCT from folder to index.html to stop ninja from complaining about double defines
|
||||
#
|
||||
# 2021-05-10, Martin Stump
|
||||
# - Check if the generator is multi-config before warning about non-Debug builds
|
||||
#
|
||||
# 2022-02-22, Marko Wehle
|
||||
# - Change gcovr output from -o <filename> for --xml <filename> and --html <filename> output respectively.
|
||||
# This will allow for Multiple Output Formats at the same time by making use of GCOVR_ADDITIONAL_ARGS, e.g. GCOVR_ADDITIONAL_ARGS "--txt".
|
||||
#
|
||||
# 2022-09-28, Sebastian Mueller
|
||||
# - fix append_coverage_compiler_flags_to_target to correctly add flags
|
||||
# - replace "-fprofile-arcs -ftest-coverage" with "--coverage" (equivalent)
|
||||
#
|
||||
# USAGE:
|
||||
#
|
||||
# 1. Copy this file into your cmake modules path.
|
||||
#
|
||||
# 2. Add the following line to your CMakeLists.txt (best inside an if-condition
|
||||
# using a CMake option() to enable it just optionally):
|
||||
# include(CodeCoverage)
|
||||
#
|
||||
# 3. Append necessary compiler flags for all supported source files:
|
||||
# append_coverage_compiler_flags()
|
||||
# Or for specific target:
|
||||
# append_coverage_compiler_flags_to_target(YOUR_TARGET_NAME)
|
||||
#
|
||||
# 3.a (OPTIONAL) Set appropriate optimization flags, e.g. -O0, -O1 or -Og
|
||||
#
|
||||
# 4. If you need to exclude additional directories from the report, specify them
|
||||
# using full paths in the COVERAGE_EXCLUDES variable before calling
|
||||
# setup_target_for_coverage_*().
|
||||
# Example:
|
||||
# set(COVERAGE_EXCLUDES
|
||||
# '${PROJECT_SOURCE_DIR}/src/dir1/*'
|
||||
# '/path/to/my/src/dir2/*')
|
||||
# Or, use the EXCLUDE argument to setup_target_for_coverage_*().
|
||||
# Example:
|
||||
# setup_target_for_coverage_lcov(
|
||||
# NAME coverage
|
||||
# EXECUTABLE testrunner
|
||||
# EXCLUDE "${PROJECT_SOURCE_DIR}/src/dir1/*" "/path/to/my/src/dir2/*")
|
||||
#
|
||||
# 4.a NOTE: With CMake 3.4+, COVERAGE_EXCLUDES or EXCLUDE can also be set
|
||||
# relative to the BASE_DIRECTORY (default: PROJECT_SOURCE_DIR)
|
||||
# Example:
|
||||
# set(COVERAGE_EXCLUDES "dir1/*")
|
||||
# setup_target_for_coverage_gcovr_html(
|
||||
# NAME coverage
|
||||
# EXECUTABLE testrunner
|
||||
# BASE_DIRECTORY "${PROJECT_SOURCE_DIR}/src"
|
||||
# EXCLUDE "dir2/*")
|
||||
#
|
||||
# 5. Use the functions described below to create a custom make target which
|
||||
# runs your test executable and produces a code coverage report.
|
||||
#
|
||||
# 6. Build a Debug build:
|
||||
# cmake -DCMAKE_BUILD_TYPE=Debug ..
|
||||
# make
|
||||
# make my_coverage_target
|
||||
#
|
||||
|
||||
include(CMakeParseArguments)
|
||||
|
||||
option(CODE_COVERAGE_VERBOSE "Verbose information" TRUE)
|
||||
|
||||
# Check prereqs
|
||||
find_program( GCOV_PATH gcov )
|
||||
find_program( LCOV_PATH NAMES lcov lcov.bat lcov.exe lcov.perl)
|
||||
find_program( FASTCOV_PATH NAMES fastcov fastcov.py )
|
||||
find_program( GENHTML_PATH NAMES genhtml genhtml.perl genhtml.bat )
|
||||
find_program( GCOVR_PATH gcovr PATHS ${CMAKE_SOURCE_DIR}/scripts/test)
|
||||
find_program( CPPFILT_PATH NAMES c++filt )
|
||||
|
||||
if(NOT GCOV_PATH)
|
||||
message(FATAL_ERROR "gcov not found! Aborting...")
|
||||
endif() # NOT GCOV_PATH
|
||||
|
||||
# Check supported compiler (Clang, GNU and Flang)
|
||||
get_property(LANGUAGES GLOBAL PROPERTY ENABLED_LANGUAGES)
|
||||
foreach(LANG ${LANGUAGES})
|
||||
if("${CMAKE_${LANG}_COMPILER_ID}" MATCHES "(Apple)?[Cc]lang")
|
||||
if("${CMAKE_${LANG}_COMPILER_VERSION}" VERSION_LESS 3)
|
||||
message(FATAL_ERROR "Clang version must be 3.0.0 or greater! Aborting...")
|
||||
endif()
|
||||
elseif(NOT "${CMAKE_${LANG}_COMPILER_ID}" MATCHES "GNU"
|
||||
AND NOT "${CMAKE_${LANG}_COMPILER_ID}" MATCHES "(LLVM)?[Ff]lang")
|
||||
if ("${LANG}" MATCHES "CUDA")
|
||||
message(STATUS "Ignoring CUDA")
|
||||
else()
|
||||
message(FATAL_ERROR "Compiler is not GNU or Flang! Aborting...")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(COVERAGE_COMPILER_FLAGS "-g --coverage"
|
||||
CACHE INTERNAL "")
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "(GNU|Clang)")
|
||||
include(CheckCXXCompilerFlag)
|
||||
check_cxx_compiler_flag(-fprofile-abs-path HAVE_fprofile_abs_path)
|
||||
if(HAVE_fprofile_abs_path)
|
||||
set(COVERAGE_COMPILER_FLAGS "${COVERAGE_COMPILER_FLAGS} -fprofile-abs-path")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CMAKE_Fortran_FLAGS_COVERAGE
|
||||
${COVERAGE_COMPILER_FLAGS}
|
||||
CACHE STRING "Flags used by the Fortran compiler during coverage builds."
|
||||
FORCE )
|
||||
set(CMAKE_CXX_FLAGS_COVERAGE
|
||||
${COVERAGE_COMPILER_FLAGS}
|
||||
CACHE STRING "Flags used by the C++ compiler during coverage builds."
|
||||
FORCE )
|
||||
set(CMAKE_C_FLAGS_COVERAGE
|
||||
${COVERAGE_COMPILER_FLAGS}
|
||||
CACHE STRING "Flags used by the C compiler during coverage builds."
|
||||
FORCE )
|
||||
set(CMAKE_EXE_LINKER_FLAGS_COVERAGE
|
||||
""
|
||||
CACHE STRING "Flags used for linking binaries during coverage builds."
|
||||
FORCE )
|
||||
set(CMAKE_SHARED_LINKER_FLAGS_COVERAGE
|
||||
""
|
||||
CACHE STRING "Flags used by the shared libraries linker during coverage builds."
|
||||
FORCE )
|
||||
mark_as_advanced(
|
||||
CMAKE_Fortran_FLAGS_COVERAGE
|
||||
CMAKE_CXX_FLAGS_COVERAGE
|
||||
CMAKE_C_FLAGS_COVERAGE
|
||||
CMAKE_EXE_LINKER_FLAGS_COVERAGE
|
||||
CMAKE_SHARED_LINKER_FLAGS_COVERAGE )
|
||||
|
||||
get_property(GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)
|
||||
if(NOT (CMAKE_BUILD_TYPE STREQUAL "Debug" OR GENERATOR_IS_MULTI_CONFIG))
|
||||
message(WARNING "Code coverage results with an optimised (non-Debug) build may be misleading")
|
||||
endif() # NOT (CMAKE_BUILD_TYPE STREQUAL "Debug" OR GENERATOR_IS_MULTI_CONFIG)
|
||||
|
||||
if(CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_Fortran_COMPILER_ID STREQUAL "GNU")
|
||||
link_libraries(gcov)
|
||||
endif()
|
||||
|
||||
# Defines a target for running and collection code coverage information
|
||||
# Builds dependencies, runs the given executable and outputs reports.
|
||||
# NOTE! The executable should always have a ZERO as exit code otherwise
|
||||
# the coverage generation will not complete.
|
||||
#
|
||||
# setup_target_for_coverage_lcov(
|
||||
# NAME testrunner_coverage # New target name
|
||||
# EXECUTABLE testrunner -j ${PROCESSOR_COUNT} # Executable in PROJECT_BINARY_DIR
|
||||
# DEPENDENCIES testrunner # Dependencies to build first
|
||||
# BASE_DIRECTORY "../" # Base directory for report
|
||||
# # (defaults to PROJECT_SOURCE_DIR)
|
||||
# EXCLUDE "src/dir1/*" "src/dir2/*" # Patterns to exclude (can be relative
|
||||
# # to BASE_DIRECTORY, with CMake 3.4+)
|
||||
# NO_DEMANGLE # Don't demangle C++ symbols
|
||||
# # even if c++filt is found
|
||||
# )
|
||||
function(setup_target_for_coverage_lcov)
|
||||
|
||||
set(options NO_DEMANGLE SONARQUBE)
|
||||
set(oneValueArgs BASE_DIRECTORY NAME)
|
||||
set(multiValueArgs EXCLUDE EXECUTABLE EXECUTABLE_ARGS DEPENDENCIES LCOV_ARGS GENHTML_ARGS)
|
||||
cmake_parse_arguments(Coverage "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if(NOT LCOV_PATH)
|
||||
message(FATAL_ERROR "lcov not found! Aborting...")
|
||||
endif() # NOT LCOV_PATH
|
||||
|
||||
if(NOT GENHTML_PATH)
|
||||
message(FATAL_ERROR "genhtml not found! Aborting...")
|
||||
endif() # NOT GENHTML_PATH
|
||||
|
||||
# Set base directory (as absolute path), or default to PROJECT_SOURCE_DIR
|
||||
if(DEFINED Coverage_BASE_DIRECTORY)
|
||||
get_filename_component(BASEDIR ${Coverage_BASE_DIRECTORY} ABSOLUTE)
|
||||
else()
|
||||
set(BASEDIR ${PROJECT_SOURCE_DIR})
|
||||
endif()
|
||||
|
||||
# Collect excludes (CMake 3.4+: Also compute absolute paths)
|
||||
set(LCOV_EXCLUDES "")
|
||||
foreach(EXCLUDE ${Coverage_EXCLUDE} ${COVERAGE_EXCLUDES} ${COVERAGE_LCOV_EXCLUDES})
|
||||
if(CMAKE_VERSION VERSION_GREATER 3.4)
|
||||
get_filename_component(EXCLUDE ${EXCLUDE} ABSOLUTE BASE_DIR ${BASEDIR})
|
||||
endif()
|
||||
list(APPEND LCOV_EXCLUDES "${EXCLUDE}")
|
||||
endforeach()
|
||||
list(REMOVE_DUPLICATES LCOV_EXCLUDES)
|
||||
|
||||
# Conditional arguments
|
||||
if(CPPFILT_PATH AND NOT ${Coverage_NO_DEMANGLE})
|
||||
set(GENHTML_EXTRA_ARGS "--demangle-cpp")
|
||||
endif()
|
||||
|
||||
# Setting up commands which will be run to generate coverage data.
|
||||
# Cleanup lcov
|
||||
set(LCOV_CLEAN_CMD
|
||||
${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} -directory .
|
||||
-b ${BASEDIR} --zerocounters
|
||||
)
|
||||
# Create baseline to make sure untouched files show up in the report
|
||||
set(LCOV_BASELINE_CMD
|
||||
${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} -c -i -d . -b
|
||||
${BASEDIR} -o ${Coverage_NAME}.base
|
||||
)
|
||||
# Run tests
|
||||
set(LCOV_EXEC_TESTS_CMD
|
||||
${Coverage_EXECUTABLE} ${Coverage_EXECUTABLE_ARGS}
|
||||
)
|
||||
# Capturing lcov counters and generating report
|
||||
set(LCOV_CAPTURE_CMD
|
||||
${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} --directory . -b
|
||||
${BASEDIR} --capture --output-file ${Coverage_NAME}.capture
|
||||
)
|
||||
# add baseline counters
|
||||
set(LCOV_BASELINE_COUNT_CMD
|
||||
${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} -a ${Coverage_NAME}.base
|
||||
-a ${Coverage_NAME}.capture --output-file ${Coverage_NAME}.total
|
||||
)
|
||||
# filter collected data to final coverage report
|
||||
set(LCOV_FILTER_CMD
|
||||
${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} --remove
|
||||
${Coverage_NAME}.total ${LCOV_EXCLUDES} --output-file ${Coverage_NAME}.info
|
||||
)
|
||||
# Generate HTML output
|
||||
set(LCOV_GEN_HTML_CMD
|
||||
${GENHTML_PATH} ${GENHTML_EXTRA_ARGS} ${Coverage_GENHTML_ARGS} -o
|
||||
${Coverage_NAME} ${Coverage_NAME}.info
|
||||
)
|
||||
if(${Coverage_SONARQUBE})
|
||||
# Generate SonarQube output
|
||||
set(GCOVR_XML_CMD
|
||||
${GCOVR_PATH} --sonarqube ${Coverage_NAME}_sonarqube.xml -r ${BASEDIR} ${GCOVR_ADDITIONAL_ARGS}
|
||||
${GCOVR_EXCLUDE_ARGS} --object-directory=${PROJECT_BINARY_DIR}
|
||||
)
|
||||
set(GCOVR_XML_CMD_COMMAND
|
||||
COMMAND ${GCOVR_XML_CMD}
|
||||
)
|
||||
set(GCOVR_XML_CMD_BYPRODUCTS ${Coverage_NAME}_sonarqube.xml)
|
||||
set(GCOVR_XML_CMD_COMMENT COMMENT "SonarQube code coverage info report saved in ${Coverage_NAME}_sonarqube.xml.")
|
||||
endif()
|
||||
|
||||
|
||||
if(CODE_COVERAGE_VERBOSE)
|
||||
message(STATUS "Executed command report")
|
||||
message(STATUS "Command to clean up lcov: ")
|
||||
string(REPLACE ";" " " LCOV_CLEAN_CMD_SPACED "${LCOV_CLEAN_CMD}")
|
||||
message(STATUS "${LCOV_CLEAN_CMD_SPACED}")
|
||||
|
||||
message(STATUS "Command to create baseline: ")
|
||||
string(REPLACE ";" " " LCOV_BASELINE_CMD_SPACED "${LCOV_BASELINE_CMD}")
|
||||
message(STATUS "${LCOV_BASELINE_CMD_SPACED}")
|
||||
|
||||
message(STATUS "Command to run the tests: ")
|
||||
string(REPLACE ";" " " LCOV_EXEC_TESTS_CMD_SPACED "${LCOV_EXEC_TESTS_CMD}")
|
||||
message(STATUS "${LCOV_EXEC_TESTS_CMD_SPACED}")
|
||||
|
||||
message(STATUS "Command to capture counters and generate report: ")
|
||||
string(REPLACE ";" " " LCOV_CAPTURE_CMD_SPACED "${LCOV_CAPTURE_CMD}")
|
||||
message(STATUS "${LCOV_CAPTURE_CMD_SPACED}")
|
||||
|
||||
message(STATUS "Command to add baseline counters: ")
|
||||
string(REPLACE ";" " " LCOV_BASELINE_COUNT_CMD_SPACED "${LCOV_BASELINE_COUNT_CMD}")
|
||||
message(STATUS "${LCOV_BASELINE_COUNT_CMD_SPACED}")
|
||||
|
||||
message(STATUS "Command to filter collected data: ")
|
||||
string(REPLACE ";" " " LCOV_FILTER_CMD_SPACED "${LCOV_FILTER_CMD}")
|
||||
message(STATUS "${LCOV_FILTER_CMD_SPACED}")
|
||||
|
||||
message(STATUS "Command to generate lcov HTML output: ")
|
||||
string(REPLACE ";" " " LCOV_GEN_HTML_CMD_SPACED "${LCOV_GEN_HTML_CMD}")
|
||||
message(STATUS "${LCOV_GEN_HTML_CMD_SPACED}")
|
||||
|
||||
if(${Coverage_SONARQUBE})
|
||||
message(STATUS "Command to generate SonarQube XML output: ")
|
||||
string(REPLACE ";" " " GCOVR_XML_CMD_SPACED "${GCOVR_XML_CMD}")
|
||||
message(STATUS "${GCOVR_XML_CMD_SPACED}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Setup target
|
||||
add_custom_target(${Coverage_NAME}
|
||||
COMMAND ${LCOV_CLEAN_CMD}
|
||||
COMMAND ${LCOV_BASELINE_CMD}
|
||||
COMMAND ${LCOV_EXEC_TESTS_CMD}
|
||||
COMMAND ${LCOV_CAPTURE_CMD}
|
||||
COMMAND ${LCOV_BASELINE_COUNT_CMD}
|
||||
COMMAND ${LCOV_FILTER_CMD}
|
||||
COMMAND ${LCOV_GEN_HTML_CMD}
|
||||
${GCOVR_XML_CMD_COMMAND}
|
||||
|
||||
# Set output files as GENERATED (will be removed on 'make clean')
|
||||
BYPRODUCTS
|
||||
${Coverage_NAME}.base
|
||||
${Coverage_NAME}.capture
|
||||
${Coverage_NAME}.total
|
||||
${Coverage_NAME}.info
|
||||
${GCOVR_XML_CMD_BYPRODUCTS}
|
||||
${Coverage_NAME}/index.html
|
||||
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
|
||||
DEPENDS ${Coverage_DEPENDENCIES}
|
||||
VERBATIM # Protect arguments to commands
|
||||
COMMENT "Resetting code coverage counters to zero.\nProcessing code coverage counters and generating report."
|
||||
)
|
||||
|
||||
# Show where to find the lcov info report
|
||||
add_custom_command(TARGET ${Coverage_NAME} POST_BUILD
|
||||
COMMAND ;
|
||||
COMMENT "Lcov code coverage info report saved in ${Coverage_NAME}.info."
|
||||
${GCOVR_XML_CMD_COMMENT}
|
||||
)
|
||||
|
||||
# Show info where to find the report
|
||||
add_custom_command(TARGET ${Coverage_NAME} POST_BUILD
|
||||
COMMAND ;
|
||||
COMMENT "Open ./${Coverage_NAME}/index.html in your browser to view the coverage report."
|
||||
)
|
||||
|
||||
endfunction() # setup_target_for_coverage_lcov
|
||||
|
||||
# Defines a target for running and collection code coverage information
|
||||
# Builds dependencies, runs the given executable and outputs reports.
|
||||
# NOTE! The executable should always have a ZERO as exit code otherwise
|
||||
# the coverage generation will not complete.
|
||||
#
|
||||
# setup_target_for_coverage_gcovr_xml(
|
||||
# NAME ctest_coverage # New target name
|
||||
# EXECUTABLE ctest -j ${PROCESSOR_COUNT} # Executable in PROJECT_BINARY_DIR
|
||||
# DEPENDENCIES executable_target # Dependencies to build first
|
||||
# BASE_DIRECTORY "../" # Base directory for report
|
||||
# # (defaults to PROJECT_SOURCE_DIR)
|
||||
# EXCLUDE "src/dir1/*" "src/dir2/*" # Patterns to exclude (can be relative
|
||||
# # to BASE_DIRECTORY, with CMake 3.4+)
|
||||
# )
|
||||
# The user can set the variable GCOVR_ADDITIONAL_ARGS to supply additional flags to the
|
||||
# GCVOR command.
|
||||
function(setup_target_for_coverage_gcovr_xml)
|
||||
|
||||
set(options NONE)
|
||||
set(oneValueArgs BASE_DIRECTORY NAME)
|
||||
set(multiValueArgs EXCLUDE EXECUTABLE EXECUTABLE_ARGS DEPENDENCIES)
|
||||
cmake_parse_arguments(Coverage "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if(NOT GCOVR_PATH)
|
||||
message(FATAL_ERROR "gcovr not found! Aborting...")
|
||||
endif() # NOT GCOVR_PATH
|
||||
|
||||
# Set base directory (as absolute path), or default to PROJECT_SOURCE_DIR
|
||||
if(DEFINED Coverage_BASE_DIRECTORY)
|
||||
get_filename_component(BASEDIR ${Coverage_BASE_DIRECTORY} ABSOLUTE)
|
||||
else()
|
||||
set(BASEDIR ${PROJECT_SOURCE_DIR})
|
||||
endif()
|
||||
|
||||
# Collect excludes (CMake 3.4+: Also compute absolute paths)
|
||||
set(GCOVR_EXCLUDES "")
|
||||
foreach(EXCLUDE ${Coverage_EXCLUDE} ${COVERAGE_EXCLUDES} ${COVERAGE_GCOVR_EXCLUDES})
|
||||
if(CMAKE_VERSION VERSION_GREATER 3.4)
|
||||
get_filename_component(EXCLUDE ${EXCLUDE} ABSOLUTE BASE_DIR ${BASEDIR})
|
||||
endif()
|
||||
list(APPEND GCOVR_EXCLUDES "${EXCLUDE}")
|
||||
endforeach()
|
||||
list(REMOVE_DUPLICATES GCOVR_EXCLUDES)
|
||||
|
||||
# Combine excludes to several -e arguments
|
||||
set(GCOVR_EXCLUDE_ARGS "")
|
||||
foreach(EXCLUDE ${GCOVR_EXCLUDES})
|
||||
list(APPEND GCOVR_EXCLUDE_ARGS "-e")
|
||||
list(APPEND GCOVR_EXCLUDE_ARGS "${EXCLUDE}")
|
||||
endforeach()
|
||||
|
||||
# Set up commands which will be run to generate coverage data
|
||||
# Run tests
|
||||
set(GCOVR_XML_EXEC_TESTS_CMD
|
||||
${Coverage_EXECUTABLE} ${Coverage_EXECUTABLE_ARGS}
|
||||
)
|
||||
# Running gcovr
|
||||
set(GCOVR_XML_CMD
|
||||
${GCOVR_PATH} --xml ${Coverage_NAME}.xml -r ${BASEDIR} ${GCOVR_ADDITIONAL_ARGS}
|
||||
${GCOVR_EXCLUDE_ARGS} --object-directory=${PROJECT_BINARY_DIR}
|
||||
)
|
||||
|
||||
if(CODE_COVERAGE_VERBOSE)
|
||||
message(STATUS "Executed command report")
|
||||
|
||||
message(STATUS "Command to run tests: ")
|
||||
string(REPLACE ";" " " GCOVR_XML_EXEC_TESTS_CMD_SPACED "${GCOVR_XML_EXEC_TESTS_CMD}")
|
||||
message(STATUS "${GCOVR_XML_EXEC_TESTS_CMD_SPACED}")
|
||||
|
||||
message(STATUS "Command to generate gcovr XML coverage data: ")
|
||||
string(REPLACE ";" " " GCOVR_XML_CMD_SPACED "${GCOVR_XML_CMD}")
|
||||
message(STATUS "${GCOVR_XML_CMD_SPACED}")
|
||||
endif()
|
||||
|
||||
add_custom_target(${Coverage_NAME}
|
||||
COMMAND ${GCOVR_XML_EXEC_TESTS_CMD}
|
||||
COMMAND ${GCOVR_XML_CMD}
|
||||
|
||||
BYPRODUCTS ${Coverage_NAME}.xml
|
||||
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
|
||||
DEPENDS ${Coverage_DEPENDENCIES}
|
||||
VERBATIM # Protect arguments to commands
|
||||
COMMENT "Running gcovr to produce Cobertura code coverage report."
|
||||
)
|
||||
|
||||
# Show info where to find the report
|
||||
add_custom_command(TARGET ${Coverage_NAME} POST_BUILD
|
||||
COMMAND ;
|
||||
COMMENT "Cobertura code coverage report saved in ${Coverage_NAME}.xml."
|
||||
)
|
||||
endfunction() # setup_target_for_coverage_gcovr_xml
|
||||
|
||||
# Defines a target for running and collection code coverage information
|
||||
# Builds dependencies, runs the given executable and outputs reports.
|
||||
# NOTE! The executable should always have a ZERO as exit code otherwise
|
||||
# the coverage generation will not complete.
|
||||
#
|
||||
# setup_target_for_coverage_gcovr_html(
|
||||
# NAME ctest_coverage # New target name
|
||||
# EXECUTABLE ctest -j ${PROCESSOR_COUNT} # Executable in PROJECT_BINARY_DIR
|
||||
# DEPENDENCIES executable_target # Dependencies to build first
|
||||
# BASE_DIRECTORY "../" # Base directory for report
|
||||
# # (defaults to PROJECT_SOURCE_DIR)
|
||||
# EXCLUDE "src/dir1/*" "src/dir2/*" # Patterns to exclude (can be relative
|
||||
# # to BASE_DIRECTORY, with CMake 3.4+)
|
||||
# )
|
||||
# The user can set the variable GCOVR_ADDITIONAL_ARGS to supply additional flags to the
|
||||
# GCVOR command.
|
||||
function(setup_target_for_coverage_gcovr_html)
|
||||
|
||||
set(options NONE)
|
||||
set(oneValueArgs BASE_DIRECTORY NAME)
|
||||
set(multiValueArgs EXCLUDE EXECUTABLE EXECUTABLE_ARGS DEPENDENCIES)
|
||||
cmake_parse_arguments(Coverage "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if(NOT GCOVR_PATH)
|
||||
message(FATAL_ERROR "gcovr not found! Aborting...")
|
||||
endif() # NOT GCOVR_PATH
|
||||
|
||||
# Set base directory (as absolute path), or default to PROJECT_SOURCE_DIR
|
||||
if(DEFINED Coverage_BASE_DIRECTORY)
|
||||
get_filename_component(BASEDIR ${Coverage_BASE_DIRECTORY} ABSOLUTE)
|
||||
else()
|
||||
set(BASEDIR ${PROJECT_SOURCE_DIR})
|
||||
endif()
|
||||
|
||||
# Collect excludes (CMake 3.4+: Also compute absolute paths)
|
||||
set(GCOVR_EXCLUDES "")
|
||||
foreach(EXCLUDE ${Coverage_EXCLUDE} ${COVERAGE_EXCLUDES} ${COVERAGE_GCOVR_EXCLUDES})
|
||||
if(CMAKE_VERSION VERSION_GREATER 3.4)
|
||||
get_filename_component(EXCLUDE ${EXCLUDE} ABSOLUTE BASE_DIR ${BASEDIR})
|
||||
endif()
|
||||
list(APPEND GCOVR_EXCLUDES "${EXCLUDE}")
|
||||
endforeach()
|
||||
list(REMOVE_DUPLICATES GCOVR_EXCLUDES)
|
||||
|
||||
# Combine excludes to several -e arguments
|
||||
set(GCOVR_EXCLUDE_ARGS "")
|
||||
foreach(EXCLUDE ${GCOVR_EXCLUDES})
|
||||
list(APPEND GCOVR_EXCLUDE_ARGS "-e")
|
||||
list(APPEND GCOVR_EXCLUDE_ARGS "${EXCLUDE}")
|
||||
endforeach()
|
||||
|
||||
# Set up commands which will be run to generate coverage data
|
||||
# Run tests
|
||||
set(GCOVR_HTML_EXEC_TESTS_CMD
|
||||
${Coverage_EXECUTABLE} ${Coverage_EXECUTABLE_ARGS}
|
||||
)
|
||||
# Create folder
|
||||
set(GCOVR_HTML_FOLDER_CMD
|
||||
${CMAKE_COMMAND} -E make_directory ${PROJECT_BINARY_DIR}/${Coverage_NAME}
|
||||
)
|
||||
# Running gcovr
|
||||
set(GCOVR_HTML_CMD
|
||||
${GCOVR_PATH} --html ${Coverage_NAME}/index.html --html-details -r ${BASEDIR} ${GCOVR_ADDITIONAL_ARGS}
|
||||
${GCOVR_EXCLUDE_ARGS} --object-directory=${PROJECT_BINARY_DIR}
|
||||
)
|
||||
|
||||
if(CODE_COVERAGE_VERBOSE)
|
||||
message(STATUS "Executed command report")
|
||||
|
||||
message(STATUS "Command to run tests: ")
|
||||
string(REPLACE ";" " " GCOVR_HTML_EXEC_TESTS_CMD_SPACED "${GCOVR_HTML_EXEC_TESTS_CMD}")
|
||||
message(STATUS "${GCOVR_HTML_EXEC_TESTS_CMD_SPACED}")
|
||||
|
||||
message(STATUS "Command to create a folder: ")
|
||||
string(REPLACE ";" " " GCOVR_HTML_FOLDER_CMD_SPACED "${GCOVR_HTML_FOLDER_CMD}")
|
||||
message(STATUS "${GCOVR_HTML_FOLDER_CMD_SPACED}")
|
||||
|
||||
message(STATUS "Command to generate gcovr HTML coverage data: ")
|
||||
string(REPLACE ";" " " GCOVR_HTML_CMD_SPACED "${GCOVR_HTML_CMD}")
|
||||
message(STATUS "${GCOVR_HTML_CMD_SPACED}")
|
||||
endif()
|
||||
|
||||
add_custom_target(${Coverage_NAME}
|
||||
COMMAND ${GCOVR_HTML_EXEC_TESTS_CMD}
|
||||
COMMAND ${GCOVR_HTML_FOLDER_CMD}
|
||||
COMMAND ${GCOVR_HTML_CMD}
|
||||
|
||||
BYPRODUCTS ${PROJECT_BINARY_DIR}/${Coverage_NAME}/index.html # report directory
|
||||
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
|
||||
DEPENDS ${Coverage_DEPENDENCIES}
|
||||
VERBATIM # Protect arguments to commands
|
||||
COMMENT "Running gcovr to produce HTML code coverage report."
|
||||
)
|
||||
|
||||
# Show info where to find the report
|
||||
add_custom_command(TARGET ${Coverage_NAME} POST_BUILD
|
||||
COMMAND ;
|
||||
COMMENT "Open ./${Coverage_NAME}/index.html in your browser to view the coverage report."
|
||||
)
|
||||
|
||||
endfunction() # setup_target_for_coverage_gcovr_html
|
||||
|
||||
# Defines a target for running and collection code coverage information
|
||||
# Builds dependencies, runs the given executable and outputs reports.
|
||||
# NOTE! The executable should always have a ZERO as exit code otherwise
|
||||
# the coverage generation will not complete.
|
||||
#
|
||||
# setup_target_for_coverage_fastcov(
|
||||
# NAME testrunner_coverage # New target name
|
||||
# EXECUTABLE testrunner -j ${PROCESSOR_COUNT} # Executable in PROJECT_BINARY_DIR
|
||||
# DEPENDENCIES testrunner # Dependencies to build first
|
||||
# BASE_DIRECTORY "../" # Base directory for report
|
||||
# # (defaults to PROJECT_SOURCE_DIR)
|
||||
# EXCLUDE "src/dir1/" "src/dir2/" # Patterns to exclude.
|
||||
# NO_DEMANGLE # Don't demangle C++ symbols
|
||||
# # even if c++filt is found
|
||||
# SKIP_HTML # Don't create html report
|
||||
# POST_CMD perl -i -pe s!${PROJECT_SOURCE_DIR}/!!g ctest_coverage.json # E.g. for stripping source dir from file paths
|
||||
# )
|
||||
function(setup_target_for_coverage_fastcov)
|
||||
|
||||
set(options NO_DEMANGLE SKIP_HTML)
|
||||
set(oneValueArgs BASE_DIRECTORY NAME)
|
||||
set(multiValueArgs EXCLUDE EXECUTABLE EXECUTABLE_ARGS DEPENDENCIES FASTCOV_ARGS GENHTML_ARGS POST_CMD)
|
||||
cmake_parse_arguments(Coverage "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if(NOT FASTCOV_PATH)
|
||||
message(FATAL_ERROR "fastcov not found! Aborting...")
|
||||
endif()
|
||||
|
||||
if(NOT Coverage_SKIP_HTML AND NOT GENHTML_PATH)
|
||||
message(FATAL_ERROR "genhtml not found! Aborting...")
|
||||
endif()
|
||||
|
||||
# Set base directory (as absolute path), or default to PROJECT_SOURCE_DIR
|
||||
if(Coverage_BASE_DIRECTORY)
|
||||
get_filename_component(BASEDIR ${Coverage_BASE_DIRECTORY} ABSOLUTE)
|
||||
else()
|
||||
set(BASEDIR ${PROJECT_SOURCE_DIR})
|
||||
endif()
|
||||
|
||||
# Collect excludes (Patterns, not paths, for fastcov)
|
||||
set(FASTCOV_EXCLUDES "")
|
||||
foreach(EXCLUDE ${Coverage_EXCLUDE} ${COVERAGE_EXCLUDES} ${COVERAGE_FASTCOV_EXCLUDES})
|
||||
list(APPEND FASTCOV_EXCLUDES "${EXCLUDE}")
|
||||
endforeach()
|
||||
list(REMOVE_DUPLICATES FASTCOV_EXCLUDES)
|
||||
|
||||
# Conditional arguments
|
||||
if(CPPFILT_PATH AND NOT ${Coverage_NO_DEMANGLE})
|
||||
set(GENHTML_EXTRA_ARGS "--demangle-cpp")
|
||||
endif()
|
||||
|
||||
# Set up commands which will be run to generate coverage data
|
||||
set(FASTCOV_EXEC_TESTS_CMD ${Coverage_EXECUTABLE} ${Coverage_EXECUTABLE_ARGS})
|
||||
|
||||
set(FASTCOV_CAPTURE_CMD ${FASTCOV_PATH} ${Coverage_FASTCOV_ARGS} --gcov ${GCOV_PATH}
|
||||
--search-directory ${BASEDIR}
|
||||
--process-gcno
|
||||
--output ${Coverage_NAME}.json
|
||||
--exclude ${FASTCOV_EXCLUDES}
|
||||
)
|
||||
|
||||
set(FASTCOV_CONVERT_CMD ${FASTCOV_PATH}
|
||||
-C ${Coverage_NAME}.json --lcov --output ${Coverage_NAME}.info
|
||||
)
|
||||
|
||||
if(Coverage_SKIP_HTML)
|
||||
set(FASTCOV_HTML_CMD ";")
|
||||
else()
|
||||
set(FASTCOV_HTML_CMD ${GENHTML_PATH} ${GENHTML_EXTRA_ARGS} ${Coverage_GENHTML_ARGS}
|
||||
-o ${Coverage_NAME} ${Coverage_NAME}.info
|
||||
)
|
||||
endif()
|
||||
|
||||
set(FASTCOV_POST_CMD ";")
|
||||
if(Coverage_POST_CMD)
|
||||
set(FASTCOV_POST_CMD ${Coverage_POST_CMD})
|
||||
endif()
|
||||
|
||||
if(CODE_COVERAGE_VERBOSE)
|
||||
message(STATUS "Code coverage commands for target ${Coverage_NAME} (fastcov):")
|
||||
|
||||
message(" Running tests:")
|
||||
string(REPLACE ";" " " FASTCOV_EXEC_TESTS_CMD_SPACED "${FASTCOV_EXEC_TESTS_CMD}")
|
||||
message(" ${FASTCOV_EXEC_TESTS_CMD_SPACED}")
|
||||
|
||||
message(" Capturing fastcov counters and generating report:")
|
||||
string(REPLACE ";" " " FASTCOV_CAPTURE_CMD_SPACED "${FASTCOV_CAPTURE_CMD}")
|
||||
message(" ${FASTCOV_CAPTURE_CMD_SPACED}")
|
||||
|
||||
message(" Converting fastcov .json to lcov .info:")
|
||||
string(REPLACE ";" " " FASTCOV_CONVERT_CMD_SPACED "${FASTCOV_CONVERT_CMD}")
|
||||
message(" ${FASTCOV_CONVERT_CMD_SPACED}")
|
||||
|
||||
if(NOT Coverage_SKIP_HTML)
|
||||
message(" Generating HTML report: ")
|
||||
string(REPLACE ";" " " FASTCOV_HTML_CMD_SPACED "${FASTCOV_HTML_CMD}")
|
||||
message(" ${FASTCOV_HTML_CMD_SPACED}")
|
||||
endif()
|
||||
if(Coverage_POST_CMD)
|
||||
message(" Running post command: ")
|
||||
string(REPLACE ";" " " FASTCOV_POST_CMD_SPACED "${FASTCOV_POST_CMD}")
|
||||
message(" ${FASTCOV_POST_CMD_SPACED}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Setup target
|
||||
add_custom_target(${Coverage_NAME}
|
||||
|
||||
# Cleanup fastcov
|
||||
COMMAND ${FASTCOV_PATH} ${Coverage_FASTCOV_ARGS} --gcov ${GCOV_PATH}
|
||||
--search-directory ${BASEDIR}
|
||||
--zerocounters
|
||||
|
||||
COMMAND ${FASTCOV_EXEC_TESTS_CMD}
|
||||
COMMAND ${FASTCOV_CAPTURE_CMD}
|
||||
COMMAND ${FASTCOV_CONVERT_CMD}
|
||||
COMMAND ${FASTCOV_HTML_CMD}
|
||||
COMMAND ${FASTCOV_POST_CMD}
|
||||
|
||||
# Set output files as GENERATED (will be removed on 'make clean')
|
||||
BYPRODUCTS
|
||||
${Coverage_NAME}.info
|
||||
${Coverage_NAME}.json
|
||||
${Coverage_NAME}/index.html # report directory
|
||||
|
||||
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
|
||||
DEPENDS ${Coverage_DEPENDENCIES}
|
||||
VERBATIM # Protect arguments to commands
|
||||
COMMENT "Resetting code coverage counters to zero. Processing code coverage counters and generating report."
|
||||
)
|
||||
|
||||
set(INFO_MSG "fastcov code coverage info report saved in ${Coverage_NAME}.info and ${Coverage_NAME}.json.")
|
||||
if(NOT Coverage_SKIP_HTML)
|
||||
string(APPEND INFO_MSG " Open ${PROJECT_BINARY_DIR}/${Coverage_NAME}/index.html in your browser to view the coverage report.")
|
||||
endif()
|
||||
# Show where to find the fastcov info report
|
||||
add_custom_command(TARGET ${Coverage_NAME} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E echo ${INFO_MSG}
|
||||
)
|
||||
|
||||
endfunction() # setup_target_for_coverage_fastcov
|
||||
|
||||
function(append_coverage_compiler_flags)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${COVERAGE_COMPILER_FLAGS}" PARENT_SCOPE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${COVERAGE_COMPILER_FLAGS}" PARENT_SCOPE)
|
||||
set(CMAKE_Fortran_FLAGS "${CMAKE_Fortran_FLAGS} ${COVERAGE_COMPILER_FLAGS}" PARENT_SCOPE)
|
||||
message(STATUS "Appending code coverage compiler flags: ${COVERAGE_COMPILER_FLAGS}")
|
||||
endfunction() # append_coverage_compiler_flags
|
||||
|
||||
# Setup coverage for specific library
|
||||
function(append_coverage_compiler_flags_to_target name)
|
||||
separate_arguments(_flag_list NATIVE_COMMAND "${COVERAGE_COMPILER_FLAGS}")
|
||||
target_compile_options(${name} PRIVATE ${_flag_list})
|
||||
if(CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_Fortran_COMPILER_ID STREQUAL "GNU")
|
||||
target_link_libraries(${name} PRIVATE gcov)
|
||||
endif()
|
||||
endfunction()
|
@@ -1,22 +0,0 @@
|
||||
if(ENABLE_CLANG_TIDY)
|
||||
find_program(CLANG_TIDY_COMMAND NAMES clang-tidy)
|
||||
|
||||
if(NOT CLANG_TIDY_COMMAND)
|
||||
message(WARNING "🔴 CMake_RUN_CLANG_TIDY is ON but clang-tidy is not found!")
|
||||
set(CMAKE_CXX_CLANG_TIDY "" CACHE STRING "" FORCE)
|
||||
else()
|
||||
|
||||
message(STATUS "🟢 CMake_RUN_CLANG_TIDY is ON")
|
||||
set(CLANGTIDY_EXTRA_ARGS
|
||||
"-extra-arg=-Wno-unknown-warning-option"
|
||||
)
|
||||
set(CMAKE_CXX_CLANG_TIDY "${CLANG_TIDY_COMMAND};-p=${CMAKE_BINARY_DIR};${CLANGTIDY_EXTRA_ARGS}" CACHE STRING "" FORCE)
|
||||
|
||||
add_custom_target(clang-tidy
|
||||
COMMAND ${CMAKE_COMMAND} --build ${CMAKE_BINARY_DIR} --target ${CMAKE_PROJECT_NAME}
|
||||
COMMAND ${CMAKE_COMMAND} --build ${CMAKE_BINARY_DIR} --target clang-tidy
|
||||
COMMENT "Running clang-tidy..."
|
||||
)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
endif()
|
||||
endif(ENABLE_CLANG_TIDY)
|
10
conandata.yml
Normal file
10
conandata.yml
Normal file
@@ -0,0 +1,10 @@
|
||||
sources:
|
||||
"1.1.2":
|
||||
url: "https://github.com/rmontanana/BayesNet/archive/v1.1.2.tar.gz"
|
||||
sha256: "placeholder_sha256" # Replace with actual SHA256 when releasing
|
||||
"1.0.7":
|
||||
url: "https://github.com/rmontanana/BayesNet/archive/v1.0.7.tar.gz"
|
||||
sha256: "placeholder_sha256" # Replace with actual SHA256 when releasing
|
||||
|
||||
patches:
|
||||
# Add patches here if needed for specific versions
|
108
conanfile.py
Normal file
108
conanfile.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os, re, pathlib
|
||||
from conan import ConanFile
|
||||
from conan.tools.cmake import CMakeToolchain, CMake, cmake_layout, CMakeDeps
|
||||
from conan.tools.files import copy
|
||||
|
||||
|
||||
class BayesNetConan(ConanFile):
|
||||
name = "bayesnet"
|
||||
settings = "os", "compiler", "build_type", "arch"
|
||||
options = {
|
||||
"shared": [True, False],
|
||||
"fPIC": [True, False],
|
||||
"enable_testing": [True, False],
|
||||
"enable_coverage": [True, False],
|
||||
}
|
||||
default_options = {
|
||||
"shared": False,
|
||||
"fPIC": True,
|
||||
"enable_testing": False,
|
||||
"enable_coverage": False,
|
||||
}
|
||||
|
||||
# Sources are located in the same place as this recipe, copy them to the recipe
|
||||
exports_sources = (
|
||||
"CMakeLists.txt",
|
||||
"bayesnet/*",
|
||||
"config/*",
|
||||
"cmake/*",
|
||||
"docs/*",
|
||||
"tests/*",
|
||||
"bayesnetConfig.cmake.in",
|
||||
)
|
||||
|
||||
def set_version(self) -> None:
|
||||
cmake = pathlib.Path(self.recipe_folder) / "CMakeLists.txt"
|
||||
text = cmake.read_text(encoding="utf-8")
|
||||
|
||||
# Accept either: project(foo VERSION 1.2.3) or set(foo_VERSION 1.2.3)
|
||||
match = re.search(
|
||||
r"""project\s*\([^\)]*VERSION\s+([0-9]+\.[0-9]+\.[0-9]+)""",
|
||||
text,
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
if match:
|
||||
self.version = match.group(1)
|
||||
else:
|
||||
raise Exception("Version not found in CMakeLists.txt")
|
||||
self.version = match.group(1)
|
||||
|
||||
def config_options(self):
|
||||
if self.settings.os == "Windows":
|
||||
del self.options.fPIC
|
||||
|
||||
def configure(self):
|
||||
if self.options.shared:
|
||||
self.options.rm_safe("fPIC")
|
||||
|
||||
def requirements(self):
|
||||
# Core dependencies
|
||||
self.requires("libtorch/2.7.1")
|
||||
self.requires("nlohmann_json/3.11.3")
|
||||
self.requires("folding/1.1.2") # Custom package
|
||||
self.requires("fimdlp/2.1.1") # Custom package
|
||||
|
||||
def build_requirements(self):
|
||||
self.build_requires("cmake/[>=3.27]")
|
||||
self.test_requires("arff-files/1.2.1") # Custom package
|
||||
self.test_requires("catch2/3.8.1")
|
||||
|
||||
def layout(self):
|
||||
cmake_layout(self)
|
||||
|
||||
def generate(self):
|
||||
deps = CMakeDeps(self)
|
||||
deps.generate()
|
||||
tc = CMakeToolchain(self)
|
||||
tc.variables["ENABLE_TESTING"] = self.options.enable_testing
|
||||
tc.variables["CODE_COVERAGE"] = self.options.enable_coverage
|
||||
tc.generate()
|
||||
|
||||
def build(self):
|
||||
cmake = CMake(self)
|
||||
cmake.configure()
|
||||
cmake.build()
|
||||
|
||||
if self.options.enable_testing:
|
||||
# Run tests only if we're building with testing enabled
|
||||
self.run("ctest --output-on-failure", cwd=self.build_folder)
|
||||
|
||||
def package(self):
|
||||
copy(
|
||||
self,
|
||||
"LICENSE",
|
||||
src=self.source_folder,
|
||||
dst=os.path.join(self.package_folder, "licenses"),
|
||||
)
|
||||
cmake = CMake(self)
|
||||
cmake.install()
|
||||
|
||||
def package_info(self):
|
||||
self.cpp_info.libs = ["bayesnet"]
|
||||
self.cpp_info.includedirs = ["include"]
|
||||
self.cpp_info.set_property("cmake_find_mode", "both")
|
||||
self.cpp_info.set_property("cmake_target_name", "bayesnet::bayesnet")
|
||||
|
||||
# Add compiler flags that might be needed
|
||||
if self.settings.os == "Linux":
|
||||
self.cpp_info.system_libs = ["pthread"]
|
@@ -3,10 +3,6 @@
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
#define PROJECT_VERSION_MAJOR @PROJECT_VERSION_MAJOR@
|
||||
#define PROJECT_VERSION_MINOR @PROJECT_VERSION_MINOR@
|
||||
#define PROJECT_VERSION_PATCH @PROJECT_VERSION_PATCH@
|
||||
|
||||
static constexpr std::string_view project_name = "@PROJECT_NAME@";
|
||||
static constexpr std::string_view project_version = "@PROJECT_VERSION@";
|
||||
static constexpr std::string_view project_description = "@PROJECT_DESCRIPTION@";
|
||||
|
235
local_discretization_analysis.md
Normal file
235
local_discretization_analysis.md
Normal file
@@ -0,0 +1,235 @@
|
||||
# Local Discretization Analysis - BayesNet Library
|
||||
|
||||
## Overview
|
||||
|
||||
This document analyzes the local discretization implementation in the BayesNet library, specifically focusing on the `Proposal.cc` implementation, and evaluates the feasibility of implementing an iterative discretization approach.
|
||||
|
||||
## Current Local Discretization Implementation
|
||||
|
||||
### Core Architecture
|
||||
|
||||
The local discretization functionality is implemented through a **Proposal class** (`bayesnet/classifiers/Proposal.h`) that serves as a mixin/base class for creating "Ld" (Local Discretization) variants of existing classifiers.
|
||||
|
||||
### Key Components
|
||||
|
||||
#### 1. The Proposal Class
|
||||
- **Purpose**: Handles continuous data by applying local discretization using discretization algorithms
|
||||
- **Dependencies**: Uses the `fimdlp` library for discretization algorithms
|
||||
- **Supported Algorithms**:
|
||||
- **MDLP** (Minimum Description Length Principle) - Default
|
||||
- **BINQ** - Quantile-based binning
|
||||
- **BINU** - Uniform binning
|
||||
|
||||
#### 2. Local Discretization Variants
|
||||
|
||||
The codebase implements Ld variants using multiple inheritance:
|
||||
|
||||
**Individual Classifiers:**
|
||||
- `TANLd` - Tree Augmented Naive Bayes with Local Discretization
|
||||
- `KDBLd` - K-Dependence Bayesian with Local Discretization
|
||||
- `SPODELd` - Super-Parent One-Dependence Estimator with Local Discretization
|
||||
|
||||
**Ensemble Classifiers:**
|
||||
- `AODELd` - Averaged One-Dependence Estimator with Local Discretization
|
||||
|
||||
### Implementation Pattern
|
||||
|
||||
All Ld variants follow a consistent pattern using **multiple inheritance**:
|
||||
|
||||
```cpp
|
||||
class TANLd : public TAN, public Proposal {
|
||||
// Inherits from both the base classifier and Proposal
|
||||
};
|
||||
```
|
||||
|
||||
### Two-Phase Discretization Process
|
||||
|
||||
#### Phase 1: Initial Discretization (`fit_local_discretization`)
|
||||
- Each continuous feature is discretized independently using the chosen algorithm
|
||||
- Creates initial discrete dataset
|
||||
- Uses only class labels for discretization decisions
|
||||
|
||||
#### Phase 2: Network-Aware Refinement (`localDiscretizationProposal`)
|
||||
- After building the initial Bayesian network structure
|
||||
- Features are re-discretized considering their parent nodes in the network
|
||||
- Uses topological ordering to ensure proper dependency handling
|
||||
- Creates more informed discretization boundaries based on network relationships
|
||||
|
||||
### Hyperparameter Support
|
||||
|
||||
The Proposal class supports several configurable hyperparameters:
|
||||
- `ld_algorithm`: Choice of discretization algorithm (MDLP, BINQ, BINU)
|
||||
- `ld_proposed_cuts`: Number of proposed cuts for discretization
|
||||
- `mdlp_min_length`: Minimum interval length for MDLP
|
||||
- `mdlp_max_depth`: Maximum depth for MDLP tree
|
||||
|
||||
## Current Implementation Strengths
|
||||
|
||||
1. **Sophisticated Approach**: Considers network structure in discretization decisions
|
||||
2. **Modular Design**: Clean separation through Proposal class mixin
|
||||
3. **Multiple Algorithm Support**: Flexible discretization strategies
|
||||
4. **Proper Dependency Handling**: Topological ordering ensures correct processing
|
||||
5. **Well-Integrated**: Seamless integration with existing classifier architecture
|
||||
|
||||
## Areas for Improvement
|
||||
|
||||
### Code Quality Issues
|
||||
|
||||
1. **Dead Code**: Line 161 in `Proposal.cc` contains unused variable `allDigits`
|
||||
2. **Performance Issues**:
|
||||
- String concatenation in tight loop (lines 82-84) using `+=` operator
|
||||
- Memory allocations could be optimized
|
||||
- Tensor operations could be batched better
|
||||
3. **Error Handling**: Could be more robust with better exception handling
|
||||
|
||||
### Algorithm Clarity
|
||||
|
||||
1. **Logic Clarity**: The `upgrade` flag logic could be more descriptive
|
||||
2. **Variable Naming**: Some variables need more descriptive names
|
||||
3. **Documentation**: Better inline documentation of the two-phase process
|
||||
4. **Method Complexity**: `localDiscretizationProposal` method is quite long and complex
|
||||
|
||||
### Suggested Code Improvements
|
||||
|
||||
```cpp
|
||||
// Instead of string concatenation in loop:
|
||||
for (auto idx : indices) {
|
||||
for (int i = 0; i < Xf.size(1); ++i) {
|
||||
yJoinParents[i] += to_string(pDataset.index({ idx, i }).item<int>());
|
||||
}
|
||||
}
|
||||
|
||||
// Consider using stringstream or pre-allocation:
|
||||
std::stringstream ss;
|
||||
for (auto idx : indices) {
|
||||
for (int i = 0; i < Xf.size(1); ++i) {
|
||||
ss << pDataset.index({ idx, i }).item<int>();
|
||||
yJoinParents[i] = ss.str();
|
||||
ss.str("");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Iterative Discretization Proposal
|
||||
|
||||
### Concept
|
||||
|
||||
Implement an iterative process: discretize → build model → re-discretize → rebuild model → repeat until convergence.
|
||||
|
||||
### Feasibility Assessment
|
||||
|
||||
**Highly Feasible** - The current implementation already provides a solid foundation with its two-phase approach, making extension straightforward.
|
||||
|
||||
### Proposed Implementation Strategy
|
||||
|
||||
```cpp
|
||||
class IterativeProposal : public Proposal {
|
||||
public:
|
||||
struct ConvergenceParams {
|
||||
int max_iterations = 10;
|
||||
double tolerance = 1e-6;
|
||||
bool check_network_structure = true;
|
||||
bool check_discretization_stability = true;
|
||||
};
|
||||
|
||||
private:
|
||||
map<string, vector<int>> iterativeLocalDiscretization(const torch::Tensor& y) {
|
||||
auto states = fit_local_discretization(y); // Initial discretization
|
||||
Network previousModel, currentModel;
|
||||
int iteration = 0;
|
||||
|
||||
do {
|
||||
previousModel = currentModel;
|
||||
|
||||
// Build model with current discretization
|
||||
const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble);
|
||||
currentModel.fit(pDataset, weights, pFeatures, pClassName, states, Smoothing_t::ORIGINAL);
|
||||
|
||||
// Apply local discretization based on current model
|
||||
auto newStates = localDiscretizationProposal(states, currentModel);
|
||||
|
||||
// Check for convergence
|
||||
if (hasConverged(previousModel, currentModel, states, newStates)) {
|
||||
break;
|
||||
}
|
||||
|
||||
states = newStates;
|
||||
iteration++;
|
||||
|
||||
} while (iteration < convergenceParams.max_iterations);
|
||||
|
||||
return states;
|
||||
}
|
||||
|
||||
bool hasConverged(const Network& prev, const Network& curr,
|
||||
const map<string, vector<int>>& oldStates,
|
||||
const map<string, vector<int>>& newStates) {
|
||||
// Implementation of convergence criteria
|
||||
return checkNetworkStructureConvergence(prev, curr) &&
|
||||
checkDiscretizationStability(oldStates, newStates);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Convergence Criteria Options
|
||||
|
||||
1. **Network Structure Comparison**: Compare edge sets between iterations
|
||||
```cpp
|
||||
bool checkNetworkStructureConvergence(const Network& prev, const Network& curr) {
|
||||
// Compare adjacency matrices or edge lists
|
||||
return prev.getEdges() == curr.getEdges();
|
||||
}
|
||||
```
|
||||
|
||||
2. **Discretization Stability**: Check if cut points change significantly
|
||||
```cpp
|
||||
bool checkDiscretizationStability(const map<string, vector<int>>& oldStates,
|
||||
const map<string, vector<int>>& newStates) {
|
||||
for (const auto& [feature, states] : oldStates) {
|
||||
if (states != newStates.at(feature)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
```
|
||||
|
||||
3. **Performance Metrics**: Monitor accuracy/likelihood convergence
|
||||
4. **Maximum Iterations**: Prevent infinite loops
|
||||
|
||||
### Expected Benefits
|
||||
|
||||
1. **Better Discretization Quality**: Each iteration refines boundaries based on learned dependencies
|
||||
2. **Improved Model Accuracy**: More informed discretization leads to better classification
|
||||
3. **Adaptive Process**: Automatically finds optimal discretization-model combination
|
||||
4. **Principled Approach**: Theoretically sound iterative refinement
|
||||
5. **Reduced Manual Tuning**: Less need for hyperparameter optimization
|
||||
|
||||
### Implementation Considerations
|
||||
|
||||
1. **Convergence Detection**: Need robust criteria to detect when to stop
|
||||
2. **Performance Impact**: Multiple iterations increase computational cost
|
||||
3. **Overfitting Prevention**: May need regularization to avoid over-discretization
|
||||
4. **Stability Guarantees**: Ensure the process doesn't oscillate indefinitely
|
||||
5. **Memory Management**: Handle multiple model instances efficiently
|
||||
|
||||
### Integration Strategy
|
||||
|
||||
1. **Backward Compatibility**: Keep existing two-phase approach as default
|
||||
2. **Optional Feature**: Add iterative mode as configurable option
|
||||
3. **Hyperparameter Extension**: Add convergence-related parameters
|
||||
4. **Testing Framework**: Comprehensive testing on standard datasets
|
||||
|
||||
## Conclusion
|
||||
|
||||
The current local discretization implementation in BayesNet is well-designed and functional, providing a solid foundation for the proposed iterative enhancement. The iterative approach would significantly improve the quality of discretization by creating a feedback loop between model structure and discretization decisions.
|
||||
|
||||
The implementation is highly feasible given the existing architecture, and the expected benefits justify the additional computational complexity. The key to success will be implementing robust convergence criteria and maintaining the modularity of the current design.
|
||||
|
||||
## Recommendations
|
||||
|
||||
1. **Immediate Improvements**: Fix code quality issues and optimize performance bottlenecks
|
||||
2. **Iterative Implementation**: Develop the iterative approach as an optional enhancement
|
||||
3. **Comprehensive Testing**: Validate improvements on standard benchmark datasets
|
||||
4. **Documentation**: Enhance inline documentation and user guides
|
||||
5. **Performance Profiling**: Monitor computational overhead and optimize where necessary
|
@@ -4,37 +4,19 @@ project(bayesnet_sample VERSION 0.1.0 LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
|
||||
find_package(Torch CONFIG REQUIRED)
|
||||
find_package(fimdlp CONFIG REQUIRED)
|
||||
find_package(folding CONFIG REQUIRED)
|
||||
find_package(arff-files CONFIG REQUIRED)
|
||||
find_package(nlohmann_json CONFIG REQUIRED)
|
||||
|
||||
option(BAYESNET_VCPKG_CONFIG "Use vcpkg config for BayesNet" ON)
|
||||
|
||||
if (BAYESNET_VCPKG_CONFIG)
|
||||
message(STATUS "Using BayesNet vcpkg config")
|
||||
find_package(bayesnet CONFIG REQUIRED)
|
||||
set(BayesNet_LIBRARIES bayesnet::bayesnet)
|
||||
else(BAYESNET_VCPKG_CONFIG)
|
||||
message(STATUS "Using BayesNet local library config")
|
||||
find_library(bayesnet NAMES libbayesnet bayesnet libbayesnet.a PATHS ${Platform_SOURCE_DIR}/../lib/lib REQUIRED)
|
||||
find_path(Bayesnet_INCLUDE_DIRS REQUIRED NAMES bayesnet PATHS ${Platform_SOURCE_DIR}/../lib/include)
|
||||
add_library(bayesnet::bayesnet UNKNOWN IMPORTED)
|
||||
set_target_properties(bayesnet::bayesnet PROPERTIES
|
||||
IMPORTED_LOCATION ${bayesnet}
|
||||
INTERFACE_INCLUDE_DIRECTORIES ${Bayesnet_INCLUDE_DIRS}
|
||||
)
|
||||
endif(BAYESNET_VCPKG_CONFIG)
|
||||
message(STATUS "BayesNet: ${bayesnet}")
|
||||
find_package(nlohmann_json REQUIRED)
|
||||
find_package(bayesnet CONFIG REQUIRED)
|
||||
|
||||
add_executable(bayesnet_sample sample.cc)
|
||||
target_link_libraries(bayesnet_sample PRIVATE
|
||||
fimdlp::fimdlp
|
||||
arff-files::arff-files
|
||||
"${TORCH_LIBRARIES}"
|
||||
torch::torch
|
||||
bayesnet::bayesnet
|
||||
folding::folding
|
||||
nlohmann_json::nlohmann_json
|
||||
)
|
||||
|
9
sample/CMakeUserPresets.json
Normal file
9
sample/CMakeUserPresets.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"version": 4,
|
||||
"vendor": {
|
||||
"conan": {}
|
||||
},
|
||||
"include": [
|
||||
"build/CMakePresets.json"
|
||||
]
|
||||
}
|
14
sample/conanfile.txt
Normal file
14
sample/conanfile.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
[requires]
|
||||
libtorch/2.7.0
|
||||
arff-files/1.2.0
|
||||
fimdlp/2.1.0
|
||||
folding/1.1.1
|
||||
bayesnet/1.2.0
|
||||
nlohmann_json/3.11.3
|
||||
|
||||
[generators]
|
||||
CMakeToolchain
|
||||
CMakeDeps
|
||||
|
||||
[options]
|
||||
libtorch/2.7.0:shared=True
|
@@ -6,7 +6,7 @@
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <ArffFiles/ArffFiles.hpp>
|
||||
#include <ArffFiles.hpp>
|
||||
#include <fimdlp/CPPFImdlp.h>
|
||||
#include <bayesnet/classifiers/TANLd.h>
|
||||
#include <bayesnet/classifiers/KDBLd.h>
|
||||
|
@@ -2,12 +2,13 @@ if(ENABLE_TESTING)
|
||||
include_directories(
|
||||
${BayesNet_SOURCE_DIR}
|
||||
${CMAKE_BINARY_DIR}/configured_files/include
|
||||
${nlohmann_json_INCLUDE_DIRS}
|
||||
)
|
||||
file(GLOB_RECURSE BayesNet_SOURCES "${bayesnet_SOURCE_DIR}/bayesnet/*.cc")
|
||||
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc TestXSPnDE.cc TestXBA2DE.cc
|
||||
TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestBoostAODE.cc TestXBAODE.cc TestA2DE.cc
|
||||
TestUtils.cc TestBayesEnsemble.cc TestModulesVersions.cc TestBoostA2DE.cc TestMST.cc TestXSPODE.cc ${BayesNet_SOURCES})
|
||||
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" fimdlp::fimdlp PRIVATE Catch2::Catch2WithMain)
|
||||
target_link_libraries(TestBayesNet PRIVATE torch::torch fimdlp::fimdlp Catch2::Catch2WithMain folding::folding)
|
||||
add_test(NAME BayesNetworkTest COMMAND TestBayesNet)
|
||||
add_test(NAME A2DE COMMAND TestBayesNet "[A2DE]")
|
||||
add_test(NAME BoostA2DE COMMAND TestBayesNet "[BoostA2DE]")
|
||||
|
@@ -20,7 +20,7 @@
|
||||
#include "bayesnet/ensembles/AODELd.h"
|
||||
#include "bayesnet/ensembles/BoostAODE.h"
|
||||
|
||||
const std::string ACTUAL_VERSION = "1.1.2";
|
||||
const std::string ACTUAL_VERSION = "1.2.1";
|
||||
|
||||
TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
{
|
||||
@@ -31,9 +31,9 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
{{"diabetes", "SPODE"}, 0.802083},
|
||||
{{"diabetes", "TAN"}, 0.821615},
|
||||
{{"diabetes", "AODELd"}, 0.8125f},
|
||||
{{"diabetes", "KDBLd"}, 0.80208f},
|
||||
{{"diabetes", "KDBLd"}, 0.804688f},
|
||||
{{"diabetes", "SPODELd"}, 0.7890625f},
|
||||
{{"diabetes", "TANLd"}, 0.803385437f},
|
||||
{{"diabetes", "TANLd"}, 0.8125f},
|
||||
{{"diabetes", "BoostAODE"}, 0.83984f},
|
||||
// Ecoli
|
||||
{{"ecoli", "AODE"}, 0.889881},
|
||||
@@ -42,9 +42,9 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
{{"ecoli", "SPODE"}, 0.880952},
|
||||
{{"ecoli", "TAN"}, 0.892857},
|
||||
{{"ecoli", "AODELd"}, 0.875f},
|
||||
{{"ecoli", "KDBLd"}, 0.880952358f},
|
||||
{{"ecoli", "KDBLd"}, 0.872024f},
|
||||
{{"ecoli", "SPODELd"}, 0.839285731f},
|
||||
{{"ecoli", "TANLd"}, 0.848214269f},
|
||||
{{"ecoli", "TANLd"}, 0.869047642f},
|
||||
{{"ecoli", "BoostAODE"}, 0.89583f},
|
||||
// Glass
|
||||
{{"glass", "AODE"}, 0.79439},
|
||||
@@ -53,9 +53,9 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
{{"glass", "SPODE"}, 0.775701},
|
||||
{{"glass", "TAN"}, 0.827103},
|
||||
{{"glass", "AODELd"}, 0.799065411f},
|
||||
{{"glass", "KDBLd"}, 0.82710278f},
|
||||
{{"glass", "KDBLd"}, 0.864485979f},
|
||||
{{"glass", "SPODELd"}, 0.780373812f},
|
||||
{{"glass", "TANLd"}, 0.869158864f},
|
||||
{{"glass", "TANLd"}, 0.831775725f},
|
||||
{{"glass", "BoostAODE"}, 0.84579f},
|
||||
// Iris
|
||||
{{"iris", "AODE"}, 0.973333},
|
||||
@@ -68,29 +68,29 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
{{"iris", "SPODELd"}, 0.96f},
|
||||
{{"iris", "TANLd"}, 0.97333f},
|
||||
{{"iris", "BoostAODE"}, 0.98f} };
|
||||
std::map<std::string, bayesnet::BaseClassifier*> models{ {"AODE", new bayesnet::AODE()},
|
||||
{"AODELd", new bayesnet::AODELd()},
|
||||
{"BoostAODE", new bayesnet::BoostAODE()},
|
||||
{"KDB", new bayesnet::KDB(2)},
|
||||
{"KDBLd", new bayesnet::KDBLd(2)},
|
||||
{"XSPODE", new bayesnet::XSpode(1)},
|
||||
{"SPODE", new bayesnet::SPODE(1)},
|
||||
{"SPODELd", new bayesnet::SPODELd(1)},
|
||||
{"TAN", new bayesnet::TAN()},
|
||||
{"TANLd", new bayesnet::TANLd()} };
|
||||
std::map<std::string, std::unique_ptr<bayesnet::BaseClassifier>> models;
|
||||
models["AODE"] = std::make_unique<bayesnet::AODE>();
|
||||
models["AODELd"] = std::make_unique<bayesnet::AODELd>();
|
||||
models["BoostAODE"] = std::make_unique<bayesnet::BoostAODE>();
|
||||
models["KDB"] = std::make_unique<bayesnet::KDB>(2);
|
||||
models["KDBLd"] = std::make_unique<bayesnet::KDBLd>(2);
|
||||
models["XSPODE"] = std::make_unique<bayesnet::XSpode>(1);
|
||||
models["SPODE"] = std::make_unique<bayesnet::SPODE>(1);
|
||||
models["SPODELd"] = std::make_unique<bayesnet::SPODELd>(1);
|
||||
models["TAN"] = std::make_unique<bayesnet::TAN>();
|
||||
models["TANLd"] = std::make_unique<bayesnet::TANLd>();
|
||||
std::string name = GENERATE("AODE", "AODELd", "KDB", "KDBLd", "SPODE", "XSPODE", "SPODELd", "TAN", "TANLd");
|
||||
auto clf = models[name];
|
||||
auto clf = std::move(models[name]);
|
||||
|
||||
SECTION("Test " + name + " classifier")
|
||||
{
|
||||
for (const std::string& file_name : { "glass", "iris", "ecoli", "diabetes" }) {
|
||||
auto clf = models[name];
|
||||
auto discretize = name.substr(name.length() - 2) != "Ld";
|
||||
auto raw = RawDatasets(file_name, discretize);
|
||||
clf->fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score = clf->score(raw.Xt, raw.yt);
|
||||
// std::cout << "Classifier: " << name << " File: " << file_name << " Score: " << score << " expected = " <<
|
||||
// scores[{file_name, name}] << std::endl;
|
||||
// scores[{file_name, name}] << std::endl;
|
||||
INFO("Classifier: " << name << " File: " << file_name);
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, name}]).epsilon(raw.epsilon));
|
||||
REQUIRE(clf->getStatus() == bayesnet::NORMAL);
|
||||
@@ -101,7 +101,6 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
INFO("Checking version of " << name << " classifier");
|
||||
REQUIRE(clf->getVersion() == ACTUAL_VERSION);
|
||||
}
|
||||
delete clf;
|
||||
}
|
||||
TEST_CASE("Models features & Graph", "[Models]")
|
||||
{
|
||||
@@ -133,7 +132,7 @@ TEST_CASE("Models features & Graph", "[Models]")
|
||||
clf.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||
REQUIRE(clf.getNumberOfEdges() == 7);
|
||||
REQUIRE(clf.getNumberOfStates() == 27);
|
||||
REQUIRE(clf.getNumberOfStates() == 26);
|
||||
REQUIRE(clf.getClassNumStates() == 3);
|
||||
REQUIRE(clf.show() == std::vector<std::string>{"class -> sepallength, sepalwidth, petallength, petalwidth, ",
|
||||
"petallength -> sepallength, ", "petalwidth -> ",
|
||||
@@ -149,7 +148,6 @@ TEST_CASE("Get num features & num edges", "[Models]")
|
||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||
REQUIRE(clf.getNumberOfEdges() == 8);
|
||||
}
|
||||
|
||||
TEST_CASE("Model predict_proba", "[Models]")
|
||||
{
|
||||
std::string model = GENERATE("TAN", "SPODE", "BoostAODEproba", "BoostAODEvoting", "TANLd", "SPODELd", "KDBLd");
|
||||
@@ -180,15 +178,15 @@ TEST_CASE("Model predict_proba", "[Models]")
|
||||
{0.0284828, 0.770524, 0.200993},
|
||||
{0.0213182, 0.857189, 0.121493},
|
||||
{0.00868436, 0.949494, 0.0418215} });
|
||||
auto res_prob_tanld = std::vector<std::vector<double>>({ {0.000544493, 0.995796, 0.00365992 },
|
||||
{0.000908092, 0.997268, 0.00182429 },
|
||||
{0.000908092, 0.997268, 0.00182429 },
|
||||
{0.000908092, 0.997268, 0.00182429 },
|
||||
{0.00228423, 0.994645, 0.00307078 },
|
||||
{0.00120539, 0.0666788, 0.932116 },
|
||||
{0.00361847, 0.979203, 0.017179 },
|
||||
{0.00483293, 0.985326, 0.00984064 },
|
||||
{0.000595606, 0.9977, 0.00170441 } });
|
||||
auto res_prob_tanld = std::vector<std::vector<double>>({ {0.000597557, 0.9957, 0.00370254},
|
||||
{0.000731377, 0.997914, 0.0013544},
|
||||
{0.000731377, 0.997914, 0.0013544},
|
||||
{0.000731377, 0.997914, 0.0013544},
|
||||
{0.000838614, 0.998122, 0.00103923},
|
||||
{0.00130852, 0.0659492, 0.932742},
|
||||
{0.00365946, 0.979412, 0.0169281},
|
||||
{0.00435035, 0.986248, 0.00940212},
|
||||
{0.000583815, 0.997746, 0.00167066} });
|
||||
auto res_prob_spodeld = std::vector<std::vector<double>>({ {0.000908024, 0.993742, 0.00535024 },
|
||||
{0.00187726, 0.99167, 0.00645308 },
|
||||
{0.00187726, 0.99167, 0.00645308 },
|
||||
@@ -216,29 +214,33 @@ TEST_CASE("Model predict_proba", "[Models]")
|
||||
{"TANLd", res_prob_tanld},
|
||||
{"SPODELd", res_prob_spodeld},
|
||||
{"KDBLd", res_prob_kdbld} };
|
||||
std::map<std::string, bayesnet::BaseClassifier*> models{ {"TAN", new bayesnet::TAN()},
|
||||
{"SPODE", new bayesnet::SPODE(0)},
|
||||
{"BoostAODEproba", new bayesnet::BoostAODE(false)},
|
||||
{"BoostAODEvoting", new bayesnet::BoostAODE(true)},
|
||||
{"TANLd", new bayesnet::TANLd()},
|
||||
{"SPODELd", new bayesnet::SPODELd(0)},
|
||||
{"KDBLd", new bayesnet::KDBLd(2)} };
|
||||
|
||||
std::map<std::string, std::unique_ptr<bayesnet::BaseClassifier>> models;
|
||||
models["TAN"] = std::make_unique<bayesnet::TAN>();
|
||||
models["SPODE"] = std::make_unique<bayesnet::SPODE>(0);
|
||||
models["BoostAODEproba"] = std::make_unique<bayesnet::BoostAODE>(false);
|
||||
models["BoostAODEvoting"] = std::make_unique<bayesnet::BoostAODE>(true);
|
||||
models["TANLd"] = std::make_unique<bayesnet::TANLd>();
|
||||
models["SPODELd"] = std::make_unique<bayesnet::SPODELd>(0);
|
||||
models["KDBLd"] = std::make_unique<bayesnet::KDBLd>(2);
|
||||
|
||||
int init_index = 78;
|
||||
|
||||
SECTION("Test " + model + " predict_proba")
|
||||
{
|
||||
INFO("Testing " << model << " predict_proba");
|
||||
auto ld_model = model.substr(model.length() - 2) == "Ld";
|
||||
auto discretize = !ld_model;
|
||||
auto raw = RawDatasets("iris", discretize);
|
||||
auto clf = models[model];
|
||||
clf->fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto yt_pred_proba = clf->predict_proba(raw.Xt);
|
||||
auto yt_pred = clf->predict(raw.Xt);
|
||||
auto& clf = *models[model];
|
||||
clf.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto yt_pred_proba = clf.predict_proba(raw.Xt);
|
||||
auto yt_pred = clf.predict(raw.Xt);
|
||||
std::vector<int> y_pred;
|
||||
std::vector<std::vector<double>> y_pred_proba;
|
||||
if (!ld_model) {
|
||||
y_pred = clf->predict(raw.Xv);
|
||||
y_pred_proba = clf->predict_proba(raw.Xv);
|
||||
y_pred = clf.predict(raw.Xv);
|
||||
y_pred_proba = clf.predict_proba(raw.Xv);
|
||||
REQUIRE(y_pred.size() == y_pred_proba.size());
|
||||
REQUIRE(y_pred.size() == yt_pred.size(0));
|
||||
REQUIRE(y_pred.size() == yt_pred_proba.size(0));
|
||||
@@ -267,18 +269,20 @@ TEST_CASE("Model predict_proba", "[Models]")
|
||||
} else {
|
||||
// Check predict_proba values for vectors and tensors
|
||||
auto predictedClasses = yt_pred_proba.argmax(1);
|
||||
// std::cout << model << std::endl;
|
||||
for (int i = 0; i < 9; i++) {
|
||||
REQUIRE(predictedClasses[i].item<int>() == yt_pred[i].item<int>());
|
||||
// std::cout << "{";
|
||||
for (int j = 0; j < 3; j++) {
|
||||
// std::cout << yt_pred_proba[i + init_index][j].item<double>() << ", ";
|
||||
REQUIRE(res_prob[model][i][j] ==
|
||||
Catch::Approx(yt_pred_proba[i + init_index][j].item<double>()).epsilon(raw.epsilon));
|
||||
}
|
||||
// std::cout << "\b\b}," << std::endl;
|
||||
}
|
||||
}
|
||||
delete clf;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("AODE voting-proba", "[Models]")
|
||||
{
|
||||
auto raw = RawDatasets("glass", true);
|
||||
@@ -297,17 +301,30 @@ TEST_CASE("AODE voting-proba", "[Models]")
|
||||
REQUIRE(pred_proba[67][0] == Catch::Approx(0.702184).epsilon(raw.epsilon));
|
||||
REQUIRE(clf.topological_order() == std::vector<std::string>());
|
||||
}
|
||||
TEST_CASE("SPODELd dataset", "[Models]")
|
||||
TEST_CASE("Ld models with dataset", "[Models]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", false);
|
||||
auto clf = bayesnet::SPODELd(0);
|
||||
// raw.dataset.to(torch::kFloat32);
|
||||
clf.fit(raw.dataset, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score = clf.score(raw.Xt, raw.yt);
|
||||
clf.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto scoret = clf.score(raw.Xt, raw.yt);
|
||||
REQUIRE(score == Catch::Approx(0.97333f).epsilon(raw.epsilon));
|
||||
REQUIRE(scoret == Catch::Approx(0.97333f).epsilon(raw.epsilon));
|
||||
auto clf2 = bayesnet::TANLd();
|
||||
clf2.fit(raw.dataset, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score2 = clf2.score(raw.Xt, raw.yt);
|
||||
clf2.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score2t = clf2.score(raw.Xt, raw.yt);
|
||||
REQUIRE(score2 == Catch::Approx(0.97333f).epsilon(raw.epsilon));
|
||||
REQUIRE(score2t == Catch::Approx(0.97333f).epsilon(raw.epsilon));
|
||||
auto clf3 = bayesnet::KDBLd(2);
|
||||
clf3.fit(raw.dataset, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score3 = clf3.score(raw.Xt, raw.yt);
|
||||
clf3.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score3t = clf3.score(raw.Xt, raw.yt);
|
||||
REQUIRE(score3 == Catch::Approx(0.97333f).epsilon(raw.epsilon));
|
||||
REQUIRE(score3t == Catch::Approx(0.97333f).epsilon(raw.epsilon));
|
||||
}
|
||||
TEST_CASE("KDB with hyperparameters", "[Models]")
|
||||
{
|
||||
@@ -324,11 +341,15 @@ TEST_CASE("KDB with hyperparameters", "[Models]")
|
||||
REQUIRE(score == Catch::Approx(0.827103).epsilon(raw.epsilon));
|
||||
REQUIRE(scoret == Catch::Approx(0.761682).epsilon(raw.epsilon));
|
||||
}
|
||||
TEST_CASE("Incorrect type of data for SPODELd", "[Models]")
|
||||
TEST_CASE("Incorrect type of data for Ld models", "[Models]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::SPODELd(0);
|
||||
REQUIRE_THROWS_AS(clf.fit(raw.dataset, raw.features, raw.className, raw.states, raw.smoothing), std::runtime_error);
|
||||
auto clfs = bayesnet::SPODELd(0);
|
||||
REQUIRE_THROWS_AS(clfs.fit(raw.dataset, raw.features, raw.className, raw.states, raw.smoothing), std::runtime_error);
|
||||
auto clft = bayesnet::TANLd();
|
||||
REQUIRE_THROWS_AS(clft.fit(raw.dataset, raw.features, raw.className, raw.states, raw.smoothing), std::runtime_error);
|
||||
auto clfk = bayesnet::KDBLd(0);
|
||||
REQUIRE_THROWS_AS(clfk.fit(raw.dataset, raw.features, raw.className, raw.states, raw.smoothing), std::runtime_error);
|
||||
}
|
||||
TEST_CASE("Predict, predict_proba & score without fitting", "[Models]")
|
||||
{
|
||||
@@ -386,14 +407,15 @@ TEST_CASE("Check proposal checkInput", "[Models]")
|
||||
{
|
||||
class testProposal : public bayesnet::Proposal {
|
||||
public:
|
||||
testProposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_)
|
||||
: Proposal(dataset_, features_, className_)
|
||||
testProposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_, std::vector<std::string>& notes_)
|
||||
: Proposal(dataset_, features_, className_, notes_)
|
||||
{
|
||||
}
|
||||
void test_X_y(const torch::Tensor& X, const torch::Tensor& y) { checkInput(X, y); }
|
||||
};
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = testProposal(raw.dataset, raw.features, raw.className);
|
||||
std::vector<std::string> notes;
|
||||
auto clf = testProposal(raw.dataset, raw.features, raw.className, notes);
|
||||
torch::Tensor X = torch::randint(0, 3, { 10, 4 });
|
||||
torch::Tensor y = torch::rand({ 10 });
|
||||
INFO("Check X is not float");
|
||||
@@ -428,3 +450,49 @@ TEST_CASE("Check KDB loop detection", "[Models]")
|
||||
REQUIRE_NOTHROW(clf.test_add_m_edges(features, 0, S, weights));
|
||||
REQUIRE_NOTHROW(clf.test_add_m_edges(features, 1, S, weights));
|
||||
}
|
||||
TEST_CASE("Local discretization hyperparameters", "[Models]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", false);
|
||||
auto clfs = bayesnet::SPODELd(0);
|
||||
clfs.setHyperparameters({
|
||||
{"max_iterations", 7},
|
||||
{"verbose_convergence", true},
|
||||
});
|
||||
REQUIRE_NOTHROW(clfs.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing));
|
||||
REQUIRE(clfs.getStatus() == bayesnet::NORMAL);
|
||||
auto clfk = bayesnet::KDBLd(0);
|
||||
clfk.setHyperparameters({
|
||||
{"k", 3},
|
||||
{"theta", 1e-4},
|
||||
});
|
||||
REQUIRE_NOTHROW(clfk.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing));
|
||||
REQUIRE(clfk.getStatus() == bayesnet::NORMAL);
|
||||
auto clfa = bayesnet::AODELd();
|
||||
clfa.setHyperparameters({
|
||||
{"ld_proposed_cuts", 9},
|
||||
{"ld_algorithm", "BINQ"},
|
||||
});
|
||||
REQUIRE_NOTHROW(clfa.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing));
|
||||
REQUIRE(clfa.getStatus() == bayesnet::NORMAL);
|
||||
auto clft = bayesnet::TANLd();
|
||||
clft.setHyperparameters({
|
||||
{"ld_proposed_cuts", 7},
|
||||
{"mdlp_max_depth", 5},
|
||||
{"mdlp_min_length", 3},
|
||||
{"ld_algorithm", "MDLP"},
|
||||
});
|
||||
REQUIRE_NOTHROW(clft.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing));
|
||||
REQUIRE(clft.getStatus() == bayesnet::NORMAL);
|
||||
clft.setHyperparameters({
|
||||
{"ld_proposed_cuts", 9},
|
||||
{"ld_algorithm", "BINQ"},
|
||||
});
|
||||
REQUIRE_NOTHROW(clft.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing));
|
||||
REQUIRE(clft.getStatus() == bayesnet::NORMAL);
|
||||
clft.setHyperparameters({
|
||||
{"ld_proposed_cuts", 5},
|
||||
{"ld_algorithm", "BINU"},
|
||||
});
|
||||
REQUIRE_NOTHROW(clft.fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing));
|
||||
REQUIRE(clft.getStatus() == bayesnet::NORMAL);
|
||||
}
|
||||
|
@@ -338,6 +338,190 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
REQUIRE_THROWS_AS(net5.addEdge("A", "B"), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(net5.addEdge("A", "B"), "Cannot add edge to a fitted network. Initialize first.");
|
||||
}
|
||||
SECTION("Test assignment operator")
|
||||
{
|
||||
INFO("Test assignment operator");
|
||||
// Create original network
|
||||
auto net1 = bayesnet::Network();
|
||||
buildModel(net1, raw.features, raw.className);
|
||||
net1.fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
|
||||
// Create empty network and assign
|
||||
auto net2 = bayesnet::Network();
|
||||
net2.addNode("TempNode"); // Add something to make sure it gets cleared
|
||||
net2 = net1;
|
||||
|
||||
// Verify they are equal
|
||||
REQUIRE(net1.getFeatures() == net2.getFeatures());
|
||||
REQUIRE(net1.getEdges() == net2.getEdges());
|
||||
REQUIRE(net1.getNumEdges() == net2.getNumEdges());
|
||||
REQUIRE(net1.getStates() == net2.getStates());
|
||||
REQUIRE(net1.getClassName() == net2.getClassName());
|
||||
REQUIRE(net1.getClassNumStates() == net2.getClassNumStates());
|
||||
REQUIRE(net1.getSamples().size(0) == net2.getSamples().size(0));
|
||||
REQUIRE(net1.getSamples().size(1) == net2.getSamples().size(1));
|
||||
REQUIRE(net1.getNodes().size() == net2.getNodes().size());
|
||||
|
||||
// Verify topology equality
|
||||
REQUIRE(net1 == net2);
|
||||
|
||||
// Verify they are separate objects by modifying one
|
||||
net2.initialize();
|
||||
net2.addNode("OnlyInNet2");
|
||||
REQUIRE(net1.getNodes().size() != net2.getNodes().size());
|
||||
REQUIRE_FALSE(net1 == net2);
|
||||
}
|
||||
SECTION("Test self assignment")
|
||||
{
|
||||
INFO("Test self assignment");
|
||||
buildModel(net, raw.features, raw.className);
|
||||
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
|
||||
int original_edges = net.getNumEdges();
|
||||
int original_nodes = net.getNodes().size();
|
||||
|
||||
// Self assignment should not corrupt the network
|
||||
net = net;
|
||||
auto all_features = raw.features;
|
||||
all_features.push_back(raw.className);
|
||||
REQUIRE(net.getNumEdges() == original_edges);
|
||||
REQUIRE(net.getNodes().size() == original_nodes);
|
||||
REQUIRE(net.getFeatures() == all_features);
|
||||
REQUIRE(net.getClassName() == raw.className);
|
||||
}
|
||||
SECTION("Test operator== topology comparison")
|
||||
{
|
||||
INFO("Test operator== topology comparison");
|
||||
|
||||
// Test 1: Two identical networks
|
||||
auto net1 = bayesnet::Network();
|
||||
auto net2 = bayesnet::Network();
|
||||
|
||||
net1.addNode("A");
|
||||
net1.addNode("B");
|
||||
net1.addNode("C");
|
||||
net1.addEdge("A", "B");
|
||||
net1.addEdge("B", "C");
|
||||
|
||||
net2.addNode("A");
|
||||
net2.addNode("B");
|
||||
net2.addNode("C");
|
||||
net2.addEdge("A", "B");
|
||||
net2.addEdge("B", "C");
|
||||
|
||||
REQUIRE(net1 == net2);
|
||||
|
||||
// Test 2: Different nodes
|
||||
auto net3 = bayesnet::Network();
|
||||
net3.addNode("A");
|
||||
net3.addNode("D"); // Different node
|
||||
REQUIRE_FALSE(net1 == net3);
|
||||
|
||||
// Test 3: Same nodes, different edges
|
||||
auto net4 = bayesnet::Network();
|
||||
net4.addNode("A");
|
||||
net4.addNode("B");
|
||||
net4.addNode("C");
|
||||
net4.addEdge("A", "C"); // Different topology
|
||||
net4.addEdge("B", "C");
|
||||
REQUIRE_FALSE(net1 == net4);
|
||||
|
||||
// Test 4: Empty networks
|
||||
auto net5 = bayesnet::Network();
|
||||
auto net6 = bayesnet::Network();
|
||||
REQUIRE(net5 == net6);
|
||||
|
||||
// Test 5: Same topology, different edge order
|
||||
auto net7 = bayesnet::Network();
|
||||
net7.addNode("A");
|
||||
net7.addNode("B");
|
||||
net7.addNode("C");
|
||||
net7.addEdge("B", "C"); // Add edges in different order
|
||||
net7.addEdge("A", "B");
|
||||
REQUIRE(net1 == net7); // Should still be equal
|
||||
}
|
||||
SECTION("Test RAII compliance with smart pointers")
|
||||
{
|
||||
INFO("Test RAII compliance with smart pointers");
|
||||
|
||||
std::unique_ptr<bayesnet::Network> net1 = std::make_unique<bayesnet::Network>();
|
||||
buildModel(*net1, raw.features, raw.className);
|
||||
net1->fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
|
||||
// Test that copy constructor works with smart pointers
|
||||
std::unique_ptr<bayesnet::Network> net2 = std::make_unique<bayesnet::Network>(*net1);
|
||||
|
||||
REQUIRE(*net1 == *net2);
|
||||
REQUIRE(net1->getNumEdges() == net2->getNumEdges());
|
||||
REQUIRE(net1->getNodes().size() == net2->getNodes().size());
|
||||
|
||||
// Destroy original
|
||||
net1.reset();
|
||||
|
||||
// Test predictions still work
|
||||
std::vector<std::vector<int>> test = { {1}, {2}, {0}, {1} };
|
||||
REQUIRE_NOTHROW(net2->predict(test));
|
||||
|
||||
// net2 should still be valid and functional
|
||||
net2->initialize();
|
||||
REQUIRE_NOTHROW(net2->addNode("NewNode"));
|
||||
REQUIRE(net2->getNodes().count("NewNode") == 1);
|
||||
}
|
||||
SECTION("Test complex topology copy")
|
||||
{
|
||||
INFO("Test complex topology copy");
|
||||
|
||||
auto original = bayesnet::Network();
|
||||
|
||||
// Create a more complex network
|
||||
original.addNode("Root");
|
||||
original.addNode("Child1");
|
||||
original.addNode("Child2");
|
||||
original.addNode("Grandchild1");
|
||||
original.addNode("Grandchild2");
|
||||
original.addNode("Grandchild3");
|
||||
|
||||
original.addEdge("Root", "Child1");
|
||||
original.addEdge("Root", "Child2");
|
||||
original.addEdge("Child1", "Grandchild1");
|
||||
original.addEdge("Child1", "Grandchild2");
|
||||
original.addEdge("Child2", "Grandchild3");
|
||||
|
||||
// Copy it
|
||||
auto copy = original;
|
||||
|
||||
// Verify topology is identical
|
||||
REQUIRE(original == copy);
|
||||
REQUIRE(original.getNodes().size() == copy.getNodes().size());
|
||||
REQUIRE(original.getNumEdges() == copy.getNumEdges());
|
||||
|
||||
// Verify edges are properly reconstructed
|
||||
auto originalEdges = original.getEdges();
|
||||
auto copyEdges = copy.getEdges();
|
||||
REQUIRE(originalEdges.size() == copyEdges.size());
|
||||
|
||||
// Verify node relationships are properly copied
|
||||
for (const auto& nodePair : original.getNodes()) {
|
||||
const std::string& nodeName = nodePair.first;
|
||||
auto* originalNode = nodePair.second.get();
|
||||
auto* copyNode = copy.getNodes().at(nodeName).get();
|
||||
|
||||
REQUIRE(originalNode->getParents().size() == copyNode->getParents().size());
|
||||
REQUIRE(originalNode->getChildren().size() == copyNode->getChildren().size());
|
||||
|
||||
// Verify parent names match
|
||||
for (size_t i = 0; i < originalNode->getParents().size(); ++i) {
|
||||
REQUIRE(originalNode->getParents()[i]->getName() ==
|
||||
copyNode->getParents()[i]->getName());
|
||||
}
|
||||
|
||||
// Verify child names match
|
||||
for (size_t i = 0; i < originalNode->getChildren().size(); ++i) {
|
||||
REQUIRE(originalNode->getChildren()[i]->getName() ==
|
||||
copyNode->getChildren()[i]->getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
TEST_CASE("Test and empty Node", "[Network]")
|
||||
|
@@ -159,3 +159,47 @@ TEST_CASE("TEST MinFill method", "[Node]")
|
||||
REQUIRE(node_3.minFill() == 3);
|
||||
REQUIRE(node_4.minFill() == 1);
|
||||
}
|
||||
TEST_CASE("Test operator =", "[Node]")
|
||||
{
|
||||
// Generate a test to test the operator = of the Node class
|
||||
// Create a node with 3 parents and 2 children
|
||||
auto node = bayesnet::Node("N1");
|
||||
auto parent_1 = bayesnet::Node("P1");
|
||||
parent_1.setNumStates(3);
|
||||
auto child_1 = bayesnet::Node("H1");
|
||||
child_1.setNumStates(2);
|
||||
node.addParent(&parent_1);
|
||||
node.addChild(&child_1);
|
||||
// Create a cpt in the node using computeCPT
|
||||
auto dataset = torch::tensor({ {1, 0, 0, 1}, {0, 1, 2, 1}, {0, 1, 1, 0} });
|
||||
auto states = std::vector<int>({ 2, 3, 3 });
|
||||
auto features = std::vector<std::string>{ "N1", "P1", "H1" };
|
||||
auto className = std::string("Class");
|
||||
auto weights = torch::tensor({ 1.0, 1.0, 1.0, 1.0 }, torch::kDouble);
|
||||
node.setNumStates(2);
|
||||
node.computeCPT(dataset, features, 0.0, weights);
|
||||
// Get the cpt of the node
|
||||
auto cpt = node.getCPT();
|
||||
// Check that the cpt is not empty
|
||||
REQUIRE(cpt.numel() > 0);
|
||||
// Check that the cpt has the correct dimensions
|
||||
auto dimensions = cpt.sizes();
|
||||
REQUIRE(dimensions.size() == 2);
|
||||
REQUIRE(dimensions[0] == 2); // Number of states of the node
|
||||
REQUIRE(dimensions[1] == 3); // Number of states of the first parent
|
||||
// Create a copy of the node
|
||||
bayesnet::Node node_copy("XX");
|
||||
node_copy = node;
|
||||
// Check that the copy has not any parents or children
|
||||
auto parents = node_copy.getParents();
|
||||
auto children = node_copy.getChildren();
|
||||
REQUIRE(parents.size() == 0);
|
||||
REQUIRE(children.size() == 0);
|
||||
// Check that the copy has the same name
|
||||
REQUIRE(node_copy.getName() == "N1");
|
||||
// Check that the copy has the same cpt
|
||||
auto cpt_copy = node_copy.getCPT();
|
||||
REQUIRE(cpt_copy.equal(cpt));
|
||||
// Check that the copy has the same number of states
|
||||
REQUIRE(node_copy.getNumStates() == node.getNumStates());
|
||||
}
|
@@ -16,10 +16,10 @@
|
||||
#include "TestUtils.h"
|
||||
|
||||
std::map<std::string, std::string> modules = {
|
||||
{ "mdlp", "2.0.1" },
|
||||
{ "Folding", "1.1.1" },
|
||||
{ "mdlp", "2.1.1" },
|
||||
{ "Folding", "1.1.2" },
|
||||
{ "json", "3.11" },
|
||||
{ "ArffFiles", "1.1.0" }
|
||||
{ "ArffFiles", "1.2.1" }
|
||||
};
|
||||
|
||||
TEST_CASE("MDLP", "[Modules]")
|
||||
|
@@ -11,7 +11,7 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <ArffFiles/ArffFiles.hpp>
|
||||
#include <ArffFiles.hpp>
|
||||
#include <fimdlp/CPPFImdlp.h>
|
||||
#include <folding.hpp>
|
||||
#include <bayesnet/network/Network.h>
|
||||
|
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"default-registry": {
|
||||
"kind": "git",
|
||||
"baseline": "760bfd0c8d7c89ec640aec4df89418b7c2745605",
|
||||
"repository": "https://github.com/microsoft/vcpkg"
|
||||
},
|
||||
"registries": [
|
||||
{
|
||||
"kind": "git",
|
||||
"repository": "https://github.com/rmontanana/vcpkg-stash",
|
||||
"baseline": "1ea69243c0e8b0de77c9d1dd6e1d7593ae7f3627",
|
||||
"packages": [
|
||||
"arff-files",
|
||||
"fimdlp",
|
||||
"libtorch-bin",
|
||||
"bayesnet",
|
||||
"folding"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
40
vcpkg.json
40
vcpkg.json
@@ -1,40 +0,0 @@
|
||||
{
|
||||
"name": "bayesnet",
|
||||
"version": "1.0.7",
|
||||
"description": "Bayesian Network C++ Library",
|
||||
"license": "MIT",
|
||||
"dependencies": [
|
||||
"arff-files",
|
||||
"folding",
|
||||
"fimdlp",
|
||||
"libtorch-bin",
|
||||
"nlohmann-json",
|
||||
"catch2"
|
||||
],
|
||||
"overrides": [
|
||||
{
|
||||
"name": "arff-files",
|
||||
"version": "1.1.0"
|
||||
},
|
||||
{
|
||||
"name": "fimdlp",
|
||||
"version": "2.0.1"
|
||||
},
|
||||
{
|
||||
"name": "libtorch-bin",
|
||||
"version": "2.7.0"
|
||||
},
|
||||
{
|
||||
"name": "folding",
|
||||
"version": "1.1.1"
|
||||
},
|
||||
{
|
||||
"name": "nlohmann-json",
|
||||
"version": "3.11.3"
|
||||
},
|
||||
{
|
||||
"name": "catch2",
|
||||
"version": "3.8.1"
|
||||
}
|
||||
]
|
||||
}
|
Reference in New Issue
Block a user