diff --git a/.vscode/launch.json b/.vscode/launch.json index 7241ae2..e3c35bb 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -25,12 +25,12 @@ "program": "${workspaceFolder}/build/src/Platform/main", "args": [ "-m", - "AODELd", + "TANLd", "-p", "/Users/rmontanana/Code/discretizbench/datasets", "--stratified", "-d", - "iris" + "vehicle" ], "cwd": "/Users/rmontanana/Code/discretizbench", }, diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index 7f41839..b3317f4 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -37,7 +37,7 @@ namespace bayesnet { } void Classifier::trainModel() { - model.fit(dataset, features, className); + model.fit(dataset, features, className, states); } // X is nxm where n is the number of features and m the number of samples Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index 59903f3..8a4106c 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -104,7 +104,7 @@ namespace bayesnet { { return nodes; } - void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector& featureNames, const string& className) + void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector& featureNames, const string& className, const map>& states) { if (n_samples != n_samples_y) { throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")"); @@ -122,39 +122,42 @@ namespace bayesnet { if (find(features.begin(), features.end(), feature) == features.end()) { throw invalid_argument("Feature " + feature + " not found in Network::features"); } + if (states.find(feature) == states.end()) { + throw invalid_argument("Feature " + feature + " not found in states"); + } } } - void Network::setStates() + void Network::setStates(const map>& states) { // Set states to every Node in the network for (int i = 0; i < features.size(); ++i) { - nodes[features[i]]->setNumStates(static_cast(torch::max(samples.index({ i, "..." })).item()) + 1); + nodes[features[i]]->setNumStates(states.at(features[i]).size()); } classNumStates = nodes[className]->getNumStates(); } // X comes in nxm, where n is the number of features and m the number of samples - void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const vector& featureNames, const string& className) + void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const vector& featureNames, const string& className, const map>& states) { - checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className); + checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states); this->className = className; Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1); samples = torch::cat({ X , ytmp }, 0); for (int i = 0; i < featureNames.size(); ++i) { auto row_feature = X.index({ i, "..." }); } - completeFit(); + completeFit(states); } - void Network::fit(const torch::Tensor& samples, const vector& featureNames, const string& className) + void Network::fit(const torch::Tensor& samples, const vector& featureNames, const string& className, const map>& states) { - checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className); + checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states); this->className = className; this->samples = samples; - completeFit(); + completeFit(states); } // input_data comes in nxm, where n is the number of features and m the number of samples - void Network::fit(const vector>& input_data, const vector& labels, const vector& featureNames, const string& className) + void Network::fit(const vector>& input_data, const vector& labels, const vector& featureNames, const string& className, const map>& states) { - checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className); + checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states); this->className = className; // Build tensor of samples (nxm) (n+1 because of the class) samples = torch::zeros({ static_cast(input_data.size() + 1), static_cast(input_data[0].size()) }, torch::kInt32); @@ -162,11 +165,11 @@ namespace bayesnet { samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32)); } samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32)); - completeFit(); + completeFit(states); } - void Network::completeFit() + void Network::completeFit(const map>& states) { - setStates(); + setStates(states); int maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); if (maxThreadsRunning < 1) { maxThreadsRunning = 1; @@ -212,7 +215,7 @@ namespace bayesnet { torch::Tensor result; result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64); for (int i = 0; i < samples.size(1); ++i) { - auto sample = samples.index({ "...", i }); + const Tensor sample = samples.index({ "...", i }); auto psample = predict_sample(sample); auto temp = torch::tensor(psample, torch::kFloat64); // result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64)); diff --git a/src/BayesNet/Network.h b/src/BayesNet/Network.h index eb65957..d8db620 100644 --- a/src/BayesNet/Network.h +++ b/src/BayesNet/Network.h @@ -20,13 +20,9 @@ namespace bayesnet { vector predict_sample(const torch::Tensor&); vector exactInference(map&); double computeFactor(map&); - double mutual_info(torch::Tensor&, torch::Tensor&); - double entropy(torch::Tensor&); - double conditionalEntropy(torch::Tensor&, torch::Tensor&); - double mutualInformation(torch::Tensor&, torch::Tensor&); - void completeFit(); - void checkFitData(int n_features, int n_samples, int n_samples_y, const vector& featureNames, const string& className); - void setStates(); + void completeFit(const map>&); + void checkFitData(int n_features, int n_samples, int n_samples_y, const vector& featureNames, const string& className, const map>&); + void setStates(const map>&); public: Network(); explicit Network(float, int); @@ -43,13 +39,11 @@ namespace bayesnet { int getNumEdges() const; int getClassNumStates() const; string getClassName() const; - void fit(const vector>&, const vector&, const vector&, const string&); - void fit(const torch::Tensor&, const torch::Tensor&, const vector&, const string&); - void fit(const torch::Tensor&, const vector&, const string&); + void fit(const vector>&, const vector&, const vector&, const string&, const map>&); + void fit(const torch::Tensor&, const torch::Tensor&, const vector&, const string&, const map>&); + void fit(const torch::Tensor&, const vector&, const string&, const map>&); vector predict(const vector>&); // Return mx1 vector of predictions torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions - //Computes the conditional edge weight of variable index u and v conditioned on class_node - torch::Tensor conditionalEdgeWeight(); torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba); vector> predict_proba(const vector>&); // Return mxn vector of probabilities torch::Tensor predict_proba(const torch::Tensor&); // Return mxn tensor of probabilities diff --git a/src/BayesNet/Proposal.cc b/src/BayesNet/Proposal.cc index 80cb7ee..78d5225 100644 --- a/src/BayesNet/Proposal.cc +++ b/src/BayesNet/Proposal.cc @@ -64,7 +64,7 @@ namespace bayesnet { //Update new states of the feature/node states[pFeatures[index]] = xStates; } - model.fit(pDataset, pFeatures, pClassName); + model.fit(pDataset, pFeatures, pClassName, states); } } void Proposal::fit_local_discretization(map>& states, torch::Tensor& y) diff --git a/src/Platform/Report.cc b/src/Platform/Report.cc index 90aad2b..3693248 100644 --- a/src/Platform/Report.cc +++ b/src/Platform/Report.cc @@ -4,6 +4,7 @@ namespace platform { string headerLine(const string& text) { int n = MAXL - text.length() - 3; + n = n < 0 ? 0 : n; return "* " + text + string(n, ' ') + "*\n"; } string Report::fromVector(const string& key) @@ -13,7 +14,7 @@ namespace platform { for (auto& item : data[key]) { result += to_string(item) + ", "; } - return "[" + result.substr(0, result.length() - 2) + "]"; + return "[" + result.substr(0, result.size() - 2) + "]"; } string fVector(const json& data) { @@ -21,7 +22,7 @@ namespace platform { for (const auto& item : data) { result += to_string(item) + ", "; } - return "[" + result.substr(0, result.length() - 2) + "]"; + return "[" + result.substr(0, result.size() - 2) + "]"; } void Report::show() {