Add weigths as parameter

This commit is contained in:
Ricardo Montañana Gómez 2023-08-15 15:04:56 +02:00
parent a062ebf445
commit 24b68f9ae2
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
21 changed files with 31 additions and 29 deletions

5
.vscode/launch.json vendored
View File

@ -25,13 +25,12 @@
"program": "${workspaceFolder}/build/src/Platform/main", "program": "${workspaceFolder}/build/src/Platform/main",
"args": [ "args": [
"-m", "-m",
"SPODE", "TANLd",
"-p", "-p",
"/Users/rmontanana/Code/discretizbench/datasets", "/Users/rmontanana/Code/discretizbench/datasets",
"--stratified", "--stratified",
"--discretize",
"-d", "-d",
"letter" "iris"
], ],
"cwd": "/Users/rmontanana/Code/discretizbench", "cwd": "/Users/rmontanana/Code/discretizbench",
}, },

View File

@ -2,7 +2,7 @@
namespace bayesnet { namespace bayesnet {
AODE::AODE() : Ensemble() {} AODE::AODE() : Ensemble() {}
void AODE::buildModel() void AODE::buildModel(const torch::Tensor& weights)
{ {
models.clear(); models.clear();
for (int i = 0; i < features.size(); ++i) { for (int i = 0; i < features.size(); ++i) {

View File

@ -5,7 +5,7 @@
namespace bayesnet { namespace bayesnet {
class AODE : public Ensemble { class AODE : public Ensemble {
protected: protected:
void buildModel() override; void buildModel(const torch::Tensor& weights) override;
public: public:
AODE(); AODE();
virtual ~AODE() {}; virtual ~AODE() {};

View File

@ -19,7 +19,7 @@ namespace bayesnet {
return *this; return *this;
} }
void AODELd::buildModel() void AODELd::buildModel(const torch::Tensor& weights)
{ {
models.clear(); models.clear();
for (int i = 0; i < features.size(); ++i) { for (int i = 0; i < features.size(); ++i) {
@ -27,7 +27,7 @@ namespace bayesnet {
} }
n_models = models.size(); n_models = models.size();
} }
void AODELd::trainModel() void AODELd::trainModel(const torch::Tensor& weights)
{ {
for (const auto& model : models) { for (const auto& model : models) {
model->fit(Xf, y, features, className, states); model->fit(Xf, y, features, className, states);

View File

@ -8,8 +8,8 @@ namespace bayesnet {
using namespace std; using namespace std;
class AODELd : public Ensemble, public Proposal { class AODELd : public Ensemble, public Proposal {
protected: protected:
void trainModel() override; void trainModel(const torch::Tensor& weights) override;
void buildModel() override; void buildModel(const torch::Tensor& weights) override;
public: public:
AODELd(); AODELd();
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, vector<string>& features_, string className_, map<string, vector<int>>& states_) override; AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, vector<string>& features_, string className_, map<string, vector<int>>& states_) override;

View File

@ -6,7 +6,7 @@ namespace bayesnet {
using namespace std; using namespace std;
class BaseClassifier { class BaseClassifier {
protected: protected:
virtual void trainModel() = 0; virtual void trainModel(const torch::Tensor& weights) = 0;
public: public:
// X is nxm vector, y is nx1 vector // X is nxm vector, y is nx1 vector
virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0; virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;

View File

@ -52,7 +52,8 @@ namespace bayesnet {
auto mask = samples.index({ -1, "..." }) == value; auto mask = samples.index({ -1, "..." }) == value;
auto first_dataset = samples.index({ index_first, mask }); auto first_dataset = samples.index({ index_first, mask });
auto second_dataset = samples.index({ index_second, mask }); auto second_dataset = samples.index({ index_second, mask });
auto mi = mutualInformation(first_dataset, second_dataset, weights); auto weights_dataset = weights.index({ mask });
auto mi = mutualInformation(first_dataset, second_dataset, weights_dataset);
auto pb = margin[value].item<float>(); auto pb = margin[value].item<float>();
accumulated += pb * mi; accumulated += pb * mi;
} }

View File

@ -16,8 +16,10 @@ namespace bayesnet {
auto n_classes = states[className].size(); auto n_classes = states[className].size();
metrics = Metrics(dataset, features, className, n_classes); metrics = Metrics(dataset, features, className, n_classes);
model.initialize(); model.initialize();
buildModel(); // TODO weights can't be ones
trainModel(); const torch::Tensor weights = torch::ones({ m }, torch::kFloat);
buildModel(weights);
trainModel(weights);
fitted = true; fitted = true;
return *this; return *this;
} }
@ -35,9 +37,8 @@ namespace bayesnet {
exit(1); exit(1);
} }
} }
void Classifier::trainModel() void Classifier::trainModel(const torch::Tensor& weights)
{ {
const torch::Tensor weights = torch::ones({ m });
model.fit(dataset, weights, features, className, states); model.fit(dataset, weights, features, className, states);
} }
// X is nxm where n is the number of features and m the number of samples // X is nxm where n is the number of features and m the number of samples

View File

@ -21,10 +21,9 @@ namespace bayesnet {
string className; string className;
map<string, vector<int>> states; map<string, vector<int>> states;
Tensor dataset; // (n+1)xm tensor Tensor dataset; // (n+1)xm tensor
Tensor weights;
void checkFitParameters(); void checkFitParameters();
virtual void buildModel() = 0; virtual void buildModel(const torch::Tensor& weights) = 0;
void trainModel() override; void trainModel(const torch::Tensor& weights) override;
public: public:
Classifier(Network model); Classifier(Network model);
virtual ~Classifier() = default; virtual ~Classifier() = default;

View File

@ -5,7 +5,7 @@ namespace bayesnet {
Ensemble::Ensemble() : Classifier(Network()) {} Ensemble::Ensemble() : Classifier(Network()) {}
void Ensemble::trainModel() void Ensemble::trainModel(const torch::Tensor& weights)
{ {
n_models = models.size(); n_models = models.size();
for (auto i = 0; i < n_models; ++i) { for (auto i = 0; i < n_models; ++i) {

View File

@ -14,7 +14,7 @@ namespace bayesnet {
protected: protected:
unsigned n_models; unsigned n_models;
vector<unique_ptr<Classifier>> models; vector<unique_ptr<Classifier>> models;
void trainModel() override; void trainModel(const torch::Tensor& weights) override;
vector<int> voting(Tensor& y_pred); vector<int> voting(Tensor& y_pred);
public: public:
Ensemble(); Ensemble();

View File

@ -4,7 +4,7 @@ namespace bayesnet {
using namespace torch; using namespace torch;
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {} KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {}
void KDB::buildModel() void KDB::buildModel(const torch::Tensor& weights)
{ {
/* /*
1. For each feature Xi, compute mutual information, I(X;C), 1. For each feature Xi, compute mutual information, I(X;C),

View File

@ -1,5 +1,6 @@
#ifndef KDB_H #ifndef KDB_H
#define KDB_H #define KDB_H
#include <torch/torch.h>
#include "Classifier.h" #include "Classifier.h"
#include "bayesnetUtils.h" #include "bayesnetUtils.h"
namespace bayesnet { namespace bayesnet {
@ -11,7 +12,7 @@ namespace bayesnet {
float theta; float theta;
void add_m_edges(int idx, vector<int>& S, Tensor& weights); void add_m_edges(int idx, vector<int>& S, Tensor& weights);
protected: protected:
void buildModel() override; void buildModel(const torch::Tensor& weights) override;
public: public:
explicit KDB(int k, float theta = 0.03); explicit KDB(int k, float theta = 0.03);
virtual ~KDB() {}; virtual ~KDB() {};

View File

@ -107,7 +107,7 @@ namespace bayesnet {
void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states, const torch::Tensor& weights) void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states, const torch::Tensor& weights)
{ {
if (weights.size(0) != n_samples) { if (weights.size(0) != n_samples) {
throw invalid_argument("Weights must have the same number of elements as samples in Network::fit"); throw invalid_argument("Weights (" + to_string(weights.size(0)) + ") must have the same number of elements as samples (" + to_string(n_samples) + ") in Network::fit");
} }
if (n_samples != n_samples_y) { if (n_samples != n_samples_y) {
throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")"); throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")");

View File

@ -65,6 +65,7 @@ namespace bayesnet {
//Update new states of the feature/node //Update new states of the feature/node
states[pFeatures[index]] = xStates; states[pFeatures[index]] = xStates;
} }
// TODO weights can't be ones
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat); const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
model.fit(pDataset, weights, pFeatures, pClassName, states); model.fit(pDataset, weights, pFeatures, pClassName, states);
} }

View File

@ -4,7 +4,7 @@ namespace bayesnet {
SPODE::SPODE(int root) : Classifier(Network()), root(root) {} SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
void SPODE::buildModel() void SPODE::buildModel(const torch::Tensor& weights)
{ {
// 0. Add all nodes to the model // 0. Add all nodes to the model
addNodes(); addNodes();

View File

@ -7,7 +7,7 @@ namespace bayesnet {
private: private:
int root; int root;
protected: protected:
void buildModel() override; void buildModel(const torch::Tensor& weights) override;
public: public:
explicit SPODE(int root); explicit SPODE(int root);
virtual ~SPODE() {}; virtual ~SPODE() {};

View File

@ -5,7 +5,7 @@ namespace bayesnet {
TAN::TAN() : Classifier(Network()) {} TAN::TAN() : Classifier(Network()) {}
void TAN::buildModel() void TAN::buildModel(const torch::Tensor& weights)
{ {
// 0. Add all nodes to the model // 0. Add all nodes to the model
addNodes(); addNodes();

View File

@ -7,7 +7,7 @@ namespace bayesnet {
class TAN : public Classifier { class TAN : public Classifier {
private: private:
protected: protected:
void buildModel() override; void buildModel(const torch::Tensor& weights) override;
public: public:
TAN(); TAN();
virtual ~TAN() {}; virtual ~TAN() {};

View File

@ -6,7 +6,7 @@
#include "Colors.h" #include "Colors.h"
using json = nlohmann::json; using json = nlohmann::json;
const int MAXL = 121; const int MAXL = 122;
namespace platform { namespace platform {
using namespace std; using namespace std;
class Report { class Report {

View File

@ -103,7 +103,7 @@ int main(int argc, char** argv)
*/ */
auto env = platform::DotEnv(); auto env = platform::DotEnv();
auto experiment = platform::Experiment(); auto experiment = platform::Experiment();
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("1.0.0"); experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform")); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy"); experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy");
for (auto seed : seeds) { for (auto seed : seeds) {