Add nodes/leaves/depth to STree & ODTE

This commit is contained in:
Ricardo Montañana Gómez 2023-11-27 10:57:57 +01:00
parent 4fefe9a1d2
commit 82964190f6
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
8 changed files with 53 additions and 2 deletions

View File

@ -5,6 +5,18 @@ namespace pywrap {
{
validHyperparameters = { "n_jobs", "n_estimators", "random_state" };
}
int ODTE::getNumberOfNodes() const
{
return callMethodInt("get_nodes");
}
int ODTE::getNumberOfEdges() const
{
return callMethodInt("get_leaves");
}
int ODTE::getNumberOfStates() const
{
return callMethodInt("get_depth");
}
std::string ODTE::graph()
{
return callMethodString("graph");

View File

@ -8,6 +8,9 @@ namespace pywrap {
public:
ODTE();
~ODTE() = default;
int getNumberOfNodes() const override;
int getNumberOfEdges() const override;
int getNumberOfStates() const override;
std::string graph();
};
} /* namespace pywrap */

View File

@ -38,6 +38,10 @@ namespace pywrap {
{
return pyWrap->callMethodString(id, method);
}
int PyClassifier::callMethodInt(const std::string& method) const
{
return pyWrap->callMethodInt(id, method);
}
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y)
{
if (!fitted && hyperparameters.size() > 0) {

View File

@ -29,10 +29,11 @@ namespace pywrap {
float score(torch::Tensor& X, torch::Tensor& y) override;
std::string version();
std::string callMethodString(const std::string& method);
int callMethodInt(const std::string& method) const;
std::string getVersion() override { return this->version(); };
// TODO: Implement these 3 methods
int getNumberOfNodes()const override { return 0; };
int getNumberOfEdges()const override { return 0; };
int getNumberOfNodes() const override { return 0; };
int getNumberOfEdges() const override { return 0; };
int getNumberOfStates() const override { return 0; };
std::vector<std::string> show() const override { return std::vector<std::string>(); }
std::vector<std::string> graph(const std::string& title = "") const override { return std::vector<std::string>(); }

View File

@ -110,6 +110,21 @@ namespace pywrap {
Py_XDECREF(result);
return value;
}
int PyWrap::callMethodInt(const clfId_t id, const std::string& method)
{
PyObject* instance = getClass(id);
PyObject* result;
try {
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
errorAbort("Couldn't call method " + method);
}
catch (const std::exception& e) {
errorAbort(e.what());
}
int value = PyLong_AsLong(result);
Py_XDECREF(result);
return value;
}
std::string PyWrap::sklearnVersion()
{
return "1.0";

View File

@ -24,6 +24,7 @@ namespace pywrap {
void operator=(const PyWrap&) = delete;
~PyWrap() = default;
std::string callMethodString(const clfId_t id, const std::string& method);
int callMethodInt(const clfId_t id, const std::string& method);
std::string sklearnVersion();
std::string version(const clfId_t id);
void setHyperparameters(const clfId_t id, const json& hyperparameters);

View File

@ -5,6 +5,18 @@ namespace pywrap {
{
validHyperparameters = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy", "gamma", "max_features", "degree" };
};
int STree::getNumberOfNodes() const
{
return callMethodInt("get_nodes");
}
int STree::getNumberOfEdges() const
{
return callMethodInt("get_leaves");
}
int STree::getNumberOfStates() const
{
return callMethodInt("get_depth");
}
std::string STree::graph()
{
return callMethodString("graph");

View File

@ -8,6 +8,9 @@ namespace pywrap {
public:
STree();
~STree() = default;
int getNumberOfNodes() const override;
int getNumberOfEdges() const override;
int getNumberOfStates() const override;
std::string graph();
};
} /* namespace pywrap */