Continue refactoring
This commit is contained in:
@@ -46,7 +46,6 @@ namespace platform {
|
|||||||
bool debug = false;
|
bool debug = false;
|
||||||
Xaode aode_;
|
Xaode aode_;
|
||||||
torch::Tensor weights_;
|
torch::Tensor weights_;
|
||||||
bool fitted = false;
|
|
||||||
const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
|
const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
|
||||||
inline void normalize_weights(int num_instances)
|
inline void normalize_weights(int num_instances)
|
||||||
{
|
{
|
||||||
|
@@ -182,7 +182,6 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------
|
// -------------------------------------------------------
|
||||||
// computeProbabilities
|
// computeProbabilities
|
||||||
// -------------------------------------------------------
|
// -------------------------------------------------------
|
||||||
@@ -257,7 +256,6 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
matrixState_ = MatrixState::PROBS;
|
matrixState_ = MatrixState::PROBS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------
|
// -------------------------------------------------------
|
||||||
// predict_proba_spode
|
// predict_proba_spode
|
||||||
// -------------------------------------------------------
|
// -------------------------------------------------------
|
||||||
@@ -372,10 +370,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
void normalize(std::vector<double>& probs) const
|
void normalize(std::vector<double>& probs) const
|
||||||
{
|
{
|
||||||
double sum = 0;
|
double sum = std::accumulate(probs.begin(), probs.end(), 0.0);
|
||||||
for (double d : probs) {
|
|
||||||
sum += d;
|
|
||||||
}
|
|
||||||
if (std::isnan(sum)) {
|
if (std::isnan(sum)) {
|
||||||
throw std::runtime_error("Can't normalize array. Sum is NaN.");
|
throw std::runtime_error("Can't normalize array. Sum is NaN.");
|
||||||
}
|
}
|
||||||
@@ -420,7 +415,6 @@ namespace platform {
|
|||||||
active_parents.pop_back();
|
active_parents.pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// -----------
|
// -----------
|
||||||
// MEMBER DATA
|
// MEMBER DATA
|
||||||
|
Reference in New Issue
Block a user