If ! convergence don't predict test
This commit is contained in:
parent
506369e46b
commit
41257ed566
@ -32,6 +32,8 @@ namespace bayesnet {
|
|||||||
void BoostAODE::validationInit()
|
void BoostAODE::validationInit()
|
||||||
{
|
{
|
||||||
auto y_ = dataset.index({ -1, "..." });
|
auto y_ = dataset.index({ -1, "..." });
|
||||||
|
if (convergence) {
|
||||||
|
// Prepare train & validation sets from train data
|
||||||
auto fold = platform::StratifiedKFold(5, y_, 271);
|
auto fold = platform::StratifiedKFold(5, y_, 271);
|
||||||
dataset_ = torch::clone(dataset);
|
dataset_ = torch::clone(dataset);
|
||||||
// save input dataset
|
// save input dataset
|
||||||
@ -49,6 +51,12 @@ namespace bayesnet {
|
|||||||
metrics = Metrics(dataset, features, className, n_classes);
|
metrics = Metrics(dataset, features, className, n_classes);
|
||||||
// Build dataset with train data
|
// Build dataset with train data
|
||||||
buildDataset(y_train);
|
buildDataset(y_train);
|
||||||
|
} else {
|
||||||
|
// Use all data to train
|
||||||
|
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
||||||
|
y_train = y_;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
@ -64,7 +72,7 @@ namespace bayesnet {
|
|||||||
double priorAccuracy = 0.0;
|
double priorAccuracy = 0.0;
|
||||||
double delta = 1.0;
|
double delta = 1.0;
|
||||||
double threshold = 1e-4;
|
double threshold = 1e-4;
|
||||||
int tolerance = convergence ? 5 : INT_MAX; // number of times the accuracy can be lower than the threshold
|
int tolerance = 5; // number of times the accuracy can be lower than the threshold
|
||||||
int count = 0; // number of times the accuracy is lower than the threshold
|
int count = 0; // number of times the accuracy is lower than the threshold
|
||||||
fitted = true; // to enable predict
|
fitted = true; // to enable predict
|
||||||
// Step 0: Set the finish condition
|
// Step 0: Set the finish condition
|
||||||
@ -115,6 +123,7 @@ namespace bayesnet {
|
|||||||
models.push_back(std::move(model));
|
models.push_back(std::move(model));
|
||||||
significanceModels.push_back(alpha_t);
|
significanceModels.push_back(alpha_t);
|
||||||
n_models++;
|
n_models++;
|
||||||
|
if (convergence) {
|
||||||
auto y_val_predict = predict(X_test);
|
auto y_val_predict = predict(X_test);
|
||||||
double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
|
double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
|
||||||
if (priorAccuracy == 0) {
|
if (priorAccuracy == 0) {
|
||||||
@ -125,6 +134,7 @@ namespace bayesnet {
|
|||||||
if (delta < threshold) {
|
if (delta < threshold) {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
||||||
}
|
}
|
||||||
if (featuresUsed.size() != features.size()) {
|
if (featuresUsed.size() != features.size()) {
|
||||||
|
@ -132,9 +132,9 @@ namespace bayesnet {
|
|||||||
void Network::setStates(const map<string, vector<int>>& states)
|
void Network::setStates(const map<string, vector<int>>& states)
|
||||||
{
|
{
|
||||||
// Set states to every Node in the network
|
// Set states to every Node in the network
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
for_each(features.begin(), features.end(), [this, &states](const string& feature) {
|
||||||
nodes.at(features.at(i))->setNumStates(states.at(features[i]).size());
|
nodes.at(feature)->setNumStates(states.at(feature).size());
|
||||||
}
|
});
|
||||||
classNumStates = nodes.at(className)->getNumStates();
|
classNumStates = nodes.at(className)->getNumStates();
|
||||||
}
|
}
|
||||||
// X comes in nxm, where n is the number of features and m the number of samples
|
// X comes in nxm, where n is the number of features and m the number of samples
|
||||||
|
Loading…
Reference in New Issue
Block a user