Complete Adding weights to Models
This commit is contained in:
parent
24b68f9ae2
commit
fa612c531e
@ -13,6 +13,7 @@ namespace bayesnet {
|
|||||||
// X is nxm tensor, y is nx1 tensor
|
// X is nxm tensor, y is nx1 tensor
|
||||||
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
||||||
virtual BaseClassifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
virtual BaseClassifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
||||||
|
virtual BaseClassifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights) = 0;
|
||||||
virtual ~BaseClassifier() = default;
|
virtual ~BaseClassifier() = default;
|
||||||
torch::Tensor virtual predict(torch::Tensor& X) = 0;
|
torch::Tensor virtual predict(torch::Tensor& X) = 0;
|
||||||
vector<int> virtual predict(vector<vector<int>>& X) = 0;
|
vector<int> virtual predict(vector<vector<int>>& X) = 0;
|
||||||
|
@ -5,7 +5,7 @@ namespace bayesnet {
|
|||||||
using namespace torch;
|
using namespace torch;
|
||||||
|
|
||||||
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
|
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
|
||||||
Classifier& Classifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
|
Classifier& Classifier::build(vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
this->features = features;
|
this->features = features;
|
||||||
this->className = className;
|
this->className = className;
|
||||||
@ -16,14 +16,11 @@ namespace bayesnet {
|
|||||||
auto n_classes = states[className].size();
|
auto n_classes = states[className].size();
|
||||||
metrics = Metrics(dataset, features, className, n_classes);
|
metrics = Metrics(dataset, features, className, n_classes);
|
||||||
model.initialize();
|
model.initialize();
|
||||||
// TODO weights can't be ones
|
|
||||||
const torch::Tensor weights = torch::ones({ m }, torch::kFloat);
|
|
||||||
buildModel(weights);
|
buildModel(weights);
|
||||||
trainModel(weights);
|
trainModel(weights);
|
||||||
fitted = true;
|
fitted = true;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Classifier::buildDataset(Tensor& ytmp)
|
void Classifier::buildDataset(Tensor& ytmp)
|
||||||
{
|
{
|
||||||
try {
|
try {
|
||||||
@ -46,7 +43,8 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
dataset = X;
|
dataset = X;
|
||||||
buildDataset(y);
|
buildDataset(y);
|
||||||
return build(features, className, states);
|
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
|
||||||
|
return build(features, className, states, weights);
|
||||||
}
|
}
|
||||||
// X is nxm where n is the number of features and m the number of samples
|
// X is nxm where n is the number of features and m the number of samples
|
||||||
Classifier& Classifier::fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
Classifier& Classifier::fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||||
@ -57,12 +55,19 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
auto ytmp = torch::tensor(y, kInt32);
|
auto ytmp = torch::tensor(y, kInt32);
|
||||||
buildDataset(ytmp);
|
buildDataset(ytmp);
|
||||||
return build(features, className, states);
|
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
|
||||||
|
return build(features, className, states, weights);
|
||||||
}
|
}
|
||||||
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
|
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||||
{
|
{
|
||||||
this->dataset = dataset;
|
this->dataset = dataset;
|
||||||
return build(features, className, states);
|
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
|
||||||
|
return build(features, className, states, weights);
|
||||||
|
}
|
||||||
|
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)
|
||||||
|
{
|
||||||
|
this->dataset = dataset;
|
||||||
|
return build(features, className, states, weights);
|
||||||
}
|
}
|
||||||
void Classifier::checkFitParameters()
|
void Classifier::checkFitParameters()
|
||||||
{
|
{
|
||||||
|
@ -11,7 +11,7 @@ namespace bayesnet {
|
|||||||
class Classifier : public BaseClassifier {
|
class Classifier : public BaseClassifier {
|
||||||
private:
|
private:
|
||||||
void buildDataset(torch::Tensor& y);
|
void buildDataset(torch::Tensor& y);
|
||||||
Classifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
|
Classifier& build(vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights);
|
||||||
protected:
|
protected:
|
||||||
bool fitted;
|
bool fitted;
|
||||||
int m, n; // m: number of samples, n: number of features
|
int m, n; // m: number of samples, n: number of features
|
||||||
@ -30,6 +30,7 @@ namespace bayesnet {
|
|||||||
Classifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
Classifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
Classifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
Classifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
Classifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
Classifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
|
Classifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights) override;
|
||||||
void addNodes();
|
void addNodes();
|
||||||
int getNumberOfNodes() const override;
|
int getNumberOfNodes() const override;
|
||||||
int getNumberOfEdges() const override;
|
int getNumberOfEdges() const override;
|
||||||
|
Loading…
Reference in New Issue
Block a user