123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- def tokenize(s):
- tokens = []
- token = ""
- isString = False
- isComment = False
- for symbol in s:
- isComment = (isComment and symbol != '\n') or (not isString and symbol == '#')
- if isComment:
- continue
- if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \
- symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \
- symbol == ',':
- if (symbol == '\"' or symbol == '\'') and isString:
- tokens.append(token)
- token = ""
- else:
- if isString:
- token += symbol
- elif token:
- tokens.append(token)
- token = ""
- isString = (symbol == '\"' or symbol == '\'') ^ isString
- elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']':
- if token:
- tokens.append(token)
- token = ""
- tokens.append(symbol)
- else:
- token += symbol
- if token:
- tokens.append(token)
- return tokens
- def parseMessage(tokens, idx):
- msg = {}
- assert(tokens[idx] == '{')
- isArray = False
- while True:
- if not isArray:
- idx += 1
- if idx < len(tokens):
- fieldName = tokens[idx]
- else:
- return None
- if fieldName == '}':
- break
- idx += 1
- fieldValue = tokens[idx]
- if fieldValue == '{':
- embeddedMsg, idx = parseMessage(tokens, idx)
- if fieldName in msg:
- msg[fieldName].append(embeddedMsg)
- else:
- msg[fieldName] = [embeddedMsg]
- elif fieldValue == '[':
- isArray = True
- elif fieldValue == ']':
- isArray = False
- else:
- if fieldName in msg:
- msg[fieldName].append(fieldValue)
- else:
- msg[fieldName] = [fieldValue]
- return msg, idx
- def readTextMessage(filePath):
- if not filePath:
- return {}
- with open(filePath, 'rt') as f:
- content = f.read()
- tokens = tokenize('{' + content + '}')
- msg = parseMessage(tokens, 0)
- return msg[0] if msg else {}
- def listToTensor(values):
- if all([isinstance(v, float) for v in values]):
- dtype = 'DT_FLOAT'
- field = 'float_val'
- elif all([isinstance(v, int) for v in values]):
- dtype = 'DT_INT32'
- field = 'int_val'
- else:
- raise Exception('Wrong values types')
- msg = {
- 'tensor': {
- 'dtype': dtype,
- 'tensor_shape': {
- 'dim': {
- 'size': len(values)
- }
- }
- }
- }
- msg['tensor'][field] = values
- return msg
- def addConstNode(name, values, graph_def):
- node = NodeDef()
- node.name = name
- node.op = 'Const'
- node.addAttr('value', values)
- graph_def.node.extend([node])
- def addSlice(inp, out, begins, sizes, graph_def):
- beginsNode = NodeDef()
- beginsNode.name = out + '/begins'
- beginsNode.op = 'Const'
- beginsNode.addAttr('value', begins)
- graph_def.node.extend([beginsNode])
- sizesNode = NodeDef()
- sizesNode.name = out + '/sizes'
- sizesNode.op = 'Const'
- sizesNode.addAttr('value', sizes)
- graph_def.node.extend([sizesNode])
- sliced = NodeDef()
- sliced.name = out
- sliced.op = 'Slice'
- sliced.input.append(inp)
- sliced.input.append(beginsNode.name)
- sliced.input.append(sizesNode.name)
- graph_def.node.extend([sliced])
- def addReshape(inp, out, shape, graph_def):
- shapeNode = NodeDef()
- shapeNode.name = out + '/shape'
- shapeNode.op = 'Const'
- shapeNode.addAttr('value', shape)
- graph_def.node.extend([shapeNode])
- reshape = NodeDef()
- reshape.name = out
- reshape.op = 'Reshape'
- reshape.input.append(inp)
- reshape.input.append(shapeNode.name)
- graph_def.node.extend([reshape])
- def addSoftMax(inp, out, graph_def):
- softmax = NodeDef()
- softmax.name = out
- softmax.op = 'Softmax'
- softmax.addAttr('axis', -1)
- softmax.input.append(inp)
- graph_def.node.extend([softmax])
- def addFlatten(inp, out, graph_def):
- flatten = NodeDef()
- flatten.name = out
- flatten.op = 'Flatten'
- flatten.input.append(inp)
- graph_def.node.extend([flatten])
- class NodeDef:
- def __init__(self):
- self.input = []
- self.name = ""
- self.op = ""
- self.attr = {}
- def addAttr(self, key, value):
- assert(not key in self.attr)
- if isinstance(value, bool):
- self.attr[key] = {'b': value}
- elif isinstance(value, int):
- self.attr[key] = {'i': value}
- elif isinstance(value, float):
- self.attr[key] = {'f': value}
- elif isinstance(value, str):
- self.attr[key] = {'s': value}
- elif isinstance(value, list):
- self.attr[key] = listToTensor(value)
- else:
- raise Exception('Unknown type of attribute ' + key)
- def Clear(self):
- self.input = []
- self.name = ""
- self.op = ""
- self.attr = {}
- class GraphDef:
- def __init__(self):
- self.node = []
- def save(self, filePath):
- with open(filePath, 'wt') as f:
- def printAttr(d, indent):
- indent = ' ' * indent
- for key, value in sorted(d.items(), key=lambda x:x[0].lower()):
- value = value if isinstance(value, list) else [value]
- for v in value:
- if isinstance(v, dict):
- f.write(indent + key + ' {\n')
- printAttr(v, len(indent) + 2)
- f.write(indent + '}\n')
- else:
- isString = False
- if isinstance(v, str) and not v.startswith('DT_'):
- try:
- float(v)
- except:
- isString = True
- if isinstance(v, bool):
- printed = 'true' if v else 'false'
- elif v == 'true' or v == 'false':
- printed = 'true' if v == 'true' else 'false'
- elif isString:
- printed = '\"%s\"' % v
- else:
- printed = str(v)
- f.write(indent + key + ': ' + printed + '\n')
- for node in self.node:
- f.write('node {\n')
- f.write(' name: \"%s\"\n' % node.name)
- f.write(' op: \"%s\"\n' % node.op)
- for inp in node.input:
- f.write(' input: \"%s\"\n' % inp)
- for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()):
- f.write(' attr {\n')
- f.write(' key: \"%s\"\n' % key)
- f.write(' value {\n')
- printAttr(value, 6)
- f.write(' }\n')
- f.write(' }\n')
- f.write('}\n')
- def parseTextGraph(filePath):
- msg = readTextMessage(filePath)
- graph = GraphDef()
- for node in msg['node']:
- graphNode = NodeDef()
- graphNode.name = node['name'][0]
- graphNode.op = node['op'][0]
- graphNode.input = node['input'] if 'input' in node else []
- if 'attr' in node:
- for attr in node['attr']:
- graphNode.attr[attr['key'][0]] = attr['value'][0]
- graph.node.append(graphNode)
- return graph
- # Removes Identity nodes
- def removeIdentity(graph_def):
- identities = {}
- for node in graph_def.node:
- if node.op == 'Identity' or node.op == 'IdentityN':
- inp = node.input[0]
- if inp in identities:
- identities[node.name] = identities[inp]
- else:
- identities[node.name] = inp
- graph_def.node.remove(node)
- for node in graph_def.node:
- for i in range(len(node.input)):
- if node.input[i] in identities:
- node.input[i] = identities[node.input[i]]
- def removeUnusedNodesAndAttrs(to_remove, graph_def):
- unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
- 'Index', 'Tperm', 'is_training', 'Tpaddings']
- removedNodes = []
- for i in reversed(range(len(graph_def.node))):
- op = graph_def.node[i].op
- name = graph_def.node[i].name
- if to_remove(name, op):
- if op != 'Const':
- removedNodes.append(name)
- del graph_def.node[i]
- else:
- for attr in unusedAttrs:
- if attr in graph_def.node[i].attr:
- del graph_def.node[i].attr[attr]
- # Remove references to removed nodes except Const nodes.
- for node in graph_def.node:
- for i in reversed(range(len(node.input))):
- if node.input[i] in removedNodes:
- del node.input[i]
- def writeTextGraph(modelPath, outputPath, outNodes):
- try:
- import cv2 as cv
- cv.dnn.writeTextGraph(modelPath, outputPath)
- except:
- import tensorflow as tf
- from tensorflow.tools.graph_transforms import TransformGraph
- with tf.gfile.FastGFile(modelPath, 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])
- for node in graph_def.node:
- if node.op == 'Const':
- if 'value' in node.attr and node.attr['value'].tensor.tensor_content:
- node.attr['value'].tensor.tensor_content = b''
- tf.train.write_graph(graph_def, "", outputPath, as_text=True)
|