#include "ArffFiles.h" #include "CPPFImdlp.h" #include "BoostAODE.h" std::vector discretizeDataset(std::vector& X, mdlp::labels_t& y) { std::vector Xd; auto fimdlp = mdlp::CPPFImdlp(); for (int i = 0; i < X.size(); i++) { fimdlp.fit(X[i], y); mdlp::labels_t& xd = fimdlp.transform(X[i]); Xd.push_back(xd); } return Xd; } tuple, std::string, map>> loadDataset(const std::string& name, bool class_last) { auto handler = ArffFiles(); handler.load(name, class_last); // Get Dataset X, y std::vector& X = handler.getX(); mdlp::labels_t& y = handler.getY(); // Get className & Features auto className = handler.getClassName(); std::vector features; auto attributes = handler.getAttributes(); transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; }); torch::Tensor Xd; auto states = map>(); auto Xr = discretizeDataset(X, y); Xd = torch::zeros({ static_cast(Xr.size()), static_cast(Xr[0].size()) }, torch::kInt32); for (int i = 0; i < features.size(); ++i) { states[features[i]] = std::vector(*max_element(Xr[i].begin(), Xr[i].end()) + 1); auto item = states.at(features[i]); iota(begin(item), end(item), 0); Xd.index_put_({ i, "..." }, torch::tensor(Xr[i], torch::kInt32)); } states[className] = std::vector(*max_element(y.begin(), y.end()) + 1); iota(begin(states.at(className)), end(states.at(className)), 0); return { Xd, torch::tensor(y, torch::kInt32), features, className, states }; } int main(int argc, char* argv[]) { if (argc < 2) { std::cerr << "Usage: " << argv[0] << " " << std::endl; return 1; } std::string file_name = argv[1]; torch::Tensor X, y; std::vector features; std::string className; map> states; auto clf = bayesnet::BoostAODE(false); // false for not using voting in predict std::cout << "Library version: " << clf.getVersion() << std::endl; tie(X, y, features, className, states) = loadDataset(file_name, true); clf.fit(X, y, features, className, states); auto score = clf.score(X, y); std::cout << "File: " << file_name << " score: " << score << std::endl; return 0; }