From 619276a5ea945e68e7478a277efd897f415aacb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 10 Mar 2025 21:44:12 +0100 Subject: [PATCH] Update sample_xpode --- bayesnet/classifiers/.XSPODE.h.swp | Bin 12288 -> 0 bytes lib/folding | 2 +- sample/sample_xspode.cc | 14 +------------- tests/lib/catch2 | 2 +- 4 files changed, 3 insertions(+), 15 deletions(-) delete mode 100644 bayesnet/classifiers/.XSPODE.h.swp diff --git a/bayesnet/classifiers/.XSPODE.h.swp b/bayesnet/classifiers/.XSPODE.h.swp deleted file mode 100644 index 024d03cb7f1e6ccd02db405cb17f4cb0787a22d0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeHNO>Epm6rMt9`D;ON0tsnI#IoBayGh}~rd=rMhD2(9=%!VrDq3dk@p_Qgw#Kt* z($d0(BU})01@Y4x65@yu5(RPV9fSl3Zd`inp+(}GS$pl>Zb%|}DYBz4o_OB8dGDL| z-q^~trp_!Z(ih6p1lOa4eDKS)_RY60lds1Kk)6O-%*T)J-1?gDw)q^l1Ifx>U@>ni z$;+w2%4}fs%FAq%M@w8)TDsw*VR;EO6r?=mI_h#(+;AgdSi8r~}UcPXfO^K*)PQ1Gs)aA-@A3165!W_;WuY-vXZj z?*I{SYabyufh)l0z`MW_upih5{Ba*49|Chg9XJSlx0jF*m<65yu7jH&fbW4TKmg1G z$AK4t8ld?*1mxHMmVt5Z+rpuh3hkXtY0eXJiBMdZX3R}DXM?V?RQN&(=7|ejHfOYECy>kZVgSVVnZf|DlOwjWjPJ^{Uv(Gtelto4hD3J=h*?eO{Kk%~} zjjyp?R7A8T+LZf2x9w6DP&e>w$~-TbjWgAiJfeKV;-P}Q)vPXUTUgv^8l7O;+~E;@ zgKjjV5Kj^ZOQRPcVKoZpdNak$QoYU*>%dc&g?xfCM{%rrAUiq|m=Zzi)Tm|ZD@;~E zD?Decp@($LC?3t~b5u=wm|AP2oJ#G7JJvQ1m@`=`QUnsKZLB|Mt?k}yqo!}NvDv7x ztZt|4!R>XPHPv`PEuBtLO0X695td5Fd88!rjLx1s-e@e$E!0mnmzGzXa|^4h^%Gc~ zF{8j7C8I8a#}5>3aqd$`D8=oue@mMlhuYcXG7KALn+{fUYJ!%_WyE;;sJJvaU0S0> z+}HQ5X^Bf24%D@<}5~J9cE?IHqBHM7T9#=HxaKPlPGu)kK*v~ zy}<1-->+6*j{?6y-9f45C_arwWbd~A>%1J?2n%!*yT!{PiXBToD5a*iyHw++rNP_}&s3(5$!iM?@|+ODMEW;Ndk zJz)ujSK1JZQ`8hh;evj>O6omcRE5`YT-}ya?yOhu)XV?pP|F-Ngqwj+$hynBQf$mX0{uTdf+Vw7DaEp1!P)3kQBV z=>vj#cIdxdT18{!#M#M{!opU^Cl_pN%;v45N*5Pa`~7pm>, std::vector, std::vector, states[className] = std::vector(*max_element(y.begin(), y.end()) + 1); iota(begin(states.at(className)), end(states.at(className)), 0); return { Xr, y, features, className, states }; - // Xd = torch::zeros({ static_cast(Xr.size()), static_cast(Xr[0].size()) }, torch::kInt32); - // for (int i = 0; i < features.size(); ++i) { - // states[features[i]] = std::vector(*max_element(Xr[i].begin(), Xr[i].end()) + 1); - // auto item = states.at(features[i]); - // iota(begin(item), end(item), 0); - // Xd.index_put_({ i, "..." }, torch::tensor(Xr[i], torch::kInt32)); - // } - // states[className] = std::vector(*max_element(y.begin(), y.end()) + 1); - // iota(begin(states.at(className)), end(states.at(className)), 0); - // return { Xd, torch::tensor(y, torch::kInt32), features, className, states }; } int main(int argc, char* argv[]) @@ -62,13 +52,11 @@ int main(int argc, char* argv[]) return 1; } std::string file_name = argv[1]; - // auto clf = bayesnet::BoostAODE(false); // false for not using voting in predict - bayesnet::BaseClassifier* clf = new bayesnet::XSpode(0); // false for not using voting in predict + bayesnet::BaseClassifier* clf = new bayesnet::XSpode(0); std::cout << "Library version: " << clf->getVersion() << std::endl; auto [X, y, features, className, states] = loadDataset(file_name, true); torch::Tensor weights = torch::full({ static_cast(X[0].size()) }, 1.0 / X[0].size(), torch::kDouble); clf->fit(X, y, features, className, states, bayesnet::Smoothing_t::ORIGINAL); - // auto score = clf.score(X, y); auto score = clf->score(X, y); std::cout << "File: " << file_name << " Model: XSpode(0) score: " << score << std::endl; delete clf; diff --git a/tests/lib/catch2 b/tests/lib/catch2 index 506276c..0321d2f 160000 --- a/tests/lib/catch2 +++ b/tests/lib/catch2 @@ -1 +1 @@ -Subproject commit 506276c59217429c93abd2fe9507c7f45eb81072 +Subproject commit 0321d2fce328b5e2ad106a8230ff20e0d5bf5501