Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "KDBLd.h"
8 :
9 : namespace bayesnet {
10 68 : KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className) {}
11 20 : KDBLd& KDBLd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
12 : {
13 20 : checkInput(X_, y_);
14 20 : features = features_;
15 20 : className = className_;
16 20 : Xf = X_;
17 20 : y = y_;
18 : // Fills std::vectors Xv & yv with the data from tensors X_ (discretized) & y
19 20 : states = fit_local_discretization(y);
20 : // We have discretized the input data
21 : // 1st we need to fit the model to build the normal KDB structure, KDB::fit initializes the base Bayesian network
22 20 : KDB::fit(dataset, features, className, states);
23 20 : states = localDiscretizationProposal(states, model);
24 20 : return *this;
25 : }
26 16 : torch::Tensor KDBLd::predict(torch::Tensor& X)
27 : {
28 16 : auto Xt = prepareX(X);
29 32 : return KDB::predict(Xt);
30 16 : }
31 4 : std::vector<std::string> KDBLd::graph(const std::string& name) const
32 : {
33 4 : return KDB::graph(name);
34 : }
35 : }
|