From d1b235261ec71155d9595e87ad5cdcb5a204895a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 10 Mar 2025 14:21:01 +0100 Subject: [PATCH] Fix XSpode --- bayesnet/classifiers/.XSPODE.h.swp | Bin 0 -> 12288 bytes bayesnet/classifiers/XSPODE.cc | 18 +++++++++++++----- bayesnet/classifiers/XSPODE.h | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) create mode 100644 bayesnet/classifiers/.XSPODE.h.swp diff --git a/bayesnet/classifiers/.XSPODE.h.swp b/bayesnet/classifiers/.XSPODE.h.swp new file mode 100644 index 0000000000000000000000000000000000000000..024d03cb7f1e6ccd02db405cb17f4cb0787a22d0 GIT binary patch literal 12288 zcmeHNO>Epm6rMt9`D;ON0tsnI#IoBayGh}~rd=rMhD2(9=%!VrDq3dk@p_Qgw#Kt* z($d0(BU})01@Y4x65@yu5(RPV9fSl3Zd`inp+(}GS$pl>Zb%|}DYBz4o_OB8dGDL| z-q^~trp_!Z(ih6p1lOa4eDKS)_RY60lds1Kk)6O-%*T)J-1?gDw)q^l1Ifx>U@>ni z$;+w2%4}fs%FAq%M@w8)TDsw*VR;EO6r?=mI_h#(+;AgdSi8r~}UcPXfO^K*)PQ1Gs)aA-@A3165!W_;WuY-vXZj z?*I{SYabyufh)l0z`MW_upih5{Ba*49|Chg9XJSlx0jF*m<65yu7jH&fbW4TKmg1G z$AK4t8ld?*1mxHMmVt5Z+rpuh3hkXtY0eXJiBMdZX3R}DXM?V?RQN&(=7|ejHfOYECy>kZVgSVVnZf|DlOwjWjPJ^{Uv(Gtelto4hD3J=h*?eO{Kk%~} zjjyp?R7A8T+LZf2x9w6DP&e>w$~-TbjWgAiJfeKV;-P}Q)vPXUTUgv^8l7O;+~E;@ zgKjjV5Kj^ZOQRPcVKoZpdNak$QoYU*>%dc&g?xfCM{%rrAUiq|m=Zzi)Tm|ZD@;~E zD?Decp@($LC?3t~b5u=wm|AP2oJ#G7JJvQ1m@`=`QUnsKZLB|Mt?k}yqo!}NvDv7x ztZt|4!R>XPHPv`PEuBtLO0X695td5Fd88!rjLx1s-e@e$E!0mnmzGzXa|^4h^%Gc~ zF{8j7C8I8a#}5>3aqd$`D8=oue@mMlhuYcXG7KALn+{fUYJ!%_WyE;;sJJvaU0S0> z+}HQ5X^Bf24%D@<}5~J9cE?IHqBHM7T9#=HxaKPlPGu)kK*v~ zy}<1-->+6*j{?6y-9f45C_arwWbd~A>%1J?2n%!*yT!{PiXBToD5a*iyHw++rNP_}&s3(5$!iM?@|+ODMEW;Ndk zJz)ujSK1JZQ`8hh;evj>O6omcRE5`YT-}ya?yOhu)XV?pP|F-Ngqwj+$hynBQf$mX0{uTdf+Vw7DaEp1!P)3kQBV z=>vj#cIdxdT18{!#M#M{!opU^Cl_pN%;v45N*5Pa`~7pm> XSpode::predict_proba(const std::vector>& test_data) + std::vector> XSpode::predict_proba(std::vector>& test_data) { int test_size = test_data[0].size(); int sample_size = test_data.size(); @@ -390,14 +390,22 @@ namespace bayesnet { torch::Tensor XSpode::predict(torch::Tensor& X) { auto X_ = TensorUtils::to_matrix(X); - auto result = predict(X_); - return TensorUtils::to_tensor(result); + auto result_v = predict(X_); + torch::Tensor result; + for (int i = 0; i < result_v.size(); ++i) { + result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kInt32)); + } + return result; } torch::Tensor XSpode::predict_proba(torch::Tensor& X) { auto X_ = TensorUtils::to_matrix(X); - auto result = predict_proba(X_); - return TensorUtils::to_tensor(result); + auto result_v = predict_proba(X_); + torch::Tensor result; + for (int i = 0; i < result_v.size(); ++i) { + result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble)); + } + return result; } torch::Tensor XSpode::predict(torch::Tensor& X) { diff --git a/bayesnet/classifiers/XSPODE.h b/bayesnet/classifiers/XSPODE.h index 7f57b22..fe29f34 100644 --- a/bayesnet/classifiers/XSPODE.h +++ b/bayesnet/classifiers/XSPODE.h @@ -18,7 +18,7 @@ namespace bayesnet { public: explicit XSpode(int spIndex); std::vector predict_proba(const std::vector& instance) const; - std::vector> predict_proba(const std::vector>& test_data); + std::vector> predict_proba(std::vector>& X) override; int predict(const std::vector& instance) const; void normalize(std::vector& v) const; std::string to_string() const;