Fix XSpode
This commit is contained in:
BIN
bayesnet/classifiers/.XSPODE.h.swp
Normal file
BIN
bayesnet/classifiers/.XSPODE.h.swp
Normal file
Binary file not shown.
@@ -264,7 +264,7 @@ namespace bayesnet {
|
|||||||
normalize(probs);
|
normalize(probs);
|
||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
std::vector<std::vector<double>> XSpode::predict_proba(const std::vector<std::vector<int>>& test_data)
|
std::vector<std::vector<double>> XSpode::predict_proba(std::vector<std::vector<int>>& test_data)
|
||||||
{
|
{
|
||||||
int test_size = test_data[0].size();
|
int test_size = test_data[0].size();
|
||||||
int sample_size = test_data.size();
|
int sample_size = test_data.size();
|
||||||
@@ -390,14 +390,22 @@ namespace bayesnet {
|
|||||||
torch::Tensor XSpode::predict(torch::Tensor& X)
|
torch::Tensor XSpode::predict(torch::Tensor& X)
|
||||||
{
|
{
|
||||||
auto X_ = TensorUtils::to_matrix(X);
|
auto X_ = TensorUtils::to_matrix(X);
|
||||||
auto result = predict(X_);
|
auto result_v = predict(X_);
|
||||||
return TensorUtils::to_tensor(result);
|
torch::Tensor result;
|
||||||
|
for (int i = 0; i < result_v.size(); ++i) {
|
||||||
|
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kInt32));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
|
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
|
||||||
{
|
{
|
||||||
auto X_ = TensorUtils::to_matrix(X);
|
auto X_ = TensorUtils::to_matrix(X);
|
||||||
auto result = predict_proba(X_);
|
auto result_v = predict_proba(X_);
|
||||||
return TensorUtils::to_tensor<double>(result);
|
torch::Tensor result;
|
||||||
|
for (int i = 0; i < result_v.size(); ++i) {
|
||||||
|
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
torch::Tensor XSpode::predict(torch::Tensor& X)
|
torch::Tensor XSpode::predict(torch::Tensor& X)
|
||||||
{
|
{
|
||||||
|
@@ -18,7 +18,7 @@ namespace bayesnet {
|
|||||||
public:
|
public:
|
||||||
explicit XSpode(int spIndex);
|
explicit XSpode(int spIndex);
|
||||||
std::vector<double> predict_proba(const std::vector<int>& instance) const;
|
std::vector<double> predict_proba(const std::vector<int>& instance) const;
|
||||||
std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>& test_data);
|
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
|
||||||
int predict(const std::vector<int>& instance) const;
|
int predict(const std::vector<int>& instance) const;
|
||||||
void normalize(std::vector<double>& v) const;
|
void normalize(std::vector<double>& v) const;
|
||||||
std::string to_string() const;
|
std::string to_string() const;
|
||||||
|
Reference in New Issue
Block a user