Refactor Hyperparameters management

This commit is contained in:
2023-11-19 22:36:27 +01:00
parent 89c4613591
commit 4f3a04058f
21 changed files with 1070 additions and 78 deletions

View File

@@ -1,15 +1,12 @@
#include "ODTE.h"
namespace pywrap {
ODTE::ODTE() : PyClassifier("odte", "Odte")
{
validHyperparameters = { "n_jobs", "n_estimators", "random_state" };
}
std::string ODTE::graph()
{
return callMethodString("graph");
}
void ODTE::setHyperparameters(const nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "n_jobs", "n_estimators", "random_state" };
checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters;
}
} /* namespace pywrap */

View File

@@ -6,10 +6,9 @@
namespace pywrap {
class ODTE : public PyClassifier {
public:
ODTE() : PyClassifier("odte", "Odte") {};
ODTE();
~ODTE() = default;
std::string graph();
void setHyperparameters(const nlohmann::json& hyperparameters) override;
};
} /* namespace pywrap */
#endif /* ODTE_H */

View File

@@ -83,17 +83,6 @@ namespace pywrap {
}
void PyClassifier::setHyperparameters(const nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid, default is no hyperparameters
const std::vector<std::string> validKeys = { };
checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters;
}
void PyClassifier::checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters)
{
for (const auto& item : hyperparameters.items()) {
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
throw std::invalid_argument("Hyperparameter " + item.key() + " is not valid");
}
}
}
} /* namespace pywrap */

View File

@@ -40,7 +40,6 @@ namespace pywrap {
void dump_cpt() const override {};
void setHyperparameters(const nlohmann::json& hyperparameters) override;
protected:
void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters);
nlohmann::json hyperparameters;
void trainModel(const torch::Tensor& weights) override {};
private:

View File

@@ -1,11 +1,8 @@
#include "RandomForest.h"
namespace pywrap {
void RandomForest::setHyperparameters(const nlohmann::json& hyperparameters)
RandomForest::RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier", true)
{
// Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "n_estimators", "n_jobs", "random_state" };
checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters;
validHyperparameters = { "n_estimators", "n_jobs", "random_state" };
}
} /* namespace pywrap */

View File

@@ -5,9 +5,8 @@
namespace pywrap {
class RandomForest : public PyClassifier {
public:
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier", true) {};
RandomForest();
~RandomForest() = default;
void setHyperparameters(const nlohmann::json& hyperparameters) override;
};
} /* namespace pywrap */
#endif /* RANDOMFOREST_H */

View File

@@ -1,15 +1,12 @@
#include "STree.h"
namespace pywrap {
STree::STree() : PyClassifier("stree", "Stree")
{
validHyperparameters = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy", "gamma", "max_features", "degree" };
};
std::string STree::graph()
{
return callMethodString("graph");
}
void STree::setHyperparameters(const nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy" };
checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters;
}
} /* namespace pywrap */

View File

@@ -6,10 +6,9 @@
namespace pywrap {
class STree : public PyClassifier {
public:
STree() : PyClassifier("stree", "Stree") {};
STree();
~STree() = default;
std::string graph();
void setHyperparameters(const nlohmann::json& hyperparameters) override;
};
} /* namespace pywrap */
#endif /* STREE_H */

View File

@@ -1,11 +1,8 @@
#include "SVC.h"
namespace pywrap {
void SVC::setHyperparameters(const nlohmann::json& hyperparameters)
SVC::SVC() : PyClassifier("sklearn.svm", "SVC", true)
{
// Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "C", "gamma", "kernel", "random_state" };
checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters;
validHyperparameters = { "C", "gamma", "kernel", "random_state" };
}
} /* namespace pywrap */

View File

@@ -5,10 +5,9 @@
namespace pywrap {
class SVC : public PyClassifier {
public:
SVC() : PyClassifier("sklearn.svm", "SVC", true) {};
SVC();
~SVC() = default;
void setHyperparameters(const nlohmann::json& hyperparameters) override;
};
} /* namespace pywrap */
#endif /* STREE_H */
#endif /* SVC_H */