Fix tests
This commit is contained in:
@@ -224,14 +224,34 @@ namespace bayesnet {
|
||||
if (!fitted) {
|
||||
throw std::logic_error("You must call fit() before calling predict()");
|
||||
}
|
||||
// Ensure the sample size is equal to the number of features
|
||||
if (samples.size(0) != features.size() - 1) {
|
||||
throw std::invalid_argument("(T) Sample size (" + std::to_string(samples.size(0)) +
|
||||
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
|
||||
}
|
||||
torch::Tensor result;
|
||||
std::vector<std::thread> threads;
|
||||
std::mutex mtx;
|
||||
auto& semaphore = CountingSemaphore::getInstance();
|
||||
result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
|
||||
for (int i = 0; i < samples.size(1); ++i) {
|
||||
const torch::Tensor sample = samples.index({ "...", i });
|
||||
auto worker = [&](const torch::Tensor& sample, int i) {
|
||||
std::string threadName = "PredictWorker-" + std::to_string(i);
|
||||
pthread_setname_np(pthread_self(), threadName.c_str());
|
||||
semaphore.acquire();
|
||||
auto psample = predict_sample(sample);
|
||||
auto temp = torch::tensor(psample, torch::kFloat64);
|
||||
// result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64));
|
||||
result.index_put_({ i, "..." }, temp);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
result.index_put_({ i, "..." }, temp);
|
||||
}
|
||||
semaphore.release();
|
||||
};
|
||||
for (int i = 0; i < samples.size(1); ++i) {
|
||||
const torch::Tensor sample = samples.index({ "...", i });
|
||||
threads.emplace_back(worker, sample, i);
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
if (proba)
|
||||
return result;
|
||||
@@ -256,18 +276,36 @@ namespace bayesnet {
|
||||
if (!fitted) {
|
||||
throw std::logic_error("You must call fit() before calling predict()");
|
||||
}
|
||||
std::vector<int> predictions;
|
||||
// Ensure the sample size is equal to the number of features
|
||||
if (tsamples.size() != features.size() - 1) {
|
||||
throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) +
|
||||
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
|
||||
}
|
||||
std::vector<int> predictions(tsamples[0].size(), 0);
|
||||
std::vector<int> sample;
|
||||
std::vector<std::thread> threads;
|
||||
std::mutex mtx;
|
||||
auto& semaphore = CountingSemaphore::getInstance();
|
||||
auto worker = [&](const std::vector<int>& sample, const int row, std::vector<int>& predictions) {
|
||||
semaphore.acquire();
|
||||
auto classProbabilities = predict_sample(sample);
|
||||
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
|
||||
int predictedClass = distance(classProbabilities.begin(), maxElem);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
predictions[row] = predictedClass;
|
||||
}
|
||||
semaphore.release();
|
||||
};
|
||||
for (int row = 0; row < tsamples[0].size(); ++row) {
|
||||
sample.clear();
|
||||
for (int col = 0; col < tsamples.size(); ++col) {
|
||||
sample.push_back(tsamples[col][row]);
|
||||
}
|
||||
std::vector<double> classProbabilities = predict_sample(sample);
|
||||
// Find the class with the maximum posterior probability
|
||||
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
|
||||
int predictedClass = distance(classProbabilities.begin(), maxElem);
|
||||
predictions.push_back(predictedClass);
|
||||
threads.emplace_back(worker, sample, row, std::ref(predictions));
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
@@ -278,6 +316,11 @@ namespace bayesnet {
|
||||
if (!fitted) {
|
||||
throw std::logic_error("You must call fit() before calling predict_proba()");
|
||||
}
|
||||
// Ensure the sample size is equal to the number of features
|
||||
if (tsamples.size() != features.size() - 1) {
|
||||
throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) +
|
||||
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
|
||||
}
|
||||
std::vector<std::vector<double>> predictions;
|
||||
std::vector<int> sample;
|
||||
for (int row = 0; row < tsamples[0].size(); ++row) {
|
||||
@@ -303,11 +346,6 @@ namespace bayesnet {
|
||||
// Return 1xn std::vector of probabilities
|
||||
std::vector<double> Network::predict_sample(const std::vector<int>& sample)
|
||||
{
|
||||
// Ensure the sample size is equal to the number of features
|
||||
if (sample.size() != features.size() - 1) {
|
||||
throw std::invalid_argument("Sample size (" + std::to_string(sample.size()) +
|
||||
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
|
||||
}
|
||||
std::map<std::string, int> evidence;
|
||||
for (int i = 0; i < sample.size(); ++i) {
|
||||
evidence[features[i]] = sample[i];
|
||||
@@ -317,56 +355,23 @@ namespace bayesnet {
|
||||
// Return 1xn std::vector of probabilities
|
||||
std::vector<double> Network::predict_sample(const torch::Tensor& sample)
|
||||
{
|
||||
// Ensure the sample size is equal to the number of features
|
||||
if (sample.size(0) != features.size() - 1) {
|
||||
throw std::invalid_argument("Sample size (" + std::to_string(sample.size(0)) +
|
||||
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
|
||||
}
|
||||
std::map<std::string, int> evidence;
|
||||
for (int i = 0; i < sample.size(0); ++i) {
|
||||
evidence[features[i]] = sample[i].item<int>();
|
||||
}
|
||||
return exactInference(evidence);
|
||||
}
|
||||
double Network::computeFactor(std::map<std::string, int>& completeEvidence)
|
||||
{
|
||||
double result = 1.0;
|
||||
for (auto& node : getNodes()) {
|
||||
result *= node.second->getFactorValue(completeEvidence);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
|
||||
{
|
||||
|
||||
|
||||
//Implementar una cache para acelerar la inferencia.
|
||||
// Cambiar la estrategia de crear hilos en la inferencia (por nodos como en fit?)
|
||||
|
||||
|
||||
|
||||
std::vector<double> result(classNumStates, 0.0);
|
||||
std::vector<std::thread> threads;
|
||||
std::mutex mtx;
|
||||
auto& semaphore = CountingSemaphore::getInstance();
|
||||
auto worker = [&](int i) {
|
||||
semaphore.acquire();
|
||||
std::string threadName = "InferenceWorker-" + std::to_string(i);
|
||||
pthread_setname_np(pthread_self(), threadName.c_str());
|
||||
auto completeEvidence = std::map<std::string, int>(evidence);
|
||||
completeEvidence[getClassName()] = i;
|
||||
double factor = computeFactor(completeEvidence);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
result[i] = factor;
|
||||
}
|
||||
semaphore.release();
|
||||
};
|
||||
auto completeEvidence = std::map<std::string, int>(evidence);
|
||||
for (int i = 0; i < classNumStates; ++i) {
|
||||
threads.emplace_back(worker, i);
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
completeEvidence[getClassName()] = i;
|
||||
double partial = 1.0;
|
||||
for (auto& node : getNodes()) {
|
||||
partial *= node.second->getFactorValue(completeEvidence);
|
||||
}
|
||||
result[i] = partial;
|
||||
}
|
||||
// Normalize result
|
||||
double sum = std::accumulate(result.begin(), result.end(), 0.0);
|
||||
|
@@ -21,11 +21,9 @@ namespace bayesnet {
|
||||
class Network {
|
||||
public:
|
||||
Network();
|
||||
explicit Network(float);
|
||||
explicit Network(const Network&);
|
||||
~Network() = default;
|
||||
torch::Tensor& getSamples();
|
||||
float getMaxThreads() const;
|
||||
void addNode(const std::string&);
|
||||
void addEdge(const std::string&, const std::string&);
|
||||
std::map<std::string, std::unique_ptr<Node>>& getNodes();
|
||||
@@ -64,7 +62,6 @@ namespace bayesnet {
|
||||
std::vector<double> predict_sample(const std::vector<int>&);
|
||||
std::vector<double> predict_sample(const torch::Tensor&);
|
||||
std::vector<double> exactInference(std::map<std::string, int>&);
|
||||
double computeFactor(std::map<std::string, int>&);
|
||||
void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing);
|
||||
void checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
|
||||
void setStates(const std::map<std::string, std::vector<int>>&);
|
||||
|
Reference in New Issue
Block a user