tolerance <- 3
This commit is contained in:
@@ -13,8 +13,8 @@ namespace platform {
|
|||||||
auto X = TensorUtils::to_matrix(dataset.slice(0, 0, dataset.size(0) - 1));
|
auto X = TensorUtils::to_matrix(dataset.slice(0, 0, dataset.size(0) - 1));
|
||||||
auto y = TensorUtils::to_vector<int>(dataset.index({ -1, "..." }));
|
auto y = TensorUtils::to_vector<int>(dataset.index({ -1, "..." }));
|
||||||
int num_instances = X[0].size();
|
int num_instances = X[0].size();
|
||||||
weights_ = weights;
|
weights_ = torch::full({ num_instances }, 1.0);
|
||||||
normalize_weights(num_instances);
|
normalize_weights(num_instances);
|
||||||
aode_.fit(X, y, features, className, states, weights_, true);
|
aode_.fit(X, y, features, className, states, weights_, true, smoothing);
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -26,7 +26,7 @@ namespace platform {
|
|||||||
y_train_ = TensorUtils::to_vector<int>(y_train);
|
y_train_ = TensorUtils::to_vector<int>(y_train);
|
||||||
X_test_ = TensorUtils::to_matrix(X_test);
|
X_test_ = TensorUtils::to_matrix(X_test);
|
||||||
y_test_ = TensorUtils::to_vector<int>(y_test);
|
y_test_ = TensorUtils::to_vector<int>(y_test);
|
||||||
maxTolerance = 5;
|
maxTolerance = 3;
|
||||||
//
|
//
|
||||||
// Logging setup
|
// Logging setup
|
||||||
//
|
//
|
||||||
|
@@ -15,7 +15,10 @@
|
|||||||
#include "ExpClf.h"
|
#include "ExpClf.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class XBAODE : public ExpClf {
|
class XBAODE {
|
||||||
|
|
||||||
|
// Hay que hacer un vector de modelos entrenados y hacer un predict ensemble con todos ellos
|
||||||
|
// Probar XA1DE con smooth original y laplace y comprobar diferencias si se pasan pesos a 1 o a 1/m
|
||||||
public:
|
public:
|
||||||
XBAODE();
|
XBAODE();
|
||||||
std::string getVersion() override { return version; };
|
std::string getVersion() override { return version; };
|
||||||
|
@@ -18,6 +18,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
#include <bayesnet/network/Smoothing.h>
|
||||||
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
@@ -49,7 +50,7 @@ namespace platform {
|
|||||||
//
|
//
|
||||||
// Internally, in COUNTS mode, data_ accumulates raw counts, then
|
// Internally, in COUNTS mode, data_ accumulates raw counts, then
|
||||||
// computeProbabilities(...) normalizes them into conditionals.
|
// computeProbabilities(...) normalizes them into conditionals.
|
||||||
void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const bool all_parents)
|
void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const bool all_parents, const bayesnet::Smoothing_t smoothing)
|
||||||
{
|
{
|
||||||
int num_instances = X[0].size();
|
int num_instances = X[0].size();
|
||||||
nFeatures_ = X.size();
|
nFeatures_ = X.size();
|
||||||
@@ -110,8 +111,16 @@ namespace platform {
|
|||||||
instance[nFeatures_] = y[n_instance];
|
instance[nFeatures_] = y[n_instance];
|
||||||
addSample(instance, weights[n_instance].item<double>());
|
addSample(instance, weights[n_instance].item<double>());
|
||||||
}
|
}
|
||||||
// alpha_ Laplace smoothing adapted to the number of instances
|
switch (smoothing) {
|
||||||
alpha_ = 1.0 / static_cast<double>(num_instances);
|
case bayesnet::Smoothing_t::ORIGINAL:
|
||||||
|
alpha_ = 1.0 / num_instances;
|
||||||
|
break;
|
||||||
|
case bayesnet::Smoothing_t::LAPLACE:
|
||||||
|
alpha_ = 1.0;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
alpha_ = 0.0; // No smoothing
|
||||||
|
}
|
||||||
initializer_ = std::numeric_limits<double>::max() / (nFeatures_ * nFeatures_);
|
initializer_ = std::numeric_limits<double>::max() / (nFeatures_ * nFeatures_);
|
||||||
computeProbabilities();
|
computeProbabilities();
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user