diff --git a/src/experimental_clfs/XA1DE.cpp b/src/experimental_clfs/XA1DE.cpp index 72d2ac2..1ff388b 100644 --- a/src/experimental_clfs/XA1DE.cpp +++ b/src/experimental_clfs/XA1DE.cpp @@ -22,22 +22,21 @@ namespace platform { throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump()); } } - void XA1DE::fit(std::vector> X, std::vector y, std::vector weights) + XA1DE& XA1DE::fit(std::vector>& X, std::vector& y, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) { Timer timer, timert; timer.start(); timert.start(); - weights_ = weights; std::vector> instances = X; instances.push_back(y); int num_instances = instances[0].size(); int num_attributes = instances.size(); normalize_weights(num_instances); - std::vector states; + std::vector statesv; for (int i = 0; i < num_attributes; i++) { - states.push_back(*max_element(instances[i].begin(), instances[i].end()) + 1); + statesv.push_back(*max_element(instances[i].begin(), instances[i].end()) + 1); } - aode_.init(states); + aode_.init(statesv); aode_.duration_first += timer.getDuration(); timer.start(); std::vector instance; for (int n_instance = 0; n_instance < num_instances; n_instance++) { @@ -62,6 +61,7 @@ namespace platform { std::cout << "* Time to build the model: " << timert.getDuration() << " seconds" << std::endl; // exit(1); } + return *this; } std::vector> XA1DE::predict_proba(std::vector>& test_data) { diff --git a/src/experimental_clfs/XA1DE.h b/src/experimental_clfs/XA1DE.h index 6839f96..2f7c59f 100644 --- a/src/experimental_clfs/XA1DE.h +++ b/src/experimental_clfs/XA1DE.h @@ -21,34 +21,34 @@ namespace platform { public: XA1DE(); virtual ~XA1DE() = default; - void setDebug(bool debug) { this->debug = debug; } - std::vector> predict_proba_threads(const std::vector>& test_data); + std::vector> predict_proba_threads(const std::vector>& test_data); + std::vector> predict_proba(std::vector>& X) override; + float score(std::vector>& X, std::vector& y) override; XA1DE& fit(std::vector>& X, std::vector& y, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) override; - XA1DE& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) override; - XA1DE& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) override; - XA1DE& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states, const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override; + XA1DE& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) override { return *this; }; + XA1DE& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states, const bayesnet::Smoothing_t smoothing) override { return *this; }; + XA1DE& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states, const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override { return *this; }; + torch::Tensor predict(torch::Tensor& X) override { return torch::zeros(0); }; + std::vector predict(std::vector>& X) override; + torch::Tensor predict_proba(torch::Tensor& X) override { return torch::zeros(0); }; + int getNumberOfNodes() const override { return 0; }; int getNumberOfEdges() const override { return 0; }; int getNumberOfStates() const override { return 0; }; int getClassNumStates() const override { return 0; }; - torch::Tensor predict(torch::Tensor& X) override { return torch::zeros(0); }; - std::vector predict(std::vector>& X) override; - torch::Tensor predict_proba(torch::Tensor& X) override { return torch::zeros(0); }; - std::vector> predict_proba(std::vector>& X) override; bayesnet::status_t getStatus() const override { return status; } - std::string getVersion() override { return { project_version.begin(), project_version.end() }; }; + std::string getVersion() override { return version; }; float score(torch::Tensor& X, torch::Tensor& y) override { return 0; }; - float score(std::vector>& X, std::vector& y) override; std::vector show() const override { return {}; } std::vector topological_order() override { return {}; } std::vector getNotes() const override { return notes; } std::string dump_cpt() const override { return ""; } void setHyperparameters(const nlohmann::json& hyperparameters) override; - std::vector& getValidHyperparameters() { return validHyperparameters; } + void setDebug(bool debug) { this->debug = debug; } protected: - void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override; + void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override {}; private: inline void normalize_weights(int num_instances) @@ -61,7 +61,6 @@ namespace platform { w = w * num_instances / sum; } } - // The instances of the dataset Xaode aode_; std::vector weights_; CountingSemaphore& semaphore_; @@ -69,6 +68,7 @@ namespace platform { bayesnet::status_t status = bayesnet::NORMAL; std::vector notes; bool use_threads = false; + std::string version = "0.9.7"; }; } #endif // XA1DE_H \ No newline at end of file