Add getStates

This commit is contained in:
2023-07-11 21:28:29 +02:00
parent 36cc875615
commit 8b6624e08a
5 changed files with 306 additions and 197 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -15,6 +15,7 @@ cdef extern from "Network.h" namespace "bayesnet":
void addEdge(string, string); void addEdge(string, string);
vector[string] getFeatures(); vector[string] getFeatures();
int getClassNumStates(); int getClassNumStates();
int getStates();
string getClassName(); string getClassName();
string version() string version()
@@ -45,6 +46,8 @@ cdef class BayesNetwork:
def getFeatures(self): def getFeatures(self):
res = self.thisptr.getFeatures() res = self.thisptr.getFeatures()
return [x.decode() for x in res] return [x.decode() for x in res]
def getStates(self):
return self.thisptr.getStates()
def getClassName(self): def getClassName(self):
return self.thisptr.getClassName().decode() return self.thisptr.getClassName().decode()
def getClassNumStates(self): def getClassNumStates(self):

View File

@@ -38,6 +38,14 @@ namespace bayesnet {
{ {
return classNumStates; return classNumStates;
} }
int Network::getStates()
{
int result = 0;
for (auto node : nodes) {
result += node.second->getNumStates();
}
return result;
}
string Network::getClassName() string Network::getClassName()
{ {
return className; return className;

View File

@@ -30,6 +30,7 @@ namespace bayesnet {
void addEdge(const string, const string); void addEdge(const string, const string);
map<string, Node*>& getNodes(); map<string, Node*>& getNodes();
vector<string> getFeatures(); vector<string> getFeatures();
int getStates();
int getClassNumStates(); int getClassNumStates();
string getClassName(); string getClassName();
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&); void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);

View File

@@ -95,7 +95,7 @@ class BayesBase(BaseEstimator, ClassifierMixin):
@property @property
def states_(self): def states_(self):
if hasattr(self, "fitted_"): if hasattr(self, "fitted_"):
return sum([len(item) for _, item in self.model_.states.items()]) return self.states_computed_
return 0 return 0
@property @property
@@ -180,14 +180,15 @@ class BayesBase(BaseEstimator, ClassifierMixin):
# ) # )
self.model_ = BayesNetwork() self.model_ = BayesNetwork()
features = kwargs["features"] features = kwargs["features"]
for i, feature in enumerate(features): states = kwargs["state_names"]
maxf = max(self.X_[:, i] + 1) for feature in features:
self.model_.addNode(feature, maxf) self.model_.addNode(feature, len(states[feature]))
class_name = kwargs["class_name"] class_name = kwargs["class_name"]
self.model_.addNode(class_name, max(self.y_) + 1) self.model_.addNode(class_name, max(self.y_) + 1)
for source, destination in self.dag_.edges(): for source, destination in self.dag_.edges():
self.model_.addEdge(source, destination) self.model_.addEdge(source, destination)
self.model_.fit(self.X_, self.y_, features, class_name) self.model_.fit(self.X_, self.y_, features, class_name)
self.states_computed_ = self.model_.getStates()
def predict(self, X): def predict(self, X):
"""A reference implementation of a prediction for a classifier. """A reference implementation of a prediction for a classifier.
@@ -381,7 +382,7 @@ class KDB(BayesBase):
def _build(self): def _build(self):
""" """
1. For each feature Xi, compute mutual information, I(X;;C), 1. For each feature Xi, compute mutual information, I(X;C),
where C is the class. where C is the class.
2. Compute class conditional mutual information I(Xi;XjIC), f or each 2. Compute class conditional mutual information I(Xi;XjIC), f or each
pair of features Xi and Xj, where i#j. pair of features Xi and Xj, where i#j.
@@ -407,6 +408,37 @@ class KDB(BayesBase):
)._get_conditional_weights( )._get_conditional_weights(
self.dataset_, self.class_name_, show_progress=self.show_progress self.dataset_, self.class_name_, show_progress=self.show_progress
) )
'''
# Step 1: Compute edge weights for a fully connected graph.
n_vars = len(data.columns)
pbar = combinations(data.columns, 2)
if show_progress and SHOW_PROGRESS:
pbar = tqdm(pbar, total=(n_vars * (n_vars - 1) / 2), desc="Building tree")
def _conditional_edge_weights_fn(u, v):
"""
Computes the conditional edge weight of variable index u and v conditioned on class_node
"""
cond_marginal = data.loc[:, class_node].value_counts() / data.shape[0]
cond_edge_weight = 0
for index, marg_prob in cond_marginal.items():
df_cond_subset = data[data.loc[:, class_node] == index]
cond_edge_weight += marg_prob * edge_weights_fn(
df_cond_subset.loc[:, u], df_cond_subset.loc[:, v]
)
return cond_edge_weight
vals = Parallel(n_jobs=1, prefer="threads")(
delayed(_conditional_edge_weights_fn)(u, v) for u, v in pbar
)
weights = np.zeros((n_vars, n_vars))
indices = np.triu_indices(n_vars, k=1)
weights[indices] = vals
weights.T[indices] = vals
return weights
'''
# 3. Let the used variable list, S, be empty. # 3. Let the used variable list, S, be empty.
S_nodes = [] S_nodes = []
# 4. Let the DAG being constructed, BN, begin with a single class node # 4. Let the DAG being constructed, BN, begin with a single class node