tf_text_graph_mask_rcnn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import argparse
  2. import numpy as np
  3. from tf_text_graph_common import *
  4. parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
  5. 'Mask-RCNN model from TensorFlow Object Detection API. '
  6. 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
  7. parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
  8. parser.add_argument('--output', required=True, help='Path to output text graph.')
  9. parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
  10. args = parser.parse_args()
  11. scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
  12. 'FirstStageBoxPredictor/BoxEncodingPredictor',
  13. 'FirstStageBoxPredictor/ClassPredictor',
  14. 'CropAndResize',
  15. 'MaxPool2D',
  16. 'SecondStageFeatureExtractor',
  17. 'SecondStageBoxPredictor',
  18. 'Preprocessor/sub',
  19. 'Preprocessor/mul',
  20. 'image_tensor')
  21. scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
  22. 'FirstStageFeatureExtractor/Shape',
  23. 'FirstStageFeatureExtractor/strided_slice',
  24. 'FirstStageFeatureExtractor/GreaterEqual',
  25. 'FirstStageFeatureExtractor/LogicalAnd',
  26. 'Conv/required_space_to_batch_paddings')
  27. # Load a config file.
  28. config = readTextMessage(args.config)
  29. config = config['model'][0]['faster_rcnn'][0]
  30. num_classes = int(config['num_classes'][0])
  31. grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
  32. scales = [float(s) for s in grid_anchor_generator['scales']]
  33. aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
  34. width_stride = float(grid_anchor_generator['width_stride'][0])
  35. height_stride = float(grid_anchor_generator['height_stride'][0])
  36. features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
  37. first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
  38. first_stage_max_proposals = int(config['first_stage_max_proposals'][0])
  39. print('Number of classes: %d' % num_classes)
  40. print('Scales: %s' % str(scales))
  41. print('Aspect ratios: %s' % str(aspect_ratios))
  42. print('Width stride: %f' % width_stride)
  43. print('Height stride: %f' % height_stride)
  44. print('Features stride: %f' % features_stride)
  45. # Read the graph.
  46. writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks'])
  47. graph_def = parseTextGraph(args.output)
  48. removeIdentity(graph_def)
  49. nodesToKeep = []
  50. def to_remove(name, op):
  51. if name in nodesToKeep:
  52. return False
  53. return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
  54. (name.startswith('CropAndResize') and op != 'CropAndResize')
  55. # Fuse atrous convolutions (with dilations).
  56. nodesMap = {node.name: node for node in graph_def.node}
  57. for node in reversed(graph_def.node):
  58. if node.op == 'BatchToSpaceND':
  59. del node.input[2]
  60. conv = nodesMap[node.input[0]]
  61. spaceToBatchND = nodesMap[conv.input[0]]
  62. paddingsNode = NodeDef()
  63. paddingsNode.name = conv.name + '/paddings'
  64. paddingsNode.op = 'Const'
  65. paddingsNode.addAttr('value', [2, 2, 2, 2])
  66. graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode)
  67. nodesToKeep.append(paddingsNode.name)
  68. spaceToBatchND.input[2] = paddingsNode.name
  69. removeUnusedNodesAndAttrs(to_remove, graph_def)
  70. # Connect input node to the first layer
  71. assert(graph_def.node[0].op == 'Placeholder')
  72. graph_def.node[1].input.insert(0, graph_def.node[0].name)
  73. # Temporarily remove top nodes.
  74. topNodes = []
  75. numCropAndResize = 0
  76. while True:
  77. node = graph_def.node.pop()
  78. topNodes.append(node)
  79. if node.op == 'CropAndResize':
  80. numCropAndResize += 1
  81. if numCropAndResize == 2:
  82. break
  83. addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
  84. 'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
  85. addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
  86. 'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
  87. addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
  88. 'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
  89. # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
  90. addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
  91. 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
  92. proposals = NodeDef()
  93. proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
  94. proposals.op = 'PriorBox'
  95. proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
  96. proposals.input.append(graph_def.node[0].name) # image_tensor
  97. proposals.addAttr('flip', False)
  98. proposals.addAttr('clip', True)
  99. proposals.addAttr('step', features_stride)
  100. proposals.addAttr('offset', 0.0)
  101. proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
  102. widths = []
  103. heights = []
  104. for a in aspect_ratios:
  105. for s in scales:
  106. ar = np.sqrt(a)
  107. heights.append((height_stride**2) * s / ar)
  108. widths.append((width_stride**2) * s * ar)
  109. proposals.addAttr('width', widths)
  110. proposals.addAttr('height', heights)
  111. graph_def.node.extend([proposals])
  112. # Compare with Reshape_5
  113. detectionOut = NodeDef()
  114. detectionOut.name = 'detection_out'
  115. detectionOut.op = 'DetectionOutput'
  116. detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
  117. detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
  118. detectionOut.input.append('proposals')
  119. detectionOut.addAttr('num_classes', 2)
  120. detectionOut.addAttr('share_location', True)
  121. detectionOut.addAttr('background_label_id', 0)
  122. detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold)
  123. detectionOut.addAttr('top_k', 6000)
  124. detectionOut.addAttr('code_type', "CENTER_SIZE")
  125. detectionOut.addAttr('keep_top_k', first_stage_max_proposals)
  126. detectionOut.addAttr('clip', True)
  127. graph_def.node.extend([detectionOut])
  128. # Save as text.
  129. cropAndResizeNodesNames = []
  130. for node in reversed(topNodes):
  131. if node.op != 'CropAndResize':
  132. graph_def.node.extend([node])
  133. topNodes.pop()
  134. else:
  135. cropAndResizeNodesNames.append(node.name)
  136. if numCropAndResize == 1:
  137. break
  138. else:
  139. graph_def.node.extend([node])
  140. topNodes.pop()
  141. numCropAndResize -= 1
  142. addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
  143. addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
  144. 'SecondStageBoxPredictor/Reshape_1/slice',
  145. [0, 0, 1], [-1, -1, -1], graph_def)
  146. addReshape('SecondStageBoxPredictor/Reshape_1/slice',
  147. 'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
  148. # Replace Flatten subgraph onto a single node.
  149. for i in reversed(range(len(graph_def.node))):
  150. if graph_def.node[i].op == 'CropAndResize':
  151. graph_def.node[i].input.insert(1, 'detection_out')
  152. if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
  153. addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
  154. graph_def.node[i].input.pop()
  155. graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
  156. if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
  157. 'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
  158. 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape',
  159. 'SecondStageBoxPredictor/Flatten_1/flatten/Shape',
  160. 'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice',
  161. 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']:
  162. del graph_def.node[i]
  163. for node in graph_def.node:
  164. if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \
  165. node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape':
  166. node.op = 'Flatten'
  167. node.input.pop()
  168. if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
  169. 'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
  170. node.addAttr('loc_pred_transposed', True)
  171. if node.name.startswith('MaxPool2D'):
  172. assert(node.op == 'MaxPool')
  173. assert(len(cropAndResizeNodesNames) == 2)
  174. node.input = [cropAndResizeNodesNames[0]]
  175. del cropAndResizeNodesNames[0]
  176. ################################################################################
  177. ### Postprocessing
  178. ################################################################################
  179. addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
  180. variance = NodeDef()
  181. variance.name = 'proposals/variance'
  182. variance.op = 'Const'
  183. variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
  184. graph_def.node.extend([variance])
  185. varianceEncoder = NodeDef()
  186. varianceEncoder.name = 'variance_encoded'
  187. varianceEncoder.op = 'Mul'
  188. varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
  189. varianceEncoder.input.append(variance.name)
  190. varianceEncoder.addAttr('axis', 2)
  191. graph_def.node.extend([varianceEncoder])
  192. addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
  193. addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
  194. detectionOut = NodeDef()
  195. detectionOut.name = 'detection_out_final'
  196. detectionOut.op = 'DetectionOutput'
  197. detectionOut.input.append('variance_encoded/flatten')
  198. detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
  199. detectionOut.input.append('detection_out/slice/reshape')
  200. detectionOut.addAttr('num_classes', num_classes)
  201. detectionOut.addAttr('share_location', False)
  202. detectionOut.addAttr('background_label_id', num_classes + 1)
  203. detectionOut.addAttr('nms_threshold', 0.6)
  204. detectionOut.addAttr('code_type', "CENTER_SIZE")
  205. detectionOut.addAttr('keep_top_k',100)
  206. detectionOut.addAttr('clip', True)
  207. detectionOut.addAttr('variance_encoded_in_target', True)
  208. detectionOut.addAttr('confidence_threshold', 0.3)
  209. detectionOut.addAttr('group_by_classes', False)
  210. graph_def.node.extend([detectionOut])
  211. for node in reversed(topNodes):
  212. graph_def.node.extend([node])
  213. if node.name.startswith('MaxPool2D'):
  214. assert(node.op == 'MaxPool')
  215. assert(len(cropAndResizeNodesNames) == 1)
  216. node.input = [cropAndResizeNodesNames[0]]
  217. for i in reversed(range(len(graph_def.node))):
  218. if graph_def.node[i].op == 'CropAndResize':
  219. graph_def.node[i].input.insert(1, 'detection_out_final')
  220. break
  221. graph_def.node[-1].name = 'detection_masks'
  222. graph_def.node[-1].op = 'Sigmoid'
  223. graph_def.node[-1].input.pop()
  224. def getUnconnectedNodes():
  225. unconnected = [node.name for node in graph_def.node]
  226. for node in graph_def.node:
  227. for inp in node.input:
  228. if inp in unconnected:
  229. unconnected.remove(inp)
  230. return unconnected
  231. while True:
  232. unconnectedNodes = getUnconnectedNodes()
  233. unconnectedNodes.remove(graph_def.node[-1].name)
  234. if not unconnectedNodes:
  235. break
  236. for name in unconnectedNodes:
  237. for i in range(len(graph_def.node)):
  238. if graph_def.node[i].name == name:
  239. del graph_def.node[i]
  240. break
  241. # Save as text.
  242. graph_def.save(args.output)