Fixed normal classifiers

This commit is contained in:
Ricardo Montañana Gómez 2023-08-07 13:50:11 +02:00
parent 06db8f51ce
commit ef1bffcac3
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 22 additions and 12 deletions

3
.vscode/launch.json vendored
View File

@ -25,7 +25,8 @@
"program": "${workspaceFolder}/build/src/Platform/main",
"args": [
"-m",
"AODELd",
"AODE",
"--discretize",
"-p",
"/Users/rmontanana/Code/discretizbench/datasets",
"--stratified",

View File

@ -8,7 +8,7 @@ namespace bayesnet {
using namespace std;
class AODELd : public Ensemble, public Proposal {
private:
void trainModel();
void trainModel() override;
void buildModel() override;
public:
AODELd();

View File

@ -10,26 +10,35 @@ namespace bayesnet {
this->features = features;
this->className = className;
this->states = states;
m = dataset.size(1);
n = dataset.size(0) - 1;
checkFitParameters();
auto n_classes = states[className].size();
metrics = Metrics(dataset, features, className, n_classes);
model.initialize();
buildModel();
m = dataset.size(1);
n = dataset.size(0);
trainModel();
fitted = true;
return *this;
}
void Classifier::buildDataset(Tensor& ytmp)
{
try {
auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
dataset = torch::cat({ dataset, yresized }, 0);
}
catch (const std::exception& e) {
std::cerr << e.what() << '\n';
cout << "X dimensions: " << dataset.sizes() << "\n";
cout << "y dimensions: " << ytmp.sizes() << "\n";
exit(1);
}
}
void Classifier::trainModel()
{
model.fit(dataset, features, className);
}
void Classifier::buildDataset(Tensor& ytmp)
{
ytmp = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
dataset = torch::cat({ dataset, ytmp }, 0);
}
// X is nxm where n is the number of features and m the number of samples
Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
{
@ -56,7 +65,7 @@ namespace bayesnet {
void Classifier::checkFitParameters()
{
if (n != features.size()) {
throw invalid_argument("X and features must have the same number of features");
throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
}
if (states.find(className) == states.end()) {
throw invalid_argument("className not found in states");

View File

@ -23,7 +23,7 @@ namespace bayesnet {
map<string, vector<int>> states;
void checkFitParameters();
virtual void buildModel() = 0;
void trainModel();
virtual void trainModel();
public:
Classifier(Network model);
virtual ~Classifier() = default;

View File

@ -14,7 +14,7 @@ namespace bayesnet {
protected:
unsigned n_models;
vector<unique_ptr<Classifier>> models;
void trainModel();
void trainModel() override;
vector<int> voting(Tensor& y_pred);
public:
Ensemble();