mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-17 16:45:54 +00:00
Update BayesNetwork class
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -7,10 +7,10 @@ from libcpp.string cimport string
|
|||||||
cdef extern from "Network.h" namespace "bayesnet":
|
cdef extern from "Network.h" namespace "bayesnet":
|
||||||
cdef cppclass Network:
|
cdef cppclass Network:
|
||||||
Network(float, float) except +
|
Network(float, float) except +
|
||||||
void fit(vector[vector[int]], vector[int], vector[string], string)
|
void fit(vector[vector[int]]&, vector[int]&, vector[string]&, string)
|
||||||
vector[int] predict(vector[vector[int]])
|
vector[int] predict(vector[vector[int]]&)
|
||||||
vector[vector[double]] predict_proba(vector[vector[int]])
|
vector[vector[double]] predict_proba(vector[vector[int]]&)
|
||||||
float score(const vector[vector[int]], const vector[int])
|
float score(const vector[vector[int]]&, const vector[int]&)
|
||||||
void addNode(string, int);
|
void addNode(string, int);
|
||||||
void addEdge(string, string);
|
void addEdge(string, string);
|
||||||
vector[string] getFeatures();
|
vector[string] getFeatures();
|
||||||
@@ -25,7 +25,8 @@ cdef class BayesNetwork:
|
|||||||
def __dealloc__(self):
|
def __dealloc__(self):
|
||||||
del self.thisptr
|
del self.thisptr
|
||||||
def fit(self, X, y, features, className):
|
def fit(self, X, y, features, className):
|
||||||
self.thisptr.fit(X, y, features, className)
|
features_bytes = [x.encode() for x in features]
|
||||||
|
self.thisptr.fit(X, y, features_bytes, className.encode())
|
||||||
return self
|
return self
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
return self.thisptr.predict(X)
|
return self.thisptr.predict(X)
|
||||||
@@ -38,9 +39,10 @@ cdef class BayesNetwork:
|
|||||||
def addEdge(self, source, destination):
|
def addEdge(self, source, destination):
|
||||||
self.thisptr.addEdge(str.encode(source), str.encode(destination))
|
self.thisptr.addEdge(str.encode(source), str.encode(destination))
|
||||||
def getFeatures(self):
|
def getFeatures(self):
|
||||||
return self.thisptr.getFeatures()
|
res = self.thisptr.getFeatures()
|
||||||
|
return [x.decode() for x in res]
|
||||||
def getClassName(self):
|
def getClassName(self):
|
||||||
return self.thisptr.getClassName()
|
return self.thisptr.getClassName().decode()
|
||||||
def getClassNumStates(self):
|
def getClassNumStates(self):
|
||||||
return self.thisptr.getClassNumStates()
|
return self.thisptr.getClassNumStates()
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
|
Reference in New Issue
Block a user