Refactor Hyperparameters management
This commit is contained in:
@@ -6,8 +6,6 @@
|
||||
namespace bayesnet {
|
||||
enum status_t { NORMAL, WARNING, ERROR };
|
||||
class BaseClassifier {
|
||||
protected:
|
||||
virtual void trainModel(const torch::Tensor& weights) = 0;
|
||||
public:
|
||||
// X is nxm std::vector, y is nx1 std::vector
|
||||
virtual BaseClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
|
||||
@@ -30,6 +28,10 @@ namespace bayesnet {
|
||||
std::vector<std::string> virtual topological_order() = 0;
|
||||
void virtual dump_cpt()const = 0;
|
||||
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
|
||||
std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
|
||||
protected:
|
||||
virtual void trainModel(const torch::Tensor& weights) = 0;
|
||||
std::vector<std::string> validHyperparameters;
|
||||
};
|
||||
}
|
||||
#endif
|
@@ -10,7 +10,11 @@
|
||||
#include "IWSS.h"
|
||||
|
||||
namespace bayesnet {
|
||||
BoostAODE::BoostAODE() : Ensemble() {}
|
||||
BoostAODE::BoostAODE() : Ensemble()
|
||||
{
|
||||
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features" };
|
||||
|
||||
}
|
||||
void BoostAODE::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// Models shall be built in trainModel
|
||||
@@ -45,9 +49,6 @@ namespace bayesnet {
|
||||
}
|
||||
void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Check if hyperparameters are valid
|
||||
const std::vector<std::string> validKeys = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features" };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
if (hyperparameters.contains("repeatSparent")) {
|
||||
repeatSparent = hyperparameters["repeatSparent"];
|
||||
}
|
||||
|
@@ -153,18 +153,8 @@ namespace bayesnet {
|
||||
{
|
||||
model.dump_cpt();
|
||||
}
|
||||
void Classifier::checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters)
|
||||
{
|
||||
for (const auto& item : hyperparameters.items()) {
|
||||
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
|
||||
throw std::invalid_argument("Hyperparameter " + item.key() + " is not valid");
|
||||
}
|
||||
}
|
||||
}
|
||||
void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Check if hyperparameters are valid, default is no hyperparameters
|
||||
const std::vector<std::string> validKeys = { };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
//For classifiers that don't have hyperparameters
|
||||
}
|
||||
}
|
@@ -22,7 +22,6 @@ namespace bayesnet {
|
||||
void checkFitParameters();
|
||||
virtual void buildModel(const torch::Tensor& weights) = 0;
|
||||
void trainModel(const torch::Tensor& weights) override;
|
||||
void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters);
|
||||
void buildDataset(torch::Tensor& y);
|
||||
public:
|
||||
Classifier(Network model);
|
||||
@@ -44,7 +43,7 @@ namespace bayesnet {
|
||||
std::vector<std::string> show() const override;
|
||||
std::vector<std::string> topological_order() override;
|
||||
void dump_cpt() const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
@@ -1,12 +1,13 @@
|
||||
#include "KDB.h"
|
||||
|
||||
namespace bayesnet {
|
||||
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {}
|
||||
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta)
|
||||
{
|
||||
validHyperparameters = { "k", "theta" };
|
||||
|
||||
}
|
||||
void KDB::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Check if hyperparameters are valid
|
||||
const std::vector<std::string> validKeys = { "k", "theta" };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
if (hyperparameters.contains("k")) {
|
||||
k = hyperparameters["k"];
|
||||
}
|
||||
|
@@ -170,9 +170,9 @@ namespace platform {
|
||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||
auto clf = Models::instance()->create(model);
|
||||
setModelVersion(clf->getVersion());
|
||||
if (hyperparameters.notEmpty(fileName)) {
|
||||
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||
}
|
||||
auto valid = clf->getValidHyperparameters();
|
||||
hyperparameters.check(valid, fileName);
|
||||
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||
// Split train - test dataset
|
||||
train_timer.start();
|
||||
auto [train, test] = fold->getFold(nfold);
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include "HyperParameters.h"
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
namespace platform {
|
||||
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const json& hyperparameters_)
|
||||
@@ -21,13 +22,24 @@ namespace platform {
|
||||
// Check if hyperparameters are valid
|
||||
for (const auto& dataset : datasets) {
|
||||
if (!input_hyperparameters.contains(dataset)) {
|
||||
throw std::runtime_error("Dataset " + dataset + " not found in hyperparameters file");
|
||||
std::cerr << "*Warning: Dataset " << dataset << " not found in hyperparameters file" << " assuming default hyperparameters" << std::endl;
|
||||
hyperparameters[dataset] = json({});
|
||||
continue;
|
||||
}
|
||||
hyperparameters[dataset] = input_hyperparameters[dataset];
|
||||
hyperparameters[dataset] = input_hyperparameters[dataset].get<json>();
|
||||
}
|
||||
}
|
||||
json HyperParameters::get(const std::string& key)
|
||||
void HyperParameters::check(const std::vector<std::string>& valid, const std::string& fileName)
|
||||
{
|
||||
return hyperparameters.at(key);
|
||||
json result = hyperparameters.at(fileName);
|
||||
for (const auto& item : result.items()) {
|
||||
if (find(valid.begin(), valid.end(), item.key()) == valid.end()) {
|
||||
throw std::invalid_argument("Hyperparameter " + item.key() + " is not valid. Passed Hyperparameters are: " + result.dump(4));
|
||||
}
|
||||
}
|
||||
}
|
||||
json HyperParameters::get(const std::string& fileName)
|
||||
{
|
||||
return hyperparameters.at(fileName);
|
||||
}
|
||||
} /* namespace platform */
|
@@ -13,8 +13,9 @@ namespace platform {
|
||||
explicit HyperParameters(const std::vector<std::string>& datasets, const json& hyperparameters_);
|
||||
explicit HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file);
|
||||
~HyperParameters() = default;
|
||||
bool notEmpty(const std::string& key) const { return hyperparameters.at(key) != json(); }
|
||||
json get(const std::string& key);
|
||||
bool notEmpty(const std::string& key) const { return !hyperparameters.at(key).empty(); }
|
||||
void check(const std::vector<std::string>& valid, const std::string& fileName);
|
||||
json get(const std::string& fileName);
|
||||
private:
|
||||
std::map<std::string, json> hyperparameters;
|
||||
};
|
||||
|
@@ -1,15 +1,12 @@
|
||||
#include "ODTE.h"
|
||||
|
||||
namespace pywrap {
|
||||
ODTE::ODTE() : PyClassifier("odte", "Odte")
|
||||
{
|
||||
validHyperparameters = { "n_jobs", "n_estimators", "random_state" };
|
||||
}
|
||||
std::string ODTE::graph()
|
||||
{
|
||||
return callMethodString("graph");
|
||||
}
|
||||
void ODTE::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Check if hyperparameters are valid
|
||||
const std::vector<std::string> validKeys = { "n_jobs", "n_estimators", "random_state" };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
this->hyperparameters = hyperparameters;
|
||||
}
|
||||
} /* namespace pywrap */
|
@@ -6,10 +6,9 @@
|
||||
namespace pywrap {
|
||||
class ODTE : public PyClassifier {
|
||||
public:
|
||||
ODTE() : PyClassifier("odte", "Odte") {};
|
||||
ODTE();
|
||||
~ODTE() = default;
|
||||
std::string graph();
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
};
|
||||
} /* namespace pywrap */
|
||||
#endif /* ODTE_H */
|
@@ -83,17 +83,6 @@ namespace pywrap {
|
||||
}
|
||||
void PyClassifier::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Check if hyperparameters are valid, default is no hyperparameters
|
||||
const std::vector<std::string> validKeys = { };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
this->hyperparameters = hyperparameters;
|
||||
}
|
||||
void PyClassifier::checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters)
|
||||
{
|
||||
for (const auto& item : hyperparameters.items()) {
|
||||
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
|
||||
throw std::invalid_argument("Hyperparameter " + item.key() + " is not valid");
|
||||
}
|
||||
}
|
||||
}
|
||||
} /* namespace pywrap */
|
@@ -40,7 +40,6 @@ namespace pywrap {
|
||||
void dump_cpt() const override {};
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
protected:
|
||||
void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters);
|
||||
nlohmann::json hyperparameters;
|
||||
void trainModel(const torch::Tensor& weights) override {};
|
||||
private:
|
||||
|
@@ -1,11 +1,8 @@
|
||||
#include "RandomForest.h"
|
||||
|
||||
namespace pywrap {
|
||||
void RandomForest::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
RandomForest::RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier", true)
|
||||
{
|
||||
// Check if hyperparameters are valid
|
||||
const std::vector<std::string> validKeys = { "n_estimators", "n_jobs", "random_state" };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
this->hyperparameters = hyperparameters;
|
||||
validHyperparameters = { "n_estimators", "n_jobs", "random_state" };
|
||||
}
|
||||
} /* namespace pywrap */
|
@@ -5,9 +5,8 @@
|
||||
namespace pywrap {
|
||||
class RandomForest : public PyClassifier {
|
||||
public:
|
||||
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier", true) {};
|
||||
RandomForest();
|
||||
~RandomForest() = default;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
};
|
||||
} /* namespace pywrap */
|
||||
#endif /* RANDOMFOREST_H */
|
@@ -1,15 +1,12 @@
|
||||
#include "STree.h"
|
||||
|
||||
namespace pywrap {
|
||||
STree::STree() : PyClassifier("stree", "Stree")
|
||||
{
|
||||
validHyperparameters = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy", "gamma", "max_features", "degree" };
|
||||
};
|
||||
std::string STree::graph()
|
||||
{
|
||||
return callMethodString("graph");
|
||||
}
|
||||
void STree::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Check if hyperparameters are valid
|
||||
const std::vector<std::string> validKeys = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy" };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
this->hyperparameters = hyperparameters;
|
||||
}
|
||||
} /* namespace pywrap */
|
@@ -6,10 +6,9 @@
|
||||
namespace pywrap {
|
||||
class STree : public PyClassifier {
|
||||
public:
|
||||
STree() : PyClassifier("stree", "Stree") {};
|
||||
STree();
|
||||
~STree() = default;
|
||||
std::string graph();
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
};
|
||||
} /* namespace pywrap */
|
||||
#endif /* STREE_H */
|
@@ -1,11 +1,8 @@
|
||||
#include "SVC.h"
|
||||
|
||||
namespace pywrap {
|
||||
void SVC::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
SVC::SVC() : PyClassifier("sklearn.svm", "SVC", true)
|
||||
{
|
||||
// Check if hyperparameters are valid
|
||||
const std::vector<std::string> validKeys = { "C", "gamma", "kernel", "random_state" };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
this->hyperparameters = hyperparameters;
|
||||
validHyperparameters = { "C", "gamma", "kernel", "random_state" };
|
||||
}
|
||||
} /* namespace pywrap */
|
@@ -5,10 +5,9 @@
|
||||
namespace pywrap {
|
||||
class SVC : public PyClassifier {
|
||||
public:
|
||||
SVC() : PyClassifier("sklearn.svm", "SVC", true) {};
|
||||
SVC();
|
||||
~SVC() = default;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
};
|
||||
|
||||
} /* namespace pywrap */
|
||||
#endif /* STREE_H */
|
||||
#endif /* SVC_H */
|
Reference in New Issue
Block a user