Update BayesNetwork class

This commit is contained in:
2023-07-08 00:39:10 +02:00
parent 4bad5ccfee
commit 8a9c86a22d
2 changed files with 419 additions and 251 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -7,10 +7,10 @@ from libcpp.string cimport string
cdef extern from "Network.h" namespace "bayesnet": cdef extern from "Network.h" namespace "bayesnet":
cdef cppclass Network: cdef cppclass Network:
Network(float, float) except + Network(float, float) except +
void fit(vector[vector[int]], vector[int], vector[string], string) void fit(vector[vector[int]]&, vector[int]&, vector[string]&, string)
vector[int] predict(vector[vector[int]]) vector[int] predict(vector[vector[int]]&)
vector[vector[double]] predict_proba(vector[vector[int]]) vector[vector[double]] predict_proba(vector[vector[int]]&)
float score(const vector[vector[int]], const vector[int]) float score(const vector[vector[int]]&, const vector[int]&)
void addNode(string, int); void addNode(string, int);
void addEdge(string, string); void addEdge(string, string);
vector[string] getFeatures(); vector[string] getFeatures();
@@ -25,7 +25,8 @@ cdef class BayesNetwork:
def __dealloc__(self): def __dealloc__(self):
del self.thisptr del self.thisptr
def fit(self, X, y, features, className): def fit(self, X, y, features, className):
self.thisptr.fit(X, y, features, className) features_bytes = [x.encode() for x in features]
self.thisptr.fit(X, y, features_bytes, className.encode())
return self return self
def predict(self, X): def predict(self, X):
return self.thisptr.predict(X) return self.thisptr.predict(X)
@@ -38,9 +39,10 @@ cdef class BayesNetwork:
def addEdge(self, source, destination): def addEdge(self, source, destination):
self.thisptr.addEdge(str.encode(source), str.encode(destination)) self.thisptr.addEdge(str.encode(source), str.encode(destination))
def getFeatures(self): def getFeatures(self):
return self.thisptr.getFeatures() res = self.thisptr.getFeatures()
return [x.decode() for x in res]
def getClassName(self): def getClassName(self):
return self.thisptr.getClassName() return self.thisptr.getClassName().decode()
def getClassNumStates(self): def getClassNumStates(self):
return self.thisptr.getClassNumStates() return self.thisptr.getClassNumStates()
def __reduce__(self): def __reduce__(self):