Update BayesNetwork class

This commit is contained in:
2023-07-08 00:39:10 +02:00
parent 4bad5ccfee
commit 8a9c86a22d
2 changed files with 419 additions and 251 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -7,10 +7,10 @@ from libcpp.string cimport string
cdef extern from "Network.h" namespace "bayesnet":
cdef cppclass Network:
Network(float, float) except +
void fit(vector[vector[int]], vector[int], vector[string], string)
vector[int] predict(vector[vector[int]])
vector[vector[double]] predict_proba(vector[vector[int]])
float score(const vector[vector[int]], const vector[int])
void fit(vector[vector[int]]&, vector[int]&, vector[string]&, string)
vector[int] predict(vector[vector[int]]&)
vector[vector[double]] predict_proba(vector[vector[int]]&)
float score(const vector[vector[int]]&, const vector[int]&)
void addNode(string, int);
void addEdge(string, string);
vector[string] getFeatures();
@@ -25,7 +25,8 @@ cdef class BayesNetwork:
def __dealloc__(self):
del self.thisptr
def fit(self, X, y, features, className):
self.thisptr.fit(X, y, features, className)
features_bytes = [x.encode() for x in features]
self.thisptr.fit(X, y, features_bytes, className.encode())
return self
def predict(self, X):
return self.thisptr.predict(X)
@@ -38,9 +39,10 @@ cdef class BayesNetwork:
def addEdge(self, source, destination):
self.thisptr.addEdge(str.encode(source), str.encode(destination))
def getFeatures(self):
return self.thisptr.getFeatures()
res = self.thisptr.getFeatures()
return [x.decode() for x in res]
def getClassName(self):
return self.thisptr.getClassName()
return self.thisptr.getClassName().decode()
def getClassNumStates(self):
return self.thisptr.getClassNumStates()
def __reduce__(self):