Fix XSpode
This commit is contained in:
@@ -390,11 +390,7 @@ namespace bayesnet {
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
auto result_v = predict(X_);
|
||||
torch::Tensor result;
|
||||
for (int i = 0; i < result_v.size(); ++i) {
|
||||
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kInt32));
|
||||
}
|
||||
return result;
|
||||
return torch::tensor(result_v, torch::kInt32);
|
||||
}
|
||||
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
|
||||
{
|
||||
|
@@ -27,16 +27,16 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
{
|
||||
map <pair<std::string, std::string>, float> scores{
|
||||
// Diabetes
|
||||
{{"diabetes", "AODE"}, 0.82161}, {{"diabetes", "KDB"}, 0.852865}, {{"diabetes", "XSPODE"}, 0.802083}, {{"diabetes", "SPODE"}, 0.802083}, {{"diabetes", "TAN"}, 0.821615},
|
||||
{{"diabetes", "AODE"}, 0.82161}, {{"diabetes", "KDB"}, 0.852865}, {{"diabetes", "XSPODE"}, 0.631510437f}, {{"diabetes", "SPODE"}, 0.802083}, {{"diabetes", "TAN"}, 0.821615},
|
||||
{{"diabetes", "AODELd"}, 0.8125f}, {{"diabetes", "KDBLd"}, 0.80208f}, {{"diabetes", "SPODELd"}, 0.7890625f}, {{"diabetes", "TANLd"}, 0.803385437f}, {{"diabetes", "BoostAODE"}, 0.83984f},
|
||||
// Ecoli
|
||||
{{"ecoli", "AODE"}, 0.889881}, {{"ecoli", "KDB"}, 0.889881}, {{"ecoli", "XSPODE"}, 0.880952}, {{"ecoli", "SPODE"}, 0.880952}, {{"ecoli", "TAN"}, 0.892857},
|
||||
{{"ecoli", "AODE"}, 0.889881}, {{"ecoli", "KDB"}, 0.889881}, {{"ecoli", "XSPODE"}, 0.696428597f}, {{"ecoli", "SPODE"}, 0.880952}, {{"ecoli", "TAN"}, 0.892857},
|
||||
{{"ecoli", "AODELd"}, 0.875f}, {{"ecoli", "KDBLd"}, 0.880952358f}, {{"ecoli", "SPODELd"}, 0.839285731f}, {{"ecoli", "TANLd"}, 0.848214269f}, {{"ecoli", "BoostAODE"}, 0.89583f},
|
||||
// Glass
|
||||
{{"glass", "AODE"}, 0.79439}, {{"glass", "KDB"}, 0.827103}, {{"glass", "XSPODE"}, 0.775701}, {{"glass", "SPODE"}, 0.775701}, {{"glass", "TAN"}, 0.827103},
|
||||
{{"glass", "AODE"}, 0.79439}, {{"glass", "KDB"}, 0.827103}, {{"glass", "XSPODE"}, 0.775701}, {{"glass", "SPODE"}, 0.775701}, {{"glass", "TAN"}, 0.827103},
|
||||
{{"glass", "AODELd"}, 0.799065411f}, {{"glass", "KDBLd"}, 0.82710278f}, {{"glass", "SPODELd"}, 0.780373812f}, {{"glass", "TANLd"}, 0.869158864f}, {{"glass", "BoostAODE"}, 0.84579f},
|
||||
// Iris
|
||||
{{"iris", "AODE"}, 0.973333}, {{"iris", "KDB"}, 0.973333}, {{"iris", "XSPODE"}, 0.973333}, {{"iris", "SPODE"}, 0.973333}, {{"iris", "TAN"}, 0.973333},
|
||||
{{"iris", "AODE"}, 0.973333}, {{"iris", "KDB"}, 0.973333}, {{"iris", "XSPODE"}, 0.853333354f}, {{"iris", "SPODE"}, 0.973333}, {{"iris", "TAN"}, 0.973333},
|
||||
{{"iris", "AODELd"}, 0.973333}, {{"iris", "KDBLd"}, 0.973333}, {{"iris", "SPODELd"}, 0.96f}, {{"iris", "TANLd"}, 0.97333f}, {{"iris", "BoostAODE"}, 0.98f}
|
||||
};
|
||||
std::map<std::string, bayesnet::BaseClassifier*> models{
|
||||
@@ -46,8 +46,7 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
{"XSPODE", new bayesnet::XSpode(1)}, {"SPODE", new bayesnet::SPODE(1)}, {"SPODELd", new bayesnet::SPODELd(1)},
|
||||
{"TAN", new bayesnet::TAN()}, {"TANLd", new bayesnet::TANLd()}
|
||||
};
|
||||
// std::string name = GENERATE("AODE", "AODELd", "KDB", "KDBLd", "SPODE", "XSPODE", "SPODELd", "TAN", "TANLd");
|
||||
std::string name = GENERATE("XSPODE");
|
||||
std::string name = GENERATE("AODE", "AODELd", "KDB", "KDBLd", "SPODE", "XSPODE", "SPODELd", "TAN", "TANLd");
|
||||
auto clf = models[name];
|
||||
|
||||
SECTION("Test " + name + " classifier")
|
||||
@@ -56,14 +55,9 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
auto clf = models[name];
|
||||
auto discretize = name.substr(name.length() - 2) != "Ld";
|
||||
auto raw = RawDatasets(file_name, discretize);
|
||||
if (name == "XSPODE") {
|
||||
std::cout << "Fitting XSPODE" << std::endl;
|
||||
} else {
|
||||
std::cout << "Fitting something else [" << name << "]" << std::endl;
|
||||
}
|
||||
clf->fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
clf->fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score = clf->score(raw.Xt, raw.yt);
|
||||
std::cout << "Classifier: " << name << " File: " << file_name << " Score: " << score << " expected = " << scores[{file_name, name}] << std::endl;
|
||||
// std::cout << "Classifier: " << name << " File: " << file_name << " Score: " << score << " expected = " << scores[{file_name, name}] << std::endl;
|
||||
INFO("Classifier: " << name << " File: " << file_name);
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, name}]).epsilon(raw.epsilon));
|
||||
REQUIRE(clf->getStatus() == bayesnet::NORMAL);
|
||||
|
Reference in New Issue
Block a user