Initial commit

This commit is contained in:
2024-12-23 15:06:56 +01:00
commit 43fe2ae2bc
67 changed files with 28332 additions and 0 deletions

View File

@@ -0,0 +1,197 @@
package com.testpbc4;
import weka.core.*;
import java.util.Random;
import java.io.File;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.Scanner;
import weka.classifiers.trees.PBC4cip;
import java.util.ArrayList;
import java.util.Iterator;
import weka.core.converters.ConverterUtils.DataSource;
public class App {
static int numFolds = 5;
static String path = "data/";
static String datasetsFile = path + "datasets.txt";
static String outputFile = "results_pbc4.txt";
static boolean debug = true;
static boolean stratify = false;
static boolean normalizeData = false;
public static void printAndExit(String s) {
System.out.println(s);
System.exit(0);
}
private static Instances getInstances(String fileName) {
DataSource source = null;
System.out.println(String.format("Reading file " + path + "%s.arff", fileName));
try {
source = new DataSource(String.format(path + "%s.arff", fileName));
System.out.println("File read.");
return source.getDataSet();
} catch (Exception e) {
printAndExit(
String.format("*** Error trying to read " + path + "%s.arff file... (%s)", fileName,
e.getMessage()));
}
return null;
}
private static void validation(Instances instancesIn, int seed, ArrayList<Double> accuracy,
ArrayList<Double> timeSpent, ArrayList<Double> complexity) throws Exception {
PBC4cip pb4 = new PBC4cip();
Random rand = new Random(seed);
pb4.setSeed(seed);
Instances instances;
Instances training = null;
Instances test = null;
long startTime;
// Normalize normalize = new Normalize();
// normalize.setInputFormat(instancesIn);
// if (normalizeData) {
// instances = Filter.useFilter(instancesIn, normalize);
// } else {
instances = instancesIn;
// }
instances.randomize(rand);
if (stratify) {
instances.stratify(numFolds);
}
for (int i = 0; i < numFolds; i++) {
startTime = System.currentTimeMillis();
training = instances.trainCV(numFolds, i);
test = instances.testCV(numFolds, i);
pb4.buildClassifier(training);
// let's classify
int aciertos = 0;
for (int j = 0; j < test.numInstances(); j++) {
double v = pb4.classifyInstance(test.instance(j));
double trueValue = test.instance(j).value(test.classIndex());
if (trueValue == v)
aciertos++;
}
double acc = aciertos / (double) test.numInstances();
double tspent = (System.currentTimeMillis() - startTime) / 1000.0;
if (debug) {
System.err.println(String.format("%f, %f", acc, tspent));
}
accuracy.add(acc);
timeSpent.add(tspent);
complexity.add(0.0);
}
}
private static double[] meanAndDeviation(ArrayList<Double> input) {
double sum = 0.0;
int i = 0;
double[] results = new double[4]; // 0 -> mean, 1 -> std
Iterator it = input.iterator();
// Compute mean
while (it.hasNext()) {
i++;
sum += (double) it.next();
}
results[0] = i != 0 ? sum / (double) i : 0;
it = input.iterator();
sum = 0.0;
while (it.hasNext()) {
sum += Math.pow((double) it.next() - results[0], 2);
}
results[1] = i != 0 ? Math.sqrt(sum / (double) i) : 0;
return results;
}
private static void initializeOutput() {
try {
File f = new File(outputFile);
f.delete();
} catch (Exception e) {
printAndExit(
String.format("*** Error trying to delete the output file %s... (%s)", outputFile, e.getMessage()));
}
}
private static void store(String dataset, double[] accuracy, double[] timeSpent, double[] complexity)
throws Exception {
// Append output
FileWriter file = new FileWriter(outputFile, true);
PrintWriter linePrint = new PrintWriter(file);
linePrint.printf("%s; %f; %f; %f; %f; %f; %f%n", dataset, accuracy[0], accuracy[1], timeSpent[0], timeSpent[1],
complexity[0], complexity[1]);
linePrint.close();
}
public static void main(String[] args) throws Exception {
Instances instances = null;
Scanner sc;
File file = null;
Random gen;
double[] accuracyStat = new double[2];
double[] timeStat = new double[2];
double[] complexityStat = new double[2];
int[] seeds = { 57, 31, 1714, 17, 23, 79, 83, 97, 7, 1 };
String dataset;
String debugMsg, stratMsg, normMsg;
ArrayList<Double> timeSpent;
ArrayList<Double> accuracy;
ArrayList<Double> complexity;
debugMsg = debug ? "With debug output" : "";
stratMsg = stratify ? "stratified" : "without stratification";
normMsg = normalizeData ? " with normalization" : " without normalization";
try {
file = new File(datasetsFile);
} catch (Exception e) {
printAndExit(String.format("*** Error trying to read datasets file... (%s)",
e.getMessage()));
}
System.out.println(String.format("%d fold cross validation %s %s %s",
numFolds, stratMsg, debugMsg, normMsg));
initializeOutput();
sc = new Scanner(file);
while (sc.hasNextLine()) {
dataset = sc.nextLine();
if (debug && dataset.equals("balloons")) {
printAndExit("* Check error output. Debug End.");
}
timeSpent = new ArrayList<>();
accuracy = new ArrayList<>();
complexity = new ArrayList<>();
for (int seed : seeds) {
// Establece la semilla
gen = new Random(seed);
// Obtiene los datos
instances = getInstances(dataset);
instances.randomize(gen);
instances.setClassIndex(instances.numAttributes() - 1);
System.out.println("Instances:");
System.out.println(instances.toString());
try {
validation(instances, seed, accuracy, timeSpent, complexity);
} catch (Exception e) {
printAndExit(String.format("*** Error training dataset %s... (%s)", dataset,
e.getMessage()));
}
}
accuracyStat = meanAndDeviation(accuracy);
timeStat = meanAndDeviation(timeSpent);
complexityStat = meanAndDeviation(complexity);
try {
store(dataset, accuracyStat, timeStat, complexityStat);
} catch (Exception e) {
printAndExit(String.format("*** Error storing results of dataset %s... (%s)",
dataset, e.getMessage()));
}
}
sc.close();
printAndExit("* End.");
}
}