stitcher.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. from types import SimpleNamespace
  2. from .image_handler import ImageHandler
  3. from .feature_detector import FeatureDetector
  4. from .feature_matcher import FeatureMatcher
  5. from .subsetter import Subsetter
  6. from .camera_estimator import CameraEstimator
  7. from .camera_adjuster import CameraAdjuster
  8. from .camera_wave_corrector import WaveCorrector
  9. from .warper import Warper
  10. from .cropper import Cropper
  11. from .exposure_error_compensator import ExposureErrorCompensator
  12. from .seam_finder import SeamFinder
  13. from .blender import Blender
  14. from .timelapser import Timelapser
  15. from .stitching_error import StitchingError
  16. class Stitcher:
  17. DEFAULT_SETTINGS = {
  18. "medium_megapix": ImageHandler.DEFAULT_MEDIUM_MEGAPIX,
  19. "detector": FeatureDetector.DEFAULT_DETECTOR,
  20. "nfeatures": 500,
  21. "matcher_type": FeatureMatcher.DEFAULT_MATCHER,
  22. "range_width": FeatureMatcher.DEFAULT_RANGE_WIDTH,
  23. "try_use_gpu": False,
  24. "match_conf": None,
  25. "confidence_threshold": Subsetter.DEFAULT_CONFIDENCE_THRESHOLD,
  26. "matches_graph_dot_file": Subsetter.DEFAULT_MATCHES_GRAPH_DOT_FILE,
  27. "estimator": CameraEstimator.DEFAULT_CAMERA_ESTIMATOR,
  28. "adjuster": CameraAdjuster.DEFAULT_CAMERA_ADJUSTER,
  29. "refinement_mask": CameraAdjuster.DEFAULT_REFINEMENT_MASK,
  30. "wave_correct_kind": WaveCorrector.DEFAULT_WAVE_CORRECTION,
  31. "warper_type": Warper.DEFAULT_WARP_TYPE,
  32. "low_megapix": ImageHandler.DEFAULT_LOW_MEGAPIX,
  33. "crop": Cropper.DEFAULT_CROP,
  34. "compensator": ExposureErrorCompensator.DEFAULT_COMPENSATOR,
  35. "nr_feeds": ExposureErrorCompensator.DEFAULT_NR_FEEDS,
  36. "block_size": ExposureErrorCompensator.DEFAULT_BLOCK_SIZE,
  37. "finder": SeamFinder.DEFAULT_SEAM_FINDER,
  38. "final_megapix": ImageHandler.DEFAULT_FINAL_MEGAPIX,
  39. "blender_type": Blender.DEFAULT_BLENDER,
  40. "blend_strength": Blender.DEFAULT_BLEND_STRENGTH,
  41. "timelapse": Timelapser.DEFAULT_TIMELAPSE}
  42. def __init__(self, **kwargs):
  43. self.initialize_stitcher(**kwargs)
  44. def initialize_stitcher(self, **kwargs):
  45. self.settings = Stitcher.DEFAULT_SETTINGS.copy()
  46. self.validate_kwargs(kwargs)
  47. self.settings.update(kwargs)
  48. args = SimpleNamespace(**self.settings)
  49. self.img_handler = ImageHandler(args.medium_megapix,
  50. args.low_megapix,
  51. args.final_megapix)
  52. self.detector = \
  53. FeatureDetector(args.detector, nfeatures=args.nfeatures)
  54. match_conf = \
  55. FeatureMatcher.get_match_conf(args.match_conf, args.detector)
  56. self.matcher = FeatureMatcher(args.matcher_type, args.range_width,
  57. try_use_gpu=args.try_use_gpu,
  58. match_conf=match_conf)
  59. self.subsetter = \
  60. Subsetter(args.confidence_threshold, args.matches_graph_dot_file)
  61. self.camera_estimator = CameraEstimator(args.estimator)
  62. self.camera_adjuster = \
  63. CameraAdjuster(args.adjuster, args.refinement_mask)
  64. self.wave_corrector = WaveCorrector(args.wave_correct_kind)
  65. self.warper = Warper(args.warper_type)
  66. self.cropper = Cropper(args.crop)
  67. self.compensator = \
  68. ExposureErrorCompensator(args.compensator, args.nr_feeds,
  69. args.block_size)
  70. self.seam_finder = SeamFinder(args.finder)
  71. self.blender = Blender(args.blender_type, args.blend_strength)
  72. self.timelapser = Timelapser(args.timelapse)
  73. def stitch(self, img_names):
  74. self.initialize_registration(img_names)
  75. imgs = self.resize_medium_resolution()
  76. features = self.find_features(imgs)
  77. matches = self.match_features(features)
  78. imgs, features, matches = self.subset(imgs, features, matches)
  79. cameras = self.estimate_camera_parameters(features, matches)
  80. cameras = self.refine_camera_parameters(features, matches, cameras)
  81. cameras = self.perform_wave_correction(cameras)
  82. self.estimate_scale(cameras)
  83. imgs = self.resize_low_resolution(imgs)
  84. imgs, masks, corners, sizes = self.warp_low_resolution(imgs, cameras)
  85. self.prepare_cropper(imgs, masks, corners, sizes)
  86. imgs, masks, corners, sizes = \
  87. self.crop_low_resolution(imgs, masks, corners, sizes)
  88. self.estimate_exposure_errors(corners, imgs, masks)
  89. seam_masks = self.find_seam_masks(imgs, corners, masks)
  90. imgs = self.resize_final_resolution()
  91. imgs, masks, corners, sizes = self.warp_final_resolution(imgs, cameras)
  92. imgs, masks, corners, sizes = \
  93. self.crop_final_resolution(imgs, masks, corners, sizes)
  94. self.set_masks(masks)
  95. imgs = self.compensate_exposure_errors(corners, imgs)
  96. seam_masks = self.resize_seam_masks(seam_masks)
  97. self.initialize_composition(corners, sizes)
  98. self.blend_images(imgs, seam_masks, corners)
  99. return self.create_final_panorama()
  100. def initialize_registration(self, img_names):
  101. self.img_handler.set_img_names(img_names)
  102. def resize_medium_resolution(self):
  103. return list(self.img_handler.resize_to_medium_resolution())
  104. def find_features(self, imgs):
  105. return [self.detector.detect_features(img) for img in imgs]
  106. def match_features(self, features):
  107. return self.matcher.match_features(features)
  108. def subset(self, imgs, features, matches):
  109. names, sizes, imgs, features, matches = \
  110. self.subsetter.subset(self.img_handler.img_names,
  111. self.img_handler.img_sizes,
  112. imgs, features, matches)
  113. self.img_handler.img_names, self.img_handler.img_sizes = names, sizes
  114. return imgs, features, matches
  115. def estimate_camera_parameters(self, features, matches):
  116. return self.camera_estimator.estimate(features, matches)
  117. def refine_camera_parameters(self, features, matches, cameras):
  118. return self.camera_adjuster.adjust(features, matches, cameras)
  119. def perform_wave_correction(self, cameras):
  120. return self.wave_corrector.correct(cameras)
  121. def estimate_scale(self, cameras):
  122. self.warper.set_scale(cameras)
  123. def resize_low_resolution(self, imgs=None):
  124. return list(self.img_handler.resize_to_low_resolution(imgs))
  125. def warp_low_resolution(self, imgs, cameras):
  126. sizes = self.img_handler.get_low_img_sizes()
  127. camera_aspect = self.img_handler.get_medium_to_low_ratio()
  128. imgs, masks, corners, sizes = \
  129. self.warp(imgs, cameras, sizes, camera_aspect)
  130. return list(imgs), list(masks), corners, sizes
  131. def warp_final_resolution(self, imgs, cameras):
  132. sizes = self.img_handler.get_final_img_sizes()
  133. camera_aspect = self.img_handler.get_medium_to_final_ratio()
  134. return self.warp(imgs, cameras, sizes, camera_aspect)
  135. def warp(self, imgs, cameras, sizes, aspect=1):
  136. imgs = self.warper.warp_images(imgs, cameras, aspect)
  137. masks = self.warper.create_and_warp_masks(sizes, cameras, aspect)
  138. corners, sizes = self.warper.warp_rois(sizes, cameras, aspect)
  139. return imgs, masks, corners, sizes
  140. def prepare_cropper(self, imgs, masks, corners, sizes):
  141. self.cropper.prepare(imgs, masks, corners, sizes)
  142. def crop_low_resolution(self, imgs, masks, corners, sizes):
  143. imgs, masks, corners, sizes = self.crop(imgs, masks, corners, sizes)
  144. return list(imgs), list(masks), corners, sizes
  145. def crop_final_resolution(self, imgs, masks, corners, sizes):
  146. lir_aspect = self.img_handler.get_low_to_final_ratio()
  147. return self.crop(imgs, masks, corners, sizes, lir_aspect)
  148. def crop(self, imgs, masks, corners, sizes, aspect=1):
  149. masks = self.cropper.crop_images(masks, aspect)
  150. imgs = self.cropper.crop_images(imgs, aspect)
  151. corners, sizes = self.cropper.crop_rois(corners, sizes, aspect)
  152. return imgs, masks, corners, sizes
  153. def estimate_exposure_errors(self, corners, imgs, masks):
  154. self.compensator.feed(corners, imgs, masks)
  155. def find_seam_masks(self, imgs, corners, masks):
  156. return self.seam_finder.find(imgs, corners, masks)
  157. def resize_final_resolution(self):
  158. return self.img_handler.resize_to_final_resolution()
  159. def compensate_exposure_errors(self, corners, imgs):
  160. for idx, (corner, img) in enumerate(zip(corners, imgs)):
  161. yield self.compensator.apply(idx, corner, img, self.get_mask(idx))
  162. def resize_seam_masks(self, seam_masks):
  163. for idx, seam_mask in enumerate(seam_masks):
  164. yield SeamFinder.resize(seam_mask, self.get_mask(idx))
  165. def set_masks(self, mask_generator):
  166. self.masks = mask_generator
  167. self.mask_index = -1
  168. def get_mask(self, idx):
  169. if idx == self.mask_index + 1:
  170. self.mask_index += 1
  171. self.mask = next(self.masks)
  172. return self.mask
  173. elif idx == self.mask_index:
  174. return self.mask
  175. else:
  176. raise StitchingError("Invalid Mask Index!")
  177. def initialize_composition(self, corners, sizes):
  178. if self.timelapser.do_timelapse:
  179. self.timelapser.initialize(corners, sizes)
  180. else:
  181. self.blender.prepare(corners, sizes)
  182. def blend_images(self, imgs, masks, corners):
  183. for idx, (img, mask, corner) in enumerate(zip(imgs, masks, corners)):
  184. if self.timelapser.do_timelapse:
  185. self.timelapser.process_and_save_frame(
  186. self.img_handler.img_names[idx], img, corner
  187. )
  188. else:
  189. self.blender.feed(img, mask, corner)
  190. def create_final_panorama(self):
  191. if not self.timelapser.do_timelapse:
  192. panorama, _ = self.blender.blend()
  193. return panorama
  194. @staticmethod
  195. def validate_kwargs(kwargs):
  196. for arg in kwargs:
  197. if arg not in Stitcher.DEFAULT_SETTINGS:
  198. raise StitchingError("Invalid Argument: " + arg)