Add nodes/leaves/depth to STree & ODTE

This commit is contained in:
Ricardo Montañana Gómez 2023-11-27 10:57:57 +01:00
parent 4fefe9a1d2
commit 82964190f6
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
8 changed files with 53 additions and 2 deletions

View File

@ -5,6 +5,18 @@ namespace pywrap {
{ {
validHyperparameters = { "n_jobs", "n_estimators", "random_state" }; validHyperparameters = { "n_jobs", "n_estimators", "random_state" };
} }
int ODTE::getNumberOfNodes() const
{
return callMethodInt("get_nodes");
}
int ODTE::getNumberOfEdges() const
{
return callMethodInt("get_leaves");
}
int ODTE::getNumberOfStates() const
{
return callMethodInt("get_depth");
}
std::string ODTE::graph() std::string ODTE::graph()
{ {
return callMethodString("graph"); return callMethodString("graph");

View File

@ -8,6 +8,9 @@ namespace pywrap {
public: public:
ODTE(); ODTE();
~ODTE() = default; ~ODTE() = default;
int getNumberOfNodes() const override;
int getNumberOfEdges() const override;
int getNumberOfStates() const override;
std::string graph(); std::string graph();
}; };
} /* namespace pywrap */ } /* namespace pywrap */

View File

@ -38,6 +38,10 @@ namespace pywrap {
{ {
return pyWrap->callMethodString(id, method); return pyWrap->callMethodString(id, method);
} }
int PyClassifier::callMethodInt(const std::string& method) const
{
return pyWrap->callMethodInt(id, method);
}
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y) PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y)
{ {
if (!fitted && hyperparameters.size() > 0) { if (!fitted && hyperparameters.size() > 0) {

View File

@ -29,6 +29,7 @@ namespace pywrap {
float score(torch::Tensor& X, torch::Tensor& y) override; float score(torch::Tensor& X, torch::Tensor& y) override;
std::string version(); std::string version();
std::string callMethodString(const std::string& method); std::string callMethodString(const std::string& method);
int callMethodInt(const std::string& method) const;
std::string getVersion() override { return this->version(); }; std::string getVersion() override { return this->version(); };
// TODO: Implement these 3 methods // TODO: Implement these 3 methods
int getNumberOfNodes() const override { return 0; }; int getNumberOfNodes() const override { return 0; };

View File

@ -110,6 +110,21 @@ namespace pywrap {
Py_XDECREF(result); Py_XDECREF(result);
return value; return value;
} }
int PyWrap::callMethodInt(const clfId_t id, const std::string& method)
{
PyObject* instance = getClass(id);
PyObject* result;
try {
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
errorAbort("Couldn't call method " + method);
}
catch (const std::exception& e) {
errorAbort(e.what());
}
int value = PyLong_AsLong(result);
Py_XDECREF(result);
return value;
}
std::string PyWrap::sklearnVersion() std::string PyWrap::sklearnVersion()
{ {
return "1.0"; return "1.0";

View File

@ -24,6 +24,7 @@ namespace pywrap {
void operator=(const PyWrap&) = delete; void operator=(const PyWrap&) = delete;
~PyWrap() = default; ~PyWrap() = default;
std::string callMethodString(const clfId_t id, const std::string& method); std::string callMethodString(const clfId_t id, const std::string& method);
int callMethodInt(const clfId_t id, const std::string& method);
std::string sklearnVersion(); std::string sklearnVersion();
std::string version(const clfId_t id); std::string version(const clfId_t id);
void setHyperparameters(const clfId_t id, const json& hyperparameters); void setHyperparameters(const clfId_t id, const json& hyperparameters);

View File

@ -5,6 +5,18 @@ namespace pywrap {
{ {
validHyperparameters = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy", "gamma", "max_features", "degree" }; validHyperparameters = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy", "gamma", "max_features", "degree" };
}; };
int STree::getNumberOfNodes() const
{
return callMethodInt("get_nodes");
}
int STree::getNumberOfEdges() const
{
return callMethodInt("get_leaves");
}
int STree::getNumberOfStates() const
{
return callMethodInt("get_depth");
}
std::string STree::graph() std::string STree::graph()
{ {
return callMethodString("graph"); return callMethodString("graph");

View File

@ -8,6 +8,9 @@ namespace pywrap {
public: public:
STree(); STree();
~STree() = default; ~STree() = default;
int getNumberOfNodes() const override;
int getNumberOfEdges() const override;
int getNumberOfStates() const override;
std::string graph(); std::string graph();
}; };
} /* namespace pywrap */ } /* namespace pywrap */