First implemented aproximation

This commit is contained in:
2025-01-31 13:55:46 +01:00
parent b90e558238
commit fb957ac3fe
4 changed files with 353 additions and 1 deletions

View File

@@ -0,0 +1,53 @@
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#ifndef WA2DE_H
#define WA2DE_H
#include "Ensemble.h"
#include <torch/torch.h>
#include <vector>
#include <map>
#include <nlohmann/json.hpp>
namespace bayesnet {
/**
* Geoffrey I. Webb's A2DE (Averaged 2-Dependence Estimators) classifier
* Implements the A2DE algorithm as an ensemble of SPODE models.
*/
class WA2DE : public Ensemble {
public:
explicit WA2DE(bool predict_voting = false);
virtual ~WA2DE() {};
// Override method to set hyperparameters
void setHyperparameters(const nlohmann::json& hyperparameters) override;
// Graph visualization function
std::vector<std::string> graph(const std::string& title = "A2DE") const override;
torch::Tensor computeProbabilities(const torch::Tensor& data) const;
double score(const torch::Tensor& X, const torch::Tensor& y);
protected:
// Model-building function
void buildModel(const torch::Tensor& weights) override;
private:
int num_classes_; // Number of classes
int num_attributes_; // Number of attributes
std::vector<int> attribute_cardinalities_; // Cardinalities of attributes
// Frequency counts (similar to Java implementation)
std::vector<double> class_counts_; // Class frequency
std::vector<std::vector<std::vector<double>>> freq_attr_class_; // P(A | C)
std::vector<std::vector<std::vector<std::vector<std::vector<double>>>>> freq_pair_class_; // P(A_i, A_j | C)
double total_count_; // Total instance count
bool weighted_a2de_; // Whether to use weighted A2DE
double smoothing_factor_; // Smoothing parameter (default: Laplace)
torch::Tensor AODEConditionalProb(const torch::Tensor& data);
void trainModel(const torch::Tensor& data, const Smoothing_t smoothing);
int toIntValue(int attributeIndex, float value) const;
};
}
#endif