First commit
This commit is contained in:
500
libsvm-3.36/tools/grid.py
Executable file
500
libsvm-3.36/tools/grid.py
Executable file
@@ -0,0 +1,500 @@
|
||||
#!/usr/bin/env python
|
||||
__all__ = ['find_parameters']
|
||||
|
||||
import os, sys, traceback, getpass, time, re
|
||||
from threading import Thread
|
||||
from subprocess import *
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
from Queue import Queue
|
||||
else:
|
||||
from queue import Queue
|
||||
|
||||
telnet_workers = []
|
||||
ssh_workers = []
|
||||
nr_local_worker = 1
|
||||
|
||||
class GridOption:
|
||||
def __init__(self, dataset_pathname, options):
|
||||
dirname = os.path.dirname(__file__)
|
||||
if sys.platform != 'win32':
|
||||
self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
|
||||
self.gnuplot_pathname = '/usr/bin/gnuplot'
|
||||
else:
|
||||
# example for windows
|
||||
self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
|
||||
# svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe'
|
||||
self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
|
||||
self.fold = 5
|
||||
self.c_begin, self.c_end, self.c_step = -5, 15, 2
|
||||
self.g_begin, self.g_end, self.g_step = 3, -15, -2
|
||||
self.grid_with_c, self.grid_with_g = True, True
|
||||
self.dataset_pathname = dataset_pathname
|
||||
self.dataset_title = os.path.split(dataset_pathname)[1]
|
||||
self.out_pathname = '{0}.out'.format(self.dataset_title)
|
||||
self.png_pathname = '{0}.png'.format(self.dataset_title)
|
||||
self.pass_through_string = ' '
|
||||
self.resume_pathname = None
|
||||
self.parse_options(options)
|
||||
|
||||
def parse_options(self, options):
|
||||
if type(options) == str:
|
||||
options = options.split()
|
||||
i = 0
|
||||
pass_through_options = []
|
||||
|
||||
while i < len(options):
|
||||
if options[i] == '-log2c':
|
||||
i = i + 1
|
||||
if options[i] == 'null':
|
||||
self.grid_with_c = False
|
||||
else:
|
||||
self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
|
||||
elif options[i] == '-log2g':
|
||||
i = i + 1
|
||||
if options[i] == 'null':
|
||||
self.grid_with_g = False
|
||||
else:
|
||||
self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
|
||||
elif options[i] == '-v':
|
||||
i = i + 1
|
||||
self.fold = options[i]
|
||||
elif options[i] in ('-c','-g'):
|
||||
raise ValueError('Use -log2c and -log2g.')
|
||||
elif options[i] == '-svmtrain':
|
||||
i = i + 1
|
||||
self.svmtrain_pathname = options[i]
|
||||
elif options[i] == '-gnuplot':
|
||||
i = i + 1
|
||||
if options[i] == 'null':
|
||||
self.gnuplot_pathname = None
|
||||
else:
|
||||
self.gnuplot_pathname = options[i]
|
||||
elif options[i] == '-out':
|
||||
i = i + 1
|
||||
if options[i] == 'null':
|
||||
self.out_pathname = None
|
||||
else:
|
||||
self.out_pathname = options[i]
|
||||
elif options[i] == '-png':
|
||||
i = i + 1
|
||||
self.png_pathname = options[i]
|
||||
elif options[i] == '-resume':
|
||||
if i == (len(options)-1) or options[i+1].startswith('-'):
|
||||
self.resume_pathname = self.dataset_title + '.out'
|
||||
else:
|
||||
i = i + 1
|
||||
self.resume_pathname = options[i]
|
||||
else:
|
||||
pass_through_options.append(options[i])
|
||||
i = i + 1
|
||||
|
||||
self.pass_through_string = ' '.join(pass_through_options)
|
||||
if not os.path.exists(self.svmtrain_pathname):
|
||||
raise IOError('svm-train executable not found')
|
||||
if not os.path.exists(self.dataset_pathname):
|
||||
raise IOError('dataset not found')
|
||||
if self.resume_pathname and not os.path.exists(self.resume_pathname):
|
||||
raise IOError('file for resumption not found')
|
||||
if not self.grid_with_c and not self.grid_with_g:
|
||||
raise ValueError('-log2c and -log2g should not be null simultaneously')
|
||||
if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
|
||||
sys.stderr.write('gnuplot executable not found\n')
|
||||
self.gnuplot_pathname = None
|
||||
|
||||
def redraw(db,best_param,gnuplot,options,tofile=False):
|
||||
if len(db) == 0: return
|
||||
begin_level = round(max(x[2] for x in db)) - 3
|
||||
step_size = 0.5
|
||||
|
||||
best_log2c,best_log2g,best_rate = best_param
|
||||
|
||||
# if newly obtained c, g, or cv values are the same,
|
||||
# then stop redrawing the contour.
|
||||
if all(x[0] == db[0][0] for x in db): return
|
||||
if all(x[1] == db[0][1] for x in db): return
|
||||
if all(x[2] == db[0][2] for x in db): return
|
||||
|
||||
if tofile:
|
||||
gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
|
||||
gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
|
||||
#gnuplot.write(b"set term postscript color solid\n")
|
||||
#gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode())
|
||||
elif sys.platform == 'win32':
|
||||
gnuplot.write(b"set term windows\n")
|
||||
else:
|
||||
gnuplot.write( b"set term x11\n")
|
||||
gnuplot.write(b"set xlabel \"log2(C)\"\n")
|
||||
gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
|
||||
gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
|
||||
gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
|
||||
gnuplot.write(b"set contour\n")
|
||||
gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
|
||||
gnuplot.write(b"unset surface\n")
|
||||
gnuplot.write(b"unset ztics\n")
|
||||
gnuplot.write(b"set view 0,0\n")
|
||||
gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
|
||||
gnuplot.write(b"unset label\n")
|
||||
gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \
|
||||
at screen 0.5,0.85 center\n". \
|
||||
format(best_log2c, best_log2g, best_rate).encode())
|
||||
gnuplot.write("set label \"C = {0} gamma = {1}\""
|
||||
" at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
|
||||
gnuplot.write(b"set key at screen 0.9,0.9\n")
|
||||
gnuplot.write(b"splot \"-\" with lines\n")
|
||||
|
||||
db.sort(key = lambda x:(x[0], -x[1]))
|
||||
|
||||
prevc = db[0][0]
|
||||
for line in db:
|
||||
if prevc != line[0]:
|
||||
gnuplot.write(b"\n")
|
||||
prevc = line[0]
|
||||
gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
|
||||
gnuplot.write(b"e\n")
|
||||
gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
|
||||
gnuplot.flush()
|
||||
|
||||
|
||||
def calculate_jobs(options):
|
||||
|
||||
def range_f(begin,end,step):
|
||||
# like range, but works on non-integer too
|
||||
seq = []
|
||||
while True:
|
||||
if step > 0 and begin > end: break
|
||||
if step < 0 and begin < end: break
|
||||
seq.append(begin)
|
||||
begin = begin + step
|
||||
return seq
|
||||
|
||||
def permute_sequence(seq):
|
||||
n = len(seq)
|
||||
if n <= 1: return seq
|
||||
|
||||
mid = int(n/2)
|
||||
left = permute_sequence(seq[:mid])
|
||||
right = permute_sequence(seq[mid+1:])
|
||||
|
||||
ret = [seq[mid]]
|
||||
while left or right:
|
||||
if left: ret.append(left.pop(0))
|
||||
if right: ret.append(right.pop(0))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
|
||||
g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))
|
||||
|
||||
if not options.grid_with_c:
|
||||
c_seq = [None]
|
||||
if not options.grid_with_g:
|
||||
g_seq = [None]
|
||||
|
||||
nr_c = float(len(c_seq))
|
||||
nr_g = float(len(g_seq))
|
||||
i, j = 0, 0
|
||||
jobs = []
|
||||
|
||||
while i < nr_c or j < nr_g:
|
||||
if i/nr_c < j/nr_g:
|
||||
# increase C resolution
|
||||
line = []
|
||||
for k in range(0,j):
|
||||
line.append((c_seq[i],g_seq[k]))
|
||||
i = i + 1
|
||||
jobs.append(line)
|
||||
else:
|
||||
# increase g resolution
|
||||
line = []
|
||||
for k in range(0,i):
|
||||
line.append((c_seq[k],g_seq[j]))
|
||||
j = j + 1
|
||||
jobs.append(line)
|
||||
|
||||
resumed_jobs = {}
|
||||
|
||||
if options.resume_pathname is None:
|
||||
return jobs, resumed_jobs
|
||||
|
||||
for line in open(options.resume_pathname, 'r'):
|
||||
line = line.strip()
|
||||
rst = re.findall(r'rate=([0-9.]+)',line)
|
||||
if not rst:
|
||||
continue
|
||||
rate = float(rst[0])
|
||||
|
||||
c, g = None, None
|
||||
rst = re.findall(r'log2c=([0-9.-]+)',line)
|
||||
if rst:
|
||||
c = float(rst[0])
|
||||
rst = re.findall(r'log2g=([0-9.-]+)',line)
|
||||
if rst:
|
||||
g = float(rst[0])
|
||||
|
||||
resumed_jobs[(c,g)] = rate
|
||||
|
||||
return jobs, resumed_jobs
|
||||
|
||||
|
||||
class WorkerStopToken: # used to notify the worker to stop or if a worker is dead
|
||||
pass
|
||||
|
||||
class Worker(Thread):
|
||||
def __init__(self,name,job_queue,result_queue,options):
|
||||
Thread.__init__(self)
|
||||
self.name = name
|
||||
self.job_queue = job_queue
|
||||
self.result_queue = result_queue
|
||||
self.options = options
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
(cexp,gexp) = self.job_queue.get()
|
||||
if cexp is WorkerStopToken:
|
||||
self.job_queue.put((cexp,gexp))
|
||||
# print('worker {0} stop.'.format(self.name))
|
||||
break
|
||||
try:
|
||||
c, g = None, None
|
||||
if cexp != None:
|
||||
c = 2.0**cexp
|
||||
if gexp != None:
|
||||
g = 2.0**gexp
|
||||
rate = self.run_one(c,g)
|
||||
if rate is None: raise RuntimeError('get no rate')
|
||||
except:
|
||||
# we failed, let others do that and we just quit
|
||||
|
||||
traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
|
||||
|
||||
self.job_queue.put((cexp,gexp))
|
||||
sys.stderr.write('worker {0} quit.\n'.format(self.name))
|
||||
break
|
||||
else:
|
||||
self.result_queue.put((self.name,cexp,gexp,rate))
|
||||
|
||||
def get_cmd(self,c,g):
|
||||
options=self.options
|
||||
cmdline = '"' + options.svmtrain_pathname + '"'
|
||||
if options.grid_with_c:
|
||||
cmdline += ' -c {0} '.format(c)
|
||||
if options.grid_with_g:
|
||||
cmdline += ' -g {0} '.format(g)
|
||||
cmdline += ' -v {0} {1} {2} '.format\
|
||||
(options.fold,options.pass_through_string,'"' + options.dataset_pathname + '"')
|
||||
return cmdline
|
||||
|
||||
class LocalWorker(Worker):
|
||||
def run_one(self,c,g):
|
||||
cmdline = self.get_cmd(c,g)
|
||||
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
|
||||
for line in result.readlines():
|
||||
if str(line).find('Cross') != -1:
|
||||
return float(line.split()[-1][0:-1])
|
||||
|
||||
class SSHWorker(Worker):
|
||||
def __init__(self,name,job_queue,result_queue,host,options):
|
||||
Worker.__init__(self,name,job_queue,result_queue,options)
|
||||
self.host = host
|
||||
self.cwd = os.getcwd()
|
||||
def run_one(self,c,g):
|
||||
cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
|
||||
(self.host,self.cwd,self.get_cmd(c,g))
|
||||
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
|
||||
for line in result.readlines():
|
||||
if str(line).find('Cross') != -1:
|
||||
return float(line.split()[-1][0:-1])
|
||||
|
||||
class TelnetWorker(Worker):
|
||||
def __init__(self,name,job_queue,result_queue,host,username,password,options):
|
||||
Worker.__init__(self,name,job_queue,result_queue,options)
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.password = password
|
||||
def run(self):
|
||||
import telnetlib
|
||||
self.tn = tn = telnetlib.Telnet(self.host)
|
||||
tn.read_until('login: ')
|
||||
tn.write(self.username + '\n')
|
||||
tn.read_until('Password: ')
|
||||
tn.write(self.password + '\n')
|
||||
|
||||
# XXX: how to know whether login is successful?
|
||||
tn.read_until(self.username)
|
||||
#
|
||||
print('login ok', self.host)
|
||||
tn.write('cd '+os.getcwd()+'\n')
|
||||
Worker.run(self)
|
||||
tn.write('exit\n')
|
||||
def run_one(self,c,g):
|
||||
cmdline = self.get_cmd(c,g)
|
||||
result = self.tn.write(cmdline+'\n')
|
||||
(idx,matchm,output) = self.tn.expect(['Cross.*\n'])
|
||||
for line in output.split('\n'):
|
||||
if str(line).find('Cross') != -1:
|
||||
return float(line.split()[-1][0:-1])
|
||||
|
||||
def find_parameters(dataset_pathname, options=''):
|
||||
|
||||
def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
|
||||
if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c):
|
||||
best_rate,best_c,best_g = rate,c,g
|
||||
stdout_str = '[{0}] {1} {2} (best '.format\
|
||||
(worker,' '.join(str(x) for x in [c,g] if x is not None),rate)
|
||||
output_str = ''
|
||||
if c != None:
|
||||
stdout_str += 'c={0}, '.format(2.0**best_c)
|
||||
output_str += 'log2c={0} '.format(c)
|
||||
if g != None:
|
||||
stdout_str += 'g={0}, '.format(2.0**best_g)
|
||||
output_str += 'log2g={0} '.format(g)
|
||||
stdout_str += 'rate={0})'.format(best_rate)
|
||||
print(stdout_str)
|
||||
if options.out_pathname and not resumed:
|
||||
output_str += 'rate={0}\n'.format(rate)
|
||||
result_file.write(output_str)
|
||||
result_file.flush()
|
||||
|
||||
return best_c,best_g,best_rate
|
||||
|
||||
options = GridOption(dataset_pathname, options);
|
||||
|
||||
if options.gnuplot_pathname:
|
||||
gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
|
||||
else:
|
||||
gnuplot = None
|
||||
|
||||
# put jobs in queue
|
||||
|
||||
jobs,resumed_jobs = calculate_jobs(options)
|
||||
job_queue = Queue(0)
|
||||
result_queue = Queue(0)
|
||||
|
||||
for (c,g) in resumed_jobs:
|
||||
result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))
|
||||
|
||||
for line in jobs:
|
||||
for (c,g) in line:
|
||||
if (c,g) not in resumed_jobs:
|
||||
job_queue.put((c,g))
|
||||
|
||||
# hack the queue to become a stack --
|
||||
# this is important when some thread
|
||||
# failed and re-put a job. It we still
|
||||
# use FIFO, the job will be put
|
||||
# into the end of the queue, and the graph
|
||||
# will only be updated in the end
|
||||
|
||||
job_queue._put = job_queue.queue.appendleft
|
||||
|
||||
# fire telnet workers
|
||||
|
||||
if telnet_workers:
|
||||
nr_telnet_worker = len(telnet_workers)
|
||||
username = getpass.getuser()
|
||||
password = getpass.getpass()
|
||||
for host in telnet_workers:
|
||||
worker = TelnetWorker(host,job_queue,result_queue,
|
||||
host,username,password,options)
|
||||
worker.start()
|
||||
|
||||
# fire ssh workers
|
||||
|
||||
if ssh_workers:
|
||||
for host in ssh_workers:
|
||||
worker = SSHWorker(host,job_queue,result_queue,host,options)
|
||||
worker.start()
|
||||
|
||||
# fire local workers
|
||||
|
||||
for i in range(nr_local_worker):
|
||||
worker = LocalWorker('local',job_queue,result_queue,options)
|
||||
worker.start()
|
||||
|
||||
# gather results
|
||||
|
||||
done_jobs = {}
|
||||
|
||||
if options.out_pathname:
|
||||
if options.resume_pathname:
|
||||
result_file = open(options.out_pathname, 'a')
|
||||
else:
|
||||
result_file = open(options.out_pathname, 'w')
|
||||
|
||||
|
||||
db = []
|
||||
best_rate = -1
|
||||
best_c,best_g = None,None
|
||||
|
||||
for (c,g) in resumed_jobs:
|
||||
rate = resumed_jobs[(c,g)]
|
||||
best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)
|
||||
|
||||
for line in jobs:
|
||||
for (c,g) in line:
|
||||
while (c,g) not in done_jobs:
|
||||
(worker,c1,g1,rate1) = result_queue.get()
|
||||
done_jobs[(c1,g1)] = rate1
|
||||
if (c1,g1) not in resumed_jobs:
|
||||
best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
|
||||
db.append((c,g,done_jobs[(c,g)]))
|
||||
if gnuplot and options.grid_with_c and options.grid_with_g:
|
||||
redraw(db,[best_c, best_g, best_rate],gnuplot,options)
|
||||
redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)
|
||||
|
||||
|
||||
if options.out_pathname:
|
||||
result_file.close()
|
||||
job_queue.put((WorkerStopToken,None))
|
||||
best_param, best_cg = {}, []
|
||||
if best_c != None:
|
||||
best_param['c'] = 2.0**best_c
|
||||
best_cg += [2.0**best_c]
|
||||
if best_g != None:
|
||||
best_param['g'] = 2.0**best_g
|
||||
best_cg += [2.0**best_g]
|
||||
print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))
|
||||
|
||||
return best_rate, best_param
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def exit_with_help():
|
||||
print("""\
|
||||
Usage: grid.py [grid_options] [svm_options] dataset
|
||||
|
||||
grid_options :
|
||||
-log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
|
||||
begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
|
||||
"null" -- do not grid with c
|
||||
-log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
|
||||
begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
|
||||
"null" -- do not grid with g
|
||||
-v n : n-fold cross validation (default 5)
|
||||
-svmtrain pathname : set svm executable path and name
|
||||
-gnuplot {pathname | "null"} :
|
||||
pathname -- set gnuplot executable path and name
|
||||
"null" -- do not plot
|
||||
-out {pathname | "null"} : (default dataset.out)
|
||||
pathname -- set output file path and name
|
||||
"null" -- do not output file
|
||||
-png pathname : set graphic output file path and name (default dataset.png)
|
||||
-resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
|
||||
This is experimental. Try this option only if some parameters have been checked for the SAME data.
|
||||
|
||||
svm_options : additional options for svm-train""")
|
||||
sys.exit(1)
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
exit_with_help()
|
||||
dataset_pathname = sys.argv[-1]
|
||||
options = sys.argv[1:-1]
|
||||
try:
|
||||
find_parameters(dataset_pathname, options)
|
||||
except (IOError,ValueError) as e:
|
||||
sys.stderr.write(str(e) + '\n')
|
||||
sys.stderr.write('Try "grid.py" for more information.\n')
|
||||
sys.exit(1)
|
Reference in New Issue
Block a user