Add nodes/leaves/depth to STree & ODTE
This commit is contained in:
parent
4fefe9a1d2
commit
82964190f6
@ -5,6 +5,18 @@ namespace pywrap {
|
|||||||
{
|
{
|
||||||
validHyperparameters = { "n_jobs", "n_estimators", "random_state" };
|
validHyperparameters = { "n_jobs", "n_estimators", "random_state" };
|
||||||
}
|
}
|
||||||
|
int ODTE::getNumberOfNodes() const
|
||||||
|
{
|
||||||
|
return callMethodInt("get_nodes");
|
||||||
|
}
|
||||||
|
int ODTE::getNumberOfEdges() const
|
||||||
|
{
|
||||||
|
return callMethodInt("get_leaves");
|
||||||
|
}
|
||||||
|
int ODTE::getNumberOfStates() const
|
||||||
|
{
|
||||||
|
return callMethodInt("get_depth");
|
||||||
|
}
|
||||||
std::string ODTE::graph()
|
std::string ODTE::graph()
|
||||||
{
|
{
|
||||||
return callMethodString("graph");
|
return callMethodString("graph");
|
||||||
|
@ -8,6 +8,9 @@ namespace pywrap {
|
|||||||
public:
|
public:
|
||||||
ODTE();
|
ODTE();
|
||||||
~ODTE() = default;
|
~ODTE() = default;
|
||||||
|
int getNumberOfNodes() const override;
|
||||||
|
int getNumberOfEdges() const override;
|
||||||
|
int getNumberOfStates() const override;
|
||||||
std::string graph();
|
std::string graph();
|
||||||
};
|
};
|
||||||
} /* namespace pywrap */
|
} /* namespace pywrap */
|
||||||
|
@ -38,6 +38,10 @@ namespace pywrap {
|
|||||||
{
|
{
|
||||||
return pyWrap->callMethodString(id, method);
|
return pyWrap->callMethodString(id, method);
|
||||||
}
|
}
|
||||||
|
int PyClassifier::callMethodInt(const std::string& method) const
|
||||||
|
{
|
||||||
|
return pyWrap->callMethodInt(id, method);
|
||||||
|
}
|
||||||
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y)
|
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y)
|
||||||
{
|
{
|
||||||
if (!fitted && hyperparameters.size() > 0) {
|
if (!fitted && hyperparameters.size() > 0) {
|
||||||
|
@ -29,6 +29,7 @@ namespace pywrap {
|
|||||||
float score(torch::Tensor& X, torch::Tensor& y) override;
|
float score(torch::Tensor& X, torch::Tensor& y) override;
|
||||||
std::string version();
|
std::string version();
|
||||||
std::string callMethodString(const std::string& method);
|
std::string callMethodString(const std::string& method);
|
||||||
|
int callMethodInt(const std::string& method) const;
|
||||||
std::string getVersion() override { return this->version(); };
|
std::string getVersion() override { return this->version(); };
|
||||||
// TODO: Implement these 3 methods
|
// TODO: Implement these 3 methods
|
||||||
int getNumberOfNodes() const override { return 0; };
|
int getNumberOfNodes() const override { return 0; };
|
||||||
|
@ -110,6 +110,21 @@ namespace pywrap {
|
|||||||
Py_XDECREF(result);
|
Py_XDECREF(result);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
int PyWrap::callMethodInt(const clfId_t id, const std::string& method)
|
||||||
|
{
|
||||||
|
PyObject* instance = getClass(id);
|
||||||
|
PyObject* result;
|
||||||
|
try {
|
||||||
|
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
|
||||||
|
errorAbort("Couldn't call method " + method);
|
||||||
|
}
|
||||||
|
catch (const std::exception& e) {
|
||||||
|
errorAbort(e.what());
|
||||||
|
}
|
||||||
|
int value = PyLong_AsLong(result);
|
||||||
|
Py_XDECREF(result);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
std::string PyWrap::sklearnVersion()
|
std::string PyWrap::sklearnVersion()
|
||||||
{
|
{
|
||||||
return "1.0";
|
return "1.0";
|
||||||
|
@ -24,6 +24,7 @@ namespace pywrap {
|
|||||||
void operator=(const PyWrap&) = delete;
|
void operator=(const PyWrap&) = delete;
|
||||||
~PyWrap() = default;
|
~PyWrap() = default;
|
||||||
std::string callMethodString(const clfId_t id, const std::string& method);
|
std::string callMethodString(const clfId_t id, const std::string& method);
|
||||||
|
int callMethodInt(const clfId_t id, const std::string& method);
|
||||||
std::string sklearnVersion();
|
std::string sklearnVersion();
|
||||||
std::string version(const clfId_t id);
|
std::string version(const clfId_t id);
|
||||||
void setHyperparameters(const clfId_t id, const json& hyperparameters);
|
void setHyperparameters(const clfId_t id, const json& hyperparameters);
|
||||||
|
@ -5,6 +5,18 @@ namespace pywrap {
|
|||||||
{
|
{
|
||||||
validHyperparameters = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy", "gamma", "max_features", "degree" };
|
validHyperparameters = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy", "gamma", "max_features", "degree" };
|
||||||
};
|
};
|
||||||
|
int STree::getNumberOfNodes() const
|
||||||
|
{
|
||||||
|
return callMethodInt("get_nodes");
|
||||||
|
}
|
||||||
|
int STree::getNumberOfEdges() const
|
||||||
|
{
|
||||||
|
return callMethodInt("get_leaves");
|
||||||
|
}
|
||||||
|
int STree::getNumberOfStates() const
|
||||||
|
{
|
||||||
|
return callMethodInt("get_depth");
|
||||||
|
}
|
||||||
std::string STree::graph()
|
std::string STree::graph()
|
||||||
{
|
{
|
||||||
return callMethodString("graph");
|
return callMethodString("graph");
|
||||||
|
@ -8,6 +8,9 @@ namespace pywrap {
|
|||||||
public:
|
public:
|
||||||
STree();
|
STree();
|
||||||
~STree() = default;
|
~STree() = default;
|
||||||
|
int getNumberOfNodes() const override;
|
||||||
|
int getNumberOfEdges() const override;
|
||||||
|
int getNumberOfStates() const override;
|
||||||
std::string graph();
|
std::string graph();
|
||||||
};
|
};
|
||||||
} /* namespace pywrap */
|
} /* namespace pywrap */
|
||||||
|
Loading…
Reference in New Issue
Block a user