Add nodes/leaves/depth to STree & ODTE
This commit is contained in:
parent
4fefe9a1d2
commit
82964190f6
@ -5,6 +5,18 @@ namespace pywrap {
|
||||
{
|
||||
validHyperparameters = { "n_jobs", "n_estimators", "random_state" };
|
||||
}
|
||||
int ODTE::getNumberOfNodes() const
|
||||
{
|
||||
return callMethodInt("get_nodes");
|
||||
}
|
||||
int ODTE::getNumberOfEdges() const
|
||||
{
|
||||
return callMethodInt("get_leaves");
|
||||
}
|
||||
int ODTE::getNumberOfStates() const
|
||||
{
|
||||
return callMethodInt("get_depth");
|
||||
}
|
||||
std::string ODTE::graph()
|
||||
{
|
||||
return callMethodString("graph");
|
||||
|
@ -8,6 +8,9 @@ namespace pywrap {
|
||||
public:
|
||||
ODTE();
|
||||
~ODTE() = default;
|
||||
int getNumberOfNodes() const override;
|
||||
int getNumberOfEdges() const override;
|
||||
int getNumberOfStates() const override;
|
||||
std::string graph();
|
||||
};
|
||||
} /* namespace pywrap */
|
||||
|
@ -38,6 +38,10 @@ namespace pywrap {
|
||||
{
|
||||
return pyWrap->callMethodString(id, method);
|
||||
}
|
||||
int PyClassifier::callMethodInt(const std::string& method) const
|
||||
{
|
||||
return pyWrap->callMethodInt(id, method);
|
||||
}
|
||||
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y)
|
||||
{
|
||||
if (!fitted && hyperparameters.size() > 0) {
|
||||
|
@ -29,10 +29,11 @@ namespace pywrap {
|
||||
float score(torch::Tensor& X, torch::Tensor& y) override;
|
||||
std::string version();
|
||||
std::string callMethodString(const std::string& method);
|
||||
int callMethodInt(const std::string& method) const;
|
||||
std::string getVersion() override { return this->version(); };
|
||||
// TODO: Implement these 3 methods
|
||||
int getNumberOfNodes()const override { return 0; };
|
||||
int getNumberOfEdges()const override { return 0; };
|
||||
int getNumberOfNodes() const override { return 0; };
|
||||
int getNumberOfEdges() const override { return 0; };
|
||||
int getNumberOfStates() const override { return 0; };
|
||||
std::vector<std::string> show() const override { return std::vector<std::string>(); }
|
||||
std::vector<std::string> graph(const std::string& title = "") const override { return std::vector<std::string>(); }
|
||||
|
@ -110,6 +110,21 @@ namespace pywrap {
|
||||
Py_XDECREF(result);
|
||||
return value;
|
||||
}
|
||||
int PyWrap::callMethodInt(const clfId_t id, const std::string& method)
|
||||
{
|
||||
PyObject* instance = getClass(id);
|
||||
PyObject* result;
|
||||
try {
|
||||
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
|
||||
errorAbort("Couldn't call method " + method);
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
errorAbort(e.what());
|
||||
}
|
||||
int value = PyLong_AsLong(result);
|
||||
Py_XDECREF(result);
|
||||
return value;
|
||||
}
|
||||
std::string PyWrap::sklearnVersion()
|
||||
{
|
||||
return "1.0";
|
||||
|
@ -24,6 +24,7 @@ namespace pywrap {
|
||||
void operator=(const PyWrap&) = delete;
|
||||
~PyWrap() = default;
|
||||
std::string callMethodString(const clfId_t id, const std::string& method);
|
||||
int callMethodInt(const clfId_t id, const std::string& method);
|
||||
std::string sklearnVersion();
|
||||
std::string version(const clfId_t id);
|
||||
void setHyperparameters(const clfId_t id, const json& hyperparameters);
|
||||
|
@ -5,6 +5,18 @@ namespace pywrap {
|
||||
{
|
||||
validHyperparameters = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy", "gamma", "max_features", "degree" };
|
||||
};
|
||||
int STree::getNumberOfNodes() const
|
||||
{
|
||||
return callMethodInt("get_nodes");
|
||||
}
|
||||
int STree::getNumberOfEdges() const
|
||||
{
|
||||
return callMethodInt("get_leaves");
|
||||
}
|
||||
int STree::getNumberOfStates() const
|
||||
{
|
||||
return callMethodInt("get_depth");
|
||||
}
|
||||
std::string STree::graph()
|
||||
{
|
||||
return callMethodString("graph");
|
||||
|
@ -8,6 +8,9 @@ namespace pywrap {
|
||||
public:
|
||||
STree();
|
||||
~STree() = default;
|
||||
int getNumberOfNodes() const override;
|
||||
int getNumberOfEdges() const override;
|
||||
int getNumberOfStates() const override;
|
||||
std::string graph();
|
||||
};
|
||||
} /* namespace pywrap */
|
||||
|
Loading…
Reference in New Issue
Block a user