Add Cuda iniitialization in Classifier
This commit is contained in:
@@ -9,7 +9,15 @@
|
||||
#include "Classifier.h"
|
||||
|
||||
namespace bayesnet {
|
||||
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
|
||||
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false), device(torch::kCPU)
|
||||
{
|
||||
if (torch::cuda::is_available()) {
|
||||
device = torch::Device(torch::kCUDA);
|
||||
std::cout << "CUDA is available! Using GPU." << std::endl;
|
||||
} else {
|
||||
std::cout << "CUDA is not available. Using CPU." << std::endl;
|
||||
}
|
||||
}
|
||||
const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
|
||||
Classifier& Classifier::build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing)
|
||||
{
|
||||
@@ -31,7 +39,7 @@ namespace bayesnet {
|
||||
{
|
||||
try {
|
||||
auto yresized = torch::transpose(ytmp.view({ ytmp.size(0), 1 }), 0, 1);
|
||||
dataset = torch::cat({ dataset, yresized }, 0);
|
||||
dataset = torch::cat({ dataset, yresized }, 0).to(device);
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::stringstream oss;
|
||||
@@ -50,7 +58,7 @@ namespace bayesnet {
|
||||
{
|
||||
dataset = X;
|
||||
buildDataset(y);
|
||||
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble);
|
||||
const torch::Tensor weights = torch::full({ dataset.size(1) }, 1.0 / dataset.size(1), torch::kDouble).to(device);
|
||||
return build(features, className, states, weights, smoothing);
|
||||
}
|
||||
// X is nxm where n is the number of features and m the number of samples
|
||||
|
@@ -38,6 +38,7 @@ namespace bayesnet {
|
||||
std::string dump_cpt() const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
|
||||
protected:
|
||||
torch::Device device;
|
||||
bool fitted;
|
||||
unsigned int m, n; // m: number of samples, n: number of features
|
||||
Network model;
|
||||
|
@@ -97,7 +97,7 @@ namespace bayesnet {
|
||||
dimensions.push_back(numStates);
|
||||
transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
|
||||
// Create a tensor of zeros with the dimensions of the CPT
|
||||
cpTable = torch::zeros(dimensions, torch::kDouble) + smoothing;
|
||||
cpTable = torch::zeros(dimensions, torch::kDouble).to(device) + smoothing;
|
||||
// Fill table with counts
|
||||
auto pos = find(features.begin(), features.end(), name);
|
||||
if (pos == features.end()) {
|
||||
|
Reference in New Issue
Block a user