tf_text_graph_common.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. def tokenize(s):
  2. tokens = []
  3. token = ""
  4. isString = False
  5. isComment = False
  6. for symbol in s:
  7. isComment = (isComment and symbol != '\n') or (not isString and symbol == '#')
  8. if isComment:
  9. continue
  10. if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \
  11. symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \
  12. symbol == ',':
  13. if (symbol == '\"' or symbol == '\'') and isString:
  14. tokens.append(token)
  15. token = ""
  16. else:
  17. if isString:
  18. token += symbol
  19. elif token:
  20. tokens.append(token)
  21. token = ""
  22. isString = (symbol == '\"' or symbol == '\'') ^ isString
  23. elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']':
  24. if token:
  25. tokens.append(token)
  26. token = ""
  27. tokens.append(symbol)
  28. else:
  29. token += symbol
  30. if token:
  31. tokens.append(token)
  32. return tokens
  33. def parseMessage(tokens, idx):
  34. msg = {}
  35. assert(tokens[idx] == '{')
  36. isArray = False
  37. while True:
  38. if not isArray:
  39. idx += 1
  40. if idx < len(tokens):
  41. fieldName = tokens[idx]
  42. else:
  43. return None
  44. if fieldName == '}':
  45. break
  46. idx += 1
  47. fieldValue = tokens[idx]
  48. if fieldValue == '{':
  49. embeddedMsg, idx = parseMessage(tokens, idx)
  50. if fieldName in msg:
  51. msg[fieldName].append(embeddedMsg)
  52. else:
  53. msg[fieldName] = [embeddedMsg]
  54. elif fieldValue == '[':
  55. isArray = True
  56. elif fieldValue == ']':
  57. isArray = False
  58. else:
  59. if fieldName in msg:
  60. msg[fieldName].append(fieldValue)
  61. else:
  62. msg[fieldName] = [fieldValue]
  63. return msg, idx
  64. def readTextMessage(filePath):
  65. if not filePath:
  66. return {}
  67. with open(filePath, 'rt') as f:
  68. content = f.read()
  69. tokens = tokenize('{' + content + '}')
  70. msg = parseMessage(tokens, 0)
  71. return msg[0] if msg else {}
  72. def listToTensor(values):
  73. if all([isinstance(v, float) for v in values]):
  74. dtype = 'DT_FLOAT'
  75. field = 'float_val'
  76. elif all([isinstance(v, int) for v in values]):
  77. dtype = 'DT_INT32'
  78. field = 'int_val'
  79. else:
  80. raise Exception('Wrong values types')
  81. msg = {
  82. 'tensor': {
  83. 'dtype': dtype,
  84. 'tensor_shape': {
  85. 'dim': {
  86. 'size': len(values)
  87. }
  88. }
  89. }
  90. }
  91. msg['tensor'][field] = values
  92. return msg
  93. def addConstNode(name, values, graph_def):
  94. node = NodeDef()
  95. node.name = name
  96. node.op = 'Const'
  97. node.addAttr('value', values)
  98. graph_def.node.extend([node])
  99. def addSlice(inp, out, begins, sizes, graph_def):
  100. beginsNode = NodeDef()
  101. beginsNode.name = out + '/begins'
  102. beginsNode.op = 'Const'
  103. beginsNode.addAttr('value', begins)
  104. graph_def.node.extend([beginsNode])
  105. sizesNode = NodeDef()
  106. sizesNode.name = out + '/sizes'
  107. sizesNode.op = 'Const'
  108. sizesNode.addAttr('value', sizes)
  109. graph_def.node.extend([sizesNode])
  110. sliced = NodeDef()
  111. sliced.name = out
  112. sliced.op = 'Slice'
  113. sliced.input.append(inp)
  114. sliced.input.append(beginsNode.name)
  115. sliced.input.append(sizesNode.name)
  116. graph_def.node.extend([sliced])
  117. def addReshape(inp, out, shape, graph_def):
  118. shapeNode = NodeDef()
  119. shapeNode.name = out + '/shape'
  120. shapeNode.op = 'Const'
  121. shapeNode.addAttr('value', shape)
  122. graph_def.node.extend([shapeNode])
  123. reshape = NodeDef()
  124. reshape.name = out
  125. reshape.op = 'Reshape'
  126. reshape.input.append(inp)
  127. reshape.input.append(shapeNode.name)
  128. graph_def.node.extend([reshape])
  129. def addSoftMax(inp, out, graph_def):
  130. softmax = NodeDef()
  131. softmax.name = out
  132. softmax.op = 'Softmax'
  133. softmax.addAttr('axis', -1)
  134. softmax.input.append(inp)
  135. graph_def.node.extend([softmax])
  136. def addFlatten(inp, out, graph_def):
  137. flatten = NodeDef()
  138. flatten.name = out
  139. flatten.op = 'Flatten'
  140. flatten.input.append(inp)
  141. graph_def.node.extend([flatten])
  142. class NodeDef:
  143. def __init__(self):
  144. self.input = []
  145. self.name = ""
  146. self.op = ""
  147. self.attr = {}
  148. def addAttr(self, key, value):
  149. assert(not key in self.attr)
  150. if isinstance(value, bool):
  151. self.attr[key] = {'b': value}
  152. elif isinstance(value, int):
  153. self.attr[key] = {'i': value}
  154. elif isinstance(value, float):
  155. self.attr[key] = {'f': value}
  156. elif isinstance(value, str):
  157. self.attr[key] = {'s': value}
  158. elif isinstance(value, list):
  159. self.attr[key] = listToTensor(value)
  160. else:
  161. raise Exception('Unknown type of attribute ' + key)
  162. def Clear(self):
  163. self.input = []
  164. self.name = ""
  165. self.op = ""
  166. self.attr = {}
  167. class GraphDef:
  168. def __init__(self):
  169. self.node = []
  170. def save(self, filePath):
  171. with open(filePath, 'wt') as f:
  172. def printAttr(d, indent):
  173. indent = ' ' * indent
  174. for key, value in sorted(d.items(), key=lambda x:x[0].lower()):
  175. value = value if isinstance(value, list) else [value]
  176. for v in value:
  177. if isinstance(v, dict):
  178. f.write(indent + key + ' {\n')
  179. printAttr(v, len(indent) + 2)
  180. f.write(indent + '}\n')
  181. else:
  182. isString = False
  183. if isinstance(v, str) and not v.startswith('DT_'):
  184. try:
  185. float(v)
  186. except:
  187. isString = True
  188. if isinstance(v, bool):
  189. printed = 'true' if v else 'false'
  190. elif v == 'true' or v == 'false':
  191. printed = 'true' if v == 'true' else 'false'
  192. elif isString:
  193. printed = '\"%s\"' % v
  194. else:
  195. printed = str(v)
  196. f.write(indent + key + ': ' + printed + '\n')
  197. for node in self.node:
  198. f.write('node {\n')
  199. f.write(' name: \"%s\"\n' % node.name)
  200. f.write(' op: \"%s\"\n' % node.op)
  201. for inp in node.input:
  202. f.write(' input: \"%s\"\n' % inp)
  203. for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()):
  204. f.write(' attr {\n')
  205. f.write(' key: \"%s\"\n' % key)
  206. f.write(' value {\n')
  207. printAttr(value, 6)
  208. f.write(' }\n')
  209. f.write(' }\n')
  210. f.write('}\n')
  211. def parseTextGraph(filePath):
  212. msg = readTextMessage(filePath)
  213. graph = GraphDef()
  214. for node in msg['node']:
  215. graphNode = NodeDef()
  216. graphNode.name = node['name'][0]
  217. graphNode.op = node['op'][0]
  218. graphNode.input = node['input'] if 'input' in node else []
  219. if 'attr' in node:
  220. for attr in node['attr']:
  221. graphNode.attr[attr['key'][0]] = attr['value'][0]
  222. graph.node.append(graphNode)
  223. return graph
  224. # Removes Identity nodes
  225. def removeIdentity(graph_def):
  226. identities = {}
  227. for node in graph_def.node:
  228. if node.op == 'Identity' or node.op == 'IdentityN':
  229. inp = node.input[0]
  230. if inp in identities:
  231. identities[node.name] = identities[inp]
  232. else:
  233. identities[node.name] = inp
  234. graph_def.node.remove(node)
  235. for node in graph_def.node:
  236. for i in range(len(node.input)):
  237. if node.input[i] in identities:
  238. node.input[i] = identities[node.input[i]]
  239. def removeUnusedNodesAndAttrs(to_remove, graph_def):
  240. unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
  241. 'Index', 'Tperm', 'is_training', 'Tpaddings']
  242. removedNodes = []
  243. for i in reversed(range(len(graph_def.node))):
  244. op = graph_def.node[i].op
  245. name = graph_def.node[i].name
  246. if to_remove(name, op):
  247. if op != 'Const':
  248. removedNodes.append(name)
  249. del graph_def.node[i]
  250. else:
  251. for attr in unusedAttrs:
  252. if attr in graph_def.node[i].attr:
  253. del graph_def.node[i].attr[attr]
  254. # Remove references to removed nodes except Const nodes.
  255. for node in graph_def.node:
  256. for i in reversed(range(len(node.input))):
  257. if node.input[i] in removedNodes:
  258. del node.input[i]
  259. def writeTextGraph(modelPath, outputPath, outNodes):
  260. try:
  261. import cv2 as cv
  262. cv.dnn.writeTextGraph(modelPath, outputPath)
  263. except:
  264. import tensorflow as tf
  265. from tensorflow.tools.graph_transforms import TransformGraph
  266. with tf.gfile.FastGFile(modelPath, 'rb') as f:
  267. graph_def = tf.GraphDef()
  268. graph_def.ParseFromString(f.read())
  269. graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])
  270. for node in graph_def.node:
  271. if node.op == 'Const':
  272. if 'value' in node.attr and node.attr['value'].tensor.tensor_content:
  273. node.attr['value'].tensor.tensor_content = b''
  274. tf.train.write_graph(graph_def, "", outputPath, as_text=True)