Complete integration with BayesNet

This commit is contained in:
2023-07-07 19:19:52 +02:00
parent 5866e19fae
commit 4bad5ccfee
4 changed files with 620 additions and 269 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -2,13 +2,8 @@
# cython: language_level = 3 # cython: language_level = 3
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libcpp.string cimport string from libcpp.string cimport string
from libcpp.pair cimport pair
from libcpp cimport bool
cdef extern from "Node.h" namespace "bayesnet":
cdef cppclass Node:
pass
cdef extern from "Network.h" namespace "bayesnet": cdef extern from "Network.h" namespace "bayesnet":
cdef cppclass Network: cdef cppclass Network:
Network(float, float) except + Network(float, float) except +
@@ -39,9 +34,9 @@ cdef class BayesNetwork:
def score(self, X, y): def score(self, X, y):
return self.thisptr.score(X, y) return self.thisptr.score(X, y)
def addNode(self, name, states): def addNode(self, name, states):
self.thisptr.addNode(name, states) self.thisptr.addNode(str.encode(name), states)
def addEdge(self, source, destination): def addEdge(self, source, destination):
self.thisptr.addEdge(source, destination) self.thisptr.addEdge(str.encode(source), str.encode(destination))
def getFeatures(self): def getFeatures(self):
return self.thisptr.getFeatures() return self.thisptr.getFeatures()
def getClassName(self): def getClassName(self):

View File

@@ -39,9 +39,7 @@ classifiers = [
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python", "Programming Language :: Python",
"Programming Language :: Python", "Programming Language :: Python",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@@ -61,7 +59,7 @@ show_missing = true
[tool.black] [tool.black]
line-length = 79 line-length = 79
target_version = ['py38', 'py39', 'py310'] target_version = ['py311']
include = '\.pyi?$' include = '\.pyi?$'
exclude = ''' exclude = '''
/( /(

View File

@@ -5,7 +5,12 @@
""" """
from setuptools import Extension, setup from setuptools import Extension, setup
from torch.utils import cpp_extension from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
include_paths,
)
setup( setup(
ext_modules=[ ext_modules=[
@@ -21,18 +26,15 @@ setup(
"-std=c++17", "-std=c++17",
], ],
), ),
Extension( CppExtension(
name="bayesclass.cppBayesNetwork", name="bayesclass.BayesNet",
sources=[ sources=[
"bayesclass/BayesNetwork.pyx", "bayesclass/BayesNetwork.pyx",
"bayesclass/Network.cc", "bayesclass/Network.cc",
"bayesclass/Node.cc", "bayesclass/Node.cc",
], ],
include_dirs=cpp_extension.include_paths(), include_dirs=include_paths(),
language="c++",
extra_compile_args=[
"-std=c++17",
],
), ),
] ],
cmdclass={"build_ext": BuildExtension},
) )