Fit PyWrap into BayesNet

This commit is contained in:
Ricardo Montañana Gómez 2023-11-13 11:13:32 +01:00
parent 6a23e2cc26
commit 431b3a3aa5
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
15 changed files with 48 additions and 40 deletions

View File

@ -26,7 +26,7 @@ namespace bayesnet {
int virtual getNumberOfStates() const = 0; int virtual getNumberOfStates() const = 0;
std::vector<std::string> virtual show() const = 0; std::vector<std::string> virtual show() const = 0;
std::vector<std::string> virtual graph(const std::string& title = "") const = 0; std::vector<std::string> virtual graph(const std::string& title = "") const = 0;
const std::string inline getVersion() const { return "0.2.0"; }; virtual std::string getVersion() = 0;
std::vector<std::string> virtual topological_order() = 0; std::vector<std::string> virtual topological_order() = 0;
void virtual dump_cpt()const = 0; void virtual dump_cpt()const = 0;
virtual void setHyperparameters(nlohmann::json& hyperparameters) = 0; virtual void setHyperparameters(nlohmann::json& hyperparameters) = 0;

View File

@ -3,6 +3,9 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet) include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
include_directories(${BayesNet_SOURCE_DIR}/src/Platform) include_directories(${BayesNet_SOURCE_DIR}/src/Platform)
include_directories(${BayesNet_SOURCE_DIR}/src/PyClassifiers)
include_directories(${Python3_INCLUDE_DIRS})
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc
KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc BoostAODE.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc BoostAODE.cc
Mst.cc Proposal.cc CFS.cc FCBF.cc IWSS.cc FeatureSelect.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc) Mst.cc Proposal.cc CFS.cc FCBF.cc IWSS.cc FeatureSelect.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)

View File

@ -37,6 +37,7 @@ namespace bayesnet {
int getNumberOfStates() const override; int getNumberOfStates() const override;
torch::Tensor predict(torch::Tensor& X) override; torch::Tensor predict(torch::Tensor& X) override;
status_t getStatus() const override { return status; } status_t getStatus() const override { return status; }
std::string getVersion() override { return "0.2.0"; };
std::vector<int> predict(std::vector<std::vector<int>>& X) override; std::vector<int> predict(std::vector<std::vector<int>>& X) override;
float score(torch::Tensor& X, torch::Tensor& y) override; float score(torch::Tensor& X, torch::Tensor& y) override;
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override; float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;

View File

@ -6,6 +6,7 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include) include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include)
include_directories(${Python3_INCLUDE_DIRS})
add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc ReportConsole.cc ReportBase.cc) add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc ReportConsole.cc ReportBase.cc)
add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)

View File

@ -12,6 +12,9 @@
#include "AODELd.h" #include "AODELd.h"
#include "BoostAODE.h" #include "BoostAODE.h"
#include "STree.h" #include "STree.h"
#include "ODTE.h"
#include "SVC.h"
#include "RandomForest.h"
namespace platform { namespace platform {
class Models { class Models {
private: private:

View File

@ -18,6 +18,12 @@ static platform::Registrar registrarALD("AODELd",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();}); [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();});
static platform::Registrar registrarBA("BoostAODE", static platform::Registrar registrarBA("BoostAODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();}); [](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();});
static pywrap::Registrar registrarSt("STree", static platform::Registrar registrarSt("STree",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();}); [](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();});
static platform::Registrar registrarOdte("Odte",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::ODTE();});
static platform::Registrar registrarSvc("SVC",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::SVC();});
static platform::Registrar registrarRaF("RandomForest",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();});
#endif #endif

View File

@ -1,22 +0,0 @@
#ifndef CLASSIFIER_H
#define CLASSIFIER_H
#include <torch/torch.h>
#include "BaseClassifier.h"
#include <nlohmann/json.hpp>
#include <string>
#include <map>
#include <vector>
namespace pywrap {
class Classifier : bayesnet::BaseClassifier {
public:
Classifier() = default;
virtual ~Classifier() = default;
virtual Classifier& fit(torch::Tensor& X, torch::Tensor& y) = 0;
virtual std::string version() = 0;
virtual std::string sklearnVersion() = 0;
protected:
virtual void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters) = 0;
};
} /* namespace pywrap */
#endif /* CLASSIFIER_H */

View File

@ -5,7 +5,7 @@ namespace pywrap {
{ {
return callMethodString("graph"); return callMethodString("graph");
} }
void ODTE::setHyperparameters(const nlohmann::json& hyperparameters) void ODTE::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid // Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "n_jobs", "n_estimators", "random_state" }; const std::vector<std::string> validKeys = { "n_jobs", "n_estimators", "random_state" };

View File

@ -9,7 +9,7 @@ namespace pywrap {
ODTE() : PyClassifier("odte", "Odte") {}; ODTE() : PyClassifier("odte", "Odte") {};
~ODTE() = default; ~ODTE() = default;
std::string graph(); std::string graph();
void setHyperparameters(const nlohmann::json& hyperparameters) override; void setHyperparameters(nlohmann::json& hyperparameters) override;
}; };
} /* namespace pywrap */ } /* namespace pywrap */
#endif /* ODTE_H */ #endif /* ODTE_H */

View File

@ -74,15 +74,15 @@ namespace pywrap {
Py_XDECREF(incoming); Py_XDECREF(incoming);
return resultTensor; return resultTensor;
} }
double PyClassifier::score(torch::Tensor& X, torch::Tensor& y) float PyClassifier::score(torch::Tensor& X, torch::Tensor& y)
{ {
auto [Xn, yn] = tensors2numpy(X, y); auto [Xn, yn] = tensors2numpy(X, y);
CPyObject Xp = bp::incref(bp::object(Xn).ptr()); CPyObject Xp = bp::incref(bp::object(Xn).ptr());
CPyObject yp = bp::incref(bp::object(yn).ptr()); CPyObject yp = bp::incref(bp::object(yn).ptr());
auto result = pyWrap->score(id, Xp, yp); float result = pyWrap->score(id, Xp, yp);
return result; return result;
} }
void PyClassifier::setHyperparameters(const nlohmann::json& hyperparameters) void PyClassifier::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid, default is no hyperparameters // Check if hyperparameters are valid, default is no hyperparameters
const std::vector<std::string> validKeys = { }; const std::vector<std::string> validKeys = { };

View File

@ -13,21 +13,37 @@
#include "TypeId.h" #include "TypeId.h"
namespace pywrap { namespace pywrap {
class PyClassifier : public Classifier { class PyClassifier : public bayesnet::BaseClassifier {
public: public:
PyClassifier(const std::string& module, const std::string& className); PyClassifier(const std::string& module, const std::string& className);
virtual ~PyClassifier(); virtual ~PyClassifier();
PyClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override { return *this; };
// X is nxm tensor, y is nx1 tensor
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override; PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y) override; PyClassifier& fit(torch::Tensor& X, torch::Tensor& y);
PyClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override { return *this; };
PyClassifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) { return *this; };
torch::Tensor predict(torch::Tensor& X) override; torch::Tensor predict(torch::Tensor& X) override;
double score(torch::Tensor& X, torch::Tensor& y) override; std::vector<int> predict(std::vector<std::vector<int >>& X) override { return std::vector<int>(); };
std::string version() override; float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override { return 0.0; };
std::string sklearnVersion() override; float score(torch::Tensor& X, torch::Tensor& y) override;
void setHyperparameters(nlohmann::json& hyperparameters) override;
std::string version();
std::string sklearnVersion();
std::string callMethodString(const std::string& method); std::string callMethodString(const std::string& method);
void setHyperparameters(const nlohmann::json& hyperparameters) override; std::string getVersion() override { return this->version(); };
int getNumberOfNodes()const override { return 0; };
int getNumberOfEdges()const override { return 0; };
int getNumberOfStates() const override { return 0; };
std::vector<std::string> show() const override { return std::vector<std::string>(); }
std::vector<std::string> graph(const std::string& title = "") const override { return std::vector<std::string>(); }
bayesnet::status_t getStatus() const override { return bayesnet::NORMAL; };
std::vector<std::string> topological_order() override { return std::vector<std::string>(); }
void dump_cpt() const override {};
protected: protected:
void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters) override; void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters);
nlohmann::json hyperparameters; nlohmann::json hyperparameters;
void trainModel(const torch::Tensor& weights) override {};
private: private:
PyWrap* pyWrap; PyWrap* pyWrap;
std::string module; std::string module;

View File

@ -5,7 +5,7 @@ namespace pywrap {
{ {
return callMethodString("graph"); return callMethodString("graph");
} }
void STree::setHyperparameters(const nlohmann::json& hyperparameters) void STree::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid // Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "C", "n_jobs", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy" }; const std::vector<std::string> validKeys = { "C", "n_jobs", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy" };

View File

@ -9,7 +9,7 @@ namespace pywrap {
STree() : PyClassifier("stree", "Stree") {}; STree() : PyClassifier("stree", "Stree") {};
~STree() = default; ~STree() = default;
std::string graph(); std::string graph();
void setHyperparameters(const nlohmann::json& hyperparameters) override; void setHyperparameters(nlohmann::json& hyperparameters) override;
}; };
} /* namespace pywrap */ } /* namespace pywrap */
#endif /* STREE_H */ #endif /* STREE_H */

View File

@ -5,7 +5,7 @@ namespace pywrap {
{ {
return sklearnVersion(); return sklearnVersion();
} }
void SVC::setHyperparameters(const nlohmann::json& hyperparameters) void SVC::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid // Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "C", "gamma", "kernel", "random_state" }; const std::vector<std::string> validKeys = { "C", "gamma", "kernel", "random_state" };

View File

@ -8,7 +8,7 @@ namespace pywrap {
SVC() : PyClassifier("sklearn.svm", "SVC") {}; SVC() : PyClassifier("sklearn.svm", "SVC") {};
~SVC() = default; ~SVC() = default;
std::string version(); std::string version();
void setHyperparameters(const nlohmann::json& hyperparameters) override; void setHyperparameters(nlohmann::json& hyperparameters) override;
}; };
} /* namespace pywrap */ } /* namespace pywrap */