Add Cuda iniitialization in Classifier
This commit is contained in:
@@ -9,7 +9,15 @@
|
||||
#include "Classifier.h"
|
||||
|
||||
namespace bayesnet {
|
||||
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
|
||||
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false), device(torch::kCPU)
|
||||
{
|
||||
if (torch::cuda::is_available()) {
|
||||
device = torch::Device(torch::kCUDA);
|
||||
std::cout << "CUDA is available! Using GPU." << std::endl;
|
||||
} else {
|
||||
std::cout << "CUDA is not available. Using CPU." << std::endl;
|
||||
}
|
||||
}
|
||||
const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
|
||||
Classifier& Classifier::build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing)
|
||||
{
|
||||
@@ -31,7 +39,7 @@ namespace bayesnet {
|
||||
{
|
||||
try {
|
||||
auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
|
||||
dataset = torch::cat({ dataset, yresized }, 0);
|
||||
dataset = torch::cat({ dataset, yresized }, 0).to(device);
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::stringstream oss;
|
||||
@@ -50,7 +58,7 @@ namespace bayesnet {
|
||||
{
|
||||
dataset = X;
|
||||
buildDataset(y);
|
||||
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
|
||||
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble).to(device);
|
||||
return build(features, className, states, weights, smoothing);
|
||||
}
|
||||
// X is nxm where n is the number of features and m the number of samples
|
||||
|
@@ -38,6 +38,7 @@ namespace bayesnet {
|
||||
std::string dump_cpt() const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
|
||||
protected:
|
||||
torch::Device device;
|
||||
bool fitted;
|
||||
unsigned int m, n; // m: number of samples, n: number of features
|
||||
Network model;
|
||||
|
Reference in New Issue
Block a user