Add weigths as parameter
This commit is contained in:
parent
a062ebf445
commit
24b68f9ae2
5
.vscode/launch.json
vendored
5
.vscode/launch.json
vendored
@ -25,13 +25,12 @@
|
|||||||
"program": "${workspaceFolder}/build/src/Platform/main",
|
"program": "${workspaceFolder}/build/src/Platform/main",
|
||||||
"args": [
|
"args": [
|
||||||
"-m",
|
"-m",
|
||||||
"SPODE",
|
"TANLd",
|
||||||
"-p",
|
"-p",
|
||||||
"/Users/rmontanana/Code/discretizbench/datasets",
|
"/Users/rmontanana/Code/discretizbench/datasets",
|
||||||
"--stratified",
|
"--stratified",
|
||||||
"--discretize",
|
|
||||||
"-d",
|
"-d",
|
||||||
"letter"
|
"iris"
|
||||||
],
|
],
|
||||||
"cwd": "/Users/rmontanana/Code/discretizbench",
|
"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||||
},
|
},
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
AODE::AODE() : Ensemble() {}
|
AODE::AODE() : Ensemble() {}
|
||||||
void AODE::buildModel()
|
void AODE::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
models.clear();
|
models.clear();
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class AODE : public Ensemble {
|
class AODE : public Ensemble {
|
||||||
protected:
|
protected:
|
||||||
void buildModel() override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
public:
|
public:
|
||||||
AODE();
|
AODE();
|
||||||
virtual ~AODE() {};
|
virtual ~AODE() {};
|
||||||
|
@ -19,7 +19,7 @@ namespace bayesnet {
|
|||||||
return *this;
|
return *this;
|
||||||
|
|
||||||
}
|
}
|
||||||
void AODELd::buildModel()
|
void AODELd::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
models.clear();
|
models.clear();
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
@ -27,7 +27,7 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
n_models = models.size();
|
n_models = models.size();
|
||||||
}
|
}
|
||||||
void AODELd::trainModel()
|
void AODELd::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
for (const auto& model : models) {
|
for (const auto& model : models) {
|
||||||
model->fit(Xf, y, features, className, states);
|
model->fit(Xf, y, features, className, states);
|
||||||
|
@ -8,8 +8,8 @@ namespace bayesnet {
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
class AODELd : public Ensemble, public Proposal {
|
class AODELd : public Ensemble, public Proposal {
|
||||||
protected:
|
protected:
|
||||||
void trainModel() override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
void buildModel() override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
public:
|
public:
|
||||||
AODELd();
|
AODELd();
|
||||||
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, vector<string>& features_, string className_, map<string, vector<int>>& states_) override;
|
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, vector<string>& features_, string className_, map<string, vector<int>>& states_) override;
|
||||||
|
@ -6,7 +6,7 @@ namespace bayesnet {
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
class BaseClassifier {
|
class BaseClassifier {
|
||||||
protected:
|
protected:
|
||||||
virtual void trainModel() = 0;
|
virtual void trainModel(const torch::Tensor& weights) = 0;
|
||||||
public:
|
public:
|
||||||
// X is nxm vector, y is nx1 vector
|
// X is nxm vector, y is nx1 vector
|
||||||
virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
||||||
|
@ -52,7 +52,8 @@ namespace bayesnet {
|
|||||||
auto mask = samples.index({ -1, "..." }) == value;
|
auto mask = samples.index({ -1, "..." }) == value;
|
||||||
auto first_dataset = samples.index({ index_first, mask });
|
auto first_dataset = samples.index({ index_first, mask });
|
||||||
auto second_dataset = samples.index({ index_second, mask });
|
auto second_dataset = samples.index({ index_second, mask });
|
||||||
auto mi = mutualInformation(first_dataset, second_dataset, weights);
|
auto weights_dataset = weights.index({ mask });
|
||||||
|
auto mi = mutualInformation(first_dataset, second_dataset, weights_dataset);
|
||||||
auto pb = margin[value].item<float>();
|
auto pb = margin[value].item<float>();
|
||||||
accumulated += pb * mi;
|
accumulated += pb * mi;
|
||||||
}
|
}
|
||||||
|
@ -16,8 +16,10 @@ namespace bayesnet {
|
|||||||
auto n_classes = states[className].size();
|
auto n_classes = states[className].size();
|
||||||
metrics = Metrics(dataset, features, className, n_classes);
|
metrics = Metrics(dataset, features, className, n_classes);
|
||||||
model.initialize();
|
model.initialize();
|
||||||
buildModel();
|
// TODO weights can't be ones
|
||||||
trainModel();
|
const torch::Tensor weights = torch::ones({ m }, torch::kFloat);
|
||||||
|
buildModel(weights);
|
||||||
|
trainModel(weights);
|
||||||
fitted = true;
|
fitted = true;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -35,9 +37,8 @@ namespace bayesnet {
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void Classifier::trainModel()
|
void Classifier::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
const torch::Tensor weights = torch::ones({ m });
|
|
||||||
model.fit(dataset, weights, features, className, states);
|
model.fit(dataset, weights, features, className, states);
|
||||||
}
|
}
|
||||||
// X is nxm where n is the number of features and m the number of samples
|
// X is nxm where n is the number of features and m the number of samples
|
||||||
|
@ -21,10 +21,9 @@ namespace bayesnet {
|
|||||||
string className;
|
string className;
|
||||||
map<string, vector<int>> states;
|
map<string, vector<int>> states;
|
||||||
Tensor dataset; // (n+1)xm tensor
|
Tensor dataset; // (n+1)xm tensor
|
||||||
Tensor weights;
|
|
||||||
void checkFitParameters();
|
void checkFitParameters();
|
||||||
virtual void buildModel() = 0;
|
virtual void buildModel(const torch::Tensor& weights) = 0;
|
||||||
void trainModel() override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
public:
|
public:
|
||||||
Classifier(Network model);
|
Classifier(Network model);
|
||||||
virtual ~Classifier() = default;
|
virtual ~Classifier() = default;
|
||||||
|
@ -5,7 +5,7 @@ namespace bayesnet {
|
|||||||
|
|
||||||
Ensemble::Ensemble() : Classifier(Network()) {}
|
Ensemble::Ensemble() : Classifier(Network()) {}
|
||||||
|
|
||||||
void Ensemble::trainModel()
|
void Ensemble::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
n_models = models.size();
|
n_models = models.size();
|
||||||
for (auto i = 0; i < n_models; ++i) {
|
for (auto i = 0; i < n_models; ++i) {
|
||||||
|
@ -14,7 +14,7 @@ namespace bayesnet {
|
|||||||
protected:
|
protected:
|
||||||
unsigned n_models;
|
unsigned n_models;
|
||||||
vector<unique_ptr<Classifier>> models;
|
vector<unique_ptr<Classifier>> models;
|
||||||
void trainModel() override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
vector<int> voting(Tensor& y_pred);
|
vector<int> voting(Tensor& y_pred);
|
||||||
public:
|
public:
|
||||||
Ensemble();
|
Ensemble();
|
||||||
|
@ -4,7 +4,7 @@ namespace bayesnet {
|
|||||||
using namespace torch;
|
using namespace torch;
|
||||||
|
|
||||||
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {}
|
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {}
|
||||||
void KDB::buildModel()
|
void KDB::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
/*
|
/*
|
||||||
1. For each feature Xi, compute mutual information, I(X;C),
|
1. For each feature Xi, compute mutual information, I(X;C),
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#ifndef KDB_H
|
#ifndef KDB_H
|
||||||
#define KDB_H
|
#define KDB_H
|
||||||
|
#include <torch/torch.h>
|
||||||
#include "Classifier.h"
|
#include "Classifier.h"
|
||||||
#include "bayesnetUtils.h"
|
#include "bayesnetUtils.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
@ -11,7 +12,7 @@ namespace bayesnet {
|
|||||||
float theta;
|
float theta;
|
||||||
void add_m_edges(int idx, vector<int>& S, Tensor& weights);
|
void add_m_edges(int idx, vector<int>& S, Tensor& weights);
|
||||||
protected:
|
protected:
|
||||||
void buildModel() override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
public:
|
public:
|
||||||
explicit KDB(int k, float theta = 0.03);
|
explicit KDB(int k, float theta = 0.03);
|
||||||
virtual ~KDB() {};
|
virtual ~KDB() {};
|
||||||
|
@ -107,7 +107,7 @@ namespace bayesnet {
|
|||||||
void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states, const torch::Tensor& weights)
|
void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states, const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
if (weights.size(0) != n_samples) {
|
if (weights.size(0) != n_samples) {
|
||||||
throw invalid_argument("Weights must have the same number of elements as samples in Network::fit");
|
throw invalid_argument("Weights (" + to_string(weights.size(0)) + ") must have the same number of elements as samples (" + to_string(n_samples) + ") in Network::fit");
|
||||||
}
|
}
|
||||||
if (n_samples != n_samples_y) {
|
if (n_samples != n_samples_y) {
|
||||||
throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")");
|
throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")");
|
||||||
|
@ -65,6 +65,7 @@ namespace bayesnet {
|
|||||||
//Update new states of the feature/node
|
//Update new states of the feature/node
|
||||||
states[pFeatures[index]] = xStates;
|
states[pFeatures[index]] = xStates;
|
||||||
}
|
}
|
||||||
|
// TODO weights can't be ones
|
||||||
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
|
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
|
||||||
model.fit(pDataset, weights, pFeatures, pClassName, states);
|
model.fit(pDataset, weights, pFeatures, pClassName, states);
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ namespace bayesnet {
|
|||||||
|
|
||||||
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
|
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
|
||||||
|
|
||||||
void SPODE::buildModel()
|
void SPODE::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
// 0. Add all nodes to the model
|
// 0. Add all nodes to the model
|
||||||
addNodes();
|
addNodes();
|
||||||
|
@ -7,7 +7,7 @@ namespace bayesnet {
|
|||||||
private:
|
private:
|
||||||
int root;
|
int root;
|
||||||
protected:
|
protected:
|
||||||
void buildModel() override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
public:
|
public:
|
||||||
explicit SPODE(int root);
|
explicit SPODE(int root);
|
||||||
virtual ~SPODE() {};
|
virtual ~SPODE() {};
|
||||||
|
@ -5,7 +5,7 @@ namespace bayesnet {
|
|||||||
|
|
||||||
TAN::TAN() : Classifier(Network()) {}
|
TAN::TAN() : Classifier(Network()) {}
|
||||||
|
|
||||||
void TAN::buildModel()
|
void TAN::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
// 0. Add all nodes to the model
|
// 0. Add all nodes to the model
|
||||||
addNodes();
|
addNodes();
|
||||||
|
@ -7,7 +7,7 @@ namespace bayesnet {
|
|||||||
class TAN : public Classifier {
|
class TAN : public Classifier {
|
||||||
private:
|
private:
|
||||||
protected:
|
protected:
|
||||||
void buildModel() override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
public:
|
public:
|
||||||
TAN();
|
TAN();
|
||||||
virtual ~TAN() {};
|
virtual ~TAN() {};
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
const int MAXL = 121;
|
const int MAXL = 122;
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
class Report {
|
class Report {
|
||||||
|
@ -103,7 +103,7 @@ int main(int argc, char** argv)
|
|||||||
*/
|
*/
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
auto experiment = platform::Experiment();
|
auto experiment = platform::Experiment();
|
||||||
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("1.0.0");
|
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
|
||||||
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
||||||
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy");
|
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy");
|
||||||
for (auto seed : seeds) {
|
for (auto seed : seeds) {
|
||||||
|
Loading…
Reference in New Issue
Block a user