k_means.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import argparse
  2. import sys
  3. import os
  4. import time
  5. import numpy as np
  6. from sklearn.cluster import KMeans
  7. import matplotlib.pyplot as plt
  8. def k_means(K, data, max_iter, n_jobs, image_file):
  9. X = np.array(data)
  10. np.random.shuffle(X)
  11. begin = time.time()
  12. print 'Running kmeans'
  13. kmeans = KMeans(n_clusters=K, max_iter=max_iter, n_jobs=n_jobs, verbose=1).fit(X)
  14. print 'K-Means took {} seconds to complete'.format(time.time()-begin)
  15. step_size = 0.2
  16. xmin, xmax = X[:, 0].min()-1, X[:, 0].max()+1
  17. ymin, ymax = X[:, 1].min()-1, X[:, 1].max()+1
  18. xx, yy = np.meshgrid(np.arange(xmin, xmax, step_size), np.arange(ymin, ymax, step_size))
  19. preds = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])
  20. preds = preds.reshape(xx.shape)
  21. plt.figure()
  22. plt.clf()
  23. plt.imshow(preds, interpolation='nearest', extent=(xx.min(), xx.max(), yy.min(), yy.max()), cmap=plt.cm.Paired, aspect='auto', origin='lower')
  24. plt.plot(X[:, 0], X[:, 1], 'k.', markersize=2)
  25. centroids = kmeans.cluster_centers_
  26. plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', s=169, linewidths=5, color='r', zorder=10)
  27. plt.title("Anchor shapes generated using K-Means")
  28. plt.xlim(xmin, xmax)
  29. plt.ylim(ymin, ymax)
  30. print 'Mean centroids are:'
  31. for i, center in enumerate(centroids):
  32. print '{}: {}, {}'.format(i, center[0], center[1])
  33. # plt.xticks(())
  34. # plt.yticks(())
  35. plt.show()
  36. def pre_process(directory, data_list):
  37. if not os.path.exists(directory):
  38. print "Path {} doesn't exist".format(directory)
  39. return
  40. files = os.listdir(directory)
  41. print 'Loading data...'
  42. for i, f in enumerate(files):
  43. # Progress bar
  44. sys.stdout.write('\r')
  45. percentage = (i+1.0) / len(files)
  46. progress = int(percentage * 30)
  47. bar = [progress*'=', ' '*(29-progress), percentage*100]
  48. sys.stdout.write('[{}>{}] {:.0f}%'.format(*bar))
  49. sys.stdout.flush()
  50. with open(directory+"/"+f, 'r') as ann:
  51. l = ann.readline()
  52. l = l.rstrip()
  53. l = l.split(' ')
  54. l = [float(i) for i in l]
  55. if len(l) % 5 != 0:
  56. sys.stderr.write('File {} contains incorrect number of annotations'.format(f))
  57. return
  58. num_objs = len(l) / 5
  59. for obj in range(num_objs):
  60. xmin = l[obj * 5 + 0]
  61. ymin = l[obj * 5 + 1]
  62. xmax = l[obj * 5 + 2]
  63. ymax = l[obj * 5 + 3]
  64. w = xmax - xmin
  65. h = ymax - ymin
  66. data_list.append([w, h])
  67. if w > 1000 or h > 1000:
  68. sys.stdout.write("[{}, {}]".format(w, h))
  69. sys.stdout.write('\nProcessed {} files containing {} objects'.format(len(files), len(data_list)))
  70. return data_list
  71. def main():
  72. parser = argparse.ArgumentParser("Parse hyperparameters")
  73. parser.add_argument("clusters", help="Number of clusters", type=int)
  74. parser.add_argument("dir", help="Directory containing annotations")
  75. parser.add_argument("image_file", help="File to generate the final cluster of image")
  76. parser.add_argument('-jobs', help="Number of jobs for parallel computation", default=1)
  77. parser.add_argument('-iter', help="Max Iterations to run algorithm for", default=1000)
  78. p = parser.parse_args(sys.argv[1:])
  79. K = p.clusters
  80. directory = p.dir
  81. data_list = []
  82. pre_process(directory, data_list )
  83. sys.stdout.write('\nDone collecting data\n')
  84. k_means(K, data_list, int(p.iter), int(p.jobs), p.image_file)
  85. print 'Done !'
  86. if __name__=='__main__':
  87. try:
  88. main()
  89. except Exception as E:
  90. print E