tf_text_graph_ssd.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. # This file is a part of OpenCV project.
  2. # It is a subject to the license terms in the LICENSE file found in the top-level directory
  3. # of this distribution and at http://opencv.org/license.html.
  4. #
  5. # Copyright (C) 2018, Intel Corporation, all rights reserved.
  6. # Third party copyrights are property of their respective owners.
  7. #
  8. # Use this script to get the text graph representation (.pbtxt) of SSD-based
  9. # deep learning network trained in TensorFlow Object Detection API.
  10. # Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function.
  11. # See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
  12. import argparse
  13. import re
  14. from math import sqrt
  15. from tf_text_graph_common import *
  16. class SSDAnchorGenerator:
  17. def __init__(self, min_scale, max_scale, num_layers, aspect_ratios,
  18. reduce_boxes_in_lowest_layer, image_width, image_height):
  19. self.min_scale = min_scale
  20. self.aspect_ratios = aspect_ratios
  21. self.reduce_boxes_in_lowest_layer = reduce_boxes_in_lowest_layer
  22. self.image_width = image_width
  23. self.image_height = image_height
  24. self.scales = [min_scale + (max_scale - min_scale) * i / (num_layers - 1)
  25. for i in range(num_layers)] + [1.0]
  26. def get(self, layer_id):
  27. if layer_id == 0 and self.reduce_boxes_in_lowest_layer:
  28. widths = [0.1, self.min_scale * sqrt(2.0), self.min_scale * sqrt(0.5)]
  29. heights = [0.1, self.min_scale / sqrt(2.0), self.min_scale / sqrt(0.5)]
  30. else:
  31. widths = [self.scales[layer_id] * sqrt(ar) for ar in self.aspect_ratios]
  32. heights = [self.scales[layer_id] / sqrt(ar) for ar in self.aspect_ratios]
  33. widths += [sqrt(self.scales[layer_id] * self.scales[layer_id + 1])]
  34. heights += [sqrt(self.scales[layer_id] * self.scales[layer_id + 1])]
  35. min_size = min(self.image_width, self.image_height)
  36. widths = [w * min_size for w in widths]
  37. heights = [h * min_size for h in heights]
  38. return widths, heights
  39. class MultiscaleAnchorGenerator:
  40. def __init__(self, min_level, aspect_ratios, scales_per_octave, anchor_scale):
  41. self.min_level = min_level
  42. self.aspect_ratios = aspect_ratios
  43. self.anchor_scale = anchor_scale
  44. self.scales = [2**(float(s) / scales_per_octave) for s in range(scales_per_octave)]
  45. def get(self, layer_id):
  46. widths = []
  47. heights = []
  48. for a in self.aspect_ratios:
  49. for s in self.scales:
  50. base_anchor_size = 2**(self.min_level + layer_id) * self.anchor_scale
  51. ar = sqrt(a)
  52. heights.append(base_anchor_size * s / ar)
  53. widths.append(base_anchor_size * s * ar)
  54. return widths, heights
  55. def createSSDGraph(modelPath, configPath, outputPath):
  56. # Nodes that should be kept.
  57. keepOps = ['Conv2D', 'BiasAdd', 'Add', 'AddV2', 'Relu', 'Relu6', 'Placeholder', 'FusedBatchNorm',
  58. 'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
  59. 'Sub', 'ResizeNearestNeighbor', 'Pad', 'FusedBatchNormV3', 'Mean']
  60. # Node with which prefixes should be removed
  61. prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Concatenate/', 'Postprocessor/', 'Preprocessor/map')
  62. # Load a config file.
  63. config = readTextMessage(configPath)
  64. config = config['model'][0]['ssd'][0]
  65. num_classes = int(config['num_classes'][0])
  66. fixed_shape_resizer = config['image_resizer'][0]['fixed_shape_resizer'][0]
  67. image_width = int(fixed_shape_resizer['width'][0])
  68. image_height = int(fixed_shape_resizer['height'][0])
  69. box_predictor = 'convolutional' if 'convolutional_box_predictor' in config['box_predictor'][0] else 'weight_shared_convolutional'
  70. anchor_generator = config['anchor_generator'][0]
  71. if 'ssd_anchor_generator' in anchor_generator:
  72. ssd_anchor_generator = anchor_generator['ssd_anchor_generator'][0]
  73. min_scale = float(ssd_anchor_generator['min_scale'][0])
  74. max_scale = float(ssd_anchor_generator['max_scale'][0])
  75. num_layers = int(ssd_anchor_generator['num_layers'][0])
  76. aspect_ratios = [float(ar) for ar in ssd_anchor_generator['aspect_ratios']]
  77. reduce_boxes_in_lowest_layer = True
  78. if 'reduce_boxes_in_lowest_layer' in ssd_anchor_generator:
  79. reduce_boxes_in_lowest_layer = ssd_anchor_generator['reduce_boxes_in_lowest_layer'][0] == 'true'
  80. priors_generator = SSDAnchorGenerator(min_scale, max_scale, num_layers,
  81. aspect_ratios, reduce_boxes_in_lowest_layer,
  82. image_width, image_height)
  83. print('Scale: [%f-%f]' % (min_scale, max_scale))
  84. print('Aspect ratios: %s' % str(aspect_ratios))
  85. print('Reduce boxes in the lowest layer: %s' % str(reduce_boxes_in_lowest_layer))
  86. elif 'multiscale_anchor_generator' in anchor_generator:
  87. multiscale_anchor_generator = anchor_generator['multiscale_anchor_generator'][0]
  88. min_level = int(multiscale_anchor_generator['min_level'][0])
  89. max_level = int(multiscale_anchor_generator['max_level'][0])
  90. anchor_scale = float(multiscale_anchor_generator['anchor_scale'][0])
  91. aspect_ratios = [float(ar) for ar in multiscale_anchor_generator['aspect_ratios']]
  92. scales_per_octave = int(multiscale_anchor_generator['scales_per_octave'][0])
  93. num_layers = max_level - min_level + 1
  94. priors_generator = MultiscaleAnchorGenerator(min_level, aspect_ratios,
  95. scales_per_octave, anchor_scale)
  96. print('Levels: [%d-%d]' % (min_level, max_level))
  97. print('Anchor scale: %f' % anchor_scale)
  98. print('Scales per octave: %d' % scales_per_octave)
  99. print('Aspect ratios: %s' % str(aspect_ratios))
  100. else:
  101. print('Unknown anchor_generator')
  102. exit(0)
  103. print('Number of classes: %d' % num_classes)
  104. print('Number of layers: %d' % num_layers)
  105. print('box predictor: %s' % box_predictor)
  106. print('Input image size: %dx%d' % (image_width, image_height))
  107. # Read the graph.
  108. outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
  109. writeTextGraph(modelPath, outputPath, outNames)
  110. graph_def = parseTextGraph(outputPath)
  111. def getUnconnectedNodes():
  112. unconnected = []
  113. for node in graph_def.node:
  114. unconnected.append(node.name)
  115. for inp in node.input:
  116. if inp in unconnected:
  117. unconnected.remove(inp)
  118. return unconnected
  119. def fuse_nodes(nodesToKeep):
  120. # Detect unfused batch normalization nodes and fuse them.
  121. # Add_0 <-- moving_variance, add_y
  122. # Rsqrt <-- Add_0
  123. # Mul_0 <-- Rsqrt, gamma
  124. # Mul_1 <-- input, Mul_0
  125. # Mul_2 <-- moving_mean, Mul_0
  126. # Sub_0 <-- beta, Mul_2
  127. # Add_1 <-- Mul_1, Sub_0
  128. nodesMap = {node.name: node for node in graph_def.node}
  129. subgraphBatchNorm = ['Add',
  130. ['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
  131. ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
  132. subgraphBatchNormV2 = ['AddV2',
  133. ['Mul', 'input', ['Mul', ['Rsqrt', ['AddV2', 'moving_variance', 'add_y']], 'gamma']],
  134. ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
  135. # Detect unfused nearest neighbor resize.
  136. subgraphResizeNN = ['Reshape',
  137. ['Mul', ['Reshape', 'input', ['Pack', 'shape_1', 'shape_2', 'shape_3', 'shape_4', 'shape_5']],
  138. 'ones'],
  139. ['Pack', ['StridedSlice', ['Shape', 'input'], 'stack', 'stack_1', 'stack_2'],
  140. 'out_height', 'out_width', 'out_channels']]
  141. def checkSubgraph(node, targetNode, inputs, fusedNodes):
  142. op = targetNode[0]
  143. if node.op == op and (len(node.input) >= len(targetNode) - 1):
  144. fusedNodes.append(node)
  145. for i, inpOp in enumerate(targetNode[1:]):
  146. if isinstance(inpOp, list):
  147. if not node.input[i] in nodesMap or \
  148. not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
  149. return False
  150. else:
  151. inputs[inpOp] = node.input[i]
  152. return True
  153. else:
  154. return False
  155. nodesToRemove = []
  156. for node in graph_def.node:
  157. inputs = {}
  158. fusedNodes = []
  159. if checkSubgraph(node, subgraphBatchNorm, inputs, fusedNodes) or \
  160. checkSubgraph(node, subgraphBatchNormV2, inputs, fusedNodes):
  161. name = node.name
  162. node.Clear()
  163. node.name = name
  164. node.op = 'FusedBatchNorm'
  165. node.input.append(inputs['input'])
  166. node.input.append(inputs['gamma'])
  167. node.input.append(inputs['beta'])
  168. node.input.append(inputs['moving_mean'])
  169. node.input.append(inputs['moving_variance'])
  170. node.addAttr('epsilon', 0.001)
  171. nodesToRemove += fusedNodes[1:]
  172. inputs = {}
  173. fusedNodes = []
  174. if checkSubgraph(node, subgraphResizeNN, inputs, fusedNodes):
  175. name = node.name
  176. node.Clear()
  177. node.name = name
  178. node.op = 'ResizeNearestNeighbor'
  179. node.input.append(inputs['input'])
  180. node.input.append(name + '/output_shape')
  181. out_height_node = nodesMap[inputs['out_height']]
  182. out_width_node = nodesMap[inputs['out_width']]
  183. out_height = int(out_height_node.attr['value']['tensor'][0]['int_val'][0])
  184. out_width = int(out_width_node.attr['value']['tensor'][0]['int_val'][0])
  185. shapeNode = NodeDef()
  186. shapeNode.name = name + '/output_shape'
  187. shapeNode.op = 'Const'
  188. shapeNode.addAttr('value', [out_height, out_width])
  189. graph_def.node.insert(graph_def.node.index(node), shapeNode)
  190. nodesToKeep.append(shapeNode.name)
  191. nodesToRemove += fusedNodes[1:]
  192. for node in nodesToRemove:
  193. graph_def.node.remove(node)
  194. nodesToKeep = []
  195. fuse_nodes(nodesToKeep)
  196. removeIdentity(graph_def)
  197. def to_remove(name, op):
  198. return (not name in nodesToKeep) and \
  199. (op == 'Const' or (not op in keepOps) or name.startswith(prefixesToRemove))
  200. removeUnusedNodesAndAttrs(to_remove, graph_def)
  201. # Connect input node to the first layer
  202. assert(graph_def.node[0].op == 'Placeholder')
  203. try:
  204. input_shape = graph_def.node[0].attr['shape']['shape'][0]['dim']
  205. input_shape[1]['size'] = image_height
  206. input_shape[2]['size'] = image_width
  207. except:
  208. print("Input shapes are undefined")
  209. # assert(graph_def.node[1].op == 'Conv2D')
  210. weights = graph_def.node[1].input[-1]
  211. for i in range(len(graph_def.node[1].input)):
  212. graph_def.node[1].input.pop()
  213. graph_def.node[1].input.append(graph_def.node[0].name)
  214. graph_def.node[1].input.append(weights)
  215. # check and correct the case when preprocessing block is after input
  216. preproc_id = "Preprocessor/"
  217. if graph_def.node[2].name.startswith(preproc_id) and \
  218. graph_def.node[2].input[0].startswith(preproc_id):
  219. if not any(preproc_id in inp for inp in graph_def.node[3].input):
  220. graph_def.node[3].input.insert(0, graph_def.node[2].name)
  221. # Create SSD postprocessing head ###############################################
  222. # Concatenate predictions of classes, predictions of bounding boxes and proposals.
  223. def addConcatNode(name, inputs, axisNodeName):
  224. concat = NodeDef()
  225. concat.name = name
  226. concat.op = 'ConcatV2'
  227. for inp in inputs:
  228. concat.input.append(inp)
  229. concat.input.append(axisNodeName)
  230. graph_def.node.extend([concat])
  231. addConstNode('concat/axis_flatten', [-1], graph_def)
  232. addConstNode('PriorBox/concat/axis', [-2], graph_def)
  233. for label in ['ClassPredictor', 'BoxEncodingPredictor' if box_predictor == 'convolutional' else 'BoxPredictor']:
  234. concatInputs = []
  235. for i in range(num_layers):
  236. # Flatten predictions
  237. flatten = NodeDef()
  238. if box_predictor == 'convolutional':
  239. inpName = 'BoxPredictor_%d/%s/BiasAdd' % (i, label)
  240. else:
  241. if i == 0:
  242. inpName = 'WeightSharedConvolutionalBoxPredictor/%s/BiasAdd' % label
  243. else:
  244. inpName = 'WeightSharedConvolutionalBoxPredictor_%d/%s/BiasAdd' % (i, label)
  245. flatten.input.append(inpName)
  246. flatten.name = inpName + '/Flatten'
  247. flatten.op = 'Flatten'
  248. concatInputs.append(flatten.name)
  249. graph_def.node.extend([flatten])
  250. addConcatNode('%s/concat' % label, concatInputs, 'concat/axis_flatten')
  251. num_matched_layers = 0
  252. for node in graph_def.node:
  253. if re.match('BoxPredictor_\d/BoxEncodingPredictor/convolution', node.name) or \
  254. re.match('BoxPredictor_\d/BoxEncodingPredictor/Conv2D', node.name) or \
  255. re.match('WeightSharedConvolutionalBoxPredictor(_\d)*/BoxPredictor/Conv2D', node.name):
  256. node.addAttr('loc_pred_transposed', True)
  257. num_matched_layers += 1
  258. assert(num_matched_layers == num_layers)
  259. # Add layers that generate anchors (bounding boxes proposals).
  260. priorBoxes = []
  261. boxCoder = config['box_coder'][0]
  262. fasterRcnnBoxCoder = boxCoder['faster_rcnn_box_coder'][0]
  263. boxCoderVariance = [1.0/float(fasterRcnnBoxCoder['x_scale'][0]), 1.0/float(fasterRcnnBoxCoder['y_scale'][0]), 1.0/float(fasterRcnnBoxCoder['width_scale'][0]), 1.0/float(fasterRcnnBoxCoder['height_scale'][0])]
  264. for i in range(num_layers):
  265. priorBox = NodeDef()
  266. priorBox.name = 'PriorBox_%d' % i
  267. priorBox.op = 'PriorBox'
  268. if box_predictor == 'convolutional':
  269. priorBox.input.append('BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % i)
  270. else:
  271. if i == 0:
  272. priorBox.input.append('WeightSharedConvolutionalBoxPredictor/BoxPredictor/Conv2D')
  273. else:
  274. priorBox.input.append('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/BiasAdd' % i)
  275. priorBox.input.append(graph_def.node[0].name) # image_tensor
  276. priorBox.addAttr('flip', False)
  277. priorBox.addAttr('clip', False)
  278. widths, heights = priors_generator.get(i)
  279. priorBox.addAttr('width', widths)
  280. priorBox.addAttr('height', heights)
  281. priorBox.addAttr('variance', boxCoderVariance)
  282. graph_def.node.extend([priorBox])
  283. priorBoxes.append(priorBox.name)
  284. # Compare this layer's output with Postprocessor/Reshape
  285. addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten')
  286. # Sigmoid for classes predictions and DetectionOutput layer
  287. addReshape('ClassPredictor/concat', 'ClassPredictor/concat3d', [0, -1, num_classes + 1], graph_def)
  288. sigmoid = NodeDef()
  289. sigmoid.name = 'ClassPredictor/concat/sigmoid'
  290. sigmoid.op = 'Sigmoid'
  291. sigmoid.input.append('ClassPredictor/concat3d')
  292. graph_def.node.extend([sigmoid])
  293. addFlatten(sigmoid.name, sigmoid.name + '/Flatten', graph_def)
  294. detectionOut = NodeDef()
  295. detectionOut.name = 'detection_out'
  296. detectionOut.op = 'DetectionOutput'
  297. if box_predictor == 'convolutional':
  298. detectionOut.input.append('BoxEncodingPredictor/concat')
  299. else:
  300. detectionOut.input.append('BoxPredictor/concat')
  301. detectionOut.input.append(sigmoid.name + '/Flatten')
  302. detectionOut.input.append('PriorBox/concat')
  303. detectionOut.addAttr('num_classes', num_classes + 1)
  304. detectionOut.addAttr('share_location', True)
  305. detectionOut.addAttr('background_label_id', 0)
  306. postProcessing = config['post_processing'][0]
  307. batchNMS = postProcessing['batch_non_max_suppression'][0]
  308. if 'iou_threshold' in batchNMS:
  309. detectionOut.addAttr('nms_threshold', float(batchNMS['iou_threshold'][0]))
  310. else:
  311. detectionOut.addAttr('nms_threshold', 0.6)
  312. if 'score_threshold' in batchNMS:
  313. detectionOut.addAttr('confidence_threshold', float(batchNMS['score_threshold'][0]))
  314. else:
  315. detectionOut.addAttr('confidence_threshold', 0.01)
  316. if 'max_detections_per_class' in batchNMS:
  317. detectionOut.addAttr('top_k', int(batchNMS['max_detections_per_class'][0]))
  318. else:
  319. detectionOut.addAttr('top_k', 100)
  320. if 'max_total_detections' in batchNMS:
  321. detectionOut.addAttr('keep_top_k', int(batchNMS['max_total_detections'][0]))
  322. else:
  323. detectionOut.addAttr('keep_top_k', 100)
  324. detectionOut.addAttr('code_type', "CENTER_SIZE")
  325. graph_def.node.extend([detectionOut])
  326. while True:
  327. unconnectedNodes = getUnconnectedNodes()
  328. unconnectedNodes.remove(detectionOut.name)
  329. if not unconnectedNodes:
  330. break
  331. for name in unconnectedNodes:
  332. for i in range(len(graph_def.node)):
  333. if graph_def.node[i].name == name:
  334. del graph_def.node[i]
  335. break
  336. # Save as text.
  337. graph_def.save(outputPath)
  338. if __name__ == "__main__":
  339. parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
  340. 'SSD model from TensorFlow Object Detection API. '
  341. 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
  342. parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
  343. parser.add_argument('--output', required=True, help='Path to output text graph.')
  344. parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
  345. args = parser.parse_args()
  346. createSSDGraph(args.input, args.config, args.output)