Complete proposal
This commit is contained in:
@@ -17,14 +17,90 @@ namespace bayesnet {
|
||||
Network::Network() : fitted{ false }, classNumStates{ 0 }
|
||||
{
|
||||
}
|
||||
Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
|
||||
fitted(other.fitted), samples(other.samples)
|
||||
Network::Network(const Network& other)
|
||||
: features(other.features), className(other.className), classNumStates(other.classNumStates),
|
||||
fitted(other.fitted)
|
||||
{
|
||||
if (samples.defined())
|
||||
samples = samples.clone();
|
||||
// Deep copy the samples tensor
|
||||
if (other.samples.defined()) {
|
||||
samples = other.samples.clone();
|
||||
}
|
||||
|
||||
// First, create all nodes (without relationships)
|
||||
for (const auto& node : other.nodes) {
|
||||
nodes[node.first] = std::make_unique<Node>(*node.second);
|
||||
}
|
||||
|
||||
// Second, reconstruct the relationships between nodes
|
||||
for (const auto& node : other.nodes) {
|
||||
const std::string& nodeName = node.first;
|
||||
Node* originalNode = node.second.get();
|
||||
Node* newNode = nodes[nodeName].get();
|
||||
|
||||
// Reconstruct parent relationships
|
||||
for (Node* parent : originalNode->getParents()) {
|
||||
const std::string& parentName = parent->getName();
|
||||
if (nodes.find(parentName) != nodes.end()) {
|
||||
newNode->addParent(nodes[parentName].get());
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct child relationships
|
||||
for (Node* child : originalNode->getChildren()) {
|
||||
const std::string& childName = child->getName();
|
||||
if (nodes.find(childName) != nodes.end()) {
|
||||
newNode->addChild(nodes[childName].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Network& Network::operator=(const Network& other)
|
||||
{
|
||||
if (this != &other) {
|
||||
// Clear existing state
|
||||
nodes.clear();
|
||||
features = other.features;
|
||||
className = other.className;
|
||||
classNumStates = other.classNumStates;
|
||||
fitted = other.fitted;
|
||||
|
||||
// Deep copy the samples tensor
|
||||
if (other.samples.defined()) {
|
||||
samples = other.samples.clone();
|
||||
} else {
|
||||
samples = torch::Tensor();
|
||||
}
|
||||
|
||||
// First, create all nodes (without relationships)
|
||||
for (const auto& node : other.nodes) {
|
||||
nodes[node.first] = std::make_unique<Node>(*node.second);
|
||||
}
|
||||
|
||||
// Second, reconstruct the relationships between nodes
|
||||
for (const auto& node : other.nodes) {
|
||||
const std::string& nodeName = node.first;
|
||||
Node* originalNode = node.second.get();
|
||||
Node* newNode = nodes[nodeName].get();
|
||||
|
||||
// Reconstruct parent relationships
|
||||
for (Node* parent : originalNode->getParents()) {
|
||||
const std::string& parentName = parent->getName();
|
||||
if (nodes.find(parentName) != nodes.end()) {
|
||||
newNode->addParent(nodes[parentName].get());
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct child relationships
|
||||
for (Node* child : originalNode->getChildren()) {
|
||||
const std::string& childName = child->getName();
|
||||
if (nodes.find(childName) != nodes.end()) {
|
||||
newNode->addChild(nodes[childName].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
void Network::initialize()
|
||||
{
|
||||
@@ -503,4 +579,41 @@ namespace bayesnet {
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
bool Network::operator==(const Network& other) const
|
||||
{
|
||||
// Compare number of nodes
|
||||
if (nodes.size() != other.nodes.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compare if all node names exist in both networks
|
||||
for (const auto& node : nodes) {
|
||||
if (other.nodes.find(node.first) == other.nodes.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Compare edges (topology)
|
||||
auto thisEdges = getEdges();
|
||||
auto otherEdges = other.getEdges();
|
||||
|
||||
// Compare number of edges
|
||||
if (thisEdges.size() != otherEdges.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Sort both edge lists for comparison
|
||||
std::sort(thisEdges.begin(), thisEdges.end());
|
||||
std::sort(otherEdges.begin(), otherEdges.end());
|
||||
|
||||
// Compare each edge
|
||||
for (size_t i = 0; i < thisEdges.size(); ++i) {
|
||||
if (thisEdges[i] != otherEdges[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user