diff --git a/tests/TestPythonClassifiers.cc b/tests/TestPythonClassifiers.cc index 9c42f67..8c2b22c 100644 --- a/tests/TestPythonClassifiers.cc +++ b/tests/TestPythonClassifiers.cc @@ -17,11 +17,11 @@ TEST_CASE("Test Python Classifiers score", "[PyClassifiers]") { map , float> scores = { // Diabetes - {{"diabetes", "STree"}, 0.81641}, {{"diabetes", "ODTE"}, 0.84635}, {{"diabetes", "SVC"}, 0.76823}, {{"diabetes", "RandomForest"}, 1.0}, + {{"diabetes", "STree"}, 0.81641}, {{"diabetes", "ODTE"}, 0.854166687}, {{"diabetes", "SVC"}, 0.76823}, {{"diabetes", "RandomForest"}, 1.0}, // Ecoli - {{"ecoli", "STree"}, 0.8125}, {{"ecoli", "ODTE"}, 0.84821}, {{"ecoli", "SVC"}, 0.89583}, {{"ecoli", "RandomForest"}, 1.0}, + {{"ecoli", "STree"}, 0.8125}, {{"ecoli", "ODTE"}, 0.875}, {{"ecoli", "SVC"}, 0.89583}, {{"ecoli", "RandomForest"}, 1.0}, // Glass - {{"glass", "STree"}, 0.57009}, {{"glass", "ODTE"}, 0.77103}, {{"glass", "SVC"}, 0.35514}, {{"glass", "RandomForest"}, 1.0}, + {{"glass", "STree"}, 0.57009}, {{"glass", "ODTE"}, 0.76168227}, {{"glass", "SVC"}, 0.35514}, {{"glass", "RandomForest"}, 1.0}, // Iris {{"iris", "STree"}, 0.99333}, {{"iris", "ODTE"}, 0.98667}, {{"iris", "SVC"}, 0.97333}, {{"iris", "RandomForest"}, 1.0}, }; @@ -33,10 +33,10 @@ TEST_CASE("Test Python Classifiers score", "[PyClassifiers]") {"RandomForest", new pywrap::RandomForest()} }; map versions = { - {"ODTE", "0.3.6"}, + {"ODTE", "1.0.0"}, {"STree", "1.3.2"}, - {"SVC", "1.5.0"}, - {"RandomForest", "1.5.0"} + {"SVC", "1.5.1"}, + {"RandomForest", "1.5.1"} }; auto clf = models[name]; @@ -68,8 +68,10 @@ TEST_CASE("Classifiers features", "[PyClassifiers]") } TEST_CASE("Get num features & num edges", "[PyClassifiers]") { + auto estimators = nlohmann::json::parse("{ \"n_estimators\": 10 }"); auto raw = RawDatasets("iris", false); auto clf = pywrap::ODTE(); + clf.setHyperparameters(estimators); clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); REQUIRE(clf.getNumberOfNodes() == 50); REQUIRE(clf.getNumberOfEdges() == 30); @@ -115,22 +117,22 @@ TEST_CASE("XGBoost", "[PyClassifiers]") auto score = clf.score(raw.Xt, raw.yt); REQUIRE(score == Catch::Approx(0.98).epsilon(raw.epsilon)); } -TEST_CASE("XGBoost predict proba", "[PyClassifiers]") -{ - auto raw = RawDatasets("iris", true); - auto clf = pywrap::XGBoost(); - clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); - nlohmann::json hyperparameters = { "n_jobs=1" }; - clf.setHyperparameters(hyperparameters); - auto predict = clf.predict(raw.Xt); - // for (int row = 0; row < predict.size(0); row++) { - // auto sum = 0.0; - // for (int col = 0; col < predict.size(1); col++) { - // std::cout << std::setw(12) << std::setprecision(10) << predict[row][col].item() << " "; - // sum += predict[row][col].item(); - // } - // std::cout << std::endl; - // // REQUIRE(sum == Catch::Approx(1.0).epsilon(raw.epsilon)); - // } - std::cout << predict << std::endl; -} \ No newline at end of file +// TEST_CASE("XGBoost predict proba", "[PyClassifiers]") +// { +// auto raw = RawDatasets("iris", true); +// auto clf = pywrap::XGBoost(); +// clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); +// // nlohmann::json hyperparameters = { "n_jobs=1" }; +// // clf.setHyperparameters(hyperparameters); +// auto predict = clf.predict(raw.Xt); +// for (int row = 0; row < predict.size(0); row++) { +// auto sum = 0.0; +// for (int col = 0; col < predict.size(1); col++) { +// std::cout << std::setw(12) << std::setprecision(10) << predict[row][col].item() << " "; +// sum += predict[row][col].item(); +// } +// std::cout << std::endl; +// // REQUIRE(sum == Catch::Approx(1.0).epsilon(raw.epsilon)); +// } +// std::cout << predict << std::endl; +// } \ No newline at end of file diff --git a/tests/data/glass.arff b/tests/data/glass.arff index 3bcb091..abd9e3c 100755 --- a/tests/data/glass.arff +++ b/tests/data/glass.arff @@ -114,7 +114,7 @@ @attribute 'Ca' real @attribute 'Ba' real @attribute 'Fe' real -@attribute 'Type' {'build wind float', 'build wind non-float', 'vehic wind float', 'vehic wind non-float', containers, tableware, headlamps} +@attribute 'Type' { 'build wind float', 'build wind non-float', 'vehic wind float', 'vehic wind non-float', containers, tableware, headlamps} @data 1.51793,12.79,3.5,1.12,73.03,0.64,8.77,0,0,'build wind float' 1.51643,12.16,3.52,1.35,72.89,0.57,8.53,0,0,'vehic wind float'