Refactor Library renaming Base classes
This commit is contained in:
parent
41cceece20
commit
9981ad1811
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@ -99,7 +99,8 @@
|
|||||||
"typeindex": "cpp",
|
"typeindex": "cpp",
|
||||||
"shared_mutex": "cpp",
|
"shared_mutex": "cpp",
|
||||||
"*.ipp": "cpp",
|
"*.ipp": "cpp",
|
||||||
"cassert": "cpp"
|
"cassert": "cpp",
|
||||||
|
"charconv": "cpp"
|
||||||
},
|
},
|
||||||
"cmake.configureOnOpen": false,
|
"cmake.configureOnOpen": false,
|
||||||
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
|
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <map>
|
||||||
#include <argparse/argparse.hpp>
|
#include <argparse/argparse.hpp>
|
||||||
|
#include "BaseClassifier.h"
|
||||||
#include "ArffFiles.h"
|
#include "ArffFiles.h"
|
||||||
#include "Network.h"
|
#include "Network.h"
|
||||||
#include "BayesMetrics.h"
|
#include "BayesMetrics.h"
|
||||||
@ -143,38 +145,12 @@ int main(int argc, char** argv)
|
|||||||
states[className] = vector<int>(
|
states[className] = vector<int>(
|
||||||
maxes[className]);
|
maxes[className]);
|
||||||
double score;
|
double score;
|
||||||
vector<string> lines;
|
auto classifiers = map<string, bayesnet::BaseClassifier*>({ { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } });
|
||||||
vector<string> graph;
|
bayesnet::BaseClassifier* clf = classifiers[model_name];
|
||||||
auto kdb = bayesnet::KDB(2);
|
clf->fit(Xd, y, features, className, states);
|
||||||
auto aode = bayesnet::AODE();
|
score = clf->score(Xd, y);
|
||||||
auto spode = bayesnet::SPODE(2);
|
auto lines = clf->show();
|
||||||
auto tan = bayesnet::TAN();
|
auto graph = clf->graph();
|
||||||
switch (hash_conv(model_name)) {
|
|
||||||
case "AODE"_sh:
|
|
||||||
aode.fit(Xd, y, features, className, states);
|
|
||||||
lines = aode.show();
|
|
||||||
score = aode.score(Xd, y);
|
|
||||||
graph = aode.graph();
|
|
||||||
break;
|
|
||||||
case "KDB"_sh:
|
|
||||||
kdb.fit(Xd, y, features, className, states);
|
|
||||||
lines = kdb.show();
|
|
||||||
score = kdb.score(Xd, y);
|
|
||||||
graph = kdb.graph();
|
|
||||||
break;
|
|
||||||
case "SPODE"_sh:
|
|
||||||
spode.fit(Xd, y, features, className, states);
|
|
||||||
lines = spode.show();
|
|
||||||
score = spode.score(Xd, y);
|
|
||||||
graph = spode.graph();
|
|
||||||
break;
|
|
||||||
case "TAN"_sh:
|
|
||||||
tan.fit(Xd, y, features, className, states);
|
|
||||||
lines = tan.show();
|
|
||||||
score = tan.score(Xd, y);
|
|
||||||
graph = tan.graph();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
for (auto line : lines) {
|
for (auto line : lines) {
|
||||||
cout << line << endl;
|
cout << line << endl;
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ namespace bayesnet {
|
|||||||
void train() override;
|
void train() override;
|
||||||
public:
|
public:
|
||||||
AODE();
|
AODE();
|
||||||
vector<string> graph(string title = "AODE");
|
vector<string> graph(string title = "AODE") override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -1,48 +1,17 @@
|
|||||||
#ifndef CLASSIFIERS_H
|
#ifndef BASE_H
|
||||||
#define CLASSIFIERS_H
|
#define BASE_H
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
#include "Network.h"
|
#include <vector>
|
||||||
#include "BayesMetrics.h"
|
|
||||||
using namespace std;
|
|
||||||
using namespace torch;
|
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
|
using namespace std;
|
||||||
class BaseClassifier {
|
class BaseClassifier {
|
||||||
private:
|
|
||||||
bool fitted;
|
|
||||||
BaseClassifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
|
|
||||||
protected:
|
|
||||||
Network model;
|
|
||||||
int m, n; // m: number of samples, n: number of features
|
|
||||||
Tensor X;
|
|
||||||
vector<vector<int>> Xv;
|
|
||||||
Tensor y;
|
|
||||||
vector<int> yv;
|
|
||||||
Tensor dataset;
|
|
||||||
Metrics metrics;
|
|
||||||
vector<string> features;
|
|
||||||
string className;
|
|
||||||
map<string, vector<int>> states;
|
|
||||||
void checkFitParameters();
|
|
||||||
virtual void train() = 0;
|
|
||||||
public:
|
public:
|
||||||
BaseClassifier(Network model);
|
virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
||||||
|
vector<int> virtual predict(vector<vector<int>>& X) = 0;
|
||||||
|
float virtual score(vector<vector<int>>& X, vector<int>& y) = 0;
|
||||||
|
vector<string> virtual show() = 0;
|
||||||
|
vector<string> virtual graph(string title = "") = 0;
|
||||||
virtual ~BaseClassifier() = default;
|
virtual ~BaseClassifier() = default;
|
||||||
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
|
|
||||||
void addNodes();
|
|
||||||
int getNumberOfNodes();
|
|
||||||
int getNumberOfEdges();
|
|
||||||
Tensor predict(Tensor& X);
|
|
||||||
vector<int> predict(vector<vector<int>>& X);
|
|
||||||
float score(Tensor& X, Tensor& y);
|
|
||||||
float score(vector<vector<int>>& X, vector<int>& y);
|
|
||||||
vector<string> show();
|
|
||||||
virtual vector<string> graph(string title) = 0;
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
|||||||
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc)
|
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc)
|
||||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
@ -1,12 +1,12 @@
|
|||||||
#include "BaseClassifier.h"
|
#include "Classifier.h"
|
||||||
#include "bayesnetUtils.h"
|
#include "bayesnetUtils.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
|
|
||||||
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
|
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
|
||||||
BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
|
Classifier& Classifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
|
||||||
{
|
{
|
||||||
dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1);
|
dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1);
|
||||||
this->features = features;
|
this->features = features;
|
||||||
@ -20,7 +20,7 @@ namespace bayesnet {
|
|||||||
fitted = true;
|
fitted = true;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
BaseClassifier& BaseClassifier::fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
Classifier& Classifier::fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||||
{
|
{
|
||||||
this->X = torch::zeros({ static_cast<int64_t>(X[0].size()), static_cast<int64_t>(X.size()) }, kInt64);
|
this->X = torch::zeros({ static_cast<int64_t>(X[0].size()), static_cast<int64_t>(X.size()) }, kInt64);
|
||||||
Xv = X;
|
Xv = X;
|
||||||
@ -31,7 +31,7 @@ namespace bayesnet {
|
|||||||
yv = y;
|
yv = y;
|
||||||
return build(features, className, states);
|
return build(features, className, states);
|
||||||
}
|
}
|
||||||
void BaseClassifier::checkFitParameters()
|
void Classifier::checkFitParameters()
|
||||||
{
|
{
|
||||||
auto sizes = X.sizes();
|
auto sizes = X.sizes();
|
||||||
m = sizes[0];
|
m = sizes[0];
|
||||||
@ -52,7 +52,7 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor BaseClassifier::predict(Tensor& X)
|
Tensor Classifier::predict(Tensor& X)
|
||||||
{
|
{
|
||||||
if (!fitted) {
|
if (!fitted) {
|
||||||
throw logic_error("Classifier has not been fitted");
|
throw logic_error("Classifier has not been fitted");
|
||||||
@ -68,7 +68,7 @@ namespace bayesnet {
|
|||||||
auto ypred = torch::tensor(yp, torch::kInt64);
|
auto ypred = torch::tensor(yp, torch::kInt64);
|
||||||
return ypred;
|
return ypred;
|
||||||
}
|
}
|
||||||
vector<int> BaseClassifier::predict(vector<vector<int>>& X)
|
vector<int> Classifier::predict(vector<vector<int>>& X)
|
||||||
{
|
{
|
||||||
if (!fitted) {
|
if (!fitted) {
|
||||||
throw logic_error("Classifier has not been fitted");
|
throw logic_error("Classifier has not been fitted");
|
||||||
@ -82,7 +82,7 @@ namespace bayesnet {
|
|||||||
auto yp = model.predict(Xd);
|
auto yp = model.predict(Xd);
|
||||||
return yp;
|
return yp;
|
||||||
}
|
}
|
||||||
float BaseClassifier::score(Tensor& X, Tensor& y)
|
float Classifier::score(Tensor& X, Tensor& y)
|
||||||
{
|
{
|
||||||
if (!fitted) {
|
if (!fitted) {
|
||||||
throw logic_error("Classifier has not been fitted");
|
throw logic_error("Classifier has not been fitted");
|
||||||
@ -90,7 +90,7 @@ namespace bayesnet {
|
|||||||
Tensor y_pred = predict(X);
|
Tensor y_pred = predict(X);
|
||||||
return (y_pred == y).sum().item<float>() / y.size(0);
|
return (y_pred == y).sum().item<float>() / y.size(0);
|
||||||
}
|
}
|
||||||
float BaseClassifier::score(vector<vector<int>>& X, vector<int>& y)
|
float Classifier::score(vector<vector<int>>& X, vector<int>& y)
|
||||||
{
|
{
|
||||||
if (!fitted) {
|
if (!fitted) {
|
||||||
throw logic_error("Classifier has not been fitted");
|
throw logic_error("Classifier has not been fitted");
|
||||||
@ -103,11 +103,11 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
return model.score(Xd, y);
|
return model.score(Xd, y);
|
||||||
}
|
}
|
||||||
vector<string> BaseClassifier::show()
|
vector<string> Classifier::show()
|
||||||
{
|
{
|
||||||
return model.show();
|
return model.show();
|
||||||
}
|
}
|
||||||
void BaseClassifier::addNodes()
|
void Classifier::addNodes()
|
||||||
{
|
{
|
||||||
// Add all nodes to the network
|
// Add all nodes to the network
|
||||||
for (auto feature : features) {
|
for (auto feature : features) {
|
||||||
@ -115,12 +115,12 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
model.addNode(className, states[className].size());
|
model.addNode(className, states[className].size());
|
||||||
}
|
}
|
||||||
int BaseClassifier::getNumberOfNodes()
|
int Classifier::getNumberOfNodes()
|
||||||
{
|
{
|
||||||
// Features does not include class
|
// Features does not include class
|
||||||
return fitted ? model.getFeatures().size() + 1 : 0;
|
return fitted ? model.getFeatures().size() + 1 : 0;
|
||||||
}
|
}
|
||||||
int BaseClassifier::getNumberOfEdges()
|
int Classifier::getNumberOfEdges()
|
||||||
{
|
{
|
||||||
return fitted ? model.getEdges().size() : 0;
|
return fitted ? model.getEdges().size() : 0;
|
||||||
}
|
}
|
48
src/BayesNet/Classifier.h
Normal file
48
src/BayesNet/Classifier.h
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
#ifndef CLASSIFIER_H
|
||||||
|
#define CLASSIFIER_H
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include "BaseClassifier.h"
|
||||||
|
#include "Network.h"
|
||||||
|
#include "BayesMetrics.h"
|
||||||
|
using namespace std;
|
||||||
|
using namespace torch;
|
||||||
|
|
||||||
|
namespace bayesnet {
|
||||||
|
class Classifier : public BaseClassifier {
|
||||||
|
private:
|
||||||
|
bool fitted;
|
||||||
|
Classifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
|
||||||
|
protected:
|
||||||
|
Network model;
|
||||||
|
int m, n; // m: number of samples, n: number of features
|
||||||
|
Tensor X;
|
||||||
|
vector<vector<int>> Xv;
|
||||||
|
Tensor y;
|
||||||
|
vector<int> yv;
|
||||||
|
Tensor dataset;
|
||||||
|
Metrics metrics;
|
||||||
|
vector<string> features;
|
||||||
|
string className;
|
||||||
|
map<string, vector<int>> states;
|
||||||
|
void checkFitParameters();
|
||||||
|
virtual void train() = 0;
|
||||||
|
public:
|
||||||
|
Classifier(Network model);
|
||||||
|
virtual ~Classifier() = default;
|
||||||
|
Classifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
|
||||||
|
void addNodes();
|
||||||
|
int getNumberOfNodes();
|
||||||
|
int getNumberOfEdges();
|
||||||
|
Tensor predict(Tensor& X);
|
||||||
|
vector<int> predict(vector<vector<int>>& X);
|
||||||
|
float score(Tensor& X, Tensor& y);
|
||||||
|
float score(vector<vector<int>>& X, vector<int>& y);
|
||||||
|
vector<string> show();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
|||||||
#ifndef ENSEMBLE_H
|
#ifndef ENSEMBLE_H
|
||||||
#define ENSEMBLE_H
|
#define ENSEMBLE_H
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
#include "BaseClassifier.h"
|
#include "Classifier.h"
|
||||||
#include "BayesMetrics.h"
|
#include "BayesMetrics.h"
|
||||||
#include "bayesnetUtils.h"
|
#include "bayesnetUtils.h"
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class Ensemble {
|
class Ensemble : public BaseClassifier {
|
||||||
private:
|
private:
|
||||||
bool fitted;
|
bool fitted;
|
||||||
long n_models;
|
long n_models;
|
||||||
Ensemble& build(vector<string>& features, string className, map<string, vector<int>>& states);
|
Ensemble& build(vector<string>& features, string className, map<string, vector<int>>& states);
|
||||||
protected:
|
protected:
|
||||||
vector<unique_ptr<BaseClassifier>> models;
|
vector<unique_ptr<Classifier>> models;
|
||||||
int m, n; // m: number of samples, n: number of features
|
int m, n; // m: number of samples, n: number of features
|
||||||
Tensor X;
|
Tensor X;
|
||||||
vector<vector<int>> Xv;
|
vector<vector<int>> Xv;
|
||||||
@ -30,13 +30,13 @@ namespace bayesnet {
|
|||||||
public:
|
public:
|
||||||
Ensemble();
|
Ensemble();
|
||||||
virtual ~Ensemble() = default;
|
virtual ~Ensemble() = default;
|
||||||
Ensemble& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
|
Ensemble& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
Tensor predict(Tensor& X);
|
Tensor predict(Tensor& X);
|
||||||
vector<int> predict(vector<vector<int>>& X);
|
vector<int> predict(vector<vector<int>>& X) override;
|
||||||
float score(Tensor& X, Tensor& y);
|
float score(Tensor& X, Tensor& y);
|
||||||
float score(vector<vector<int>>& X, vector<int>& y);
|
float score(vector<vector<int>>& X, vector<int>& y) override;
|
||||||
vector<string> show();
|
vector<string> show() override;
|
||||||
vector<string> graph(string title);
|
vector<string> graph(string title) override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -4,7 +4,7 @@ namespace bayesnet {
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
|
|
||||||
KDB::KDB(int k, float theta) : BaseClassifier(Network()), k(k), theta(theta) {}
|
KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta) {}
|
||||||
void KDB::train()
|
void KDB::train()
|
||||||
{
|
{
|
||||||
/*
|
/*
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
#ifndef KDB_H
|
#ifndef KDB_H
|
||||||
#define KDB_H
|
#define KDB_H
|
||||||
#include "BaseClassifier.h"
|
#include "Classifier.h"
|
||||||
#include "bayesnetUtils.h"
|
#include "bayesnetUtils.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
class KDB : public BaseClassifier {
|
class KDB : public Classifier {
|
||||||
private:
|
private:
|
||||||
int k;
|
int k;
|
||||||
float theta;
|
float theta;
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
|
|
||||||
SPODE::SPODE(int root) : BaseClassifier(Network()), root(root) {}
|
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
|
||||||
|
|
||||||
void SPODE::train()
|
void SPODE::train()
|
||||||
{
|
{
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
#ifndef SPODE_H
|
#ifndef SPODE_H
|
||||||
#define SPODE_H
|
#define SPODE_H
|
||||||
#include "BaseClassifier.h"
|
#include "Classifier.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class SPODE : public BaseClassifier {
|
class SPODE : public Classifier {
|
||||||
private:
|
private:
|
||||||
int root;
|
int root;
|
||||||
protected:
|
protected:
|
||||||
|
@ -4,7 +4,7 @@ namespace bayesnet {
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
|
|
||||||
TAN::TAN() : BaseClassifier(Network()) {}
|
TAN::TAN() : Classifier(Network()) {}
|
||||||
|
|
||||||
void TAN::train()
|
void TAN::train()
|
||||||
{
|
{
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
#ifndef TAN_H
|
#ifndef TAN_H
|
||||||
#define TAN_H
|
#define TAN_H
|
||||||
#include "BaseClassifier.h"
|
#include "Classifier.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
class TAN : public BaseClassifier {
|
class TAN : public Classifier {
|
||||||
private:
|
private:
|
||||||
protected:
|
protected:
|
||||||
void train() override;
|
void train() override;
|
||||||
|
@ -11,12 +11,11 @@
|
|||||||
#include "SPODE.h"
|
#include "SPODE.h"
|
||||||
#include "AODE.h"
|
#include "AODE.h"
|
||||||
#include "TAN.h"
|
#include "TAN.h"
|
||||||
|
#include "platformUtils.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
const string PATH = "../../data/";
|
|
||||||
|
|
||||||
inline constexpr auto hash_conv(const std::string_view sv)
|
inline constexpr auto hash_conv(const std::string_view sv)
|
||||||
{
|
{
|
||||||
unsigned long hash{ 5381 };
|
unsigned long hash{ 5381 };
|
||||||
@ -31,31 +30,6 @@ inline constexpr auto operator"" _sh(const char* str, size_t len)
|
|||||||
return hash_conv(std::string_view{ str, len });
|
return hash_conv(std::string_view{ str, len });
|
||||||
}
|
}
|
||||||
|
|
||||||
pair<vector<mdlp::labels_t>, map<string, int>> discretize(vector<mdlp::samples_t>& X, mdlp::labels_t& y, vector<string> features)
|
|
||||||
{
|
|
||||||
vector<mdlp::labels_t>Xd;
|
|
||||||
map<string, int> maxes;
|
|
||||||
|
|
||||||
auto fimdlp = mdlp::CPPFImdlp();
|
|
||||||
for (int i = 0; i < X.size(); i++) {
|
|
||||||
fimdlp.fit(X[i], y);
|
|
||||||
mdlp::labels_t& xd = fimdlp.transform(X[i]);
|
|
||||||
maxes[features[i]] = *max_element(xd.begin(), xd.end()) + 1;
|
|
||||||
Xd.push_back(xd);
|
|
||||||
}
|
|
||||||
return { Xd, maxes };
|
|
||||||
}
|
|
||||||
|
|
||||||
bool file_exists(const std::string& name)
|
|
||||||
{
|
|
||||||
if (FILE* file = fopen(name.c_str(), "r")) {
|
|
||||||
fclose(file);
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
map<string, bool> datasets = {
|
map<string, bool> datasets = {
|
||||||
|
@ -63,7 +63,7 @@ StratifiedKFold::StratifiedKFold(int k, const vector<int>& y, int seed) :
|
|||||||
class_indices[label].erase(class_indices[label].begin(), it);
|
class_indices[label].erase(class_indices[label].begin(), it);
|
||||||
}
|
}
|
||||||
while (remainder_samples_to_take > 0) {
|
while (remainder_samples_to_take > 0) {
|
||||||
int fold = (rand() % static_cast<int>(k));
|
int fold = (arc4random() % static_cast<int>(k));
|
||||||
if (stratified_indices[fold].size() == fold_size) {
|
if (stratified_indices[fold].size() == fold_size) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user