diff --git a/src/experimental_clfs/Xaode.hpp b/src/experimental_clfs/Xaode.hpp index 37a2101..cb021e0 100644 --- a/src/experimental_clfs/Xaode.hpp +++ b/src/experimental_clfs/Xaode.hpp @@ -38,43 +38,6 @@ namespace platform { // Classifiers interface // all parameter decide if the model is initialized with all the parents active or none of them // - void fit(std::vector>& X, std::vector& y, const std::vector& features, const std::string& className, std::map>& states, const torch::Tensor& weights, const bool all_parents) - { - int num_instances = X[0].size(); - int n_features_ = X.size(); - - significance_models_.resize(n_features_, (all_parents ? 1.0 : 0.0)); - std::vector statesv; - for (int i = 0; i < n_features_; i++) { - if (all_parents) active_parents.push_back(i); - statesv.push_back(*max_element(X[i].begin(), X[i].end()) + 1); - } - statesv.push_back(*max_element(y.begin(), y.end()) + 1); - // std::cout << "* States: " << statesv << std::endl; - // std::cout << "* Weights: " << weights_ << std::endl; - // std::cout << "* Instances: " << num_instances << std::endl; - // std::cout << "* Attributes: " << n_features_ +1 << std::endl; - // std::cout << "* y: " << y << std::endl; - // std::cout << "* x shape: " << X.size() << "x" << X[0].size() << std::endl; - // for (int i = 0; i < n_features_; i++) { - // std::cout << "* " << features[i] << ": " << instances[i] << std::endl; - // } - // std::cout << "Starting to build the model" << std::endl; - init(statesv); - std::vector instance(n_features_ + 1); - for (int n_instance = 0; n_instance < num_instances; n_instance++) { - for (int feature = 0; feature < n_features_; feature++) { - instance[feature] = X[feature][n_instance]; - } - instance[n_features_] = y[n_instance]; - addSample(instance, weights[n_instance].item()); - } - computeProbabilities(); - } - // ------------------------------------------------------- - // init - // ------------------------------------------------------- - // // states.size() = nFeatures + 1, // where states.back() = number of class states. // @@ -84,29 +47,24 @@ namespace platform { // // Internally, in COUNTS mode, data_ accumulates raw counts, then // computeProbabilities(...) normalizes them into conditionals. - // - void init(const std::vector& states) + void fit(std::vector>& X, std::vector& y, const std::vector& features, const std::string& className, std::map>& states, const torch::Tensor& weights, const bool all_parents) { - // - // Check Valid input data - // - if (matrixState_ != MatrixState::EMPTY) { - throw std::logic_error("Xaode: already initialized."); - } - states_ = states; - nFeatures_ = static_cast(states_.size()) - 1; - if (nFeatures_ < 1) { - throw std::invalid_argument("Xaode: need at least 1 feature plus class states."); + int num_instances = X[0].size(); + nFeatures_ = X.size(); + + significance_models_.resize(nFeatures_, (all_parents ? 1.0 : 0.0)); + for (int i = 0; i < nFeatures_; i++) { + if (all_parents) active_parents.push_back(i); + states_.push_back(*max_element(X[i].begin(), X[i].end()) + 1); } + states_.push_back(*max_element(y.begin(), y.end()) + 1); + // statesClass_ = states_.back(); - if (statesClass_ <= 0) { - throw std::invalid_argument("Xaode: class states must be > 0."); - } // // Initialize data structures // active_parents.resize(nFeatures_); - int totalStates = std::accumulate(states.begin(), states.end(), 0) - statesClass_; + int totalStates = std::accumulate(states_.begin(), states_.end(), 0) - statesClass_; // For p(x_i=si | c), we store them in a 1D array classFeatureProbs_ after we compute. // We'll need the offsets for each feature i in featureClassOffset_. @@ -140,14 +98,19 @@ namespace platform { classCounts_.resize(statesClass_, 0.0); matrixState_ = MatrixState::COUNTS; + // + // Add samples + // + std::vector instance(nFeatures_ + 1); + for (int n_instance = 0; n_instance < num_instances; n_instance++) { + for (int feature = 0; feature < nFeatures_; feature++) { + instance[feature] = X[feature][n_instance]; + } + instance[nFeatures_] = y[n_instance]; + addSample(instance, weights[n_instance].item()); + } + computeProbabilities(); } - - // Returns current mode: INIT, COUNTS or PROBS - MatrixState state() const - { - return matrixState_; - } - // Optional: print a quick summary void show() const { @@ -171,7 +134,6 @@ namespace platform { for (double d : data_) std::cout << d << " "; std::cout << std::endl; std::cout << "--------------------------------" << std::endl; } - // ------------------------------------------------------- // addSample (only in COUNTS mode) // ------------------------------------------------------- @@ -424,7 +386,11 @@ namespace platform { probs[i] /= sum; } } - + // Returns current mode: INIT, COUNTS or PROBS + MatrixState state() const + { + return matrixState_; + } int statesClass() const { return statesClass_;