Complete Adding weights to Models

This commit is contained in:
Ricardo Montañana Gómez 2023-08-15 15:59:56 +02:00
parent 24b68f9ae2
commit fa612c531e
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 15 additions and 8 deletions

View File

@ -13,6 +13,7 @@ namespace bayesnet {
// X is nxm tensor, y is nx1 tensor
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
virtual BaseClassifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
virtual BaseClassifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights) = 0;
virtual ~BaseClassifier() = default;
torch::Tensor virtual predict(torch::Tensor& X) = 0;
vector<int> virtual predict(vector<vector<int>>& X) = 0;

View File

@ -5,7 +5,7 @@ namespace bayesnet {
using namespace torch;
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
Classifier& Classifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
Classifier& Classifier::build(vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)
{
this->features = features;
this->className = className;
@ -16,14 +16,11 @@ namespace bayesnet {
auto n_classes = states[className].size();
metrics = Metrics(dataset, features, className, n_classes);
model.initialize();
// TODO weights can't be ones
const torch::Tensor weights = torch::ones({ m }, torch::kFloat);
buildModel(weights);
trainModel(weights);
fitted = true;
return *this;
}
void Classifier::buildDataset(Tensor& ytmp)
{
try {
@ -46,7 +43,8 @@ namespace bayesnet {
{
dataset = X;
buildDataset(y);
return build(features, className, states);
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
return build(features, className, states, weights);
}
// X is nxm where n is the number of features and m the number of samples
Classifier& Classifier::fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states)
@ -57,12 +55,19 @@ namespace bayesnet {
}
auto ytmp = torch::tensor(y, kInt32);
buildDataset(ytmp);
return build(features, className, states);
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
return build(features, className, states, weights);
}
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
{
this->dataset = dataset;
return build(features, className, states);
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
return build(features, className, states, weights);
}
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)
{
this->dataset = dataset;
return build(features, className, states, weights);
}
void Classifier::checkFitParameters()
{

View File

@ -11,7 +11,7 @@ namespace bayesnet {
class Classifier : public BaseClassifier {
private:
void buildDataset(torch::Tensor& y);
Classifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
Classifier& build(vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights);
protected:
bool fitted;
int m, n; // m: number of samples, n: number of features
@ -30,6 +30,7 @@ namespace bayesnet {
Classifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
Classifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
Classifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) override;
Classifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights) override;
void addNodes();
int getNumberOfNodes() const override;
int getNumberOfEdges() const override;