package com.testpbc4; import java.io.File; import java.io.FileWriter; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Iterator; import java.util.Random; import java.util.Scanner; import org.apache.commons.lang3.StringUtils; import weka.classifiers.trees.PBC4cip; import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; public class App { static int numFolds = 5; static String path = "data/"; static String datasetsFile = path + "datasets.txt"; static String outputFile = "results_pbc4.txt"; static boolean debug = false; public static void printAndExit(String s) { System.out.println(s); System.exit(0); } private static Instances getInstances(String fileName) { DataSource source = null; try { source = new DataSource(String.format(path + "%s.arff", fileName)); return source.getDataSet(); } catch (Exception e) { printAndExit( String.format("*** Error trying to read " + path + "%s.arff file... (%s)", fileName, e.getMessage())); } return null; } private static void validation(Instances instancesIn, int seed, ArrayList accuracy, ArrayList timeSpent, ArrayList complexity) throws Exception { PBC4cip pb4 = new PBC4cip(); Random rand = new Random(seed); pb4.setSeed(seed); Instances instances; Instances training = null; Instances test = null; instances = instancesIn; instances.randomize(rand); instances.stratify(numFolds); for (int i = 0; i < numFolds; i++) { long startTime = System.currentTimeMillis(); training = instances.trainCV(numFolds, i); test = instances.testCV(numFolds, i); pb4.buildClassifier(training); // let's classify int aciertos = 0; for (int j = 0; j < test.numInstances(); j++) { double v = pb4.classifyInstance(test.instance(j)); double trueValue = test.instance(j).value(test.classIndex()); if (trueValue == v) aciertos++; } double acc = aciertos / (double) test.numInstances(); double tspent = (System.currentTimeMillis() - startTime) / 1000.0; if (debug) { System.err.println(String.format("%f, %f", acc, tspent)); } accuracy.add(acc); timeSpent.add(tspent); complexity.add(0.0); } } private static double[] meanAndDeviation(ArrayList input) { double sum = 0.0; int i = 0; double[] results = new double[4]; // 0 -> mean, 1 -> std Iterator it = input.iterator(); // Compute mean while (it.hasNext()) { i++; sum += (double) it.next(); } results[0] = i != 0 ? sum / (double) i : 0; it = input.iterator(); sum = 0.0; while (it.hasNext()) { sum += Math.pow((double) it.next() - results[0], 2); } results[1] = i != 0 ? Math.sqrt(sum / (double) i) : 0; return results; } private static void initializeOutput() { try { File f = new File(outputFile); f.delete(); } catch (Exception e) { printAndExit( String.format("*** Error trying to delete the output file %s... (%s)", outputFile, e.getMessage())); } } private static void store(String dataset, double[] accuracy, double[] timeSpent, double[] complexity) throws Exception { // Append output FileWriter file = new FileWriter(outputFile, true); PrintWriter linePrint = new PrintWriter(file); linePrint.printf("%s; %f; %f; %f; %f; %f; %f%n", dataset, accuracy[0], accuracy[1], timeSpent[0], timeSpent[1], complexity[0], complexity[1]); linePrint.close(); } public static void main(String[] args) throws Exception { Instances instances = null; Scanner sc; File file = null; Random gen; double[] accuracyStat = new double[2]; double[] timeStat = new double[2]; double[] complexityStat = new double[2]; int[] seeds = { 57, 31, 1714, 17, 23, 79, 83, 97, 7, 1 }; String dataset; ArrayList timeSpent; ArrayList accuracy; ArrayList complexity; try { file = new File(datasetsFile); } catch (Exception e) { printAndExit(String.format("*** Error trying to read datasets file... (%s)", e.getMessage())); } System.out.println(StringUtils.center(String.format("%d fold cross validation stratified", numFolds), 64)); initializeOutput(); sc = new Scanner(file); System.out.println(String.format(" # %-30s %-29s", "Dataset", "Seeds")); System.out.println("--- " + "-".repeat(30) + " " + "-".repeat(29)); int number = 0; while (sc.hasNextLine()) { dataset = sc.nextLine(); System.out.print(String.format("%3d %-30s ", number++, dataset)); if (debug && dataset.equals("balloons")) { printAndExit("* Check error output. Debug End."); } timeSpent = new ArrayList<>(); accuracy = new ArrayList<>(); complexity = new ArrayList<>(); for (int seed : seeds) { System.out.print(String.format("%d ", seed)); // Establece la semilla gen = new Random(seed); // Obtiene los datos instances = getInstances(dataset); instances.randomize(gen); instances.setClassIndex(instances.numAttributes() - 1); try { validation(instances, seed, accuracy, timeSpent, complexity); } catch (Exception e) { printAndExit(String.format("*** Error training dataset %s... (%s)", dataset, e.getMessage())); } } System.out.println(""); accuracyStat = meanAndDeviation(accuracy); timeStat = meanAndDeviation(timeSpent); complexityStat = meanAndDeviation(complexity); try { store(dataset, accuracyStat, timeStat, complexityStat); } catch (Exception e) { printAndExit(String.format("*** Error storing results of dataset %s... (%s)", dataset, e.getMessage())); } } sc.close(); printAndExit("* End."); } }