Fix some mistakes in tensors treatment
This commit is contained in:
@@ -18,8 +18,13 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
Result cross_validation(Fold* fold, bayesnet::BaseClassifier* model, Tensor& X, Tensor& y, vector<string> features, string className, map<string, vector<int>> states)
|
||||
Result cross_validation(Fold* fold, string model_name, Tensor& X, Tensor& y, vector<string> features, string className, map<string, vector<int>> states)
|
||||
{
|
||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||
}
|
||||
);
|
||||
auto result = Result();
|
||||
auto k = fold->getNumberOfFolds();
|
||||
auto accuracy = torch::zeros({ k }, kFloat64);
|
||||
@@ -27,6 +32,7 @@ Result cross_validation(Fold* fold, bayesnet::BaseClassifier* model, Tensor& X,
|
||||
auto test_time = torch::zeros({ k }, kFloat64);
|
||||
Timer train_timer, test_timer;
|
||||
for (int i = 0; i < k; i++) {
|
||||
bayesnet::BaseClassifier* model = classifiers[model_name];
|
||||
train_timer.start();
|
||||
auto [train, test] = fold->getFold(i);
|
||||
auto train_t = torch::tensor(train);
|
||||
@@ -43,8 +49,7 @@ Result cross_validation(Fold* fold, bayesnet::BaseClassifier* model, Tensor& X,
|
||||
cout << "y_test: " << y_test.sizes() << endl;
|
||||
train_time[i] = train_timer.getDuration();
|
||||
test_timer.start();
|
||||
//auto acc = model->score(X_test, y_test);
|
||||
auto acc = 7;
|
||||
auto acc = model->score(X_test, y_test);
|
||||
test_time[i] = test_timer.getDuration();
|
||||
accuracy[i] = acc;
|
||||
}
|
||||
@@ -140,18 +145,16 @@ int main(int argc, char** argv)
|
||||
fold = new StratifiedKFold(n_folds, y, -1);
|
||||
else
|
||||
fold = new KFold(n_folds, y.numel(), -1);
|
||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||
}
|
||||
);
|
||||
|
||||
auto experiment = Experiment();
|
||||
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform("cpp");
|
||||
experiment.setStratified(stratified).setNFolds(5).addRandomSeed(271).setScoreName("accuracy");
|
||||
bayesnet::BaseClassifier* model = classifiers[model_name];
|
||||
auto result = cross_validation(fold, model, X, y, features, className, states);
|
||||
auto result = cross_validation(fold, model_name, X, y, features, className, states);
|
||||
result.setDataset(file_name);
|
||||
experiment.addResult(result);
|
||||
experiment.save(path);
|
||||
for (auto& item : states) {
|
||||
cout << item.first << ": " << item.second.size() << endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
Reference in New Issue
Block a user