From c7e2042c6ed22ac7914369310b5f089b3d74aeb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 11 Jul 2023 17:42:20 +0200 Subject: [PATCH] Add entropy, conditionalEntropy, mutualInformation and conditionalEdgeWeight methods --- .vscode/settings.json | 4 +- data/ecoli.arff | 428 ++++++++++++++++++++++++++++++++++++++++++ sample/main.cc | 2 + sample/test.cc | 166 +++++++++++++--- src/Network.cc | 104 ++++++++-- src/Network.h | 7 + src/Node.cc | 12 +- src/Node.h | 3 +- 8 files changed, 683 insertions(+), 43 deletions(-) create mode 100755 data/ecoli.arff diff --git a/.vscode/settings.json b/.vscode/settings.json index ef91e92..86e20ee 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -88,7 +88,9 @@ "iterator": "cpp", "memory_resource": "cpp", "format": "cpp", - "valarray": "cpp" + "valarray": "cpp", + "regex": "cpp", + "span": "cpp" }, "cmake.configureOnOpen": false, "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools" diff --git a/data/ecoli.arff b/data/ecoli.arff new file mode 100755 index 0000000..6008975 --- /dev/null +++ b/data/ecoli.arff @@ -0,0 +1,428 @@ +% +% 1. Title: Protein Localization Sites +% +% +% 2. Creator and Maintainer: +% Kenta Nakai +% Institue of Molecular and Cellular Biology +% Osaka, University +% 1-3 Yamada-oka, Suita 565 Japan +% nakai@imcb.osaka-u.ac.jp +% http://www.imcb.osaka-u.ac.jp/nakai/psort.html +% Donor: Paul Horton (paulh@cs.berkeley.edu) +% Date: September, 1996 +% See also: yeast database +% +% 3. Past Usage. +% Reference: "A Probablistic Classification System for Predicting the Cellular +% Localization Sites of Proteins", Paul Horton & Kenta Nakai, +% Intelligent Systems in Molecular Biology, 109-115. +% St. Louis, USA 1996. +% Results: 81% for E.coli with an ad hoc structured +% probability model. Also similar accuracy for Binary Decision Tree and +% Bayesian Classifier methods applied by the same authors in +% unpublished results. +% +% Predicted Attribute: Localization site of protein. ( non-numeric ). +% +% +% 4. The references below describe a predecessor to this dataset and its +% development. They also give results (not cross-validated) for classification +% by a rule-based expert system with that version of the dataset. +% +% Reference: "Expert Sytem for Predicting Protein Localization Sites in +% Gram-Negative Bacteria", Kenta Nakai & Minoru Kanehisa, +% PROTEINS: Structure, Function, and Genetics 11:95-110, 1991. +% +% Reference: "A Knowledge Base for Predicting Protein Localization Sites in +% Eukaryotic Cells", Kenta Nakai & Minoru Kanehisa, +% Genomics 14:897-911, 1992. +% +% +% 5. Number of Instances: 336 for the E.coli dataset and +% +% +% 6. Number of Attributes. +% for E.coli dataset: 8 ( 7 predictive, 1 name ) +% +% 7. Attribute Information. +% +% 1. Sequence Name: Accession number for the SWISS-PROT database +% 2. mcg: McGeoch's method for signal sequence recognition. +% 3. gvh: von Heijne's method for signal sequence recognition. +% 4. lip: von Heijne's Signal Peptidase II consensus sequence score. +% Binary attribute. +% 5. chg: Presence of charge on N-terminus of predicted lipoproteins. +% Binary attribute. +% 6. aac: score of discriminant analysis of the amino acid content of +% outer membrane and periplasmic proteins. +% 7. alm1: score of the ALOM membrane spanning region prediction program. +% 8. alm2: score of ALOM program after excluding putative cleavable signal +% regions from the sequence. +% +% NOTE - the sequence name has been removed +% +% 8. Missing Attribute Values: None. +% +% +% 9. Class Distribution. The class is the localization site. Please see Nakai & +% Kanehisa referenced above for more details. +% +% cp (cytoplasm) 143 +% im (inner membrane without signal sequence) 77 +% pp (perisplasm) 52 +% imU (inner membrane, uncleavable signal sequence) 35 +% om (outer membrane) 20 +% omL (outer membrane lipoprotein) 5 +% imL (inner membrane lipoprotein) 2 +% imS (inner membrane, cleavable signal sequence) 2 + +@relation ecoli + +@attribute mcg numeric +@attribute gvh numeric +@attribute lip numeric +@attribute chg numeric +@attribute aac numeric +@attribute alm1 numeric +@attribute alm2 numeric +@attribute class {cp,im,pp,imU,om,omL,imL,imS} + +@data + +0.49,0.29,0.48,0.5,0.56,0.24,0.35,cp +0.07,0.4,0.48,0.5,0.54,0.35,0.44,cp +0.56,0.4,0.48,0.5,0.49,0.37,0.46,cp +0.59,0.49,0.48,0.5,0.52,0.45,0.36,cp +0.23,0.32,0.48,0.5,0.55,0.25,0.35,cp +0.67,0.39,0.48,0.5,0.36,0.38,0.46,cp +0.29,0.28,0.48,0.5,0.44,0.23,0.34,cp +0.21,0.34,0.48,0.5,0.51,0.28,0.39,cp +0.2,0.44,0.48,0.5,0.46,0.51,0.57,cp +0.42,0.4,0.48,0.5,0.56,0.18,0.3,cp +0.42,0.24,0.48,0.5,0.57,0.27,0.37,cp +0.25,0.48,0.48,0.5,0.44,0.17,0.29,cp +0.39,0.32,0.48,0.5,0.46,0.24,0.35,cp +0.51,0.5,0.48,0.5,0.46,0.32,0.35,cp +0.22,0.43,0.48,0.5,0.48,0.16,0.28,cp +0.25,0.4,0.48,0.5,0.46,0.44,0.52,cp +0.34,0.45,0.48,0.5,0.38,0.24,0.35,cp +0.44,0.27,0.48,0.5,0.55,0.52,0.58,cp +0.23,0.4,0.48,0.5,0.39,0.28,0.38,cp +0.41,0.57,0.48,0.5,0.39,0.21,0.32,cp +0.4,0.45,0.48,0.5,0.38,0.22,0,cp +0.31,0.23,0.48,0.5,0.73,0.05,0.14,cp +0.51,0.54,0.48,0.5,0.41,0.34,0.43,cp +0.3,0.16,0.48,0.5,0.56,0.11,0.23,cp +0.36,0.39,0.48,0.5,0.48,0.22,0.23,cp +0.29,0.37,0.48,0.5,0.48,0.44,0.52,cp +0.25,0.4,0.48,0.5,0.47,0.33,0.42,cp +0.21,0.51,0.48,0.5,0.5,0.32,0.41,cp +0.43,0.37,0.48,0.5,0.53,0.35,0.44,cp +0.43,0.39,0.48,0.5,0.47,0.31,0.41,cp +0.53,0.38,0.48,0.5,0.44,0.26,0.36,cp +0.34,0.33,0.48,0.5,0.38,0.35,0.44,cp +0.56,0.51,0.48,0.5,0.34,0.37,0.46,cp +0.4,0.29,0.48,0.5,0.42,0.35,0.44,cp +0.24,0.35,0.48,0.5,0.31,0.19,0.31,cp +0.36,0.54,0.48,0.5,0.41,0.38,0.46,cp +0.29,0.52,0.48,0.5,0.42,0.29,0.39,cp +0.65,0.47,0.48,0.5,0.59,0.3,0.4,cp +0.32,0.42,0.48,0.5,0.35,0.28,0.38,cp +0.38,0.46,0.48,0.5,0.48,0.22,0.29,cp +0.33,0.45,0.48,0.5,0.52,0.32,0.41,cp +0.3,0.37,0.48,0.5,0.59,0.41,0.49,cp +0.4,0.5,0.48,0.5,0.45,0.39,0.47,cp +0.28,0.38,0.48,0.5,0.5,0.33,0.42,cp +0.61,0.45,0.48,0.5,0.48,0.35,0.41,cp +0.17,0.38,0.48,0.5,0.45,0.42,0.5,cp +0.44,0.35,0.48,0.5,0.55,0.55,0.61,cp +0.43,0.4,0.48,0.5,0.39,0.28,0.39,cp +0.42,0.35,0.48,0.5,0.58,0.15,0.27,cp +0.23,0.33,0.48,0.5,0.43,0.33,0.43,cp +0.37,0.52,0.48,0.5,0.42,0.42,0.36,cp +0.29,0.3,0.48,0.5,0.45,0.03,0.17,cp +0.22,0.36,0.48,0.5,0.35,0.39,0.47,cp +0.23,0.58,0.48,0.5,0.37,0.53,0.59,cp +0.47,0.47,0.48,0.5,0.22,0.16,0.26,cp +0.54,0.47,0.48,0.5,0.28,0.33,0.42,cp +0.51,0.37,0.48,0.5,0.35,0.36,0.45,cp +0.4,0.35,0.48,0.5,0.45,0.33,0.42,cp +0.44,0.34,0.48,0.5,0.3,0.33,0.43,cp +0.42,0.38,0.48,0.5,0.54,0.34,0.43,cp +0.44,0.56,0.48,0.5,0.5,0.46,0.54,cp +0.52,0.36,0.48,0.5,0.41,0.28,0.38,cp +0.36,0.41,0.48,0.5,0.48,0.47,0.54,cp +0.18,0.3,0.48,0.5,0.46,0.24,0.35,cp +0.47,0.29,0.48,0.5,0.51,0.33,0.43,cp +0.24,0.43,0.48,0.5,0.54,0.52,0.59,cp +0.25,0.37,0.48,0.5,0.41,0.33,0.42,cp +0.52,0.57,0.48,0.5,0.42,0.47,0.54,cp +0.25,0.37,0.48,0.5,0.43,0.26,0.36,cp +0.35,0.48,0.48,0.5,0.56,0.4,0.48,cp +0.26,0.26,0.48,0.5,0.34,0.25,0.35,cp +0.44,0.51,0.48,0.5,0.47,0.26,0.36,cp +0.37,0.5,0.48,0.5,0.42,0.36,0.45,cp +0.44,0.42,0.48,0.5,0.42,0.25,0.2,cp +0.24,0.43,0.48,0.5,0.37,0.28,0.38,cp +0.42,0.3,0.48,0.5,0.48,0.26,0.36,cp +0.48,0.42,0.48,0.5,0.45,0.25,0.35,cp +0.41,0.48,0.48,0.5,0.51,0.44,0.51,cp +0.44,0.28,0.48,0.5,0.43,0.27,0.37,cp +0.29,0.41,0.48,0.5,0.48,0.38,0.46,cp +0.34,0.28,0.48,0.5,0.41,0.35,0.44,cp +0.41,0.43,0.48,0.5,0.45,0.31,0.41,cp +0.29,0.47,0.48,0.5,0.41,0.23,0.34,cp +0.34,0.55,0.48,0.5,0.58,0.31,0.41,cp +0.36,0.56,0.48,0.5,0.43,0.45,0.53,cp +0.4,0.46,0.48,0.5,0.52,0.49,0.56,cp +0.5,0.49,0.48,0.5,0.49,0.46,0.53,cp +0.52,0.44,0.48,0.5,0.37,0.36,0.42,cp +0.5,0.51,0.48,0.5,0.27,0.23,0.34,cp +0.53,0.42,0.48,0.5,0.16,0.29,0.39,cp +0.34,0.46,0.48,0.5,0.52,0.35,0.44,cp +0.4,0.42,0.48,0.5,0.37,0.27,0.27,cp +0.41,0.43,0.48,0.5,0.5,0.24,0.25,cp +0.3,0.45,0.48,0.5,0.36,0.21,0.32,cp +0.31,0.47,0.48,0.5,0.29,0.28,0.39,cp +0.64,0.76,0.48,0.5,0.45,0.35,0.38,cp +0.35,0.37,0.48,0.5,0.3,0.34,0.43,cp +0.57,0.54,0.48,0.5,0.37,0.28,0.33,cp +0.65,0.55,0.48,0.5,0.34,0.37,0.28,cp +0.51,0.46,0.48,0.5,0.58,0.31,0.41,cp +0.38,0.4,0.48,0.5,0.63,0.25,0.35,cp +0.24,0.57,0.48,0.5,0.63,0.34,0.43,cp +0.38,0.26,0.48,0.5,0.54,0.16,0.28,cp +0.33,0.47,0.48,0.5,0.53,0.18,0.29,cp +0.24,0.34,0.48,0.5,0.38,0.3,0.4,cp +0.26,0.5,0.48,0.5,0.44,0.32,0.41,cp +0.44,0.49,0.48,0.5,0.39,0.38,0.4,cp +0.43,0.32,0.48,0.5,0.33,0.45,0.52,cp +0.49,0.43,0.48,0.5,0.49,0.3,0.4,cp +0.47,0.28,0.48,0.5,0.56,0.2,0.25,cp +0.32,0.33,0.48,0.5,0.6,0.06,0.2,cp +0.34,0.35,0.48,0.5,0.51,0.49,0.56,cp +0.35,0.34,0.48,0.5,0.46,0.3,0.27,cp +0.38,0.3,0.48,0.5,0.43,0.29,0.39,cp +0.38,0.44,0.48,0.5,0.43,0.2,0.31,cp +0.41,0.51,0.48,0.5,0.58,0.2,0.31,cp +0.34,0.42,0.48,0.5,0.41,0.34,0.43,cp +0.51,0.49,0.48,0.5,0.53,0.14,0.26,cp +0.25,0.51,0.48,0.5,0.37,0.42,0.5,cp +0.29,0.28,0.48,0.5,0.5,0.42,0.5,cp +0.25,0.26,0.48,0.5,0.39,0.32,0.42,cp +0.24,0.41,0.48,0.5,0.49,0.23,0.34,cp +0.17,0.39,0.48,0.5,0.53,0.3,0.39,cp +0.04,0.31,0.48,0.5,0.41,0.29,0.39,cp +0.61,0.36,0.48,0.5,0.49,0.35,0.44,cp +0.34,0.51,0.48,0.5,0.44,0.37,0.46,cp +0.28,0.33,0.48,0.5,0.45,0.22,0.33,cp +0.4,0.46,0.48,0.5,0.42,0.35,0.44,cp +0.23,0.34,0.48,0.5,0.43,0.26,0.37,cp +0.37,0.44,0.48,0.5,0.42,0.39,0.47,cp +0,0.38,0.48,0.5,0.42,0.48,0.55,cp +0.39,0.31,0.48,0.5,0.38,0.34,0.43,cp +0.3,0.44,0.48,0.5,0.49,0.22,0.33,cp +0.27,0.3,0.48,0.5,0.71,0.28,0.39,cp +0.17,0.52,0.48,0.5,0.49,0.37,0.46,cp +0.36,0.42,0.48,0.5,0.53,0.32,0.41,cp +0.3,0.37,0.48,0.5,0.43,0.18,0.3,cp +0.26,0.4,0.48,0.5,0.36,0.26,0.37,cp +0.4,0.41,0.48,0.5,0.55,0.22,0.33,cp +0.22,0.34,0.48,0.5,0.42,0.29,0.39,cp +0.44,0.35,0.48,0.5,0.44,0.52,0.59,cp +0.27,0.42,0.48,0.5,0.37,0.38,0.43,cp +0.16,0.43,0.48,0.5,0.54,0.27,0.37,cp +0.06,0.61,0.48,0.5,0.49,0.92,0.37,im +0.44,0.52,0.48,0.5,0.43,0.47,0.54,im +0.63,0.47,0.48,0.5,0.51,0.82,0.84,im +0.23,0.48,0.48,0.5,0.59,0.88,0.89,im +0.34,0.49,0.48,0.5,0.58,0.85,0.8,im +0.43,0.4,0.48,0.5,0.58,0.75,0.78,im +0.46,0.61,0.48,0.5,0.48,0.86,0.87,im +0.27,0.35,0.48,0.5,0.51,0.77,0.79,im +0.52,0.39,0.48,0.5,0.65,0.71,0.73,im +0.29,0.47,0.48,0.5,0.71,0.65,0.69,im +0.55,0.47,0.48,0.5,0.57,0.78,0.8,im +0.12,0.67,0.48,0.5,0.74,0.58,0.63,im +0.4,0.5,0.48,0.5,0.65,0.82,0.84,im +0.73,0.36,0.48,0.5,0.53,0.91,0.92,im +0.84,0.44,0.48,0.5,0.48,0.71,0.74,im +0.48,0.45,0.48,0.5,0.6,0.78,0.8,im +0.54,0.49,0.48,0.5,0.4,0.87,0.88,im +0.48,0.41,0.48,0.5,0.51,0.9,0.88,im +0.5,0.66,0.48,0.5,0.31,0.92,0.92,im +0.72,0.46,0.48,0.5,0.51,0.66,0.7,im +0.47,0.55,0.48,0.5,0.58,0.71,0.75,im +0.33,0.56,0.48,0.5,0.33,0.78,0.8,im +0.64,0.58,0.48,0.5,0.48,0.78,0.73,im +0.54,0.57,0.48,0.5,0.56,0.81,0.83,im +0.47,0.59,0.48,0.5,0.52,0.76,0.79,im +0.63,0.5,0.48,0.5,0.59,0.85,0.86,im +0.49,0.42,0.48,0.5,0.53,0.79,0.81,im +0.31,0.5,0.48,0.5,0.57,0.84,0.85,im +0.74,0.44,0.48,0.5,0.55,0.88,0.89,im +0.33,0.45,0.48,0.5,0.45,0.88,0.89,im +0.45,0.4,0.48,0.5,0.61,0.74,0.77,im +0.71,0.4,0.48,0.5,0.71,0.7,0.74,im +0.5,0.37,0.48,0.5,0.66,0.64,0.69,im +0.66,0.53,0.48,0.5,0.59,0.66,0.66,im +0.6,0.61,0.48,0.5,0.54,0.67,0.71,im +0.83,0.37,0.48,0.5,0.61,0.71,0.74,im +0.34,0.51,0.48,0.5,0.67,0.9,0.9,im +0.63,0.54,0.48,0.5,0.65,0.79,0.81,im +0.7,0.4,0.48,0.5,0.56,0.86,0.83,im +0.6,0.5,1,0.5,0.54,0.77,0.8,im +0.16,0.51,0.48,0.5,0.33,0.39,0.48,im +0.74,0.7,0.48,0.5,0.66,0.65,0.69,im +0.2,0.46,0.48,0.5,0.57,0.78,0.81,im +0.89,0.55,0.48,0.5,0.51,0.72,0.76,im +0.7,0.46,0.48,0.5,0.56,0.78,0.73,im +0.12,0.43,0.48,0.5,0.63,0.7,0.74,im +0.61,0.52,0.48,0.5,0.54,0.67,0.52,im +0.33,0.37,0.48,0.5,0.46,0.65,0.69,im +0.63,0.65,0.48,0.5,0.66,0.67,0.71,im +0.41,0.51,0.48,0.5,0.53,0.75,0.78,im +0.34,0.67,0.48,0.5,0.52,0.76,0.79,im +0.58,0.34,0.48,0.5,0.56,0.87,0.81,im +0.59,0.56,0.48,0.5,0.55,0.8,0.82,im +0.51,0.4,0.48,0.5,0.57,0.62,0.67,im +0.5,0.57,0.48,0.5,0.71,0.61,0.66,im +0.6,0.46,0.48,0.5,0.45,0.81,0.83,im +0.37,0.47,0.48,0.5,0.39,0.76,0.79,im +0.58,0.55,0.48,0.5,0.57,0.7,0.74,im +0.36,0.47,0.48,0.5,0.51,0.69,0.72,im +0.39,0.41,0.48,0.5,0.52,0.72,0.75,im +0.35,0.51,0.48,0.5,0.61,0.71,0.74,im +0.31,0.44,0.48,0.5,0.5,0.79,0.82,im +0.61,0.66,0.48,0.5,0.46,0.87,0.88,im +0.48,0.49,0.48,0.5,0.52,0.77,0.71,im +0.11,0.5,0.48,0.5,0.58,0.72,0.68,im +0.31,0.36,0.48,0.5,0.58,0.94,0.94,im +0.68,0.51,0.48,0.5,0.71,0.75,0.78,im +0.69,0.39,0.48,0.5,0.57,0.76,0.79,im +0.52,0.54,0.48,0.5,0.62,0.76,0.79,im +0.46,0.59,0.48,0.5,0.36,0.76,0.23,im +0.36,0.45,0.48,0.5,0.38,0.79,0.17,im +0,0.51,0.48,0.5,0.35,0.67,0.44,im +0.1,0.49,0.48,0.5,0.41,0.67,0.21,im +0.3,0.51,0.48,0.5,0.42,0.61,0.34,im +0.61,0.47,0.48,0.5,0,0.8,0.32,im +0.63,0.75,0.48,0.5,0.64,0.73,0.66,im +0.71,0.52,0.48,0.5,0.64,1,0.99,im +0.85,0.53,0.48,0.5,0.53,0.52,0.35,imS +0.63,0.49,0.48,0.5,0.54,0.76,0.79,imS +0.75,0.55,1,1,0.4,0.47,0.3,imL +0.7,0.39,1,0.5,0.51,0.82,0.84,imL +0.72,0.42,0.48,0.5,0.65,0.77,0.79,imU +0.79,0.41,0.48,0.5,0.66,0.81,0.83,imU +0.83,0.48,0.48,0.5,0.65,0.76,0.79,imU +0.69,0.43,0.48,0.5,0.59,0.74,0.77,imU +0.79,0.36,0.48,0.5,0.46,0.82,0.7,imU +0.78,0.33,0.48,0.5,0.57,0.77,0.79,imU +0.75,0.37,0.48,0.5,0.64,0.7,0.74,imU +0.59,0.29,0.48,0.5,0.64,0.75,0.77,imU +0.67,0.37,0.48,0.5,0.54,0.64,0.68,imU +0.66,0.48,0.48,0.5,0.54,0.7,0.74,imU +0.64,0.46,0.48,0.5,0.48,0.73,0.76,imU +0.76,0.71,0.48,0.5,0.5,0.71,0.75,imU +0.84,0.49,0.48,0.5,0.55,0.78,0.74,imU +0.77,0.55,0.48,0.5,0.51,0.78,0.74,imU +0.81,0.44,0.48,0.5,0.42,0.67,0.68,imU +0.58,0.6,0.48,0.5,0.59,0.73,0.76,imU +0.63,0.42,0.48,0.5,0.48,0.77,0.8,imU +0.62,0.42,0.48,0.5,0.58,0.79,0.81,imU +0.86,0.39,0.48,0.5,0.59,0.89,0.9,imU +0.81,0.53,0.48,0.5,0.57,0.87,0.88,imU +0.87,0.49,0.48,0.5,0.61,0.76,0.79,imU +0.47,0.46,0.48,0.5,0.62,0.74,0.77,imU +0.76,0.41,0.48,0.5,0.5,0.59,0.62,imU +0.7,0.53,0.48,0.5,0.7,0.86,0.87,imU +0.64,0.45,0.48,0.5,0.67,0.61,0.66,imU +0.81,0.52,0.48,0.5,0.57,0.78,0.8,imU +0.73,0.26,0.48,0.5,0.57,0.75,0.78,imU +0.49,0.61,1,0.5,0.56,0.71,0.74,imU +0.88,0.42,0.48,0.5,0.52,0.73,0.75,imU +0.84,0.54,0.48,0.5,0.75,0.92,0.7,imU +0.63,0.51,0.48,0.5,0.64,0.72,0.76,imU +0.86,0.55,0.48,0.5,0.63,0.81,0.83,imU +0.79,0.54,0.48,0.5,0.5,0.66,0.68,imU +0.57,0.38,0.48,0.5,0.06,0.49,0.33,imU +0.78,0.44,0.48,0.5,0.45,0.73,0.68,imU +0.78,0.68,0.48,0.5,0.83,0.4,0.29,om +0.63,0.69,0.48,0.5,0.65,0.41,0.28,om +0.67,0.88,0.48,0.5,0.73,0.5,0.25,om +0.61,0.75,0.48,0.5,0.51,0.33,0.33,om +0.67,0.84,0.48,0.5,0.74,0.54,0.37,om +0.74,0.9,0.48,0.5,0.57,0.53,0.29,om +0.73,0.84,0.48,0.5,0.86,0.58,0.29,om +0.75,0.76,0.48,0.5,0.83,0.57,0.3,om +0.77,0.57,0.48,0.5,0.88,0.53,0.2,om +0.74,0.78,0.48,0.5,0.75,0.54,0.15,om +0.68,0.76,0.48,0.5,0.84,0.45,0.27,om +0.56,0.68,0.48,0.5,0.77,0.36,0.45,om +0.65,0.51,0.48,0.5,0.66,0.54,0.33,om +0.52,0.81,0.48,0.5,0.72,0.38,0.38,om +0.64,0.57,0.48,0.5,0.7,0.33,0.26,om +0.6,0.76,1,0.5,0.77,0.59,0.52,om +0.69,0.59,0.48,0.5,0.77,0.39,0.21,om +0.63,0.49,0.48,0.5,0.79,0.45,0.28,om +0.71,0.71,0.48,0.5,0.68,0.43,0.36,om +0.68,0.63,0.48,0.5,0.73,0.4,0.3,om +0.77,0.57,1,0.5,0.37,0.54,0.01,omL +0.66,0.49,1,0.5,0.54,0.56,0.36,omL +0.71,0.46,1,0.5,0.52,0.59,0.3,omL +0.67,0.55,1,0.5,0.66,0.58,0.16,omL +0.68,0.49,1,0.5,0.62,0.55,0.28,omL +0.74,0.49,0.48,0.5,0.42,0.54,0.36,pp +0.7,0.61,0.48,0.5,0.56,0.52,0.43,pp +0.66,0.86,0.48,0.5,0.34,0.41,0.36,pp +0.73,0.78,0.48,0.5,0.58,0.51,0.31,pp +0.65,0.57,0.48,0.5,0.47,0.47,0.51,pp +0.72,0.86,0.48,0.5,0.17,0.55,0.21,pp +0.67,0.7,0.48,0.5,0.46,0.45,0.33,pp +0.67,0.81,0.48,0.5,0.54,0.49,0.23,pp +0.67,0.61,0.48,0.5,0.51,0.37,0.38,pp +0.63,1,0.48,0.5,0.35,0.51,0.49,pp +0.57,0.59,0.48,0.5,0.39,0.47,0.33,pp +0.71,0.71,0.48,0.5,0.4,0.54,0.39,pp +0.66,0.74,0.48,0.5,0.31,0.38,0.43,pp +0.67,0.81,0.48,0.5,0.25,0.42,0.25,pp +0.64,0.72,0.48,0.5,0.49,0.42,0.19,pp +0.68,0.82,0.48,0.5,0.38,0.65,0.56,pp +0.32,0.39,0.48,0.5,0.53,0.28,0.38,pp +0.7,0.64,0.48,0.5,0.47,0.51,0.47,pp +0.63,0.57,0.48,0.5,0.49,0.7,0.2,pp +0.74,0.82,0.48,0.5,0.49,0.49,0.41,pp +0.63,0.86,0.48,0.5,0.39,0.47,0.34,pp +0.63,0.83,0.48,0.5,0.4,0.39,0.19,pp +0.63,0.71,0.48,0.5,0.6,0.4,0.39,pp +0.71,0.86,0.48,0.5,0.4,0.54,0.32,pp +0.68,0.78,0.48,0.5,0.43,0.44,0.42,pp +0.64,0.84,0.48,0.5,0.37,0.45,0.4,pp +0.74,0.47,0.48,0.5,0.5,0.57,0.42,pp +0.75,0.84,0.48,0.5,0.35,0.52,0.33,pp +0.63,0.65,0.48,0.5,0.39,0.44,0.35,pp +0.69,0.67,0.48,0.5,0.3,0.39,0.24,pp +0.7,0.71,0.48,0.5,0.42,0.84,0.85,pp +0.69,0.8,0.48,0.5,0.46,0.57,0.26,pp +0.64,0.66,0.48,0.5,0.41,0.39,0.2,pp +0.63,0.8,0.48,0.5,0.46,0.31,0.29,pp +0.66,0.71,0.48,0.5,0.41,0.5,0.35,pp +0.69,0.59,0.48,0.5,0.46,0.44,0.52,pp +0.68,0.67,0.48,0.5,0.49,0.4,0.34,pp +0.64,0.78,0.48,0.5,0.5,0.36,0.38,pp +0.62,0.78,0.48,0.5,0.47,0.49,0.54,pp +0.76,0.73,0.48,0.5,0.44,0.39,0.39,pp +0.64,0.81,0.48,0.5,0.37,0.39,0.44,pp +0.29,0.39,0.48,0.5,0.52,0.4,0.48,pp +0.62,0.83,0.48,0.5,0.46,0.36,0.4,pp +0.56,0.54,0.48,0.5,0.43,0.37,0.3,pp +0.69,0.66,0.48,0.5,0.41,0.5,0.25,pp +0.69,0.65,0.48,0.5,0.63,0.48,0.41,pp +0.43,0.59,0.48,0.5,0.52,0.49,0.56,pp +0.74,0.56,0.48,0.5,0.47,0.68,0.3,pp +0.71,0.57,0.48,0.5,0.48,0.35,0.32,pp +0.61,0.6,0.48,0.5,0.44,0.39,0.38,pp +0.59,0.61,0.48,0.5,0.42,0.42,0.37,pp +0.74,0.74,0.48,0.5,0.31,0.53,0.52,pp diff --git a/sample/main.cc b/sample/main.cc index 8d65ba0..4f79a9b 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -138,6 +138,7 @@ pair get_options(int argc, char** argv) { map datasets = { {"diabetes", true}, + {"ecoli", true}, {"glass", true}, {"iris", true}, {"kdd_JapaneseVowels", false}, @@ -229,5 +230,6 @@ int main(int argc, char** argv) cout << "BayesNet version: " << network.version() << endl; unsigned int nthreads = std::thread::hardware_concurrency(); cout << "Computer has " << nthreads << " cores." << endl; + cout << "conditionalEdgeWeight " << endl << network.conditionalEdgeWeight() << endl; return 0; } \ No newline at end of file diff --git a/sample/test.cc b/sample/test.cc index 334e6ba..e6b946f 100644 --- a/sample/test.cc +++ b/sample/test.cc @@ -25,36 +25,154 @@ #include #include using namespace std; +double entropy(torch::Tensor feature) +{ + torch::Tensor counts = feature.bincount(); + int totalWeight = counts.sum().item(); + torch::Tensor probs = counts.to(torch::kFloat) / totalWeight; + torch::Tensor logProbs = torch::log2(probs); + torch::Tensor entropy = -probs * logProbs; + return entropy.sum().item(); +} +// H(Y|X) = sum_{x in X} p(x) H(Y|X=x) +double conditionalEntropy(torch::Tensor firstFeature, torch::Tensor secondFeature) +{ + int numSamples = firstFeature.sizes()[0]; + torch::Tensor featureCounts = secondFeature.bincount(); + unordered_map> jointCounts; + double totalWeight = 0; + for (auto i = 0; i < numSamples; i++) { + jointCounts[secondFeature[i].item()][firstFeature[i].item()] += 1; + totalWeight += 1; + } + if (totalWeight == 0) + throw invalid_argument("Total weight should not be zero"); + double entropy = 0; + for (int value = 0; value < featureCounts.sizes()[0]; ++value) { + double p_f = featureCounts[value].item() / totalWeight; + double entropy_f = 0; + for (auto& [label, jointCount] : jointCounts[value]) { + double p_l_f = jointCount / featureCounts[value].item(); + if (p_l_f > 0) { + entropy_f -= p_l_f * log2(p_l_f); + } else { + entropy_f = 0; + } + } + entropy += p_f * entropy_f; + } + return entropy; +} + +// I(X;Y) = H(Y) - H(Y|X) +double mutualInformation(torch::Tensor firstFeature, torch::Tensor secondFeature) +{ + return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature); +} +double entropy2(torch::Tensor feature) +{ + return torch::special::entr(feature).sum().item(); +} int main() { //int i = 3, j = 1, k = 2; // Indices for the cell you want to update // Print original tensor // torch::Tensor t = torch::tensor({ {1, 2, 3}, {4, 5, 6} }); // 3D tensor for this example - auto variables = vector{ "A", "B" }; - auto cardinalities = vector{ 5, 4 }; - torch::Tensor values = torch::rand({ 5, 4 }); - auto candidate = "B"; - vector newVariables; - vector newCardinalities; - for (int i = 0; i < variables.size(); i++) { - if (variables[i] != candidate) { - newVariables.push_back(variables[i]); - newCardinalities.push_back(cardinalities[i]); - } - } - torch::Tensor newValues = values.sum(1); - cout << "original values" << endl; - cout << values << endl; - cout << "newValues" << endl; - cout << newValues << endl; - cout << "newVariables" << endl; - for (auto& variable : newVariables) { - cout << variable << endl; - } - cout << "newCardinalities" << endl; - for (auto& cardinality : newCardinalities) { - cout << cardinality << endl; + // auto variables = vector{ "A", "B" }; + // auto cardinalities = vector{ 5, 4 }; + // torch::Tensor values = torch::rand({ 5, 4 }); + // auto candidate = "B"; + // vector newVariables; + // vector newCardinalities; + // for (int i = 0; i < variables.size(); i++) { + // if (variables[i] != candidate) { + // newVariables.push_back(variables[i]); + // newCardinalities.push_back(cardinalities[i]); + // } + // } + // torch::Tensor newValues = values.sum(1); + // cout << "original values" << endl; + // cout << values << endl; + // cout << "newValues" << endl; + // cout << newValues << endl; + // cout << "newVariables" << endl; + // for (auto& variable : newVariables) { + // cout << variable << endl; + // } + // cout << "newCardinalities" << endl; + // for (auto& cardinality : newCardinalities) { + // cout << cardinality << endl; + // } + // auto row2 = values.index({ torch::tensor(1) }); // + // cout << "row2" << endl; + // cout << row2 << endl; + // auto col2 = values.index({ "...", 1 }); + // cout << "col2" << endl; + // cout << col2 << endl; + // auto col_last = values.index({ "...", -1 }); + // cout << "col_last" << endl; + // cout << col_last << endl; + // values.index_put_({ "...", -1 }, torch::tensor({ 1,2,3,4,5 })); + // cout << "col_last" << endl; + // cout << col_last << endl; + // auto slice2 = values.index({ torch::indexing::Slice(1, torch::indexing::None) }); + // cout << "slice2" << endl; + // cout << slice2 << endl; + // auto mask = values.index({ "...", -1 }) % 2 == 0; + // auto filter = values.index({ mask, 2 }); // Filter values + // cout << "filter" << endl; + // cout << filter << endl; + // torch::Tensor dataset = torch::tensor({ {1,0,0,1},{1,1,1,2},{0,0,0,1},{1,0,2,0},{0,0,3,0} }); + // cout << "dataset" << endl; + // cout << dataset << endl; + // cout << "entropy(dataset.indices('...', 2))" << endl; + // cout << dataset.index({ "...", 2 }) << endl; + // cout << "*********************************" << endl; + // for (int i = 0; i < 4; i++) { + // cout << "datset(" << i << ")" << endl; + // cout << dataset.index({ "...", i }) << endl; + // cout << "entropy(" << i << ")" << endl; + // cout << entropy(dataset.index({ "...", i })) << endl; + // } + // cout << "......................................" << endl; + // //cout << entropy2(dataset.index({ "...", 2 })); + // cout << "conditional entropy 0 2" << endl; + // cout << conditionalEntropy(dataset.index({ "...", 0 }), dataset.index({ "...", 2 })) << endl; + // cout << "mutualInformation(dataset.index({ '...', 0 }), dataset.index({ '...', 2 }))" << endl; + // cout << mutualInformation(dataset.index({ "...", 0 }), dataset.index({ "...", 2 })) << endl; + // auto test = torch::tensor({ .1, .2, .3 }, torch::kFloat); + // auto result = torch::zeros({ 3, 3 }, torch::kFloat); + // result.index_put_({ indices }, test); + // cout << "indices" << endl; + // cout << indices << endl; + // cout << "result" << endl; + // cout << result << endl; + // cout << "Test" << endl; + // cout << torch::triu(test.reshape(3, 3), torch::kFloat)) << endl; + + + // Create a 3x3 tensor with zeros + torch::Tensor tensor_3x3 = torch::zeros({ 3, 3 }, torch::kFloat); + + // Create a 1D tensor with the three elements you want to set in the upper corner + torch::Tensor tensor_1d = torch::tensor({ 10, 11, 12 }, torch::kFloat); + + // Set the upper corner of the 3x3 tensor + auto indices = torch::triu_indices(3, 3, 1); + for (auto i = 0; i < tensor_1d.sizes()[0]; ++i) { + auto x = indices[0][i]; + auto y = indices[1][i]; + tensor_3x3[x][y] = tensor_1d[i]; + tensor_3x3[y][x] = tensor_1d[i]; } + // Print the resulting 3x3 tensor + std::cout << tensor_3x3 << std::endl; + + + + + + // std::cout << t << std::endl; // std::cout << "sum(0)" << std::endl; // std::cout << t.sum(0) << std::endl; diff --git a/src/Network.cc b/src/Network.cc index 4c039b4..8d336ac 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -98,11 +98,14 @@ namespace bayesnet { this->className = className; dataset.clear(); - // Build dataset + // Build dataset & tensor of samples + samples = torch::zeros({ static_cast(input_data[0].size()), static_cast(input_data.size() + 1) }, torch::kInt64); for (int i = 0; i < featureNames.size(); ++i) { dataset[featureNames[i]] = input_data[i]; + samples.index_put_({ "...", i }, torch::tensor(input_data[i], torch::kInt64)); } dataset[className] = labels; + samples.index_put_({ "...", -1 }, torch::tensor(labels, torch::kInt64)); classNumStates = *max_element(labels.begin(), labels.end()) + 1; int maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); if (maxThreadsRunning < 1) { @@ -150,14 +153,14 @@ namespace bayesnet { } } - vector Network::predict(const vector>& samples) + vector Network::predict(const vector>& tsamples) { vector predictions; vector sample; - for (int row = 0; row < samples[0].size(); ++row) { + for (int row = 0; row < tsamples[0].size(); ++row) { sample.clear(); - for (int col = 0; col < samples.size(); ++col) { - sample.push_back(samples[col][row]); + for (int col = 0; col < tsamples.size(); ++col) { + sample.push_back(tsamples[col][row]); } vector classProbabilities = predict_sample(sample); // Find the class with the maximum posterior probability @@ -167,22 +170,22 @@ namespace bayesnet { } return predictions; } - vector> Network::predict_proba(const vector>& samples) + vector> Network::predict_proba(const vector>& tsamples) { vector> predictions; vector sample; - for (int row = 0; row < samples[0].size(); ++row) { + for (int row = 0; row < tsamples[0].size(); ++row) { sample.clear(); - for (int col = 0; col < samples.size(); ++col) { - sample.push_back(samples[col][row]); + for (int col = 0; col < tsamples.size(); ++col) { + sample.push_back(tsamples[col][row]); } predictions.push_back(predict_sample(sample)); } return predictions; } - double Network::score(const vector>& samples, const vector& labels) + double Network::score(const vector>& tsamples, const vector& labels) { - vector y_pred = predict(samples); + vector y_pred = predict(tsamples); int correct = 0; for (int i = 0; i < y_pred.size(); ++i) { if (y_pred[i] == labels[i]) { @@ -238,4 +241,83 @@ namespace bayesnet { } return result; } + double Network::mutual_info(torch::Tensor& first, torch::Tensor& second) + { + return 1; + } + torch::Tensor Network::conditionalEdgeWeight() + { + auto result = vector(); + auto source = vector(features); + source.push_back(className); + auto combinations = nodes[className]->combinations(source); + auto margin = nodes[className]->getCPT(); + for (auto [first, second] : combinations) { + int64_t index_first = find(features.begin(), features.end(), first) - features.begin(); + int64_t index_second = find(features.begin(), features.end(), second) - features.begin(); + double accumulated = 0; + for (int value = 0; value < classNumStates; ++value) { + auto mask = samples.index({ "...", -1 }) == value; + auto first_dataset = samples.index({ mask, index_first }); + auto second_dataset = samples.index({ mask, index_second }); + auto mi = mutualInformation(first_dataset, second_dataset); + auto pb = margin[value].item(); + accumulated += pb * mi; + } + result.push_back(accumulated); + } + long n_vars = source.size(); + auto matrix = torch::zeros({ n_vars, n_vars }); + auto indices = torch::triu_indices(n_vars, n_vars, 1); + for (auto i = 0; i < result.size(); ++i) { + auto x = indices[0][i]; + auto y = indices[1][i]; + matrix[x][y] = result[i]; + matrix[y][x] = result[i]; + } + return matrix; + } + double Network::entropy(torch::Tensor& feature) + { + torch::Tensor counts = feature.bincount(); + int totalWeight = counts.sum().item(); + torch::Tensor probs = counts.to(torch::kFloat) / totalWeight; + torch::Tensor logProbs = torch::log(probs); + torch::Tensor entropy = -probs * logProbs; + return entropy.nansum().item(); + } + // H(Y|X) = sum_{x in X} p(x) H(Y|X=x) + double Network::conditionalEntropy(torch::Tensor& firstFeature, torch::Tensor& secondFeature) + { + int numSamples = firstFeature.sizes()[0]; + torch::Tensor featureCounts = secondFeature.bincount(); + unordered_map> jointCounts; + double totalWeight = 0; + for (auto i = 0; i < numSamples; i++) { + jointCounts[secondFeature[i].item()][firstFeature[i].item()] += 1; + totalWeight += 1; + } + if (totalWeight == 0) + throw invalid_argument("Total weight should not be zero"); + double entropyValue = 0; + for (int value = 0; value < featureCounts.sizes()[0]; ++value) { + double p_f = featureCounts[value].item() / totalWeight; + double entropy_f = 0; + for (auto& [label, jointCount] : jointCounts[value]) { + double p_l_f = jointCount / featureCounts[value].item(); + if (p_l_f > 0) { + entropy_f -= p_l_f * log(p_l_f); + } else { + entropy_f = 0; + } + } + entropyValue += p_f * entropy_f; + } + return entropyValue; + } + // I(X;Y) = H(Y) - H(Y|X) + double Network::mutualInformation(torch::Tensor& firstFeature, torch::Tensor& secondFeature) + { + return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature); + } } diff --git a/src/Network.h b/src/Network.h index 7db04f9..0ba6783 100644 --- a/src/Network.h +++ b/src/Network.h @@ -19,7 +19,12 @@ namespace bayesnet { vector predict_sample(const vector&); vector exactInference(map&); double computeFactor(map&); + double mutual_info(torch::Tensor&, torch::Tensor&); + double entropy(torch::Tensor&); + double conditionalEntropy(torch::Tensor&, torch::Tensor&); + double mutualInformation(torch::Tensor&, torch::Tensor&); public: + torch::Tensor samples; Network(); Network(float, int); Network(float); @@ -35,6 +40,8 @@ namespace bayesnet { string getClassName(); void fit(const vector>&, const vector&, const vector&, const string&); vector predict(const vector>&); + //Computes the conditional edge weight of variable index u and v conditioned on class_node + torch::Tensor conditionalEdgeWeight(); vector> predict_proba(const vector>&); double score(const vector>&, const vector&); inline string version() { return "0.1.0"; } diff --git a/src/Node.cc b/src/Node.cc index 075353d..2c5a04d 100644 --- a/src/Node.cc +++ b/src/Node.cc @@ -57,23 +57,23 @@ namespace bayesnet { */ unsigned Node::minFill() { - set neighbors; + unordered_set neighbors; for (auto child : children) { neighbors.emplace(child->getName()); } for (auto parent : parents) { neighbors.emplace(parent->getName()); } - return combinations(neighbors).size(); + auto source = vector(neighbors.begin(), neighbors.end()); + return combinations(source).size(); } - vector Node::combinations(const set& neighbors) + vector> Node::combinations(const vector& source) { - vector source(neighbors.begin(), neighbors.end()); - vector result; + vector> result; for (int i = 0; i < source.size(); ++i) { string temp = source[i]; for (int j = i + 1; j < source.size(); ++j) { - result.push_back(temp + source[j]); + result.push_back({ temp, source[j] }); } } return result; diff --git a/src/Node.h b/src/Node.h index 39189ce..c7961aa 100644 --- a/src/Node.h +++ b/src/Node.h @@ -1,6 +1,7 @@ #ifndef NODE_H #define NODE_H #include +#include #include #include namespace bayesnet { @@ -13,8 +14,8 @@ namespace bayesnet { int numStates; // number of states of the variable torch::Tensor cpTable; // Order of indices is 0-> node variable, 1-> 1st parent, 2-> 2nd parent, ... vector dimensions; // dimensions of the cpTable - vector combinations(const set&); public: + vector> combinations(const vector&); Node(const std::string&, int); void addParent(Node*); void addChild(Node*);