Fixed normal classifiers

This commit is contained in:
Ricardo Montañana Gómez 2023-08-07 13:50:11 +02:00
parent 06db8f51ce
commit ef1bffcac3
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 22 additions and 12 deletions

3
.vscode/launch.json vendored
View File

@ -25,7 +25,8 @@
"program": "${workspaceFolder}/build/src/Platform/main", "program": "${workspaceFolder}/build/src/Platform/main",
"args": [ "args": [
"-m", "-m",
"AODELd", "AODE",
"--discretize",
"-p", "-p",
"/Users/rmontanana/Code/discretizbench/datasets", "/Users/rmontanana/Code/discretizbench/datasets",
"--stratified", "--stratified",

View File

@ -8,7 +8,7 @@ namespace bayesnet {
using namespace std; using namespace std;
class AODELd : public Ensemble, public Proposal { class AODELd : public Ensemble, public Proposal {
private: private:
void trainModel(); void trainModel() override;
void buildModel() override; void buildModel() override;
public: public:
AODELd(); AODELd();

View File

@ -10,26 +10,35 @@ namespace bayesnet {
this->features = features; this->features = features;
this->className = className; this->className = className;
this->states = states; this->states = states;
m = dataset.size(1);
n = dataset.size(0) - 1;
checkFitParameters(); checkFitParameters();
auto n_classes = states[className].size(); auto n_classes = states[className].size();
metrics = Metrics(dataset, features, className, n_classes); metrics = Metrics(dataset, features, className, n_classes);
model.initialize(); model.initialize();
buildModel(); buildModel();
m = dataset.size(1);
n = dataset.size(0);
trainModel(); trainModel();
fitted = true; fitted = true;
return *this; return *this;
} }
void Classifier::buildDataset(Tensor& ytmp)
{
try {
auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
dataset = torch::cat({ dataset, yresized }, 0);
}
catch (const std::exception& e) {
std::cerr << e.what() << '\n';
cout << "X dimensions: " << dataset.sizes() << "\n";
cout << "y dimensions: " << ytmp.sizes() << "\n";
exit(1);
}
}
void Classifier::trainModel() void Classifier::trainModel()
{ {
model.fit(dataset, features, className); model.fit(dataset, features, className);
} }
void Classifier::buildDataset(Tensor& ytmp)
{
ytmp = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
dataset = torch::cat({ dataset, ytmp }, 0);
}
// X is nxm where n is the number of features and m the number of samples // X is nxm where n is the number of features and m the number of samples
Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
{ {
@ -56,7 +65,7 @@ namespace bayesnet {
void Classifier::checkFitParameters() void Classifier::checkFitParameters()
{ {
if (n != features.size()) { if (n != features.size()) {
throw invalid_argument("X and features must have the same number of features"); throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
} }
if (states.find(className) == states.end()) { if (states.find(className) == states.end()) {
throw invalid_argument("className not found in states"); throw invalid_argument("className not found in states");

View File

@ -23,7 +23,7 @@ namespace bayesnet {
map<string, vector<int>> states; map<string, vector<int>> states;
void checkFitParameters(); void checkFitParameters();
virtual void buildModel() = 0; virtual void buildModel() = 0;
void trainModel(); virtual void trainModel();
public: public:
Classifier(Network model); Classifier(Network model);
virtual ~Classifier() = default; virtual ~Classifier() = default;

View File

@ -14,7 +14,7 @@ namespace bayesnet {
protected: protected:
unsigned n_models; unsigned n_models;
vector<unique_ptr<Classifier>> models; vector<unique_ptr<Classifier>> models;
void trainModel(); void trainModel() override;
vector<int> voting(Tensor& y_pred); vector<int> voting(Tensor& y_pred);
public: public:
Ensemble(); Ensemble();