First approx with const 1 weights
This commit is contained in:
parent
90c92e5c56
commit
af0419c9da
5
.vscode/launch.json
vendored
5
.vscode/launch.json
vendored
@ -25,12 +25,13 @@
|
||||
"program": "${workspaceFolder}/build/src/Platform/main",
|
||||
"args": [
|
||||
"-m",
|
||||
"SPODELd",
|
||||
"SPODE",
|
||||
"-p",
|
||||
"/Users/rmontanana/Code/discretizbench/datasets",
|
||||
"--stratified",
|
||||
"--discretize",
|
||||
"-d",
|
||||
"iris"
|
||||
"letter"
|
||||
],
|
||||
"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||
},
|
||||
|
@ -37,7 +37,8 @@ namespace bayesnet {
|
||||
}
|
||||
void Classifier::trainModel()
|
||||
{
|
||||
model.fit(dataset, features, className, states);
|
||||
const torch::Tensor weights = torch::ones({ m });
|
||||
model.fit(dataset, weights, features, className, states);
|
||||
}
|
||||
// X is nxm where n is the number of features and m the number of samples
|
||||
Classifier& Classifier::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||
|
@ -104,8 +104,11 @@ namespace bayesnet {
|
||||
{
|
||||
return nodes;
|
||||
}
|
||||
void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
|
||||
void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states, const torch::Tensor& weights)
|
||||
{
|
||||
if (weights.size(0) != n_samples) {
|
||||
throw invalid_argument("Weights must have the same number of elements as samples in Network::fit");
|
||||
}
|
||||
if (n_samples != n_samples_y) {
|
||||
throw invalid_argument("X and y must have the same number of samples in Network::fit (" + to_string(n_samples) + " != " + to_string(n_samples_y) + ")");
|
||||
}
|
||||
@ -136,28 +139,29 @@ namespace bayesnet {
|
||||
classNumStates = nodes[className]->getNumStates();
|
||||
}
|
||||
// X comes in nxm, where n is the number of features and m the number of samples
|
||||
void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
|
||||
void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
|
||||
{
|
||||
checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states);
|
||||
checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
|
||||
this->className = className;
|
||||
Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
|
||||
samples = torch::cat({ X , ytmp }, 0);
|
||||
for (int i = 0; i < featureNames.size(); ++i) {
|
||||
auto row_feature = X.index({ i, "..." });
|
||||
}
|
||||
completeFit(states);
|
||||
completeFit(states, weights);
|
||||
}
|
||||
void Network::fit(const torch::Tensor& samples, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
|
||||
void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
|
||||
{
|
||||
checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states);
|
||||
checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
|
||||
this->className = className;
|
||||
this->samples = samples;
|
||||
completeFit(states);
|
||||
completeFit(states, weights);
|
||||
}
|
||||
// input_data comes in nxm, where n is the number of features and m the number of samples
|
||||
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
|
||||
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<float>& weights_, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
|
||||
{
|
||||
checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states);
|
||||
const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
|
||||
checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
|
||||
this->className = className;
|
||||
// Build tensor of samples (nxm) (n+1 because of the class)
|
||||
samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
|
||||
@ -165,9 +169,9 @@ namespace bayesnet {
|
||||
samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
|
||||
}
|
||||
samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
|
||||
completeFit(states);
|
||||
completeFit(states, weights);
|
||||
}
|
||||
void Network::completeFit(const map<string, vector<int>>& states)
|
||||
void Network::completeFit(const map<string, vector<int>>& states, const torch::Tensor& weights)
|
||||
{
|
||||
setStates(states);
|
||||
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
|
||||
@ -182,7 +186,7 @@ namespace bayesnet {
|
||||
while (nextNodeIndex < nodes.size()) {
|
||||
unique_lock<mutex> lock(mtx);
|
||||
cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; });
|
||||
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() {
|
||||
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads, &weights]() {
|
||||
while (true) {
|
||||
unique_lock<mutex> lock(mtx);
|
||||
if (nextNodeIndex >= nodes.size()) {
|
||||
@ -191,7 +195,7 @@ namespace bayesnet {
|
||||
auto& pair = *std::next(nodes.begin(), nextNodeIndex);
|
||||
++nextNodeIndex;
|
||||
lock.unlock();
|
||||
pair.second->computeCPT(samples, features, laplaceSmoothing);
|
||||
pair.second->computeCPT(samples, features, laplaceSmoothing, weights);
|
||||
lock.lock();
|
||||
nodes[pair.first] = std::move(pair.second);
|
||||
lock.unlock();
|
||||
|
@ -20,8 +20,8 @@ namespace bayesnet {
|
||||
vector<double> predict_sample(const torch::Tensor&);
|
||||
vector<double> exactInference(map<string, int>&);
|
||||
double computeFactor(map<string, int>&);
|
||||
void completeFit(const map<string, vector<int>>&);
|
||||
void checkFitData(int n_features, int n_samples, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>&);
|
||||
void completeFit(const map<string, vector<int>>& states, const torch::Tensor& weights);
|
||||
void checkFitData(int n_features, int n_samples, int n_samples_y, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states, const torch::Tensor& weights);
|
||||
void setStates(const map<string, vector<int>>&);
|
||||
public:
|
||||
Network();
|
||||
@ -39,9 +39,9 @@ namespace bayesnet {
|
||||
int getNumEdges() const;
|
||||
int getClassNumStates() const;
|
||||
string getClassName() const;
|
||||
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const map<string, vector<int>>&);
|
||||
void fit(const torch::Tensor&, const torch::Tensor&, const vector<string>&, const string&, const map<string, vector<int>>&);
|
||||
void fit(const torch::Tensor&, const vector<string>&, const string&, const map<string, vector<int>>&);
|
||||
void fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<float>& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
|
||||
void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
|
||||
void fit(const torch::Tensor& samples, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states);
|
||||
vector<int> predict(const vector<vector<int>>&); // Return mx1 vector of predictions
|
||||
torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions
|
||||
torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba);
|
||||
|
@ -84,7 +84,7 @@ namespace bayesnet {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
void Node::computeCPT(const torch::Tensor& dataset, const vector<string>& features, const int laplaceSmoothing)
|
||||
void Node::computeCPT(const torch::Tensor& dataset, const vector<string>& features, const int laplaceSmoothing, const torch::Tensor& weights)
|
||||
{
|
||||
dimensions.clear();
|
||||
// Get dimensions of the CPT
|
||||
@ -111,7 +111,7 @@ namespace bayesnet {
|
||||
coordinates.push_back(dataset.index({ parent_index, n_sample }));
|
||||
}
|
||||
// Increment the count of the corresponding coordinate
|
||||
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1);
|
||||
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<float>());
|
||||
}
|
||||
// Normalize the counts
|
||||
cpTable = cpTable / cpTable.sum(0);
|
||||
|
@ -26,7 +26,7 @@ namespace bayesnet {
|
||||
vector<Node*>& getParents();
|
||||
vector<Node*>& getChildren();
|
||||
torch::Tensor& getCPT();
|
||||
void computeCPT(const torch::Tensor&, const vector<string>&, const int);
|
||||
void computeCPT(const torch::Tensor& dataset, const vector<string>& features, const int laplaceSmoothing, const torch::Tensor& weights);
|
||||
int getNumStates() const;
|
||||
void setNumStates(int);
|
||||
unsigned minFill();
|
||||
|
@ -65,7 +65,8 @@ namespace bayesnet {
|
||||
//Update new states of the feature/node
|
||||
states[pFeatures[index]] = xStates;
|
||||
}
|
||||
model.fit(pDataset, pFeatures, pClassName, states);
|
||||
const torch::Tensor weights = torch::ones({ pDataset.size(1) }, torch::kFloat);
|
||||
model.fit(pDataset, weights, pFeatures, pClassName, states);
|
||||
}
|
||||
return states;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user