Update sample_xpode

This commit is contained in:
2025-03-10 21:44:12 +01:00
parent e681099360
commit 619276a5ea
4 changed files with 3 additions and 15 deletions

View File

@@ -43,16 +43,6 @@ tuple<std::vector<std::vector<int>>, std::vector<int>, std::vector<std::string>,
states[className] = std::vector<int>(*max_element(y.begin(), y.end()) + 1);
iota(begin(states.at(className)), end(states.at(className)), 0);
return { Xr, y, features, className, states };
// Xd = torch::zeros({ static_cast<int>(Xr.size()), static_cast<int>(Xr[0].size()) }, torch::kInt32);
// for (int i = 0; i < features.size(); ++i) {
// states[features[i]] = std::vector<int>(*max_element(Xr[i].begin(), Xr[i].end()) + 1);
// auto item = states.at(features[i]);
// iota(begin(item), end(item), 0);
// Xd.index_put_({ i, "..." }, torch::tensor(Xr[i], torch::kInt32));
// }
// states[className] = std::vector<int>(*max_element(y.begin(), y.end()) + 1);
// iota(begin(states.at(className)), end(states.at(className)), 0);
// return { Xd, torch::tensor(y, torch::kInt32), features, className, states };
}
int main(int argc, char* argv[])
@@ -62,13 +52,11 @@ int main(int argc, char* argv[])
return 1;
}
std::string file_name = argv[1];
// auto clf = bayesnet::BoostAODE(false); // false for not using voting in predict
bayesnet::BaseClassifier* clf = new bayesnet::XSpode(0); // false for not using voting in predict
bayesnet::BaseClassifier* clf = new bayesnet::XSpode(0);
std::cout << "Library version: " << clf->getVersion() << std::endl;
auto [X, y, features, className, states] = loadDataset(file_name, true);
torch::Tensor weights = torch::full({ static_cast<long>(X[0].size()) }, 1.0 / X[0].size(), torch::kDouble);
clf->fit(X, y, features, className, states, bayesnet::Smoothing_t::ORIGINAL);
// auto score = clf.score(X, y);
auto score = clf->score(X, y);
std::cout << "File: " << file_name << " Model: XSpode(0) score: " << score << std::endl;
delete clf;