subsetter.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from itertools import chain
  2. import math
  3. import cv2 as cv
  4. import numpy as np
  5. from .feature_matcher import FeatureMatcher
  6. from .stitching_error import StitchingError
  7. class Subsetter:
  8. DEFAULT_CONFIDENCE_THRESHOLD = 1
  9. DEFAULT_MATCHES_GRAPH_DOT_FILE = None
  10. def __init__(self,
  11. confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD,
  12. matches_graph_dot_file=DEFAULT_MATCHES_GRAPH_DOT_FILE):
  13. self.confidence_threshold = confidence_threshold
  14. self.save_file = matches_graph_dot_file
  15. def subset(self, img_names, img_sizes, imgs, features, matches):
  16. self.save_matches_graph_dot_file(img_names, matches)
  17. indices = self.get_indices_to_keep(features, matches)
  18. img_names = Subsetter.subset_list(img_names, indices)
  19. img_sizes = Subsetter.subset_list(img_sizes, indices)
  20. imgs = Subsetter.subset_list(imgs, indices)
  21. features = Subsetter.subset_list(features, indices)
  22. matches = Subsetter.subset_matches(matches, indices)
  23. return img_names, img_sizes, imgs, features, matches
  24. def save_matches_graph_dot_file(self, img_names, pairwise_matches):
  25. if self.save_file:
  26. with open(self.save_file, 'w') as filehandler:
  27. filehandler.write(self.get_matches_graph(img_names,
  28. pairwise_matches)
  29. )
  30. def get_matches_graph(self, img_names, pairwise_matches):
  31. return cv.detail.matchesGraphAsString(img_names, pairwise_matches,
  32. self.confidence_threshold)
  33. def get_indices_to_keep(self, features, pairwise_matches):
  34. indices = cv.detail.leaveBiggestComponent(features,
  35. pairwise_matches,
  36. self.confidence_threshold)
  37. if len(indices) < 2:
  38. raise StitchingError("No match exceeds the "
  39. "given confidence theshold.")
  40. return indices
  41. @staticmethod
  42. def subset_list(list_to_subset, indices):
  43. return [list_to_subset[i] for i in indices]
  44. @staticmethod
  45. def subset_matches(pairwise_matches, indices):
  46. indices_to_delete = Subsetter.get_indices_to_delete(
  47. math.sqrt(len(pairwise_matches)),
  48. indices
  49. )
  50. matches_matrix = FeatureMatcher.get_matches_matrix(pairwise_matches)
  51. matches_matrix_subset = Subsetter.subset_matrix(matches_matrix,
  52. indices_to_delete)
  53. matches_subset = Subsetter.matrix_rows_to_list(matches_matrix_subset)
  54. return matches_subset
  55. @staticmethod
  56. def get_indices_to_delete(nr_elements, indices_to_keep):
  57. return list(set(range(int(nr_elements))) - set(indices_to_keep))
  58. @staticmethod
  59. def subset_matrix(matrix_to_subset, indices_to_delete):
  60. for idx, idx_to_delete in enumerate(indices_to_delete):
  61. matrix_to_subset = Subsetter.delete_index_from_matrix(
  62. matrix_to_subset,
  63. idx_to_delete-idx # matrix shape reduced by one at each step
  64. )
  65. return matrix_to_subset
  66. @staticmethod
  67. def delete_index_from_matrix(matrix, idx):
  68. mask = np.ones(matrix.shape[0], bool)
  69. mask[idx] = 0
  70. return matrix[mask, :][:, mask]
  71. @staticmethod
  72. def matrix_rows_to_list(matrix):
  73. return list(chain.from_iterable(matrix.tolist()))