diff --git a/src/main.cc b/src/main.cc index 56b0015..82f44de 100644 --- a/src/main.cc +++ b/src/main.cc @@ -44,19 +44,19 @@ int main(int argc, char* argv[]) { cout << "* Begin." << endl; { - auto [X, y, features, className, states] = loadDataset("iris", true); + auto [X, y, features, className, states] = loadDataset("wine", true); cout << "X: " << X.sizes() << endl; cout << "y: " << y.sizes() << endl; auto clf = pywrap::STree(); cout << "STree Version: " << clf.version() << endl; - // if (true) { - // auto svc = pywrap::SVC(); - // svc.fit(X, y, features, className, states); - // cout << "SVC Score: " << svc.score(X, y) << endl; - // } - // cout << "Graph: " << endl << clf.graph() << endl; + if (true) { + auto svc = pywrap::SVC(); + svc.fit(X, y, features, className, states); + cout << "SVC Score: " << svc.score(X, y) << endl; + } + cout << "Graph: " << endl << clf.graph() << endl; clf.fit(X, y, features, className, states); - // cout << "STree Score: " << clf.score(X, y) << endl; + cout << "STree Score: " << clf.score(X, y) << endl; auto prediction = clf.predict(X); cout << "Prediction: " << endl << "{"; for (int i = 0; i < prediction.size(0); ++i) {