Fix tests

This commit is contained in:
2024-06-21 13:58:42 +02:00
parent 02bcab01be
commit 8e9090d283
5 changed files with 76 additions and 75 deletions

View File

@@ -224,14 +224,34 @@ namespace bayesnet {
if (!fitted) {
throw std::logic_error("You must call fit() before calling predict()");
}
// Ensure the sample size is equal to the number of features
if (samples.size(0) != features.size() - 1) {
throw std::invalid_argument("(T) Sample size (" + std::to_string(samples.size(0)) +
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
}
torch::Tensor result;
std::vector<std::thread> threads;
std::mutex mtx;
auto& semaphore = CountingSemaphore::getInstance();
result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
for (int i = 0; i < samples.size(1); ++i) {
const torch::Tensor sample = samples.index({ "...", i });
auto worker = [&](const torch::Tensor& sample, int i) {
std::string threadName = "PredictWorker-" + std::to_string(i);
pthread_setname_np(pthread_self(), threadName.c_str());
semaphore.acquire();
auto psample = predict_sample(sample);
auto temp = torch::tensor(psample, torch::kFloat64);
// result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64));
result.index_put_({ i, "..." }, temp);
{
std::lock_guard<std::mutex> lock(mtx);
result.index_put_({ i, "..." }, temp);
}
semaphore.release();
};
for (int i = 0; i < samples.size(1); ++i) {
const torch::Tensor sample = samples.index({ "...", i });
threads.emplace_back(worker, sample, i);
}
for (auto& thread : threads) {
thread.join();
}
if (proba)
return result;
@@ -256,18 +276,36 @@ namespace bayesnet {
if (!fitted) {
throw std::logic_error("You must call fit() before calling predict()");
}
std::vector<int> predictions;
// Ensure the sample size is equal to the number of features
if (tsamples.size() != features.size() - 1) {
throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) +
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
}
std::vector<int> predictions(tsamples[0].size(), 0);
std::vector<int> sample;
std::vector<std::thread> threads;
std::mutex mtx;
auto& semaphore = CountingSemaphore::getInstance();
auto worker = [&](const std::vector<int>& sample, const int row, std::vector<int>& predictions) {
semaphore.acquire();
auto classProbabilities = predict_sample(sample);
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
int predictedClass = distance(classProbabilities.begin(), maxElem);
{
std::lock_guard<std::mutex> lock(mtx);
predictions[row] = predictedClass;
}
semaphore.release();
};
for (int row = 0; row < tsamples[0].size(); ++row) {
sample.clear();
for (int col = 0; col < tsamples.size(); ++col) {
sample.push_back(tsamples[col][row]);
}
std::vector<double> classProbabilities = predict_sample(sample);
// Find the class with the maximum posterior probability
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
int predictedClass = distance(classProbabilities.begin(), maxElem);
predictions.push_back(predictedClass);
threads.emplace_back(worker, sample, row, std::ref(predictions));
}
for (auto& thread : threads) {
thread.join();
}
return predictions;
}
@@ -278,6 +316,11 @@ namespace bayesnet {
if (!fitted) {
throw std::logic_error("You must call fit() before calling predict_proba()");
}
// Ensure the sample size is equal to the number of features
if (tsamples.size() != features.size() - 1) {
throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) +
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
}
std::vector<std::vector<double>> predictions;
std::vector<int> sample;
for (int row = 0; row < tsamples[0].size(); ++row) {
@@ -303,11 +346,6 @@ namespace bayesnet {
// Return 1xn std::vector of probabilities
std::vector<double> Network::predict_sample(const std::vector<int>& sample)
{
// Ensure the sample size is equal to the number of features
if (sample.size() != features.size() - 1) {
throw std::invalid_argument("Sample size (" + std::to_string(sample.size()) +
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
}
std::map<std::string, int> evidence;
for (int i = 0; i < sample.size(); ++i) {
evidence[features[i]] = sample[i];
@@ -317,56 +355,23 @@ namespace bayesnet {
// Return 1xn std::vector of probabilities
std::vector<double> Network::predict_sample(const torch::Tensor& sample)
{
// Ensure the sample size is equal to the number of features
if (sample.size(0) != features.size() - 1) {
throw std::invalid_argument("Sample size (" + std::to_string(sample.size(0)) +
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
}
std::map<std::string, int> evidence;
for (int i = 0; i < sample.size(0); ++i) {
evidence[features[i]] = sample[i].item<int>();
}
return exactInference(evidence);
}
double Network::computeFactor(std::map<std::string, int>& completeEvidence)
{
double result = 1.0;
for (auto& node : getNodes()) {
result *= node.second->getFactorValue(completeEvidence);
}
return result;
}
std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
{
//Implementar una cache para acelerar la inferencia.
// Cambiar la estrategia de crear hilos en la inferencia (por nodos como en fit?)
std::vector<double> result(classNumStates, 0.0);
std::vector<std::thread> threads;
std::mutex mtx;
auto& semaphore = CountingSemaphore::getInstance();
auto worker = [&](int i) {
semaphore.acquire();
std::string threadName = "InferenceWorker-" + std::to_string(i);
pthread_setname_np(pthread_self(), threadName.c_str());
auto completeEvidence = std::map<std::string, int>(evidence);
completeEvidence[getClassName()] = i;
double factor = computeFactor(completeEvidence);
{
std::lock_guard<std::mutex> lock(mtx);
result[i] = factor;
}
semaphore.release();
};
auto completeEvidence = std::map<std::string, int>(evidence);
for (int i = 0; i < classNumStates; ++i) {
threads.emplace_back(worker, i);
}
for (auto& thread : threads) {
thread.join();
completeEvidence[getClassName()] = i;
double partial = 1.0;
for (auto& node : getNodes()) {
partial *= node.second->getFactorValue(completeEvidence);
}
result[i] = partial;
}
// Normalize result
double sum = std::accumulate(result.begin(), result.end(), 0.0);

View File

@@ -21,11 +21,9 @@ namespace bayesnet {
class Network {
public:
Network();
explicit Network(float);
explicit Network(const Network&);
~Network() = default;
torch::Tensor& getSamples();
float getMaxThreads() const;
void addNode(const std::string&);
void addEdge(const std::string&, const std::string&);
std::map<std::string, std::unique_ptr<Node>>& getNodes();
@@ -64,7 +62,6 @@ namespace bayesnet {
std::vector<double> predict_sample(const std::vector<int>&);
std::vector<double> predict_sample(const torch::Tensor&);
std::vector<double> exactInference(std::map<std::string, int>&);
double computeFactor(std::map<std::string, int>&);
void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing);
void checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
void setStates(const std::map<std::string, std::vector<int>>&);