make weights double
This commit is contained in:
parent
918a7b4180
commit
a3e665eed6
@ -43,7 +43,7 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
dataset = X;
|
dataset = X;
|
||||||
buildDataset(y);
|
buildDataset(y);
|
||||||
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
|
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
|
||||||
return build(features, className, states, weights);
|
return build(features, className, states, weights);
|
||||||
}
|
}
|
||||||
// X is nxm where n is the number of features and m the number of samples
|
// X is nxm where n is the number of features and m the number of samples
|
||||||
@ -55,13 +55,13 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
auto ytmp = torch::tensor(y, kInt32);
|
auto ytmp = torch::tensor(y, kInt32);
|
||||||
buildDataset(ytmp);
|
buildDataset(ytmp);
|
||||||
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
|
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
|
||||||
return build(features, className, states, weights);
|
return build(features, className, states, weights);
|
||||||
}
|
}
|
||||||
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
|
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||||
{
|
{
|
||||||
this->dataset = dataset;
|
this->dataset = dataset;
|
||||||
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kFloat);
|
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
|
||||||
return build(features, className, states, weights);
|
return build(features, className, states, weights);
|
||||||
}
|
}
|
||||||
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)
|
Classifier& Classifier::fit(torch::Tensor& dataset, vector<string>& features, string className, map<string, vector<int>>& states, const torch::Tensor& weights)
|
||||||
|
@ -65,8 +65,7 @@ namespace bayesnet {
|
|||||||
//Update new states of the feature/node
|
//Update new states of the feature/node
|
||||||
states[pFeatures[index]] = xStates;
|
states[pFeatures[index]] = xStates;
|
||||||
}
|
}
|
||||||
// TODO weights can't be ones
|
const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble);
|
||||||
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
|
|
||||||
model.fit(pDataset, weights, pFeatures, pClassName, states);
|
model.fit(pDataset, weights, pFeatures, pClassName, states);
|
||||||
}
|
}
|
||||||
return states;
|
return states;
|
||||||
|
Loading…
Reference in New Issue
Block a user