Compare commits

..

9 Commits

24 changed files with 98 additions and 83 deletions

28
.vscode/launch.json vendored
View File

@@ -5,7 +5,7 @@
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
"name": "sample", "name": "sample",
"program": "${workspaceFolder}/build/sample/BayesNetSample", "program": "${workspaceFolder}/build_debug/sample/BayesNetSample",
"args": [ "args": [
"-d", "-d",
"iris", "iris",
@@ -14,7 +14,7 @@
"-s", "-s",
"271", "271",
"-p", "-p",
"/Users/rmontanana/Code/discretizbench/datasets/", "/home/rmontanana/Code/discretizbench/datasets/",
], ],
//"cwd": "${workspaceFolder}/build/sample/", //"cwd": "${workspaceFolder}/build/sample/",
}, },
@@ -22,24 +22,24 @@
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
"name": "experiment", "name": "experiment",
"program": "${workspaceFolder}/build/src/Platform/b_main", "program": "${workspaceFolder}/build_debug/src/Platform/b_main",
"args": [ "args": [
"-m", "-m",
"TAN", "STree",
"--stratified", "--stratified",
"-d", "-d",
"zoo", "iris",
"--discretize" //"--discretize"
// "--hyperparameters", // "--hyperparameters",
// "{\"repeatSparent\": true, \"maxModels\": 12}" // "{\"repeatSparent\": true, \"maxModels\": 12}"
], ],
"cwd": "/Users/rmontanana/Code/odtebench", "cwd": "/home/rmontanana/Code/discretizbench",
}, },
{ {
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
"name": "best", "name": "best",
"program": "${workspaceFolder}/build/src/Platform/b_best", "program": "${workspaceFolder}/build_debug/src/Platform/b_best",
"args": [ "args": [
"-m", "-m",
"BoostAODE", "BoostAODE",
@@ -47,24 +47,24 @@
"accuracy", "accuracy",
"--build", "--build",
], ],
"cwd": "/Users/rmontanana/Code/discretizbench", "cwd": "/home/rmontanana/Code/discretizbench",
}, },
{ {
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
"name": "manage", "name": "manage",
"program": "${workspaceFolder}/build/src/Platform/b_manage", "program": "${workspaceFolder}/build_debug/src/Platform/b_manage",
"args": [ "args": [
"-n", "-n",
"20" "20"
], ],
"cwd": "/Users/rmontanana/Code/discretizbench", "cwd": "/home/rmontanana/Code/discretizbench",
}, },
{ {
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
"name": "list", "name": "list",
"program": "${workspaceFolder}/build/src/Platform/b_list", "program": "${workspaceFolder}/build_debug/src/Platform/b_list",
"args": [], "args": [],
//"cwd": "/Users/rmontanana/Code/discretizbench", //"cwd": "/Users/rmontanana/Code/discretizbench",
"cwd": "/home/rmontanana/Code/covbench", "cwd": "/home/rmontanana/Code/covbench",
@@ -73,7 +73,7 @@
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
"name": "test", "name": "test",
"program": "${workspaceFolder}/build/tests/unit_tests", "program": "${workspaceFolder}/build_debug/tests/unit_tests",
"args": [ "args": [
"-c=\"Metrics Test\"", "-c=\"Metrics Test\"",
// "-s", // "-s",
@@ -84,7 +84,7 @@
"name": "Build & debug active file", "name": "Build & debug active file",
"type": "cppdbg", "type": "cppdbg",
"request": "launch", "request": "launch",
"program": "${workspaceFolder}/build/bayesnet", "program": "${workspaceFolder}/build_debug/bayesnet",
"args": [], "args": [],
"stopAtEntry": false, "stopAtEntry": false,
"cwd": "${workspaceFolder}", "cwd": "${workspaceFolder}",

View File

@@ -36,12 +36,16 @@ option(CODE_COVERAGE "Collect coverage from test library" OFF)
set(Boost_USE_STATIC_LIBS OFF) set(Boost_USE_STATIC_LIBS OFF)
set(Boost_USE_MULTITHREADED ON) set(Boost_USE_MULTITHREADED ON)
set(Boost_USE_STATIC_RUNTIME OFF) set(Boost_USE_STATIC_RUNTIME OFF)
find_package(Boost 1.66.0 REQUIRED) find_package(Boost 1.66.0 REQUIRED COMPONENTS python3 numpy3)
if(Boost_FOUND) if(Boost_FOUND)
message("Boost_INCLUDE_DIRS=${Boost_INCLUDE_DIRS}") message("Boost_INCLUDE_DIRS=${Boost_INCLUDE_DIRS}")
include_directories(${Boost_INCLUDE_DIRS}) include_directories(${Boost_INCLUDE_DIRS})
endif() endif()
# Python
find_package(Python3 3.11...3.11.9 COMPONENTS Interpreter Development REQUIRED)
message("Python3_LIBRARIES=${Python3_LIBRARIES}")
# CMakes modules # CMakes modules
# -------------- # --------------
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules ${CMAKE_MODULE_PATH}) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules ${CMAKE_MODULE_PATH})
@@ -77,6 +81,7 @@ add_subdirectory(config)
add_subdirectory(lib/Files) add_subdirectory(lib/Files)
add_subdirectory(src/BayesNet) add_subdirectory(src/BayesNet)
add_subdirectory(src/Platform) add_subdirectory(src/Platform)
add_subdirectory(src/PyClassifiers)
add_subdirectory(sample) add_subdirectory(sample)
file(GLOB BayesNet_HEADERS CONFIGURE_DEPENDS ${BayesNet_SOURCE_DIR}/src/BayesNet/*.h ${BayesNet_SOURCE_DIR}/BayesNet/*.h) file(GLOB BayesNet_HEADERS CONFIGURE_DEPENDS ${BayesNet_SOURCE_DIR}/src/BayesNet/*.h ${BayesNet_SOURCE_DIR}/BayesNet/*.h)

View File

@@ -1,5 +1,4 @@
#include "AODELd.h" #include "AODELd.h"
#include "Models.h"
namespace bayesnet { namespace bayesnet {
AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {} AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {}

View File

@@ -26,7 +26,7 @@ namespace bayesnet {
int virtual getNumberOfStates() const = 0; int virtual getNumberOfStates() const = 0;
std::vector<std::string> virtual show() const = 0; std::vector<std::string> virtual show() const = 0;
std::vector<std::string> virtual graph(const std::string& title = "") const = 0; std::vector<std::string> virtual graph(const std::string& title = "") const = 0;
const std::string inline getVersion() const { return "0.2.0"; }; virtual std::string getVersion() = 0;
std::vector<std::string> virtual topological_order() = 0; std::vector<std::string> virtual topological_order() = 0;
void virtual dump_cpt()const = 0; void virtual dump_cpt()const = 0;
virtual void setHyperparameters(nlohmann::json& hyperparameters) = 0; virtual void setHyperparameters(nlohmann::json& hyperparameters) = 0;

View File

@@ -108,8 +108,10 @@ namespace bayesnet {
void BoostAODE::trainModel(const torch::Tensor& weights) void BoostAODE::trainModel(const torch::Tensor& weights)
{ {
std::unordered_set<int> featuresUsed; std::unordered_set<int> featuresUsed;
int tolerance = 5; // number of times the accuracy can be lower than the threshold
if (selectFeatures) { if (selectFeatures) {
featuresUsed = initializeModels(); featuresUsed = initializeModels();
tolerance = 0; // Remove tolerance if features are selected
} }
if (maxModels == 0) if (maxModels == 0)
maxModels = .1 * n > 10 ? .1 * n : n; maxModels = .1 * n > 10 ? .1 * n : n;
@@ -119,7 +121,6 @@ namespace bayesnet {
double priorAccuracy = 0.0; double priorAccuracy = 0.0;
double delta = 1.0; double delta = 1.0;
double threshold = 1e-4; double threshold = 1e-4;
int tolerance = 5; // number of times the accuracy can be lower than the threshold
int count = 0; // number of times the accuracy is lower than the threshold int count = 0; // number of times the accuracy is lower than the threshold
fitted = true; // to enable predict fitted = true; // to enable predict
// Step 0: Set the finish condition // Step 0: Set the finish condition

View File

@@ -3,6 +3,9 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet) include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
include_directories(${BayesNet_SOURCE_DIR}/src/Platform) include_directories(${BayesNet_SOURCE_DIR}/src/Platform)
include_directories(${BayesNet_SOURCE_DIR}/src/PyClassifiers)
include_directories(${Python3_INCLUDE_DIRS})
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc
KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc BoostAODE.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc BoostAODE.cc
Mst.cc Proposal.cc CFS.cc FCBF.cc IWSS.cc FeatureSelect.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc) Mst.cc Proposal.cc CFS.cc FCBF.cc IWSS.cc FeatureSelect.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)

View File

@@ -37,6 +37,7 @@ namespace bayesnet {
int getNumberOfStates() const override; int getNumberOfStates() const override;
torch::Tensor predict(torch::Tensor& X) override; torch::Tensor predict(torch::Tensor& X) override;
status_t getStatus() const override { return status; } status_t getStatus() const override { return status; }
std::string getVersion() override { return "0.2.0"; };
std::vector<int> predict(std::vector<std::vector<int>>& X) override; std::vector<int> predict(std::vector<std::vector<int>>& X) override;
float score(torch::Tensor& X, torch::Tensor& y) override; float score(torch::Tensor& X, torch::Tensor& y) override;
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override; float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;

View File

@@ -1,17 +1,19 @@
include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet) include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
include_directories(${BayesNet_SOURCE_DIR}/src/Platform) include_directories(${BayesNet_SOURCE_DIR}/src/Platform)
include_directories(${BayesNet_SOURCE_DIR}/src/PyClassifiers)
include_directories(${BayesNet_SOURCE_DIR}/lib/Files) include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include) include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include)
include_directories(${Python3_INCLUDE_DIRS})
add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc ReportConsole.cc ReportBase.cc) add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc ReportConsole.cc ReportBase.cc)
add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
add_executable(b_list b_list.cc Datasets.cc Dataset.cc) add_executable(b_list b_list.cc Datasets.cc Dataset.cc)
add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap)
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp) target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}" "${TORCH_LIBRARIES}" ArffFiles mdlp) target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}" "${TORCH_LIBRARIES}" ArffFiles mdlp)
target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@@ -211,7 +211,6 @@ namespace platform {
result.addTimeTrain(train_time[item].item<double>()); result.addTimeTrain(train_time[item].item<double>());
result.addTimeTest(test_time[item].item<double>()); result.addTimeTest(test_time[item].item<double>());
item++; item++;
clf.reset();
} }
if (!quiet) if (!quiet)
std::cout << "end. " << flush; std::cout << "end. " << flush;

View File

@@ -11,6 +11,10 @@
#include "SPODELd.h" #include "SPODELd.h"
#include "AODELd.h" #include "AODELd.h"
#include "BoostAODE.h" #include "BoostAODE.h"
#include "STree.h"
#include "ODTE.h"
#include "SVC.h"
#include "RandomForest.h"
namespace platform { namespace platform {
class Models { class Models {
private: private:

View File

@@ -18,4 +18,12 @@ static platform::Registrar registrarALD("AODELd",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();}); [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();});
static platform::Registrar registrarBA("BoostAODE", static platform::Registrar registrarBA("BoostAODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();}); [](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();});
static platform::Registrar registrarSt("STree",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();});
static platform::Registrar registrarOdte("Odte",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::ODTE();});
static platform::Registrar registrarSvc("SVC",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::SVC();});
static platform::Registrar registrarRaF("RandomForest",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();});
#endif #endif

View File

@@ -1,5 +1,6 @@
include_directories(${PyWrap_SOURCE_DIR}/lib/Files) include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${PyWrap_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
include_directories(${Python3_INCLUDE_DIRS}) include_directories(${Python3_INCLUDE_DIRS})
include_directories(${TORCH_INCLUDE_DIRS}) include_directories(${TORCH_INCLUDE_DIRS})

View File

@@ -1,25 +0,0 @@
#ifndef CLASSIFIER_H
#define CLASSIFIER_H
#include <torch/torch.h>
#include <nlohmann/json.hpp>
#include <string>
#include <map>
#include <vector>
namespace pywrap {
class Classifier {
public:
Classifier() = default;
virtual ~Classifier() = default;
virtual Classifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
virtual Classifier& fit(torch::Tensor& X, torch::Tensor& y) = 0;
virtual torch::Tensor predict(torch::Tensor& X) = 0;
virtual double score(torch::Tensor& X, torch::Tensor& y) = 0;
virtual std::string version() = 0;
virtual std::string sklearnVersion() = 0;
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
protected:
virtual void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters) = 0;
};
} /* namespace pywrap */
#endif /* CLASSIFIER_H */

View File

@@ -5,7 +5,7 @@ namespace pywrap {
{ {
return callMethodString("graph"); return callMethodString("graph");
} }
void ODTE::setHyperparameters(const nlohmann::json& hyperparameters) void ODTE::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid // Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "n_jobs", "n_estimators", "random_state" }; const std::vector<std::string> validKeys = { "n_jobs", "n_estimators", "random_state" };

View File

@@ -9,7 +9,7 @@ namespace pywrap {
ODTE() : PyClassifier("odte", "Odte") {}; ODTE() : PyClassifier("odte", "Odte") {};
~ODTE() = default; ~ODTE() = default;
std::string graph(); std::string graph();
void setHyperparameters(const nlohmann::json& hyperparameters) override; void setHyperparameters(nlohmann::json& hyperparameters) override;
}; };
} /* namespace pywrap */ } /* namespace pywrap */
#endif /* ODTE_H */ #endif /* ODTE_H */

View File

@@ -2,7 +2,7 @@
namespace pywrap { namespace pywrap {
namespace bp = boost::python; namespace bp = boost::python;
namespace np = boost::python::numpy; namespace np = boost::python::numpy;
PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false) PyClassifier::PyClassifier(const std::string& module, const std::string& className, bool sklearn) : module(module), className(className), sklearn(sklearn), fitted(false)
{ {
// This id allows to have more than one instance of the same module/class // This id allows to have more than one instance of the same module/class
id = reinterpret_cast<clfId_t>(this); id = reinterpret_cast<clfId_t>(this);
@@ -29,12 +29,11 @@ namespace pywrap {
} }
std::string PyClassifier::version() std::string PyClassifier::version()
{ {
if (sklearn) {
return pyWrap->sklearnVersion();
}
return pyWrap->version(id); return pyWrap->version(id);
} }
std::string PyClassifier::sklearnVersion()
{
return pyWrap->sklearnVersion();
}
std::string PyClassifier::callMethodString(const std::string& method) std::string PyClassifier::callMethodString(const std::string& method)
{ {
return pyWrap->callMethodString(id, method); return pyWrap->callMethodString(id, method);
@@ -74,15 +73,15 @@ namespace pywrap {
Py_XDECREF(incoming); Py_XDECREF(incoming);
return resultTensor; return resultTensor;
} }
double PyClassifier::score(torch::Tensor& X, torch::Tensor& y) float PyClassifier::score(torch::Tensor& X, torch::Tensor& y)
{ {
auto [Xn, yn] = tensors2numpy(X, y); auto [Xn, yn] = tensors2numpy(X, y);
CPyObject Xp = bp::incref(bp::object(Xn).ptr()); CPyObject Xp = bp::incref(bp::object(Xn).ptr());
CPyObject yp = bp::incref(bp::object(yn).ptr()); CPyObject yp = bp::incref(bp::object(yn).ptr());
auto result = pyWrap->score(id, Xp, yp); float result = pyWrap->score(id, Xp, yp);
return result; return result;
} }
void PyClassifier::setHyperparameters(const nlohmann::json& hyperparameters) void PyClassifier::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid, default is no hyperparameters // Check if hyperparameters are valid, default is no hyperparameters
const std::vector<std::string> validKeys = { }; const std::vector<std::string> validKeys = { };

View File

@@ -13,25 +13,41 @@
#include "TypeId.h" #include "TypeId.h"
namespace pywrap { namespace pywrap {
class PyClassifier : public Classifier { class PyClassifier : public bayesnet::BaseClassifier {
public: public:
PyClassifier(const std::string& module, const std::string& className); PyClassifier(const std::string& module, const std::string& className, const bool sklearn = false);
virtual ~PyClassifier(); virtual ~PyClassifier();
PyClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override { return *this; };
// X is nxm tensor, y is nx1 tensor
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override; PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y) override; PyClassifier& fit(torch::Tensor& X, torch::Tensor& y);
PyClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override { return *this; };
PyClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) override { return *this; };
torch::Tensor predict(torch::Tensor& X) override; torch::Tensor predict(torch::Tensor& X) override;
double score(torch::Tensor& X, torch::Tensor& y) override; std::vector<int> predict(std::vector<std::vector<int >>& X) override { return std::vector<int>(); };
std::string version() override; float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override { return 0.0; };
std::string sklearnVersion() override; float score(torch::Tensor& X, torch::Tensor& y) override;
void setHyperparameters(nlohmann::json& hyperparameters) override;
std::string version();
std::string callMethodString(const std::string& method); std::string callMethodString(const std::string& method);
void setHyperparameters(const nlohmann::json& hyperparameters) override; std::string getVersion() override { return this->version(); };
int getNumberOfNodes()const override { return 0; };
int getNumberOfEdges()const override { return 0; };
int getNumberOfStates() const override { return 0; };
std::vector<std::string> show() const override { return std::vector<std::string>(); }
std::vector<std::string> graph(const std::string& title = "") const override { return std::vector<std::string>(); }
bayesnet::status_t getStatus() const override { return bayesnet::NORMAL; };
std::vector<std::string> topological_order() override { return std::vector<std::string>(); }
void dump_cpt() const override {};
protected: protected:
void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters) override; void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters);
nlohmann::json hyperparameters; nlohmann::json hyperparameters;
void trainModel(const torch::Tensor& weights) override {};
private: private:
PyWrap* pyWrap; PyWrap* pyWrap;
std::string module; std::string module;
std::string className; std::string className;
bool sklearn;
clfId_t id; clfId_t id;
bool fitted; bool fitted;
}; };

View File

@@ -5,6 +5,7 @@
#include <map> #include <map>
#include <sstream> #include <sstream>
#include <boost/python/numpy.hpp> #include <boost/python/numpy.hpp>
#include <iostream>
namespace pywrap { namespace pywrap {
namespace np = boost::python::numpy; namespace np = boost::python::numpy;
@@ -19,6 +20,7 @@ namespace pywrap {
if (wrapper == nullptr) { if (wrapper == nullptr) {
wrapper = new PyWrap(); wrapper = new PyWrap();
pyInstance = new CPyInstance(); pyInstance = new CPyInstance();
PyRun_SimpleString("import warnings;warnings.filterwarnings('ignore')");
} }
return wrapper; return wrapper;
} }
@@ -72,9 +74,11 @@ namespace pywrap {
PyErr_Print(); PyErr_Print();
errorAbort("Error cleaning module "); errorAbort("Error cleaning module ");
} }
if (moduleClassMap.empty()) { // With boost you can't remove the interpreter
RemoveInstance(); // https://www.boost.org/doc/libs/1_83_0/libs/python/doc/html/tutorial/tutorial/embedding.html#tutorial.embedding.getting_started
} // if (moduleClassMap.empty()) {
// RemoveInstance();
// }
} }
void PyWrap::errorAbort(const std::string& message) void PyWrap::errorAbort(const std::string& message)
{ {

View File

@@ -1,8 +1,11 @@
#include "RandomForest.h" #include "RandomForest.h"
namespace pywrap { namespace pywrap {
std::string RandomForest::version() void RandomForest::setHyperparameters(nlohmann::json& hyperparameters)
{ {
return sklearnVersion(); // Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "n_estimators", "n_jobs", "random_state" };
checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters;
} }
} /* namespace pywrap */ } /* namespace pywrap */

View File

@@ -5,9 +5,9 @@
namespace pywrap { namespace pywrap {
class RandomForest : public PyClassifier { class RandomForest : public PyClassifier {
public: public:
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier") {}; RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier", true) {};
~RandomForest() = default; ~RandomForest() = default;
std::string version(); void setHyperparameters(nlohmann::json& hyperparameters) override;
}; };
} /* namespace pywrap */ } /* namespace pywrap */
#endif /* RANDOMFOREST_H */ #endif /* RANDOMFOREST_H */

View File

@@ -5,10 +5,10 @@ namespace pywrap {
{ {
return callMethodString("graph"); return callMethodString("graph");
} }
void STree::setHyperparameters(const nlohmann::json& hyperparameters) void STree::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid // Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "C", "n_jobs", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy" }; const std::vector<std::string> validKeys = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy" };
checkHyperparameters(validKeys, hyperparameters); checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters; this->hyperparameters = hyperparameters;
} }

View File

@@ -9,7 +9,7 @@ namespace pywrap {
STree() : PyClassifier("stree", "Stree") {}; STree() : PyClassifier("stree", "Stree") {};
~STree() = default; ~STree() = default;
std::string graph(); std::string graph();
void setHyperparameters(const nlohmann::json& hyperparameters) override; void setHyperparameters(nlohmann::json& hyperparameters) override;
}; };
} /* namespace pywrap */ } /* namespace pywrap */
#endif /* STREE_H */ #endif /* STREE_H */

View File

@@ -1,11 +1,7 @@
#include "SVC.h" #include "SVC.h"
namespace pywrap { namespace pywrap {
std::string SVC::version() void SVC::setHyperparameters(nlohmann::json& hyperparameters)
{
return sklearnVersion();
}
void SVC::setHyperparameters(const nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid // Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "C", "gamma", "kernel", "random_state" }; const std::vector<std::string> validKeys = { "C", "gamma", "kernel", "random_state" };

View File

@@ -5,10 +5,9 @@
namespace pywrap { namespace pywrap {
class SVC : public PyClassifier { class SVC : public PyClassifier {
public: public:
SVC() : PyClassifier("sklearn.svm", "SVC") {}; SVC() : PyClassifier("sklearn.svm", "SVC", true) {};
~SVC() = default; ~SVC() = default;
std::string version(); void setHyperparameters(nlohmann::json& hyperparameters) override;
void setHyperparameters(const nlohmann::json& hyperparameters) override;
}; };
} /* namespace pywrap */ } /* namespace pywrap */