184 lines
6.7 KiB
Java
184 lines
6.7 KiB
Java
package com.testpbc4;
|
|
|
|
import java.io.File;
|
|
import java.io.FileWriter;
|
|
import java.io.PrintWriter;
|
|
import java.util.ArrayList;
|
|
import java.util.Iterator;
|
|
import java.util.Random;
|
|
import java.util.Scanner;
|
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
|
import weka.classifiers.trees.PBC4cip;
|
|
import weka.core.Instances;
|
|
import weka.core.converters.ConverterUtils.DataSource;
|
|
|
|
public class App {
|
|
static int numFolds = 5;
|
|
static String path = "data/";
|
|
static String datasetsFile = path + "datasets.txt";
|
|
static String outputFile = "results_pbc4.txt";
|
|
static boolean debug = false;
|
|
|
|
public static void printAndExit(String s) {
|
|
System.out.println(s);
|
|
System.exit(0);
|
|
}
|
|
|
|
private static Instances getInstances(String fileName) {
|
|
DataSource source = null;
|
|
try {
|
|
source = new DataSource(String.format(path + "%s.arff", fileName));
|
|
return source.getDataSet();
|
|
} catch (Exception e) {
|
|
printAndExit(
|
|
String.format("*** Error trying to read " + path + "%s.arff file... (%s)", fileName,
|
|
e.getMessage()));
|
|
}
|
|
return null;
|
|
}
|
|
|
|
private static void validation(Instances instancesIn, int seed, ArrayList<Double> accuracy,
|
|
ArrayList<Double> timeSpent, ArrayList<Double> complexity) throws Exception {
|
|
PBC4cip pb4 = new PBC4cip();
|
|
Random rand = new Random(seed);
|
|
pb4.setSeed(seed);
|
|
Instances instances;
|
|
Instances training = null;
|
|
Instances test = null;
|
|
|
|
instances = instancesIn;
|
|
instances.randomize(rand);
|
|
instances.stratify(numFolds);
|
|
for (int i = 0; i < numFolds; i++) {
|
|
long startTime = System.currentTimeMillis();
|
|
training = instances.trainCV(numFolds, i);
|
|
test = instances.testCV(numFolds, i);
|
|
pb4.buildClassifier(training);
|
|
// let's classify
|
|
int aciertos = 0;
|
|
for (int j = 0; j < test.numInstances(); j++) {
|
|
double v = pb4.classifyInstance(test.instance(j));
|
|
double trueValue = test.instance(j).value(test.classIndex());
|
|
if (trueValue == v)
|
|
aciertos++;
|
|
}
|
|
double acc = aciertos / (double) test.numInstances();
|
|
double tspent = (System.currentTimeMillis() - startTime) / 1000.0;
|
|
if (debug) {
|
|
System.err.println(String.format("%f, %f", acc, tspent));
|
|
}
|
|
accuracy.add(acc);
|
|
timeSpent.add(tspent);
|
|
complexity.add(0.0);
|
|
}
|
|
}
|
|
|
|
private static double[] meanAndDeviation(ArrayList<Double> input) {
|
|
double sum = 0.0;
|
|
int i = 0;
|
|
double[] results = new double[4]; // 0 -> mean, 1 -> std
|
|
Iterator it = input.iterator();
|
|
// Compute mean
|
|
while (it.hasNext()) {
|
|
i++;
|
|
sum += (double) it.next();
|
|
}
|
|
results[0] = i != 0 ? sum / (double) i : 0;
|
|
it = input.iterator();
|
|
sum = 0.0;
|
|
while (it.hasNext()) {
|
|
sum += Math.pow((double) it.next() - results[0], 2);
|
|
}
|
|
results[1] = i != 0 ? Math.sqrt(sum / (double) i) : 0;
|
|
return results;
|
|
}
|
|
|
|
private static void initializeOutput() {
|
|
try {
|
|
File f = new File(outputFile);
|
|
f.delete();
|
|
} catch (Exception e) {
|
|
printAndExit(
|
|
String.format("*** Error trying to delete the output file %s... (%s)", outputFile, e.getMessage()));
|
|
}
|
|
}
|
|
|
|
private static void store(String dataset, double[] accuracy, double[] timeSpent, double[] complexity)
|
|
throws Exception {
|
|
// Append output
|
|
FileWriter file = new FileWriter(outputFile, true);
|
|
PrintWriter linePrint = new PrintWriter(file);
|
|
linePrint.printf("%s; %f; %f; %f; %f; %f; %f%n", dataset, accuracy[0], accuracy[1], timeSpent[0], timeSpent[1],
|
|
complexity[0], complexity[1]);
|
|
linePrint.close();
|
|
}
|
|
|
|
public static void main(String[] args) throws Exception {
|
|
Instances instances = null;
|
|
Scanner sc;
|
|
File file = null;
|
|
Random gen;
|
|
double[] accuracyStat = new double[2];
|
|
double[] timeStat = new double[2];
|
|
double[] complexityStat = new double[2];
|
|
int[] seeds = { 57, 31, 1714, 17, 23, 79, 83, 97, 7, 1 };
|
|
String dataset;
|
|
ArrayList<Double> timeSpent;
|
|
ArrayList<Double> accuracy;
|
|
ArrayList<Double> complexity;
|
|
|
|
try {
|
|
file = new File(datasetsFile);
|
|
} catch (Exception e) {
|
|
printAndExit(String.format("*** Error trying to read datasets file... (%s)",
|
|
e.getMessage()));
|
|
}
|
|
System.out.println(StringUtils.center(String.format("%d fold cross validation stratified", numFolds), 64));
|
|
initializeOutput();
|
|
sc = new Scanner(file);
|
|
System.out.println(String.format(" # %-30s %-29s", "Dataset", "Seeds"));
|
|
System.out.println("--- " + "-".repeat(30) + " " + "-".repeat(29));
|
|
int number = 0;
|
|
while (sc.hasNextLine()) {
|
|
dataset = sc.nextLine();
|
|
System.out.print(String.format("%3d %-30s ", number++, dataset));
|
|
if (debug && dataset.equals("balloons")) {
|
|
printAndExit("* Check error output. Debug End.");
|
|
}
|
|
timeSpent = new ArrayList<>();
|
|
accuracy = new ArrayList<>();
|
|
complexity = new ArrayList<>();
|
|
for (int seed : seeds) {
|
|
System.out.print(String.format("%d ", seed));
|
|
// Establece la semilla
|
|
gen = new Random(seed);
|
|
// Obtiene los datos
|
|
instances = getInstances(dataset);
|
|
instances.randomize(gen);
|
|
instances.setClassIndex(instances.numAttributes() - 1);
|
|
try {
|
|
validation(instances, seed, accuracy, timeSpent, complexity);
|
|
} catch (Exception e) {
|
|
printAndExit(String.format("*** Error training dataset %s... (%s)", dataset,
|
|
e.getMessage()));
|
|
}
|
|
|
|
}
|
|
System.out.println("");
|
|
accuracyStat = meanAndDeviation(accuracy);
|
|
timeStat = meanAndDeviation(timeSpent);
|
|
complexityStat = meanAndDeviation(complexity);
|
|
try {
|
|
store(dataset, accuracyStat, timeStat, complexityStat);
|
|
} catch (Exception e) {
|
|
printAndExit(String.format("*** Error storing results of dataset %s... (%s)",
|
|
dataset, e.getMessage()));
|
|
}
|
|
}
|
|
sc.close();
|
|
printAndExit("* End.");
|
|
}
|
|
}
|