2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2023-07-06 09:59:48 +00:00
# include <thread>
2024-04-07 22:13:59 +00:00
# include <sstream>
2024-06-09 15:19:38 +00:00
# include <numeric>
2024-06-20 08:36:09 +00:00
# include <algorithm>
2023-06-29 20:00:41 +00:00
# include "Network.h"
2024-03-08 21:20:54 +00:00
# include "bayesnet/utils/bayesnetUtils.h"
2024-06-21 07:30:24 +00:00
# include "bayesnet/utils/CountingSemaphore.h"
# include <pthread.h>
2024-07-07 19:06:59 +00:00
# include <fstream>
2023-06-29 20:00:41 +00:00
namespace bayesnet {
2024-06-21 07:30:24 +00:00
Network : : Network ( ) : fitted { false } , classNumStates { 0 }
2023-07-02 14:15:14 +00:00
{
2024-04-07 22:13:59 +00:00
}
2024-06-11 09:40:45 +00:00
Network : : Network ( const Network & other ) : features ( other . features ) , className ( other . className ) , classNumStates ( other . getClassNumStates ( ) ) ,
2024-06-21 07:30:24 +00:00
fitted ( other . fitted ) , samples ( other . samples )
2024-04-07 22:13:59 +00:00
{
if ( samples . defined ( ) )
samples = samples . clone ( ) ;
2023-11-08 17:45:35 +00:00
for ( const auto & node : other . nodes ) {
nodes [ node . first ] = std : : make_unique < Node > ( * node . second ) ;
2023-06-29 20:00:41 +00:00
}
}
2023-08-03 18:22:33 +00:00
void Network : : initialize ( )
{
2024-04-07 22:13:59 +00:00
features . clear ( ) ;
2023-08-03 18:22:33 +00:00
className = " " ;
classNumStates = 0 ;
fitted = false ;
nodes . clear ( ) ;
samples = torch : : Tensor ( ) ;
}
2023-07-11 20:23:49 +00:00
torch : : Tensor & Network : : getSamples ( )
{
return samples ;
}
2023-11-08 17:45:35 +00:00
void Network : : addNode ( const std : : string & name )
2023-06-29 20:00:41 +00:00
{
2024-07-07 19:06:59 +00:00
if ( fitted ) {
throw std : : invalid_argument ( " Cannot add node to a fitted network. Initialize first. " ) ;
}
2023-08-03 18:22:33 +00:00
if ( name = = " " ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " Node name cannot be empty " ) ;
2023-08-03 18:22:33 +00:00
}
2023-06-30 19:24:12 +00:00
if ( nodes . find ( name ) ! = nodes . end ( ) ) {
2023-07-01 12:45:44 +00:00
return ;
2023-06-30 19:24:12 +00:00
}
2023-08-05 12:40:42 +00:00
if ( find ( features . begin ( ) , features . end ( ) , name ) = = features . end ( ) ) {
features . push_back ( name ) ;
}
nodes [ name ] = std : : make_unique < Node > ( name ) ;
2023-06-29 21:53:33 +00:00
}
2023-11-08 17:45:35 +00:00
std : : vector < std : : string > Network : : getFeatures ( ) const
2023-07-02 18:39:13 +00:00
{
return features ;
}
2023-08-07 23:53:41 +00:00
int Network : : getClassNumStates ( ) const
2023-07-05 16:38:54 +00:00
{
return classNumStates ;
}
2023-08-07 23:53:41 +00:00
int Network : : getStates ( ) const
2023-07-09 14:25:24 +00:00
{
int result = 0 ;
2023-07-14 23:05:36 +00:00
for ( auto & node : nodes ) {
2023-07-09 14:25:24 +00:00
result + = node . second - > getNumStates ( ) ;
}
return result ;
}
2023-11-08 17:45:35 +00:00
std : : string Network : : getClassName ( ) const
2023-07-05 16:38:54 +00:00
{
return className ;
}
2023-11-08 17:45:35 +00:00
bool Network : : isCyclic ( const std : : string & nodeId , std : : unordered_set < std : : string > & visited , std : : unordered_set < std : : string > & recStack )
2023-06-29 21:53:33 +00:00
{
if ( visited . find ( nodeId ) = = visited . end ( ) ) // if node hasn't been visited yet
{
visited . insert ( nodeId ) ;
recStack . insert ( nodeId ) ;
for ( Node * child : nodes [ nodeId ] - > getChildren ( ) ) {
if ( visited . find ( child - > getName ( ) ) = = visited . end ( ) & & isCyclic ( child - > getName ( ) , visited , recStack ) )
return true ;
2024-02-28 10:51:37 +00:00
if ( recStack . find ( child - > getName ( ) ) ! = recStack . end ( ) )
2023-06-29 21:53:33 +00:00
return true ;
}
}
recStack . erase ( nodeId ) ; // remove node from recursion stack before function ends
return false ;
}
2023-11-08 17:45:35 +00:00
void Network : : addEdge ( const std : : string & parent , const std : : string & child )
2023-06-29 20:00:41 +00:00
{
2024-07-07 19:06:59 +00:00
if ( fitted ) {
throw std : : invalid_argument ( " Cannot add edge to a fitted network. Initialize first. " ) ;
}
2023-06-29 20:00:41 +00:00
if ( nodes . find ( parent ) = = nodes . end ( ) ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " Parent node " + parent + " does not exist " ) ;
2023-06-29 20:00:41 +00:00
}
if ( nodes . find ( child ) = = nodes . end ( ) ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " Child node " + child + " does not exist " ) ;
2023-06-29 20:00:41 +00:00
}
2024-07-04 16:52:41 +00:00
// Check if the edge is already in the graph
for ( auto & node : nodes [ parent ] - > getChildren ( ) ) {
if ( node - > getName ( ) = = child ) {
throw std : : invalid_argument ( " Edge " + parent + " -> " + child + " already exists " ) ;
}
}
2023-06-29 21:53:33 +00:00
// Temporarily add edge to check for cycles
2023-07-14 23:05:36 +00:00
nodes [ parent ] - > addChild ( nodes [ child ] . get ( ) ) ;
nodes [ child ] - > addParent ( nodes [ parent ] . get ( ) ) ;
2023-11-08 17:45:35 +00:00
std : : unordered_set < std : : string > visited ;
std : : unordered_set < std : : string > recStack ;
2023-06-29 21:53:33 +00:00
if ( isCyclic ( nodes [ child ] - > getName ( ) , visited , recStack ) ) // if adding this edge forms a cycle
{
2023-06-30 19:24:12 +00:00
// remove problematic edge
2023-07-14 23:05:36 +00:00
nodes [ parent ] - > removeChild ( nodes [ child ] . get ( ) ) ;
nodes [ child ] - > removeParent ( nodes [ parent ] . get ( ) ) ;
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " Adding this edge forms a cycle in the graph. " ) ;
2023-06-29 21:53:33 +00:00
}
2023-06-29 20:00:41 +00:00
}
2023-11-08 17:45:35 +00:00
std : : map < std : : string , std : : unique_ptr < Node > > & Network : : getNodes ( )
2023-06-29 20:00:41 +00:00
{
return nodes ;
}
2023-11-08 17:45:35 +00:00
void Network : : checkFitData ( int n_samples , int n_features , int n_samples_y , const std : : vector < std : : string > & featureNames , const std : : string & className , const std : : map < std : : string , std : : vector < int > > & states , const torch : : Tensor & weights )
2023-08-03 18:22:33 +00:00
{
2023-08-12 22:59:02 +00:00
if ( weights . size ( 0 ) ! = n_samples ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " Weights ( " + std : : to_string ( weights . size ( 0 ) ) + " ) must have the same number of elements as samples ( " + std : : to_string ( n_samples ) + " ) in Network::fit " ) ;
2023-08-12 22:59:02 +00:00
}
2023-08-03 18:22:33 +00:00
if ( n_samples ! = n_samples_y ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " X and y must have the same number of samples in Network::fit ( " + std : : to_string ( n_samples ) + " != " + std : : to_string ( n_samples_y ) + " ) " ) ;
2023-08-03 18:22:33 +00:00
}
if ( n_features ! = featureNames . size ( ) ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " X and features must have the same number of features in Network::fit ( " + std : : to_string ( n_features ) + " != " + std : : to_string ( featureNames . size ( ) ) + " ) " ) ;
2023-08-03 18:22:33 +00:00
}
2024-04-07 22:13:59 +00:00
if ( features . size ( ) = = 0 ) {
throw std : : invalid_argument ( " The network has not been initialized. You must call addNode() before calling fit() " ) ;
}
2023-08-03 18:22:33 +00:00
if ( n_features ! = features . size ( ) - 1 ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " X and local features must have the same number of features in Network::fit ( " + std : : to_string ( n_features ) + " != " + std : : to_string ( features . size ( ) - 1 ) + " ) " ) ;
2023-08-03 18:22:33 +00:00
}
if ( find ( features . begin ( ) , features . end ( ) , className ) = = features . end ( ) ) {
2024-04-07 22:13:59 +00:00
throw std : : invalid_argument ( " Class Name not found in Network::features " ) ;
2023-08-03 18:22:33 +00:00
}
for ( auto & feature : featureNames ) {
if ( find ( features . begin ( ) , features . end ( ) , feature ) = = features . end ( ) ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " Feature " + feature + " not found in Network::features " ) ;
2023-08-03 18:22:33 +00:00
}
2023-08-12 09:10:53 +00:00
if ( states . find ( feature ) = = states . end ( ) ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " Feature " + feature + " not found in states " ) ;
2023-08-12 09:10:53 +00:00
}
2023-08-03 18:22:33 +00:00
}
}
2023-11-08 17:45:35 +00:00
void Network : : setStates ( const std : : map < std : : string , std : : vector < int > > & states )
2023-08-05 12:40:42 +00:00
{
// Set states to every Node in the network
2023-11-08 17:45:35 +00:00
for_each ( features . begin ( ) , features . end ( ) , [ this , & states ] ( const std : : string & feature ) {
2023-09-10 17:50:36 +00:00
nodes . at ( feature ) - > setNumStates ( states . at ( feature ) . size ( ) ) ;
} ) ;
2023-09-04 19:24:11 +00:00
classNumStates = nodes . at ( className ) - > getNumStates ( ) ;
2023-08-05 12:40:42 +00:00
}
2023-08-03 18:22:33 +00:00
// X comes in nxm, where n is the number of features and m the number of samples
2024-06-11 09:40:45 +00:00
void Network : : fit ( const torch : : Tensor & X , const torch : : Tensor & y , const torch : : Tensor & weights , const std : : vector < std : : string > & featureNames , const std : : string & className , const std : : map < std : : string , std : : vector < int > > & states , const Smoothing_t smoothing )
2023-07-23 12:10:28 +00:00
{
2023-08-12 22:59:02 +00:00
checkFitData ( X . size ( 1 ) , X . size ( 0 ) , y . size ( 0 ) , featureNames , className , states , weights ) ;
2023-07-25 23:39:01 +00:00
this - > className = className ;
2023-11-08 17:45:35 +00:00
torch : : Tensor ytmp = torch : : transpose ( y . view ( { y . size ( 0 ) , 1 } ) , 0 , 1 ) ;
2023-08-03 18:22:33 +00:00
samples = torch : : cat ( { X , ytmp } , 0 ) ;
2023-07-25 23:39:01 +00:00
for ( int i = 0 ; i < featureNames . size ( ) ; + + i ) {
2023-08-03 18:22:33 +00:00
auto row_feature = X . index ( { i , " ... " } ) ;
2023-07-25 23:39:01 +00:00
}
2024-06-11 09:40:45 +00:00
completeFit ( states , weights , smoothing ) ;
2023-08-07 10:49:37 +00:00
}
2024-06-11 09:40:45 +00:00
void Network : : fit ( const torch : : Tensor & samples , const torch : : Tensor & weights , const std : : vector < std : : string > & featureNames , const std : : string & className , const std : : map < std : : string , std : : vector < int > > & states , const Smoothing_t smoothing )
2023-08-07 10:49:37 +00:00
{
2023-08-12 22:59:02 +00:00
checkFitData ( samples . size ( 1 ) , samples . size ( 0 ) - 1 , samples . size ( 1 ) , featureNames , className , states , weights ) ;
2023-08-07 10:49:37 +00:00
this - > className = className ;
this - > samples = samples ;
2024-06-11 09:40:45 +00:00
completeFit ( states , weights , smoothing ) ;
2023-07-23 12:10:28 +00:00
}
2023-08-03 18:22:33 +00:00
// input_data comes in nxm, where n is the number of features and m the number of samples
2024-06-11 09:40:45 +00:00
void Network : : fit ( const std : : vector < std : : vector < int > > & input_data , const std : : vector < int > & labels , const std : : vector < double > & weights_ , const std : : vector < std : : string > & featureNames , const std : : string & className , const std : : map < std : : string , std : : vector < int > > & states , const Smoothing_t smoothing )
2023-06-30 00:46:06 +00:00
{
2023-08-12 22:59:02 +00:00
const torch : : Tensor weights = torch : : tensor ( weights_ , torch : : kFloat64 ) ;
checkFitData ( input_data [ 0 ] . size ( ) , input_data . size ( ) , labels . size ( ) , featureNames , className , states , weights ) ;
2023-06-30 19:24:12 +00:00
this - > className = className ;
2023-08-07 10:49:37 +00:00
// Build tensor of samples (nxm) (n+1 because of the class)
2023-08-03 18:22:33 +00:00
samples = torch : : zeros ( { static_cast < int > ( input_data . size ( ) + 1 ) , static_cast < int > ( input_data [ 0 ] . size ( ) ) } , torch : : kInt32 ) ;
2023-06-30 19:24:12 +00:00
for ( int i = 0 ; i < featureNames . size ( ) ; + + i ) {
2023-08-03 18:22:33 +00:00
samples . index_put_ ( { i , " ... " } , torch : : tensor ( input_data [ i ] , torch : : kInt32 ) ) ;
2023-06-30 19:24:12 +00:00
}
2023-08-03 18:22:33 +00:00
samples . index_put_ ( { - 1 , " ... " } , torch : : tensor ( labels , torch : : kInt32 ) ) ;
2024-06-11 09:40:45 +00:00
completeFit ( states , weights , smoothing ) ;
2023-07-25 23:39:01 +00:00
}
2024-06-11 09:40:45 +00:00
void Network : : completeFit ( const std : : map < std : : string , std : : vector < int > > & states , const torch : : Tensor & weights , const Smoothing_t smoothing )
2023-07-25 23:39:01 +00:00
{
2023-08-12 09:10:53 +00:00
setStates ( states ) ;
2023-11-08 17:45:35 +00:00
std : : vector < std : : thread > threads ;
2024-06-21 07:30:24 +00:00
auto & semaphore = CountingSemaphore : : getInstance ( ) ;
2024-06-10 13:49:01 +00:00
const double n_samples = static_cast < double > ( samples . size ( 1 ) ) ;
2024-06-21 07:30:24 +00:00
auto worker = [ & ] ( std : : pair < const std : : string , std : : unique_ptr < Node > > & node , int i ) {
std : : string threadName = " FitWorker- " + std : : to_string ( i ) ;
2024-06-21 17:56:35 +00:00
# if defined(__linux__)
2024-06-21 07:30:24 +00:00
pthread_setname_np ( pthread_self ( ) , threadName . c_str ( ) ) ;
2024-06-21 17:56:35 +00:00
# else
pthread_setname_np ( threadName . c_str ( ) ) ;
# endif
2024-06-18 21:18:24 +00:00
double numStates = static_cast < double > ( node . second - > getNumStates ( ) ) ;
double smoothing_factor = 0.0 ;
switch ( smoothing ) {
case Smoothing_t : : ORIGINAL :
smoothing_factor = 1.0 / n_samples ;
break ;
case Smoothing_t : : LAPLACE :
smoothing_factor = 1.0 ;
break ;
case Smoothing_t : : CESTNIK :
smoothing_factor = 1 / numStates ;
break ;
default :
throw std : : invalid_argument ( " Smoothing method not recognized " + std : : to_string ( static_cast < int > ( smoothing ) ) ) ;
}
node . second - > computeCPT ( samples , features , smoothing_factor , weights ) ;
2024-06-20 08:36:09 +00:00
semaphore . release ( ) ;
2024-06-18 21:18:24 +00:00
} ;
2024-06-21 07:30:24 +00:00
int i = 0 ;
2023-08-31 18:30:28 +00:00
for ( auto & node : nodes ) {
2024-06-21 17:56:35 +00:00
semaphore . acquire ( ) ;
2024-06-21 07:30:24 +00:00
threads . emplace_back ( worker , std : : ref ( node ) , i + + ) ;
2023-08-31 18:30:28 +00:00
}
2023-09-04 19:24:11 +00:00
for ( auto & thread : threads ) {
thread . join ( ) ;
}
2024-07-07 19:06:59 +00:00
// std::fstream file;
// file.open("cpt.txt", std::fstream::out | std::fstream::app);
// file << std::string(80, '*') << std::endl;
// for (const auto& item : graph("Test")) {
// file << item << std::endl;
// }
// file << std::string(80, '-') << std::endl;
// file << dump_cpt() << std::endl;
// file << std::string(80, '=') << std::endl;
// file.close();
2023-09-04 19:24:11 +00:00
fitted = true ;
2023-06-29 20:00:41 +00:00
}
2023-08-03 18:22:33 +00:00
torch : : Tensor Network : : predict_tensor ( const torch : : Tensor & samples , const bool proba )
2023-07-30 17:00:02 +00:00
{
if ( ! fitted ) {
2023-11-08 17:45:35 +00:00
throw std : : logic_error ( " You must call fit() before calling predict() " ) ;
2023-07-30 17:00:02 +00:00
}
2024-06-21 11:58:42 +00:00
// Ensure the sample size is equal to the number of features
if ( samples . size ( 0 ) ! = features . size ( ) - 1 ) {
throw std : : invalid_argument ( " (T) Sample size ( " + std : : to_string ( samples . size ( 0 ) ) +
" ) does not match the number of features ( " + std : : to_string ( features . size ( ) - 1 ) + " ) " ) ;
}
2023-08-03 18:22:33 +00:00
torch : : Tensor result ;
2024-06-21 11:58:42 +00:00
std : : vector < std : : thread > threads ;
std : : mutex mtx ;
auto & semaphore = CountingSemaphore : : getInstance ( ) ;
2023-08-03 18:22:33 +00:00
result = torch : : zeros ( { samples . size ( 1 ) , classNumStates } , torch : : kFloat64 ) ;
2024-06-21 11:58:42 +00:00
auto worker = [ & ] ( const torch : : Tensor & sample , int i ) {
std : : string threadName = " PredictWorker- " + std : : to_string ( i ) ;
2024-06-21 17:56:35 +00:00
# if defined(__linux__)
2024-06-21 11:58:42 +00:00
pthread_setname_np ( pthread_self ( ) , threadName . c_str ( ) ) ;
2024-06-21 17:56:35 +00:00
# else
pthread_setname_np ( threadName . c_str ( ) ) ;
# endif
2023-08-04 17:42:18 +00:00
auto psample = predict_sample ( sample ) ;
auto temp = torch : : tensor ( psample , torch : : kFloat64 ) ;
2024-06-21 11:58:42 +00:00
{
std : : lock_guard < std : : mutex > lock ( mtx ) ;
result . index_put_ ( { i , " ... " } , temp ) ;
}
semaphore . release ( ) ;
} ;
for ( int i = 0 ; i < samples . size ( 1 ) ; + + i ) {
2024-06-21 17:56:35 +00:00
semaphore . acquire ( ) ;
2024-06-21 11:58:42 +00:00
const torch : : Tensor sample = samples . index ( { " ... " , i } ) ;
threads . emplace_back ( worker , sample , i ) ;
}
for ( auto & thread : threads ) {
thread . join ( ) ;
2023-07-30 17:00:02 +00:00
}
2023-08-03 18:22:33 +00:00
if ( proba )
return result ;
2023-10-09 09:25:30 +00:00
return result . argmax ( 1 ) ;
2023-07-30 17:00:02 +00:00
}
2023-08-03 18:22:33 +00:00
// Return mxn tensor of probabilities
2023-11-08 17:45:35 +00:00
torch : : Tensor Network : : predict_proba ( const torch : : Tensor & samples )
2023-08-03 18:22:33 +00:00
{
return predict_tensor ( samples , true ) ;
}
// Return mxn tensor of probabilities
2023-11-08 17:45:35 +00:00
torch : : Tensor Network : : predict ( const torch : : Tensor & samples )
2023-07-30 17:00:02 +00:00
{
2023-08-03 18:22:33 +00:00
return predict_tensor ( samples , false ) ;
2023-07-30 17:00:02 +00:00
}
2023-07-01 12:45:44 +00:00
2023-11-08 17:45:35 +00:00
// Return mx1 std::vector of predictions
// tsamples is nxm std::vector of samples
std : : vector < int > Network : : predict ( const std : : vector < std : : vector < int > > & tsamples )
2023-07-01 12:45:44 +00:00
{
2023-07-14 23:05:36 +00:00
if ( ! fitted ) {
2023-11-08 17:45:35 +00:00
throw std : : logic_error ( " You must call fit() before calling predict() " ) ;
2023-07-14 23:05:36 +00:00
}
2024-06-21 11:58:42 +00:00
// Ensure the sample size is equal to the number of features
if ( tsamples . size ( ) ! = features . size ( ) - 1 ) {
throw std : : invalid_argument ( " (V) Sample size ( " + std : : to_string ( tsamples . size ( ) ) +
" ) does not match the number of features ( " + std : : to_string ( features . size ( ) - 1 ) + " ) " ) ;
}
std : : vector < int > predictions ( tsamples [ 0 ] . size ( ) , 0 ) ;
2023-11-08 17:45:35 +00:00
std : : vector < int > sample ;
2024-06-21 11:58:42 +00:00
std : : vector < std : : thread > threads ;
auto & semaphore = CountingSemaphore : : getInstance ( ) ;
2024-06-23 11:02:40 +00:00
auto worker = [ & ] ( const std : : vector < int > & sample , const int row , int & prediction ) {
std : : string threadName = " (V)PWorker- " + std : : to_string ( row ) ;
# if defined(__linux__)
pthread_setname_np ( pthread_self ( ) , threadName . c_str ( ) ) ;
# else
pthread_setname_np ( threadName . c_str ( ) ) ;
# endif
2024-06-21 11:58:42 +00:00
auto classProbabilities = predict_sample ( sample ) ;
auto maxElem = max_element ( classProbabilities . begin ( ) , classProbabilities . end ( ) ) ;
int predictedClass = distance ( classProbabilities . begin ( ) , maxElem ) ;
2024-06-23 11:02:40 +00:00
prediction = predictedClass ;
2024-06-21 11:58:42 +00:00
semaphore . release ( ) ;
} ;
2023-07-11 15:42:20 +00:00
for ( int row = 0 ; row < tsamples [ 0 ] . size ( ) ; + + row ) {
2023-07-01 12:45:44 +00:00
sample . clear ( ) ;
2023-07-11 15:42:20 +00:00
for ( int col = 0 ; col < tsamples . size ( ) ; + + col ) {
sample . push_back ( tsamples [ col ] [ row ] ) ;
2023-07-01 12:45:44 +00:00
}
2024-06-23 11:02:40 +00:00
semaphore . acquire ( ) ;
threads . emplace_back ( worker , sample , row , std : : ref ( predictions [ row ] ) ) ;
2024-06-21 11:58:42 +00:00
}
for ( auto & thread : threads ) {
thread . join ( ) ;
2023-07-01 12:45:44 +00:00
}
return predictions ;
}
2023-11-08 17:45:35 +00:00
// Return mxn std::vector of probabilities
2024-02-23 19:36:11 +00:00
// tsamples is nxm std::vector of samples
2023-11-08 17:45:35 +00:00
std : : vector < std : : vector < double > > Network : : predict_proba ( const std : : vector < std : : vector < int > > & tsamples )
2023-07-01 12:45:44 +00:00
{
2023-07-14 23:05:36 +00:00
if ( ! fitted ) {
2023-11-08 17:45:35 +00:00
throw std : : logic_error ( " You must call fit() before calling predict_proba() " ) ;
2023-07-14 23:05:36 +00:00
}
2024-06-21 11:58:42 +00:00
// Ensure the sample size is equal to the number of features
if ( tsamples . size ( ) ! = features . size ( ) - 1 ) {
throw std : : invalid_argument ( " (V) Sample size ( " + std : : to_string ( tsamples . size ( ) ) +
" ) does not match the number of features ( " + std : : to_string ( features . size ( ) - 1 ) + " ) " ) ;
}
2024-06-23 11:02:40 +00:00
std : : vector < std : : vector < double > > predictions ( tsamples [ 0 ] . size ( ) , std : : vector < double > ( classNumStates , 0.0 ) ) ;
2023-11-08 17:45:35 +00:00
std : : vector < int > sample ;
2024-06-23 11:02:40 +00:00
std : : vector < std : : thread > threads ;
auto & semaphore = CountingSemaphore : : getInstance ( ) ;
auto worker = [ & ] ( const std : : vector < int > & sample , int row , std : : vector < double > & predictions ) {
std : : string threadName = " (V)PWorker- " + std : : to_string ( row ) ;
# if defined(__linux__)
pthread_setname_np ( pthread_self ( ) , threadName . c_str ( ) ) ;
# else
pthread_setname_np ( threadName . c_str ( ) ) ;
# endif
std : : vector < double > classProbabilities = predict_sample ( sample ) ;
predictions = classProbabilities ;
semaphore . release ( ) ;
} ;
2023-07-11 15:42:20 +00:00
for ( int row = 0 ; row < tsamples [ 0 ] . size ( ) ; + + row ) {
2023-07-01 12:45:44 +00:00
sample . clear ( ) ;
2023-07-11 15:42:20 +00:00
for ( int col = 0 ; col < tsamples . size ( ) ; + + col ) {
sample . push_back ( tsamples [ col ] [ row ] ) ;
2023-07-01 12:45:44 +00:00
}
2024-06-23 11:02:40 +00:00
semaphore . acquire ( ) ;
threads . emplace_back ( worker , sample , row , std : : ref ( predictions [ row ] ) ) ;
}
for ( auto & thread : threads ) {
thread . join ( ) ;
2023-07-01 12:45:44 +00:00
}
return predictions ;
}
2023-11-08 17:45:35 +00:00
double Network : : score ( const std : : vector < std : : vector < int > > & tsamples , const std : : vector < int > & labels )
2023-07-01 12:45:44 +00:00
{
2023-11-08 17:45:35 +00:00
std : : vector < int > y_pred = predict ( tsamples ) ;
2023-07-01 12:45:44 +00:00
int correct = 0 ;
for ( int i = 0 ; i < y_pred . size ( ) ; + + i ) {
if ( y_pred [ i ] = = labels [ i ] ) {
correct + + ;
}
}
return ( double ) correct / y_pred . size ( ) ;
}
2023-11-08 17:45:35 +00:00
// Return 1xn std::vector of probabilities
std : : vector < double > Network : : predict_sample ( const std : : vector < int > & sample )
2023-07-02 14:31:50 +00:00
{
2023-11-08 17:45:35 +00:00
std : : map < std : : string , int > evidence ;
2023-07-02 18:39:13 +00:00
for ( int i = 0 ; i < sample . size ( ) ; + + i ) {
evidence [ features [ i ] ] = sample [ i ] ;
2023-07-02 14:31:50 +00:00
}
2023-07-06 22:33:04 +00:00
return exactInference ( evidence ) ;
2023-07-02 14:31:50 +00:00
}
2023-11-08 17:45:35 +00:00
// Return 1xn std::vector of probabilities
std : : vector < double > Network : : predict_sample ( const torch : : Tensor & sample )
2023-07-30 17:00:02 +00:00
{
2023-11-08 17:45:35 +00:00
std : : map < std : : string , int > evidence ;
2023-07-30 17:00:02 +00:00
for ( int i = 0 ; i < sample . size ( 0 ) ; + + i ) {
evidence [ features [ i ] ] = sample [ i ] . item < int > ( ) ;
}
return exactInference ( evidence ) ;
}
2023-11-08 17:45:35 +00:00
std : : vector < double > Network : : exactInference ( std : : map < std : : string , int > & evidence )
2023-07-06 09:01:58 +00:00
{
2023-11-08 17:45:35 +00:00
std : : vector < double > result ( classNumStates , 0.0 ) ;
2024-06-21 11:58:42 +00:00
auto completeEvidence = std : : map < std : : string , int > ( evidence ) ;
for ( int i = 0 ; i < classNumStates ; + + i ) {
2024-06-18 21:18:24 +00:00
completeEvidence [ getClassName ( ) ] = i ;
2024-06-21 11:58:42 +00:00
double partial = 1.0 ;
for ( auto & node : getNodes ( ) ) {
partial * = node . second - > getFactorValue ( completeEvidence ) ;
2024-06-18 21:18:24 +00:00
}
2024-06-21 11:58:42 +00:00
result [ i ] = partial ;
2023-07-06 09:01:58 +00:00
}
// Normalize result
2024-06-09 15:19:38 +00:00
double sum = std : : accumulate ( result . begin ( ) , result . end ( ) , 0.0 ) ;
2023-08-16 10:32:51 +00:00
transform ( result . begin ( ) , result . end ( ) , result . begin ( ) , [ sum ] ( const double & value ) { return value / sum ; } ) ;
2023-07-06 09:01:58 +00:00
return result ;
}
2023-11-08 17:45:35 +00:00
std : : vector < std : : string > Network : : show ( ) const
2023-07-13 14:59:06 +00:00
{
2023-11-08 17:45:35 +00:00
std : : vector < std : : string > result ;
2023-07-13 14:59:06 +00:00
// Draw the network
2023-07-14 23:05:36 +00:00
for ( auto & node : nodes ) {
2023-11-08 17:45:35 +00:00
std : : string line = node . first + " -> " ;
2023-07-13 14:59:06 +00:00
for ( auto child : node . second - > getChildren ( ) ) {
line + = child - > getName ( ) + " , " ;
}
result . push_back ( line ) ;
}
return result ;
}
2023-11-08 17:45:35 +00:00
std : : vector < std : : string > Network : : graph ( const std : : string & title ) const
2023-07-15 23:20:47 +00:00
{
2023-11-08 17:45:35 +00:00
auto output = std : : vector < std : : string > ( ) ;
2023-07-15 23:20:47 +00:00
auto prefix = " digraph BayesNet { \n label=<BayesNet " ;
auto suffix = " > \n fontsize=30 \n fontcolor=blue \n labelloc=t \n layout=circo \n " ;
2023-11-08 17:45:35 +00:00
std : : string header = prefix + title + suffix ;
2023-07-15 23:20:47 +00:00
output . push_back ( header ) ;
for ( auto & node : nodes ) {
auto result = node . second - > graph ( className ) ;
output . insert ( output . end ( ) , result . begin ( ) , result . end ( ) ) ;
}
output . push_back ( " } \n " ) ;
return output ;
}
2023-11-08 17:45:35 +00:00
std : : vector < std : : pair < std : : string , std : : string > > Network : : getEdges ( ) const
2023-07-19 13:05:44 +00:00
{
2023-11-08 17:45:35 +00:00
auto edges = std : : vector < std : : pair < std : : string , std : : string > > ( ) ;
2023-07-19 13:05:44 +00:00
for ( const auto & node : nodes ) {
auto head = node . first ;
for ( const auto & child : node . second - > getChildren ( ) ) {
auto tail = child - > getName ( ) ;
edges . push_back ( { head , tail } ) ;
}
}
return edges ;
}
2023-08-07 23:53:41 +00:00
int Network : : getNumEdges ( ) const
{
return getEdges ( ) . size ( ) ;
}
2023-11-08 17:45:35 +00:00
std : : vector < std : : string > Network : : topological_sort ( )
2023-08-01 22:56:52 +00:00
{
/* Check if al the fathers of every node are before the node */
auto result = features ;
2023-08-03 18:22:33 +00:00
result . erase ( remove ( result . begin ( ) , result . end ( ) , className ) , result . end ( ) ) ;
2023-08-01 22:56:52 +00:00
bool ending { false } ;
while ( ! ending ) {
ending = true ;
for ( auto feature : features ) {
auto fathers = nodes [ feature ] - > getParents ( ) ;
for ( const auto & father : fathers ) {
auto fatherName = father - > getName ( ) ;
if ( fatherName = = className ) {
continue ;
}
2023-08-03 18:22:33 +00:00
// Check if father is placed before the actual feature
2023-08-01 22:56:52 +00:00
auto it = find ( result . begin ( ) , result . end ( ) , fatherName ) ;
if ( it ! = result . end ( ) ) {
auto it2 = find ( result . begin ( ) , result . end ( ) , feature ) ;
if ( it2 ! = result . end ( ) ) {
if ( distance ( it , it2 ) < 0 ) {
2023-08-03 18:22:33 +00:00
// if it is not, insert it before the feature
2023-08-01 22:56:52 +00:00
result . erase ( remove ( result . begin ( ) , result . end ( ) , fatherName ) , result . end ( ) ) ;
result . insert ( it2 , fatherName ) ;
ending = false ;
}
}
}
}
}
}
return result ;
}
2024-04-07 22:13:59 +00:00
std : : string Network : : dump_cpt ( ) const
2023-08-03 18:22:33 +00:00
{
2024-04-07 22:13:59 +00:00
std : : stringstream oss ;
2023-08-03 18:22:33 +00:00
for ( auto & node : nodes ) {
2024-04-07 22:13:59 +00:00
oss < < " * " < < node . first < < " : ( " < < node . second - > getNumStates ( ) < < " ) : " < < node . second - > getCPT ( ) . sizes ( ) < < std : : endl ;
oss < < node . second - > getCPT ( ) < < std : : endl ;
2023-08-03 18:22:33 +00:00
}
2024-04-07 22:13:59 +00:00
return oss . str ( ) ;
2023-08-03 18:22:33 +00:00
}
2023-06-29 20:00:41 +00:00
}