Complete integration with BayesNet

This commit is contained in:
2023-07-07 19:19:52 +02:00
parent 5866e19fae
commit 4bad5ccfee
4 changed files with 620 additions and 269 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -2,13 +2,8 @@
# cython: language_level = 3
from libcpp.vector cimport vector
from libcpp.string cimport string
from libcpp.pair cimport pair
from libcpp cimport bool
cdef extern from "Node.h" namespace "bayesnet":
cdef cppclass Node:
pass
cdef extern from "Network.h" namespace "bayesnet":
cdef cppclass Network:
Network(float, float) except +
@@ -39,9 +34,9 @@ cdef class BayesNetwork:
def score(self, X, y):
return self.thisptr.score(X, y)
def addNode(self, name, states):
self.thisptr.addNode(name, states)
self.thisptr.addNode(str.encode(name), states)
def addEdge(self, source, destination):
self.thisptr.addEdge(source, destination)
self.thisptr.addEdge(str.encode(source), str.encode(destination))
def getFeatures(self):
return self.thisptr.getFeatures()
def getClassName(self):

View File

@@ -39,9 +39,7 @@ classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
[project.optional-dependencies]
@@ -61,7 +59,7 @@ show_missing = true
[tool.black]
line-length = 79
target_version = ['py38', 'py39', 'py310']
target_version = ['py311']
include = '\.pyi?$'
exclude = '''
/(

View File

@@ -5,7 +5,12 @@
"""
from setuptools import Extension, setup
from torch.utils import cpp_extension
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
include_paths,
)
setup(
ext_modules=[
@@ -21,18 +26,15 @@ setup(
"-std=c++17",
],
),
Extension(
name="bayesclass.cppBayesNetwork",
CppExtension(
name="bayesclass.BayesNet",
sources=[
"bayesclass/BayesNetwork.pyx",
"bayesclass/Network.cc",
"bayesclass/Node.cc",
],
include_dirs=cpp_extension.include_paths(),
language="c++",
extra_compile_args=[
"-std=c++17",
],
include_dirs=include_paths(),
),
]
],
cmdclass={"build_ext": BuildExtension},
)