Complete proposal
This commit is contained in:
@@ -17,14 +17,90 @@ namespace bayesnet {
|
||||
Network::Network() : fitted{ false }, classNumStates{ 0 }
|
||||
{
|
||||
}
|
||||
Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
|
||||
fitted(other.fitted), samples(other.samples)
|
||||
Network::Network(const Network& other)
|
||||
: features(other.features), className(other.className), classNumStates(other.classNumStates),
|
||||
fitted(other.fitted)
|
||||
{
|
||||
if (samples.defined())
|
||||
samples = samples.clone();
|
||||
// Deep copy the samples tensor
|
||||
if (other.samples.defined()) {
|
||||
samples = other.samples.clone();
|
||||
}
|
||||
|
||||
// First, create all nodes (without relationships)
|
||||
for (const auto& node : other.nodes) {
|
||||
nodes[node.first] = std::make_unique<Node>(*node.second);
|
||||
}
|
||||
|
||||
// Second, reconstruct the relationships between nodes
|
||||
for (const auto& node : other.nodes) {
|
||||
const std::string& nodeName = node.first;
|
||||
Node* originalNode = node.second.get();
|
||||
Node* newNode = nodes[nodeName].get();
|
||||
|
||||
// Reconstruct parent relationships
|
||||
for (Node* parent : originalNode->getParents()) {
|
||||
const std::string& parentName = parent->getName();
|
||||
if (nodes.find(parentName) != nodes.end()) {
|
||||
newNode->addParent(nodes[parentName].get());
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct child relationships
|
||||
for (Node* child : originalNode->getChildren()) {
|
||||
const std::string& childName = child->getName();
|
||||
if (nodes.find(childName) != nodes.end()) {
|
||||
newNode->addChild(nodes[childName].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Network& Network::operator=(const Network& other)
|
||||
{
|
||||
if (this != &other) {
|
||||
// Clear existing state
|
||||
nodes.clear();
|
||||
features = other.features;
|
||||
className = other.className;
|
||||
classNumStates = other.classNumStates;
|
||||
fitted = other.fitted;
|
||||
|
||||
// Deep copy the samples tensor
|
||||
if (other.samples.defined()) {
|
||||
samples = other.samples.clone();
|
||||
} else {
|
||||
samples = torch::Tensor();
|
||||
}
|
||||
|
||||
// First, create all nodes (without relationships)
|
||||
for (const auto& node : other.nodes) {
|
||||
nodes[node.first] = std::make_unique<Node>(*node.second);
|
||||
}
|
||||
|
||||
// Second, reconstruct the relationships between nodes
|
||||
for (const auto& node : other.nodes) {
|
||||
const std::string& nodeName = node.first;
|
||||
Node* originalNode = node.second.get();
|
||||
Node* newNode = nodes[nodeName].get();
|
||||
|
||||
// Reconstruct parent relationships
|
||||
for (Node* parent : originalNode->getParents()) {
|
||||
const std::string& parentName = parent->getName();
|
||||
if (nodes.find(parentName) != nodes.end()) {
|
||||
newNode->addParent(nodes[parentName].get());
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct child relationships
|
||||
for (Node* child : originalNode->getChildren()) {
|
||||
const std::string& childName = child->getName();
|
||||
if (nodes.find(childName) != nodes.end()) {
|
||||
newNode->addChild(nodes[childName].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
void Network::initialize()
|
||||
{
|
||||
@@ -503,4 +579,41 @@ namespace bayesnet {
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
bool Network::operator==(const Network& other) const
|
||||
{
|
||||
// Compare number of nodes
|
||||
if (nodes.size() != other.nodes.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compare if all node names exist in both networks
|
||||
for (const auto& node : nodes) {
|
||||
if (other.nodes.find(node.first) == other.nodes.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Compare edges (topology)
|
||||
auto thisEdges = getEdges();
|
||||
auto otherEdges = other.getEdges();
|
||||
|
||||
// Compare number of edges
|
||||
if (thisEdges.size() != otherEdges.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Sort both edge lists for comparison
|
||||
std::sort(thisEdges.begin(), thisEdges.end());
|
||||
std::sort(otherEdges.begin(), otherEdges.end());
|
||||
|
||||
// Compare each edge
|
||||
for (size_t i = 0; i < thisEdges.size(); ++i) {
|
||||
if (thisEdges[i] != otherEdges[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@@ -17,7 +17,8 @@ namespace bayesnet {
|
||||
class Network {
|
||||
public:
|
||||
Network();
|
||||
explicit Network(const Network&);
|
||||
Network(const Network& other);
|
||||
Network& operator=(const Network& other);
|
||||
~Network() = default;
|
||||
torch::Tensor& getSamples();
|
||||
void addNode(const std::string&);
|
||||
@@ -47,6 +48,7 @@ namespace bayesnet {
|
||||
void initialize();
|
||||
std::string dump_cpt() const;
|
||||
inline std::string version() { return { project_version.begin(), project_version.end() }; }
|
||||
bool operator==(const Network& other) const;
|
||||
private:
|
||||
std::map<std::string, std::unique_ptr<Node>> nodes;
|
||||
bool fitted;
|
||||
|
@@ -13,6 +13,41 @@ namespace bayesnet {
|
||||
: name(name)
|
||||
{
|
||||
}
|
||||
|
||||
Node::Node(const Node& other)
|
||||
: name(other.name), numStates(other.numStates), dimensions(other.dimensions)
|
||||
{
|
||||
// Deep copy the CPT tensor
|
||||
if (other.cpTable.defined()) {
|
||||
cpTable = other.cpTable.clone();
|
||||
}
|
||||
// Note: parent and children pointers are NOT copied here
|
||||
// They will be reconstructed by the Network copy constructor
|
||||
// to maintain proper object relationships
|
||||
}
|
||||
|
||||
Node& Node::operator=(const Node& other)
|
||||
{
|
||||
if (this != &other) {
|
||||
name = other.name;
|
||||
numStates = other.numStates;
|
||||
dimensions = other.dimensions;
|
||||
|
||||
// Deep copy the CPT tensor
|
||||
if (other.cpTable.defined()) {
|
||||
cpTable = other.cpTable.clone();
|
||||
} else {
|
||||
cpTable = torch::Tensor();
|
||||
}
|
||||
|
||||
// Clear existing relationships
|
||||
parents.clear();
|
||||
children.clear();
|
||||
// Note: parent and children pointers are NOT copied here
|
||||
// They must be reconstructed to maintain proper object relationships
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
void Node::clear()
|
||||
{
|
||||
parents.clear();
|
||||
|
@@ -14,6 +14,9 @@ namespace bayesnet {
|
||||
class Node {
|
||||
public:
|
||||
explicit Node(const std::string&);
|
||||
Node(const Node& other);
|
||||
Node& operator=(const Node& other);
|
||||
~Node() = default;
|
||||
void clear();
|
||||
void addParent(Node*);
|
||||
void addChild(Node*);
|
||||
|
Reference in New Issue
Block a user