From ac89cefab375cc72ca42f9a7c131be6fb6c0c14d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 18 Feb 2025 12:07:56 +0100 Subject: [PATCH] Add conversion methods --- src/experimental_clfs/XA1DE.cpp | 47 +++++++++++++++++++++++++++++++++ src/experimental_clfs/XA1DE.h | 7 ++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/experimental_clfs/XA1DE.cpp b/src/experimental_clfs/XA1DE.cpp index 1ff388b..cb113d6 100644 --- a/src/experimental_clfs/XA1DE.cpp +++ b/src/experimental_clfs/XA1DE.cpp @@ -61,6 +61,7 @@ namespace platform { std::cout << "* Time to build the model: " << timert.getDuration() << " seconds" << std::endl; // exit(1); } + fitted = true; return *this; } std::vector> XA1DE::predict_proba(std::vector>& test_data) @@ -115,6 +116,9 @@ namespace platform { } std::vector XA1DE::predict(std::vector>& test_data) { + if (!fitted) { + throw std::logic_error(CLASSIFIER_NOT_FITTED); + } auto probabilities = predict_proba(test_data); std::vector predictions(probabilities.size(), 0); @@ -147,4 +151,47 @@ namespace platform { } return static_cast(correct) / predictions.size(); } + std::vector> to_matrix(const torch::Tensor& X) + { + // Ensure tensor is contiguous in memory + auto X_contig = X.contiguous(); + + // Access tensor data pointer directly + auto data_ptr = X_contig.data_ptr(); + + // IF you are using int64_t as the data type, use the following line + //auto data_ptr = X_contig.data_ptr(); + //std::vector> data(X.size(0), std::vector(X.size(1))); + + // Prepare output container + std::vector> data(X.size(0), std::vector(X.size(1))); + + // Fill the 2D vector in a single loop using pointer arithmetic + int rows = X.size(0); + int cols = X.size(1); + for (int i = 0; i < rows; ++i) { + std::copy(data_ptr + i * cols, data_ptr + (i + 1) * cols, data[i].begin()); + } + return data; + } + std::vector to_vector(const torch::Tensor& y) + { + // Ensure the tensor is contiguous in memory + auto y_contig = y.contiguous(); + + // Access data pointer + auto data_ptr = y_contig.data_ptr(); + + // Prepare output container + std::vector data(y.size(0)); + + // Copy data efficiently + std::copy(data_ptr, data_ptr + y.size(0), data.begin()); + + return data; + } + XA1DE& XA1DE::fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) + { + return fit(to_matrix(X), to_vector(y), features, className, states, smoothing); + } } \ No newline at end of file diff --git a/src/experimental_clfs/XA1DE.h b/src/experimental_clfs/XA1DE.h index 2f7c59f..640a50d 100644 --- a/src/experimental_clfs/XA1DE.h +++ b/src/experimental_clfs/XA1DE.h @@ -21,16 +21,18 @@ namespace platform { public: XA1DE(); virtual ~XA1DE() = default; + const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted"; std::vector> predict_proba_threads(const std::vector>& test_data); std::vector> predict_proba(std::vector>& X) override; float score(std::vector>& X, std::vector& y) override; + std::vector predict(std::vector>& X) override; XA1DE& fit(std::vector>& X, std::vector& y, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) override; + XA1DE& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) override { return *this; }; XA1DE& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) override { return *this; }; XA1DE& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states, const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override { return *this; }; torch::Tensor predict(torch::Tensor& X) override { return torch::zeros(0); }; - std::vector predict(std::vector>& X) override; torch::Tensor predict_proba(torch::Tensor& X) override { return torch::zeros(0); }; int getNumberOfNodes() const override { return 0; }; @@ -61,6 +63,8 @@ namespace platform { w = w * num_instances / sum; } } + std::vector to_vector(const torch::Tensor& y); + std::vector> to_matrix(const torch::Tensor& X); Xaode aode_; std::vector weights_; CountingSemaphore& semaphore_; @@ -69,6 +73,7 @@ namespace platform { std::vector notes; bool use_threads = false; std::string version = "0.9.7"; + bool fitted = false; }; } #endif // XA1DE_H \ No newline at end of file