Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "SPODELd.h"
8 :
9 : namespace bayesnet {
10 220 : SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className) {}
11 168 : SPODELd& SPODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
12 : {
13 168 : checkInput(X_, y_);
14 168 : Xf = X_;
15 168 : y = y_;
16 168 : return commonFit(features_, className_, states_);
17 : }
18 :
19 8 : SPODELd& SPODELd::fit(torch::Tensor& dataset, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
20 : {
21 8 : if (!torch::is_floating_point(dataset)) {
22 4 : throw std::runtime_error("Dataset must be a floating point tensor");
23 : }
24 16 : Xf = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }).clone();
25 12 : y = dataset.index({ -1, "..." }).clone().to(torch::kInt32);
26 4 : return commonFit(features_, className_, states_);
27 12 : }
28 :
29 172 : SPODELd& SPODELd::commonFit(const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
30 : {
31 172 : features = features_;
32 172 : className = className_;
33 : // Fills std::vectors Xv & yv with the data from tensors X_ (discretized) & y
34 172 : states = fit_local_discretization(y);
35 : // We have discretized the input data
36 : // 1st we need to fit the model to build the normal SPODE structure, SPODE::fit initializes the base Bayesian network
37 172 : SPODE::fit(dataset, features, className, states);
38 172 : states = localDiscretizationProposal(states, model);
39 172 : return *this;
40 : }
41 136 : torch::Tensor SPODELd::predict(torch::Tensor& X)
42 : {
43 136 : auto Xt = prepareX(X);
44 272 : return SPODE::predict(Xt);
45 136 : }
46 36 : std::vector<std::string> SPODELd::graph(const std::string& name) const
47 : {
48 36 : return SPODE::graph(name);
49 : }
50 : }
|