Add tests for Classifier class

This commit is contained in:
2024-04-08 01:25:14 +02:00
parent 9014649a0d
commit 50543e7929
6 changed files with 73 additions and 9 deletions

View File

@@ -75,11 +75,11 @@ namespace bayesnet {
if (torch::is_floating_point(dataset)) {
throw std::invalid_argument("dataset (X, y) must be of type Integer");
}
if (n != features.size()) {
throw std::invalid_argument("Classifier: X " + std::to_string(n) + " and features " + std::to_string(features.size()) + " must have the same number of features");
if (dataset.size(0) - 1 != features.size()) {
throw std::invalid_argument("Classifier: X " + std::to_string(dataset.size(0) - 1) + " and features " + std::to_string(features.size()) + " must have the same number of features");
}
if (states.find(className) == states.end()) {
throw std::invalid_argument("className not found in states");
throw std::invalid_argument("class name not found in states");
}
for (auto feature : features) {
if (states.find(feature) == states.end()) {
@@ -175,9 +175,9 @@ namespace bayesnet {
{
return model.topological_sort();
}
void Classifier::dump_cpt() const
std::string Classifier::dump_cpt() const
{
model.dump_cpt();
return model.dump_cpt();
}
void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
{