Implement 3 types of smoothing
This commit is contained in:
@@ -37,6 +37,7 @@ namespace bayesnet {
|
||||
void AODELd::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
for (const auto& model : models) {
|
||||
model->setSmoothing(smoothing);
|
||||
model->fit(Xf, y, features, className, states);
|
||||
}
|
||||
}
|
||||
|
@@ -32,6 +32,7 @@ namespace bayesnet {
|
||||
for (int j = i + 1; j < featuresSelected.size(); j++) {
|
||||
auto parents = { featuresSelected[i], featuresSelected[j] };
|
||||
std::unique_ptr<Classifier> model = std::make_unique<SPnDE>(parents);
|
||||
model->setSmoothing(smoothing);
|
||||
model->fit(dataset, features, className, states, weights_);
|
||||
models.push_back(std::move(model));
|
||||
significanceModels.push_back(1.0); // They will be updated later in trainModel
|
||||
@@ -96,6 +97,7 @@ namespace bayesnet {
|
||||
pairSelection.erase(pairSelection.begin());
|
||||
std::unique_ptr<Classifier> model;
|
||||
model = std::make_unique<SPnDE>(std::vector<int>({ feature_pair.first, feature_pair.second }));
|
||||
model->setSmoothing(smoothing);
|
||||
model->fit(dataset, features, className, states, weights_);
|
||||
alpha_t = 0.0;
|
||||
if (!block_update) {
|
||||
|
@@ -22,6 +22,7 @@ namespace bayesnet {
|
||||
std::vector<int> featuresSelected = featureSelection(weights_);
|
||||
for (const int& feature : featuresSelected) {
|
||||
std::unique_ptr<Classifier> model = std::make_unique<SPODE>(feature);
|
||||
model->setSmoothing(smoothing);
|
||||
model->fit(dataset, features, className, states, weights_);
|
||||
models.push_back(std::move(model));
|
||||
significanceModels.push_back(1.0); // They will be updated later in trainModel
|
||||
@@ -89,6 +90,7 @@ namespace bayesnet {
|
||||
featureSelection.erase(featureSelection.begin());
|
||||
std::unique_ptr<Classifier> model;
|
||||
model = std::make_unique<SPODE>(feature);
|
||||
model->setSmoothing(smoothing);
|
||||
model->fit(dataset, features, className, states, weights_);
|
||||
alpha_t = 0.0;
|
||||
if (!block_update) {
|
||||
|
@@ -18,6 +18,7 @@ namespace bayesnet {
|
||||
n_models = models.size();
|
||||
for (auto i = 0; i < n_models; ++i) {
|
||||
// fit with std::vectors
|
||||
models[i]->setSmoothing(smoothing);
|
||||
models[i]->fit(dataset, features, className, states);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user