Add weigths as parameter
This commit is contained in:
parent
a062ebf445
commit
24b68f9ae2
5
.vscode/launch.json
vendored
5
.vscode/launch.json
vendored
@ -25,13 +25,12 @@
|
||||
"program": "${workspaceFolder}/build/src/Platform/main",
|
||||
"args": [
|
||||
"-m",
|
||||
"SPODE",
|
||||
"TANLd",
|
||||
"-p",
|
||||
"/Users/rmontanana/Code/discretizbench/datasets",
|
||||
"--stratified",
|
||||
"--discretize",
|
||||
"-d",
|
||||
"letter"
|
||||
"iris"
|
||||
],
|
||||
"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||
},
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
namespace bayesnet {
|
||||
AODE::AODE() : Ensemble() {}
|
||||
void AODE::buildModel()
|
||||
void AODE::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
models.clear();
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
|
@ -5,7 +5,7 @@
|
||||
namespace bayesnet {
|
||||
class AODE : public Ensemble {
|
||||
protected:
|
||||
void buildModel() override;
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
AODE();
|
||||
virtual ~AODE() {};
|
||||
|
@ -19,7 +19,7 @@ namespace bayesnet {
|
||||
return *this;
|
||||
|
||||
}
|
||||
void AODELd::buildModel()
|
||||
void AODELd::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
models.clear();
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
@ -27,7 +27,7 @@ namespace bayesnet {
|
||||
}
|
||||
n_models = models.size();
|
||||
}
|
||||
void AODELd::trainModel()
|
||||
void AODELd::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
for (const auto& model : models) {
|
||||
model->fit(Xf, y, features, className, states);
|
||||
|
@ -8,8 +8,8 @@ namespace bayesnet {
|
||||
using namespace std;
|
||||
class AODELd : public Ensemble, public Proposal {
|
||||
protected:
|
||||
void trainModel() override;
|
||||
void buildModel() override;
|
||||
void trainModel(const torch::Tensor& weights) override;
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
AODELd();
|
||||
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, vector<string>& features_, string className_, map<string, vector<int>>& states_) override;
|
||||
|
@ -6,7 +6,7 @@ namespace bayesnet {
|
||||
using namespace std;
|
||||
class BaseClassifier {
|
||||
protected:
|
||||
virtual void trainModel() = 0;
|
||||
virtual void trainModel(const torch::Tensor& weights) = 0;
|
||||
public:
|
||||
// X is nxm vector, y is nx1 vector
|
||||
virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
||||
|
@ -52,7 +52,8 @@ namespace bayesnet {
|
||||
auto mask = samples.index({ -1, "..." }) == value;
|
||||
auto first_dataset = samples.index({ index_first, mask });
|
||||
auto second_dataset = samples.index({ index_second, mask });
|
||||
auto mi = mutualInformation(first_dataset, second_dataset, weights);
|
||||
auto weights_dataset = weights.index({ mask });
|
||||
auto mi = mutualInformation(first_dataset, second_dataset, weights_dataset);
|
||||
auto pb = margin[value].item<float>();
|
||||
accumulated += pb * mi;
|
||||
}
|
||||
|
@ -16,8 +16,10 @@ namespace bayesnet {
|
||||
auto n_classes = states[className].size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
model.initialize();
|
||||
buildModel();
|
||||
trainModel();
|
||||
// TODO weights can't be ones
|
||||
const torch::Tensor weights = torch::ones({ m }, torch::kFloat);
|
||||
buildModel(weights);
|
||||
trainModel(weights);
|
||||
fitted = true;
|
||||
return *this;
|
||||
}
|
||||
@ -35,9 +37,8 @@ namespace bayesnet {
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
void Classifier::trainModel()
|
||||
void Classifier::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
const torch::Tensor weights = torch::ones({ m });
|
||||
model.fit(dataset, weights, features, className, states);
|
||||
}
|
||||
// X is nxm where n is the number of features and m the number of samples
|
||||
|
@ -21,10 +21,9 @@ namespace bayesnet {
|
||||
string className;
|
||||
map<string, vector<int>> states;
|
||||
Tensor dataset; // (n+1)xm tensor
|
||||
Tensor weights;
|
||||
void checkFitParameters();
|
||||
virtual void buildModel() = 0;
|
||||
void trainModel() override;
|
||||
virtual void buildModel(const torch::Tensor& weights) = 0;
|
||||
void trainModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
Classifier(Network model);
|
||||
virtual ~Classifier() = default;
|
||||
|
@ -5,7 +5,7 @@ namespace bayesnet {
|
||||
|
||||
Ensemble::Ensemble() : Classifier(Network()) {}
|
||||
|
||||
void Ensemble::trainModel()
|
||||
void Ensemble::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
n_models = models.size();
|
||||
for (auto i = 0; i < n_models; ++i) {
|
||||
|
@ -14,7 +14,7 @@ namespace bayesnet {
|
||||
protected:
|
||||
unsigned n_models;
|
||||
vector<unique_ptr<Classifier>> models;
|
||||
void trainModel() override;
|
||||
void trainModel(const torch::Tensor& weights) override;
|
||||
vector<int> voting(Tensor& y_pred);
|
||||
public:
|
||||
Ensemble();
|
||||
|
@ -4,7 +4,7 @@ namespace bayesnet {
|
||||
using namespace torch;
|
||||
|
||||
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {}
|
||||
void KDB::buildModel()
|
||||
void KDB::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
/*
|
||||
1. For each feature Xi, compute mutual information, I(X;C),
|
||||
|
@ -1,5 +1,6 @@
|
||||
#ifndef KDB_H
|
||||
#define KDB_H
|
||||
#include <torch/torch.h>
|
||||
#include "Classifier.h"
|
||||
#include "bayesnetUtils.h"
|
||||
namespace bayesnet {
|
||||
@ -11,7 +12,7 @@ namespace bayesnet {
|
||||
float theta;
|
||||
void add_m_edges(int idx, vector<int>& S, Tensor& weights);
|
||||
protected:
|
||||
void buildModel() override;
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
explicit KDB(int k, float theta = 0.03);
|
||||
virtual ~KDB() {};
|
||||
|
@ -107,7 +107,7 @@ namespace bayesnet {
|
||||
void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states, const torch::Tensor& weights)
|
||||
{
|
||||
if (weights.size(0) != n_samples) {
|
||||
throw invalid_argument("Weights must have the same number of elements as samples in Network::fit");
|
||||
throw invalid_argument("Weights (" + to_string(weights.size(0)) + ") must have the same number of elements as samples (" + to_string(n_samples) + ") in Network::fit");
|
||||
}
|
||||
if (n_samples != n_samples_y) {
|
||||
throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")");
|
||||
|
@ -65,6 +65,7 @@ namespace bayesnet {
|
||||
//Update new states of the feature/node
|
||||
states[pFeatures[index]] = xStates;
|
||||
}
|
||||
// TODO weights can't be ones
|
||||
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
|
||||
model.fit(pDataset, weights, pFeatures, pClassName, states);
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ namespace bayesnet {
|
||||
|
||||
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
|
||||
|
||||
void SPODE::buildModel()
|
||||
void SPODE::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// 0. Add all nodes to the model
|
||||
addNodes();
|
||||
|
@ -7,7 +7,7 @@ namespace bayesnet {
|
||||
private:
|
||||
int root;
|
||||
protected:
|
||||
void buildModel() override;
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
explicit SPODE(int root);
|
||||
virtual ~SPODE() {};
|
||||
|
@ -5,7 +5,7 @@ namespace bayesnet {
|
||||
|
||||
TAN::TAN() : Classifier(Network()) {}
|
||||
|
||||
void TAN::buildModel()
|
||||
void TAN::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// 0. Add all nodes to the model
|
||||
addNodes();
|
||||
|
@ -7,7 +7,7 @@ namespace bayesnet {
|
||||
class TAN : public Classifier {
|
||||
private:
|
||||
protected:
|
||||
void buildModel() override;
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
TAN();
|
||||
virtual ~TAN() {};
|
||||
|
@ -6,7 +6,7 @@
|
||||
#include "Colors.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
const int MAXL = 121;
|
||||
const int MAXL = 122;
|
||||
namespace platform {
|
||||
using namespace std;
|
||||
class Report {
|
||||
|
@ -103,7 +103,7 @@ int main(int argc, char** argv)
|
||||
*/
|
||||
auto env = platform::DotEnv();
|
||||
auto experiment = platform::Experiment();
|
||||
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("1.0.0");
|
||||
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
|
||||
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
||||
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName("accuracy");
|
||||
for (auto seed : seeds) {
|
||||
|
Loading…
Reference in New Issue
Block a user