Complete Adding weights to Models
This commit is contained in:
parent
24b68f9ae2
commit
fa612c531e
@ -13,6 +13,7 @@ namespace bayesnet {
|
||||
// X is nxm tensor, y is nx1 tensor
|
||||
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
||||
virtual BaseClassifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
||||
virtual BaseClassifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights) = 0;
|
||||
virtual ~BaseClassifier() = default;
|
||||
torch::Tensor virtual predict(torch::Tensor& X) = 0;
|
||||
vector<int> virtual predict(vector<vector<int>>& X) = 0;
|
||||
|
@ -5,7 +5,7 @@ namespace bayesnet {
|
||||
using namespace torch;
|
||||
|
||||
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
|
||||
Classifier& Classifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
Classifier& Classifier::build(vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)
|
||||
{
|
||||
this->features = features;
|
||||
this->className = className;
|
||||
@ -16,14 +16,11 @@ namespace bayesnet {
|
||||
auto n_classes = states[className].size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
model.initialize();
|
||||
// TODO weights can't be ones
|
||||
const torch::Tensor weights = torch::ones({ m }, torch::kFloat);
|
||||
buildModel(weights);
|
||||
trainModel(weights);
|
||||
fitted = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Classifier::buildDataset(Tensor& ytmp)
|
||||
{
|
||||
try {
|
||||
@ -46,7 +43,8 @@ namespace bayesnet {
|
||||
{
|
||||
dataset = X;
|
||||
buildDataset(y);
|
||||
return build(features, className, states);
|
||||
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
|
||||
return build(features, className, states, weights);
|
||||
}
|
||||
// X is nxm where n is the number of features and m the number of samples
|
||||
Classifier& Classifier::fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
@ -57,12 +55,19 @@ namespace bayesnet {
|
||||
}
|
||||
auto ytmp = torch::tensor(y, kInt32);
|
||||
buildDataset(ytmp);
|
||||
return build(features, className, states);
|
||||
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
|
||||
return build(features, className, states, weights);
|
||||
}
|
||||
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
{
|
||||
this->dataset = dataset;
|
||||
return build(features, className, states);
|
||||
const torch::Tensor weights = torch::ones({ dataset.size(1) }, torch::kFloat);
|
||||
return build(features, className, states, weights);
|
||||
}
|
||||
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)
|
||||
{
|
||||
this->dataset = dataset;
|
||||
return build(features, className, states, weights);
|
||||
}
|
||||
void Classifier::checkFitParameters()
|
||||
{
|
||||
|
@ -11,7 +11,7 @@ namespace bayesnet {
|
||||
class Classifier : public BaseClassifier {
|
||||
private:
|
||||
void buildDataset(torch::Tensor& y);
|
||||
Classifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
|
||||
Classifier& build(vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights);
|
||||
protected:
|
||||
bool fitted;
|
||||
int m, n; // m: number of samples, n: number of features
|
||||
@ -30,6 +30,7 @@ namespace bayesnet {
|
||||
Classifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||
Classifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||
Classifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||
Classifier& fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights) override;
|
||||
void addNodes();
|
||||
int getNumberOfNodes() const override;
|
||||
int getNumberOfEdges() const override;
|
||||
|
Loading…
Reference in New Issue
Block a user