Refactor into library

This commit is contained in:
2023-11-10 11:24:27 +01:00
parent 55286168cb
commit 74fb0968c7
5 changed files with 7 additions and 24 deletions

View File

@@ -3,6 +3,9 @@ include_directories(${PyWrap_SOURCE_DIR}/lib/json/include)
include_directories(${Python3_INCLUDE_DIRS})
include_directories(${TORCH_INCLUDE_DIRS})
add_executable(main main.cc STree.cc SVC.cc RandomForest.cc PyClassifier.cc PyWrap.cc)
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles)
add_library(PyWrap SHARED PyWrap.cc STree.cc SVC.cc RandomForest.cc PyClassifier.cc)
target_link_libraries(PyWrap ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles)
add_executable(example example.cc)
target_link_libraries(example PyWrap)

View File

@@ -1,5 +1,4 @@
#include "PyClassifier.h"
#include <iostream>
namespace pywrap {
namespace bp = boost::python;
namespace np = boost::python::numpy;
@@ -7,7 +6,6 @@ namespace pywrap {
{
// This id allows to have more than one instance of the same module/class
id = reinterpret_cast<clfId_t>(this);
std::cout << "PyClassifier: Creating instance of " << module << " and class " << className << " id " << id << std::endl;
pyWrap = PyWrap::GetInstance();
pyWrap->importClass(id, module, className);
}
@@ -40,7 +38,6 @@ namespace pywrap {
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
{
if (!fitted && hyperparameters.size() > 0) {
std::cout << "PyClassifier: Setting hyperparameters" << std::endl;
pyWrap->setHyperparameters(id, hyperparameters);
}
auto [Xn, yn] = tensors2numpy(X, y);

View File

@@ -3,7 +3,6 @@
#include "PyWrap.h"
#include <string>
#include <map>
#include <iostream>
#include <sstream>
#include <boost/python/numpy.hpp>
@@ -117,7 +116,6 @@ namespace pywrap {
void PyWrap::setHyperparameters(const clfId_t id, const json& hyperparameters)
{
// Set hyperparameters as attributes of the class
std::cout << "Building dictionary of arguments" << std::endl;
PyObject* pValue;
PyObject* instance = getClass(id);
for (const auto& [key, value] : hyperparameters.items()) {