digits_adjust.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #!/usr/bin/env python
  2. '''
  3. Digit recognition adjustment.
  4. Grid search is used to find the best parameters for SVM and KNearest classifiers.
  5. SVM adjustment follows the guidelines given in
  6. http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
  7. Usage:
  8. digits_adjust.py [--model {svm|knearest}]
  9. --model {svm|knearest} - select the classifier (SVM is the default)
  10. '''
  11. # Python 2/3 compatibility
  12. from __future__ import print_function
  13. import sys
  14. PY3 = sys.version_info[0] == 3
  15. if PY3:
  16. xrange = range
  17. import numpy as np
  18. import cv2 as cv
  19. from multiprocessing.pool import ThreadPool
  20. from digits import *
  21. def cross_validate(model_class, params, samples, labels, kfold = 3, pool = None):
  22. n = len(samples)
  23. folds = np.array_split(np.arange(n), kfold)
  24. def f(i):
  25. model = model_class(**params)
  26. test_idx = folds[i]
  27. train_idx = list(folds)
  28. train_idx.pop(i)
  29. train_idx = np.hstack(train_idx)
  30. train_samples, train_labels = samples[train_idx], labels[train_idx]
  31. test_samples, test_labels = samples[test_idx], labels[test_idx]
  32. model.train(train_samples, train_labels)
  33. resp = model.predict(test_samples)
  34. score = (resp != test_labels).mean()
  35. print(".", end='')
  36. return score
  37. if pool is None:
  38. scores = list(map(f, xrange(kfold)))
  39. else:
  40. scores = pool.map(f, xrange(kfold))
  41. return np.mean(scores)
  42. class App(object):
  43. def __init__(self):
  44. self._samples, self._labels = self.preprocess()
  45. def preprocess(self):
  46. digits, labels = load_digits(DIGITS_FN)
  47. shuffle = np.random.permutation(len(digits))
  48. digits, labels = digits[shuffle], labels[shuffle]
  49. digits2 = list(map(deskew, digits))
  50. samples = preprocess_hog(digits2)
  51. return samples, labels
  52. def get_dataset(self):
  53. return self._samples, self._labels
  54. def run_jobs(self, f, jobs):
  55. pool = ThreadPool(processes=cv.getNumberOfCPUs())
  56. ires = pool.imap_unordered(f, jobs)
  57. return ires
  58. def adjust_SVM(self):
  59. Cs = np.logspace(0, 10, 15, base=2)
  60. gammas = np.logspace(-7, 4, 15, base=2)
  61. scores = np.zeros((len(Cs), len(gammas)))
  62. scores[:] = np.nan
  63. print('adjusting SVM (may take a long time) ...')
  64. def f(job):
  65. i, j = job
  66. samples, labels = self.get_dataset()
  67. params = dict(C = Cs[i], gamma=gammas[j])
  68. score = cross_validate(SVM, params, samples, labels)
  69. return i, j, score
  70. ires = self.run_jobs(f, np.ndindex(*scores.shape))
  71. for count, (i, j, score) in enumerate(ires):
  72. scores[i, j] = score
  73. print('%d / %d (best error: %.2f %%, last: %.2f %%)' %
  74. (count+1, scores.size, np.nanmin(scores)*100, score*100))
  75. print(scores)
  76. print('writing score table to "svm_scores.npz"')
  77. np.savez('svm_scores.npz', scores=scores, Cs=Cs, gammas=gammas)
  78. i, j = np.unravel_index(scores.argmin(), scores.shape)
  79. best_params = dict(C = Cs[i], gamma=gammas[j])
  80. print('best params:', best_params)
  81. print('best error: %.2f %%' % (scores.min()*100))
  82. return best_params
  83. def adjust_KNearest(self):
  84. print('adjusting KNearest ...')
  85. def f(k):
  86. samples, labels = self.get_dataset()
  87. err = cross_validate(KNearest, dict(k=k), samples, labels)
  88. return k, err
  89. best_err, best_k = np.inf, -1
  90. for k, err in self.run_jobs(f, xrange(1, 9)):
  91. if err < best_err:
  92. best_err, best_k = err, k
  93. print('k = %d, error: %.2f %%' % (k, err*100))
  94. best_params = dict(k=best_k)
  95. print('best params:', best_params, 'err: %.2f' % (best_err*100))
  96. return best_params
  97. if __name__ == '__main__':
  98. import getopt
  99. import sys
  100. print(__doc__)
  101. args, _ = getopt.getopt(sys.argv[1:], '', ['model='])
  102. args = dict(args)
  103. args.setdefault('--model', 'svm')
  104. args.setdefault('--env', '')
  105. if args['--model'] not in ['svm', 'knearest']:
  106. print('unknown model "%s"' % args['--model'])
  107. sys.exit(1)
  108. t = clock()
  109. app = App()
  110. if args['--model'] == 'knearest':
  111. app.adjust_KNearest()
  112. else:
  113. app.adjust_SVM()
  114. print('work time: %f s' % (clock() - t))