Fixed normal classifiers
This commit is contained in:
parent
06db8f51ce
commit
ef1bffcac3
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
@ -25,7 +25,8 @@
|
|||||||
"program": "${workspaceFolder}/build/src/Platform/main",
|
"program": "${workspaceFolder}/build/src/Platform/main",
|
||||||
"args": [
|
"args": [
|
||||||
"-m",
|
"-m",
|
||||||
"AODELd",
|
"AODE",
|
||||||
|
"--discretize",
|
||||||
"-p",
|
"-p",
|
||||||
"/Users/rmontanana/Code/discretizbench/datasets",
|
"/Users/rmontanana/Code/discretizbench/datasets",
|
||||||
"--stratified",
|
"--stratified",
|
||||||
|
@ -8,7 +8,7 @@ namespace bayesnet {
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
class AODELd : public Ensemble, public Proposal {
|
class AODELd : public Ensemble, public Proposal {
|
||||||
private:
|
private:
|
||||||
void trainModel();
|
void trainModel() override;
|
||||||
void buildModel() override;
|
void buildModel() override;
|
||||||
public:
|
public:
|
||||||
AODELd();
|
AODELd();
|
||||||
|
@ -10,26 +10,35 @@ namespace bayesnet {
|
|||||||
this->features = features;
|
this->features = features;
|
||||||
this->className = className;
|
this->className = className;
|
||||||
this->states = states;
|
this->states = states;
|
||||||
|
m = dataset.size(1);
|
||||||
|
n = dataset.size(0) - 1;
|
||||||
checkFitParameters();
|
checkFitParameters();
|
||||||
auto n_classes = states[className].size();
|
auto n_classes = states[className].size();
|
||||||
metrics = Metrics(dataset, features, className, n_classes);
|
metrics = Metrics(dataset, features, className, n_classes);
|
||||||
model.initialize();
|
model.initialize();
|
||||||
buildModel();
|
buildModel();
|
||||||
m = dataset.size(1);
|
|
||||||
n = dataset.size(0);
|
|
||||||
trainModel();
|
trainModel();
|
||||||
fitted = true;
|
fitted = true;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Classifier::buildDataset(Tensor& ytmp)
|
||||||
|
{
|
||||||
|
try {
|
||||||
|
auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
|
||||||
|
dataset = torch::cat({ dataset, yresized }, 0);
|
||||||
|
}
|
||||||
|
catch (const std::exception& e) {
|
||||||
|
std::cerr << e.what() << '\n';
|
||||||
|
cout << "X dimensions: " << dataset.sizes() << "\n";
|
||||||
|
cout << "y dimensions: " << ytmp.sizes() << "\n";
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
void Classifier::trainModel()
|
void Classifier::trainModel()
|
||||||
{
|
{
|
||||||
model.fit(dataset, features, className);
|
model.fit(dataset, features, className);
|
||||||
}
|
}
|
||||||
void Classifier::buildDataset(Tensor& ytmp)
|
|
||||||
{
|
|
||||||
ytmp = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
|
|
||||||
dataset = torch::cat({ dataset, ytmp }, 0);
|
|
||||||
}
|
|
||||||
// X is nxm where n is the number of features and m the number of samples
|
// X is nxm where n is the number of features and m the number of samples
|
||||||
Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||||
{
|
{
|
||||||
@ -56,7 +65,7 @@ namespace bayesnet {
|
|||||||
void Classifier::checkFitParameters()
|
void Classifier::checkFitParameters()
|
||||||
{
|
{
|
||||||
if (n != features.size()) {
|
if (n != features.size()) {
|
||||||
throw invalid_argument("X and features must have the same number of features");
|
throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
|
||||||
}
|
}
|
||||||
if (states.find(className) == states.end()) {
|
if (states.find(className) == states.end()) {
|
||||||
throw invalid_argument("className not found in states");
|
throw invalid_argument("className not found in states");
|
||||||
|
@ -23,7 +23,7 @@ namespace bayesnet {
|
|||||||
map<string, vector<int>> states;
|
map<string, vector<int>> states;
|
||||||
void checkFitParameters();
|
void checkFitParameters();
|
||||||
virtual void buildModel() = 0;
|
virtual void buildModel() = 0;
|
||||||
void trainModel();
|
virtual void trainModel();
|
||||||
public:
|
public:
|
||||||
Classifier(Network model);
|
Classifier(Network model);
|
||||||
virtual ~Classifier() = default;
|
virtual ~Classifier() = default;
|
||||||
|
@ -14,7 +14,7 @@ namespace bayesnet {
|
|||||||
protected:
|
protected:
|
||||||
unsigned n_models;
|
unsigned n_models;
|
||||||
vector<unique_ptr<Classifier>> models;
|
vector<unique_ptr<Classifier>> models;
|
||||||
void trainModel();
|
void trainModel() override;
|
||||||
vector<int> voting(Tensor& y_pred);
|
vector<int> voting(Tensor& y_pred);
|
||||||
public:
|
public:
|
||||||
Ensemble();
|
Ensemble();
|
||||||
|
Loading…
Reference in New Issue
Block a user