Add tests for Classifier class

This commit is contained in:
2024-04-08 01:25:14 +02:00
parent 9014649a0d
commit 50543e7929
6 changed files with 73 additions and 9 deletions

View File

@@ -75,11 +75,11 @@ namespace bayesnet {
if (torch::is_floating_point(dataset)) {
throw std::invalid_argument("dataset (X, y) must be of type Integer");
}
if (n != features.size()) {
throw std::invalid_argument("Classifier: X " + std::to_string(n) + " and features " + std::to_string(features.size()) + " must have the same number of features");
if (dataset.size(0) - 1 != features.size()) {
throw std::invalid_argument("Classifier: X " + std::to_string(dataset.size(0) - 1) + " and features " + std::to_string(features.size()) + " must have the same number of features");
}
if (states.find(className) == states.end()) {
throw std::invalid_argument("className not found in states");
throw std::invalid_argument("class name not found in states");
}
for (auto feature : features) {
if (states.find(feature) == states.end()) {
@@ -175,9 +175,9 @@ namespace bayesnet {
{
return model.topological_sort();
}
void Classifier::dump_cpt() const
std::string Classifier::dump_cpt() const
{
model.dump_cpt();
return model.dump_cpt();
}
void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
{

View File

@@ -30,7 +30,7 @@ namespace bayesnet {
std::vector<std::string> show() const override;
std::vector<std::string> topological_order() override;
std::vector<std::string> getNotes() const override { return notes; }
void dump_cpt() const override;
std::string dump_cpt() const override;
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
protected:
bool fitted;