// *************************************************************** // SPDX-FileCopyrightText: Copyright 2025 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #ifndef EXPCLF_H #define EXPCLF_H #include #include #include #include #include #include #include #include "common/Timer.hpp" #include "CountingSemaphore.hpp" #include "Xaode.hpp" namespace platform { class ExpClf : public bayesnet::Boost { public: ExpClf(); virtual ~ExpClf() = default; std::vector predict(std::vector>& X) override; torch::Tensor predict(torch::Tensor& X) override; torch::Tensor predict_proba(torch::Tensor& X) override; std::vector predict_spode(std::vector>& test_data, int parent); std::vector> predict_proba(const std::vector>& X); float score(std::vector>& X, std::vector& y) override; float score(torch::Tensor& X, torch::Tensor& y) override; int getNumberOfNodes() const override; int getNumberOfEdges() const override; int getNumberOfStates() const override; int getClassNumStates() const override; std::vector show() const override { return {}; } std::vector topological_order() override { return {}; } std::string dump_cpt() const override { return ""; } void setDebug(bool debug) { this->debug = debug; } bayesnet::status_t getStatus() const override { return status; } std::vector getNotes() const override { return notes; } std::vector graph(const std::string& title = "") const override { return {}; } void add_active_parents(const std::vector& active_parents); void add_active_parent(int parent); void remove_last_parent(); protected: bool debug = false; Xaode aode_; torch::Tensor weights_; const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted"; inline void normalize_weights(int num_instances) { double sum = weights_.sum().item(); if (sum == 0) { weights_ = torch::full({ num_instances }, 1.0); } else { for (int i = 0; i < weights_.size(0); ++i) { weights_[i] = weights_[i].item() * num_instances / sum; } } } private: CountingSemaphore& semaphore_; }; } #endif // EXPCLF_H