diff --git a/bayesclass/BaseClassifier.cc b/bayesclass/BaseClassifier.cc index e560db6..c3d05f8 100644 --- a/bayesclass/BaseClassifier.cc +++ b/bayesclass/BaseClassifier.cc @@ -8,7 +8,7 @@ namespace bayesnet { BaseClassifier& BaseClassifier::build(vector& features, string className, map>& states) { - dataset = torch::cat({ X, y.view({150, 1}) }, 1); + dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1); this->features = features; this->className = className; this->states = states; @@ -86,8 +86,8 @@ namespace bayesnet { Tensor y_pred = predict(X); return (y_pred == y).sum().item() / y.size(0); } - void BaseClassifier::show() + vector BaseClassifier::show() { - model.show(); + return model.show(); } } \ No newline at end of file diff --git a/bayesclass/BaseClassifier.h b/bayesclass/BaseClassifier.h index e2e5080..149636e 100644 --- a/bayesclass/BaseClassifier.h +++ b/bayesclass/BaseClassifier.h @@ -28,9 +28,8 @@ namespace bayesnet { BaseClassifier& fit(vector>& X, vector& y, vector& features, string className, map>& states); Tensor predict(Tensor& X); float score(Tensor& X, Tensor& y); - void show(); + vector show(); }; - } #endif diff --git a/bayesclass/KDB.h b/bayesclass/KDB.h index 64f1f7b..a8d154c 100644 --- a/bayesclass/KDB.h +++ b/bayesclass/KDB.h @@ -10,9 +10,9 @@ namespace bayesnet { float theta; void add_m_edges(int idx, vector& S, Tensor& weights); protected: - void train(); + void train() override; public: - KDB(int k, float theta=0.03); + KDB(int k, float theta = 0.03); }; } #endif \ No newline at end of file diff --git a/bayesclass/Network.cc b/bayesclass/Network.cc index 495ee4f..52414c3 100644 --- a/bayesclass/Network.cc +++ b/bayesclass/Network.cc @@ -245,16 +245,18 @@ namespace bayesnet { } return result; } - void Network::show() + vector Network::show() { + vector result; // Draw the network for (auto node : nodes) { - cout << node.first << " -> "; + string line = node.first + " -> "; for (auto child : node.second->getChildren()) { - cout << child->getName() << ", "; + line += child->getName() + ", "; } - cout << endl; + result.push_back(line); } + return result; } } diff --git a/bayesclass/Network.h b/bayesclass/Network.h index bc6402d..0459bd1 100644 --- a/bayesclass/Network.h +++ b/bayesclass/Network.h @@ -44,7 +44,7 @@ namespace bayesnet { torch::Tensor conditionalEdgeWeight(); vector> predict_proba(const vector>&); double score(const vector>&, const vector&); - void show(); + vector show(); inline string version() { return "0.1.0"; } }; }