First commit
This commit is contained in:
210
libsvm-3.36/tools/README
Normal file
210
libsvm-3.36/tools/README
Normal file
@@ -0,0 +1,210 @@
|
||||
This directory includes some useful codes:
|
||||
|
||||
1. subset selection tools.
|
||||
2. parameter selection tools.
|
||||
3. LIBSVM format checking tools
|
||||
|
||||
Part I: Subset selection tools
|
||||
|
||||
Introduction
|
||||
============
|
||||
|
||||
Training large data is time consuming. Sometimes one should work on a
|
||||
smaller subset first. The python script subset.py randomly selects a
|
||||
specified number of samples. For classification data, we provide a
|
||||
stratified selection to ensure the same class distribution in the
|
||||
subset.
|
||||
|
||||
Usage: subset.py [options] dataset number [output1] [output2]
|
||||
|
||||
This script selects a subset of the given data set.
|
||||
|
||||
options:
|
||||
-s method : method of selection (default 0)
|
||||
0 -- stratified selection (classification only)
|
||||
1 -- random selection
|
||||
|
||||
output1 : the subset (optional)
|
||||
output2 : the rest of data (optional)
|
||||
|
||||
If output1 is omitted, the subset will be printed on the screen.
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
> python subset.py heart_scale 100 file1 file2
|
||||
|
||||
From heart_scale 100 samples are randomly selected and stored in
|
||||
file1. All remaining instances are stored in file2.
|
||||
|
||||
|
||||
Part II: Parameter Selection Tools
|
||||
|
||||
Introduction
|
||||
============
|
||||
|
||||
grid.py is a parameter selection tool for C-SVM classification using
|
||||
the RBF (radial basis function) kernel. It uses cross validation (CV)
|
||||
technique to estimate the accuracy of each parameter combination in
|
||||
the specified range and helps you to decide the best parameters for
|
||||
your problem.
|
||||
|
||||
grid.py directly executes libsvm binaries (so no python binding is needed)
|
||||
for cross validation and then draw contour of CV accuracy using gnuplot.
|
||||
You must have libsvm and gnuplot installed before using it. The package
|
||||
gnuplot is available at http://www.gnuplot.info/
|
||||
|
||||
On Mac OSX, the precompiled gnuplot file needs the library Aquarterm,
|
||||
which thus must be installed as well. In addition, this version of
|
||||
gnuplot does not support png, so you need to change "set term png
|
||||
transparent small" and use other image formats. For example, you may
|
||||
have "set term pbm small color".
|
||||
|
||||
Usage: grid.py [grid_options] [svm_options] dataset
|
||||
|
||||
grid_options :
|
||||
-log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
|
||||
begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
|
||||
"null" -- do not grid with c
|
||||
-log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
|
||||
begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
|
||||
"null" -- do not grid with g
|
||||
-v n : n-fold cross validation (default 5)
|
||||
-svmtrain pathname : set svm executable path and name
|
||||
-gnuplot {pathname | "null"} :
|
||||
pathname -- set gnuplot executable path and name
|
||||
"null" -- do not plot
|
||||
-out {pathname | "null"} : (default dataset.out)
|
||||
pathname -- set output file path and name
|
||||
"null" -- do not output file
|
||||
-png pathname : set graphic output file path and name (default dataset.png)
|
||||
-resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
|
||||
Use this option only if some parameters have been checked for the SAME data.
|
||||
|
||||
svm_options : additional options for svm-train
|
||||
|
||||
The program conducts v-fold cross validation using parameter C (and gamma)
|
||||
= 2^begin, 2^(begin+step), ..., 2^end.
|
||||
|
||||
You can specify where the libsvm executable and gnuplot are using the
|
||||
-svmtrain and -gnuplot parameters.
|
||||
|
||||
For windows users, please use pgnuplot.exe. If you are using gnuplot
|
||||
3.7.1, please upgrade to version 3.7.3 or higher. The version 3.7.1
|
||||
has a bug. If you use cygwin on windows, please use gunplot-x11.
|
||||
|
||||
If the task is terminated accidentally or you would like to change the
|
||||
range of parameters, you can apply '-resume' to save time by re-using
|
||||
previous results. You may specify the output file of a previous run
|
||||
or use the default (i.e., dataset.out) without giving a name. Please
|
||||
note that the same condition must be used in two runs. For example,
|
||||
you cannot use '-v 10' earlier and resume the task with '-v 5'.
|
||||
|
||||
The value of some options can be "null." For example, `-log2c -1,0,1
|
||||
-log2 "null"' means that C=2^-1,2^0,2^1 and g=LIBSVM's default gamma
|
||||
value. That is, you do not conduct parameter selection on gamma.
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
> python grid.py -log2c -5,5,1 -log2g -4,0,1 -v 5 -m 300 heart_scale
|
||||
|
||||
Users (in particular MS Windows users) may need to specify the path of
|
||||
executable files. You can either change paths in the beginning of
|
||||
grid.py or specify them in the command line. For example,
|
||||
|
||||
> grid.py -log2c -5,5,1 -svmtrain "c:\Program Files\libsvm\windows\svm-train.exe" -gnuplot c:\tmp\gnuplot\binary\pgnuplot.exe -v 10 heart_scale
|
||||
|
||||
Output: two files
|
||||
dataset.png: the CV accuracy contour plot generated by gnuplot
|
||||
dataset.out: the CV accuracy at each (log2(C),log2(gamma))
|
||||
|
||||
The following example saves running time by loading the output file of a previous run.
|
||||
|
||||
> python grid.py -log2c -7,7,1 -log2g -5,2,1 -v 5 -resume heart_scale.out heart_scale
|
||||
|
||||
Parallel grid search
|
||||
====================
|
||||
|
||||
You can conduct a parallel grid search by dispatching jobs to a
|
||||
cluster of computers which share the same file system. First, you add
|
||||
machine names in grid.py:
|
||||
|
||||
ssh_workers = ["linux1", "linux5", "linux5"]
|
||||
|
||||
and then setup your ssh so that the authentication works without
|
||||
asking a password.
|
||||
|
||||
The same machine (e.g., linux5 here) can be listed more than once if
|
||||
it has multiple CPUs or has more RAM. If the local machine is the
|
||||
best, you can also enlarge the nr_local_worker. For example:
|
||||
|
||||
nr_local_worker = 2
|
||||
|
||||
Example:
|
||||
|
||||
> python grid.py heart_scale
|
||||
[local] -1 -1 78.8889 (best c=0.5, g=0.5, rate=78.8889)
|
||||
[linux5] -1 -7 83.3333 (best c=0.5, g=0.0078125, rate=83.3333)
|
||||
[linux5] 5 -1 77.037 (best c=0.5, g=0.0078125, rate=83.3333)
|
||||
[linux1] 5 -7 83.3333 (best c=0.5, g=0.0078125, rate=83.3333)
|
||||
.
|
||||
.
|
||||
.
|
||||
|
||||
If -log2c, -log2g, or -v is not specified, default values are used.
|
||||
|
||||
If your system uses telnet instead of ssh, you list the computer names
|
||||
in telnet_workers.
|
||||
|
||||
Calling grid in Python
|
||||
======================
|
||||
|
||||
In addition to using grid.py as a command-line tool, you can use it as a
|
||||
Python module.
|
||||
|
||||
>>> rate, param = find_parameters(dataset, options)
|
||||
|
||||
You need to specify `dataset' and `options' (default ''). See the following example.
|
||||
|
||||
> python
|
||||
|
||||
>>> from grid import *
|
||||
>>> rate, param = find_parameters('../heart_scale', '-log2c -1,1,1 -log2g -1,1,1')
|
||||
[local] 0.0 0.0 rate=74.8148 (best c=1.0, g=1.0, rate=74.8148)
|
||||
[local] 0.0 -1.0 rate=77.037 (best c=1.0, g=0.5, rate=77.037)
|
||||
.
|
||||
.
|
||||
[local] -1.0 -1.0 rate=78.8889 (best c=0.5, g=0.5, rate=78.8889)
|
||||
.
|
||||
.
|
||||
>>> rate
|
||||
78.8889
|
||||
>>> param
|
||||
{'c': 0.5, 'g': 0.5}
|
||||
|
||||
|
||||
Part III: LIBSVM format checking tools
|
||||
|
||||
Introduction
|
||||
============
|
||||
|
||||
`svm-train' conducts only a simple check of the input data. To do a
|
||||
detailed check, we provide a python script `checkdata.py.'
|
||||
|
||||
Usage: checkdata.py dataset
|
||||
|
||||
Exit status (returned value): 1 if there are errors, 0 otherwise.
|
||||
|
||||
This tool is written by Rong-En Fan at National Taiwan University.
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
> cat bad_data
|
||||
1 3:1 2:4
|
||||
> python checkdata.py bad_data
|
||||
line 1: feature indices must be in an ascending order, previous/current features 3:1 2:4
|
||||
Found 1 lines with error.
|
||||
|
||||
|
108
libsvm-3.36/tools/checkdata.py
Executable file
108
libsvm-3.36/tools/checkdata.py
Executable file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
#
|
||||
# A format checker for LIBSVM
|
||||
#
|
||||
|
||||
#
|
||||
# Copyright (c) 2007, Rong-En Fan
|
||||
#
|
||||
# All rights reserved.
|
||||
#
|
||||
# This program is distributed under the same license of the LIBSVM package.
|
||||
#
|
||||
|
||||
from sys import argv, exit
|
||||
import os.path
|
||||
|
||||
def err(line_no, msg):
|
||||
print("line {0}: {1}".format(line_no, msg))
|
||||
|
||||
# works like float() but does not accept nan and inf
|
||||
def my_float(x):
|
||||
if x.lower().find("nan") != -1 or x.lower().find("inf") != -1:
|
||||
raise ValueError
|
||||
|
||||
return float(x)
|
||||
|
||||
def main():
|
||||
if len(argv) != 2:
|
||||
print("Usage: {0} dataset".format(argv[0]))
|
||||
exit(1)
|
||||
|
||||
dataset = argv[1]
|
||||
|
||||
if not os.path.exists(dataset):
|
||||
print("dataset {0} not found".format(dataset))
|
||||
exit(1)
|
||||
|
||||
line_no = 1
|
||||
error_line_count = 0
|
||||
for line in open(dataset, 'r'):
|
||||
line_error = False
|
||||
|
||||
# each line must end with a newline character
|
||||
if line[-1] != '\n':
|
||||
err(line_no, "missing a newline character in the end")
|
||||
line_error = True
|
||||
|
||||
nodes = line.split()
|
||||
|
||||
# check label
|
||||
try:
|
||||
label = nodes.pop(0)
|
||||
|
||||
if label.find(',') != -1:
|
||||
# multi-label format
|
||||
try:
|
||||
for l in label.split(','):
|
||||
l = my_float(l)
|
||||
except:
|
||||
err(line_no, "label {0} is not a valid multi-label form".format(label))
|
||||
line_error = True
|
||||
else:
|
||||
try:
|
||||
label = my_float(label)
|
||||
except:
|
||||
err(line_no, "label {0} is not a number".format(label))
|
||||
line_error = True
|
||||
except:
|
||||
err(line_no, "missing label, perhaps an empty line?")
|
||||
line_error = True
|
||||
|
||||
# check features
|
||||
prev_index = -1
|
||||
for i in range(len(nodes)):
|
||||
try:
|
||||
(index, value) = nodes[i].split(':')
|
||||
|
||||
index = int(index)
|
||||
value = my_float(value)
|
||||
|
||||
# precomputed kernel's index starts from 0 and LIBSVM
|
||||
# checks it. Hence, don't treat index 0 as an error.
|
||||
if index < 0:
|
||||
err(line_no, "feature index must be positive; wrong feature {0}".format(nodes[i]))
|
||||
line_error = True
|
||||
elif index <= prev_index:
|
||||
err(line_no, "feature indices must be in an ascending order, previous/current features {0} {1}".format(nodes[i-1], nodes[i]))
|
||||
line_error = True
|
||||
prev_index = index
|
||||
except:
|
||||
err(line_no, "feature '{0}' not an <index>:<value> pair, <index> integer, <value> real number ".format(nodes[i]))
|
||||
line_error = True
|
||||
|
||||
line_no += 1
|
||||
|
||||
if line_error:
|
||||
error_line_count += 1
|
||||
|
||||
if error_line_count > 0:
|
||||
print("Found {0} lines with error.".format(error_line_count))
|
||||
return 1
|
||||
else:
|
||||
print("No error.")
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
79
libsvm-3.36/tools/easy.py
Executable file
79
libsvm-3.36/tools/easy.py
Executable file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
import os
|
||||
from subprocess import *
|
||||
|
||||
if len(sys.argv) <= 1:
|
||||
print('Usage: {0} training_file [testing_file]'.format(sys.argv[0]))
|
||||
raise SystemExit
|
||||
|
||||
# svm, grid, and gnuplot executable files
|
||||
|
||||
is_win32 = (sys.platform == 'win32')
|
||||
if not is_win32:
|
||||
svmscale_exe = "../svm-scale"
|
||||
svmtrain_exe = "../svm-train"
|
||||
svmpredict_exe = "../svm-predict"
|
||||
grid_py = "./grid.py"
|
||||
gnuplot_exe = "/usr/bin/gnuplot"
|
||||
else:
|
||||
# example for windows
|
||||
svmscale_exe = r"..\windows\svm-scale.exe"
|
||||
svmtrain_exe = r"..\windows\svm-train.exe"
|
||||
svmpredict_exe = r"..\windows\svm-predict.exe"
|
||||
gnuplot_exe = r"c:\tmp\gnuplot\binary\pgnuplot.exe"
|
||||
grid_py = r".\grid.py"
|
||||
|
||||
assert os.path.exists(svmscale_exe),"svm-scale executable not found"
|
||||
assert os.path.exists(svmtrain_exe),"svm-train executable not found"
|
||||
assert os.path.exists(svmpredict_exe),"svm-predict executable not found"
|
||||
assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
|
||||
assert os.path.exists(grid_py),"grid.py not found"
|
||||
|
||||
train_pathname = sys.argv[1]
|
||||
assert os.path.exists(train_pathname),"training file not found"
|
||||
file_name = os.path.split(train_pathname)[1]
|
||||
scaled_file = file_name + ".scale"
|
||||
model_file = file_name + ".model"
|
||||
range_file = file_name + ".range"
|
||||
|
||||
if len(sys.argv) > 2:
|
||||
test_pathname = sys.argv[2]
|
||||
file_name = os.path.split(test_pathname)[1]
|
||||
assert os.path.exists(test_pathname),"testing file not found"
|
||||
scaled_test_file = file_name + ".scale"
|
||||
predict_test_file = file_name + ".predict"
|
||||
|
||||
cmd = '{0} -s "{1}" "{2}" > "{3}"'.format(svmscale_exe, range_file, train_pathname, scaled_file)
|
||||
print('Scaling training data...')
|
||||
Popen(cmd, shell = True, stdout = PIPE).communicate()
|
||||
|
||||
cmd = '{0} -svmtrain "{1}" -gnuplot "{2}" "{3}"'.format(grid_py, svmtrain_exe, gnuplot_exe, scaled_file)
|
||||
print('Cross validation...')
|
||||
f = Popen(cmd, shell = True, stdout = PIPE).stdout
|
||||
|
||||
line = ''
|
||||
while True:
|
||||
last_line = line
|
||||
line = f.readline()
|
||||
if not line: break
|
||||
c,g,rate = map(float,last_line.split())
|
||||
|
||||
print('Best c={0}, g={1} CV rate={2}'.format(c,g,rate))
|
||||
|
||||
cmd = '{0} -c {1} -g {2} "{3}" "{4}"'.format(svmtrain_exe,c,g,scaled_file,model_file)
|
||||
print('Training...')
|
||||
Popen(cmd, shell = True, stdout = PIPE).communicate()
|
||||
|
||||
print('Output model: {0}'.format(model_file))
|
||||
if len(sys.argv) > 2:
|
||||
cmd = '{0} -r "{1}" "{2}" > "{3}"'.format(svmscale_exe, range_file, test_pathname, scaled_test_file)
|
||||
print('Scaling testing data...')
|
||||
Popen(cmd, shell = True, stdout = PIPE).communicate()
|
||||
|
||||
cmd = '{0} "{1}" "{2}" "{3}"'.format(svmpredict_exe, scaled_test_file, model_file, predict_test_file)
|
||||
print('Testing...')
|
||||
Popen(cmd, shell = True).communicate()
|
||||
|
||||
print('Output prediction: {0}'.format(predict_test_file))
|
500
libsvm-3.36/tools/grid.py
Executable file
500
libsvm-3.36/tools/grid.py
Executable file
@@ -0,0 +1,500 @@
|
||||
#!/usr/bin/env python
|
||||
__all__ = ['find_parameters']
|
||||
|
||||
import os, sys, traceback, getpass, time, re
|
||||
from threading import Thread
|
||||
from subprocess import *
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
from Queue import Queue
|
||||
else:
|
||||
from queue import Queue
|
||||
|
||||
telnet_workers = []
|
||||
ssh_workers = []
|
||||
nr_local_worker = 1
|
||||
|
||||
class GridOption:
|
||||
def __init__(self, dataset_pathname, options):
|
||||
dirname = os.path.dirname(__file__)
|
||||
if sys.platform != 'win32':
|
||||
self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
|
||||
self.gnuplot_pathname = '/usr/bin/gnuplot'
|
||||
else:
|
||||
# example for windows
|
||||
self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
|
||||
# svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe'
|
||||
self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
|
||||
self.fold = 5
|
||||
self.c_begin, self.c_end, self.c_step = -5, 15, 2
|
||||
self.g_begin, self.g_end, self.g_step = 3, -15, -2
|
||||
self.grid_with_c, self.grid_with_g = True, True
|
||||
self.dataset_pathname = dataset_pathname
|
||||
self.dataset_title = os.path.split(dataset_pathname)[1]
|
||||
self.out_pathname = '{0}.out'.format(self.dataset_title)
|
||||
self.png_pathname = '{0}.png'.format(self.dataset_title)
|
||||
self.pass_through_string = ' '
|
||||
self.resume_pathname = None
|
||||
self.parse_options(options)
|
||||
|
||||
def parse_options(self, options):
|
||||
if type(options) == str:
|
||||
options = options.split()
|
||||
i = 0
|
||||
pass_through_options = []
|
||||
|
||||
while i < len(options):
|
||||
if options[i] == '-log2c':
|
||||
i = i + 1
|
||||
if options[i] == 'null':
|
||||
self.grid_with_c = False
|
||||
else:
|
||||
self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
|
||||
elif options[i] == '-log2g':
|
||||
i = i + 1
|
||||
if options[i] == 'null':
|
||||
self.grid_with_g = False
|
||||
else:
|
||||
self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
|
||||
elif options[i] == '-v':
|
||||
i = i + 1
|
||||
self.fold = options[i]
|
||||
elif options[i] in ('-c','-g'):
|
||||
raise ValueError('Use -log2c and -log2g.')
|
||||
elif options[i] == '-svmtrain':
|
||||
i = i + 1
|
||||
self.svmtrain_pathname = options[i]
|
||||
elif options[i] == '-gnuplot':
|
||||
i = i + 1
|
||||
if options[i] == 'null':
|
||||
self.gnuplot_pathname = None
|
||||
else:
|
||||
self.gnuplot_pathname = options[i]
|
||||
elif options[i] == '-out':
|
||||
i = i + 1
|
||||
if options[i] == 'null':
|
||||
self.out_pathname = None
|
||||
else:
|
||||
self.out_pathname = options[i]
|
||||
elif options[i] == '-png':
|
||||
i = i + 1
|
||||
self.png_pathname = options[i]
|
||||
elif options[i] == '-resume':
|
||||
if i == (len(options)-1) or options[i+1].startswith('-'):
|
||||
self.resume_pathname = self.dataset_title + '.out'
|
||||
else:
|
||||
i = i + 1
|
||||
self.resume_pathname = options[i]
|
||||
else:
|
||||
pass_through_options.append(options[i])
|
||||
i = i + 1
|
||||
|
||||
self.pass_through_string = ' '.join(pass_through_options)
|
||||
if not os.path.exists(self.svmtrain_pathname):
|
||||
raise IOError('svm-train executable not found')
|
||||
if not os.path.exists(self.dataset_pathname):
|
||||
raise IOError('dataset not found')
|
||||
if self.resume_pathname and not os.path.exists(self.resume_pathname):
|
||||
raise IOError('file for resumption not found')
|
||||
if not self.grid_with_c and not self.grid_with_g:
|
||||
raise ValueError('-log2c and -log2g should not be null simultaneously')
|
||||
if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
|
||||
sys.stderr.write('gnuplot executable not found\n')
|
||||
self.gnuplot_pathname = None
|
||||
|
||||
def redraw(db,best_param,gnuplot,options,tofile=False):
|
||||
if len(db) == 0: return
|
||||
begin_level = round(max(x[2] for x in db)) - 3
|
||||
step_size = 0.5
|
||||
|
||||
best_log2c,best_log2g,best_rate = best_param
|
||||
|
||||
# if newly obtained c, g, or cv values are the same,
|
||||
# then stop redrawing the contour.
|
||||
if all(x[0] == db[0][0] for x in db): return
|
||||
if all(x[1] == db[0][1] for x in db): return
|
||||
if all(x[2] == db[0][2] for x in db): return
|
||||
|
||||
if tofile:
|
||||
gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
|
||||
gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
|
||||
#gnuplot.write(b"set term postscript color solid\n")
|
||||
#gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode())
|
||||
elif sys.platform == 'win32':
|
||||
gnuplot.write(b"set term windows\n")
|
||||
else:
|
||||
gnuplot.write( b"set term x11\n")
|
||||
gnuplot.write(b"set xlabel \"log2(C)\"\n")
|
||||
gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
|
||||
gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
|
||||
gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
|
||||
gnuplot.write(b"set contour\n")
|
||||
gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
|
||||
gnuplot.write(b"unset surface\n")
|
||||
gnuplot.write(b"unset ztics\n")
|
||||
gnuplot.write(b"set view 0,0\n")
|
||||
gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
|
||||
gnuplot.write(b"unset label\n")
|
||||
gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \
|
||||
at screen 0.5,0.85 center\n". \
|
||||
format(best_log2c, best_log2g, best_rate).encode())
|
||||
gnuplot.write("set label \"C = {0} gamma = {1}\""
|
||||
" at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
|
||||
gnuplot.write(b"set key at screen 0.9,0.9\n")
|
||||
gnuplot.write(b"splot \"-\" with lines\n")
|
||||
|
||||
db.sort(key = lambda x:(x[0], -x[1]))
|
||||
|
||||
prevc = db[0][0]
|
||||
for line in db:
|
||||
if prevc != line[0]:
|
||||
gnuplot.write(b"\n")
|
||||
prevc = line[0]
|
||||
gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
|
||||
gnuplot.write(b"e\n")
|
||||
gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
|
||||
gnuplot.flush()
|
||||
|
||||
|
||||
def calculate_jobs(options):
|
||||
|
||||
def range_f(begin,end,step):
|
||||
# like range, but works on non-integer too
|
||||
seq = []
|
||||
while True:
|
||||
if step > 0 and begin > end: break
|
||||
if step < 0 and begin < end: break
|
||||
seq.append(begin)
|
||||
begin = begin + step
|
||||
return seq
|
||||
|
||||
def permute_sequence(seq):
|
||||
n = len(seq)
|
||||
if n <= 1: return seq
|
||||
|
||||
mid = int(n/2)
|
||||
left = permute_sequence(seq[:mid])
|
||||
right = permute_sequence(seq[mid+1:])
|
||||
|
||||
ret = [seq[mid]]
|
||||
while left or right:
|
||||
if left: ret.append(left.pop(0))
|
||||
if right: ret.append(right.pop(0))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
|
||||
g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))
|
||||
|
||||
if not options.grid_with_c:
|
||||
c_seq = [None]
|
||||
if not options.grid_with_g:
|
||||
g_seq = [None]
|
||||
|
||||
nr_c = float(len(c_seq))
|
||||
nr_g = float(len(g_seq))
|
||||
i, j = 0, 0
|
||||
jobs = []
|
||||
|
||||
while i < nr_c or j < nr_g:
|
||||
if i/nr_c < j/nr_g:
|
||||
# increase C resolution
|
||||
line = []
|
||||
for k in range(0,j):
|
||||
line.append((c_seq[i],g_seq[k]))
|
||||
i = i + 1
|
||||
jobs.append(line)
|
||||
else:
|
||||
# increase g resolution
|
||||
line = []
|
||||
for k in range(0,i):
|
||||
line.append((c_seq[k],g_seq[j]))
|
||||
j = j + 1
|
||||
jobs.append(line)
|
||||
|
||||
resumed_jobs = {}
|
||||
|
||||
if options.resume_pathname is None:
|
||||
return jobs, resumed_jobs
|
||||
|
||||
for line in open(options.resume_pathname, 'r'):
|
||||
line = line.strip()
|
||||
rst = re.findall(r'rate=([0-9.]+)',line)
|
||||
if not rst:
|
||||
continue
|
||||
rate = float(rst[0])
|
||||
|
||||
c, g = None, None
|
||||
rst = re.findall(r'log2c=([0-9.-]+)',line)
|
||||
if rst:
|
||||
c = float(rst[0])
|
||||
rst = re.findall(r'log2g=([0-9.-]+)',line)
|
||||
if rst:
|
||||
g = float(rst[0])
|
||||
|
||||
resumed_jobs[(c,g)] = rate
|
||||
|
||||
return jobs, resumed_jobs
|
||||
|
||||
|
||||
class WorkerStopToken: # used to notify the worker to stop or if a worker is dead
|
||||
pass
|
||||
|
||||
class Worker(Thread):
|
||||
def __init__(self,name,job_queue,result_queue,options):
|
||||
Thread.__init__(self)
|
||||
self.name = name
|
||||
self.job_queue = job_queue
|
||||
self.result_queue = result_queue
|
||||
self.options = options
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
(cexp,gexp) = self.job_queue.get()
|
||||
if cexp is WorkerStopToken:
|
||||
self.job_queue.put((cexp,gexp))
|
||||
# print('worker {0} stop.'.format(self.name))
|
||||
break
|
||||
try:
|
||||
c, g = None, None
|
||||
if cexp != None:
|
||||
c = 2.0**cexp
|
||||
if gexp != None:
|
||||
g = 2.0**gexp
|
||||
rate = self.run_one(c,g)
|
||||
if rate is None: raise RuntimeError('get no rate')
|
||||
except:
|
||||
# we failed, let others do that and we just quit
|
||||
|
||||
traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
|
||||
|
||||
self.job_queue.put((cexp,gexp))
|
||||
sys.stderr.write('worker {0} quit.\n'.format(self.name))
|
||||
break
|
||||
else:
|
||||
self.result_queue.put((self.name,cexp,gexp,rate))
|
||||
|
||||
def get_cmd(self,c,g):
|
||||
options=self.options
|
||||
cmdline = '"' + options.svmtrain_pathname + '"'
|
||||
if options.grid_with_c:
|
||||
cmdline += ' -c {0} '.format(c)
|
||||
if options.grid_with_g:
|
||||
cmdline += ' -g {0} '.format(g)
|
||||
cmdline += ' -v {0} {1} {2} '.format\
|
||||
(options.fold,options.pass_through_string,'"' + options.dataset_pathname + '"')
|
||||
return cmdline
|
||||
|
||||
class LocalWorker(Worker):
|
||||
def run_one(self,c,g):
|
||||
cmdline = self.get_cmd(c,g)
|
||||
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
|
||||
for line in result.readlines():
|
||||
if str(line).find('Cross') != -1:
|
||||
return float(line.split()[-1][0:-1])
|
||||
|
||||
class SSHWorker(Worker):
|
||||
def __init__(self,name,job_queue,result_queue,host,options):
|
||||
Worker.__init__(self,name,job_queue,result_queue,options)
|
||||
self.host = host
|
||||
self.cwd = os.getcwd()
|
||||
def run_one(self,c,g):
|
||||
cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
|
||||
(self.host,self.cwd,self.get_cmd(c,g))
|
||||
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
|
||||
for line in result.readlines():
|
||||
if str(line).find('Cross') != -1:
|
||||
return float(line.split()[-1][0:-1])
|
||||
|
||||
class TelnetWorker(Worker):
|
||||
def __init__(self,name,job_queue,result_queue,host,username,password,options):
|
||||
Worker.__init__(self,name,job_queue,result_queue,options)
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.password = password
|
||||
def run(self):
|
||||
import telnetlib
|
||||
self.tn = tn = telnetlib.Telnet(self.host)
|
||||
tn.read_until('login: ')
|
||||
tn.write(self.username + '\n')
|
||||
tn.read_until('Password: ')
|
||||
tn.write(self.password + '\n')
|
||||
|
||||
# XXX: how to know whether login is successful?
|
||||
tn.read_until(self.username)
|
||||
#
|
||||
print('login ok', self.host)
|
||||
tn.write('cd '+os.getcwd()+'\n')
|
||||
Worker.run(self)
|
||||
tn.write('exit\n')
|
||||
def run_one(self,c,g):
|
||||
cmdline = self.get_cmd(c,g)
|
||||
result = self.tn.write(cmdline+'\n')
|
||||
(idx,matchm,output) = self.tn.expect(['Cross.*\n'])
|
||||
for line in output.split('\n'):
|
||||
if str(line).find('Cross') != -1:
|
||||
return float(line.split()[-1][0:-1])
|
||||
|
||||
def find_parameters(dataset_pathname, options=''):
|
||||
|
||||
def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
|
||||
if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c):
|
||||
best_rate,best_c,best_g = rate,c,g
|
||||
stdout_str = '[{0}] {1} {2} (best '.format\
|
||||
(worker,' '.join(str(x) for x in [c,g] if x is not None),rate)
|
||||
output_str = ''
|
||||
if c != None:
|
||||
stdout_str += 'c={0}, '.format(2.0**best_c)
|
||||
output_str += 'log2c={0} '.format(c)
|
||||
if g != None:
|
||||
stdout_str += 'g={0}, '.format(2.0**best_g)
|
||||
output_str += 'log2g={0} '.format(g)
|
||||
stdout_str += 'rate={0})'.format(best_rate)
|
||||
print(stdout_str)
|
||||
if options.out_pathname and not resumed:
|
||||
output_str += 'rate={0}\n'.format(rate)
|
||||
result_file.write(output_str)
|
||||
result_file.flush()
|
||||
|
||||
return best_c,best_g,best_rate
|
||||
|
||||
options = GridOption(dataset_pathname, options);
|
||||
|
||||
if options.gnuplot_pathname:
|
||||
gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
|
||||
else:
|
||||
gnuplot = None
|
||||
|
||||
# put jobs in queue
|
||||
|
||||
jobs,resumed_jobs = calculate_jobs(options)
|
||||
job_queue = Queue(0)
|
||||
result_queue = Queue(0)
|
||||
|
||||
for (c,g) in resumed_jobs:
|
||||
result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))
|
||||
|
||||
for line in jobs:
|
||||
for (c,g) in line:
|
||||
if (c,g) not in resumed_jobs:
|
||||
job_queue.put((c,g))
|
||||
|
||||
# hack the queue to become a stack --
|
||||
# this is important when some thread
|
||||
# failed and re-put a job. It we still
|
||||
# use FIFO, the job will be put
|
||||
# into the end of the queue, and the graph
|
||||
# will only be updated in the end
|
||||
|
||||
job_queue._put = job_queue.queue.appendleft
|
||||
|
||||
# fire telnet workers
|
||||
|
||||
if telnet_workers:
|
||||
nr_telnet_worker = len(telnet_workers)
|
||||
username = getpass.getuser()
|
||||
password = getpass.getpass()
|
||||
for host in telnet_workers:
|
||||
worker = TelnetWorker(host,job_queue,result_queue,
|
||||
host,username,password,options)
|
||||
worker.start()
|
||||
|
||||
# fire ssh workers
|
||||
|
||||
if ssh_workers:
|
||||
for host in ssh_workers:
|
||||
worker = SSHWorker(host,job_queue,result_queue,host,options)
|
||||
worker.start()
|
||||
|
||||
# fire local workers
|
||||
|
||||
for i in range(nr_local_worker):
|
||||
worker = LocalWorker('local',job_queue,result_queue,options)
|
||||
worker.start()
|
||||
|
||||
# gather results
|
||||
|
||||
done_jobs = {}
|
||||
|
||||
if options.out_pathname:
|
||||
if options.resume_pathname:
|
||||
result_file = open(options.out_pathname, 'a')
|
||||
else:
|
||||
result_file = open(options.out_pathname, 'w')
|
||||
|
||||
|
||||
db = []
|
||||
best_rate = -1
|
||||
best_c,best_g = None,None
|
||||
|
||||
for (c,g) in resumed_jobs:
|
||||
rate = resumed_jobs[(c,g)]
|
||||
best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)
|
||||
|
||||
for line in jobs:
|
||||
for (c,g) in line:
|
||||
while (c,g) not in done_jobs:
|
||||
(worker,c1,g1,rate1) = result_queue.get()
|
||||
done_jobs[(c1,g1)] = rate1
|
||||
if (c1,g1) not in resumed_jobs:
|
||||
best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
|
||||
db.append((c,g,done_jobs[(c,g)]))
|
||||
if gnuplot and options.grid_with_c and options.grid_with_g:
|
||||
redraw(db,[best_c, best_g, best_rate],gnuplot,options)
|
||||
redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)
|
||||
|
||||
|
||||
if options.out_pathname:
|
||||
result_file.close()
|
||||
job_queue.put((WorkerStopToken,None))
|
||||
best_param, best_cg = {}, []
|
||||
if best_c != None:
|
||||
best_param['c'] = 2.0**best_c
|
||||
best_cg += [2.0**best_c]
|
||||
if best_g != None:
|
||||
best_param['g'] = 2.0**best_g
|
||||
best_cg += [2.0**best_g]
|
||||
print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))
|
||||
|
||||
return best_rate, best_param
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def exit_with_help():
|
||||
print("""\
|
||||
Usage: grid.py [grid_options] [svm_options] dataset
|
||||
|
||||
grid_options :
|
||||
-log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
|
||||
begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
|
||||
"null" -- do not grid with c
|
||||
-log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
|
||||
begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
|
||||
"null" -- do not grid with g
|
||||
-v n : n-fold cross validation (default 5)
|
||||
-svmtrain pathname : set svm executable path and name
|
||||
-gnuplot {pathname | "null"} :
|
||||
pathname -- set gnuplot executable path and name
|
||||
"null" -- do not plot
|
||||
-out {pathname | "null"} : (default dataset.out)
|
||||
pathname -- set output file path and name
|
||||
"null" -- do not output file
|
||||
-png pathname : set graphic output file path and name (default dataset.png)
|
||||
-resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
|
||||
This is experimental. Try this option only if some parameters have been checked for the SAME data.
|
||||
|
||||
svm_options : additional options for svm-train""")
|
||||
sys.exit(1)
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
exit_with_help()
|
||||
dataset_pathname = sys.argv[-1]
|
||||
options = sys.argv[1:-1]
|
||||
try:
|
||||
find_parameters(dataset_pathname, options)
|
||||
except (IOError,ValueError) as e:
|
||||
sys.stderr.write(str(e) + '\n')
|
||||
sys.stderr.write('Try "grid.py" for more information.\n')
|
||||
sys.exit(1)
|
120
libsvm-3.36/tools/subset.py
Executable file
120
libsvm-3.36/tools/subset.py
Executable file
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os, sys, math, random
|
||||
from collections import defaultdict
|
||||
|
||||
if sys.version_info[0] >= 3:
|
||||
xrange = range
|
||||
|
||||
def exit_with_help(argv):
|
||||
print("""\
|
||||
Usage: {0} [options] dataset subset_size [output1] [output2]
|
||||
|
||||
This script randomly selects a subset of the dataset.
|
||||
|
||||
options:
|
||||
-s method : method of selection (default 0)
|
||||
0 -- stratified selection (classification only)
|
||||
1 -- random selection
|
||||
|
||||
output1 : the subset (optional)
|
||||
output2 : rest of the data (optional)
|
||||
If output1 is omitted, the subset will be printed on the screen.""".format(argv[0]))
|
||||
exit(1)
|
||||
|
||||
def process_options(argv):
|
||||
argc = len(argv)
|
||||
if argc < 3:
|
||||
exit_with_help(argv)
|
||||
|
||||
# default method is stratified selection
|
||||
method = 0
|
||||
subset_file = sys.stdout
|
||||
rest_file = None
|
||||
|
||||
i = 1
|
||||
while i < argc:
|
||||
if argv[i][0] != "-":
|
||||
break
|
||||
if argv[i] == "-s":
|
||||
i = i + 1
|
||||
method = int(argv[i])
|
||||
if method not in [0,1]:
|
||||
print("Unknown selection method {0}".format(method))
|
||||
exit_with_help(argv)
|
||||
i = i + 1
|
||||
|
||||
dataset = argv[i]
|
||||
subset_size = int(argv[i+1])
|
||||
if i+2 < argc:
|
||||
subset_file = open(argv[i+2],'w')
|
||||
if i+3 < argc:
|
||||
rest_file = open(argv[i+3],'w')
|
||||
|
||||
return dataset, subset_size, method, subset_file, rest_file
|
||||
|
||||
def random_selection(dataset, subset_size):
|
||||
l = sum(1 for line in open(dataset,'r'))
|
||||
return sorted(random.sample(xrange(l), subset_size))
|
||||
|
||||
def stratified_selection(dataset, subset_size):
|
||||
labels = [line.split(None,1)[0] for line in open(dataset)]
|
||||
label_linenums = defaultdict(list)
|
||||
for i, label in enumerate(labels):
|
||||
label_linenums[label] += [i]
|
||||
|
||||
l = len(labels)
|
||||
remaining = subset_size
|
||||
ret = []
|
||||
|
||||
# classes with fewer data are sampled first; otherwise
|
||||
# some rare classes may not be selected
|
||||
for label in sorted(label_linenums, key=lambda x: len(label_linenums[x])):
|
||||
linenums = label_linenums[label]
|
||||
label_size = len(linenums)
|
||||
# at least one instance per class
|
||||
s = int(min(remaining, max(1, math.ceil(label_size*(float(subset_size)/l)))))
|
||||
if s == 0:
|
||||
sys.stderr.write('''\
|
||||
Error: failed to have at least one instance per class
|
||||
1. You may have regression data.
|
||||
2. Your classification data is unbalanced or too small.
|
||||
Please use -s 1.
|
||||
''')
|
||||
sys.exit(-1)
|
||||
remaining -= s
|
||||
ret += [linenums[i] for i in random.sample(xrange(label_size), s)]
|
||||
return sorted(ret)
|
||||
|
||||
def main(argv=sys.argv):
|
||||
dataset, subset_size, method, subset_file, rest_file = process_options(argv)
|
||||
#uncomment the following line to fix the random seed
|
||||
#random.seed(0)
|
||||
selected_lines = []
|
||||
|
||||
if method == 0:
|
||||
selected_lines = stratified_selection(dataset, subset_size)
|
||||
elif method == 1:
|
||||
selected_lines = random_selection(dataset, subset_size)
|
||||
|
||||
#select instances based on selected_lines
|
||||
dataset = open(dataset,'r')
|
||||
prev_selected_linenum = -1
|
||||
for i in xrange(len(selected_lines)):
|
||||
for cnt in xrange(selected_lines[i]-prev_selected_linenum-1):
|
||||
line = dataset.readline()
|
||||
if rest_file:
|
||||
rest_file.write(line)
|
||||
subset_file.write(dataset.readline())
|
||||
prev_selected_linenum = selected_lines[i]
|
||||
subset_file.close()
|
||||
|
||||
if rest_file:
|
||||
for line in dataset:
|
||||
rest_file.write(line)
|
||||
rest_file.close()
|
||||
dataset.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv)
|
||||
|
Reference in New Issue
Block a user