Files
Platform/src/experimental_clfs/XA1DE.cpp

21 lines
819 B
C++

// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2025 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#include "XA1DE.h"
#include "TensorUtils.hpp"
namespace platform {
void XA1DE::trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing)
{
auto X = TensorUtils::to_matrix(dataset.slice(0, 0, dataset.size(0) - 1));
auto y = TensorUtils::to_vector<int>(dataset.index({ -1, "..." }));
int num_instances = X[0].size();
weights_ = torch::full({ num_instances }, 1.0);
//normalize_weights(num_instances);
aode_.fit(X, y, features, className, states, weights_, true, smoothing);
}
}