Fixed normal classifiers
This commit is contained in:
parent
06db8f51ce
commit
ef1bffcac3
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
@ -25,7 +25,8 @@
|
||||
"program": "${workspaceFolder}/build/src/Platform/main",
|
||||
"args": [
|
||||
"-m",
|
||||
"AODELd",
|
||||
"AODE",
|
||||
"--discretize",
|
||||
"-p",
|
||||
"/Users/rmontanana/Code/discretizbench/datasets",
|
||||
"--stratified",
|
||||
|
@ -8,7 +8,7 @@ namespace bayesnet {
|
||||
using namespace std;
|
||||
class AODELd : public Ensemble, public Proposal {
|
||||
private:
|
||||
void trainModel();
|
||||
void trainModel() override;
|
||||
void buildModel() override;
|
||||
public:
|
||||
AODELd();
|
||||
|
@ -10,26 +10,35 @@ namespace bayesnet {
|
||||
this->features = features;
|
||||
this->className = className;
|
||||
this->states = states;
|
||||
m = dataset.size(1);
|
||||
n = dataset.size(0) - 1;
|
||||
checkFitParameters();
|
||||
auto n_classes = states[className].size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
model.initialize();
|
||||
buildModel();
|
||||
m = dataset.size(1);
|
||||
n = dataset.size(0);
|
||||
trainModel();
|
||||
fitted = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Classifier::buildDataset(Tensor& ytmp)
|
||||
{
|
||||
try {
|
||||
auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
|
||||
dataset = torch::cat({ dataset, yresized }, 0);
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cerr << e.what() << '\n';
|
||||
cout << "X dimensions: " << dataset.sizes() << "\n";
|
||||
cout << "y dimensions: " << ytmp.sizes() << "\n";
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
void Classifier::trainModel()
|
||||
{
|
||||
model.fit(dataset, features, className);
|
||||
}
|
||||
void Classifier::buildDataset(Tensor& ytmp)
|
||||
{
|
||||
ytmp = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
|
||||
dataset = torch::cat({ dataset, ytmp }, 0);
|
||||
}
|
||||
// X is nxm where n is the number of features and m the number of samples
|
||||
Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
{
|
||||
@ -56,7 +65,7 @@ namespace bayesnet {
|
||||
void Classifier::checkFitParameters()
|
||||
{
|
||||
if (n != features.size()) {
|
||||
throw invalid_argument("X and features must have the same number of features");
|
||||
throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
|
||||
}
|
||||
if (states.find(className) == states.end()) {
|
||||
throw invalid_argument("className not found in states");
|
||||
|
@ -23,7 +23,7 @@ namespace bayesnet {
|
||||
map<string, vector<int>> states;
|
||||
void checkFitParameters();
|
||||
virtual void buildModel() = 0;
|
||||
void trainModel();
|
||||
virtual void trainModel();
|
||||
public:
|
||||
Classifier(Network model);
|
||||
virtual ~Classifier() = default;
|
||||
|
@ -14,7 +14,7 @@ namespace bayesnet {
|
||||
protected:
|
||||
unsigned n_models;
|
||||
vector<unique_ptr<Classifier>> models;
|
||||
void trainModel();
|
||||
void trainModel() override;
|
||||
vector<int> voting(Tensor& y_pred);
|
||||
public:
|
||||
Ensemble();
|
||||
|
Loading…
Reference in New Issue
Block a user